1 files changed, 3 insertions(+), 6 deletions(-)
M src/td3.py
M src/td3.py => src/td3.py +3 -6
@@ 143,12 143,9 @@ class TD3Agent():
if self.total_steps < self.initial_random_steps and not self.is_test:
selected_action = self.env.action_space.sample()
else:
- selected_action = (
- self.actor(torch.FloatTensor(state).to(self.device))[0]
- .detach()
- .cpu()
- .numpy()
- )
+ selected_action = self.actor(
+ torch.FloatTensor(state).to(self.device)
+ ).detach().cpu().numpy()
if not self.is_test:
if self.exploration_noise is not None: