~piotr-machura/sweep-ai

c926d8d99ba9c1df35540a9726e0644d4da25be0 — Piotr Machura 2 years ago 15237b1 conv
Docstring fix
1 files changed, 7 insertions(+), 6 deletions(-)

M sweep_ai/ai.py
M sweep_ai/ai.py => sweep_ai/ai.py +7 -6
@@ 81,16 81,14 @@ class Player:
    def expect(state: State, x: int, y: int) -> float:
        """What should the network output at `(x, y)` be?

        Returns `# tiles / # safe & hidden ` if `(x, y)` is hidden, next to
        a revealed spot and safe, `# tiles / # bombs` if `(x, y)` is a bomb,
        `# tiles / # revealed ` otherwise.
        Returns `1` if `(x, y)` is hidden, next to a revealed spot and safe,
        `0` otherwise.
        """
        if state.safe[x, y] and state.hidden[x, y]:
            for x_n, y_n in state.neighbors(x, y):
                if state.revealed[x_n, y_n]:
                    return state.size * state.size / (
                        state.safe_n - state.revealed_n)
        return state.size * state.size / (state.revealed_n + 1)
                    return 1
        return 0

    def training_data(self, state: State) -> Tuple[np.ndarray, np.ndarray]:
        """Returns a tuple of `(neural network input, expected output)`."""


@@ 115,12 113,15 @@ class Player:
        output = self.brain.predict(
            np.stack([self.training_data(state)[0]]),
        )[0]
        # print(output)
        output = np.where(
            (state.revealed == 1) | (state.flagged == 1),
            0,
            output,
        )
        # print(output)
        pos = np.unravel_index(np.argmax(output, axis=None), output.shape)
        # print(pos)
        return pos

    def train(self):