pytorch中的交叉熵
pytorch的交叉熵nn.CrossEntropyLoss
在训练阶段,里面是内置了softmax
操作的,因此只需要喂入原始的数据结果即可,不需要在之前再添加softmax
层。这个和tensorflow的tf.softmax_cross_entropy_with_logits
如出一辙.[1][2]pytorch的交叉熵nn.CrossEntropyLoss
在训练阶段,里面是内置了softmax
操作的,因此只需要喂入原始的数据结果即可,不需要在之前再添加softmax
层。这个和tensorflow的tf.softmax_cross_entropy_with_logits
如出一辙.[1][2]
pytorch中的MSELoss和KLDivLoss
在深度学习中,MSELoss
均方差损失和KLDivLoss
KL散度是经常使用的两种损失,在pytorch中,也有这两个函数,如:
这个时候我们要注意到,我们的标签target
是需要一个不能被训练的,也就是requires_grad=False
的值,否则将会报错,出现如:
我们注意到,其实不只是MSELoss
,其他很多loss,比如交叉熵,KL散度等,其target
都需要是一个不能被训练的值的,这个和TensorFlow中的tf.nn.softmax_cross_entropy_with_logits_v2
不太一样,后者可以使用可训练的target
,具体见[3]
在验证和测试阶段取消掉梯度(no_grad)
一般来说,我们在进行模型训练的过程中,因为要监控模型的性能,在跑完若干个epoch
训练之后,需要进行一次在验证集[4]上的性能验证。一般来说,在验证或者是测试阶段,因为只是需要跑个前向传播(forward)就足够了,因此不需要保存变量的梯度。保存梯度是需要额外显存或者内存进行保存的,占用了空间,有时候还会在验证阶段导致OOM(Out Of Memory)错误,因此我们在验证和测试阶段,最好显式地取消掉模型变量的梯度。 在pytroch 0.4
及其以后的版本中,用torch.no_grad()
这个上下文管理器就可以了,例子如下:
如上,我们只需要在加上上下文管理器就可以很方便的取消掉梯度。这个功能在pytorch
以前的版本中,通过设置volatile=True
生效,不过现在这个用法已经被抛弃了。
显式指定model.train()
和model.eval()
dropout
[6]中的丢弃率和Batch Normalization
[5]中的和等,这个时候我们就需要显式地指定不同的阶段(训练或者测试),在pytorch
中我们通过model.train()
和model.eval()
进行显式指定,具体如:
关于retain_graph
的使用
pytorch
中调用out.backward()
即可实现,给个小例子如:
loss
进行反向传播就可以求得,即是损失对于每个叶子节点的梯度。我们注意到,在.backward()
这个API的文档中,有几个参数,如:
retain_graph
这个参数,这个参数如果为False
或者None
则在反向传播完后,就释放掉构建出来的graph,如果为True
则不对graph进行释放[7][8]。
d
进行求梯度,我们有:
e
进行求梯度,那么将会因为没有这个graph而报错。因此有例子:
loss
,例子如:
loss
进行分步的训练了。进行梯度累积,实现内存紧张情况下的大batch_size
训练
retain_graph
参数中,还可以用于累积梯度,在GPU显存紧张的情况下使用可以等价于用更大的batch_size
进行训练。首先我们要明白,当调用.backward()
时,其实是对损失到各个节点的梯度进行计算,计算结果将会保存在各个节点上,如果不用opt.zero_grad()
对其进行清0,那么只要你一直调用.backward()
梯度就会一直累积,相当于是在大的batch_size
下进行的训练。我们给出几个例子阐述我们的观点。
loss.backward(retain_graph=True)
,输出为:
opt.zero_grad()
,输出为:
opt.zero_grad()
了吧,那是为什么不要这次的梯度结果被上一次给影响,但是在某些情况下这个‘影响’是可以利用的。调皮的dropout
torch.nn.functional.dropout
的时候,其参数为:
training
指明了是否是在训练阶段,是否需要对神经元输出进行随机丢弃,这个是需要自行指定的,即便是用了model.train()
或者model.eval()
都是如此,这个和torch.nn.dropout
不同,因为后者是一个层(Layer),而前者只是一个函数,不能纪录状态[9]。嘿,检查自己,说你呢, index_select
torch.index_select()
是一个用于索引给定张量中某一个维度中元素的方法,其API手册如:
1000 * 10
的尺寸大小,其中1000
为样本数量,10
为特征数目,如果我现在需要指定的某些样本,比如第1-100
,300-400
等等样本,我可以用一个index
进行索引,然后应用torch.index_select()
就可以索引了,例子如:
pytorch
似乎在使用GPU
的情况下,不检查index
是否会越界,因此如果你的index
越界了,但是报错的地方可能不在使用index_select()
的地方,而是在后续的代码中,这个似乎就需要留意下你的index
了。同时,index
是一个LongTensor
,这个也是要留意的。悄悄地更新,BN层就是个小可爱
running_mean
和running_var
是在调用forward()
后就更新的,这个和一般的参数不同,容易造成疑惑,考虑到篇幅较长,请移步到[11]。F.interpolate
的问题
pytorch
的确也是提供对以tensor
形式表示的图像进行插值的功能,那就是函数torch.nn.functional.interpolate
[12],但是我们注意到这个插值函数有点特别,它是对以batch
为单位的图像进行插值的,如果你想要用以下的代码去插值:
size
只接受一个整数,其对W这个维度进行缩放,这里,interpolate
会认为3是batch_size
,因此如果需要对图像的H和W进行插值,那么我们应该如下操作:
Reference
<section style="white-space: normal;line-height: 1.75em;text-align: center;"><strong style="color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;widows: 1;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong>完<strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong></span></strong></span></strong></section><pre><pre><section style="letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="letter-spacing: 0.544px;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;text-align: center;"><span style="color: rgb(0, 0, 0);"><strong><span style="font-size: 16px;font-family: 微软雅黑;caret-color: red;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">图灵奖得主Yann LeCun《深度学习》春季课程</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">你一定从未看过如此通俗易懂的YOLO系列解读 (下)</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;color: rgb(0, 0, 0);">22课时、19大主题,CS 231n进阶版课程视频上线<br /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;color: rgb(0, 0, 0);">数据分析入门常用的23个牛逼Pandas代码</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;"><span style="color: rgb(87, 107, 149);font-size: 14px;">如何在科研论文中画出漂亮的插图?</span><br /></section></section></section></section></section></section></section></section></section>
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
你可能也喜欢
- ♥ 十大反直觉的数学结论06/29
- ♥ 知乎热议:高数、线代应该成为计算机专业学习的重心吗?07/27
- ♥ 限时报名 | 带学《百面机器学习》葫芦书,算法+leetcode一应俱全03/14
- ♥ NeurIPS 2020 所有RL papers全扫荡01/26
- ♥ 博士最“惨”能到什么程度?04/30
- ♥ 如何教老婆学Python?03/15
内容反馈