torch.nn.CrossEntropyLoss

Reason is the light and the light of life.

Jerry Su Jan 04, 2021 2 mins

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

import torch
from torch.nn import CrossEntropyLoss

BATCH_SIZE = 2
MAX_SEQ_LENGTH = 3
EMBEDDING_SIZE = VOCAB_SIZE = 5
input_ids = torch.randn((BATCH_SIZE, MAX_SEQ_LENGTH, EMBEDDING_SIZE))
input_ids
tensor([[[-1.2207e-01,  9.4908e-01,  9.6387e-02,  1.4286e+00,  1.9476e+00],
         [ 6.7663e-01,  8.2334e-04, -2.6130e+00,  3.3753e-02, -3.8190e-01],
         [-2.3500e-01,  1.7386e-01, -9.2278e-01,  6.7210e-01, -5.9908e-01]],

        [[-7.6696e-01,  1.0766e+00, -3.9904e-01,  1.3112e-01,  1.4053e-02],
         [ 1.1152e+00,  1.4323e+00,  8.4845e-01,  8.5321e-01, -2.1357e-01],
         [ 5.5852e-01,  1.7036e-01,  4.6033e-01,  1.2075e+00, -9.5198e-01]]])
labels = torch.randint(low=0, high=4, size=(BATCH_SIZE, MAX_SEQ_LENGTH))
labels
tensor([[2, 0, 2],
        [2, 0, 3]])
loss_fct = CrossEntropyLoss()
"""
 - Input: :math:`(N, C)` where `C = number of classes`, or
          :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
          in the case of `K`-dimensional loss.
        - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
"""
# 注意:input是一个隐向量,loss_fct内回自动算arg_max,label是class整型值
loss_fct(input=input_ids.view(-1, EMBEDDING_SIZE), target=labels.view(-1))
tensor(1.7686)
# EMBEDDING_SIZE即CLASS类别数,即字典大小VOCAB_SIZE。注意label值要与EMBEDDING_SIZE一致
labels = torch.randint(low=0, high=10, size=(BATCH_SIZE, MAX_SEQ_LENGTH))
labels
tensor([[6, 2, 5],
        [7, 7, 1]])
# 越界
loss_fct(input=input_ids.view(-1, EMBEDDING_SIZE), target=labels.view(-1))
---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-37-cbf2ac1ed94f> in <module>
      1 # 越界
----> 2 loss_fct(input=input_ids.view(-1, EMBEDDING_SIZE), target=labels.view(-1))


/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),


/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    959 
    960     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961         return F.cross_entropy(input, target, weight=self.weight,
    962                                ignore_index=self.ignore_index, reduction=self.reduction)
    963


/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470


/opt/conda/envs/blog/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
-> 2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2265     elif dim == 4:
   2266         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)


IndexError: Target 6 is out of bounds.


Read more:

Related posts: