来源:机器学习算法与自然语言处理
编辑:忆臻
https://www.zhihu.com/question/68384370
本文仅作为学术分享,如果侵权,会删文处理
PyTorch提取中间层特征?
作者:涩醉
https://www.zhihu.com/question/68384370/answer/751212803
通过pytorch的hook机制简单实现了一下,只输出conv层的特征图。
<span style="color:#cc7832;">import </span>torch<br /><span style="color:#cc7832;">from </span>torchvision.models <span style="color:#cc7832;">import </span>resnet18<br /><span style="color:#cc7832;">import </span>torch.nn <span style="color:#cc7832;">as </span>nn<br /><span style="color:#cc7832;">from </span>torchvision <span style="color:#cc7832;">import </span>transforms<br /><br /><span style="color:#cc7832;">import </span>matplotlib.pyplot <span style="color:#cc7832;">as </span>plt<br /><br /><br /><span style="color:#cc7832;">def </span><span style="color:#ffc66d;">viz</span>(<span style="color:#808080;">module</span><span style="color:#cc7832;">, </span>input):<br /> x = input[<span style="color:#6897bb;">0</span>][<span style="color:#6897bb;">0</span>]<br /> <span style="color:#808080;">#最多显示4张图<br /></span><span style="color:#808080;"> </span>min_num = np.minimum(<span style="color:#6897bb;">4</span><span style="color:#cc7832;">, </span>x.size()[<span style="color:#6897bb;">0</span>])<br /> <span style="color:#cc7832;">for </span>i <span style="color:#cc7832;">in </span><span style="color:#8888c6;">range</span>(min_num):<br /> plt.subplot(<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">4</span><span style="color:#cc7832;">, </span>i+<span style="color:#6897bb;">1</span>)<br /> plt.imshow(x[i])<br /> plt.show()<br /><br /><br /><span style="color:#cc7832;">import </span>cv2<br /><span style="color:#cc7832;">import </span>numpy <span style="color:#cc7832;">as </span>np<br /><span style="color:#cc7832;">def </span><span style="color:#ffc66d;">main</span>():<br /> t = transforms.Compose([transforms.ToPILImage()<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>transforms.Resize((<span style="color:#6897bb;">224</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">224</span>))<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>transforms.ToTensor()<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>transforms.Normalize(<span style="color:#aa4926;">mean</span>=[<span style="color:#6897bb;">0.485</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.456</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.406</span>]<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span><span style="color:#aa4926;">std</span>=[<span style="color:#6897bb;">0.229</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.224</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.225</span>])<br /> ])<br /><br /> device = torch.device(<span style="color:#6a8759;">"cuda:0" </span><span style="color:#cc7832;">if </span>torch.cuda.is_available() <span style="color:#cc7832;">else </span><span style="color:#6a8759;">"cpu"</span>)<br /><br /> model = resnet18(<span style="color:#aa4926;">pretrained</span>=<span style="color:#cc7832;">True</span>).to(device)<br /> <span style="color:#cc7832;">for </span>name<span style="color:#cc7832;">, </span>m <span style="color:#cc7832;">in </span>model.named_modules():<br /> <span style="color:#808080;"># if not isinstance(m, torch.nn.ModuleList) and<br /></span><span style="color:#808080;"> # not isinstance(m, torch.nn.Sequential) and<br /></span><span style="color:#808080;"> # type(m) in torch.nn.__dict__.values():<br /></span><span style="color:#808080;"> # 这里只对卷积层的feature map进行显示<br /></span><span style="color:#808080;"> </span><span style="color:#cc7832;">if </span><span style="color:#8888c6;">isinstance</span>(m<span style="color:#cc7832;">, </span>torch.nn.Conv2d):<br /> m.register_forward_pre_hook(viz)<br /> img = cv2.imread(<span style="color:#6a8759;">'/Users/edgar/Desktop/cat.jpeg'</span>)<br /> img = t(img).unsqueeze(<span style="color:#6897bb;">0</span>).to(device)<br /> <span style="color:#cc7832;">with </span>torch.no_grad():<br /> model(img)<br /><br /><span style="color:#cc7832;">if </span>__name__ == <span style="color:#6a8759;">'__main__'</span>:<br /> main()
打印的特征图大概是这个样子,取了第一层以及第四层的特征图。
作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762
建议使用hook,在不改变网络forward函数的基础上提取所需的特征或者梯度,在调用阶段对module使用即可获得所需梯度或者特征。
inter_feature = {}<br /> inter_gradient = {}<br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">make_hook</span>(name<span style="color:#cc7832;">, </span>flag):<br /> <span style="color:#cc7832;">if </span>flag == <span style="color:#6a8759;">'forward'</span>:<br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">hook</span>(m<span style="color:#cc7832;">, </span>input<span style="color:#cc7832;">, </span>output):<br /> inter_feature[name] = input<br /> <span style="color:#cc7832;">return </span>hook<br /> <span style="color:#cc7832;">elif </span>flag == <span style="color:#6a8759;">'backward'</span>:<br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">hook</span>(m<span style="color:#cc7832;">, </span>input<span style="color:#cc7832;">, </span>output):<br /> inter_gradient[name] = output<br /> <span style="color:#cc7832;">return </span>hook<br /> <span style="color:#cc7832;">else</span>:<br /> <span style="color:#cc7832;">assert False<br /></span>m.register_forward_hook(make_hook(name<span style="color:#cc7832;">, </span><span style="color:#6a8759;">'forward'</span>))<br />m.register_backward_hook(make_hook(name<span style="color:#cc7832;">, </span><span style="color:#6a8759;">'backward'</span>))
在前向计算和反向计算的时候即可达到类似钩子的作用,中间变量已经被放置于inter_feature 和 inter_gradient。
output = model(<span style="color:#8888c6;">input</span>) <span style="color:#808080;"># achieve intermediate feature<br /></span>loss = criterion(output<span style="color:#cc7832;">, </span>target)<br />loss.backward() <span style="color:#808080;"># achieve backward intermediate gradients</span>
最后可根据需求是否释放hook。
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, 'Helvetica Neue', 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei UI', 'Microsoft YaHei', Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);">hook.remove()</span>
作者:罗一成
https://www.zhihu.com/question/68384370/answer/263120790
提取中间特征是指把中间的weights给提出来吗?这样不是直接访问那个矩阵不就好了吗? pytorch在存参数的时候, 其实就是给所有的weights bias之类的起个名字然后存在了一个字典里面. 不然你看看state_dict.keys(), 找到相对应的key拿出来就好了.
然后你说的慎用也是一个很奇怪的问题啊..
就算用modules下面的class, 你存模型的时候因为你的activation function上面本身没有参数, 所以也不会被存进去. 不然你可以试试在Sequential里面把relu换成sigmoid, 你还是可以把之前存的state_dict给load回去.
不能说是慎用functional吧, 我觉得其他的设置是应该分开也存一份的(假设你把这些当做超参的话)
利益相关: 给pytorch提过PR
为您推荐
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
内容反馈