【pytorch笔记】损失函数nll_loss

热门标签

,
GentleCP

发表文章数:52

首页 » 技术杂谈 » Python » 正文

使用场景

在用pytorch做训练、测试时经常要用到损失函数计算输出与目标结果的差距,例如下面的代码:

# 训练
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
# 测试
for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    output = model(data)
    test_loss += F.nll_loss(output, target, reduction = 'sum')

前一部分是训练过程,计算输出outputtarget的误差回传,后一部分是测试过程,计算outputtarget误差,并进行误差求和。

函数理解

在该函数中重要的参数主要有三个,分别是:

  • input:(N,C),其中C表示分类的数量,N表示数据的条数,由于数据的输入是按batch输入,所以N也是batch的大小。
  • target:(N)目标结果,即常见分类任务中的label,包含N个。
  • reduction:对计算结果采取的操作,通常我们用sum(对N个误差结果求和),mean(对N个误差结果取平均),默认是对所有样本求loss均值

例子演示

采用官方提供的演示代码如下:

为了方便对中间变量的理解,我添加了一些打印的代码

input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
print('input:{}\n target:{}'.format(input,target))
print('log softmax:{}'.format(F.log_softmax(input,dim=1)))
output = F.nll_loss(F.log_softmax(input,dim=1), target)
print('output:{}'.format(output))
output.backward()

jupyter notebook打印一下中间结果:
【pytorch笔记】损失函数nll_loss
这里做了一次log softmax操作,softmax实际上就是对输入tensor中的元素按照数值计算了比例,dim=1保证所有分类概率和为1,最后对每个数值取了log。最终要求的就是log softmax后结果与target的误差。

我们重点关注这一步计算:

  • log softmax tensor(模型output):

    tensor([[-3.2056, -1.7804, -0.4350, -3.9833, -2.0795],
        [-2.1543, -1.8606, -1.5360, -1.1057, -1.7025],
        [-2.3243, -0.7615, -1.1595, -2.5594, -3.1195]]
  • target:

    tensor([1, 0, 4])

    标签代表了tensor中每一行向量应该检查的位置,例如第一个标签是1,这表示在tensor第一行中应该选择1号位置的元素-1.7804(代表了模型将数据分为1类的概率)取出,同理取第2行0号位置元素-2.1543,取第三行4号位置元素-3.1195,将它们去除负号求和再取均值。
    则该模型输出outputtarget之间误差应为:(1.7804+2.1543+3.1195)/3 = 2.3514

回顾上文的output结果2.35138....与预期相符。

  • reduction
    同样是上面的输入,我们添加reductionsum,查看output结果:
    【pytorch笔记】损失函数nll_loss
    发现计算结果是7.054...,说明没有执行前面 (1.7804+2.1543+3.1195)/3 = 2.3514求均值的操作。直接将各个样本与label之间的误差求和返回。

总结

nll_loss 函数接收两个tensor第一个是模型的output,第二个是label targetoutput中每一行与一个标签中每一列的元素对应,根据target的取值找出output行中对应位置元素,求和取平均值。

标签:

未经本人允许不得转载!作者:GentleCP, 转载或复制请以 超链接形式 并注明出处 求索
原文地址:《【pytorch笔记】损失函数nll_loss》 发布于2020-06-07

分享到:
赞(0)

评论 抢沙发

评论前必须登录!

  注册



Vieu4.5主题
专业打造轻量级个人企业风格博客主题!专注于前端开发,全站响应式布局自适应模板。
切换注册

登录

忘记密码 ?

您也可以使用第三方帐号快捷登录

Q Q 登 录
微 博 登 录
切换登录

注册