~bonbon/gmcts

3ec7baa7d4d7c7e325d3bed3c7211660d9156612 — bonbon 1 year, 6 months ago 698483e
add statistical methods on trees
3 files changed, 68 insertions(+), 0 deletions(-)

M mcts_test.go
M tree.go
A tree_test.go
M mcts_test.go => mcts_test.go +7 -0
@@ 62,6 62,7 @@ func (g tttGame) Winners() []Player {
//Global vars to be checked by other tests
var finishedGame tttGame
var firstMove tictactoe.Move
var treeToTest *Tree

//TestMain runs through a tictactoe game, saving the first move made and
//the resulting terminal game state into global variables to be used by


@@ 71,6 72,7 @@ func TestMain(m *testing.M) {
	concurrentSearches := 1 //runtime.NumCPU()

	var setFirstMove sync.Once
	var setTestingTree sync.Once

	for !game.IsTerminal() {
		mcts := NewMCTS(game)


@@ 83,6 85,11 @@ func TestMain(m *testing.M) {
				tree.SearchRounds(10000)
				mcts.AddTree(tree)
				wait.Done()

				//Set the tree to perform benchmarks on
				setTestingTree.Do(func() {
					treeToTest = tree
				})
			}()
		}
		wait.Wait()

M tree.go => tree.go +24 -0
@@ 45,6 45,30 @@ func (t *Tree) search() {
	t.current.runSimulation()
}

//Rounds returns the number of MCTS rounds were performed
//on this tree.
func (t Tree) Rounds() int {
	return t.current.nodeVisits
}

//Nodes returns the number of nodes created on this tree.
func (t Tree) Nodes() int {
	return len(t.gameStates)
}

//MaxDepth returns the maximum depth of this tree.
//The value can be thought of as the amount of moves ahead
//this tree searched through.
func (t Tree) MaxDepth() int {
	maxDepth := 0
	for state := range t.gameStates {
		if state.turn > maxDepth {
			maxDepth = state.turn
		}
	}
	return maxDepth
}

func (t *Tree) bestAction() Action {
	root := t.current


A tree_test.go => tree_test.go +37 -0
@@ 0,0 1,37 @@
package gmcts

import (
	"testing"
	"time"

	tictactoe "git.sr.ht/~bonbon/go-tic-tac-toe"
)

func TestRounds(t *testing.T) {
	rounds := treeToTest.Rounds()
	if rounds != 10000 {
		t.Errorf("Tree performed %d rounds: wanted 1", rounds)
		t.FailNow()
	}
}

func TestNodes(t *testing.T) {
	//The amount of nodes in the tree should not exceed the
	//amount of mcts rounds performed on the tree.
	rounds := treeToTest.Rounds()
	nodes := treeToTest.Nodes()
	if nodes > rounds {
		t.Errorf("Tree has %d nodes: wanted <= %d", nodes, rounds)
		t.FailNow()
	}
}

func TestDepth(t *testing.T) {
	//Because tictactoe is a simple game, the
	//tree should have looked 9 moves ahead.
	depth := treeToTest.MaxDepth()
	if depth != 9 {
		t.Errorf("Tree has depth %d: wanted 0", depth)
		t.FailNow()
	}
}