~enan/ros-rl

5c359c95dcae5a591a4fc130d8c4ed69e1509659 — Enan Ajmain 2 years ago cebffd3
td3: plot avg score over total score of episode
1 files changed, 11 insertions(+), 3 deletions(-)

M src/td3.py
M src/td3.py => src/td3.py +11 -3
@@ 288,12 288,15 @@ class TD3Agent:
        actor_losses = []
        critic_losses = []
        scores = []
        avgscores = []
        score = 0
        episode = 1
        prev_episode_steps = 0

        for self.total_steps in range(1, num_frames + 1):
            print("STEP:", self.total_steps)
            print("EPISODE:", episode)
            print("TOTAL STEP:", self.total_steps)
            print("EPISODE: %d - %d" % (episode, self.total_steps - prev_episode_steps))
            # print("SCORES:", scores)

            action = self.select_action(state)
            next_state, reward, done = self.step(action)


@@ 306,6 309,8 @@ class TD3Agent:
                state = self.env.reset()
                if self.start_storing:
                    scores.append(score)
                    # scores.append(score/(self.total_steps - prev_episode_steps))
                    avgscores.append(np.mean(scores[-50:]))
                if score > self.max_score and self.start_storing:
                    self.save(directory="./saves",
                              filename=self.model_filename+"_"+str(episode))


@@ 326,7 331,8 @@ class TD3Agent:
                    actor_losses.append(actor_loss)
                critic_losses.append(critic_loss)

        self._plot(self.total_steps, scores, actor_losses, critic_losses)
        self._plot(self.total_steps, scores, avgscores,
                   actor_losses, critic_losses)

        self.env.close()



@@ 374,6 380,7 @@ class TD3Agent:
        self,
        frame_idx: int,
        scores: List[float],
        avgscores: List[float],
        actor_losses: List[float],
        critic_losses: List[float],
    ):


@@ 382,6 389,7 @@ class TD3Agent:
        plt.subplot(131)
        plt.title("frame %s. score: %s" % (frame_idx, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.plot(avgscores)
        plt.subplot(132)
        plt.title("actor_loss")
        plt.plot(actor_losses)