作者 | hyk_1996 来源 | CSDN博客
编译 | 大白
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;">model = model.cuda() <br />model.cuda() <br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;">model = create_a_model()<br />tensor = torch.zeros([2,3,10,10])<br />model.cuda()<br />tensor.cuda()<br />model(tensor) <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># 会报错</span><br />tensor = tensor.cuda()<br />model(tensor) <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># 正常运行</span><br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;">total_loss += loss.item()</p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># torch.device object used throughout this script</span><br />device = torch.device(<span class="hljs-string" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">"cuda"</span> if use_cuda <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">else</span> <span class="hljs-string" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">"cpu"</span>)<br />model = MyRNN().to(device)<br /><br /><span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># train</span><br />total_loss= 0<br />for input, target in train_loader:<br /> input, target = input.to(device), target.to(device)<br /> hidden = input.new_zeros(*h_shape) <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># has the same device & dtype as `input`</span><br /> ... <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># get loss and optimize</span><br /> total_loss += loss.item()<br /><br /><span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># test</span><br />with torch.no_grad(): <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># operations inside don't track history</span><br /> for input, targetin test_loader:<br /> ...<br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;">Returns a <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">new</span> Tensor, detached <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> the current graph.<br /> The result will never <span class="hljs-built_in" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">require</span> gradient.<br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span class="hljs-attr" style="font-size: inherit;line-height: inherit;color: rgb(165, 218, 45);overflow-wrap: inherit !important;word-break: inherit !important;">input_B</span> = output_A.detach()<br /></p>
以CrossEntropyLoss为例:
<p style="margin-right: 8px;margin-left: 8px;padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;">CrossEntropyLoss(self, weight=<span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, size_average=<span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, ignore_index=<span class="hljs-number" style="font-size: inherit;line-height: inherit;color: rgb(174, 135, 250);overflow-wrap: inherit !important;word-break: inherit !important;">-100</span>, reduce=<span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, reduction=<span class="hljs-string" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">'elementwise_mean'</span>)<br /></p>
-
若 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss,即batch中每个元素对应的loss.
-
若 reduce = True,那么 loss 返回的是标量:
-
如果 size_average = True,返回 loss.mean().
-
如果 size_average = False,返回 loss.sum().
-
weight : 输入一个1D的权值向量,为各个类别的loss加权,如下公式所示:
-
ignore_index : 选择要忽视的目标值,使其对输入梯度不作贡献。如果 size_average = True,那么只计算不被忽视的目标的loss的均值。
-
reduction : 可选的参数有:‘none’ | ‘elementwise_mean’ | ‘sum’, 正如参数的字面意思,不解释。
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">self</span>.training <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">and</span> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">self</span>.track_running_stats:<br /> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">self</span>.num_batches_tracked += <span class="hljs-number" style="font-size: inherit;line-height: inherit;color: rgb(174, 135, 250);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">self</span>.momentum is None: <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># use cumulative moving average</span><br /> exponential_average_factor = <span class="hljs-number" style="font-size: inherit;line-height: inherit;color: rgb(174, 135, 250);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span> / <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">self</span>.num_batches_tracked.item()<br /> <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>: <span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># use exponential moving average</span><br /> exponential_average_factor = <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">self</span>.momentum<br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;">load_state_dict(torch.load(weight_path), <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">strict</span>=<span class="hljs-literal" style="font-size: inherit;line-height: inherit;color: rgb(174, 135, 250);overflow-wrap: inherit !important;word-break: inherit !important;">False</span>)<br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> numpy <span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> np<br /><br /><span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># 判断输入数据是否存在nan</span><br /><span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> np.any(np.isnan(input.cpu().numpy())):<br /> <span class="hljs-built_in" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">print</span>(<span class="hljs-string" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">'Input data has NaN!'</span>)<br /><br /><span class="hljs-comment" style="font-size: inherit;line-height: inherit;color: rgb(128, 128, 128);overflow-wrap: inherit !important;word-break: inherit !important;"># 判断损失是否为nan</span><br /><span class="hljs-keyword" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> np.isnan(loss.item()):<br /> <span class="hljs-built_in" style="font-size: inherit;line-height: inherit;color: rgb(248, 35, 117);overflow-wrap: inherit !important;word-break: inherit !important;">print</span>(<span class="hljs-string" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">'Loss value is NaN!'</span>)<br /></p>
<p style="line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;color: rgb(169, 183, 198);background: rgb(40, 43, 46);padding: 0.5em;margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"> <span class="hljs-attribute" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span class="hljs-string" style="font-size: inherit;line-height: inherit;color: rgb(238, 220, 112);overflow-wrap: inherit !important;word-break: inherit !important;">'Expected more than 1 value per channel when training, got input size {}'</span>.format(size))<br /></p>
<pre style="letter-spacing: 0.544px;"><section style="margin-right: 8px;margin-left: 8px;white-space: normal;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;text-align: center;widows: 1;line-height: 1.75em;"><strong><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong>完<strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong></span></strong></span></strong></section><section style="white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;text-align: center;widows: 1;color: rgb(255, 97, 149);"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="letter-spacing: 0.544px;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="margin-right: 8px;margin-bottom: 15px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 25.5938px;letter-spacing: 3px;"><span style="color: rgb(0, 0, 0);"><strong><span style="font-size: 16px;font-family: 微软雅黑;caret-color: red;">为您推荐</span></strong></span></section><section style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;min-height: 1em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">Deep Learning 调参经验<br /></section><section style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;min-height: 1em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">最新:网传饶毅举报多位学者论文造假?国自然基金委和饶毅都回应了<br /></section><section style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;min-height: 1em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">程序员真的是太太太太太太太太难了!<br /></section><section style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;min-height: 1em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">深度学习必懂的13种概率分布<br /></section><section style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;min-height: 1em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">【微软】AI-神经网络基本原理简明教程</section></section></section></section></section></section></section></section></section>
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
内容反馈