知行编程网知行编程网  2022-07-04 12:00 知行编程网 隐藏边栏 |   抢沙发  5 
文章评分 0 次,平均分 0.0

13个算法工程师必须掌握的PyTorch Tricks

来自 | 知乎   作者丨z.defying
链接丨https://zhuanlan.zhihu.com/p/76459295

仅作学术分享,如有侵权,请联系删文。

13个算法工程师必须掌握的PyTorch Tricks


   目录

1、指定GPU编号

2、查看模型每层输出详情

3、梯度裁剪

4、扩展单张图片维度

5、one hot编码

6、防止验证模型时爆显存

7、学习率衰减

8、冻结某些层的参数

9、对不同层使用不同学习率

10、模型相关操作

11、Pytorch内置one hot函数

12、网络参数初始化

13、加载内置预训练模型


   1、指定GPU编号

  • 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  • 设置当前使用的GPU设备为0,1号两个设备,名称依次为 /gpu:0/gpu:1os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" ,根据顺序表示优先使用0号设备,然后使用1号设备。

指定GPU的命令需要放在和神经网络相关的一系列操作的前面。


   2、查看模型每层输出详情

Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网络时非常有用。现在在PyTorch中也可以实现这个功能。
使用很简单,如下用法:
input_size 是根据你自己的网络模型的输入尺寸进行设置。


   3、梯度裁剪(Gradient Clipping)

nn.utils.clip_grad_norm_ 的参数:
  • parameters – 一个基于变量的迭代器,会进行梯度归一化
  • max_norm – 梯度的最大范数
  • norm_type – 规定范数的类型,默认为L2
@不椭的椭圆 提出:梯度裁剪在某些任务上会额外消耗大量的计算时间,可移步评论区查看详情。


   4、扩展单张图片维度

因为在训练时的数据维度一般都是 (batch_size, c, h, w),而在测试时只输入一张图片,所以需要扩展维度,扩展维度有多个方法:
或(感谢 @coldleaf 的补充)
tensor.unsqueeze(dim):扩展维度,dim指定扩展哪个维度。
tensor.squeeze(dim):去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。


   5、独热编码

在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。
注:第11条有更简单的方法。


   6、防止验证模型时爆显存

验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。
感谢@zhaz 的提醒,我把 torch.cuda.empty_cache() 的使用原因更新一下。
这是原回答:
Pytorch 训练时无用的临时变量可能会越来越多,导致 out of memory ,可以使用下面语句来清理这些不需要的变量。
官网 上的解释为:
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible innvidia-smi. torch.cuda.empty_cache()
意思就是PyTorch的缓存分配器会事先分配一些固定的显存,即使实际上tensors并没有使用完这些显存,这些显存也不能被其他应用使用。这个分配过程由第一次CUDA内存访问触发的。
torch.cuda.empty_cache() 的作用就是释放缓存分配器当前持有的且未占用的缓存显存,以便这些显存可以被其他GPU应用程序中使用,并且通过 nvidia-smi命令可见。注意使用此命令不会释放tensors占用的显存。
对于不用的数据变量,Pytorch 可以自动进行回收从而释放相应的显存。
更详细的优化可以查看 优化显存使用 和 显存利用问题。


   7、学习率衰减

可以随时查看学习率的值:optimizer.param_groups[0]['lr']
还有其他学习率更新的方式:
1、自定义更新公式:
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch:1/(epoch+1))
2、不依赖epoch更新学习率:
lr_scheduler.ReduceLROnPlateau()提供了基于训练中某些测量值使学习率动态下降的方法,它的参数说明到处都可以查到。
提醒一点就是参数 mode='min' 还是'max',取决于优化的的损失还是准确率,即使用
scheduler.step(loss)还是scheduler.step(acc)

   8、冻结某些层的参数

参考:https://www.zhihu.com/question/311095447/answer/589307812
在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。
我们需要先知道每一层的名字,通过如下代码打印:
假设前几层信息如下:
后面的True表示该层的参数可训练,然后我们定义一个要冻结的层的列表:
冻结方法如下:
冻结后我们再打印每层的信息:
可以看到前两层的weight和bias的requires_grad都为False,表示它们不可训练。
最后在定义优化器时,只对requires_grad为True的层的参数进行更新。

   9、对不同层使用不同学习率

我们对模型的不同层使用不同的学习率。
还是使用这个模型作为例子:
对 convolution1 和 convolution2 设置不同的学习率,首先将它们分开,即放到不同的列表里:
我们将模型划分为两部分,存放到一个列表里,每部分就对应上面的一个字典,在字典里设置不同的学习率。当这两部分有相同的其他参数时,就将该参数放到列表外面作为全局参数,如上面的`weight_decay`。
也可以在列表外设置一个全局学习率,当各部分字典里设置了局部学习率时,就使用该学习率,否则就使用列表外的全局学习率。

   10、模型相关操作

这个内容比较多,我写成了一篇文章:https://zhuanlan.zhihu.com/p/73893187
  

   11、Pytorch内置one_hot函数

感谢@yangyangyang 补充:Pytorch 1.1后,one_hot可以直接用torch.nn.functional.one_hot
然后我将Pytorch升级到1.2版本,试用了下 one_hot 函数,确实很方便。
具体用法如下:
F.one_hot会自己检测不同类别个数,生成对应独热编码。我们也可以自己指定类别数:
升级 Pytorch (cpu版本)的命令:conda install pytorch torchvision -c pytorch
(希望Pytorch升级不会影响项目代码)

   12、网络参数初始化

神经网络的初始化是训练流程的重要基础环节,会对模型的性能、收敛性、收敛速度等产生重要的影响。

以下介绍两种常用的初始化操作。
(1) 使用pytorch内置的torch.nn.init方法。
常用的初始化操作,例如正态分布、均匀分布、xavier初始化、kaiming初始化等都已经实现,可以直接使用。具体详见PyTorch 中 torch.nn.init 中文文档。
(2) 对于一些更加灵活的初始化方法,可以借助numpy。
对于自定义的初始化方法,有时tensor的功能不如numpy强大灵活,故可以借助numpy实现初始化方法,再转换到tensor上使用。

 

   13、加载内置预训练模型

torchvision.models模块的子模块中包含以下模型:
  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
导入这些模型的方法为:
有一个很重要的参数为pretrained,默认为False,表示只导入模型的结构,其中的权重是随机初始化的。
如果pretrainedTrue,表示导入的是在ImageNet数据集上预训练的模型。

更多的模型可以查看:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/

<pre><section data-brushtype="text" style="padding-right: 0em;padding-left: 0em;white-space: normal;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-family: "Helvetica Neue", Helvetica, "Hiragino Sans GB", "Microsoft YaHei", Arial, sans-serif;font-size: 16px;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);background-color: rgb(255, 255, 255);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><pre style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-size: 16px;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);background-color: rgb(255, 255, 255);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">干货 | 算法工程师超实用技术路线图</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">Github上10个超好看的可视化面板,nice!</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">吴恩达上新:生成对抗网络(GAN)专项课程</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">拿到2021灰飞烟灭算法岗offer的大佬们是啥样的</span><span style="font-size: 14px;">?</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">你一定从未看过如此通俗易懂的YOLO系列解读 (下)</section></section></section></section></section></section></section></section></section>

13个算法工程师必须掌握的PyTorch Tricks

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享