知行编程网知行编程网  2022-04-08 16:00 知行编程网 隐藏边栏 |   抢沙发  3 
文章评分 0 次,平均分 0.0

PyTorch 中,nn 与 nn.functional 有什么区别?

编辑:忆臻

https://www.zhihu.com/question/66782101

本文仅作为学术分享,如果侵权,会删文处理

机器学习算法与自然语言处理报道

PyTorch 中,nn 与 nn.functional 有什么区别?



作者:蒲嘉宸
https://www.zhihu.com/question/66782101/answer/246341271

其实这两个是差不多的,不过一个包装好的类,一个是可以直接调用的函数。我们可以去翻这两个模块的具体实现代码,我下面以卷积Conv1d为例。

首先是torch.nn下的Conv1d:

<span style="color:#cc7832;">class </span>Conv1d(_ConvNd):<br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>in_channels<span style="color:#cc7832;">, </span>out_channels<span style="color:#cc7832;">, </span>kernel_size<span style="color:#cc7832;">, </span>stride=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">                 </span>padding=<span style="color:#6897bb;">0</span><span style="color:#cc7832;">, </span>dilation=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span>groups=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span>bias=<span style="color:#cc7832;">True</span>):<br  />        kernel_size = _single(kernel_size)<br  />        stride = _single(stride)<br  />        padding = _single(padding)<br  />        dilation = _single(dilation)<br  />        <span style="color:#8888c6;">super</span>(Conv1d<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>(<br  />            in_channels<span style="color:#cc7832;">, </span>out_channels<span style="color:#cc7832;">, </span>kernel_size<span style="color:#cc7832;">, </span>stride<span style="color:#cc7832;">, </span>padding<span style="color:#cc7832;">, </span>dilation<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">            False, </span>_single(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>groups<span style="color:#cc7832;">, </span>bias)<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>input):<br  />        <span style="color:#cc7832;">return </span>F.conv1d(input<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.stride<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">                        </span><span style="color:#94558d;">self</span>.padding<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.dilation<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.groups)

这是torch.nn.functional下的conv1d:

