M README.md => README.md +2 -12
@@ 55,12 55,7 @@ func runGame() {
mcts.AddTree(tree)
//Get the best action based off of the trees collected from mcts.AddTree()
- bestAction, err := mcts.BestAction()
- if err != nil {
- //...
- //handle error
- //...
- }
+ bestAction := mcts.BestAction()
//Update the game state using the tree's best action
gameState, _ = gameState.ApplyAction(bestAction)
@@ 89,12 84,7 @@ for i := 0; i < concurrentTrees; i++ {
//Wait for the 4 trees to finish searching
wait.Wait()
-bestAction, err := mcts.BestAction()
-if err != nil {
- //...
- //handle error
- //...
-}
+bestAction := mcts.BestAction()
gameState, _ = gameState.ApplyAction(bestAction)
```
A comparable_test.go => comparable_test.go +55 -0
@@ 0,0 1,55 @@
+package gmcts
+
+import "testing"
+
+//comparableState tests the comparable action requirement of gmcts, as
+//the GetActions method returns a noncomparable action.
+type comparableState struct{}
+
+//nonComparableState tests the comparable state requirement of gmcts.
+type nonComparableState struct {
+ comparableState
+ _ []int
+}
+
+func (n comparableState) GetActions() []Action {
+ return []Action{nonComparableState{}}
+}
+
+func (n comparableState) ApplyAction(a Action) (Game, error) {
+ return n, nil
+}
+
+func (n comparableState) IsTerminal() bool {
+ return true
+}
+
+func (n comparableState) Player() Player {
+ return 0
+}
+
+func (n comparableState) Winners() []Player {
+ return nil
+}
+
+func TestNonComparableState(t *testing.T) {
+ //Calling NewMCTS should panic, as the nonComparableState is, as
+ //the name suggests, not comparable.
+ defer func() {
+ if r := recover(); r == nil {
+ t.FailNow()
+ }
+ }()
+ NewMCTS(nonComparableState{})
+}
+
+func TestNonComparableAction(t *testing.T) {
+ //Calling NewMCTS should panic, as the actions from comparableState
+ //are not comparable.
+ defer func() {
+ if r := recover(); r == nil {
+ t.FailNow()
+ }
+ }()
+ NewMCTS(comparableState{})
+}
M mcts.go => mcts.go +38 -31
@@ 1,24 1,31 @@
package gmcts
import (
- "errors"
"math/rand"
+ "reflect"
"sync"
)
-var (
- //ErrNoTrees notifies the callee that the MCTS wrapper has recieved to trees to analyze
- ErrNoTrees = errors.New("gmcts: mcts wrapper has collected to trees to analyze")
-
- //ErrTerminal notifies the callee that the given state is terminal
- ErrTerminal = errors.New("gmcts: given game state is a terminal state, therefore, it cannot return an action")
-
- //ErrNoActions notifies the callee that the given state has <= 0 actions
- ErrNoActions = errors.New("gmcts: given game state is not terminal, yet the state has <= 0 actions to search through")
-)
-
//NewMCTS returns a new MCTS wrapper
+//
+//If either the Game or its Action types are not comparable,
+//this function panics
func NewMCTS(initial Game) *MCTS {
+ //Check if Game type if comparable
+ if !reflect.TypeOf(initial).Comparable() {
+ panic("gmcts: game type is not comparable")
+ }
+
+ //Check if Action type is comparable
+ //We only need to check the actions that can affect the initial gamestate
+ //as those are the only actions that need to be compared.
+ actions := initial.GetActions()
+ for i := range actions {
+ if !reflect.TypeOf(actions[i]).Comparable() {
+ panic("gmcts: action type is not comparable")
+ }
+ }
+
return &MCTS{
init: initial,
trees: make([]*Tree, 0),
@@ 33,7 40,7 @@ func (m *MCTS) SpawnTree() *Tree {
}
//SetSeed sets the seed of the next tree to be spawned.
-//This value is initially set to 0, and increments on each
+//This value is initially set to 1, and increments on each
//spawned tree.
func (m *MCTS) SetSeed(seed int64) {
m.mutex.Lock()
@@ 47,11 54,11 @@ func (m *MCTS) SpawnCustomTree(explorationConst float64) *Tree {
defer m.mutex.Unlock()
t := &Tree{
- gameStates: make(map[gameHash]*node),
+ gameStates: make(map[gameState]*node),
explorationConst: explorationConst,
randSource: rand.New(rand.NewSource(m.seed)),
}
- t.current = initializeNode(gameState{m.init, gameHash{m.init.Hash(), 0}}, t)
+ t.current = initializeNode(gameState{m.init, 0}, t)
m.seed++
return t
@@ 67,34 74,34 @@ func (m *MCTS) AddTree(t *Tree) {
}
//BestAction takes all of the searched trees and returns
-//the index of the best action based on the highest win
-//percentage of each action.
+//the best action based on the highest win percentage
+//of each action.
//
-//BestAction returns ErrNoTrees if it has received no trees
-//to search through, ErrNoActions if the current state
-//it's considering has no legal actions, or ErrTerminal
-//if the current state it's considering is terminal.
-func (m *MCTS) BestAction() (int, error) {
+//BestAction returns nil if it has received no trees
+//to search through or if the current state
+//it's considering has no legal actions or is terminal.
+func (m *MCTS) BestAction() Action {
m.mutex.RLock()
defer m.mutex.RUnlock()
- //Error checking
if len(m.trees) == 0 {
- return -1, ErrNoTrees
- } else if m.init.IsTerminal() {
- return -1, ErrTerminal
- } else if m.init.Len() <= 0 {
- return -1, ErrNoActions
+ return nil
+ }
+
+ //Safe guard set in place in case we're dealing
+ //with a terminal state
+ if m.init.IsTerminal() {
+ return nil
}
//Democracy Section: each tree votes for an action
- actionScore := make([]int, m.init.Len())
+ actionScore := make(map[Action]int)
for _, t := range m.trees {
actionScore[t.bestAction()]++
}
//Democracy Section: the action with the most votes wins
- var bestAction int
+ var bestAction Action
var mostVotes int
for a, s := range actionScore {
if s > mostVotes {
@@ 102,5 109,5 @@ func (m *MCTS) BestAction() (int, error) {
mostVotes = s
}
}
- return bestAction, nil
+ return bestAction
}
M mcts_test.go => mcts_test.go +29 -24
@@ 16,22 16,30 @@ func getPlayerID(ascii byte) Player {
}
type tttGame struct {
- game tictactoe.Game
- actions []tictactoe.Move
+ game tictactoe.Game
}
-func (g tttGame) Len() int {
- return len(g.actions)
-}
+func (g tttGame) GetActions() []Action {
+ gameActions := g.game.GetActions()
+
+ actions := make([]Action, len(gameActions))
-func (g tttGame) ApplyAction(i int) (Game, error) {
- game, err := g.game.ApplyAction(g.actions[i])
+ for i, a := range gameActions {
+ actions[i] = a
+ }
- return tttGame{game, game.GetActions()}, err
+ return actions
}
-func (g tttGame) Hash() interface{} {
- return g.game
+func (g tttGame) ApplyAction(a Action) (Game, error) {
+ move, ok := a.(tictactoe.Move)
+ if !ok {
+ return nil, fmt.Errorf("action not correct type")
+ }
+
+ game, err := g.game.ApplyAction(move)
+
+ return tttGame{game}, err
}
func (g tttGame) Player() Player {
@@ 52,7 60,7 @@ func (g tttGame) Winners() []Player {
}
//Global vars to be checked by other tests
-var newGame, finishedGame tttGame
+var finishedGame tttGame
var firstMove tictactoe.Move
var treeToTest *Tree
@@ 60,10 68,7 @@ var treeToTest *Tree
//the resulting terminal game state into global variables to be used by
//other tests.
func TestMain(m *testing.M) {
- newGame = tttGame{game: tictactoe.NewGame()}
- newGame.actions = newGame.game.GetActions()
-
- game := newGame
+ game := tttGame{tictactoe.NewGame()}
concurrentSearches := 1 //runtime.NumCPU()
var setFirstMove sync.Once
@@ 89,14 94,14 @@ func TestMain(m *testing.M) {
}
wait.Wait()
- bestAction, _ := mcts.BestAction()
+ bestAction := mcts.BestAction()
nextState, _ := game.ApplyAction(bestAction)
game = nextState.(tttGame)
fmt.Println(game.game)
//Save the first action taken
setFirstMove.Do(func() {
- firstMove = newGame.actions[bestAction]
+ firstMove = bestAction.(tictactoe.Move)
})
}
//Save the terminal game state
@@ 125,8 130,8 @@ func TestTicTacToeMiddle(t *testing.T) {
func TestZeroTrees(t *testing.T) {
mcts := NewMCTS(finishedGame)
- bestAction, _ := mcts.BestAction()
- if bestAction != -1 {
+ bestAction := mcts.BestAction()
+ if bestAction != nil {
t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction)
t.FailNow()
}
@@ 135,15 140,15 @@ func TestZeroTrees(t *testing.T) {
func TestTerminalState(t *testing.T) {
mcts := NewMCTS(finishedGame)
mcts.AddTree(mcts.SpawnTree())
- bestAction, _ := mcts.BestAction()
- if bestAction != -1 {
+ bestAction := mcts.BestAction()
+ if bestAction != nil {
t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction)
t.FailNow()
}
}
func BenchmarkTicTacToe1KRounds(b *testing.B) {
- mcts := NewMCTS(newGame)
+ mcts := NewMCTS(tttGame{tictactoe.NewGame()})
b.ResetTimer()
for i := 0; i < b.N; i++ {
mcts.SpawnTree().SearchRounds(1000)
@@ 151,7 156,7 @@ func BenchmarkTicTacToe1KRounds(b *testing.B) {
}
func BenchmarkTicTacToe10KRounds(b *testing.B) {
- mcts := NewMCTS(newGame)
+ mcts := NewMCTS(tttGame{tictactoe.NewGame()})
b.ResetTimer()
for i := 0; i < b.N; i++ {
mcts.SpawnTree().SearchRounds(10000)
@@ 159,7 164,7 @@ func BenchmarkTicTacToe10KRounds(b *testing.B) {
}
func BenchmarkTicTacToe100KRounds(b *testing.B) {
- mcts := NewMCTS(newGame)
+ mcts := NewMCTS(tttGame{tictactoe.NewGame()})
b.ResetTimer()
for i := 0; i < b.N; i++ {
mcts.SpawnTree().SearchRounds(100000)
M models.go => models.go +17 -18
@@ 5,24 5,26 @@ import (
"sync"
)
+//Action is the interface that represents an action that can be
+//performed on a Game.
+//
+//Any implementation of Action should be comparable (i.e. be a key in a map)
+type Action interface{}
+
//Player is an id for the player
type Player int
//Game is the interface that represents game states.
//
-//Any implementation of Game should be immutable
-//(state cannot change as this package calls any function).
+//Any implementation of Game should be comparable (i.e. be a key in a map)
+//and immutable (state cannot change as this package calls any function).
type Game interface {
- //Len returns the number of actions to consider.
- Len() int
+ //GetActions returns a list of actions to consider
+ GetActions() []Action
- //ApplyAction applies the ith action (0-indexed) to the game state,
+ //ApplyAction applies the given action to the game state,
//and returns a new game state and an error for invalid actions
- ApplyAction(i int) (Game, error)
-
- //Hash returns a unique representation of the state.
- //Any return value must be comparable.
- Hash() interface{}
+ ApplyAction(Action) (Game, error)
//Player returns the player that can take the next action
Player() Player
@@ 37,16 39,12 @@ type Game interface {
type gameState struct {
Game
- gameHash
-}
-
-type gameHash struct {
- hash interface{}
//This is to separate states that seemingly look the same,
//but actually occur on different turn orders. Without this,
- //the directed acyclic graph will become a directed cyclic graph,
- //which this MCTS implementation cannot handle properly.
+ //the directed tree that multiple parent nodes will just
+ //become a directed graph, which this MCTS implementation
+ //cannot handle properly.
turn int
}
@@ 62,6 60,7 @@ type node struct {
state gameState
tree *Tree
+ actions []Action
children []*node
unvisitedChildren []*node
childVisits []float64
@@ 74,7 73,7 @@ type node struct {
//Tree represents a game state tree
type Tree struct {
current *node
- gameStates map[gameHash]*node
+ gameStates map[gameState]*node
explorationConst float64
randSource *rand.Rand
}
M search.go => search.go +14 -16
@@ 20,13 20,12 @@ func initializeNode(g gameState, tree *Tree) *node {
}
}
-//UCT2 algorithm is described in this paper
-//https://www.csse.uwa.edu.au/cig08/Proceedings/papers/8057.pdf
func (n *node) UCT2(i int, p Player) float64 {
exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)
- explore := math.Log(float64(n.nodeVisits)) / n.childVisits[i]
- explore = math.Sqrt(explore)
+ explore := math.Sqrt(
+ math.Log(float64(n.nodeVisits)) / n.childVisits[i],
+ )
return exploit + n.tree.explorationConst*explore
}
@@ 87,28 86,29 @@ func (n *node) runSimulation() ([]Player, float64) {
}
func (n *node) expand() {
- n.actionCount = n.state.Len()
+ n.actions = n.state.GetActions()
+ n.actionCount = len(n.actions)
n.unvisitedChildren = make([]*node, n.actionCount)
n.children = n.unvisitedChildren
n.childVisits = make([]float64, n.actionCount)
- for i := 0; i < n.actionCount; i++ {
- newGame, err := n.state.ApplyAction(i)
+ for i, a := range n.actions {
+ newGame, err := n.state.ApplyAction(a)
if err != nil {
panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
}
- newState := gameState{newGame, gameHash{newGame.Hash(), n.state.turn + 1}}
+ newState := gameState{newGame, n.state.turn + 1}
//If we already have a copy in cache, use that and update
//this node and its parents
- if cachedNode, made := n.tree.gameStates[newState.gameHash]; made {
+ if cachedNode, made := n.tree.gameStates[newState]; made {
n.unvisitedChildren[i] = cachedNode
} else {
newNode := initializeNode(newState, n.tree)
n.unvisitedChildren[i] = newNode
//Save node for reuse
- n.tree.gameStates[newState.gameHash] = newNode
+ n.tree.gameStates[newState] = newNode
}
}
}
@@ 118,13 118,11 @@ func (n *node) simulate() []Player {
for !game.IsTerminal() {
var err error
- actions := game.Len()
- if actions <= 0 {
- panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
- }
+ actions := game.GetActions()
+ panicIfNoActions(game, actions)
- randomIndex := n.tree.randSource.Intn(actions)
- game, err = game.ApplyAction(randomIndex)
+ randomIndex := n.tree.randSource.Intn(len(actions))
+ game, err = game.ApplyAction(actions[randomIndex])
if err != nil {
panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
}
M tree.go => tree.go +9 -12
@@ 8,8 8,7 @@ import (
//Search searches the tree for a specified time
//
//Search will panic if the Game's ApplyAction
-//method returns an error or if any game state's Hash()
-//method returns a noncomparable value.
+//method returns an error
func (t *Tree) Search(duration time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
@@ 19,8 18,7 @@ func (t *Tree) Search(duration time.Duration) {
//SearchContext searches the tree using a given context
//
//SearchContext will panic if the Game's ApplyAction
-//method returns an error or if any game state's Hash()
-//method returns a noncomparable value.
+//method returns an error
func (t *Tree) SearchContext(ctx context.Context) {
for {
select {
@@ 35,8 33,7 @@ func (t *Tree) SearchContext(ctx context.Context) {
//SearchRounds searches the tree for a specified number of rounds
//
//SearchRounds will panic if the Game's ApplyAction
-//method returns an error or if any game state's Hash()
-//method returns a noncomparable value.
+//method returns an error
func (t *Tree) SearchRounds(rounds int) {
for i := 0; i < rounds; i++ {
t.search()
@@ 64,25 61,25 @@ func (t Tree) Nodes() int {
//this tree searched through.
func (t Tree) MaxDepth() int {
maxDepth := 0
- for _, node := range t.gameStates {
- if node.state.turn > maxDepth {
- maxDepth = node.state.turn
+ for state := range t.gameStates {
+ if state.turn > maxDepth {
+ maxDepth = state.turn
}
}
return maxDepth
}
-func (t *Tree) bestAction() int {
+func (t *Tree) bestAction() Action {
root := t.current
//Select the child with the highest winrate
- var bestAction int
+ var bestAction Action
bestWinRate := -1.0
player := root.state.Player()
for i := 0; i < root.actionCount; i++ {
winRate := root.children[i].nodeScore[player] / root.childVisits[i]
if winRate > bestWinRate {
- bestAction = i
+ bestAction = root.actions[i]
bestWinRate = winRate
}
}
M tree_test.go => tree_test.go +1 -2
@@ 37,8 37,7 @@ func TestDepth(t *testing.T) {
}
func TestSearch(t *testing.T) {
- newGame := tictactoe.NewGame()
- mcts := NewMCTS(tttGame{newGame, newGame.GetActions()})
+ mcts := NewMCTS(tttGame{tictactoe.NewGame()})
tree := mcts.SpawnTree()
timeToSearch := 1 * time.Millisecond
A utils.go => utils.go +9 -0
@@ 0,0 1,9 @@
+package gmcts
+
+import "fmt"
+
+func panicIfNoActions(game Game, actions []Action) {
+ if len(actions) == 0 {
+ panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
+ }
+}
A v2/LICENSE => v2/LICENSE +25 -0
@@ 0,0 1,25 @@
+Copyright (c) 2020 bonbon. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ 1. Redistributions of source code must retain the above copyright notice,
+this list of conditions and the following disclaimer.
+ 2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in the
+documentation and/or other materials provided with the distribution.
+ 3. Neither the name of the copyright holder nor the names of its
+contributors may be used to endorse or promote products derived from this
+software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
A v2/README.md => v2/README.md +120 -0
@@ 0,0 1,120 @@
+[](https://godoc.org/git.sr.ht/~bonbon/gmcts/v2)
+
+GMCTS - Monte-Carlo Tree Search (the g stands for whatever you want it to mean :^) )
+====================================================================================
+
+GMCTS is an implementation of the Monte-Carlo Tree Search algorithm
+with support for any deterministic game.
+
+How To Install
+==============
+
+This project requires Go 1.7+ to run. To install, use `go get`:
+
+```bash
+go get git.sr.ht/~bonbon/gmcts
+```
+
+Alternatively, you can clone it yourself into your $GOPATH/src/git.sr.ht/~bonbon/ folder to get the latest dev build:
+
+```bash
+git clone https://git.sr.ht/~bonbon/gmcts
+```
+
+How To Use
+==========
+
+```go
+package pkg
+
+import (
+ "git.sr.ht/~bonbon/gmcts/v2"
+)
+
+func NewGame() gmcts.Game {
+ var game gmcts.Game
+ //...
+ //Setup a new game
+ //...
+ return game
+}
+
+func runGame() {
+ gameState := NewGame()
+
+ //MCTS algorithm will play against itself
+ //until a terminal state has been reached
+ for !gameState.IsTerminal() {
+ mcts := gmcts.NewMCTS(gameState)
+
+ //Spawn a new tree and play 1000 game simulations
+ tree := mcts.SpawnTree()
+ tree.SearchRounds(1000)
+
+ //Add the searched tree into the mcts tree collection
+ mcts.AddTree(tree)
+
+ //Get the best action based off of the trees collected from mcts.AddTree()
+ bestAction, err := mcts.BestAction()
+ if err != nil {
+ //...
+ //handle error
+ //...
+ }
+
+ //Update the game state using the tree's best action
+ gameState, _ = gameState.ApplyAction(bestAction)
+ }
+}
+```
+
+If you choose to, you can run multiple trees concurrently.
+
+```go
+concurrentTrees := 4
+
+mcts := gmcts.NewMCTS(gameState)
+
+//Run 4 trees concurrently
+var wait sync.WaitGroup
+wait.Add(concurrentTrees)
+for i := 0; i < concurrentTrees; i++ {
+ go func(){
+ tree := mcts.SpawnTree()
+ tree.SearchRounds(1000)
+ mcts.AddTree(tree)
+ wait.Done()
+ }()
+}
+//Wait for the 4 trees to finish searching
+wait.Wait()
+
+bestAction, err := mcts.BestAction()
+if err != nil {
+ //...
+ //handle error
+ //...
+}
+
+gameState, _ = gameState.ApplyAction(bestAction)
+```
+
+Testing
+=======
+
+You can test this package with `go test`. The test plays a game of tic-tac-toe against itself. The test should:
+
+1. Start the game by placing an x piece in the middle, and
+2. Finish in a draw.
+
+If either of these fail, the test fails. It's a rather neat way to make sure everything works as intended!
+
+Documentation
+=============
+
+Documentation for this package can be found either at [godoc.org](https://godoc.org/git.sr.ht/~bonbon/gmcts/v2) or [pkg.go.dev](https://pkg.go.dev/git.sr.ht/~bonbon/gmcts/v2)
+
+Bug Reports
+===========
+
+Email me at bonbon@bonbon.moe :D<
\ No newline at end of file
A v2/doc.go => v2/doc.go +12 -0
@@ 0,0 1,12 @@
+//Package gmcts is a generic implementation of the
+//Monte-Carlo Tree Search (mcts) algorithm.
+//
+//This package attempts to save memory and time by caching states as to not
+//have duplicate nodes in the search tree. This optimization is efficient for
+//games like tic-tac-toe, checkers, and go among others.
+//
+//This package also allows support for tree parallelization. Trees may
+//be spawned and ran in their own goroutine. After searching, they may be
+//compiled together to produce a more informed action than just searching
+//through one tree.
+package gmcts
A v2/go.mod => v2/go.mod +5 -0
@@ 0,0 1,5 @@
+module git.sr.ht/~bonbon/gmcts/v2
+
+go 1.14
+
+require git.sr.ht/~bonbon/go-tic-tac-toe v0.2.2
A v2/go.sum => v2/go.sum +2 -0
@@ 0,0 1,2 @@
+git.sr.ht/~bonbon/go-tic-tac-toe v0.2.2 h1:YnEUZuMybqGFu2uguNosoF3LCFsLrJ3pPyYI05KrhZA=
+git.sr.ht/~bonbon/go-tic-tac-toe v0.2.2/go.mod h1:np1W9swZFfyLAax6aZg6Q+4766tZPLkRABmeH2J5HVE=
A v2/mcts.go => v2/mcts.go +106 -0
@@ 0,0 1,106 @@
+package gmcts
+
+import (
+ "errors"
+ "math/rand"
+ "sync"
+)
+
+var (
+ //ErrNoTrees notifies the callee that the MCTS wrapper has recieved to trees to analyze
+ ErrNoTrees = errors.New("gmcts: mcts wrapper has collected to trees to analyze")
+
+ //ErrTerminal notifies the callee that the given state is terminal
+ ErrTerminal = errors.New("gmcts: given game state is a terminal state, therefore, it cannot return an action")
+
+ //ErrNoActions notifies the callee that the given state has <= 0 actions
+ ErrNoActions = errors.New("gmcts: given game state is not terminal, yet the state has <= 0 actions to search through")
+)
+
+//NewMCTS returns a new MCTS wrapper
+func NewMCTS(initial Game) *MCTS {
+ return &MCTS{
+ init: initial,
+ trees: make([]*Tree, 0),
+ mutex: new(sync.RWMutex),
+ }
+}
+
+//SpawnTree creates a new search tree. The tree returned uses Sqrt(2) as the
+//exploration constant.
+func (m *MCTS) SpawnTree() *Tree {
+ return m.SpawnCustomTree(DefaultExplorationConst)
+}
+
+//SetSeed sets the seed of the next tree to be spawned.
+//This value is initially set to 0, and increments on each
+//spawned tree.
+func (m *MCTS) SetSeed(seed int64) {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ m.seed = seed
+}
+
+//SpawnCustomTree creates a new search tree with a given exploration constant.
+func (m *MCTS) SpawnCustomTree(explorationConst float64) *Tree {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ t := &Tree{
+ gameStates: make(map[gameHash]*node),
+ explorationConst: explorationConst,
+ randSource: rand.New(rand.NewSource(m.seed)),
+ }
+ t.current = initializeNode(gameState{m.init, gameHash{m.init.Hash(), 0}}, t)
+
+ m.seed++
+ return t
+}
+
+//AddTree adds a searched tree to its list of trees to consider
+//when deciding upon an action to take.
+func (m *MCTS) AddTree(t *Tree) {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ m.trees = append(m.trees, t)
+}
+
+//BestAction takes all of the searched trees and returns
+//the index of the best action based on the highest win
+//percentage of each action.
+//
+//BestAction returns ErrNoTrees if it has received no trees
+//to search through, ErrNoActions if the current state
+//it's considering has no legal actions, or ErrTerminal
+//if the current state it's considering is terminal.
+func (m *MCTS) BestAction() (int, error) {
+ m.mutex.RLock()
+ defer m.mutex.RUnlock()
+
+ //Error checking
+ if len(m.trees) == 0 {
+ return -1, ErrNoTrees
+ } else if m.init.IsTerminal() {
+ return -1, ErrTerminal
+ } else if m.init.Len() <= 0 {
+ return -1, ErrNoActions
+ }
+
+ //Democracy Section: each tree votes for an action
+ actionScore := make([]int, m.init.Len())
+ for _, t := range m.trees {
+ actionScore[t.bestAction()]++
+ }
+
+ //Democracy Section: the action with the most votes wins
+ var bestAction int
+ var mostVotes int
+ for a, s := range actionScore {
+ if s > mostVotes {
+ bestAction = a
+ mostVotes = s
+ }
+ }
+ return bestAction, nil
+}
A v2/mcts_test.go => v2/mcts_test.go +167 -0
@@ 0,0 1,167 @@
+package gmcts
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+
+ tictactoe "git.sr.ht/~bonbon/go-tic-tac-toe"
+)
+
+func getPlayerID(ascii byte) Player {
+ if ascii == 'x' || ascii == 'X' {
+ return Player(0)
+ }
+ return Player(1)
+}
+
+type tttGame struct {
+ game tictactoe.Game
+ actions []tictactoe.Move
+}
+
+func (g tttGame) Len() int {
+ return len(g.actions)
+}
+
+func (g tttGame) ApplyAction(i int) (Game, error) {
+ game, err := g.game.ApplyAction(g.actions[i])
+
+ return tttGame{game, game.GetActions()}, err
+}
+
+func (g tttGame) Hash() interface{} {
+ return g.game
+}
+
+func (g tttGame) Player() Player {
+ return getPlayerID(g.game.Player())
+}
+
+func (g tttGame) IsTerminal() bool {
+ return g.game.IsTerminal()
+}
+
+func (g tttGame) Winners() []Player {
+ winner, _ := g.game.Winner()
+ if winner == '_' {
+ return []Player{Player(0), Player(1)}
+ }
+
+ return []Player{getPlayerID(winner)}
+}
+
+//Global vars to be checked by other tests
+var newGame, finishedGame tttGame
+var firstMove tictactoe.Move
+var treeToTest *Tree
+
+//TestMain runs through a tictactoe game, saving the first move made and
+//the resulting terminal game state into global variables to be used by
+//other tests.
+func TestMain(m *testing.M) {
+ newGame = tttGame{game: tictactoe.NewGame()}
+ newGame.actions = newGame.game.GetActions()
+
+ game := newGame
+ concurrentSearches := 1 //runtime.NumCPU()
+
+ var setFirstMove sync.Once
+ var setTestingTree sync.Once
+
+ for !game.IsTerminal() {
+ mcts := NewMCTS(game)
+
+ var wait sync.WaitGroup
+ wait.Add(concurrentSearches)
+ for i := 0; i < concurrentSearches; i++ {
+ go func() {
+ tree := mcts.SpawnTree()
+ tree.SearchRounds(10000)
+ mcts.AddTree(tree)
+ wait.Done()
+
+ //Set the tree to perform benchmarks on
+ setTestingTree.Do(func() {
+ treeToTest = tree
+ })
+ }()
+ }
+ wait.Wait()
+
+ bestAction, _ := mcts.BestAction()
+ nextState, _ := game.ApplyAction(bestAction)
+ game = nextState.(tttGame)
+ fmt.Println(game.game)
+
+ //Save the first action taken
+ setFirstMove.Do(func() {
+ firstMove = newGame.actions[bestAction]
+ })
+ }
+ //Save the terminal game state
+ finishedGame = game
+
+ m.Run()
+}
+
+func TestTicTacToeDraw(t *testing.T) {
+ //Fail if there's a winner. Because tic-tac-toe is a simple game,
+ //this algorithm should've finished in a draw.
+ if len(finishedGame.Winners()) != 2 {
+ t.Errorf("gmcts: tic-tac-toe game did not end in a draw")
+ t.FailNow()
+ }
+}
+
+func TestTicTacToeMiddle(t *testing.T) {
+ //Fail if the first move doesn't pick the middle square. Because tic-tac-toe
+ //is a simple game, this algorithm should've picked the middle square.
+ if fmt.Sprintf("%v", firstMove) != "{1 1}" {
+ t.Errorf("gmcts: first action is not to take the middle spot: %v", firstMove)
+ t.FailNow()
+ }
+}
+
+func TestZeroTrees(t *testing.T) {
+ mcts := NewMCTS(finishedGame)
+ bestAction, _ := mcts.BestAction()
+ if bestAction != -1 {
+ t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction)
+ t.FailNow()
+ }
+}
+
+func TestTerminalState(t *testing.T) {
+ mcts := NewMCTS(finishedGame)
+ mcts.AddTree(mcts.SpawnTree())
+ bestAction, _ := mcts.BestAction()
+ if bestAction != -1 {
+ t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction)
+ t.FailNow()
+ }
+}
+
+func BenchmarkTicTacToe1KRounds(b *testing.B) {
+ mcts := NewMCTS(newGame)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mcts.SpawnTree().SearchRounds(1000)
+ }
+}
+
+func BenchmarkTicTacToe10KRounds(b *testing.B) {
+ mcts := NewMCTS(newGame)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mcts.SpawnTree().SearchRounds(10000)
+ }
+}
+
+func BenchmarkTicTacToe100KRounds(b *testing.B) {
+ mcts := NewMCTS(newGame)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mcts.SpawnTree().SearchRounds(100000)
+ }
+}
A v2/models.go => v2/models.go +80 -0
@@ 0,0 1,80 @@
+package gmcts
+
+import (
+ "math/rand"
+ "sync"
+)
+
+//Player is an id for the player
+type Player int
+
+//Game is the interface that represents game states.
+//
+//Any implementation of Game should be immutable
+//(state cannot change as this package calls any function).
+type Game interface {
+ //Len returns the number of actions to consider.
+ Len() int
+
+ //ApplyAction applies the ith action (0-indexed) to the game state,
+ //and returns a new game state and an error for invalid actions
+ ApplyAction(i int) (Game, error)
+
+ //Hash returns a unique representation of the state.
+ //Any return value must be comparable.
+ Hash() interface{}
+
+ //Player returns the player that can take the next action
+ Player() Player
+
+ //IsTerminal returns true if this game state is a terminal state
+ IsTerminal() bool
+
+ //Winners returns a list of players that have won the game if
+ //IsTerminal() returns true
+ Winners() []Player
+}
+
+type gameState struct {
+ Game
+ gameHash
+}
+
+type gameHash struct {
+ hash interface{}
+
+ //This is to separate states that seemingly look the same,
+ //but actually occur on different turn orders. Without this,
+ //the directed acyclic graph will become a directed cyclic graph,
+ //which this MCTS implementation cannot handle properly.
+ turn int
+}
+
+//MCTS contains functionality for the MCTS algorithm
+type MCTS struct {
+ init Game
+ trees []*Tree
+ mutex *sync.RWMutex
+ seed int64
+}
+
+type node struct {
+ state gameState
+ tree *Tree
+
+ children []*node
+ unvisitedChildren []*node
+ childVisits []float64
+ actionCount int
+
+ nodeScore map[Player]float64
+ nodeVisits int
+}
+
+//Tree represents a game state tree
+type Tree struct {
+ current *node
+ gameStates map[gameHash]*node
+ explorationConst float64
+ randSource *rand.Rand
+}
A v2/search.go => v2/search.go +133 -0
@@ 0,0 1,133 @@
+package gmcts
+
+import (
+ "fmt"
+ "math"
+)
+
+const (
+ //DefaultExplorationConst is the default exploration constant of UCB1 Formula
+ //Sqrt(2) is a frequent choice for this constant as specified by
+ //https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
+ DefaultExplorationConst = math.Sqrt2
+)
+
+func initializeNode(g gameState, tree *Tree) *node {
+ return &node{
+ state: g,
+ tree: tree,
+ nodeScore: make(map[Player]float64),
+ }
+}
+
+//UCT2 algorithm is described in this paper
+//https://www.csse.uwa.edu.au/cig08/Proceedings/papers/8057.pdf
+func (n *node) UCT2(i int, p Player) float64 {
+ exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)
+
+ explore := math.Log(float64(n.nodeVisits)) / n.childVisits[i]
+ explore = math.Sqrt(explore)
+
+ return exploit + n.tree.explorationConst*explore
+}
+
+func (n *node) runSimulation() ([]Player, float64) {
+ var selectedChildIndex int
+ var winners []Player
+ var scoreToAdd float64
+ var terminalState bool
+
+ //If we have actions, then there's no need to expand.
+ if n.actionCount == 0 {
+ //If we don't have any actions, then either the state
+ //is terminal, or we haven't expanded the node yet.
+ terminalState = n.state.IsTerminal()
+ if !terminalState {
+ n.expand()
+ }
+ }
+
+ if terminalState {
+ //Get the result of the game
+ winners = n.simulate()
+ scoreToAdd = 1.0 / float64(len(winners))
+ } else if len(n.unvisitedChildren) > 0 {
+ //Grab the first unvisited child and run a simulation from that point
+ selectedChildIndex = n.actionCount - len(n.unvisitedChildren)
+ n.children[selectedChildIndex].nodeVisits++
+ n.unvisitedChildren = n.unvisitedChildren[1:]
+
+ winners = n.children[selectedChildIndex].simulate()
+ scoreToAdd = 1.0 / float64(len(winners))
+ } else {
+ //Select the child with the max UCT2 score with the current player
+ //and get the results to add from its selection
+ maxScore := -1.0
+ thisPlayer := n.state.Player()
+ for i := 0; i < n.actionCount; i++ {
+ score := n.UCT2(i, thisPlayer)
+ if score > maxScore {
+ maxScore = score
+ selectedChildIndex = i
+ }
+ }
+ winners, scoreToAdd = n.children[selectedChildIndex].runSimulation()
+ }
+
+ //Update this node along with each parent in this path recursively
+ n.nodeVisits++
+ if n.actionCount != 0 {
+ n.childVisits[selectedChildIndex]++
+ }
+
+ for _, p := range winners {
+ n.nodeScore[p] += scoreToAdd
+ }
+ return winners, scoreToAdd
+}
+
+func (n *node) expand() {
+ n.actionCount = n.state.Len()
+ n.unvisitedChildren = make([]*node, n.actionCount)
+ n.children = n.unvisitedChildren
+ n.childVisits = make([]float64, n.actionCount)
+ for i := 0; i < n.actionCount; i++ {
+ newGame, err := n.state.ApplyAction(i)
+ if err != nil {
+ panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
+ }
+
+ newState := gameState{newGame, gameHash{newGame.Hash(), n.state.turn + 1}}
+
+ //If we already have a copy in cache, use that and update
+ //this node and its parents
+ if cachedNode, made := n.tree.gameStates[newState.gameHash]; made {
+ n.unvisitedChildren[i] = cachedNode
+ } else {
+ newNode := initializeNode(newState, n.tree)
+ n.unvisitedChildren[i] = newNode
+
+ //Save node for reuse
+ n.tree.gameStates[newState.gameHash] = newNode
+ }
+ }
+}
+
+func (n *node) simulate() []Player {
+ game := n.state.Game
+ for !game.IsTerminal() {
+ var err error
+
+ actions := game.Len()
+ if actions <= 0 {
+ panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
+ }
+
+ randomIndex := n.tree.randSource.Intn(actions)
+ game, err = game.ApplyAction(randomIndex)
+ if err != nil {
+ panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
+ }
+ }
+ return game.Winners()
+}
A v2/tree.go => v2/tree.go +91 -0
@@ 0,0 1,91 @@
+package gmcts
+
+import (
+ "context"
+ "time"
+)
+
+//Search searches the tree for a specified time
+//
+//Search will panic if the Game's ApplyAction
+//method returns an error or if any game state's Hash()
+//method returns a noncomparable value.
+func (t *Tree) Search(duration time.Duration) {
+ ctx, cancel := context.WithTimeout(context.Background(), duration)
+ defer cancel()
+ t.SearchContext(ctx)
+}
+
+//SearchContext searches the tree using a given context
+//
+//SearchContext will panic if the Game's ApplyAction
+//method returns an error or if any game state's Hash()
+//method returns a noncomparable value.
+func (t *Tree) SearchContext(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ t.search()
+ }
+ }
+}
+
+//SearchRounds searches the tree for a specified number of rounds
+//
+//SearchRounds will panic if the Game's ApplyAction
+//method returns an error or if any game state's Hash()
+//method returns a noncomparable value.
+func (t *Tree) SearchRounds(rounds int) {
+ for i := 0; i < rounds; i++ {
+ t.search()
+ }
+}
+
+//search performs 1 round of the MCTS algorithm
+func (t *Tree) search() {
+ t.current.runSimulation()
+}
+
+//Rounds returns the number of MCTS rounds were performed
+//on this tree.
+func (t Tree) Rounds() int {
+ return t.current.nodeVisits
+}
+
+//Nodes returns the number of nodes created on this tree.
+func (t Tree) Nodes() int {
+ return len(t.gameStates)
+}
+
+//MaxDepth returns the maximum depth of this tree.
+//The value can be thought of as the amount of moves ahead
+//this tree searched through.
+func (t Tree) MaxDepth() int {
+ maxDepth := 0
+ for _, node := range t.gameStates {
+ if node.state.turn > maxDepth {
+ maxDepth = node.state.turn
+ }
+ }
+ return maxDepth
+}
+
+func (t *Tree) bestAction() int {
+ root := t.current
+
+ //Select the child with the highest winrate
+ var bestAction int
+ bestWinRate := -1.0
+ player := root.state.Player()
+ for i := 0; i < root.actionCount; i++ {
+ winRate := root.children[i].nodeScore[player] / root.childVisits[i]
+ if winRate > bestWinRate {
+ bestAction = i
+ bestWinRate = winRate
+ }
+ }
+
+ return bestAction
+}
A v2/tree_test.go => v2/tree_test.go +53 -0
@@ 0,0 1,53 @@
+package gmcts
+
+import (
+ "testing"
+ "time"
+
+ tictactoe "git.sr.ht/~bonbon/go-tic-tac-toe"
+)
+
+func TestRounds(t *testing.T) {
+ rounds := treeToTest.Rounds()
+ if rounds != 10000 {
+ t.Errorf("Tree performed %d rounds: wanted 1", rounds)
+ t.FailNow()
+ }
+}
+
+func TestNodes(t *testing.T) {
+ //The amount of nodes in the tree should not exceed the
+ //amount of mcts rounds performed on the tree.
+ rounds := treeToTest.Rounds()
+ nodes := treeToTest.Nodes()
+ if nodes > rounds {
+ t.Errorf("Tree has %d nodes: wanted <= %d", nodes, rounds)
+ t.FailNow()
+ }
+}
+
+func TestDepth(t *testing.T) {
+ //Because tictactoe is a simple game, the
+ //tree should have looked 9 moves ahead.
+ depth := treeToTest.MaxDepth()
+ if depth != 9 {
+ t.Errorf("Tree has depth %d: wanted 0", depth)
+ t.FailNow()
+ }
+}
+
+func TestSearch(t *testing.T) {
+ newGame := tictactoe.NewGame()
+ mcts := NewMCTS(tttGame{newGame, newGame.GetActions()})
+ tree := mcts.SpawnTree()
+
+ timeToSearch := 1 * time.Millisecond
+ t0 := time.Now()
+ tree.Search(timeToSearch)
+ td := time.Now().Sub(t0)
+
+ if td < timeToSearch {
+ t.Errorf("Tree was searched for %s: wanted >= %s", td, timeToSearch)
+ t.FailNow()
+ }
+}