@@ 69,7 69,6 @@ class GaussianNoise:
class Actor(nn.Module):
def __init__(self, in_dim: int, out_dim: int, init_w: float = 3e-3):
- """Initialize."""
super(Actor, self).__init__()
self.hidden1 = nn.Linear(in_dim, 128)
@@ 80,7 79,6 @@ class Actor(nn.Module):
self.out.bias.data.uniform_(-init_w, init_w)
def forward(self, state: torch.Tensor) -> torch.Tensor:
- """Forward method implementation."""
x = F.relu(self.hidden1(state))
x = F.relu(self.hidden2(x))
action = self.out(x).tanh()
@@ 90,7 88,6 @@ class Actor(nn.Module):
class Critic(nn.Module):
def __init__(self, in_dim: int, init_w: float = 3e-3):
- """Initialize."""
super(Critic, self).__init__()
self.hidden1 = nn.Linear(in_dim, 128)
@@ 101,7 98,6 @@ class Critic(nn.Module):
self.out.bias.data.uniform_(-init_w, init_w)
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
- """Forward method implementation."""
x = torch.cat((state, action), dim=-1)
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
@@ 110,7 106,7 @@ class Critic(nn.Module):
return value
-class TD3Agent:
+class TD3Agent():
def __init__(
self,
env: gym.Env,
@@ 128,7 124,6 @@ class TD3Agent:
wd_actor: float = 1e-2,
model_filename: str = "td3"
):
- """Initialize."""
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
@@ 140,7 135,6 @@ class TD3Agent:
self.initial_random_steps = initial_random_steps
self.policy_update_freq = policy_update_freq
- # device: cpu / gpu
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.device)
@@ 169,20 163,12 @@ class TD3Agent:
# concat critic parameters to use one optim
critic_parameters = list(self.critic1.parameters()) + list(self.critic2.parameters())
- # optimizer
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = optim.Adam(critic_parameters, lr=lr_critic)
- # transition to store in memory
self.transition = list()
-
- # total steps count
self.total_steps = 0
-
- # update step for actor
self.update_step = 0
-
- # mode: train / test
self.is_test = False
self.max_score = float('-inf')
@@ 190,8 176,6 @@ class TD3Agent:
self.start_storing = False
def select_action(self, state: np.ndarray) -> np.ndarray:
- """Select an action from the input state."""
- # if initial random action should be conducted
if self.total_steps < self.initial_random_steps and not self.is_test:
selected_action = self.env.action_space.sample()
else:
@@ 202,7 186,6 @@ class TD3Agent:
.numpy()
)
- # add noise for exploration during training
if not self.is_test:
noise = self.exploration_noise.sample()
selected_action = np.clip(
@@ 214,7 197,6 @@ class TD3Agent:
return selected_action
def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
- """Take an action and return the response of the env."""
next_state, reward, done, _ = self.env.step(action)
if not self.is_test:
@@ 224,7 206,6 @@ class TD3Agent:
return next_state, reward, done
def update_model(self) -> torch.Tensor:
- """Update the model by gradient descent."""
device = self.device # for shortening the following lines
samples = self.memory.sample_batch()
@@ 235,7 216,6 @@ class TD3Agent:
dones = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
masks = 1 - dones
- # get actions with noise
noise = torch.FloatTensor(self.target_policy_noise.sample()).to(device)
clipped_noise = torch.clamp(
noise, -self.target_policy_noise_clip, self.target_policy_noise_clip
@@ 281,7 261,6 @@ class TD3Agent:
return actor_loss.data, critic_loss.data
def train(self, num_frames: int, plotting_interval: int = 200):
- """Train the agent."""
self.is_test = False
state = self.env.reset()
@@ 304,7 283,6 @@ class TD3Agent:
state = next_state
score += reward
- # if episode ends
if done:
state = self.env.reset()
if self.start_storing:
@@ 324,7 302,6 @@ class TD3Agent:
and self.total_steps > self.initial_random_steps):
self.start_storing = True
- # if training is ready
if len(self.memory) >= self.batch_size and self.start_storing:
actor_loss, critic_loss = self.update_model()
if self.total_steps % self.policy_update_freq == 0:
@@ 335,7 312,6 @@ class TD3Agent:
actor_losses, critic_losses)
def test(self):
- """Test the agent."""
self.is_test = True
state = self.env.reset()