<span style="color:#cc7832;">def </span><span style="color:#ffc66d;">conv1d</span>(input<span style="color:#cc7832;">, </span>weight<span style="color:#cc7832;">, </span>bias=<span style="color:#cc7832;">None, </span>stride=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span>padding=<span style="color:#6897bb;">0</span><span style="color:#cc7832;">, </span>dilation=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">           </span>groups=<span style="color:#6897bb;">1</span>):<br  />    <span style="color:#cc7832;">if </span>input <span style="color:#cc7832;">is not None and </span>input.dim() != <span style="color:#6897bb;">3</span>:<br  />        <span style="color:#cc7832;">raise </span><span style="color:#8888c6;">ValueError</span>(<span style="color:#6a8759;">"Expected 3D tensor as input, got {}D tensor instead."</span>.format(input.dim()))<br  /><br  />    f = ConvNd(_single(stride)<span style="color:#cc7832;">, </span>_single(padding)<span style="color:#cc7832;">, </span>_single(dilation)<span style="color:#cc7832;">, False,<br  /></span><span style="color:#cc7832;">               </span>_single(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>groups<span style="color:#cc7832;">, </span>torch.backends.cudnn.benchmark<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">               </span>torch.backends.cudnn.deterministic<span style="color:#cc7832;">, </span>torch.backends.cudnn.enabled)<br  />    <span style="color:#cc7832;">return </span>f(input<span style="color:#cc7832;">, </span>weight<span style="color:#cc7832;">, </span>bias)

可以看到torch.nn下的Conv1d类在forward时调用了nn.functional下的conv1d,当然最终的计算是通过C++编写的THNN库中的ConvNd进行计算的,因此这两个其实是互相调用的关系。

你可能会疑惑为什么需要这两个功能如此相近的模块,其实这么设计是有其原因的。如果我们只保留nn.functional下的函数的话,在训练或者使用时,我们就要手动去维护weight, bias, stride这些中间量的值,这显然是给用户带来了不便。而如果我们只保留nn下的类的话,其实就牺牲了一部分灵活性,因为做一些简单的计算都需要创造一个类,这也与PyTorch的风格不符。

下面我举几个例子方便各位理解:

<span style="color:#cc7832;">import </span>torch<br  /><span style="color:#cc7832;">import </span>torch.nn <span style="color:#cc7832;">as </span>nn<br  /><span style="color:#cc7832;">import </span>torch.nn.functional <span style="color:#cc7832;">as </span>F<br  /><br  /><br  /><span style="color:#cc7832;">class </span>Net(nn.Module):<br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br  />        <span style="color:#8888c6;">super</span>(Net<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br  />        <span style="color:#94558d;">self</span>.fc1 = nn.Linear(<span style="color:#6897bb;">512</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">256</span>)<br  />        <span style="color:#94558d;">self</span>.fc2 = nn.Linear(<span style="color:#6897bb;">256</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">256</span>)<br  />        <span style="color:#94558d;">self</span>.fc3 = nn.Linear(<span style="color:#6897bb;">256</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">2</span>)<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br  />        x = F.relu(<span style="color:#94558d;">self</span>.fc1(x))<br  />        x = F.relu(F.dropout(<span style="color:#94558d;">self</span>.fc2(x)<span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.5</span>))<br  />        x = F.dropout(<span style="color:#94558d;">self</span>.fc3(x)<span style="color:#cc7832;">, </span><span style="color:#6897bb;">0.5</span>)<br  />        <span style="color:#cc7832;">return </span>x

以一个最简单的三层网络为例。需要维持状态的,主要是三个线性变换,所以在构造Module是,定义了三个nn.Linear对象,而在计算时,relu,dropout之类不需要保存状态的可以直接使用。

注:dropout的话有个坑,需要设置自行设定training的state。



作者:有糖吃可好
https://www.zhihu.com/question/66782101/answer/579393790

两者的相同之处:

  • nn.Xxx和nn.functional.xxx的实际功能是相同的,即nn.Conv2d和nn.functional.conv2d 都是进行卷积,nn.Dropout 和nn.functional.dropout都是进行dropout,。

  • 运行效率也是近乎相同。


nn.functional.xxx是函数接口,而nn.Xxx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。这一点导致nn.Xxx除了具有nn.functional.xxx功能之外,内部附带了nn.Module相关的属性和方法,例如train(), eval(),load_state_dict, state_dict 等。

两者的差别之处:


  • 两者的调用方式不同。

nn.Xxx 需要先实例化并传入参数,然后以函数调用的方式调用实例化的对象并传入输入数据。

inputs = torch.rand(<span style="color:#6897bb;">64</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">244</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">244</span>)<br  />conv = nn.Conv2d(<span style="color:#aa4926;">in_channels</span>=<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">out_channels</span>=<span style="color:#6897bb;">64</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">1</span>)<br  />out = conv(inputs)


nn.functional.xxx同时传入输入数据和weight, bias等其他参数 。

<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
weight = torch.rand(<span style="color:#6897bb;">64</span><span style="color:#cc7832;">,</span><span style="color:#6897bb;">3</span><span style="color:#cc7832;">,</span><span style="color:#6897bb;">3</span><span style="color:#cc7832;">,</span><span style="color:#6897bb;">3</span>)<br  />bias = torch.rand(<span style="color:#6897bb;">64</span>)<br  />out = nn.functional.conv2d(inputs<span style="color:#cc7832;">, </span>weight<span style="color:#cc7832;">, </span>bias<span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">1</span>)
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br  />
  • nn.Xxx继承于nn.Module, 能够很好的与nn.Sequential结合使用, 而nn.functional.xxx无法与nn.Sequential结合使用。

fm_layer = nn.Sequential(<br  />    nn.Conv2d(<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">64</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">3</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">1</span>)<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">    </span>nn.BatchNorm2d(<span style="color:#aa4926;">num_features</span>=<span style="color:#6897bb;">64</span>)<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">    </span>nn.ReLU()<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">    </span>nn.MaxPool2d(<span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">2</span>)<span style="color:#cc7832;">,<br  /></span><span style="color:#cc7832;">    </span>nn.Dropout(<span style="color:#6897bb;">0.2</span>)<br  />)
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br  />
  • nn.Xxx不需要你自己定义和管理weight;而nn.functional.xxx需要你自己定义weight,每次调用的时候都需要手动传入weight, 不利于代码复用。

使用nn.Xxx定义一个CNN 。

<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>CNN(nn.Module):<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br  />        <span style="color:#8888c6;">super</span>(CNN<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br  /><br  />        <span style="color:#94558d;">self</span>.cnn1 = nn.Conv2d(<span style="color:#aa4926;">in_channels</span>=<span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">out_channels</span>=<span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">0</span>)<br  />        <span style="color:#94558d;">self</span>.relu1 = nn.ReLU()<br  />        <span style="color:#94558d;">self</span>.maxpool1 = nn.MaxPool2d(<span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">2</span>)<br  /><br  />        <span style="color:#94558d;">self</span>.cnn2 = nn.Conv2d(<span style="color:#aa4926;">in_channels</span>=<span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">out_channels</span>=<span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#aa4926;">padding</span>=<span style="color:#6897bb;">0</span>)<br  />        <span style="color:#94558d;">self</span>.relu2 = nn.ReLU()<br  />        <span style="color:#94558d;">self</span>.maxpool2 = nn.MaxPool2d(<span style="color:#aa4926;">kernel_size</span>=<span style="color:#6897bb;">2</span>)<br  /><br  />        <span style="color:#94558d;">self</span>.linear1 = nn.Linear(<span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">10</span>)<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br  />        x = x.view(x.size(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>-<span style="color:#6897bb;">1</span>)<br  />        out = <span style="color:#94558d;">self</span>.maxpool1(<span style="color:#94558d;">self</span>.relu1(<span style="color:#94558d;">self</span>.cnn1(x)))<br  />        out = <span style="color:#94558d;">self</span>.maxpool2(<span style="color:#94558d;">self</span>.relu2(<span style="color:#94558d;">self</span>.cnn2(out)))<br  />        out = <span style="color:#94558d;">self</span>.linear1(out.view(x.size(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>-<span style="color:#6897bb;">1</span>))<br  />        <span style="color:#cc7832;">return </span>out
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br  />

使用nn.function.xxx定义一个与上面相同的CNN。

<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>CNN(nn.Module):<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br  />        <span style="color:#8888c6;">super</span>(CNN<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br  /><br  />        <span style="color:#94558d;">self</span>.cnn1_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">1</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span>))<br  />        <span style="color:#94558d;">self</span>.bias1_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">16</span>))<br  /><br  />        <span style="color:#94558d;">self</span>.cnn2_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">16</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">5</span>))<br  />        <span style="color:#94558d;">self</span>.bias2_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">32</span>))<br  /><br  />        <span style="color:#94558d;">self</span>.linear1_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">4 </span>* <span style="color:#6897bb;">32</span><span style="color:#cc7832;">, </span><span style="color:#6897bb;">10</span>))<br  />        <span style="color:#94558d;">self</span>.bias3_weight = nn.Parameter(torch.rand(<span style="color:#6897bb;">10</span>))<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br  />        x = x.view(x.size(<span style="color:#6897bb;">0</span>)<span style="color:#cc7832;">, </span>-<span style="color:#6897bb;">1</span>)<br  />        out = F.conv2d(x<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.cnn1_weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias1_weight)<br  />        out = F.relu(out)<br  />        <span style="color:#808080;">out </span>= F.max_pool2d(out)<br  /><br  />        out = F.conv2d(x<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.cnn2_weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias2_weight)<br  />        out = F.relu(out)<br  />        <span style="color:#808080;">out </span>= F.max_pool2d(out)<br  /><br  />        out = F.linear(x<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.linear1_weight<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>.bias3_weight)<br  />        <span style="color:#cc7832;">return </span>out
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br  />

上面两种定义方式得到CNN功能都是相同的,至于喜欢哪一种方式,是个人口味问题,但PyTorch官方推荐:具有学习参数的(例如,conv2d, linear, batch_norm)采用nn.Xxx方式,没有学习参数的(例如,maxpool, loss func, activation func)等根据个人选择使用nn.functional.xxx或者nn.Xxx方式。但关于dropout,个人强烈推荐使用nn.Xxx方式,因为一般情况下只有训练阶段才进行dropout,在eval阶段都不会进行dropout。使用nn.Xxx方式定义dropout,在调用model.eval()之后,model中所有的dropout layer都关闭,但以nn.function.dropout方式定义dropout,在调用model.eval()之后并不能关闭dropout。

<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>Model1(nn.Module):<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br  />        <span style="color:#8888c6;">super</span>(Model1<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br  />        <span style="color:#94558d;">self</span>.dropout = nn.Dropout(<span style="color:#6897bb;">0.5</span>)<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br  />        <span style="color:#cc7832;">return </span><span style="color:#94558d;">self</span>.dropout(x)<br  /><br  /><br  /><span style="color:#cc7832;">class </span>Model2(nn.Module):<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br  />        <span style="color:#8888c6;">super</span>(Model2<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br  />        <span style="color:#cc7832;">return </span>F.dropout(x)<br  /><br  /><br  />m1 = Model1()<br  />m2 = Model2()<br  />inputs = torch.rand(<span style="color:#6897bb;">10</span>)<br  /><span style="color:#8888c6;">print</span>(m1(inputs))<br  /><span style="color:#8888c6;">print</span>(m2(inputs))<br  /><span style="color:#8888c6;">print</span>(<span style="color:#6897bb;">20 </span>* <span style="color:#6a8759;">'-' </span>+ <span style="color:#6a8759;">"eval model:" </span>+ <span style="color:#6897bb;">20 </span>* <span style="color:#6a8759;">'-' </span>+ <span style="color:#6a8759;">'</span><span style="color:#cc7832;"><br  /></span><span style="color:#6a8759;">'</span>)<br  />m1.eval()<br  />m2.eval()<br  /><span style="color:#8888c6;">print</span>(m1(inputs))<br  /><span style="color:#8888c6;">print</span>(m2(inputs))
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br  />

输出:


PyTorch 中,nn 与 nn.functional 有什么区别?


从上面输出可以看出m2调用了eval之后,dropout照样还在正常工作。当然如果你有强烈愿望坚持使用nn.functional.dropout,也可以采用下面方式来补救。

<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span>
<span style="color:#cc7832;">class </span>Model3(nn.Module):<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#b200b2;">__init__</span>(<span style="color:#94558d;">self</span>):<br  />        <span style="color:#8888c6;">super</span>(Model3<span style="color:#cc7832;">, </span><span style="color:#94558d;">self</span>).<span style="color:#b200b2;">__init__</span>()<br  /><br  />    <span style="color:#cc7832;">def </span><span style="color:#ffc66d;">forward</span>(<span style="color:#94558d;">self</span><span style="color:#cc7832;">, </span>x):<br  />        <span style="color:#cc7832;">return </span>F.dropout(x<span style="color:#cc7832;">, </span><span style="color:#aa4926;">training</span>=<span style="color:#94558d;">self</span>.training)
<span style="color: rgb(63, 63, 63);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.5px;text-align: start;background-color: rgb(255, 255, 255);"></span><br  />

什么时候使用nn.functional.xxx,什么时候使用nn.Xxx?

这个问题依赖于你要解决你问题的复杂度和个人风格喜好。在nn.Xxx不能满足你的功能需求时,nn.functional.xxx是更佳的选择,因为nn.functional.xxx更加的灵活(更加接近底层),你可以在其基础上定义出自己想要的功能。

个人偏向于在能使用nn.Xxx情况下尽量使用,不行再换nn.functional.xxx ,感觉这样更能显示出网络的层次关系,也更加的纯粹(所有layer和model本身都是Module,一种和谐统一的感觉)。


—完—

为您推荐

如何学习SVM以及改进SVM算法程序?
meta-learning和few-shot learning的关系是什么?
Nature发文:避开机器学习三大「坑」
如何利用Python开发人工智能入门项目
【微软】AI-神经网络基本原理简明教程

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享