~bonbon/gmcts

ref: 052df34072b069e95496e1379efc0c4d06a84672 gmcts/tree.go -rw-r--r-- 1.9 KiB
052df340bonbon reduce amount of IsTerminal calls 1 year, 9 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package gmcts

import (
	"context"
	"time"
)

//Search searches the tree for a specified time
//
//Search will panic if the Game's ApplyAction
//method returns an error
func (t *Tree) Search(duration time.Duration) {
	ctx, cancel := context.WithTimeout(context.Background(), duration)
	defer cancel()
	t.SearchContext(ctx)
}

//SearchContext searches the tree using a given context
//
//SearchContext will panic if the Game's ApplyAction
//method returns an error
func (t *Tree) SearchContext(ctx context.Context) {
	for {
		select {
		case <-ctx.Done():
			return
		default:
			t.search()
		}
	}
}

//SearchRounds searches the tree for a specified number of rounds
//
//SearchRounds will panic if the Game's ApplyAction
//method returns an error
func (t *Tree) SearchRounds(rounds int) {
	for i := 0; i < rounds; i++ {
		t.search()
	}
}

//search performs 1 round of the MCTS algorithm
func (t *Tree) search() {
	t.current.runSimulation()
}

//Rounds returns the number of MCTS rounds were performed
//on this tree.
func (t Tree) Rounds() int {
	return t.current.nodeVisits
}

//Nodes returns the number of nodes created on this tree.
func (t Tree) Nodes() int {
	return len(t.gameStates)
}

//MaxDepth returns the maximum depth of this tree.
//The value can be thought of as the amount of moves ahead
//this tree searched through.
func (t Tree) MaxDepth() int {
	maxDepth := 0
	for state := range t.gameStates {
		if state.turn > maxDepth {
			maxDepth = state.turn
		}
	}
	return maxDepth
}

func (t *Tree) bestAction() Action {
	root := t.current

	//Select the child with the highest winrate
	var bestAction Action
	bestWinRate := -1.0
	player := root.state.Player()
	for i := 0; i < root.actionCount; i++ {
		winRate := root.children[i].nodeScore[player] / root.childVisits[i]
		if winRate > bestWinRate {
			bestAction = root.actions[i]
			bestWinRate = winRate
		}
	}

	return bestAction
}