~piotr-machura/sweep-ai

6b57599581b7b45470702ef2a1ab825b313f1fc9 — Piotr Machura 10 months ago 8aed484
91% accuracy sounds good to me
5 files changed, 232 insertions(+), 54 deletions(-)

M .gitignore
M poetry.lock
M pyproject.toml
M sweep_ai/ai.py
M sweep_ai/window.py
M .gitignore => .gitignore +1 -0
@@ 4,3 4,4 @@ __pycache__
*.pyc
.pytest_cache
.mypy_cache
fit.pdf

M poetry.lock => poetry.lock +135 -1
@@ 252,6 252,14 @@ colors = ["colorama (>=0.4.3,<0.5.0)"]
plugins = ["setuptools"]

[[package]]
name = "joblib"
version = "1.1.0"
description = "Lightweight pipelining with Python functions"
category = "main"
optional = false
python-versions = ">=3.6"

[[package]]
name = "keras"
version = "2.8.0rc1"
description = "Deep learning for humans."


@@ 652,6 660,37 @@ python-versions = ">=3.6,<4"
pyasn1 = ">=0.1.3"

[[package]]
name = "scikit-learn"
version = "1.0.2"
description = "A set of python modules for machine learning and data mining"
category = "main"
optional = false
python-versions = ">=3.7"

[package.dependencies]
joblib = ">=0.11"
numpy = ">=1.14.6"
scipy = ">=1.1.0"
threadpoolctl = ">=2.0.0"

