本文旨在探究将 PyTorch Lightning 应用于激动人心的强化学习(RL)领域。在这里,我们将使用经典的倒立摆 gym 环境来构建一个标准的深度 Q 网络(DQN)模型,以说明如何开始使用 Lightning 来构建 RL 模型。

    在本文中,我们将讨论:

    · 什么是 lighting 以及为什么要将它应用于 RL

    · 标准 DQN 模型简介

    · 使用 Lightning 构建 DQN 的步骤

    · 结果和结论

    DQN:强化学习原理及实现 - 图1

    Lightning 是一个最近发布的 Pythorch 库,它可以清晰地抽象和自动化 ML 模型所附带的所有日常样板代码,允许您专注于实际的 ML 部分(这些也往往是最有趣的部分)。

    除了自动化样板代码外,Lightning 还可以作为一种样式指南,用于构建干净且可复制的 ML 系统。

    这非常吸引人,原因如下:

    1. 通过抽象出样板工程代码,可以更容易地识别和理解 ML 代码。
    2. Lightning 的统一结构使得在现有项目的基础上进行构建和理解变得非常容易。
    3. Lightning 自动化的代码是用经过全面测试、定期维护并遵循 ML 最佳实践的高质量代码构建的。

    DQN:强化学习原理及实现 - 图2

    在我们进入代码之前,让我们快速回顾一下 DQN 的功能。DQN 通过学习在特定状态下执行每个操作的值来学习给定环境的最佳策略。这些值称为 Q 值。

    最初,智能体对其环境的理解非常差,因为它没有太多的经验。因此,它的 Q 值将非常不准确。然而,随着时间的推移,当智能体探索其环境时,它会学习到更精确的 Q 值,然后可以做出正确的决策。这允许它进一步改进,直到它最终收敛到一个最优策略(理想情况下)。

    我们感兴趣的大多数环境,如现代电子游戏和模拟环境,都过于复杂和庞大,无法存储每个状态 / 动作对的值。这就是为什么我们使用深度神经网络来近似这些值。

    智能体的一般生命周期如下所述:

    1. 智能体获取环境的当前状态并将其通过网络进行运算。然后,网络输出给定状态的每个动作的 Q 值。
    2. 接下来,我们决定是使用由网络给出智能体所认为的最优操作,还是采取随机操作,以便进一步探索。
    3. 这个动作被传递到环境中并得到反馈,告诉智能体它处于的下一个状态是什么,在上一个状态中执行上一个动作所得到的奖励,以及该步骤中的事件是否完成。
    4. 我们以元组(状态, 行为, 奖励, 下一状态, 已经完成的事件)的形式获取在最后一步中获得的经验,并将其存储在智能体内存中。
    5. 最后,我们从智能体内存中抽取一小批重复经验,并使用这些过去的经验计算智能体的损失。

    这是 DQN 功能的一个高度概述。

    DQN:强化学习原理及实现 - 图3

    启蒙时代是一场支配思想世界的智力和哲学运动

    让我们看看构成我们的 DQN 的组成部分

    模型:用来逼近 Q 值的神经网络

    重播缓冲区:这是我们智能体的内存,用于存储以前的经验 智能体:智能体本身就是与环境和重播缓冲区交互的东西 Lightning 模块:处理智能体的所有训练

    对于这个例子,我们可以使用一个非常简单的多层感知器(MLP)。这意味着我们没有使用任何花哨的东西,像卷积层或递归层,只是正常的线性层。这样做的原因是由于卡倒立摆环境的简单性,任何比这更复杂的东西都是过度复杂的。

    1. class DQN(nn.Module):
    2. """
    3. Simple MLP network
    4. Args:
    5. obs_size: observation/state size of the environment
    6. n_actions: number of discrete actions available in the environment
    7. hidden_size: size of hidden layers
    8. """
    9. def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
    10. super(DQN, self).__init__()
    11. self.net = nn.Sequential(
    12. nn.Linear(obs_size, hidden_size),
    13. nn.ReLU(),
    14. nn.Linear(hidden_size, n_actions)
    15. )
    16. def forward(self, x):
    17. return self.net(x.float()

    重播缓冲区的构建相当直接,我们只需要某种类型的数据结构来存储元组。我们需要能够对这些元组进行采样并添加新的元组。本例中的缓冲区基于 Lapins 重播缓冲区,因为它是迄今为止我发现的最简洁并且最快的实现。代码如下

    1. # Named tuple for storing experience steps gathered in training
    2. Experience = collections.namedtuple(
    3. 'Experience', field_names=['state', 'action', 'reward',
    4. 'done', 'new_state'])
    5. class ReplayBuffer:
    6. """
    7. Replay Buffer for storing past experiences allowing the agent to learn from them
    8. Args:
    9. capacity: size of the buffer
    10. """
    11. def __init__(self, capacity: int) -> None:
    12. self.buffer = collections.deque(maxlen=capacity)
    13. def __len__(self) -> None:
    14. return len(self.buffer)
    15. def append(self, experience: Experience) -> None:
    16. """
    17. Add experience to the buffer
    18. Args:
    19. experience: tuple (state, action, reward, done, new_state)
    20. """
    21. self.buffer.append(experience)
    22. def sample(self, batch_size: int) -> Tuple:
    23. indices = np.random.choice(len(self.buffer), batch_size, replace=False)
    24. states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
    25. return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
    26. np.array(dones, dtype=np.bool), np.array(next_states))

    但我们还没有完成。如果您在知道它的结构是基于创建数据加载器的思想创建的,然后使用它将小批量的经验传递给每个训练步骤这些原理之前使用过 Lightning;那么对于大多数 ML 系统(如监督模型),这一切如何生效的是很清楚的。但是当我们在生成数据集时,它又是如何生效的呢?

    我们需要创建自己的可迭代数据集,它使用不断更新的重播缓冲区来采样以前的经验。然后,我们有一小批经验被传递到训练步骤中用于计算我们的损失,就像其他任何模型一样。除了包含输入和标签之外,我们的小批量包含(状态, 行为, 奖励, 下一状态, 已经完成的事件)

    1. class RLDataset(IterableDataset):
    2. """
    3. Iterable Dataset containing the ReplayBuffer
    4. which will be updated with new experiences during training
    5. Args:
    6. buffer: replay buffer
    7. sample_size: number of experiences to sample at a time
    8. """
    9. def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
    10. self.buffer = buffer
    11. self.sample_size = sample_size
    12. def __iter__(self) -> Tuple:
    13. states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
    14. for i in range(len(dones)):
    15. yield states[i], actions[i], rewards[i], dones[i], new_states[i]

    您可以看到,在创建数据集时,我们传入重播缓冲区,然后可以从中采样,以允许数据加载器将批处理传递给 Lightning 模块。

    智能体类将处理与环境的交互。智能体类主要有三种方法:

    get_action:使用传递的ε值,智能体决定是使用随机操作,还是从网络输出中执行 Q 值最高的操作。

    play_step:在这里,智能体通过从 get_action 中选择的操作在环境中执行一个步骤。从环境中获得反馈后,经验将存储在重播缓冲区中。如果环境已完成该步骤,则环境将重置。最后,返回当前的奖励和完成标志。

    reset:重置环境并更新存储在代理中的当前状态。

    1. class Agent:
    2. """
    3. Base Agent class handeling the interaction with the environment
    4. Args:
    5. env: training environment
    6. replay_buffer: replay buffer storing experiences
    7. """
    8. def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
    9. self.env = env
    10. self.replay_buffer = replay_buffer
    11. self.reset()
    12. self.state = self.env.reset()
    13. def reset(self) -> None:
    14. """ Resents the environment and updates the state"""
    15. self.state = self.env.reset()
    16. def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
    17. """
    18. Using the given network, decide what action to carry out
    19. using an epsilon-greedy policy
    20. Args:
    21. net: DQN network
    22. epsilon: value to determine likelihood of taking a random action
    23. device: current device
    24. Returns:
    25. action
    26. """
    27. if np.random.random() < epsilon:
    28. action = self.env.action_space.sample()
    29. else:
    30. state = torch.tensor([self.state])
    31. if device not in ['cpu']:
    32. state = state.cuda(device)
    33. q_values = net(state)
    34. _, action = torch.max(q_values, dim=1)
    35. action = int(action.item())
    36. return action
    37. @torch.no_grad()
    38. def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]:
    39. """
    40. Carries out a single interaction step between the agent and the environment
    41. Args:
    42. net: DQN network
    43. epsilon: value to determine likelihood of taking a random action
    44. device: current device
    45. Returns:
    46. reward, done
    47. """
    48. action = self.get_action(net, epsilon, device)
    49. # do step in the environment
    50. new_state, reward, done, _ = self.env.step(action)
    51. exp = Experience(self.state, action, reward, done, new_state)
    52. self.replay_buffer.append(exp)
    53. self.state = new_state
    54. if done:
    55. self.reset()
    56. return reward, done

    现在我们已经为 DQN 建立了核心类,我们可以开始考虑训练 DQN 智能体。这就是 lighting 要介入的地方。我们将通过构建一个 lighting 模块,以一种干净和结构化的方式布置我们所有的训练逻辑。

    Lightning 提供了很多接口和可重写的函数,以获得最大的灵活性,但是我们必须实现 4 个关键方法才能使项目运行。就是下面的:

    1. forward()
    2. configure_optimizers
    3. train_dataloader
    4. train_step

    有了这 4 种方法的填充,我们可以使我们遇到的任何 ML 模型都得到很好的训练。任何需要超过这些方法的东西都可以很好地与 Lightning 中剩余的接口和回调配合。有关这些可用接口的完整列表,请查看 Lightning 文档。现在,让我们看看我们的轻量化模型。

    首先,我们需要初始化我们的环境、网络、智能体和重播缓冲区。我们还调用 populate 函数,它将以随机方式填充重播缓冲区(populate 函数在下面的完整代码示例中显示)。

    DQN:强化学习原理及实现 - 图4

    我们在这里所做的就是封装我们的 DQN 网络的前向传递函数。

    DQN:强化学习原理及实现 - 图5

    在开始训练智能体之前,我们需要定义损失函数。这里使用的损失函数是基于 Lapan 的实现,可以在这里找到。

    这是一个简单的均方误差(MSE)损失,将我们的 DQN 网络的当前状态动作值与下一个状态的预期状态动作值进行比较。在 RL 中我们没有完美的标签可以学习;相反,智能体从它期望的下一个状态的值的目标值中学习。

    然而,通过使用同一个网络来预测当前状态的值和下一个状态的值,结果会成为一个不稳定的运动目标。为了对抗这种情况,我们使用目标网络。此网络是主网络的副本,并定期与主网络同步。这提供了一个临时固定的目标,允许代理计算更稳定的损失函数。

    DQN:强化学习原理及实现 - 图6

    如您所见,状态操作值使用主网络计算,而下一个状态值(相当于我们的目标 / 标签)使用目标网络。

    这是另外一个简单的补充,只是告诉 lighting 什么优化器将在反向传递期间使用。我们将使用标准的 Adam 优化器。

    DQN:强化学习原理及实现 - 图7

    接下来,我们需要向 Lightning 提供我们的训练数据加载器。如您所料,我们初始化了先前创建的 IterableDataset。然后像往常一样把这个传递给数据加载器。Lightning 将在培训期间处理提供的批次,并将这些批次转换为 Pythorch 张量,并将它们移动到正确的设备。

    DQN:强化学习原理及实现 - 图8

    最后我们有了训练的步骤。在这里,我们输入了每个训练迭代要执行的所有逻辑。

    在每次训练迭代过程中,我们希望智能体通过调用前面定义的 agent.play_step() 并传入当前设备和ε值,在环境中执行一步。这将返回该步骤的奖励,以及本次迭代否在该步骤中完成。我们将步骤奖励添加到整个事件中,以便跟踪智能体在该事件中的成功程度。

    接下来,我们使用 lighting 提供的当前小批量,计算我们的损失。

    如果我们已经到了本次迭代的结尾,用 done 标志表示,我们将用 session reward 更新当前的 total_reward 变量。

    在步骤的最后,我们检查是否是同步主网络和目标网络的时间。通常在只更新一部分权重的情况下使用软更新,但对于这个简单的示例来说,完全更新就足够了。

    最后,我们需要返回一个 Dict,其中包含 Lightning 将用于反向传播的损耗,一个 Dict 包含我们要记录的值(注意:这些值必须是张量),另一个 Dict 包含我们要在进度条上显示的任何值。

    DQN:强化学习原理及实现 - 图9

    就这样,我们现在有了运行 DQN 智能体所需的一切。

    现在要做的就是初始化并适应我们的 lighting 模型。在我们的主 python 文件中,我们将设置种子,并提供一个 arg 解析器,其中包含我们要传递给模型的任何必要的超参数。

    DQN:强化学习原理及实现 - 图10

    然后在我们的主方法中,我们用指定的参数初始化 dqnlighting 模型。接下来是 Lightning 训练器的设置。

    在这里,我们设置教练过程使用 GPU。如果您没有访问 GPU 的权限,请从培训器中删除 “GPU” 和 “distributed_backend” 参数。这种模式训练非常快,即使是使用 CPU,所以为了在运行过程中观察 Lightning,我们将关闭早停机制。

    最后,因为我们使用的是可迭代数据集,所以需要指定 val_check_interval。通常,此间隔是根据数据集的长度自动设置的。然而,可迭代数据集没有一个长度函数。因此,我们需要自己设置这个值,即使我们没有执行验证步骤。

    DQN:强化学习原理及实现 - 图11

    最后一步是调用我们的模型上的 trainer.fit(),并观看它的训练。下面你可以看到完整的 lighting 代码:

    1. class DQNLightning(pl.LightningModule):
    2. """ Basic DQN Model """
    3. def __init__(self, hparams: argparse.Namespace) -> None:
    4. super().__init__()
    5. self.hparams = hparams
    6. self.env = gym.make(self.hparams.env)
    7. obs_size = self.env.observation_space.shape[0]
    8. n_actions = self.env.action_space.n
    9. self.net = DQN(obs_size, n_actions)
    10. self.target_net = DQN(obs_size, n_actions)
    11. self.buffer = ReplayBuffer(self.hparams.replay_size)
    12. self.agent = Agent(self.env, self.buffer)
    13. self.total_reward = 0
    14. self.episode_reward = 0
    15. self.populate(self.hparams.warm_start_steps)
    16. def populate(self, steps: int = 1000) -> None:
    17. """
    18. Carries out several random steps through the environment to initially fill
    19. up the replay buffer with experiences
    20. Args:
    21. steps: number of random steps to populate the buffer with
    22. """
    23. for i in range(steps):
    24. self.agent.play_step(self.net, epsilon=1.0)
    25. def forward(self, x: torch.Tensor) -> torch.Tensor:
    26. """
    27. Passes in a state x through the network and gets the q_values of each action as an output
    28. Args:
    29. x: environment state
    30. Returns:
    31. q values
    32. """
    33. output = self.net(x)
    34. return output
    35. def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    36. """
    37. Calculates the mse loss using a mini batch from the replay buffer
    38. Args:
    39. batch: current mini batch of replay data
    40. Returns:
    41. loss
    42. """
    43. states, actions, rewards, dones, next_states = batch
    44. state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
    45. with torch.no_grad():
    46. next_state_values = self.target_net(next_states).max(1)[0]
    47. next_state_values[dones] = 0.0
    48. next_state_values = next_state_values.detach()
    49. expected_state_action_values = next_state_values * self.hparams.gamma + rewards
    50. return nn.MSELoss()(state_action_values, expected_state_action_values)
    51. def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
    52. """
    53. Carries out a single step through the environment to update the replay buffer.
    54. Then calculates loss based on the minibatch recieved
    55. Args:
    56. batch: current mini batch of replay data
    57. nb_batch: batch number
    58. Returns:
    59. Training loss and log metrics
    60. """
    61. device = self.get_device(batch)
    62. epsilon = max(self.hparams.eps_end, self.hparams.eps_start -
    63. self.global_step + 1 / self.hparams.eps_last_frame)
    64. # step through environment with agent
    65. reward, done = self.agent.play_step(self.net, epsilon, device)
    66. self.episode_reward += reward
    67. # calculates training loss
    68. loss = self.dqn_mse_loss(batch)
    69. if self.trainer.use_dp or self.trainer.use_ddp2:
    70. loss = loss.unsqueeze(0)
    71. if done:
    72. self.total_reward = self.episode_reward
    73. self.episode_reward = 0
    74. # Soft update of target network
    75. if self.global_step % self.hparams.sync_rate == 0:
    76. self.target_net.load_state_dict(self.net.state_dict())
    77. log = {'total_reward': torch.tensor(self.total_reward).to(device),
    78. 'reward': torch.tensor(reward).to(device),
    79. 'steps': torch.tensor(self.global_step).to(device)}
    80. return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})
    81. def configure_optimizers(self) -> List[Optimizer]:
    82. """ Initialize Adam optimizer"""
    83. optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr)
    84. return [optimizer]
    85. def train_dataloader(self) -> DataLoader:
    86. """Initialize the Replay Buffer dataset used for retrieving experiences"""
    87. dataset = RLDataset(self.buffer, self.hparams.episode_length)
    88. dataloader = DataLoader(dataset=dataset,
    89. batch_size=self.hparams.batch_size,
    90. )
    91. return dataloader
    92. def get_device(self, batch) -> str:
    93. """Retrieve device currently being used by minibatch"""
    94. return batch[0].device.index if self.on_gpu else 'cpu'
    95. def main(hparams) -> None:
    96. model = DQNLightning(hparams)
    97. trainer = pl.Trainer(
    98. gpus=1,
    99. distributed_backend='dp',
    100. max_epochs=10000,
    101. early_stop_callback=False,
    102. val_check_interval=100
    103. )
    104. trainer.fit(model)
    105. if __name__ == '__main__':
    106. torch.manual_seed(0)
    107. np.random.seed(0)
    108. parser = argparse.ArgumentParser()
    109. parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
    110. parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    111. parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
    112. parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
    113. parser.add_argument("--sync_rate", type=int, default=10,
    114. help="how many frames do we update the target network")
    115. parser.add_argument("--replay_size", type=int, default=1000,
    116. help="capacity of the replay buffer")
    117. parser.add_argument("--warm_start_size", type=int, default=1000,
    118. help="how many samples do we use to fill our buffer at the start of training")
    119. parser.add_argument("--eps_last_frame", type=int, default=1000,
    120. help="what frame should epsilon stop decaying")
    121. parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
    122. parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
    123. parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
    124. parser.add_argument("--max_episode_reward", type=int, default=200,
    125. help="max episode reward in the environment")
    126. parser.add_argument("--warm_start_steps", type=int, default=1000,
    127. help="max episode reward in the environment")
    128. args = parser.parse_args()
    129. main(args)

    大约 1200 代后,您将看到智能体的总奖励达到最大得分 200。为了看到正在绘制的奖励指标,调用 tensorboards。

    tensorboard —logdir lightning_logs

    DQN:强化学习原理及实现 - 图12

    在左边的图中你可以看到每一步的奖励。由于环境的性质,这将始终是 1,因为智能体每一步都会得到 + 1 的奖励,极点从没有下降(这就是全部奖励)。在右边的途中我们可以看到每一步的总奖励。智能体很快就达到了最高奖励,然后在好的状态和不好的状态之间波动。

    现在您已经看到了在强化学习项目中利用 PyTorch Lightning 的力量是多么简单和实用。

    这是一个非常简单的例子,只是为了说明 lighting 在 RL 中的使用,所以这里有很多改进的空间。如果您想将此代码作为模板,并尝试实现自己的代理,下面是一些我会尝试的事情。

    降低学习率或许更好。通过在 configure_optimizer 方法中初始化学习率调度程序来使用它。

    提高目标网络的同步速率或使用软更新而不是完全更新

    在更多步骤的过程中使用更渐进的ε衰减。

    通过在训练器中设置 max_epochs 来增加训练的代数。

    除了跟踪 tensorboard 日志中的总奖励,还跟踪平均总奖励。

    使用 test/val Lightning hook 添加测试和验证步骤

    最后,尝试一些更复杂的模型和环境

    我希望这篇文章是有帮助的,将有助于启动您使用 lighting 启动自己的项目。快乐编码!

    作者:Donal Byrne

    Deephub 翻译组:tensor-zhang
    https://www.toutiao.com/i6812769570873410055/