聚合国内IT技术精华文章,分享IT技术精华,帮助IT从业人士成长

强化学习框架 rlpyt 源码分析:(5) 为model类提供额外参数的Mixin类

2019-12-01 23:25 浏览: 2096116 次 我要评论(0 条) 字号:

转载需注明出处:https://www.codelast.com/

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。

▶▶ Mixin类简介
rlpyt 里面有大量的 *Mixin 类,例如 AtariMixin,MujocoMixin,RecurrentAgentMixin 等,作者并没有为这些名字很怪的class写任何注释,仅从使用的地方来看,很多Mixin类都与agent类有关联。

▶▶ 分析具体实例:AtariMixin
要充分理解Mixin类的设计意图,可以从一个具体的class来分析:AtariMixin。它是 AtariDqnAgent 的其中一个父类:

class AtariDqnAgent(AtariMixin, DqnAgent):
    def __init__(self, ModelCls=AtariDqnModel, **kwargs):
        super().__init__(ModelCls=ModelCls, **kwargs)

其中,另一个父类 DqnAgent 是实现了agent逻辑的类。AtariMixin 里面只实现了一个非常简单的函数,返回了一个字典:

class AtariMixin:
    def make_env_to_model_kwargs(self, env_spaces):
        return dict(image_shape=env_spaces.observation.shape,
                    output_size=env_spaces.action.n)

这个函数是在哪里被调用的?这就有点tricky了:它是在 DqnAgent 的父类 BaseAgent 的 initialize() 函数里被调用的:

self.env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)

文章来源:https://www.codelast.com/
我们来理一下,这个调用链很有意思:
rlpyt mixin class hierarchy
从这幅图可以看到,在agent类 initialize() 的时候,它调用的 make_env_to_model_kwargs() 函数,实际上调用的是 Mixin 类实现的 make_env_to_model_kwargs() 函数。
看上面的继承关系图,如果你产生一种疑问:“Python还能这样做的?” 那么我建议你可以自己去写几个简单的class实验一下——确实可以这样。
然而这个绕了一大圈的逻辑,是不是太麻烦了?
文章来源:https://www.codelast.com/
▶▶ 为什么要插入一个Mixin类
一开始我在想,为什么不直接在 DqnAgent 类中实现其父类 BaseAgent 定义的接口 make_env_to_model_kwargs() 呢?那样不就可以少写一个Mixin类?
为了想明白这个问题,我们来看看 BaseAgent 类在调用了 make_env_to_model_kwargs() 函数后干了什么事情:

self.env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)
self.model = self.ModelCls(**self.env_model_kwargs, **self.model_kwargs)

可见,它用返回的字典(dict) self.env_model_kwargs 来实例化 model 类。
要知道,rlpyt 是一个强化学习的框架,而不是一个专用于Atari游戏的强化学习库,我们可以用它来实现跟游戏毫不相关的强化学习应用。每一种强化学习应用,都有其对应的model类,而model类的参数(通常是跟environment space相关)因应用而异,我们不可能强行规定这些model类的参数必须叫什么名字,而是应该具有普适性:由应用的开发者自己去定义。
以 AtariMixin 为例,它返回的dict里包含两个参数:image_shape 和 output_size,即输入图像的shape以及输出的size,如果我自己的强化学习应用不是游戏应用、完全没有image这种东西呢?
在这个时候,我就需要几个更合适的名字来描述它们。
文章来源:https://www.codelast.com/
所以,看似半路杀出来的无厘头 Mixin 类,其实是为了 rlpyt 框架的良好扩展性而设计的一个类,它用于向model类提供实例化所需的特殊参数。
不过,在 rlpyt 中,并不是所有 Mixin 类都是为model类服务的,例如 EpsilonGreedy 类的父类 DiscreteMixin,就和model类无关。但这个类它也带了“为子类提供一些额外的功能,但放在子类中实现又不太好”的思想。



网友评论已有0条评论, 我也要评论

发表评论

*

* (保密)

Ctrl+Enter 快捷回复