~enan/ros-rl

eadc4410c08040ab094ccfe51b7a06123aee531f — Enan Ajmain 2 years ago 2b1043b
td3,ddpg: let algo choose action dimension in env
2 files changed, 32 insertions(+), 20 deletions(-)

M src/envs.py
M src/td3.py
M src/envs.py => src/envs.py +31 -19
@@ 27,24 27,27 @@ class GazeboAutoVehicleEnv(gym.Env):
        # "render_fps": 30,
    }

    def __init__(self, H, W):
    def __init__(self, H, W, action_dim=1, use_pause=False):
        self.IMAGE_TOPIC = "/vehicle_camera/image_raw"
        self.CMDVEL_TOPIC = "vehicle/cmd_vel"
        self.GZRESET_TOPIC = "/gazebo/reset_world"
        # self.GZPAUSE_TOPIC = '/gazebo/pause_physics'
        # self.GZUNPAUSE_TOPIC = '/gazebo/unpause_physics'
        self.GZPAUSE_TOPIC = '/gazebo/pause_physics'
        self.GZUNPAUSE_TOPIC = '/gazebo/unpause_physics'
        self.MODEL_TOPIC = '/gazebo/model_states'

        self.H,self.W = H,W
        self.finished = False
        self.use_pause = use_pause
        self.action_dim = action_dim

        # self.action_space = gym.spaces.Box(np.array([-1]),
        #                                    np.array([1]),
        #                                    dtype=np.float32)

        self.action_space = gym.spaces.Box(np.array([-1, 0.5]),
                                           np.array([1, 1.0]),
                                           dtype=np.float32)
        if self.action_dim == 1:
            self.action_space = gym.spaces.Box(np.array([-1]),
                                               np.array([1]),
                                               dtype=np.float32)
        else:
            self.action_space = gym.spaces.Box(np.array([-1, 1.0]),
                                               np.array([1, 1.5]),
                                               dtype=np.float32)

        self.observation_space = gym.spaces.Box(np.array([-1]),
                                                np.array([1]),


@@ 57,8 60,8 @@ class GazeboAutoVehicleEnv(gym.Env):

        rospy.wait_for_service(self.GZRESET_TOPIC)
        self.reset_proxy = rospy.ServiceProxy(self.GZRESET_TOPIC, Empty)
        # self.pause = rospy.ServiceProxy(self.GZPAUSE_TOPIC, Empty)
        # self.unpause = rospy.ServiceProxy(self.GZUNPAUSE_TOPIC, Empty)
        self.pause = rospy.ServiceProxy(self.GZPAUSE_TOPIC, Empty)
        self.unpause = rospy.ServiceProxy(self.GZUNPAUSE_TOPIC, Empty)

        self.state = None



@@ 73,15 76,18 @@ class GazeboAutoVehicleEnv(gym.Env):
            self.finished = True

    def step(self, action):
        self.speed = 0.5
        self.turn = action[0].item()
        self.speed = action[1].item()
        if self.action_dim == 1:
            self.speed = 0.5
        else:
            self.speed = action[1].item()

        twist = Twist()
        twist.linear.x = self.speed
        twist.angular.z = self.turn

        # self.unpause()
        if self.use_pause:
            self.unpause()
        self.vel_pub.publish(twist)

        state = self.state


@@ 89,7 95,10 @@ class GazeboAutoVehicleEnv(gym.Env):

        if state != None:
            done = False
            reward = 1 - abs(state) + self.speed
            if self.action_dim == 1:
                reward = 1 - abs(state)
            else:
                reward = (1 - abs(state) + self.speed) / (1+1)
        else:
            done = True
            reward = -1


@@ 108,7 117,8 @@ class GazeboAutoVehicleEnv(gym.Env):
        print("REWARD:", reward)
        print("-----------------------------")

        # self.pause()
        if self.use_pause:
            self.pause()


        return obs, reward, done, {}


@@ 118,9 128,11 @@ class GazeboAutoVehicleEnv(gym.Env):

        self.reset_proxy()
        self.finished = False
        # self.unpause()
        if self.use_pause:
            self.unpause()
        time.sleep(0.5)
        # self.pause()
        if self.use_pause:
            self.pause()

        state = None
        while state is None:

M src/td3.py => src/td3.py +1 -1
@@ 385,7 385,7 @@ class ActionNormalizer(gym.ActionWrapper):


signal.signal(signal.SIGINT, interrupt_handler)
env = GazeboAutoVehicleEnv(600, 800)
env = GazeboAutoVehicleEnv(600, 800, 2, 1)
env = ActionNormalizer(env)