~bonbon/gmcts

f7bd5eca4c27c40450cc297a73a247e43b90885a — bonbon 1 year, 2 months ago 1e8a724 master v2.0.0
move v2.0.0 changes to v2 folder

all changes post-v1.2.1 were reverted in the main directory
M README.md => README.md +2 -12
@@ 55,12 55,7 @@ func runGame() {
        mcts.AddTree(tree)

        //Get the best action based off of the trees collected from mcts.AddTree()
        bestAction, err := mcts.BestAction()
        if err != nil {
            //...
            //handle error
            //...
        }
        bestAction := mcts.BestAction()

        //Update the game state using the tree's best action
        gameState, _ = gameState.ApplyAction(bestAction)


@@ 89,12 84,7 @@ for i := 0; i < concurrentTrees; i++ {
//Wait for the 4 trees to finish searching
wait.Wait()

bestAction, err := mcts.BestAction()
if err != nil {
    //...
    //handle error
    //...
}
bestAction := mcts.BestAction()

gameState, _ = gameState.ApplyAction(bestAction)
```

A comparable_test.go => comparable_test.go +55 -0
@@ 0,0 1,55 @@
package gmcts

import "testing"

//comparableState tests the comparable action requirement of gmcts, as
//the GetActions method returns a noncomparable action.
type comparableState struct{}

//nonComparableState tests the comparable state requirement of gmcts.
type nonComparableState struct {
	comparableState
	_ []int
}

func (n comparableState) GetActions() []Action {
	return []Action{nonComparableState{}}
}

func (n comparableState) ApplyAction(a Action) (Game, error) {
	return n, nil
}

func (n comparableState) IsTerminal() bool {
	return true
}

func (n comparableState) Player() Player {
	return 0
}

func (n comparableState) Winners() []Player {
	return nil
}

func TestNonComparableState(t *testing.T) {
	//Calling NewMCTS should panic, as the nonComparableState is, as
	//the name suggests, not comparable.
	defer func() {
		if r := recover(); r == nil {
			t.FailNow()
		}
	}()
	NewMCTS(nonComparableState{})
}

func TestNonComparableAction(t *testing.T) {
	//Calling NewMCTS should panic, as the actions from comparableState
	//are not comparable.
	defer func() {
		if r := recover(); r == nil {
			t.FailNow()
		}
	}()
	NewMCTS(comparableState{})
}

M mcts.go => mcts.go +38 -31
@@ 1,24 1,31 @@
package gmcts

import (
	"errors"
	"math/rand"
	"reflect"
	"sync"
)

var (
	//ErrNoTrees notifies the callee that the MCTS wrapper has recieved to trees to analyze
	ErrNoTrees = errors.New("gmcts: mcts wrapper has collected to trees to analyze")

	//ErrTerminal notifies the callee that the given state is terminal
	ErrTerminal = errors.New("gmcts: given game state is a terminal state, therefore, it cannot return an action")

	//ErrNoActions notifies the callee that the given state has <= 0 actions
	ErrNoActions = errors.New("gmcts: given game state is not terminal, yet the state has <= 0 actions to search through")
)

//NewMCTS returns a new MCTS wrapper
//
//If either the Game or its Action types are not comparable,
//this function panics
func NewMCTS(initial Game) *MCTS {
	//Check if Game type if comparable
	if !reflect.TypeOf(initial).Comparable() {
		panic("gmcts: game type is not comparable")
	}

	//Check if Action type is comparable
	//We only need to check the actions that can affect the initial gamestate
	//as those are the only actions that need to be compared.
	actions := initial.GetActions()
	for i := range actions {
		if !reflect.TypeOf(actions[i]).Comparable() {
			panic("gmcts: action type is not comparable")
		}
	}

	return &MCTS{
		init:  initial,
		trees: make([]*Tree, 0),


@@ 33,7 40,7 @@ func (m *MCTS) SpawnTree() *Tree {
}

//SetSeed sets the seed of the next tree to be spawned.
//This value is initially set to 0, and increments on each
//This value is initially set to 1, and increments on each
//spawned tree.
func (m *MCTS) SetSeed(seed int64) {
	m.mutex.Lock()


@@ 47,11 54,11 @@ func (m *MCTS) SpawnCustomTree(explorationConst float64) *Tree {
	defer m.mutex.Unlock()

	t := &Tree{
		gameStates:       make(map[gameHash]*node),
		gameStates:       make(map[gameState]*node),
		explorationConst: explorationConst,
		randSource:       rand.New(rand.NewSource(m.seed)),
	}
	t.current = initializeNode(gameState{m.init, gameHash{m.init.Hash(), 0}}, t)
	t.current = initializeNode(gameState{m.init, 0}, t)

	m.seed++
	return t


@@ 67,34 74,34 @@ func (m *MCTS) AddTree(t *Tree) {
}

//BestAction takes all of the searched trees and returns
//the index of the best action based on the highest win
//percentage of each action.
//the best action based on the highest win percentage
//of each action.
//
//BestAction returns ErrNoTrees if it has received no trees
//to search through, ErrNoActions if the current state
//it's considering has no legal actions, or ErrTerminal
//if the current state it's considering is terminal.
func (m *MCTS) BestAction() (int, error) {
//BestAction returns nil if it has received no trees
//to search through or if the current state
//it's considering has no legal actions or is terminal.
func (m *MCTS) BestAction() Action {
	m.mutex.RLock()
	defer m.mutex.RUnlock()

	//Error checking
	if len(m.trees) == 0 {
		return -1, ErrNoTrees
	} else if m.init.IsTerminal() {
		return -1, ErrTerminal
	} else if m.init.Len() <= 0 {
		return -1, ErrNoActions
		return nil
	}

	//Safe guard set in place in case we're dealing
	//with a terminal state
	if m.init.IsTerminal() {
		return nil
	}

	//Democracy Section: each tree votes for an action
	actionScore := make([]int, m.init.Len())
	actionScore := make(map[Action]int)
	for _, t := range m.trees {
		actionScore[t.bestAction()]++
	}

	//Democracy Section: the action with the most votes wins
	var bestAction int
	var bestAction Action
	var mostVotes int
	for a, s := range actionScore {
		if s > mostVotes {


@@ 102,5 109,5 @@ func (m *MCTS) BestAction() (int, error) {
			mostVotes = s
		}
	}
	return bestAction, nil
	return bestAction
}

M mcts_test.go => mcts_test.go +29 -24
@@ 16,22 16,30 @@ func getPlayerID(ascii byte) Player {
}

type tttGame struct {
	game    tictactoe.Game
	actions []tictactoe.Move
	game tictactoe.Game
}

func (g tttGame) Len() int {
	return len(g.actions)
}
func (g tttGame) GetActions() []Action {
	gameActions := g.game.GetActions()

	actions := make([]Action, len(gameActions))

func (g tttGame) ApplyAction(i int) (Game, error) {
	game, err := g.game.ApplyAction(g.actions[i])
	for i, a := range gameActions {
		actions[i] = a
	}

	return tttGame{game, game.GetActions()}, err
	return actions
}

func (g tttGame) Hash() interface{} {
	return g.game
func (g tttGame) ApplyAction(a Action) (Game, error) {
	move, ok := a.(tictactoe.Move)
	if !ok {
		return nil, fmt.Errorf("action not correct type")
	}

	game, err := g.game.ApplyAction(move)

	return tttGame{game}, err
}

func (g tttGame) Player() Player {


@@ 52,7 60,7 @@ func (g tttGame) Winners() []Player {
}

//Global vars to be checked by other tests
var newGame, finishedGame tttGame
var finishedGame tttGame
var firstMove tictactoe.Move
var treeToTest *Tree



@@ 60,10 68,7 @@ var treeToTest *Tree
//the resulting terminal game state into global variables to be used by
//other tests.
func TestMain(m *testing.M) {
	newGame = tttGame{game: tictactoe.NewGame()}
	newGame.actions = newGame.game.GetActions()

	game := newGame
	game := tttGame{tictactoe.NewGame()}
	concurrentSearches := 1 //runtime.NumCPU()

	var setFirstMove sync.Once


@@ 89,14 94,14 @@ func TestMain(m *testing.M) {
		}
		wait.Wait()

		bestAction, _ := mcts.BestAction()
		bestAction := mcts.BestAction()
		nextState, _ := game.ApplyAction(bestAction)
		game = nextState.(tttGame)
		fmt.Println(game.game)

		//Save the first action taken
		setFirstMove.Do(func() {
			firstMove = newGame.actions[bestAction]
			firstMove = bestAction.(tictactoe.Move)
		})
	}
	//Save the terminal game state


@@ 125,8 130,8 @@ func TestTicTacToeMiddle(t *testing.T) {

func TestZeroTrees(t *testing.T) {
	mcts := NewMCTS(finishedGame)
	bestAction, _ := mcts.BestAction()
	if bestAction != -1 {
	bestAction := mcts.BestAction()
	if bestAction != nil {
		t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction)
		t.FailNow()
	}


@@ 135,15 140,15 @@ func TestZeroTrees(t *testing.T) {
func TestTerminalState(t *testing.T) {
	mcts := NewMCTS(finishedGame)
	mcts.AddTree(mcts.SpawnTree())
	bestAction, _ := mcts.BestAction()
	if bestAction != -1 {
	bestAction := mcts.BestAction()
	if bestAction != nil {
		t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction)
		t.FailNow()
	}
}

func BenchmarkTicTacToe1KRounds(b *testing.B) {
	mcts := NewMCTS(newGame)
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(1000)


@@ 151,7 156,7 @@ func BenchmarkTicTacToe1KRounds(b *testing.B) {
}

func BenchmarkTicTacToe10KRounds(b *testing.B) {
	mcts := NewMCTS(newGame)
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(10000)


@@ 159,7 164,7 @@ func BenchmarkTicTacToe10KRounds(b *testing.B) {
}

func BenchmarkTicTacToe100KRounds(b *testing.B) {
	mcts := NewMCTS(newGame)
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(100000)

M models.go => models.go +17 -18
@@ 5,24 5,26 @@ import (
	"sync"
)

//Action is the interface that represents an action that can be
//performed on a Game.
//
//Any implementation of Action should be comparable (i.e. be a key in a map)
type Action interface{}

//Player is an id for the player
type Player int

//Game is the interface that represents game states.
//
//Any implementation of Game should be immutable
//(state cannot change as this package calls any function).
//Any implementation of Game should be comparable (i.e. be a key in a map)
//and immutable (state cannot change as this package calls any function).
type Game interface {
	//Len returns the number of actions to consider.
	Len() int
	//GetActions returns a list of actions to consider
	GetActions() []Action

	//ApplyAction applies the ith action (0-indexed) to the game state,
	//ApplyAction applies the given action to the game state,
	//and returns a new game state and an error for invalid actions
	ApplyAction(i int) (Game, error)

	//Hash returns a unique representation of the state.
	//Any return value must be comparable.
	Hash() interface{}
	ApplyAction(Action) (Game, error)

	//Player returns the player that can take the next action
	Player() Player


@@ 37,16 39,12 @@ type Game interface {

type gameState struct {
	Game
	gameHash
}

type gameHash struct {
	hash interface{}

	//This is to separate states that seemingly look the same,
	//but actually occur on different turn orders. Without this,
	//the directed acyclic graph will become a directed cyclic graph,
	//which this MCTS implementation cannot handle properly.
	//the directed tree that multiple parent nodes will just
	//become a directed graph, which this MCTS implementation
	//cannot handle properly.
	turn int
}



@@ 62,6 60,7 @@ type node struct {
	state gameState
	tree  *Tree

	actions           []Action
	children          []*node
	unvisitedChildren []*node
	childVisits       []float64


@@ 74,7 73,7 @@ type node struct {
//Tree represents a game state tree
type Tree struct {
	current          *node
	gameStates       map[gameHash]*node
	gameStates       map[gameState]*node
	explorationConst float64
	randSource       *rand.Rand
}

M search.go => search.go +14 -16
@@ 20,13 20,12 @@ func initializeNode(g gameState, tree *Tree) *node {
	}
}

//UCT2 algorithm is described in this paper
//https://www.csse.uwa.edu.au/cig08/Proceedings/papers/8057.pdf
func (n *node) UCT2(i int, p Player) float64 {
	exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)

	explore := math.Log(float64(n.nodeVisits)) / n.childVisits[i]
	explore = math.Sqrt(explore)
	explore := math.Sqrt(
		math.Log(float64(n.nodeVisits)) / n.childVisits[i],
	)

	return exploit + n.tree.explorationConst*explore
}


@@ 87,28 86,29 @@ func (n *node) runSimulation() ([]Player, float64) {
}

func (n *node) expand() {
	n.actionCount = n.state.Len()
	n.actions = n.state.GetActions()
	n.actionCount = len(n.actions)
	n.unvisitedChildren = make([]*node, n.actionCount)
	n.children = n.unvisitedChildren
	n.childVisits = make([]float64, n.actionCount)
	for i := 0; i < n.actionCount; i++ {
		newGame, err := n.state.ApplyAction(i)
	for i, a := range n.actions {
		newGame, err := n.state.ApplyAction(a)
		if err != nil {
			panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
		}

		newState := gameState{newGame, gameHash{newGame.Hash(), n.state.turn + 1}}
		newState := gameState{newGame, n.state.turn + 1}

		//If we already have a copy in cache, use that and update
		//this node and its parents
		if cachedNode, made := n.tree.gameStates[newState.gameHash]; made {
		if cachedNode, made := n.tree.gameStates[newState]; made {
			n.unvisitedChildren[i] = cachedNode
		} else {
			newNode := initializeNode(newState, n.tree)
			n.unvisitedChildren[i] = newNode

			//Save node for reuse
			n.tree.gameStates[newState.gameHash] = newNode
			n.tree.gameStates[newState] = newNode
		}
	}
}


@@ 118,13 118,11 @@ func (n *node) simulate() []Player {
	for !game.IsTerminal() {
		var err error

		actions := game.Len()
		if actions <= 0 {
			panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
		}
		actions := game.GetActions()
		panicIfNoActions(game, actions)

		randomIndex := n.tree.randSource.Intn(actions)
		game, err = game.ApplyAction(randomIndex)
		randomIndex := n.tree.randSource.Intn(len(actions))
		game, err = game.ApplyAction(actions[randomIndex])
		if err != nil {
			panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
		}

M tree.go => tree.go +9 -12
@@ 8,8 8,7 @@ import (
//Search searches the tree for a specified time
//
//Search will panic if the Game's ApplyAction
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
//method returns an error
func (t *Tree) Search(duration time.Duration) {
	ctx, cancel := context.WithTimeout(context.Background(), duration)
	defer cancel()


@@ 19,8 18,7 @@ func (t *Tree) Search(duration time.Duration) {
//SearchContext searches the tree using a given context
//
//SearchContext will panic if the Game's ApplyAction
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
//method returns an error
func (t *Tree) SearchContext(ctx context.Context) {
	for {
		select {


@@ 35,8 33,7 @@ func (t *Tree) SearchContext(ctx context.Context) {
//SearchRounds searches the tree for a specified number of rounds
//
//SearchRounds will panic if the Game's ApplyAction
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
//method returns an error
func (t *Tree) SearchRounds(rounds int) {
	for i := 0; i < rounds; i++ {
		t.search()


@@ 64,25 61,25 @@ func (t Tree) Nodes() int {
//this tree searched through.
func (t Tree) MaxDepth() int {
	maxDepth := 0
	for _, node := range t.gameStates {
		if node.state.turn > maxDepth {
			maxDepth = node.state.turn
	for state := range t.gameStates {
		if state.turn > maxDepth {
			maxDepth = state.turn
		}
	}
	return maxDepth
}

func (t *Tree) bestAction() int {
func (t *Tree) bestAction() Action {
	root := t.current

	//Select the child with the highest winrate
	var bestAction int
	var bestAction Action
	bestWinRate := -1.0
	player := root.state.Player()
	for i := 0; i < root.actionCount; i++ {
		winRate := root.children[i].nodeScore[player] / root.childVisits[i]
		if winRate > bestWinRate {
			bestAction = i
			bestAction = root.actions[i]
			bestWinRate = winRate
		}
	}

M tree_test.go => tree_test.go +1 -2
@@ 37,8 37,7 @@ func TestDepth(t *testing.T) {
}

func TestSearch(t *testing.T) {
	newGame := tictactoe.NewGame()
	mcts := NewMCTS(tttGame{newGame, newGame.GetActions()})
	mcts := NewMCTS(tttGame{tictactoe.NewGame()})
	tree := mcts.SpawnTree()

	timeToSearch := 1 * time.Millisecond

A utils.go => utils.go +9 -0
@@ 0,0 1,9 @@
package gmcts

import "fmt"

func panicIfNoActions(game Game, actions []Action) {
	if len(actions) == 0 {
		panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
	}
}

A v2/LICENSE => v2/LICENSE +25 -0
@@ 0,0 1,25 @@
Copyright (c) 2020 bonbon. All rights reserved.

Redistribution and use in source and binary forms, with or without 
modification, are permitted provided that the following conditions are met:

    1. Redistributions of source code must retain the above copyright notice, 
this list of conditions and the following disclaimer.
    2. Redistributions in binary form must reproduce the above copyright 
notice, this list of conditions and the following disclaimer in the 
documentation and/or other materials provided with the distribution.
    3. Neither the name of the copyright holder nor the names of its 
contributors may be used to endorse or promote products derived from this 
software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


A v2/README.md => v2/README.md +120 -0
@@ 0,0 1,120 @@
[![Documentation](https://img.shields.io/badge/Documentation-GoDoc-green.svg)](https://godoc.org/git.sr.ht/~bonbon/gmcts/v2)

GMCTS - Monte-Carlo Tree Search (the g stands for whatever you want it to mean :^) )
====================================================================================

GMCTS is an implementation of the Monte-Carlo Tree Search algorithm
with support for any deterministic game.

How To Install
==============

This project requires Go 1.7+ to run. To install, use `go get`:

```bash
go get git.sr.ht/~bonbon/gmcts
```

Alternatively, you can clone it yourself into your $GOPATH/src/git.sr.ht/~bonbon/ folder to get the latest dev build:

```bash
git clone https://git.sr.ht/~bonbon/gmcts
```

How To Use
==========

```go
package pkg

import (
    "git.sr.ht/~bonbon/gmcts/v2"
)

func NewGame() gmcts.Game {
    var game gmcts.Game
    //...
    //Setup a new game
    //...
    return game
}

func runGame() {
    gameState := NewGame()

    //MCTS algorithm will play against itself
    //until a terminal state has been reached
    for !gameState.IsTerminal() {
        mcts := gmcts.NewMCTS(gameState)

        //Spawn a new tree and play 1000 game simulations
        tree := mcts.SpawnTree()
        tree.SearchRounds(1000)

        //Add the searched tree into the mcts tree collection
        mcts.AddTree(tree)

        //Get the best action based off of the trees collected from mcts.AddTree()
        bestAction, err := mcts.BestAction()
        if err != nil {
            //...
            //handle error
            //...
        }

        //Update the game state using the tree's best action
        gameState, _ = gameState.ApplyAction(bestAction)
    }
}
```

If you choose to, you can run multiple trees concurrently.

```go
concurrentTrees := 4

mcts := gmcts.NewMCTS(gameState)

//Run 4 trees concurrently
var wait sync.WaitGroup
wait.Add(concurrentTrees)
for i := 0; i < concurrentTrees; i++ {
    go func(){
        tree := mcts.SpawnTree()
        tree.SearchRounds(1000)
        mcts.AddTree(tree)
        wait.Done()
    }()
}
//Wait for the 4 trees to finish searching
wait.Wait()

bestAction, err := mcts.BestAction()
if err != nil {
    //...
    //handle error
    //...
}

gameState, _ = gameState.ApplyAction(bestAction)
```

Testing
=======

You can test this package with `go test`. The test plays a game of tic-tac-toe against itself. The test should:

1. Start the game by placing an x piece in the middle, and
2. Finish in a draw.

If either of these fail, the test fails. It's a rather neat way to make sure everything works as intended!

Documentation
=============

Documentation for this package can be found either at [godoc.org](https://godoc.org/git.sr.ht/~bonbon/gmcts/v2) or [pkg.go.dev](https://pkg.go.dev/git.sr.ht/~bonbon/gmcts/v2)

Bug Reports
===========

Email me at bonbon@bonbon.moe :D
\ No newline at end of file

A v2/doc.go => v2/doc.go +12 -0
@@ 0,0 1,12 @@
//Package gmcts is a generic implementation of the
//Monte-Carlo Tree Search (mcts) algorithm.
//
//This package attempts to save memory and time by caching states as to not
//have duplicate nodes in the search tree. This optimization is efficient for
//games like tic-tac-toe, checkers, and go among others.
//
//This package also allows support for tree parallelization. Trees may
//be spawned and ran in their own goroutine. After searching, they may be
//compiled together to produce a more informed action than just searching
//through one tree.
package gmcts

A v2/go.mod => v2/go.mod +5 -0
@@ 0,0 1,5 @@
module git.sr.ht/~bonbon/gmcts/v2

go 1.14

require git.sr.ht/~bonbon/go-tic-tac-toe v0.2.2

A v2/go.sum => v2/go.sum +2 -0
@@ 0,0 1,2 @@
git.sr.ht/~bonbon/go-tic-tac-toe v0.2.2 h1:YnEUZuMybqGFu2uguNosoF3LCFsLrJ3pPyYI05KrhZA=
git.sr.ht/~bonbon/go-tic-tac-toe v0.2.2/go.mod h1:np1W9swZFfyLAax6aZg6Q+4766tZPLkRABmeH2J5HVE=

A v2/mcts.go => v2/mcts.go +106 -0
@@ 0,0 1,106 @@
package gmcts

import (
	"errors"
	"math/rand"
	"sync"
)

var (
	//ErrNoTrees notifies the callee that the MCTS wrapper has recieved to trees to analyze
	ErrNoTrees = errors.New("gmcts: mcts wrapper has collected to trees to analyze")

	//ErrTerminal notifies the callee that the given state is terminal
	ErrTerminal = errors.New("gmcts: given game state is a terminal state, therefore, it cannot return an action")

	//ErrNoActions notifies the callee that the given state has <= 0 actions
	ErrNoActions = errors.New("gmcts: given game state is not terminal, yet the state has <= 0 actions to search through")
)

//NewMCTS returns a new MCTS wrapper
func NewMCTS(initial Game) *MCTS {
	return &MCTS{
		init:  initial,
		trees: make([]*Tree, 0),
		mutex: new(sync.RWMutex),
	}
}

//SpawnTree creates a new search tree. The tree returned uses Sqrt(2) as the
//exploration constant.
func (m *MCTS) SpawnTree() *Tree {
	return m.SpawnCustomTree(DefaultExplorationConst)
}

//SetSeed sets the seed of the next tree to be spawned.
//This value is initially set to 0, and increments on each
//spawned tree.
func (m *MCTS) SetSeed(seed int64) {
	m.mutex.Lock()
	defer m.mutex.Unlock()
	m.seed = seed
}

//SpawnCustomTree creates a new search tree with a given exploration constant.
func (m *MCTS) SpawnCustomTree(explorationConst float64) *Tree {
	m.mutex.Lock()
	defer m.mutex.Unlock()

	t := &Tree{
		gameStates:       make(map[gameHash]*node),
		explorationConst: explorationConst,
		randSource:       rand.New(rand.NewSource(m.seed)),
	}
	t.current = initializeNode(gameState{m.init, gameHash{m.init.Hash(), 0}}, t)

	m.seed++
	return t
}

//AddTree adds a searched tree to its list of trees to consider
//when deciding upon an action to take.
func (m *MCTS) AddTree(t *Tree) {
	m.mutex.Lock()
	defer m.mutex.Unlock()

	m.trees = append(m.trees, t)
}

//BestAction takes all of the searched trees and returns
//the index of the best action based on the highest win
//percentage of each action.
//
//BestAction returns ErrNoTrees if it has received no trees
//to search through, ErrNoActions if the current state
//it's considering has no legal actions, or ErrTerminal
//if the current state it's considering is terminal.
func (m *MCTS) BestAction() (int, error) {
	m.mutex.RLock()
	defer m.mutex.RUnlock()

	//Error checking
	if len(m.trees) == 0 {
		return -1, ErrNoTrees
	} else if m.init.IsTerminal() {
		return -1, ErrTerminal
	} else if m.init.Len() <= 0 {
		return -1, ErrNoActions
	}

	//Democracy Section: each tree votes for an action
	actionScore := make([]int, m.init.Len())
	for _, t := range m.trees {
		actionScore[t.bestAction()]++
	}

	//Democracy Section: the action with the most votes wins
	var bestAction int
	var mostVotes int
	for a, s := range actionScore {
		if s > mostVotes {
			bestAction = a
			mostVotes = s
		}
	}
	return bestAction, nil
}

A v2/mcts_test.go => v2/mcts_test.go +167 -0
@@ 0,0 1,167 @@
package gmcts

import (
	"fmt"
	"sync"
	"testing"

	tictactoe "git.sr.ht/~bonbon/go-tic-tac-toe"
)

func getPlayerID(ascii byte) Player {
	if ascii == 'x' || ascii == 'X' {
		return Player(0)
	}
	return Player(1)
}

type tttGame struct {
	game    tictactoe.Game
	actions []tictactoe.Move
}

func (g tttGame) Len() int {
	return len(g.actions)
}

func (g tttGame) ApplyAction(i int) (Game, error) {
	game, err := g.game.ApplyAction(g.actions[i])

	return tttGame{game, game.GetActions()}, err
}

func (g tttGame) Hash() interface{} {
	return g.game
}

func (g tttGame) Player() Player {
	return getPlayerID(g.game.Player())
}

func (g tttGame) IsTerminal() bool {
	return g.game.IsTerminal()
}

func (g tttGame) Winners() []Player {
	winner, _ := g.game.Winner()
	if winner == '_' {
		return []Player{Player(0), Player(1)}
	}

	return []Player{getPlayerID(winner)}
}

//Global vars to be checked by other tests
var newGame, finishedGame tttGame
var firstMove tictactoe.Move
var treeToTest *Tree

//TestMain runs through a tictactoe game, saving the first move made and
//the resulting terminal game state into global variables to be used by
//other tests.
func TestMain(m *testing.M) {
	newGame = tttGame{game: tictactoe.NewGame()}
	newGame.actions = newGame.game.GetActions()

	game := newGame
	concurrentSearches := 1 //runtime.NumCPU()

	var setFirstMove sync.Once
	var setTestingTree sync.Once

	for !game.IsTerminal() {
		mcts := NewMCTS(game)

		var wait sync.WaitGroup
		wait.Add(concurrentSearches)
		for i := 0; i < concurrentSearches; i++ {
			go func() {
				tree := mcts.SpawnTree()
				tree.SearchRounds(10000)
				mcts.AddTree(tree)
				wait.Done()

				//Set the tree to perform benchmarks on
				setTestingTree.Do(func() {
					treeToTest = tree
				})
			}()
		}
		wait.Wait()

		bestAction, _ := mcts.BestAction()
		nextState, _ := game.ApplyAction(bestAction)
		game = nextState.(tttGame)
		fmt.Println(game.game)

		//Save the first action taken
		setFirstMove.Do(func() {
			firstMove = newGame.actions[bestAction]
		})
	}
	//Save the terminal game state
	finishedGame = game

	m.Run()
}

func TestTicTacToeDraw(t *testing.T) {
	//Fail if there's a winner. Because tic-tac-toe is a simple game,
	//this algorithm should've finished in a draw.
	if len(finishedGame.Winners()) != 2 {
		t.Errorf("gmcts: tic-tac-toe game did not end in a draw")
		t.FailNow()
	}
}

func TestTicTacToeMiddle(t *testing.T) {
	//Fail if the first move doesn't pick the middle square. Because tic-tac-toe
	//is a simple game, this algorithm should've picked the middle square.
	if fmt.Sprintf("%v", firstMove) != "{1 1}" {
		t.Errorf("gmcts: first action is not to take the middle spot: %v", firstMove)
		t.FailNow()
	}
}

func TestZeroTrees(t *testing.T) {
	mcts := NewMCTS(finishedGame)
	bestAction, _ := mcts.BestAction()
	if bestAction != -1 {
		t.Errorf("gmcts: recieved a best action from no trees: %#v", bestAction)
		t.FailNow()
	}
}

func TestTerminalState(t *testing.T) {
	mcts := NewMCTS(finishedGame)
	mcts.AddTree(mcts.SpawnTree())
	bestAction, _ := mcts.BestAction()
	if bestAction != -1 {
		t.Errorf("gmcts: recieved a best action from a terminal state: %#v", bestAction)
		t.FailNow()
	}
}

func BenchmarkTicTacToe1KRounds(b *testing.B) {
	mcts := NewMCTS(newGame)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(1000)
	}
}

func BenchmarkTicTacToe10KRounds(b *testing.B) {
	mcts := NewMCTS(newGame)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(10000)
	}
}

func BenchmarkTicTacToe100KRounds(b *testing.B) {
	mcts := NewMCTS(newGame)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		mcts.SpawnTree().SearchRounds(100000)
	}
}

A v2/models.go => v2/models.go +80 -0
@@ 0,0 1,80 @@
package gmcts

import (
	"math/rand"
	"sync"
)

//Player is an id for the player
type Player int

//Game is the interface that represents game states.
//
//Any implementation of Game should be immutable
//(state cannot change as this package calls any function).
type Game interface {
	//Len returns the number of actions to consider.
	Len() int

	//ApplyAction applies the ith action (0-indexed) to the game state,
	//and returns a new game state and an error for invalid actions
	ApplyAction(i int) (Game, error)

	//Hash returns a unique representation of the state.
	//Any return value must be comparable.
	Hash() interface{}

	//Player returns the player that can take the next action
	Player() Player

	//IsTerminal returns true if this game state is a terminal state
	IsTerminal() bool

	//Winners returns a list of players that have won the game if
	//IsTerminal() returns true
	Winners() []Player
}

type gameState struct {
	Game
	gameHash
}

type gameHash struct {
	hash interface{}

	//This is to separate states that seemingly look the same,
	//but actually occur on different turn orders. Without this,
	//the directed acyclic graph will become a directed cyclic graph,
	//which this MCTS implementation cannot handle properly.
	turn int
}

//MCTS contains functionality for the MCTS algorithm
type MCTS struct {
	init  Game
	trees []*Tree
	mutex *sync.RWMutex
	seed  int64
}

type node struct {
	state gameState
	tree  *Tree

	children          []*node
	unvisitedChildren []*node
	childVisits       []float64
	actionCount       int

	nodeScore  map[Player]float64
	nodeVisits int
}

//Tree represents a game state tree
type Tree struct {
	current          *node
	gameStates       map[gameHash]*node
	explorationConst float64
	randSource       *rand.Rand
}

A v2/search.go => v2/search.go +133 -0
@@ 0,0 1,133 @@
package gmcts

import (
	"fmt"
	"math"
)

const (
	//DefaultExplorationConst is the default exploration constant of UCB1 Formula
	//Sqrt(2) is a frequent choice for this constant as specified by
	//https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
	DefaultExplorationConst = math.Sqrt2
)

func initializeNode(g gameState, tree *Tree) *node {
	return &node{
		state:     g,
		tree:      tree,
		nodeScore: make(map[Player]float64),
	}
}

//UCT2 algorithm is described in this paper
//https://www.csse.uwa.edu.au/cig08/Proceedings/papers/8057.pdf
func (n *node) UCT2(i int, p Player) float64 {
	exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)

	explore := math.Log(float64(n.nodeVisits)) / n.childVisits[i]
	explore = math.Sqrt(explore)

	return exploit + n.tree.explorationConst*explore
}

func (n *node) runSimulation() ([]Player, float64) {
	var selectedChildIndex int
	var winners []Player
	var scoreToAdd float64
	var terminalState bool

	//If we have actions, then there's no need to expand.
	if n.actionCount == 0 {
		//If we don't have any actions, then either the state
		//is terminal, or we haven't expanded the node yet.
		terminalState = n.state.IsTerminal()
		if !terminalState {
			n.expand()
		}
	}

	if terminalState {
		//Get the result of the game
		winners = n.simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else if len(n.unvisitedChildren) > 0 {
		//Grab the first unvisited child and run a simulation from that point
		selectedChildIndex = n.actionCount - len(n.unvisitedChildren)
		n.children[selectedChildIndex].nodeVisits++
		n.unvisitedChildren = n.unvisitedChildren[1:]

		winners = n.children[selectedChildIndex].simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else {
		//Select the child with the max UCT2 score with the current player
		//and get the results to add from its selection
		maxScore := -1.0
		thisPlayer := n.state.Player()
		for i := 0; i < n.actionCount; i++ {
			score := n.UCT2(i, thisPlayer)
			if score > maxScore {
				maxScore = score
				selectedChildIndex = i
			}
		}
		winners, scoreToAdd = n.children[selectedChildIndex].runSimulation()
	}

	//Update this node along with each parent in this path recursively
	n.nodeVisits++
	if n.actionCount != 0 {
		n.childVisits[selectedChildIndex]++
	}

	for _, p := range winners {
		n.nodeScore[p] += scoreToAdd
	}
	return winners, scoreToAdd
}

func (n *node) expand() {
	n.actionCount = n.state.Len()
	n.unvisitedChildren = make([]*node, n.actionCount)
	n.children = n.unvisitedChildren
	n.childVisits = make([]float64, n.actionCount)
	for i := 0; i < n.actionCount; i++ {
		newGame, err := n.state.ApplyAction(i)
		if err != nil {
			panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
		}

		newState := gameState{newGame, gameHash{newGame.Hash(), n.state.turn + 1}}

		//If we already have a copy in cache, use that and update
		//this node and its parents
		if cachedNode, made := n.tree.gameStates[newState.gameHash]; made {
			n.unvisitedChildren[i] = cachedNode
		} else {
			newNode := initializeNode(newState, n.tree)
			n.unvisitedChildren[i] = newNode

			//Save node for reuse
			n.tree.gameStates[newState.gameHash] = newNode
		}
	}
}

func (n *node) simulate() []Player {
	game := n.state.Game
	for !game.IsTerminal() {
		var err error

		actions := game.Len()
		if actions <= 0 {
			panic(fmt.Sprintf("gmcts: game returned no actions on a non-terminal state: %#v", game))
		}

		randomIndex := n.tree.randSource.Intn(actions)
		game, err = game.ApplyAction(randomIndex)
		if err != nil {
			panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
		}
	}
	return game.Winners()
}

A v2/tree.go => v2/tree.go +91 -0
@@ 0,0 1,91 @@
package gmcts

import (
	"context"
	"time"
)

//Search searches the tree for a specified time
//
//Search will panic if the Game's ApplyAction
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
func (t *Tree) Search(duration time.Duration) {
	ctx, cancel := context.WithTimeout(context.Background(), duration)
	defer cancel()
	t.SearchContext(ctx)
}

//SearchContext searches the tree using a given context
//
//SearchContext will panic if the Game's ApplyAction
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
func (t *Tree) SearchContext(ctx context.Context) {
	for {
		select {
		case <-ctx.Done():
			return
		default:
			t.search()
		}
	}
}

//SearchRounds searches the tree for a specified number of rounds
//
//SearchRounds will panic if the Game's ApplyAction
//method returns an error or if any game state's Hash()
//method returns a noncomparable value.
func (t *Tree) SearchRounds(rounds int) {
	for i := 0; i < rounds; i++ {
		t.search()
	}
}

//search performs 1 round of the MCTS algorithm
func (t *Tree) search() {
	t.current.runSimulation()
}

//Rounds returns the number of MCTS rounds were performed
//on this tree.
func (t Tree) Rounds() int {
	return t.current.nodeVisits
}

//Nodes returns the number of nodes created on this tree.
func (t Tree) Nodes() int {
	return len(t.gameStates)
}

//MaxDepth returns the maximum depth of this tree.
//The value can be thought of as the amount of moves ahead
//this tree searched through.
func (t Tree) MaxDepth() int {
	maxDepth := 0
	for _, node := range t.gameStates {
		if node.state.turn > maxDepth {
			maxDepth = node.state.turn
		}
	}
	return maxDepth
}

func (t *Tree) bestAction() int {
	root := t.current

	//Select the child with the highest winrate
	var bestAction int
	bestWinRate := -1.0
	player := root.state.Player()
	for i := 0; i < root.actionCount; i++ {
		winRate := root.children[i].nodeScore[player] / root.childVisits[i]
		if winRate > bestWinRate {
			bestAction = i
			bestWinRate = winRate
		}
	}

	return bestAction
}

A v2/tree_test.go => v2/tree_test.go +53 -0
@@ 0,0 1,53 @@
package gmcts

import (
	"testing"
	"time"

	tictactoe "git.sr.ht/~bonbon/go-tic-tac-toe"
)

func TestRounds(t *testing.T) {
	rounds := treeToTest.Rounds()
	if rounds != 10000 {
		t.Errorf("Tree performed %d rounds: wanted 1", rounds)
		t.FailNow()
	}
}

func TestNodes(t *testing.T) {
	//The amount of nodes in the tree should not exceed the
	//amount of mcts rounds performed on the tree.
	rounds := treeToTest.Rounds()
	nodes := treeToTest.Nodes()
	if nodes > rounds {
		t.Errorf("Tree has %d nodes: wanted <= %d", nodes, rounds)
		t.FailNow()
	}
}

func TestDepth(t *testing.T) {
	//Because tictactoe is a simple game, the
	//tree should have looked 9 moves ahead.
	depth := treeToTest.MaxDepth()
	if depth != 9 {
		t.Errorf("Tree has depth %d: wanted 0", depth)
		t.FailNow()
	}
}

func TestSearch(t *testing.T) {
	newGame := tictactoe.NewGame()
	mcts := NewMCTS(tttGame{newGame, newGame.GetActions()})
	tree := mcts.SpawnTree()

	timeToSearch := 1 * time.Millisecond
	t0 := time.Now()
	tree.Search(timeToSearch)
	td := time.Now().Sub(t0)

	if td < timeToSearch {
		t.Errorf("Tree was searched for %s: wanted >= %s", td, timeToSearch)
		t.FailNow()
	}
}