知行编程网知行编程网  2022-07-14 12:00 知行编程网 隐藏边栏 |   抢沙发  245 
文章评分 0 次,平均分 0.0

在 C++ 平台上部署 PyTorch 模型流程+踩坑实录

来自 | 知乎   作者丨火星少女

链接丨https://zhuanlan.zhihu.com/p/146453159

编辑丨极市平台


最近因为工作需要,要把pytorch的模型部署到c++平台上,基本过程主要参照官网的教学示例,期间发现了不少坑,特此记录。


1.模型转换

libtorch不依赖于python,python训练的模型,需要转换为script model才能由libtorch加载,并进行推理。在这一步官网提供了两种方法:


方法一:Tracing

这种方法操作比较简单,只需要给模型一组输入,走一遍推理网络,然后由torch.ji.trace记录一下路径上的信息并保存即可。示例如下:

缺点是如果模型中存在控制流比如if-else语句,一组输入只能遍历一个分支,这种情况下就没办法完整的把模型信息记录下来。


方法二:Scripting

直接在Torch脚本中编写模型并相应地注释模型,通过torch.jit.script编译模块,将其转换为ScriptModule。示例如下:

forward方法会被默认编译,forward中被调用的方法也会按照被调用的顺序被编译

如果想要编译一个forward以外且未被forward调用的方法,可以添加 @torch.jit.export.


如果想要方法不被编译,可使用

@torch.jit.ignore(https://pytorch.org/docs/master/generated/torch.jit.ignore.html#torch.jit.ignore) 

或者 @torch.jit.unused(https://pytorch.org/docs/master/generated/torch.jit.unused.html#torch.jit.unused)


在这一步遇到好多坑,主要原因可归为一下两点


1. 不支持的操作

TorchScript支持的操作是python的子集,大部分torch中用到的操作都可以找到对应实现,但也存在一些尴尬的不支持操作,详细列表可见https://pytorch.org/docs/master/jit_unsupported.html#jit-unsupported,下面列一些我自己遇到的操作:

1)参数/返回值不支持可变个数,例如

或者


2)各种iteration操作

eg1.

报错torch.jit.frontend.UnsupportedNodeError: ListComp aren’t supported

可以改成:

eg2.

eg3.


3)不支持的语句

eg1. 不支持continue

torch.jit.frontend.UnsupportedNodeError: continue statements aren’t supported

eg2. 不支持try-catch

torch.jit.frontend.UnsupportedNodeError: try blocks aren’t supported

eg3. 不支持with语句


4)其他常见op/module

eg1. torch.autograd.Variable

解决:使用torch.ones/torch.randn等初始化+.float()/.long()等指定数据类型。

eg2. torch.Tensor/torch.LongTensor etc.

解决:同上

eg3. requires_grad参数只在torch.tensor中支持,torch.ones/torch.zeros等不可用

eg4. tensor.numpy()

eg5. tensor.bool()

解决:tensor.bool()用tensor>0代替

eg6. self.seg_emb(seg_fea_ids).to(embeds.device)

解决:需要转gpu的地方显示调用.cuda()

总之一句话:除了原生python和pytorch以外的库,比如numpy什么的能不用就不用,尽量用pytorch的各种API


2. 指定数据类型

1)属性,大部分的成员数据类型可以根据值来推断,空的列表/字典则需要预先指定

2)常量,使用Final关键字

3)变量。默认是tensor类型且不可变,所以非tensor类型必须要指明


方法三:Tracing and Scriptin混合

一种是在trace模型中调用script,适合模型中只有一小部分需要用到控制流的情况,使用实例如下:

另一种情况是在script module中用tracing生成子模块,对于一些存在script module不支持的python feature的layer,就可以把相关layer封装起来,用trace记录相关layer流,其他layer不用修改。使用示例如下:


2.保存序列化模型

如果上一步的坑都踩完,那么模型保存就非常简单了,只需要调用save并传递一个文件名即可,需要注意的是如果想要在gpu上训练模型,在cpu上做inference,一定要在模型save之前转化,再就是记得调用model.eval(),形如


3.C++ load训练好的模型

要在C ++中加载序列化的PyTorch模型,必须依赖于PyTorch C ++ API(也称为LibTorch)。libtorch的安装非常简单,只需要在pytorch官网(https://pytorch.org/)下载对应版本,解压即可。会得到一个结构如下的文件夹。

然后就可以构建应用程序了,一个简单的示例目录结构如下:

example-app.cpp和CMakeLists.txt的示例代码分别如下:


至此,就可以运行以下命令从example-app/文件夹中构建应用程序啦:

其中/path/to/libtorch是之前下载后的libtorch文件夹所在的路径。这一步如果顺利能够看到编译完成100%的提示,下一步运行编译生成的可执行文件,会看到“ok”的输出,可喜可贺!


4. 执行Script Module

终于到最后一步啦!下面只需要按照构建输入传给模型,执行forward就可以得到输出啦。一个简单的示例如下:

前两行创建一个torch::jit::IValue的向量,并添加单个输入. 使用torch::ones()创建输入张量,等效于C ++ API中的torch.ones。然后,运行script::Moduleforward方法,通过调用toTensor()将返回的IValue值转换为张量。C++对torch的各种操作还是比较友好的,通过torch::或者后加_的方法都可以找到对应实现,例如

最后check一下确保c++端的输出和pytorch是一致的就大功告成啦~

踩了无数坑,薅掉了无数头发,很多东西也是自己一点点摸索的,如果有错误欢迎指正!


参考资料:

PyTorch C++ API - PyTorch master document

Torch Script - PyTorch master documentation

文章地址:

https://pytorch.org/cppdocs/

https://pytorch.org/tutorials/advanced/cpp_export.html


<section data-brushtype="text" style="padding-right: 0em;padding-left: 0em;white-space: normal;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-family: "Helvetica Neue", Helvetica, "Hiragino Sans GB", "Microsoft YaHei", Arial, sans-serif;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;"><strong style="color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong>完<strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong></span></strong></span></strong></section><pre><pre><section style="letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="letter-spacing: 0.544px;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;text-align: center;"><span style="color: rgb(0, 0, 0);"><strong><span style="font-size: 16px;font-family: 微软雅黑;caret-color: red;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">一文了解深度推荐算法的演进</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">干货 | 算法工程师超实用技术路线图</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;"><span style="font-size: 14px;">13个算法工程师必须掌握的PyTorch Tricks</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;"><span style="font-size: 14px;">吴恩达上新:生成对抗网络(GAN)专项课程</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">拿到2021灰飞烟灭算法岗offer的大佬们是啥样的<span style="font-size: 14px;">?</span></section></section></section></section></section></section></section></section></section>

在 C++ 平台上部署 PyTorch 模型流程+踩坑实录

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享