聚合国内IT技术精华文章,分享IT技术精华,帮助IT从业人士成长

强化学习框架 rlpyt 源码分析:(6) 模型指标什么时候从 nan 变成有意义的值

2019-12-08 23:25 浏览: 1938146 次 我要评论(0 条) 字号:

转载需注明出处:https://www.codelast.com/

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。

▶▶ 观察训练日志引出的问题
以 example_1 为例,在训练的过程中,程序会不断打印出类似于下面的日志(部分内容):

2019-11-08 20:38:42.067188  | StepsInEval              3796
2019-11-08 20:38:42.067216  | TrajsInEval                 5
2019-11-08 20:38:42.067240  | CumEvalTime                23.1265
2019-11-08 20:38:42.067276  | CumTrainTime                2.64641
2019-11-08 20:38:42.067297  | Iteration                 249
2019-11-08 20:38:42.067315  | CumTime (s)                25.7729
2019-11-08 20:38:42.067333  | CumSteps                 1000
2019-11-08 20:38:42.067350  | CumCompletedTrajs           1
2019-11-08 20:38:42.067368  | CumUpdates                  0
2019-11-08 20:38:42.067385  | StepsPerSecond            386.079
2019-11-08 20:38:42.067402  | UpdatesPerSecond            0
2019-11-08 20:38:42.067419  | ReplayRatio                 0
2019-11-08 20:38:42.067436  | CumReplayRatio              0
2019-11-08 20:38:42.067453  | LengthAverage             759.2
2019-11-08 20:38:42.067480  | LengthStd                   1.16619
2019-11-08 20:38:42.067499  | LengthMedian              759
2019-11-08 20:38:42.067516  | LengthMin                 758
2019-11-08 20:38:42.067533  | LengthMax                 761
2019-11-08 20:38:42.067550  | ReturnAverage             -21
2019-11-08 20:38:42.067567  | ReturnStd                   0
2019-11-08 20:38:42.067584  | ReturnMedian              -21
2019-11-08 20:38:42.067601  | ReturnMin                 -21
2019-11-08 20:38:42.067618  | ReturnMax                 -21
2019-11-08 20:38:42.067635  | NonzeroRewardsAverage      21
2019-11-08 20:38:42.067652  | NonzeroRewardsStd           0
2019-11-08 20:38:42.067669  | NonzeroRewardsMedian       21
2019-11-08 20:38:42.067686  | NonzeroRewardsMin          21
2019-11-08 20:38:42.067703  | NonzeroRewardsMax          21
2019-11-08 20:38:42.067720  | DiscountedReturnAverage    -1.87771
2019-11-08 20:38:42.067737  | DiscountedReturnStd         0.0219605
2019-11-08 20:38:42.067754  | DiscountedReturnMedian     -1.88136
2019-11-08 20:38:42.067771  | DiscountedReturnMin        -1.90036
2019-11-08 20:38:42.067788  | DiscountedReturnMax        -1.84392
2019-11-08 20:38:42.067805  | lossAverage               nan
2019-11-08 20:38:42.067822  | lossStd                   nan
2019-11-08 20:38:42.067839  | lossMedian                nan
2019-11-08 20:38:42.067856  | lossMin                   nan
2019-11-08 20:38:42.067873  | lossMax                   nan
2019-11-08 20:38:42.067890  | gradNormAverage           nan
2019-11-08 20:38:42.067907  | gradNormStd               nan
2019-11-08 20:38:42.067924  | gradNormMedian            nan
2019-11-08 20:38:42.067941  | gradNormMin               nan
2019-11-08 20:38:42.067958  | gradNormMax               nan
2019-11-08 20:38:42.067975  | tdAbsErrAverage           nan
2019-11-08 20:38:42.067992  | tdAbsErrStd               nan
2019-11-08 20:38:42.068009  | tdAbsErrMedian            nan
2019-11-08 20:38:42.068026  | tdAbsErrMin               nan
2019-11-08 20:38:42.068043  | tdAbsErrMax               nan
文章来源:https://www.codelast.com/
仔细看就会发现,最后的若干个模型指标都是“nan”,在训练了一段时间之后,这些值就变成了有意义的值,例如:

2019-11-08 20:40:40.941580  | lossAverage                 0.0129165
2019-11-08 20:40:40.941597  | lossStd                     0.0137061
2019-11-08 20:40:40.941614  | lossMedian                  0.0150348
2019-11-08 20:40:40.941631  | lossMin                     0.000105323
2019-11-08 20:40:40.941648  | lossMax                     0.0602407
2019-11-08 20:40:40.941665  | gradNormAverage             0.0283939
2019-11-08 20:40:40.941682  | gradNormStd                 0.0168219
2019-11-08 20:40:40.941699  | gradNormMedian              0.0301482
2019-11-08 20:40:40.941716  | gradNormMin                 0.00661218
2019-11-08 20:40:40.941732  | gradNormMax                 0.086334
2019-11-08 20:40:40.941749  | tdAbsErrAverage             0.0529054
2019-11-08 20:40:40.941766  | tdAbsErrStd                 0.168416
2019-11-08 20:40:40.941783  | tdAbsErrMedian              0.0233203
2019-11-08 20:40:40.941800  | tdAbsErrMin                 8.33329e-05
2019-11-08 20:40:40.941817  | tdAbsErrMax                 1
所以这些值是在什么时候才会从“nan”变成有意义的值呢?为什么刚开始训练不久的时候,会获取不到这些值?理论上,只要开始训练了,哪怕这些数字错得再离谱,它们也是有数的,不应该是“nan”才对,对吧?所以这里为什么会显示“nan”?
文章来源:https://www.codelast.com/
▶▶ nan 日志在哪记下来的
为了弄清楚上面的问题,我们要找到根源——打印“nan”日志的地方。上面那些显示为“nan”的日志,是 rlpyt/utils/logging/logger.py 的 record_tabular_misc_stat() 函数记录下来的:

