~bonbon/gmcts

73493dface28f87d1a6c0fd5f18ca256d445dda2 — bonbon 2 years ago f4368f9
rework action selection system

We offload the action selection to the user of this package.
Theoritically, gmcts only needs to know how many actions there are,
and pick random actions from there.

This change is a rather massive change for current implementations,
but significant speed gains are available due to not having to convert
actions from values to interfaces back to values.

This change also removes the need for an Action interface, as the package
no longer needs to hold a list of random actions.
8 files changed, 47 insertions(+), 145 deletions(-)

D comparable_test.go
M mcts.go
M mcts_test.go
M models.go
M search.go
M tree.go
M tree_test.go
D utils.go
D comparable_test.go => comparable_test.go +0 -59
@@ 1,59 0,0 @@
package gmcts

import "testing"

//comparableState tests the comparable action requirement of gmcts, as
//the GetActions method returns a noncomparable action.
type comparableState struct{}

//nonComparableState tests the comparable state requirement of gmcts.
type nonComparableState struct {
	comparableState
	_ []int
}

func (n comparableState) GetActions() []Action {
	return []Action{nonComparableState{}}
}

func (n comparableState) ApplyAction(a Action) (Game, error) {
	return n, nil
}

func (n comparableState) IsTerminal() bool {
	return true
}

func (n comparableState) Hash() interface{} {
	return 0
}

func (n comparableState) Player() Player {
	return 0
}

func (n comparableState) Winners() []Player {
	return nil
}

func TestNonComparableState(t *testing.T) {
	//Calling NewMCTS should panic, as the nonComparableState is, as
	//the name suggests, not comparable.
	defer func() {
		if r := recover(); r == nil {
			t.FailNow()
		}
	}()
	NewMCTS(nonComparableState{})
}

func TestNonComparableAction(t *testing.T) {
	//Calling NewMCTS should panic, as the actions from comparableState
	//are not comparable.
	defer func() {
		if r := recover(); r == nil {
			t.FailNow()
		}
	}()
	NewMCTS(comparableState{})
}

M mcts.go => mcts.go +5 -24
@@ 2,30 2,11 @@ package gmcts

import (
	"math/rand"
	"reflect"
	"sync"
)

