@@ 381,18 381,12 @@ class TD3Agent:
actor_losses: List[float],
critic_losses: List[float],
):
- """Plot the training progresses."""
- plt.figure(figsize=(30, 5))
- plt.subplot(131)
- plt.title("frame %s. score: %s" % (frame_idx, np.mean(scores[-10:])))
- plt.plot(scores)
- plt.plot(avgscores)
- plt.subplot(132)
- plt.title("actor_loss")
- plt.plot(actor_losses)
- plt.subplot(133)
- plt.title("critic_loss")
- plt.plot(critic_losses)
+
+ fig, ax = plt.subplots(3, 1, figsize=(30, 15))
+ ax[0].plot(scores)
+ ax[0].plot(avgscores)
+ ax[1].plot(actor_losses)
+ ax[2].plot(critic_losses)
plt.show()
def save(self, directory, filename):