def record_tabular_misc_stat(key, values, placement='back'):
    if placement == 'front':
        prefix = ""
        suffix = key
    else:
        prefix = key
        suffix = ""
    if len(values) > 0:
        record_tabular(prefix + "Average" + suffix, np.average(values))
        record_tabular(prefix + "Std" + suffix, np.std(values))
        record_tabular(prefix + "Median" + suffix, np.median(values))
        record_tabular(prefix + "Min" + suffix, np.min(values))
        record_tabular(prefix + "Max" + suffix, np.max(values))
    else:
        record_tabular(prefix + "Average" + suffix, np.nan)
        record_tabular(prefix + "Std" + suffix, np.nan)
        record_tabular(prefix + "Median" + suffix, np.nan)
        record_tabular(prefix + "Min" + suffix, np.nan)
        record_tabular(prefix + "Max" + suffix, np.nan)

文章来源:https://www.codelast.com/
这个函数用来计算某些模型指标,这些模型指标有一个共同的特征:它们都可以计算平均值标准差等统计值。这是什么意思?举个例子,有一个指标“CumTrainTime”(累积的训练时间),它就没有“平均值”的概念;而像 loss(损失函数的值)这种指标,它在多轮训练迭代过程中,是可以有“平均值”的概念的。
而类似于 loss 这种指标,还不止一个。为了简化代码,这里采用了拼接模型指标名称的做法,例如日志里的"lossAverage","gradNormAverage"之类的名称都是拼出来的,而不是直接写死,正如你上面看到的代码一样。
从上面的代码可见,当传入的“values”为空的时候,记下来的某些模型指标就会变成“nan”。
所以现在的问题变成了:在什么时候,传入的“values”会为空?
文章来源:https://www.codelast.com/
▶▶ logger的调用者 MinibatchRlEval 更新模型指标的逻辑
example_1 使用的 runner 是 MinibatchRlEval,它就是 logger 的调用者。在 MinibatchRlEval.train() 函数中定义了模型的训练、评估流程。
下面这句代码:

opt_info = self.algo.optimize_agent(itr, samples)

会把 loss 等参数收集到 opt_info 对象中,而下面这句代码:

self.store_diagnostics(itr, traj_infos, opt_info)

则会把 opt_info 更新到内存里。最后,这一句代码:

self.log_diagnostics(itr, eval_traj_infos, eval_time)

会把内存里的信息记录到日志,以及print到屏幕上。

所以,其实我们只要弄清楚 self.algo.optimize_agent() 返回 opt_info 的逻辑,就知道在什么情况下 loss 等指标为“nan”了。
文章来源:https://www.codelast.com/
▶▶ 找到根本原因:algorithm类更新模型指标的逻辑
example_1 使用的algorithm类是:

class DQN(RlAlgorithm):

它的 optimize_agent() 函数里有这样一段代码:

opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
if itr < self.min_itr_learn:
    return opt_info

这里的 opt_info 其实就是一个各字段为空list的 namedtuple 对象:

OptInfo(loss=[], gradNorm=[], tdAbsErr=[])

答案已经很明显了,当前模型训练的迭代次数 < self.min_itr_learn 的时候,就会造成 loss 等模型指标为“nan”。
self.min_itr_learn 是在 DQN.initialize() 函数里初始化的:

self.min_itr_learn = int(self.min_steps_learn // sampler_bs)

不用去管这个看似有点奇怪的逻辑,只需要知道:self.min_steps_learn 越大,“nan”打印出的次数就越多。
而 self.min_steps_learn 这个参数,是在 DQN 类对象构造的时候传入的(example_1.py):

algo = DQN(min_steps_learn=1e3)

所以,你只要改小这个值,就可以让“nan”出现的次数减少。
文章来源:https://www.codelast.com/
▶▶ 为什么要这样做,以及调整 min_steps_learn 参数的注意事项
rlpyt 为什么要用一个参数来控制模型指标的计算过程?其实它不是为了控制什么时候不显示“nan”,看 DQN.optimize_agent() 函数的这几句代码:

if samples is not None:
    samples_to_buffer = self.samples_to_buffer(samples)
    self.replay_buffer.append_samples(samples_to_buffer)
opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
if itr < self.min_itr_learn:
    return opt_info

就会发现:当训练迭代次数没有达到 self.min_itr_learn 的时候,算法会一直把与environment交互得到的采样数据收集到 Replay Buffer 里面,如果 Replay Buffer 里的数据太少,没有达到预设的数量,那么开始优化策略网络也是没有意义的。当满足 irt >= self.min_itr_learn 的条件之后,后面才会进行反向传播之类的工作。
所以我认为,min_steps_learn 的值确实不能设置得太小。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。



网友评论已有0条评论, 我也要评论

发表评论

*

* (保密)

Ctrl+Enter 快捷回复