~bonbon/gmcts

ref: 052df34072b069e95496e1379efc0c4d06a84672 gmcts/search.go -rw-r--r-- 3.4 KiB
052df340bonbon reduce amount of IsTerminal calls 1 year, 9 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package gmcts

import (
	"fmt"
	"math"
)

const (
	//DefaultExplorationConst is the default exploration constant of UCB1 Formula
	//Sqrt(2) is a frequent choice for this constant as specified by
	//https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
	DefaultExplorationConst = math.Sqrt2
)

func initializeNode(g gameState, tree *Tree) *node {
	return &node{
		state:     g,
		tree:      tree,
		nodeScore: make(map[Player]float64),
	}
}

func (n *node) UCT2(i int, p Player) float64 {
	exploit := n.children[i].nodeScore[p] / float64(n.children[i].nodeVisits)

	explore := math.Sqrt(
		math.Log(float64(n.nodeVisits)) / n.childVisits[i],
	)

	return exploit + n.tree.explorationConst*explore
}

func (n *node) runSimulation() ([]Player, float64) {
	var selectedChildIndex int
	var winners []Player
	var scoreToAdd float64
	var terminalState bool

	//If we have actions, then there's no need to expand.
	if n.actionCount == 0 {
		//If we don't have any actions, then either the state
		//is terminal, or we haven't expanded the node yet.
		terminalState = n.state.IsTerminal()
		if !terminalState {
			n.expand()
		}
	}

	if terminalState {
		//Get the result of the game
		winners = n.simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else if len(n.unvisitedChildren) > 0 {
		//Grab the first unvisited child and run a simulation from that point
		selectedChildIndex = n.actionCount - len(n.unvisitedChildren)
		n.children[selectedChildIndex].nodeVisits++
		n.unvisitedChildren = n.unvisitedChildren[1:]

		winners = n.children[selectedChildIndex].simulate()
		scoreToAdd = 1.0 / float64(len(winners))
	} else {
		//Select the child with the max UCT2 score with the current player
		//and get the results to add from its selection
		maxScore := -1.0
		thisPlayer := n.state.Player()
		for i := 0; i < n.actionCount; i++ {
			score := n.UCT2(i, thisPlayer)
			if score > maxScore {
				maxScore = score
				selectedChildIndex = i
			}
		}
		winners, scoreToAdd = n.children[selectedChildIndex].runSimulation()
	}

	//Update this node along with each parent in this path recursively
	n.nodeVisits++
	if n.actionCount != 0 {
		n.childVisits[selectedChildIndex]++
	}

	for _, p := range winners {
		n.nodeScore[p] += scoreToAdd
	}
	return winners, scoreToAdd
}

func (n *node) expand() {
	n.actions = n.state.GetActions()
	n.actionCount = len(n.actions)
	n.unvisitedChildren = make([]*node, n.actionCount)
	n.children = n.unvisitedChildren
	n.childVisits = make([]float64, n.actionCount)
	for i, a := range n.actions {
		newGame, err := n.state.ApplyAction(a)
		if err != nil {
			panic(fmt.Sprintf("gmcts: Game returned an error when exploring the tree: %s", err))
		}

		newState := gameState{newGame, n.state.turn + 1}

		//If we already have a copy in cache, use that and update
		//this node and its parents
		if cachedNode, made := n.tree.gameStates[newState]; made {
			n.unvisitedChildren[i] = cachedNode
		} else {
			newNode := initializeNode(newState, n.tree)
			n.unvisitedChildren[i] = newNode

			//Save node for reuse
			n.tree.gameStates[newState] = newNode
		}
	}
}

func (n *node) simulate() []Player {
	game := n.state.Game
	for !game.IsTerminal() {
		var err error

		actions := game.GetActions()
		panicIfNoActions(game, actions)

		randomIndex := n.tree.randSource.Intn(len(actions))
		game, err = game.ApplyAction(actions[randomIndex])
		if err != nil {
			panic(fmt.Sprintf("gmcts: game returned an error while searching the tree: %s", err))
		}
	}
	return game.Winners()
}