~bonbon/gmcts

ref: 698483e3eafccc0ae6e52debbe1f23622a95129c gmcts/search.go -rw-r--r-- 3.5 KiB
698483e3bonbon switch node.nodeVisits from float64 to int 1 year, 10 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package gmcts

import (
	"fmt"
	"math"
)

const (
	//DefaultExplorationConst is the default exploration constant of UCB1 Formula
	//Sqrt(2) is a frequent choice for this constant as specified by
	//https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
	DefaultExplorationConst = math.Sqrt2
)

func initializeNode(g gameState, tree *Tree) *node {
	return &node{
		state:     g,
		tree:      tree,
		nodeScore: make(map[Player]float64),
	}
}

func (n *node) UCT2(i int, p Player) float64 {
	exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)

	explore := math.Sqrt(
		math.Log(float64(n.nodeVisits)) / n.childVisits[i],
	)

	return exploit + n.tree.explorationConst*explore
}

func (n *node) runSimulation() ([]Player, float64) {
	var selectedChildIndex int
	var winners []Player
	var scoreToAdd float64

	terminalState := n.state.IsTerminal()
	if !terminalState && n.actionCount == 0 {
		n.expand()
	}

	if terminalState {
		//Get the result of the game
		winners = n.simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else if len(n.unvisitedChildren) > 0 {
		//Grab the first unvisited child and run a simulation from that point
		selectedChildIndex = n.actionCount - len(n.unvisitedChildren)
		n.children[selectedChildIndex].nodeVisits++
		n.unvisitedChildren = n.unvisitedChildren[1:]

		winners = n.children[selectedChildIndex].simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else {
		//Select the child with the max UCT2 score with the current player
		//and get the results to add from its selection
		maxScore := -1.0
		thisPlayer := n.state.Player()
		for i := 0; i < n.actionCount; i++ {
			score := n.UCT2(i, thisPlayer)
			if score > maxScore {
				maxScore = score
				selectedChildIndex = i
			}
		}
		winners, scoreToAdd = n.children[selectedChildIndex].runSimulation()
	}

	//Update this node along with each parent in this path recursively
	n.nodeVisits++
	if n.actionCount != 0 {
		n.childVisits[selectedChildIndex]++
	}

	for _, p := range winners {
		n.nodeScore[p] += scoreToAdd
	}
	return winners, scoreToAdd
}

func (n *node) isParentOf(potentialChild *node) bool {
	for _, child := range n.children {
		if child != nil && child == potentialChild {
			return true
		}
	}
	return false
}

func (n *node) expand() {
	n.actions = n.state.GetActions()
	n.actionCount = len(n.actions)
	n.unvisitedChildren = make([]*node, n.actionCount)
	n.children = n.unvisitedChildren
	n.childVisits = make([]float64, n.actionCount)
	for i, a := range n.actions {
		newGame, err := n.state.ApplyAction(a)
		if err != nil {
			panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
		}

		newState := gameState{newGame, n.state.turn + 1}

		//If we already have a copy in cache, use that and update
		//this node and its parents
		if cachedNode, made := n.tree.gameStates[newState]; made {
			if n.isParentOf(cachedNode) {
				continue
			}

			n.unvisitedChildren[i] = cachedNode
		} else {
			newNode := initializeNode(newState, n.tree)
			n.unvisitedChildren[i] = newNode

			//Save node for reuse
			n.tree.gameStates[newState] = newNode
		}
	}
}

func (n *node) simulate() []Player {
	game := n.state.Game
	for !game.IsTerminal() {
		var err error

		actions := game.GetActions()
		panicIfNoActions(game, actions)

		randomIndex := n.tree.randSource.Intn(len(actions))
		game, err = game.ApplyAction(actions[randomIndex])
		if err != nil {
			panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
		}
	}
	return game.Winners()
}