项目作者 | thu-ml 转自 | 机器之心
参与 | 思、肖清
训练模型的极速,与 1500 行源代码的精简,清华大学新开源强化学习平台「天授」。值得注意的是,该项目的两位主要作者目前都是清华大学的本科生。
是否你也有这样的感觉,成熟 ML 工具的源码很难懂,各种继承与处理关系需要花很多时间一点点理清。在清华大学开源的「天授」项目中,它以极简的代码实现了很多极速的强化学习算法。重点是,天授框架的源码很容易懂,不会有太复杂的逻辑关系。
项目地址:https://github.com/thu-ml/tianshou
天授(Tianshou)是纯 基于 PyTorch 代码的强化学习框架,与目前现有基于 TensorFlow 的强化学习库不同,天授的类继承并不复杂,API 也不是很繁琐。最重要的是,天授的训练速度非常快,我们试用 Pythonic 的 API 就能快速构建与训练 RL 智能体。
目前天授支持的 RL 算法有如下几种:
-
Policy Gradient (PG)
-
Deep Q-Network (DQN)
-
Double DQN (DDQN) with n-step returns
-
Advantage Actor-Critic (A2C)
-
Deep Deterministic Policy Gradient (DDPG)
-
Proximal Policy Optimization (PPO)
-
Twin Delayed DDPG (TD3)
-
Soft Actor-Critic (SAC)
另外,对于以上代码天授还支持并行收集样本,并且所有算法均统一改写为基于 replay-buffer 的形式。
速度与轻量:「天授」的灵魂
天授旨在提供一个高速、轻量化的 RL 开源平台。下图为天授与各大知名 RL 开源平台在 CartPole 与 Pendulum 环境下的速度对比。所有代码均在配置为 i7-8750H + GTX1060 的同一台笔记本电脑上进行测试。值得注意的是,天授实现的 VPG(vanilla policy gradient)算法在 CartPole-v0 任务中,训练用时仅为 3 秒。
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;"><span style="box-sizing: border-box;font-size: inherit;color: rgb(165, 218, 45);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">result</span> = collector.collect(n_step=n)</section>
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;"><span style="box-sizing: border-box;font-size: inherit;color: rgb(165, 218, 45);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">result</span> = policy.learn(collector.sample(batch_size))</section>
-
__init__:初始化策略
-
process_fn:从 replay buffer 中处理数据
-
__call__:给定环境观察结果计算对应行动
-
learn:给定批量数据学习策略
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;">!git clone https:<span style="box-sizing: border-box;font-size: inherit;color: rgb(128, 128, 128);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">//github.com/thu-ml/tianshou</span><br style="box-sizing: border-box;font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;" />!pip3 install tianshou<br style="box-sizing: border-box;font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;" /><span style="box-sizing: border-box;font-size: inherit;color: rgb(248, 35, 117);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os<br style="box-sizing: border-box;font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;" />os.chdir(<span style="box-sizing: border-box;font-size: inherit;color: rgb(238, 220, 112);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">'tianshou'</span>)</section>
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;">!python <span style="box-sizing: border-box;font-size: inherit;color: rgb(248, 35, 117);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">test</span>/discrete/test_pg.py<br style="box-sizing: border-box;font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;" /></section>
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;">!python <span style="box-sizing: border-box;font-size: inherit;color: rgb(248, 35, 117);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">test</span>/discrete/test_ppo.py<br style="box-sizing: border-box;font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;" /></section>
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;">!python <span style="box-sizing: border-box;font-size: inherit;color: rgb(248, 35, 117);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">test</span>/discrete/test_a2c.py</section>
<section style="box-sizing: border-box;padding: 0.5em;font-size: 14px;color: rgb(169, 183, 198);line-height: 18px;border-radius: 0px;background: rgb(40, 43, 46);display: block;font-family: Consolas, Inconsolata, Courier, monospace;overflow: auto;letter-spacing: 0px;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;">!python <span style="box-sizing: border-box;font-size: inherit;color: rgb(248, 35, 117);line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">test</span>/discrete/test_dqn.py</section>
-
Prioritized replay buffer
-
RNN support
-
Imitation Learning
-
Multi-agent
-
Distributed training
<pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><p style="max-width: 100%;min-height: 1em;letter-spacing: 0.544px;white-space: normal;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;line-height: 1.75em;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></p><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><p style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></p><p style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">一个AI PhD的毕业随感<br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;" /></p><p style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">MIT最新深度学习入门课,安排起来!</span></p><p style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">有了这个神器,轻松用 Python 写个 App</span></p><p style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">「最全」实至名归,NumPy 官方早有中文教程</span></p><section style="margin: 5px 16px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">北大团队成功实现精准删除特定记忆,马斯克脑机接口有望今年人体测试</span><br /></section></section></section></section></section></section></section></section></section>
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
内容反馈