D comparable_test.go => comparable_test.go +0 -59
@@ 1,59 0,0 @@
-package gmcts
-
-import "testing"
-
-//comparableState tests the comparable action requirement of gmcts, as
-//the GetActions method returns a noncomparable action.
-type comparableState struct{}
-
-//nonComparableState tests the comparable state requirement of gmcts.
-type nonComparableState struct {
- comparableState
- _ []int
-}
-
-func (n comparableState) GetActions() []Action {
- return []Action{nonComparableState{}}
-}
-
-func (n comparableState) ApplyAction(a Action) (Game, error) {
- return n, nil
-}
-
-func (n comparableState) IsTerminal() bool {
- return true
-}
-
-func (n comparableState) Hash() interface{} {
- return 0
-}
-
-func (n comparableState) Player() Player {
- return 0
-}
-
-func (n comparableState) Winners() []Player {
- return nil
-}
-
-func TestNonComparableState(t *testing.T) {
- //Calling NewMCTS should panic, as the nonComparableState is, as
- //the name suggests, not comparable.
- defer func() {
- if r := recover(); r == nil {
- t.FailNow()
- }
- }()
- NewMCTS(nonComparableState{})
-}
-
-func TestNonComparableAction(t *testing.T) {
- //Calling NewMCTS should panic, as the actions from comparableState
- //are not comparable.
- defer func() {
- if r := recover(); r == nil {
- t.FailNow()
- }
- }()
- NewMCTS(comparableState{})
-}
M mcts.go => mcts.go +5 -24
@@ 2,30 2,11 @@ package gmcts
import (
"math/rand"
- "reflect"
"sync"
)
//NewMCTS returns a new MCTS wrapper
-//
-//If either the Game or its Action types are not comparable,
-//this function panics
func NewMCTS(initial Game) *MCTS {
- //Check if Game type if comparable
- if !reflect.TypeOf(initial).Comparable() {
- panic("gmcts: game type is not comparable")
- }
-
- //Check if Action type is comparable
- //We only need to check the actions that can affect the initial gamestate
- //as those are the only actions that need to be compared.
- actions := initial.GetActions()
- for i := range actions {
- if !reflect.TypeOf(actions[i]).Comparable() {
- panic("gmcts: action type is not comparable")
- }
- }
-
return &MCTS{
init: initial,
trees: make([]*Tree, 0),
@@ 80,28 61,28 @@ func (m *MCTS) AddTree(t *Tree) {
//BestAction returns nil if it has received no trees
//to search through or if the current state
//it's considering has no legal actions or is terminal.
-func (m *MCTS) BestAction() Action {
+func (m *MCTS) BestAction() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
if len(m.trees) == 0 {
- return nil
+ return -1
}
//Safe guard set in place in case we're dealing
//with a terminal state
if m.init.IsTerminal() {
- return nil
+ return -1
}
//Democracy Section: each tree votes for an action
- actionScore := make(map[Action]int)
+ actionScore := make([]int, m.init.Len())
for _, t := range m.trees {
actionScore[t.bestAction()]++
}
//Democracy Section: the action with the most votes wins
- var bestAction Action
+ var bestAction int
var mostVotes int
for a, s := range actionScore {
if s > mostVotes {
M mcts_test.go => mcts_test.go +18 -27
@@ 16,30 16,18 @@ func getPlayerID(ascii byte) Player {
}
type tttGame struct {
- game tictactoe.Game
+ game tictactoe.Game
+ actions []tictactoe.Move
}
-func (g tttGame) GetActions() []Action {
- gameActions := g.game.GetActions()
-
- actions := make([]Action, len(gameActions))
-
- for i, a := range gameActions {
- actions[i] = a
- }
-
- return actions
+func (g tttGame) Len() int {
+ return len(g.actions)
}
-func (g tttGame) ApplyAction(a Action) (Game, error) {
- move, ok := a.(tictactoe.Move)
- if !ok {
- return nil, fmt.Errorf("action not correct type")
- }
-
- game, err := g.game.ApplyAction(move)
+func (g tttGame) ApplyAction(i int) (Game, error) {
+ game, err := g.game.ApplyAction(g.actions[i])
- return tttGame{game}, err
+ return tttGame{game, game.GetActions()}, err
}
func (g tttGame) Hash() interface{} {
@@ 64,7 52,7 @@ func (g tttGame) Winners() []Player {
}
//Global vars to be checked by other tests
-var finishedGame tttGame
+var newGame, finishedGame tttGame
var firstMove tictactoe.Move
var treeToTest *Tree
@@ 72,7 60,10 @@ var treeToTest *Tree
//the resulting terminal game state into global variables to be used by
//other tests.
func TestMain(m *testing.M) {
- game := tttGame{tictactoe.NewGame()}
+ newGame = tttGame{game: tictactoe.NewGame()}
+ newGame.actions = newGame.game.GetActions()
+
+ game := newGame
concurrentSearches := 1 //runtime.NumCPU()
var setFirstMove sync.Once
@@ 105,7 96,7 @@ func TestMain(m *testing.M) {
//Save the first action taken
setFirstMove.Do(func() {
- firstMove = bestAction.(tictactoe.Move)
+ firstMove = newGame.actions[bestAction]
})
}
//Save the terminal game state
@@ 135,7 126,7 @@ func TestTicTacToeMiddle(t *testing.T) {
func TestZeroTrees(t *testing.T) {
mcts := NewMCTS(finishedGame)
bestAction := mcts.BestAction()
- if bestAction != nil {
+ if bestAction != -1 {
t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction)
t.FailNow()
}
@@ 145,14 136,14 @@ func TestTerminalState(t *testing.T) {
mcts := NewMCTS(finishedGame)
mcts.AddTree(mcts.SpawnTree())
bestAction := mcts.BestAction()
- if bestAction != nil {
+ if bestAction != -1 {
t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction)
t.FailNow()
}
}
func BenchmarkTicTacToe1KRounds(b *testing.B) {
- mcts := NewMCTS(tttGame{tictactoe.NewGame()})
+ mcts := NewMCTS(newGame)
b.ResetTimer()
for i := 0; i < b.N; i++ {
mcts.SpawnTree().SearchRounds(1000)
@@ 160,7 151,7 @@ func BenchmarkTicTacToe1KRounds(b *testing.B) {
}
func BenchmarkTicTacToe10KRounds(b *testing.B) {
- mcts := NewMCTS(tttGame{tictactoe.NewGame()})
+ mcts := NewMCTS(newGame)
b.ResetTimer()
for i := 0; i < b.N; i++ {
mcts.SpawnTree().SearchRounds(10000)
@@ 168,7 159,7 @@ func BenchmarkTicTacToe10KRounds(b *testing.B) {
}
func BenchmarkTicTacToe100KRounds(b *testing.B) {
- mcts := NewMCTS(tttGame{tictactoe.NewGame()})
+ mcts := NewMCTS(newGame)
b.ResetTimer()
for i := 0; i < b.N; i++ {
mcts.SpawnTree().SearchRounds(100000)
M models.go => models.go +4 -11
@@ 5,12 5,6 @@ import (
"sync"
)
-//Action is the interface that represents an action that can be
-//performed on a Game.
-//
-//Any implementation of Action should be comparable (i.e. be a key in a map)
-type Action interface{}
-
//Player is an id for the player
type Player int
@@ 19,12 13,12 @@ type Player int
//Any implementation of Game should be comparable (i.e. be a key in a map)
//and immutable (state cannot change as this package calls any function).
type Game interface {
- //GetActions returns a list of actions to consider
- GetActions() []Action
+ //Len returns the number of actions to consider.
+ Len() int
- //ApplyAction applies the given action to the game state,
+ //ApplyAction applies the ith action (0-indexed) to the game state,
//and returns a new game state and an error for invalid actions
- ApplyAction(Action) (Game, error)
+ ApplyAction(i int) (Game, error)
//Hash returns a unique representation of the state.
//Any return value must be comparable.
@@ 69,7 63,6 @@ type node struct {
state gameState
tree *Tree
- actions []Action
children []*node
unvisitedChildren []*node
childVisits []float64
M search.go => search.go +9 -8
@@ 86,13 86,12 @@ func (n *node) runSimulation() ([]Player, float64) {
}
func (n *node) expand() {
- n.actions = n.state.GetActions()
- n.actionCount = len(n.actions)
+ n.actionCount = n.state.Len()
n.unvisitedChildren = make([]*node, n.actionCount)
n.children = n.unvisitedChildren
n.childVisits = make([]float64, n.actionCount)
- for i, a := range n.actions {
- newGame, err := n.state.ApplyAction(a)
+ for i := 0; i < n.actionCount; i++ {
+ newGame, err := n.state.ApplyAction(i)
if err != nil {
panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
}
@@ 118,11 117,13 @@ func (n *node) simulate() []Player {
for !game.IsTerminal() {
var err error
- actions := game.GetActions()
- panicIfNoActions(game, actions)
+ actions := game.Len()
+ if actions <= 0 {
+ panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
+ }
- randomIndex := n.tree.randSource.Intn(len(actions))
- game, err = game.ApplyAction(actions[randomIndex])
+ randomIndex := n.tree.randSource.Intn(actions)
+ game, err = game.ApplyAction(randomIndex)
if err != nil {
panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
}
M tree.go => tree.go +9 -6
@@ 8,7 8,8 @@ import (
//Search searches the tree for a specified time
//
//Search will panic if the Game's ApplyAction
-//method returns an error
+//method returns an error or if any game state's Hash()
+//method returns a noncomparable value.
func (t *Tree) Search(duration time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
@@ 18,7 19,8 @@ func (t *Tree) Search(duration time.Duration) {
//SearchContext searches the tree using a given context
//
//SearchContext will panic if the Game's ApplyAction
-//method returns an error
+//method returns an error or if any game state's Hash()
+//method returns a noncomparable value.
func (t *Tree) SearchContext(ctx context.Context) {
for {
select {
@@ 33,7 35,8 @@ func (t *Tree) SearchContext(ctx context.Context) {
//SearchRounds searches the tree for a specified number of rounds
//
//SearchRounds will panic if the Game's ApplyAction
-//method returns an error
+//method returns an error or if any game state's Hash()
+//method returns a noncomparable value.
func (t *Tree) SearchRounds(rounds int) {
for i := 0; i < rounds; i++ {
t.search()
@@ 69,17 72,17 @@ func (t Tree) MaxDepth() int {
return maxDepth
}
-func (t *Tree) bestAction() Action {
+func (t *Tree) bestAction() int {
root := t.current
//Select the child with the highest winrate
- var bestAction Action
+ var bestAction int
bestWinRate := -1.0
player := root.state.Player()
for i := 0; i < root.actionCount; i++ {
winRate := root.children[i].nodeScore[player] / root.childVisits[i]
if winRate > bestWinRate {
- bestAction = root.actions[i]
+ bestAction = i
bestWinRate = winRate
}
}
M tree_test.go => tree_test.go +2 -1
@@ 37,7 37,8 @@ func TestDepth(t *testing.T) {
}
func TestSearch(t *testing.T) {
- mcts := NewMCTS(tttGame{tictactoe.NewGame()})
+ newGame := tictactoe.NewGame()
+ mcts := NewMCTS(tttGame{newGame, newGame.GetActions()})
tree := mcts.SpawnTree()
timeToSearch := 1 * time.Millisecond
D utils.go => utils.go +0 -9
@@ 1,9 0,0 @@
-package gmcts
-
-import "fmt"
-
-func panicIfNoActions(game Game, actions []Action) {
- if len(actions) == 0 {
- panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
- }
-}