[package.extras]
benchmark = ["matplotlib (>=2.2.3)", "pandas (>=0.25.0)", "memory-profiler (>=0.57.0)"]
docs = ["matplotlib (>=2.2.3)", "scikit-image (>=0.14.5)", "pandas (>=0.25.0)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.0.0)", "Pillow (>=7.1.2)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"]
examples = ["matplotlib (>=2.2.3)", "scikit-image (>=0.14.5)", "pandas (>=0.25.0)", "seaborn (>=0.9.0)"]
tests = ["matplotlib (>=2.2.3)", "scikit-image (>=0.14.5)", "pandas (>=0.25.0)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "black (>=21.6b0)", "mypy (>=0.770)", "pyamg (>=4.0.0)"]

[[package]]
name = "scipy"
version = "1.7.3"
description = "SciPy: Scientific Library for Python"
category = "main"
optional = false
python-versions = ">=3.7,<3.11"

[package.dependencies]
numpy = ">=1.16.5,<1.23.0"

[[package]]
name = "setuptools-scm"
version = "6.4.2"
description = "the blessed package to manage your versions by scm tags"


@@ 676,6 715,17 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"

[[package]]
name = "sklearn"
version = "0.0"
description = "A set of python modules for machine learning and data mining"
category = "main"
optional = false
python-versions = "*"

[package.dependencies]
scikit-learn = "*"

[[package]]
name = "snowballstemmer"
version = "2.2.0"
description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms."


@@ 781,6 831,14 @@ optional = false
python-versions = "*"

[[package]]
name = "threadpoolctl"
version = "3.0.0"
description = "threadpoolctl"
category = "main"
optional = false
python-versions = ">=3.6"

[[package]]
name = "toml"
version = "0.10.2"
description = "Python Library for Tom's Obvious, Minimal Language"


@@ 859,7 917,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "71d09bed82225317a6fb4531770e645a39d566d8ce8f53ee5967ee01c6d32f89"
content-hash = "a8736cc0a5e18e65daf72234f2284f5d88807c703a13c612b9a7eb28ce966327"

[metadata.files]
absl-py = [


@@ 1007,6 1065,10 @@ isort = [
    {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
    {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
]
joblib = [
    {file = "joblib-1.1.0-py2.py3-none-any.whl", hash = "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6"},
    {file = "joblib-1.1.0.tar.gz", hash = "sha256:4158fcecd13733f8be669be0683b96ebdbbd38d23559f54dca7205aea1bf1e35"},
]
keras = [
    {file = "keras-2.8.0rc1-py2.py3-none-any.whl", hash = "sha256:ce86dd1133bc9a8b1cafab48013c28f8651e5093e9429209b0f8bb7acf1310ac"},
]


@@ 1433,6 1495,71 @@ rsa = [
    {file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"},
    {file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"},
]
scikit-learn = [
    {file = "scikit-learn-1.0.2.tar.gz", hash = "sha256:b5870959a5484b614f26d31ca4c17524b1b0317522199dc985c3b4256e030767"},
    {file = "scikit_learn-1.0.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:da3c84694ff693b5b3194d8752ccf935a665b8b5edc33a283122f4273ca3e687"},
    {file = "scikit_learn-1.0.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:75307d9ea39236cad7eea87143155eea24d48f93f3a2f9389c817f7019f00705"},
    {file = "scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f14517e174bd7332f1cca2c959e704696a5e0ba246eb8763e6c24876d8710049"},
    {file = "scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9aac97e57c196206179f674f09bc6bffcd0284e2ba95b7fe0b402ac3f986023"},
    {file = "scikit_learn-1.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:d93d4c28370aea8a7cbf6015e8a669cd5d69f856cc2aa44e7a590fb805bb5583"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:85260fb430b795d806251dd3bb05e6f48cdc777ac31f2bcf2bc8bbed3270a8f5"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a053a6a527c87c5c4fa7bf1ab2556fa16d8345cf99b6c5a19030a4a7cd8fd2c0"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:245c9b5a67445f6f044411e16a93a554edc1efdcce94d3fc0bc6a4b9ac30b752"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158faf30684c92a78e12da19c73feff9641a928a8024b4fa5ec11d583f3d8a87"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08ef968f6b72033c16c479c966bf37ccd49b06ea91b765e1cc27afefe723920b"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16455ace947d8d9e5391435c2977178d0ff03a261571e67f627c8fee0f9d431a"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-win32.whl", hash = "sha256:2f3b453e0b149898577e301d27e098dfe1a36943f7bb0ad704d1e548efc3b448"},
    {file = "scikit_learn-1.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:46f431ec59dead665e1370314dbebc99ead05e1c0a9df42f22d6a0e00044820f"},
    {file = "scikit_learn-1.0.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:ff3fa8ea0e09e38677762afc6e14cad77b5e125b0ea70c9bba1992f02c93b028"},
    {file = "scikit_learn-1.0.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:9369b030e155f8188743eb4893ac17a27f81d28a884af460870c7c072f114243"},
    {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7d6b2475f1c23a698b48515217eb26b45a6598c7b1840ba23b3c5acece658dbb"},
    {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:285db0352e635b9e3392b0b426bc48c3b485512d3b4ac3c7a44ec2a2ba061e66"},
    {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cb33fe1dc6f73dc19e67b264dbb5dde2a0539b986435fdd78ed978c14654830"},
    {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1391d1a6e2268485a63c3073111fe3ba6ec5145fc957481cfd0652be571226d"},
    {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc3744dabc56b50bec73624aeca02e0def06b03cb287de26836e730659c5d29c"},
    {file = "scikit_learn-1.0.2-cp38-cp38-win32.whl", hash = "sha256:a999c9f02ff9570c783069f1074f06fe7386ec65b84c983db5aeb8144356a355"},
    {file = "scikit_learn-1.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:7626a34eabbf370a638f32d1a3ad50526844ba58d63e3ab81ba91e2a7c6d037e"},
    {file = "scikit_learn-1.0.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:a90b60048f9ffdd962d2ad2fb16367a87ac34d76e02550968719eb7b5716fd10"},
    {file = "scikit_learn-1.0.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7a93c1292799620df90348800d5ac06f3794c1316ca247525fa31169f6d25855"},
    {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:eabceab574f471de0b0eb3f2ecf2eee9f10b3106570481d007ed1c84ebf6d6a1"},
    {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:55f2f3a8414e14fbee03782f9fe16cca0f141d639d2b1c1a36779fa069e1db57"},
    {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80095a1e4b93bd33261ef03b9bc86d6db649f988ea4dbcf7110d0cded8d7213d"},
    {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fa38a1b9b38ae1fad2863eff5e0d69608567453fdfc850c992e6e47eb764e846"},
    {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff746a69ff2ef25f62b36338c615dd15954ddc3ab8e73530237dd73235e76d62"},
    {file = "scikit_learn-1.0.2-cp39-cp39-win32.whl", hash = "sha256:e174242caecb11e4abf169342641778f68e1bfaba80cd18acd6bc84286b9a534"},
    {file = "scikit_learn-1.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:b54a62c6e318ddbfa7d22c383466d38d2ee770ebdb5ddb668d56a099f6eaf75f"},
]
scipy = [
    {file = "scipy-1.7.3-1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:c9e04d7e9b03a8a6ac2045f7c5ef741be86727d8f49c45db45f244bdd2bcff17"},
    {file = "scipy-1.7.3-1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b0e0aeb061a1d7dcd2ed59ea57ee56c9b23dd60100825f98238c06ee5cc4467e"},
    {file = "scipy-1.7.3-1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b78a35c5c74d336f42f44106174b9851c783184a85a3fe3e68857259b37b9ffb"},
    {file = "scipy-1.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:173308efba2270dcd61cd45a30dfded6ec0085b4b6eb33b5eb11ab443005e088"},
    {file = "scipy-1.7.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:21b66200cf44b1c3e86495e3a436fc7a26608f92b8d43d344457c54f1c024cbc"},
    {file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceebc3c4f6a109777c0053dfa0282fddb8893eddfb0d598574acfb734a926168"},
    {file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7eaea089345a35130bc9a39b89ec1ff69c208efa97b3f8b25ea5d4c41d88094"},
    {file = "scipy-1.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:304dfaa7146cffdb75fbf6bb7c190fd7688795389ad060b970269c8576d038e9"},
    {file = "scipy-1.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:033ce76ed4e9f62923e1f8124f7e2b0800db533828c853b402c7eec6e9465d80"},
    {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4d242d13206ca4302d83d8a6388c9dfce49fc48fdd3c20efad89ba12f785bf9e"},
    {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8499d9dd1459dc0d0fe68db0832c3d5fc1361ae8e13d05e6849b358dc3f2c279"},
    {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca36e7d9430f7481fc7d11e015ae16fbd5575615a8e9060538104778be84addf"},
    {file = "scipy-1.7.3-cp37-cp37m-win32.whl", hash = "sha256:e2c036492e673aad1b7b0d0ccdc0cb30a968353d2c4bf92ac8e73509e1bf212c"},
    {file = "scipy-1.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:866ada14a95b083dd727a845a764cf95dd13ba3dc69a16b99038001b05439709"},
    {file = "scipy-1.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:65bd52bf55f9a1071398557394203d881384d27b9c2cad7df9a027170aeaef93"},
    {file = "scipy-1.7.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f99d206db1f1ae735a8192ab93bd6028f3a42f6fa08467d37a14eb96c9dd34a3"},
    {file = "scipy-1.7.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5f2cfc359379c56b3a41b17ebd024109b2049f878badc1e454f31418c3a18436"},
    {file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb7ae2c4dbdb3c9247e07acc532f91077ae6dbc40ad5bd5dca0bb5a176ee9bda"},
    {file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c2d250074cfa76715d58830579c64dff7354484b284c2b8b87e5a38321672c"},
    {file = "scipy-1.7.3-cp38-cp38-win32.whl", hash = "sha256:87069cf875f0262a6e3187ab0f419f5b4280d3dcf4811ef9613c605f6e4dca95"},
    {file = "scipy-1.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:7edd9a311299a61e9919ea4192dd477395b50c014cdc1a1ac572d7c27e2207fa"},
    {file = "scipy-1.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eef93a446114ac0193a7b714ce67659db80caf940f3232bad63f4c7a81bc18df"},
    {file = "scipy-1.7.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb326658f9b73c07081300daba90a8746543b5ea177184daed26528273157294"},
    {file = "scipy-1.7.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93378f3d14fff07572392ce6a6a2ceb3a1f237733bd6dcb9eb6a2b29b0d19085"},
    {file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edad1cf5b2ce1912c4d8ddad20e11d333165552aba262c882e28c78bbc09dbf6"},
    {file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d1cc2c19afe3b5a546ede7e6a44ce1ff52e443d12b231823268019f608b9b12"},
    {file = "scipy-1.7.3-cp39-cp39-win32.whl", hash = "sha256:2c56b820d304dffcadbbb6cbfbc2e2c79ee46ea291db17e288e73cd3c64fefa9"},
    {file = "scipy-1.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f78181a153fa21c018d346f595edd648344751d7f03ab94b398be2ad083ed3e"},
    {file = "scipy-1.7.3.tar.gz", hash = "sha256:ab5875facfdef77e0a47d5fd39ea178b58e60e454a4c85aa1e52fcb80db7babf"},
]
setuptools-scm = [
    {file = "setuptools_scm-6.4.2-py3-none-any.whl", hash = "sha256:acea13255093849de7ccb11af9e1fb8bde7067783450cee9ef7a93139bddf6d4"},
    {file = "setuptools_scm-6.4.2.tar.gz", hash = "sha256:6833ac65c6ed9711a4d5d2266f8024cfa07c533a0e55f4c12f6eff280a5a9e30"},


@@ 1441,6 1568,9 @@ six = [
    {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
    {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
sklearn = [
    {file = "sklearn-0.0.tar.gz", hash = "sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31"},
]
snowballstemmer = [
    {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"},
    {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},


@@ 1490,6 1620,10 @@ termcolor = [
tf-estimator-nightly = [
    {file = "tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl", hash = "sha256:0065a04e396b2890bd19761fc1de7559ceafeba12839f8db2c7e7473afaaf612"},
]
threadpoolctl = [
    {file = "threadpoolctl-3.0.0-py3-none-any.whl", hash = "sha256:4fade5b3b48ae4b1c30f200b28f39180371104fccc642e039e0f2435ec8cc211"},
    {file = "threadpoolctl-3.0.0.tar.gz", hash = "sha256:d03115321233d0be715f0d3a5ad1d6c065fe425ddc2d671ca8e45e9fd5d7a52a"},
]
toml = [
    {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
    {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},

M pyproject.toml => pyproject.toml +1 -0
@@ 11,6 11,7 @@ pygame-menu = "^4.2.2"
numpy = "^1.22.0"
tensorflow = "^2.8.0rc1"
pygame = "^2.1.2"
sklearn = "^0.0"

[tool.poetry.dev-dependencies]
yapf = "^0.32.0"

M sweep_ai/ai.py => sweep_ai/ai.py +88 -46
@@ 2,8 2,9 @@
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.utils import class_weight, shuffle
from tensorflow import keras

from .logic import State


@@ 13,17 14,29 @@ class Player:
    """The AI player, capable of making informed decisions.

    Fields:
        fitness: how fit the brain is. 0 for a perfect player, `-inf` for one
            that does not have an initiated brain.
        brain: the neural network powering the decision-making process.
        size: the size of the board on which the player has been trained.
    """

    def __init__(self, size: int):
        """Constructs a new `Player` for a board of provided `size`."""
    def __init__(self, size: int, difficulty: float):
        """Constructs a new `Player`."""
        self.size = size
        self.fitness = -np.inf
        self.brain: Optional[keras.models.Sequential] = None
        self.difficulty = difficulty
        self.trained = False
        self.brain = keras.models.Sequential(
            [
                keras.layers.InputLayer(input_shape=(24, )),
                keras.layers.Dense(120, activation='relu'),
                keras.layers.Dense(60, activation='relu'),
                keras.layers.Dense(30, activation='relu'),
                keras.layers.Dense(15, activation='relu'),
                keras.layers.Dense(1, activation='sigmoid'),
            ])
        self.brain.compile(
            keras.optimizers.Adam(learning_rate=0.0025),
            loss='binary_crossentropy',
            metrics=['accuracy'],
        )

    @staticmethod
    def surround(state: State, x: int, y: int) -> np.ndarray:


@@ 31,75 44,104 @@ class Player:

        If the neighbor is empty (or is beyond the border) the array contains
        0. If it's hidden the state contains -1. If it's revealed, but near a
        bomb then it's value.
        bomb then it's value is the number of bombs near it.
        """
        surround = np.zeros((24, ), dtype=float)
        for x_n, y_n in state.neighbors(x, y, radius=2):
            if state.hidden[x_n, y_n]:
                surround[x_n + y_n] = -1
            if state.revealed[x_n, y_n] and state.near[x_n, y_n] > 0:
                surround[x_n + y_n] = state.near[x_n, y_n]
        return surround

    def predict(self, state: State, x: int, y: int) -> float:
        """Predict if the tile `(x, y)` is safe."""
        return float(
            self.brain.predict(
                np.array([self.surround(state, x, y)]),
            )[0][0],
        )

    @staticmethod
    def prosess_tile(state: State, x: int, y: int) -> bool:
        """Returns `true` if `(x, y)` is suitable for making a move."""
        return not (state.revealed[x, y] or state.flagged[x, y])

    def move(self, state: State) -> Optional[Tuple[int, int]]:
        """Make a move.

        Returns `None` if the brain has not been compiled yet or if the board
        size does not match players capabilities, and a suggested `(x, y)` move
        otherwise.
        Returns `None` if the brain has not been compiled yet , and a suggested
        `(x, y)` move otherwise.
        """
        if state.size != self.size:
            return None
        if self.brain is None:
            return None
        return (1, 1)
        if self.trained:
            pos = max(
                [
                    (x, y)
                    for x in range(state.size)
                    for y in range(state.size)
                    if self.prosess_tile(state, x, y)
                ],
                key=lambda pos: self.predict(state, *pos),
            )
            return pos
        return None

    def train(self):
        """Recrete and train the AI brain of this player."""
        self.brain = keras.models.Sequential(
            [
                keras.layers.Dense(120, activation='relu', input_shape=(24, )),
                keras.layers.Dropout(0.4),
                keras.layers.Dense(60, activation='relu'),
                keras.layers.Dropout(0.2),
                keras.layers.Dense(1, activation='sigmoid'),
            ])
        self.brain.compile(
            keras.optimizers.Adam(learning_rate=0.0005),
            loss='binary_crossentropy',
            metrics=['accuracy'],
        )
        self.brain.summary()

        x_train = []
        y_train = []
        weights = []
        for _ in range(100):
            state = State(self.size, 0.15)
            state.click(*state.cheat())
            state.click(*state.cheat())
        while len(x_train) < 1e4:
            state = State(self.size, 0.2)
            # CLick on a guaranteed empty space
            x_click, y_click = np.transpose(np.nonzero(state.near == 0))[0]
            state.click(x_click, y_click)
            # Play the game until won
            while state.won is None:
                state.click(*state.cheat())
                for x in range(self.size):
                    for y in range(self.size):
                        if state.hidden[x, y]:
                        if self.prosess_tile(state, x, y):
                            # The neural network input - 24 surrouding tiles
                            sur = self.surround(state, x, y)
                            x_train.append(sur)
                            # Network should output 1 on safe tiles
                            y_train.append(state.safe[x, y])
                            weights.append(np.sum(sur))
                state.click(*state.cheat())

        x_train = np.array(x_train)
        x_train = StandardScaler().fit_transform(np.array(x_train))
        print(state.bomb_n)
        y_train = np.array(y_train)
        weights = np.array(weights)
        x_train, y_train = shuffle(x_train, y_train)
        print(y_train)
        weights = dict(
            zip(
                np.unique(y_train),
                class_weight.compute_class_weight(
                    'balanced',
                    classes=np.unique(y_train),
                    y=y_train,
                ),
            ),
        )

        history = self.brain.fit(
            x_train,
            y_train,
            sample_weight=weights,
            epochs=10,
            class_weight=weights,
            epochs=50,
            batch_size=50,
            verbose=1,
            validation_split=0.2,
            shuffle=True,
            validation_split=0.25,
            use_multiprocessing=True,
        )
        plt.plot(history.history['val_accuracy'])
        plt.plot(history.history['val_loss'])
        plt.show()
        _, axs = plt.subplots(1, 2)
        axs[0].set_title('Accuracy')
        axs[0].plot(history.history['val_accuracy'])
        axs[0].plot(history.history['accuracy'])

        axs[1].set_title('Loss')
        axs[1].plot(history.history['val_loss'])
        axs[1].plot(history.history['loss'])
        plt.savefig('fit.pdf')
        self.trained = True

M sweep_ai/window.py => sweep_ai/window.py +7 -7
@@ 40,7 40,7 @@ class Game:
            (self.display_width, self.display_height),
        )
        self.events = []
        self.players: Dict[int, Player] = {}
        self.player = Player(self.size, self.difficulty)

        self.sprites = {}
        self.sprites['flag'] = pygame.image.load('assets/flag.png')


@@ 173,12 173,12 @@ class Game:
        self.reset()

    def get_hint(self):
        """Highlight the three safest."""
        if self.players.get(self.size) is None:
            player = Player(self.size)
            player.train()
            self.players[self.size] = player
        # self.hint = self.players[self.size].move()
        """Highlight the position suggested by the AI player."""
        hint = self.player.move(self.state)
        if hint is None:
            self.player.train()
            hint = self.player.move(self.state)
        self.hint = hint

    def within_board(self, pos_x: float, pos_y: float) -> bool:
        """Returns `true` if `pos_x`, `pos_y` is within the board."""