~enan/ros-rl

f2037d805bc9ad2936bd3814059e0d8e80e7dcc3 — Enan Ajmain 2 years ago bdde790
td3: plot scores/losses in rows instead of columns
1 files changed, 6 insertions(+), 12 deletions(-)

M src/td3.py
M src/td3.py => src/td3.py +6 -12
@@ 381,18 381,12 @@ class TD3Agent:
        actor_losses: List[float],
        critic_losses: List[float],
    ):
        """Plot the training progresses."""
        plt.figure(figsize=(30, 5))
        plt.subplot(131)
        plt.title("frame %s. score: %s" % (frame_idx, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.plot(avgscores)
        plt.subplot(132)
        plt.title("actor_loss")
        plt.plot(actor_losses)
        plt.subplot(133)
        plt.title("critic_loss")
        plt.plot(critic_losses)

        fig, ax = plt.subplots(3, 1, figsize=(30, 15))
        ax[0].plot(scores)
        ax[0].plot(avgscores)
        ax[1].plot(actor_losses)
        ax[2].plot(critic_losses)
        plt.show()

    def save(self, directory, filename):