@@ 223,7 223,7 @@ class TD3Agent():
else:
actor_loss = torch.zeros(1)
- return actor_loss.data, critic_loss.data
+ return actor_loss.data.detach().cpu().numpy(), critic_loss.data.detach().cpu().numpy()
def train(self, num_frames: int, plotting_interval: int = 200):
self.is_test = False