知行编程网知行编程网  2022-06-25 19:00 知行编程网 隐藏边栏 |   抢沙发  51 
文章评分 0 次,平均分 0.0

那些用pytorch踩过的坑

来自 | 知乎   作者 | 土豆

链接 | https://zhuanlan.zhihu.com/p/180020358

编辑 | 深度学习这件小事

本文仅作学术交流,如有侵权,请联系后台删除。


   pytorch中的交叉熵

pytorch的交叉熵nn.CrossEntropyLoss在训练阶段,里面是内置了softmax操作的,因此只需要喂入原始的数据结果即可,不需要在之前再添加softmax层。这个和tensorflow的tf.softmax_cross_entropy_with_logits如出一辙.[1][2]pytorch的交叉熵nn.CrossEntropyLoss在训练阶段,里面是内置了softmax操作的,因此只需要喂入原始的数据结果即可,不需要在之前再添加softmax层。这个和tensorflow的tf.softmax_cross_entropy_with_logits如出一辙.[1][2]

   pytorch中的MSELoss和KLDivLoss

在深度学习中,MSELoss均方差损失和KLDivLossKL散度是经常使用的两种损失,在pytorch中,也有这两个函数,如:

这个时候我们要注意到,我们的标签target是需要一个不能被训练的,也就是requires_grad=False的值,否则将会报错,出现如:

我们注意到,其实不只是MSELoss,其他很多loss,比如交叉熵,KL散度等,其target都需要是一个不能被训练的值的,这个和TensorFlow中的tf.nn.softmax_cross_entropy_with_logits_v2不太一样,后者可以使用可训练的target,具体见[3]


   在验证和测试阶段取消掉梯度(no_grad)

一般来说,我们在进行模型训练的过程中,因为要监控模型的性能,在跑完若干个epoch训练之后,需要进行一次在验证集[4]上的性能验证。一般来说,在验证或者是测试阶段,因为只是需要跑个前向传播(forward)就足够了,因此不需要保存变量的梯度。保存梯度是需要额外显存或者内存进行保存的,占用了空间,有时候还会在验证阶段导致OOM(Out Of Memory)错误,因此我们在验证和测试阶段,最好显式地取消掉模型变量的梯度。 在pytroch 0.4及其以后的版本中,用torch.no_grad()这个上下文管理器就可以了,例子如下:

如上,我们只需要在加上上下文管理器就可以很方便的取消掉梯度。这个功能在pytorch以前的版本中,通过设置volatile=True生效,不过现在这个用法已经被抛弃了。


   显式指定model.train()和model.eval()

我们的模型中经常会有一些子模型,其在训练时候和测试时候的参数是不同的,比如dropout[6]中的丢弃率和Batch Normalization[5]中的那些用pytorch踩过的坑那些用pytorch踩过的坑等,这个时候我们就需要显式地指定不同的阶段(训练或者测试),在pytorch中我们通过model.train()和model.eval()进行显式指定,具体如:

注意,在模型中有BN层或者dropout层时,在训练阶段和测试阶段必须显式指定train()和eval()。


   关于retain_graph的使用

在对一个损失进行反向传播时,在pytorch中调用out.backward()即可实现,给个小例子如:
对loss进行反向传播就可以求得那些用pytorch踩过的坑,即是损失对于每个叶子节点的梯度。我们注意到,在.backward()这个API的文档中,有几个参数,如:
这里我们关注的是retain_graph这个参数,这个参数如果为False或者None则在反向传播完后,就释放掉构建出来的graph,如果为True则不对graph进行释放[7][8]。
我们这里就有个问题,我们既然已经计算忘了梯度了,为什么还要保存graph呢?直接释放掉等待下一个迭代不就好了吗,不释放掉不会白白浪费内存吗?我们这里根据[7]中的讨论,简要介绍下为什么在某些情况下需要保留graph。如下图所示,我们用代码构造出此graph:
那些用pytorch踩过的坑
如果我们第一次需要对末节点d进行求梯度,我们有:
问题是在执行完反向传播之后,因为没有显式地要求它保留graph,系统对graph内存进行释放,如果下一步需要对节点e进行求梯度,那么将会因为没有这个graph而报错。因此有例子:

利用这个性质在某些场景是有作用的,比如在对抗生成网络GAN中需要先对某个模块比如生成器进行训练,后对判别器进行训练,这个时候整个网络就会存在两个以上的loss,例子如:
这个时候就可以对网络中多个loss进行分步的训练了。


   进行梯度累积,实现内存紧张情况下的大batch_size训练

