4 from training_node
import TrainingNode, device
7 from collections
import deque
8 from parameters_q_learning
import *
14 BATCH_INDICES = torch.arange(0, BATCH_SIZE, device=device, dtype=torch.long)
18 ''' ROS node to train the Q-Learning model 22 TrainingNode.__init__(
30 self.
memory = deque(maxlen=MEMORY_SIZE)
37 if len(self.
memory) < 500
or len(self.
memory) < BATCH_SIZE:
41 rospy.loginfo(
"Model optimization started.")
43 transitions = random.sample(self.
memory, BATCH_SIZE)
44 states, actions, rewards, next_states, is_terminal = tuple(zip(*transitions))
46 states = torch.stack(states)
47 actions = torch.tensor(actions, device=device, dtype=torch.long)
48 rewards = torch.tensor(rewards, device=device, dtype=torch.float)
49 next_states = torch.stack(next_states)
50 is_terminal = torch.tensor(
51 is_terminal, device=device, dtype=torch.uint8)
53 next_state_values = self.policy.forward(next_states).max(1)[0].detach()
54 q_updates = rewards + next_state_values * DISCOUNT_FACTOR
55 q_updates[is_terminal] = rewards[is_terminal]
57 self.optimizer.zero_grad()
58 net_output = self.policy.forward(states)
59 loss = F.smooth_l1_loss(net_output[BATCH_INDICES, actions], q_updates)
61 for param
in self.policy.parameters():
62 param.grad.data.clamp_(-1, 1)
67 return EPS_END + (EPS_START - EPS_END) * \
73 return random.randrange(ACTION_COUNT)
76 output = self.
policy(state)
79 [
"{0:.1f}".format(v).rjust(5)
for v
in output.tolist()])
80 return output.max(0)[1].item()
84 distance = abs(track_position.distance_to_center)
94 return TrainingNode.get_episode_summary(self) +
' ' \
95 + (
"memory: {0:d} / {1:d}, ".format(len(self.
memory), MEMORY_SIZE)
if len(self.
memory) < MEMORY_SIZE
else "") \
101 self.memory.append((state, action, reward, next_state, self.
is_terminal_step))
105 rospy.init_node(
'q_learning_training', anonymous=
True)
def on_complete_step(self, state, action, reward, next_state)
def get_episode_summary(self)
def get_epsilon_greedy_threshold(self)
def select_action(self, state)