//NewMCTS returns a new MCTS wrapper
//
//If either the Game or its Action types are not comparable,
//this function panics
func NewMCTS(initial Game) *MCTS {
	//Check if Game type if comparable
	if !reflect.TypeOf(initial).Comparable() {
		panic("gmcts: game type is not comparable")
	}

	//Check if Action type is comparable
	//We only need to check the actions that can affect the initial gamestate
	//as those are the only actions that need to be compared.
	actions := initial.GetActions()
	for i := range actions {
		if !reflect.TypeOf(actions[i]).Comparable() {
			panic("gmcts: action type is not comparable")
		}
	}

	return &MCTS{
		init:  initial,
		trees: make([]*Tree, 0),


@@ 80,28 61,28 @@ func (m *MCTS) AddTree(t *Tree) {
//BestAction returns nil if it has received no trees
//to search through or if the current state
//it's considering has no legal actions or is terminal.
func (m *MCTS) BestAction() Action {
func (m *MCTS) BestAction() int {
	m.mutex.RLock()
	defer m.mutex.RUnlock()

	if len(m.trees) == 0 {
		return nil
		return -1
	}

	//Safe guard set in place in case we're dealing
	//with a terminal state
	if m.init.IsTerminal() {
		return nil
		return -1
	}

	//Democracy Section: each tree votes for an action
	actionScore := make(map[Action]int)
	actionScore := make([]int, m.init.Len())
	for _, t := range m.trees {
		actionScore[t.bestAction()]++
	}

	//Democracy Section: the action with the most votes wins
	var bestAction Action
	var bestAction int
	var mostVotes int
	for a, s := range actionScore {
		if s > mostVotes {

M mcts_test.go => mcts_test.go +18 -27
@@ 16,30 16,18 @@ func getPlayerID(ascii byte) Player {
}

type tttGame struct {
	game tictactoe.Game
	game    tictactoe.Game
	actions []tictactoe.Move
}

func (g tttGame) GetActions() []Action {
	gameActions := g.game.GetActions()

	actions := make([]Action, len(gameActions))

	for i, a := range gameActions {
		actions[i] = a
	}

	return actions
func (g tttGame) Len() int {
	return len(g.actions)
}

func (g tttGame) ApplyAction(a Action) (Game, error) {
	move, ok := a.(tictactoe.Move)
	if !ok {
		return nil, fmt.Errorf("action not correct type")
	}

	game, err := g.game.ApplyAction(move)
func (g tttGame) ApplyAction(i int) (Game, error) {
	game, err := g.game.ApplyAction(g.actions[i])

	return tttGame{game}, err
	return tttGame{game, game.GetActions()}, err
}

func (g tttGame) Hash() interface{} {


@@ 64,7 52,7 @@ func (g tttGame) Winners() []Player {
}

//Global vars to be checked by other tests
var finishedGame tttGame
var newGame, finishedGame tttGame
var firstMove tictactoe.Move
var treeToTest *Tree



@@ 72,7 60,10 @@ var treeToTest *Tree
//the resulting terminal game state into global variables to be used by
//other tests.
func TestMain(m *testing.M) {
	game := tttGame{tictactoe.NewGame()}
	newGame = tttGame{game: tictactoe.NewGame()}
	newGame.actions = newGame.game.GetActions()

	game := newGame
	concurrentSearches := 1 //runtime.NumCPU()

	var setFirstMove sync.Once


@@ 105,7 96,7 @@ func TestMain(m *testing.M) {

		//Save the first action taken
		setFirstMove.Do(func() {
			firstMove = bestAction.(tictactoe.Move)
			firstMove = newGame.actions[bestAction]
		})
	}
	//Save the terminal game state


@@ 135,7 126,7 @@ func TestTicTacToeMiddle(t *testing.T) {
func TestZeroTrees(t *testing.T) {
	mcts := NewMCTS(finishedGame)
	bestAction := mcts.BestAction()
	if bestAction != nil {
	if bestAction != -1 {
		t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction)
		t.FailNow()
	}


@@ 145,14 136,14 @@ func TestTerminalState(t *testing.T) {
	mcts := NewMCTS(finishedGame)
	mcts.AddTree(mcts.SpawnTree())
	bestAction := mcts.BestAction()
	if bestAction != nil {
	if bestAction != -1 {
		t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction)
		t.FailNow()
	}
}

func BenchmarkTicTacToe1KRounds(b *testing.B) {
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	mcts := NewMCTS(newGame)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(1000)


@@ 160,7 151,7 @@ func BenchmarkTicTacToe1KRounds(b *testing.B) {
}

func BenchmarkTicTacToe10KRounds(b *testing.B) {
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	mcts := NewMCTS(newGame)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(10000)


@@ 168,7 159,7 @@ func BenchmarkTicTacToe10KRounds(b *testing.B) {
}

func BenchmarkTicTacToe100KRounds(b *testing.B) {
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	mcts := NewMCTS(newGame)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(100000)

M models.go => models.go +4 -11
@@ 5,12 5,6 @@ import (
	"sync"
)

//Action is the interface that represents an action that can be
//performed on a Game.
//
//Any implementation of Action should be comparable (i.e. be a key in a map)
type Action interface{}

//Player is an id for the player
type Player int



@@ 19,12 13,12 @@ type Player int
//Any implementation of Game should be comparable (i.e. be a key in a map)
//and immutable (state cannot change as this package calls any function).
type Game interface {
	//GetActions returns a list of actions to consider
	GetActions() []Action
	//Len returns the number of actions to consider.
	Len() int

	//ApplyAction applies the given action to the game state,
	//ApplyAction applies the ith action (0-indexed) to the game state,
	//and returns a new game state and an error for invalid actions
	ApplyAction(Action) (Game, error)
	ApplyAction(i int) (Game, error)

	//Hash returns a unique representation of the state.
	//Any return value must be comparable.


@@ 69,7 63,6 @@ type node struct {
	state gameState
	tree  *Tree

	actions           []Action
	children          []*node
	unvisitedChildren []*node
	childVisits       []float64

M search.go => search.go +9 -8
@@ 86,13 86,12 @@ func (n *node) runSimulation() ([]Player, float64) {
}

func (n *node) expand() {
	n.actions = n.state.GetActions()
	n.actionCount = len(n.actions)
	n.actionCount = n.state.Len()
	n.unvisitedChildren = make([]*node, n.actionCount)
	n.children = n.unvisitedChildren
	n.childVisits = make([]float64, n.actionCount)
	for i, a := range n.actions {
		newGame, err := n.state.ApplyAction(a)
	for i := 0; i < n.actionCount; i++ {
		newGame, err := n.state.ApplyAction(i)
		if err != nil {
			panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
		}


@@ 118,11 117,13 @@ func (n *node) simulate() []Player {
	for !game.IsTerminal() {
		var err error

		actions := game.GetActions()
		panicIfNoActions(game, actions)
		actions := game.Len()
		if actions <= 0 {
			panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
		}

		randomIndex := n.tree.randSource.Intn(len(actions))
		game, err = game.ApplyAction(actions[randomIndex])
		randomIndex := n.tree.randSource.Intn(actions)
		game, err = game.ApplyAction(randomIndex)
		if err != nil {
			panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
		}

M tree.go => tree.go +9 -6
@@ 8,7 8,8 @@ import (
//Search searches the tree for a specified time
//
//Search will panic if the Game's ApplyAction
//method returns an error
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
func (t *Tree) Search(duration time.Duration) {
	ctx, cancel := context.WithTimeout(context.Background(), duration)
	defer cancel()


@@ 18,7 19,8 @@ func (t *Tree) Search(duration time.Duration) {
//SearchContext searches the tree using a given context
//
//SearchContext will panic if the Game's ApplyAction
//method returns an error
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
func (t *Tree) SearchContext(ctx context.Context) {
	for {
		select {


@@ 33,7 35,8 @@ func (t *Tree) SearchContext(ctx context.Context) {
//SearchRounds searches the tree for a specified number of rounds
//
//SearchRounds will panic if the Game's ApplyAction
//method returns an error
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
func (t *Tree) SearchRounds(rounds int) {
	for i := 0; i < rounds; i++ {
		t.search()


@@ 69,17 72,17 @@ func (t Tree) MaxDepth() int {
	return maxDepth
}

func (t *Tree) bestAction() Action {
func (t *Tree) bestAction() int {
	root := t.current

	//Select the child with the highest winrate
	var bestAction Action
	var bestAction int
	bestWinRate := -1.0
	player := root.state.Player()
	for i := 0; i < root.actionCount; i++ {
		winRate := root.children[i].nodeScore[player] / root.childVisits[i]
		if winRate > bestWinRate {
			bestAction = root.actions[i]
			bestAction = i
			bestWinRate = winRate
		}
	}

M tree_test.go => tree_test.go +2 -1
@@ 37,7 37,8 @@ func TestDepth(t *testing.T) {
}

func TestSearch(t *testing.T) {
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	newGame := tictactoe.NewGame()
	mcts := NewMCTS(tttGame{newGame, newGame.GetActions()})
	tree := mcts.SpawnTree()

	timeToSearch := 1 * time.Millisecond

D utils.go => utils.go +0 -9
@@ 1,9 0,0 @@
package gmcts

import "fmt"

func panicIfNoActions(game Game, actions []Action) {
	if len(actions) == 0 {
		panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
	}
}