在上面讨论的retain_graph参数中,还可以用于累积梯度,在GPU显存紧张的情况下使用可以等价于用更大的batch_size进行训练。首先我们要明白,当调用.backward()时,其实是对损失到各个节点的梯度进行计算,计算结果将会保存在各个节点上,如果不用opt.zero_grad()对其进行清0,那么只要你一直调用.backward()梯度就会一直累积,相当于是在大的batch_size下进行的训练。我们给出几个例子阐述我们的观点。

第一次输出:

在运行一次loss.backward(retain_graph=True),输出为:

同理,第三次:

运行一次opt.zero_grad(),输出为:

现在明白为什么我们一般在求梯度时要用opt.zero_grad()了吧,那是为什么不要这次的梯度结果被上一次给影响,但是在某些情况下这个‘影响’是可以利用的。


   调皮的dropout

这个在利用torch.nn.functional.dropout的时候,其参数为:

注意这里有个training指明了是否是在训练阶段,是否需要对神经元输出进行随机丢弃,这个是需要自行指定的,即便是用了model.train()或者model.eval()都是如此,这个和torch.nn.dropout不同,因为后者是一个层(Layer),而前者只是一个函数,不能纪录状态[9]。


   嘿,检查自己,说你呢, index_select

torch.index_select()是一个用于索引给定张量中某一个维度中元素的方法,其API手册如:

其作用很简单,比如我现在的输入张量为1000 * 10的尺寸大小,其中1000为样本数量,10为特征数目,如果我现在需要指定的某些样本,比如第1-100,300-400等等样本,我可以用一个index进行索引,然后应用torch.index_select()就可以索引了,例子如:

然而有一个问题是,pytorch似乎在使用GPU的情况下,不检查index是否会越界,因此如果你的index越界了,但是报错的地方可能不在使用index_select()的地方,而是在后续的代码中,这个似乎就需要留意下你的index了。同时,index是一个LongTensor,这个也是要留意的。


   悄悄地更新,BN层就是个小可爱

在trainning状态下,BN层的统计参数running_mean和running_var是在调用forward()后就更新的,这个和一般的参数不同,容易造成疑惑,考虑到篇幅较长,请移步到[11]。

   F.interpolate的问题

我们经常需要对图像进行插值,而pytorch的确也是提供对以tensor形式表示的图像进行插值的功能,那就是函数torch.nn.functional.interpolate[12],但是我们注意到这个插值函数有点特别,它是对以batch为单位的图像进行插值的,如果你想要用以下的代码去插值:

那么这样就会报错,因为此处的size只接受一个整数,其对W这个维度进行缩放,这里,interpolate会认为3是batch_size,因此如果需要对图像的H和W进行插值,那么我们应该如下操作:


Reference
[1]. Why does CrossEntropyLoss include the softmax function?https://discuss.pytorch.org/t/why-does-crossentropyloss-include-the-softmax-function/4420
[2]. Do I need to use softmax before nn.CrossEntropyLoss()?https://discuss.pytorch.org/t/do-i-need-to-use-softmax-before-nn-crossentropyloss/16739/2
[3]. tf.nn.softmax_cross_entropy_with_logits 将在未来弃用,https://blog.csdn.net/LoseInVain/article/details/80932605
[4]. 训练集,测试集,检验集的区别与交叉检验,https://blog.csdn.net/LoseInVain/article/details/78108955
[5]. Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015.
[6]. Hinton G E, Srivastava N, Krizhevsky A, et al. Improving neural networks by preventing co-adaptation of feature detectors[J]. arXiv preprint arXiv:1207.0580, 2012. 
[7]. What does the parameter retain_graph mean in the Variable's backward() method?https://stackoverflow.com/questions/46774641/what-does-the-parameter-retain-graph-mean-in-the-variables-backward-method
[8]. https://pytorch.org/docs/stable/autograd.html?highlight=backward#torch.Tensor.backward
[9] https://pytorch.org/docs/stable/nn.html?highlight=dropout#torch.nn.functional.dropout
[10]. https://github.com/pytorch/pytorch/issues/571
[11]. Pytorch的BatchNorm层使用中容易出现的问题,https://blog.csdn.net/LoseInVain/article/details/86476010
[12]. https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate


<section style="white-space: normal;line-height: 1.75em;text-align: center;"><strong style="color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;widows: 1;background-color: rgb(255, 255, 255);max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><pre><pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;">吴恩达推荐笔记:22 张图总结深度学习全部知识</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">图灵奖得主Yann LeCun《深度学习》春季课程</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">你一定从未看过如此通俗易懂的YOLO系列解读 (下)</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">22课时、19大主题,CS 231n进阶版课程视频上线</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="color: rgb(87, 107, 149);font-size: 14px;">震惊!这个街道办招8人,全是清华北大博士硕士!</span><br  /></section></section></section></section></section></section></section></section></section>
那些用pytorch踩过的坑

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享