From 73493dface28f87d1a6c0fd5f18ca256d445dda2 Mon Sep 17 00:00:00 2001 From: bonbon Date: Sat, 8 Aug 2020 07:13:28 -0500 Subject: [PATCH] rework action selection system We offload the action selection to the user of this package. Theoritically, gmcts only needs to know how many actions there are, and pick random actions from there. This change is a rather massive change for current implementations, but significant speed gains are available due to not having to convert actions from values to interfaces back to values. This change also removes the need for an Action interface, as the package no longer needs to hold a list of random actions. --- comparable_test.go | 59 ---------------------------------------------- mcts.go | 29 ++++------------------- mcts_test.go | 45 ++++++++++++++--------------------- models.go | 15 ++++-------- search.go | 17 ++++++------- tree.go | 15 +++++++----- tree_test.go | 3 ++- utils.go | 9 ------- 8 files changed, 47 insertions(+), 145 deletions(-) delete mode 100644 comparable_test.go delete mode 100644 utils.go diff --git a/comparable_test.go b/comparable_test.go deleted file mode 100644 index 01af375..0000000 --- a/comparable_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package gmcts - -import "testing" - -//comparableState tests the comparable action requirement of gmcts, as -//the GetActions method returns a noncomparable action. -type comparableState struct{} - -//nonComparableState tests the comparable state requirement of gmcts. -type nonComparableState struct { - comparableState - _ []int -} - -func (n comparableState) GetActions() []Action { - return []Action{nonComparableState{}} -} - -func (n comparableState) ApplyAction(a Action) (Game, error) { - return n, nil -} - -func (n comparableState) IsTerminal() bool { - return true -} - -func (n comparableState) Hash() interface{} { - return 0 -} - -func (n comparableState) Player() Player { - return 0 -} - -func (n comparableState) Winners() []Player { - return nil -} - -func TestNonComparableState(t *testing.T) { - //Calling NewMCTS should panic, as the nonComparableState is, as - //the name suggests, not comparable. - defer func() { - if r := recover(); r == nil { - t.FailNow() - } - }() - NewMCTS(nonComparableState{}) -} - -func TestNonComparableAction(t *testing.T) { - //Calling NewMCTS should panic, as the actions from comparableState - //are not comparable. - defer func() { - if r := recover(); r == nil { - t.FailNow() - } - }() - NewMCTS(comparableState{}) -} diff --git a/mcts.go b/mcts.go index b39a52a..882c44d 100644 --- a/mcts.go +++ b/mcts.go @@ -2,30 +2,11 @@ package gmcts import ( "math/rand" - "reflect" "sync" ) //NewMCTS returns a new MCTS wrapper -// -//If either the Game or its Action types are not comparable, -//this function panics func NewMCTS(initial Game) *MCTS { - //Check if Game type if comparable - if !reflect.TypeOf(initial).Comparable() { - panic("gmcts: game type is not comparable") - } - - //Check if Action type is comparable - //We only need to check the actions that can affect the initial gamestate - //as those are the only actions that need to be compared. - actions := initial.GetActions() - for i := range actions { - if !reflect.TypeOf(actions[i]).Comparable() { - panic("gmcts: action type is not comparable") - } - } - return &MCTS{ init: initial, trees: make([]*Tree, 0), @@ -80,28 +61,28 @@ func (m *MCTS) AddTree(t *Tree) { //BestAction returns nil if it has received no trees //to search through or if the current state //it's considering has no legal actions or is terminal. -func (m *MCTS) BestAction() Action { +func (m *MCTS) BestAction() int { m.mutex.RLock() defer m.mutex.RUnlock() if len(m.trees) == 0 { - return nil + return -1 } //Safe guard set in place in case we're dealing //with a terminal state if m.init.IsTerminal() { - return nil + return -1 } //Democracy Section: each tree votes for an action - actionScore := make(map[Action]int) + actionScore := make([]int, m.init.Len()) for _, t := range m.trees { actionScore[t.bestAction()]++ } //Democracy Section: the action with the most votes wins - var bestAction Action + var bestAction int var mostVotes int for a, s := range actionScore { if s > mostVotes { diff --git a/mcts_test.go b/mcts_test.go index 34f83a0..393c77e 100644 --- a/mcts_test.go +++ b/mcts_test.go @@ -16,30 +16,18 @@ func getPlayerID(ascii byte) Player { } type tttGame struct { - game tictactoe.Game + game tictactoe.Game + actions []tictactoe.Move } -func (g tttGame) GetActions() []Action { - gameActions := g.game.GetActions() - - actions := make([]Action, len(gameActions)) - - for i, a := range gameActions { - actions[i] = a - } - - return actions +func (g tttGame) Len() int { + return len(g.actions) } -func (g tttGame) ApplyAction(a Action) (Game, error) { - move, ok := a.(tictactoe.Move) - if !ok { - return nil, fmt.Errorf("action not correct type") - } - - game, err := g.game.ApplyAction(move) +func (g tttGame) ApplyAction(i int) (Game, error) { + game, err := g.game.ApplyAction(g.actions[i]) - return tttGame{game}, err + return tttGame{game, game.GetActions()}, err } func (g tttGame) Hash() interface{} { @@ -64,7 +52,7 @@ func (g tttGame) Winners() []Player { } //Global vars to be checked by other tests -var finishedGame tttGame +var newGame, finishedGame tttGame var firstMove tictactoe.Move var treeToTest *Tree @@ -72,7 +60,10 @@ var treeToTest *Tree //the resulting terminal game state into global variables to be used by //other tests. func TestMain(m *testing.M) { - game := tttGame{tictactoe.NewGame()} + newGame = tttGame{game: tictactoe.NewGame()} + newGame.actions = newGame.game.GetActions() + + game := newGame concurrentSearches := 1 //runtime.NumCPU() var setFirstMove sync.Once @@ -105,7 +96,7 @@ func TestMain(m *testing.M) { //Save the first action taken setFirstMove.Do(func() { - firstMove = bestAction.(tictactoe.Move) + firstMove = newGame.actions[bestAction] }) } //Save the terminal game state @@ -135,7 +126,7 @@ func TestTicTacToeMiddle(t *testing.T) { func TestZeroTrees(t *testing.T) { mcts := NewMCTS(finishedGame) bestAction := mcts.BestAction() - if bestAction != nil { + if bestAction != -1 { t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction) t.FailNow() } @@ -145,14 +136,14 @@ func TestTerminalState(t *testing.T) { mcts := NewMCTS(finishedGame) mcts.AddTree(mcts.SpawnTree()) bestAction := mcts.BestAction() - if bestAction != nil { + if bestAction != -1 { t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction) t.FailNow() } } func BenchmarkTicTacToe1KRounds(b *testing.B) { - mcts := NewMCTS(tttGame{tictactoe.NewGame()}) + mcts := NewMCTS(newGame) b.ResetTimer() for i := 0; i < b.N; i++ { mcts.SpawnTree().SearchRounds(1000) @@ -160,7 +151,7 @@ func BenchmarkTicTacToe1KRounds(b *testing.B) { } func BenchmarkTicTacToe10KRounds(b *testing.B) { - mcts := NewMCTS(tttGame{tictactoe.NewGame()}) + mcts := NewMCTS(newGame) b.ResetTimer() for i := 0; i < b.N; i++ { mcts.SpawnTree().SearchRounds(10000) @@ -168,7 +159,7 @@ func BenchmarkTicTacToe10KRounds(b *testing.B) { } func BenchmarkTicTacToe100KRounds(b *testing.B) { - mcts := NewMCTS(tttGame{tictactoe.NewGame()}) + mcts := NewMCTS(newGame) b.ResetTimer() for i := 0; i < b.N; i++ { mcts.SpawnTree().SearchRounds(100000) diff --git a/models.go b/models.go index 8a62bbf..4934cb1 100644 --- a/models.go +++ b/models.go @@ -5,12 +5,6 @@ import ( "sync" ) -//Action is the interface that represents an action that can be -//performed on a Game. -// -//Any implementation of Action should be comparable (i.e. be a key in a map) -type Action interface{} - //Player is an id for the player type Player int @@ -19,12 +13,12 @@ type Player int //Any implementation of Game should be comparable (i.e. be a key in a map) //and immutable (state cannot change as this package calls any function). type Game interface { - //GetActions returns a list of actions to consider - GetActions() []Action + //Len returns the number of actions to consider. + Len() int - //ApplyAction applies the given action to the game state, + //ApplyAction applies the ith action (0-indexed) to the game state, //and returns a new game state and an error for invalid actions - ApplyAction(Action) (Game, error) + ApplyAction(i int) (Game, error) //Hash returns a unique representation of the state. //Any return value must be comparable. @@ -69,7 +63,6 @@ type node struct { state gameState tree *Tree - actions []Action children []*node unvisitedChildren []*node childVisits []float64 diff --git a/search.go b/search.go index 8ca97fd..0e73ef7 100644 --- a/search.go +++ b/search.go @@ -86,13 +86,12 @@ func (n *node) runSimulation() ([]Player, float64) { } func (n *node) expand() { - n.actions = n.state.GetActions() - n.actionCount = len(n.actions) + n.actionCount = n.state.Len() n.unvisitedChildren = make([]*node, n.actionCount) n.children = n.unvisitedChildren n.childVisits = make([]float64, n.actionCount) - for i, a := range n.actions { - newGame, err := n.state.ApplyAction(a) + for i := 0; i < n.actionCount; i++ { + newGame, err := n.state.ApplyAction(i) if err != nil { panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err)) } @@ -118,11 +117,13 @@ func (n *node) simulate() []Player { for !game.IsTerminal() { var err error - actions := game.GetActions() - panicIfNoActions(game, actions) + actions := game.Len() + if actions <= 0 { + panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game)) + } - randomIndex := n.tree.randSource.Intn(len(actions)) - game, err = game.ApplyAction(actions[randomIndex]) + randomIndex := n.tree.randSource.Intn(actions) + game, err = game.ApplyAction(randomIndex) if err != nil { panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err)) } diff --git a/tree.go b/tree.go index 9f5f5b8..6ef5a85 100644 --- a/tree.go +++ b/tree.go @@ -8,7 +8,8 @@ import ( //Search searches the tree for a specified time // //Search will panic if the Game's ApplyAction -//method returns an error +//method returns an error or if any game state's Hash() +//method returns a noncomparable value. func (t *Tree) Search(duration time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), duration) defer cancel() @@ -18,7 +19,8 @@ func (t *Tree) Search(duration time.Duration) { //SearchContext searches the tree using a given context // //SearchContext will panic if the Game's ApplyAction -//method returns an error +//method returns an error or if any game state's Hash() +//method returns a noncomparable value. func (t *Tree) SearchContext(ctx context.Context) { for { select { @@ -33,7 +35,8 @@ func (t *Tree) SearchContext(ctx context.Context) { //SearchRounds searches the tree for a specified number of rounds // //SearchRounds will panic if the Game's ApplyAction -//method returns an error +//method returns an error or if any game state's Hash() +//method returns a noncomparable value. func (t *Tree) SearchRounds(rounds int) { for i := 0; i < rounds; i++ { t.search() @@ -69,17 +72,17 @@ func (t Tree) MaxDepth() int { return maxDepth } -func (t *Tree) bestAction() Action { +func (t *Tree) bestAction() int { root := t.current //Select the child with the highest winrate - var bestAction Action + var bestAction int bestWinRate := -1.0 player := root.state.Player() for i := 0; i < root.actionCount; i++ { winRate := root.children[i].nodeScore[player] / root.childVisits[i] if winRate > bestWinRate { - bestAction = root.actions[i] + bestAction = i bestWinRate = winRate } } diff --git a/tree_test.go b/tree_test.go index e5ca635..906a85b 100644 --- a/tree_test.go +++ b/tree_test.go @@ -37,7 +37,8 @@ func TestDepth(t *testing.T) { } func TestSearch(t *testing.T) { - mcts := NewMCTS(tttGame{tictactoe.NewGame()}) + newGame := tictactoe.NewGame() + mcts := NewMCTS(tttGame{newGame, newGame.GetActions()}) tree := mcts.SpawnTree() timeToSearch := 1 * time.Millisecond diff --git a/utils.go b/utils.go deleted file mode 100644 index fc6f434..0000000 --- a/utils.go +++ /dev/null @@ -1,9 +0,0 @@ -package gmcts - -import "fmt" - -func panicIfNoActions(game Game, actions []Action) { - if len(actions) == 0 { - panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game)) - } -} -- 2.34.2