编辑:忆臻
https://www.zhihu.com/question/66782101
本文仅作为学术分享,如果侵权,会删文处理
机器学习算法与自然语言处理报道
PyTorch 中,nn 与 nn.functional 有什么区别?
作者:蒲嘉宸
https://www.zhihu.com/question/66782101/answer/246341271
其实这两个是差不多的,不过一个包装好的类,一个是可以直接调用的函数。我们可以去翻这两个模块的具体实现代码,我下面以卷积Conv1d为例。
首先是torch.nn下的Conv1d:
<span style="color:#cc7832;">class </span>Conv1d(_ConvNd):<br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>in_channels<span style="color:#cc7832;">, </span>out_channels<span style="color:#cc7832;">, </span>kernel_size<span style="color:#cc7832;">, </span>stride=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>padding=<span style="color:#6897bb;">0</span><span style="color:#cc7832;">, </span>dilation=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span>groups=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span>bias=<span style="color:#cc7832;">True</span>):<br /> kernel_size = _single(kernel_size)<br /> stride = _single(stride)<br /> padding = _single(padding)<br /> dilation = _single(dilation)<br /> <span style="color:#8888c6;">super</span>(Conv1d<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>(<br /> in_channels<span style="color:#cc7832;">, </span>out_channels<span style="color:#cc7832;">, </span>kernel_size<span style="color:#cc7832;">, </span>stride<span style="color:#cc7832;">, </span>padding<span style="color:#cc7832;">, </span>dilation<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> False, </span>_single(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>groups<span style="color:#cc7832;">, </span>bias)<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>input):<br /> <span style="color:#cc7832;">return </span>F.conv1d(input<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.stride<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span><span style="color:#94558d;">self</span>.padding<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.dilation<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.groups)
这是torch.nn.functional下的conv1d:
<span style="color:#cc7832;">def </span><span style="color:#ffc66d;">conv1d</span>(input<span style="color:#cc7832;">, </span>weight<span style="color:#cc7832;">, </span>bias=<span style="color:#cc7832;">None, </span>stride=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span>padding=<span style="color:#6897bb;">0</span><span style="color:#cc7832;">, </span>dilation=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>groups=<span style="color:#6897bb;">1</span>):<br /> <span style="color:#cc7832;">if </span>input <span style="color:#cc7832;">is not None and </span>input.dim() != <span style="color:#6897bb;">3</span>:<br /> <span style="color:#cc7832;">raise </span><span style="color:#8888c6;">ValueError</span>(<span style="color:#6a8759;">"Expected 3D tensor as input, got {}D tensor instead."</span>.format(input.dim()))<br /><br /> f = ConvNd(_single(stride)<span style="color:#cc7832;">, </span>_single(padding)<span style="color:#cc7832;">, </span>_single(dilation)<span style="color:#cc7832;">, False,<br /></span><span style="color:#cc7832;"> </span>_single(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>groups<span style="color:#cc7832;">, </span>torch.backends.cudnn.benchmark<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>torch.backends.cudnn.deterministic<span style="color:#cc7832;">, </span>torch.backends.cudnn.enabled)<br /> <span style="color:#cc7832;">return </span>f(input<span style="color:#cc7832;">, </span>weight<span style="color:#cc7832;">, </span>bias)
可以看到torch.nn下的Conv1d类在forward时调用了nn.functional下的conv1d,当然最终的计算是通过C++编写的THNN库中的ConvNd进行计算的,因此这两个其实是互相调用的关系。
你可能会疑惑为什么需要这两个功能如此相近的模块,其实这么设计是有其原因的。如果我们只保留nn.functional下的函数的话,在训练或者使用时,我们就要手动去维护weight, bias, stride这些中间量的值,这显然是给用户带来了不便。而如果我们只保留nn下的类的话,其实就牺牲了一部分灵活性,因为做一些简单的计算都需要创造一个类,这也与PyTorch的风格不符。
下面我举几个例子方便各位理解:
<span style="color:#cc7832;">import </span>torch<br /><span style="color:#cc7832;">import </span>torch.nn <span style="color:#cc7832;">as </span>nn<br /><span style="color:#cc7832;">import </span>torch.nn.functional <span style="color:#cc7832;">as </span>F<br /><br /><br /><span style="color:#cc7832;">class </span>Net(nn.Module):<br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br /> <span style="color:#8888c6;">super</span>(Net<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br /> <span style="color:#94558d;">self</span>.fc1 = nn.Linear(<span style="color:#6897bb;">512</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">256</span>)<br /> <span style="color:#94558d;">self</span>.fc2 = nn.Linear(<span style="color:#6897bb;">256</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">256</span>)<br /> <span style="color:#94558d;">self</span>.fc3 = nn.Linear(<span style="color:#6897bb;">256</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">2</span>)<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br /> x = F.relu(<span style="color:#94558d;">self</span>.fc1(x))<br /> x = F.relu(F.dropout(<span style="color:#94558d;">self</span>.fc2(x)<span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.5</span>))<br /> x = F.dropout(<span style="color:#94558d;">self</span>.fc3(x)<span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.5</span>)<br /> <span style="color:#cc7832;">return </span>x
以一个最简单的三层网络为例。需要维持状态的,主要是三个线性变换,所以在构造Module是,定义了三个nn.Linear对象,而在计算时,relu,dropout之类不需要保存状态的可以直接使用。
注:dropout的话有个坑,需要设置自行设定training的state。
作者:有糖吃可好
https://www.zhihu.com/question/66782101/answer/579393790
两者的相同之处:
-
nn.Xxx和nn.functional.xxx的实际功能是相同的,即nn.Conv2d和nn.functional.conv2d 都是进行卷积,nn.Dropout 和nn.functional.dropout都是进行dropout,。。。。。;
-
运行效率也是近乎相同。
-
nn.functional.xxx是函数接口,而nn.Xxx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。这一点导致nn.Xxx除了具有nn.functional.xxx功能之外,内部附带了nn.Module相关的属性和方法,例如train(), eval(),load_state_dict, state_dict 等。
两者的差别之处:
-
两者的调用方式不同。
nn.Xxx 需要先实例化并传入参数,然后以函数调用的方式调用实例化的对象并传入输入数据。
inputs = torch.rand(<span style="color:#6897bb;">64</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">244</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">244</span>)<br />conv = nn.Conv2d(<span style="color:#aa4926;">in_channels</span>=<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">out_channels</span>=<span style="color:#6897bb;">64</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">1</span>)<br />out = conv(inputs)
nn.functional.xxx同时传入输入数据和weight, bias等其他参数 。
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
weight = torch.rand(<span style="color:#6897bb;">64</span><span style="color:#cc7832;">,</span><span style="color:#6897bb;">3</span><span style="color:#cc7832;">,</span><span style="color:#6897bb;">3</span><span style="color:#cc7832;">,</span><span style="color:#6897bb;">3</span>)<br />bias = torch.rand(<span style="color:#6897bb;">64</span>)<br />out = nn.functional.conv2d(inputs<span style="color:#cc7832;">, </span>weight<span style="color:#cc7832;">, </span>bias<span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">1</span>)
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br />
-
nn.Xxx继承于nn.Module, 能够很好的与nn.Sequential结合使用, 而nn.functional.xxx无法与nn.Sequential结合使用。
fm_layer = nn.Sequential(<br /> nn.Conv2d(<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">64</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">1</span>)<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>nn.BatchNorm2d(<span style="color:#aa4926;">num_features</span>=<span style="color:#6897bb;">64</span>)<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>nn.ReLU()<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>nn.MaxPool2d(<span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">2</span>)<span style="color:#cc7832;">,<br /></span><span style="color:#cc7832;"> </span>nn.Dropout(<span style="color:#6897bb;">0.2</span>)<br />)
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br />
-
nn.Xxx不需要你自己定义和管理weight;而nn.functional.xxx需要你自己定义weight,每次调用的时候都需要手动传入weight, 不利于代码复用。
使用nn.Xxx定义一个CNN 。
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>CNN(nn.Module):<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br /> <span style="color:#8888c6;">super</span>(CNN<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br /><br /> <span style="color:#94558d;">self</span>.cnn1 = nn.Conv2d(<span style="color:#aa4926;">in_channels</span>=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">out_channels</span>=<span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">0</span>)<br /> <span style="color:#94558d;">self</span>.relu1 = nn.ReLU()<br /> <span style="color:#94558d;">self</span>.maxpool1 = nn.MaxPool2d(<span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">2</span>)<br /><br /> <span style="color:#94558d;">self</span>.cnn2 = nn.Conv2d(<span style="color:#aa4926;">in_channels</span>=<span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">out_channels</span>=<span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">0</span>)<br /> <span style="color:#94558d;">self</span>.relu2 = nn.ReLU()<br /> <span style="color:#94558d;">self</span>.maxpool2 = nn.MaxPool2d(<span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">2</span>)<br /><br /> <span style="color:#94558d;">self</span>.linear1 = nn.Linear(<span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">10</span>)<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br /> x = x.view(x.size(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>-<span style="color:#6897bb;">1</span>)<br /> out = <span style="color:#94558d;">self</span>.maxpool1(<span style="color:#94558d;">self</span>.relu1(<span style="color:#94558d;">self</span>.cnn1(x)))<br /> out = <span style="color:#94558d;">self</span>.maxpool2(<span style="color:#94558d;">self</span>.relu2(<span style="color:#94558d;">self</span>.cnn2(out)))<br /> out = <span style="color:#94558d;">self</span>.linear1(out.view(x.size(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>-<span style="color:#6897bb;">1</span>))<br /> <span style="color:#cc7832;">return </span>out
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br />
使用nn.function.xxx定义一个与上面相同的CNN。
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>CNN(nn.Module):<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br /> <span style="color:#8888c6;">super</span>(CNN<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br /><br /> <span style="color:#94558d;">self</span>.cnn1_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span>))<br /> <span style="color:#94558d;">self</span>.bias1_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">16</span>))<br /><br /> <span style="color:#94558d;">self</span>.cnn2_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span>))<br /> <span style="color:#94558d;">self</span>.bias2_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">32</span>))<br /><br /> <span style="color:#94558d;">self</span>.linear1_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">10</span>))<br /> <span style="color:#94558d;">self</span>.bias3_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">10</span>))<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br /> x = x.view(x.size(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>-<span style="color:#6897bb;">1</span>)<br /> out = F.conv2d(x<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.cnn1_weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias1_weight)<br /> out = F.relu(out)<br /> <span style="color:#808080;">out </span>= F.max_pool2d(out)<br /><br /> out = F.conv2d(x<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.cnn2_weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias2_weight)<br /> out = F.relu(out)<br /> <span style="color:#808080;">out </span>= F.max_pool2d(out)<br /><br /> out = F.linear(x<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.linear1_weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias3_weight)<br /> <span style="color:#cc7832;">return </span>out
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br />
上面两种定义方式得到CNN功能都是相同的,至于喜欢哪一种方式,是个人口味问题,但PyTorch官方推荐:具有学习参数的(例如,conv2d, linear, batch_norm)采用nn.Xxx方式,没有学习参数的(例如,maxpool, loss func, activation func)等根据个人选择使用nn.functional.xxx或者nn.Xxx方式。但关于dropout,个人强烈推荐使用nn.Xxx方式,因为一般情况下只有训练阶段才进行dropout,在eval阶段都不会进行dropout。使用nn.Xxx方式定义dropout,在调用model.eval()之后,model中所有的dropout layer都关闭,但以nn.function.dropout方式定义dropout,在调用model.eval()之后并不能关闭dropout。
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>Model1(nn.Module):<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br /> <span style="color:#8888c6;">super</span>(Model1<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br /> <span style="color:#94558d;">self</span>.dropout = nn.Dropout(<span style="color:#6897bb;">0.5</span>)<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br /> <span style="color:#cc7832;">return </span><span style="color:#94558d;">self</span>.dropout(x)<br /><br /><br /><span style="color:#cc7832;">class </span>Model2(nn.Module):<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br /> <span style="color:#8888c6;">super</span>(Model2<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br /> <span style="color:#cc7832;">return </span>F.dropout(x)<br /><br /><br />m1 = Model1()<br />m2 = Model2()<br />inputs = torch.rand(<span style="color:#6897bb;">10</span>)<br /><span style="color:#8888c6;">print</span>(m1(inputs))<br /><span style="color:#8888c6;">print</span>(m2(inputs))<br /><span style="color:#8888c6;">print</span>(<span style="color:#6897bb;">20 </span>* <span style="color:#6a8759;">'-' </span>+ <span style="color:#6a8759;">"eval model:" </span>+ <span style="color:#6897bb;">20 </span>* <span style="color:#6a8759;">'-' </span>+ <span style="color:#6a8759;">'</span><span style="color:#cc7832;"><br /></span><span style="color:#6a8759;">'</span>)<br />m1.eval()<br />m2.eval()<br /><span style="color:#8888c6;">print</span>(m1(inputs))<br /><span style="color:#8888c6;">print</span>(m2(inputs))
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br />
输出:
从上面输出可以看出m2调用了eval之后,dropout照样还在正常工作。当然如果你有强烈愿望坚持使用nn.functional.dropout,也可以采用下面方式来补救。
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>Model3(nn.Module):<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br /> <span style="color:#8888c6;">super</span>(Model3<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br /><br /> <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br /> <span style="color:#cc7832;">return </span>F.dropout(x<span style="color:#cc7832;">, </span><span style="color:#aa4926;">training</span>=<span style="color:#94558d;">self</span>.training)
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br />
什么时候使用nn.functional.xxx,什么时候使用nn.Xxx?
这个问题依赖于你要解决你问题的复杂度和个人风格喜好。在nn.Xxx不能满足你的功能需求时,nn.functional.xxx是更佳的选择,因为nn.functional.xxx更加的灵活(更加接近底层),你可以在其基础上定义出自己想要的功能。
个人偏向于在能使用nn.Xxx情况下尽量使用,不行再换nn.functional.xxx ,感觉这样更能显示出网络的层次关系,也更加的纯粹(所有layer和model本身都是Module,一种和谐统一的感觉)。
为您推荐
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
内容反馈