~bonbon/gmcts

ref: f4368f9a11fb1e5e4d23b8cb9a0873b147c98012 gmcts/search.go -rw-r--r-- 3.5 KiB
f4368f9abonbon add Hash() as a required method on Game 1 year, 9 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package gmcts

import (
	"fmt"
	"math"
)

const (
	//DefaultExplorationConst is the default exploration constant of UCB1 Formula
	//Sqrt(2) is a frequent choice for this constant as specified by
	//https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
	DefaultExplorationConst = math.Sqrt2
)

func initializeNode(g gameState, tree *Tree) *node {
	return &node{
		state:     g,
		tree:      tree,
		nodeScore: make(map[Player]float64),
	}
}

func (n *node) UCT2(i int, p Player) float64 {
	exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)

	explore := math.Sqrt(
		math.Log(float64(n.nodeVisits)) / n.childVisits[i],
	)

	return exploit + n.tree.explorationConst*explore
}

func (n *node) runSimulation() ([]Player, float64) {
	var selectedChildIndex int
	var winners []Player
	var scoreToAdd float64
	var terminalState bool

	//If we have actions, then there's no need to expand.
	if n.actionCount == 0 {
		//If we don't have any actions, then either the state
		//is terminal, or we haven't expanded the node yet.
		terminalState = n.state.IsTerminal()
		if !terminalState {
			n.expand()
		}
	}

	if terminalState {
		//Get the result of the game
		winners = n.simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else if len(n.unvisitedChildren) > 0 {
		//Grab the first unvisited child and run a simulation from that point
		selectedChildIndex = n.actionCount - len(n.unvisitedChildren)
		n.children[selectedChildIndex].nodeVisits++
		n.unvisitedChildren = n.unvisitedChildren[1:]

		winners = n.children[selectedChildIndex].simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else {
		//Select the child with the max UCT2 score with the current player
		//and get the results to add from its selection
		maxScore := -1.0
		thisPlayer := n.state.Player()
		for i := 0; i < n.actionCount; i++ {
			score := n.UCT2(i, thisPlayer)
			if score > maxScore {
				maxScore = score
				selectedChildIndex = i
			}
		}
		winners, scoreToAdd = n.children[selectedChildIndex].runSimulation()
	}

	//Update this node along with each parent in this path recursively
	n.nodeVisits++
	if n.actionCount != 0 {
		n.childVisits[selectedChildIndex]++
	}

	for _, p := range winners {
		n.nodeScore[p] += scoreToAdd
	}
	return winners, scoreToAdd
}

func (n *node) expand() {
	n.actions = n.state.GetActions()
	n.actionCount = len(n.actions)
	n.unvisitedChildren = make([]*node, n.actionCount)
	n.children = n.unvisitedChildren
	n.childVisits = make([]float64, n.actionCount)
	for i, a := range n.actions {
		newGame, err := n.state.ApplyAction(a)
		if err != nil {
			panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
		}

		newState := gameState{newGame, gameHash{newGame.Hash(), n.state.turn + 1}}

		//If we already have a copy in cache, use that and update
		//this node and its parents
		if cachedNode, made := n.tree.gameStates[newState.gameHash]; made {
			n.unvisitedChildren[i] = cachedNode
		} else {
			newNode := initializeNode(newState, n.tree)
			n.unvisitedChildren[i] = newNode

			//Save node for reuse
			n.tree.gameStates[newState.gameHash] = newNode
		}
	}
}

func (n *node) simulate() []Player {
	game := n.state.Game
	for !game.IsTerminal() {
		var err error

		actions := game.GetActions()
		panicIfNoActions(game, actions)

		randomIndex := n.tree.randSource.Intn(len(actions))
		game, err = game.ApplyAction(actions[randomIndex])
		if err != nil {
			panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
		}
	}
	return game.Winners()
}