In [1]:
import torch
from torch import nn
In [2]:
# BMES四位序列标注法
PADDING = 0
B = 2
E = 3
S = 4
M = 5
START = 6
END = 7

LABEL_VOCAB = {0: '<pad>', 1: '<unk>', 2: 'B', 3: 'E', 4: 'S', 5: 'M'}

NUM_TAGS = 8

1.1 Initialize Input

In [32]:
# batch_size x max_len x num_tags : (2, 9, 6)

logits = torch.randn(3, 9, NUM_TAGS)
print(f"logits:\n {logits}\n")
tags = torch.tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
                     [6, 4, 2, 3, 4, 4, 7, 0, 0],
                     [6, 2, 5, 3, 4, 7, 0, 0, 0]])
print(f"tags:\n {tags}\n")
mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1, 1, 1, 0, 0],
                     [1, 1, 1, 1, 1, 1, 0, 0, 0]])
print(f"mask:\n {mask}")
logits:
 tensor([[[-1.4452e-01, -1.9019e-01,  9.7182e-01, -1.5851e+00, -1.3361e+00,
          -1.4078e+00,  5.9501e-01,  1.1116e+00],
         [-1.2016e+00,  5.7204e-01, -1.7459e-01, -1.2101e+00, -1.5633e+00,
          -1.5958e+00,  7.4246e-01, -2.2454e-01],
         [-1.5298e-01,  7.3708e-01,  3.7166e-01,  2.4409e-01, -7.6002e-01,
          -1.1487e+00, -3.5016e-01,  1.7178e-01],
         [ 5.3908e-01,  1.7665e-01,  8.4331e-02,  1.2331e+00, -6.3707e-01,
           3.2750e-01, -9.5666e-01, -1.0764e+00],
         [-2.0275e-01, -4.7478e-01, -2.4096e-01,  3.4847e-01,  1.4107e+00,
          -6.7662e-01,  1.1356e+00, -8.8798e-01],
         [-5.2993e-01, -8.7384e-01, -7.1909e-01, -9.0088e-01, -1.0477e+00,
           5.7400e-01, -8.9259e-02, -9.4986e-01],
         [ 5.1900e-01, -1.3026e+00, -1.7043e+00, -2.8520e-01, -1.4247e+00,
           5.4460e-02, -6.3961e-01,  1.3025e-01],
         [ 1.2510e+00,  2.2883e-01,  3.4238e-01, -7.5308e-01,  2.5237e-03,
          -3.4200e-01, -2.2455e-01, -1.0249e+00],
         [ 4.8784e-01, -1.5034e+00, -1.6049e-01,  5.3555e-01,  2.2210e-03,
          -5.5209e-01,  2.6669e-01, -1.2266e-01]],

        [[ 4.8289e-01,  8.0970e-01, -2.6692e-01, -1.0974e+00,  2.5424e-01,
           1.4737e-01,  7.8376e-01, -1.4346e+00],
         [-2.6347e-01, -1.5178e+00,  1.5867e+00,  7.5319e-01, -2.9615e-01,
           6.0084e-01, -5.5508e-01, -1.8708e-01],
         [ 2.3153e+00, -1.7178e-01, -6.4674e-01,  1.1492e+00,  7.9638e-01,
           5.4677e-01,  1.6118e-01,  2.1595e-01],
         [ 1.6032e+00,  1.0249e+00, -1.1967e+00,  7.7639e-01,  1.0185e+00,
          -9.3879e-01, -1.4424e+00,  1.3888e+00],
         [-1.0615e-01,  4.5697e-01,  1.2877e-01,  1.3390e+00, -1.9707e+00,
          -5.8355e-01, -2.2570e+00, -1.1907e+00],
         [-6.4982e-01,  1.6911e-01, -3.0681e-01, -5.5137e-01, -6.9925e-01,
           1.9890e-01,  4.8145e-01,  5.3715e-01],
         [ 5.2184e-01,  5.8037e-01,  8.7149e-01, -2.0938e+00, -4.8395e-01,
          -2.0589e+00, -2.3386e+00,  7.5332e-02],
         [ 3.7932e-02,  1.3092e+00, -5.4694e-01,  1.6446e-01, -1.7779e+00,
          -6.6966e-01, -8.4106e-01,  9.3973e-02],
         [ 6.5991e-01,  6.3159e-01,  1.3538e+00, -2.7384e-01,  8.5952e-01,
          -3.7105e-01, -1.3350e-01, -2.1770e+00]],

        [[-5.1107e-01, -7.3291e-01,  8.8863e-01,  8.5765e-02,  2.0189e-01,
          -6.3774e-01,  1.1234e-01,  1.0721e+00],
         [-1.7006e-01, -2.5931e-01,  1.1782e+00, -2.6335e-01,  1.4209e+00,
           4.8150e-01,  1.1860e+00,  7.5185e-01],
         [-1.4374e+00,  1.0826e+00,  2.8144e-02,  6.7766e-01, -1.9165e-01,
           1.2558e-01,  1.2113e+00,  5.0190e-02],
         [ 3.9188e-01,  1.5474e+00, -1.5797e+00,  6.2201e-01, -3.8201e-01,
          -2.2004e+00, -1.2397e+00,  1.3466e+00],
         [-1.6803e+00,  1.5669e+00,  5.5389e-01, -8.0821e-01,  1.8370e+00,
          -2.8107e-02,  8.2856e-01,  7.0409e-01],
         [ 1.3986e+00, -1.1884e+00,  4.2444e-01,  1.3998e+00, -7.3532e-01,
          -9.4360e-01, -1.4932e-01, -1.2199e+00],
         [ 2.2620e-01, -3.7655e-01, -3.5929e-01, -1.8760e+00, -2.0527e+00,
           1.6695e-01,  2.2314e-01,  2.8254e-01],
         [ 1.6245e+00,  7.8641e-01,  1.4010e+00, -7.4482e-02, -1.4197e+00,
          -1.3425e+00, -8.8786e-01,  9.7691e-01],
         [ 3.7015e-01,  1.5858e+00, -4.2269e-01,  2.0528e+00,  3.1429e-01,
          -8.5222e-01, -5.5130e-01, -9.6802e-02]]])

tags:
 tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
        [6, 4, 2, 3, 4, 4, 7, 0, 0],
        [6, 2, 5, 3, 4, 7, 0, 0, 0]])

mask:
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0]])
In [33]:
logits = logits.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
In [34]:
trans_matrix = nn.Parameter(torch.randn(NUM_TAGS, NUM_TAGS))
trans_matrix
Out[34]:
Parameter containing:
tensor([[ 0.3130,  0.6805, -1.6291,  1.1787,  0.2590,  0.0393,  0.3733, -1.0978],
        [ 0.0257, -0.7165,  0.1774, -0.4137, -0.9312,  1.1543, -1.7091,  0.1694],
        [-0.5890, -1.3841, -0.7339,  1.5916, -0.6873, -1.2056,  1.3192,  0.1930],
        [-1.0862, -0.3838,  0.3116, -1.0895, -0.5821, -1.2720,  0.2369,  0.5035],
        [-1.1792,  0.2191, -0.5459,  1.8000, -0.0737,  1.6784, -0.8590, -0.3808],
        [-0.1573, -1.7135, -0.2278,  1.8250, -0.4302,  2.0009, -1.1343,  0.4233],
        [-0.3599, -0.0824, -0.4446,  0.4188, -0.7154, -0.1829,  0.0595, -1.1767],
        [-1.2007,  2.5477,  0.0693, -0.9544, -0.8122,  0.1949, -0.0823,  0.3554]],
       requires_grad=True)

2. Compute the score for the gold path.

In [35]:
seq_len, batch_size, _ = logits.size()
print(f"seq_len: {seq_len}\nbatch_size: {batch_size}")
seq_len: 9
batch_size: 3
In [36]:
batch_idx = torch.arange(batch_size, dtype=torch.long)
batch_idx
Out[36]:
tensor([0, 1, 2])
In [37]:
seq_idx = torch.arange(seq_len, dtype=torch.long)
seq_idx
Out[37]:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

2.1 transition probability score

In [38]:
mask
Out[38]:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 0.],
        [1., 0., 0.],
        [0., 0., 0.]])
In [39]:
mask = mask.eq(True)
mask
Out[39]:
tensor([[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True, False],
        [ True, False, False],
        [False, False, False]])
In [40]:
flip_mask = mask.eq(False)
flip_mask
Out[40]:
tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])
In [41]:
tags
Out[41]:
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
In [42]:
tags[: seq_len -1]
Out[42]:
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0]])
In [43]:
tags[1:]
Out[43]:
tensor([[4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
In [44]:
trans_matrix
Out[44]:
Parameter containing:
tensor([[ 0.3130,  0.6805, -1.6291,  1.1787,  0.2590,  0.0393,  0.3733, -1.0978],
        [ 0.0257, -0.7165,  0.1774, -0.4137, -0.9312,  1.1543, -1.7091,  0.1694],
        [-0.5890, -1.3841, -0.7339,  1.5916, -0.6873, -1.2056,  1.3192,  0.1930],
        [-1.0862, -0.3838,  0.3116, -1.0895, -0.5821, -1.2720,  0.2369,  0.5035],
        [-1.1792,  0.2191, -0.5459,  1.8000, -0.0737,  1.6784, -0.8590, -0.3808],
        [-0.1573, -1.7135, -0.2278,  1.8250, -0.4302,  2.0009, -1.1343,  0.4233],
        [-0.3599, -0.0824, -0.4446,  0.4188, -0.7154, -0.1829,  0.0595, -1.1767],
        [-1.2007,  2.5477,  0.0693, -0.9544, -0.8122,  0.1949, -0.0823,  0.3554]],
       requires_grad=True)
In [45]:
tags
Out[45]:
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
In [46]:
trans_score = trans_matrix[tags[:seq_len - 1], tags[1:]]    # tensor的二维索引
trans_score
Out[46]:
tensor([[-0.7154, -0.7154, -0.4446],
        [-0.5459, -0.5459, -1.2056],
        [-1.2056,  1.5916,  1.8250],
        [ 2.0009, -0.5821, -0.5821],
        [ 1.8250, -0.0737, -0.3808],
        [-0.5821, -0.3808, -1.2007],
        [-0.3808, -1.2007,  0.3130],
        [-1.2007,  0.3130,  0.3130]], grad_fn=<IndexBackward>)
In [47]:
flip_mask
Out[47]:
tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])
In [48]:
flip_mask[1:, :]
Out[48]:
tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])
In [49]:
# 最终计算的转移矩阵!!!
trans_score = trans_score.masked_fill(flip_mask[1:, :], 0)
trans_score
Out[49]:
tensor([[-0.7154, -0.7154, -0.4446],
        [-0.5459, -0.5459, -1.2056],
        [-1.2056,  1.5916,  1.8250],
        [ 2.0009, -0.5821, -0.5821],
        [ 1.8250, -0.0737, -0.3808],
        [-0.5821, -0.3808,  0.0000],
        [-0.3808,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]], grad_fn=<MaskedFillBackward0>)

2.2 emission probability score

In [50]:
# emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
# emit_score
In [51]:
logits.size()
Out[51]:
torch.Size([9, 3, 8])
In [52]:
seq_idx
Out[52]:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
In [53]:
seq_idx.size()
Out[53]:
torch.Size([9])
In [54]:
seq_idx.view(-1, 1)
Out[54]:
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8]])
In [55]:
batch_idx
Out[55]:
tensor([0, 1, 2])
In [56]:
batch_idx.size()
Out[56]:
torch.Size([3])
In [58]:
batch_idx.view(1, -1)
Out[58]:
tensor([[0, 1, 2]])
In [59]:
tags
Out[59]:
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
In [62]:
logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags]
Out[62]:
tensor([[ 0.5950,  0.7838,  0.1123],
        [-1.5633, -0.2962,  1.1782],
        [ 0.3717, -0.6467,  0.1256],
        [ 0.3275,  0.7764,  0.6220],
        [-0.6766, -1.9707,  1.8370],
        [-0.9009, -0.6992, -1.2199],
        [-1.4247,  0.0753,  0.2262],
        [-1.0249,  0.0379,  1.6245],
        [ 0.4878,  0.6599,  0.3702]])
In [66]:
logits
Out[66]:
tensor([[[-1.4452e-01, -1.9019e-01,  9.7182e-01, -1.5851e+00, -1.3361e+00,
          -1.4078e+00,  5.9501e-01,  1.1116e+00],
         [ 4.8289e-01,  8.0970e-01, -2.6692e-01, -1.0974e+00,  2.5424e-01,
           1.4737e-01,  7.8376e-01, -1.4346e+00],
         [-5.1107e-01, -7.3291e-01,  8.8863e-01,  8.5765e-02,  2.0189e-01,
          -6.3774e-01,  1.1234e-01,  1.0721e+00]],

        [[-1.2016e+00,  5.7204e-01, -1.7459e-01, -1.2101e+00, -1.5633e+00,
          -1.5958e+00,  7.4246e-01, -2.2454e-01],
         [-2.6347e-01, -1.5178e+00,  1.5867e+00,  7.5319e-01, -2.9615e-01,
           6.0084e-01, -5.5508e-01, -1.8708e-01],
         [-1.7006e-01, -2.5931e-01,  1.1782e+00, -2.6335e-01,  1.4209e+00,
           4.8150e-01,  1.1860e+00,  7.5185e-01]],

        [[-1.5298e-01,  7.3708e-01,  3.7166e-01,  2.4409e-01, -7.6002e-01,
          -1.1487e+00, -3.5016e-01,  1.7178e-01],
         [ 2.3153e+00, -1.7178e-01, -6.4674e-01,  1.1492e+00,  7.9638e-01,
           5.4677e-01,  1.6118e-01,  2.1595e-01],
         [-1.4374e+00,  1.0826e+00,  2.8144e-02,  6.7766e-01, -1.9165e-01,
           1.2558e-01,  1.2113e+00,  5.0190e-02]],

        [[ 5.3908e-01,  1.7665e-01,  8.4331e-02,  1.2331e+00, -6.3707e-01,
           3.2750e-01, -9.5666e-01, -1.0764e+00],
         [ 1.6032e+00,  1.0249e+00, -1.1967e+00,  7.7639e-01,  1.0185e+00,
          -9.3879e-01, -1.4424e+00,  1.3888e+00],
         [ 3.9188e-01,  1.5474e+00, -1.5797e+00,  6.2201e-01, -3.8201e-01,
          -2.2004e+00, -1.2397e+00,  1.3466e+00]],

        [[-2.0275e-01, -4.7478e-01, -2.4096e-01,  3.4847e-01,  1.4107e+00,
          -6.7662e-01,  1.1356e+00, -8.8798e-01],
         [-1.0615e-01,  4.5697e-01,  1.2877e-01,  1.3390e+00, -1.9707e+00,
          -5.8355e-01, -2.2570e+00, -1.1907e+00],
         [-1.6803e+00,  1.5669e+00,  5.5389e-01, -8.0821e-01,  1.8370e+00,
          -2.8107e-02,  8.2856e-01,  7.0409e-01]],

        [[-5.2993e-01, -8.7384e-01, -7.1909e-01, -9.0088e-01, -1.0477e+00,
           5.7400e-01, -8.9259e-02, -9.4986e-01],
         [-6.4982e-01,  1.6911e-01, -3.0681e-01, -5.5137e-01, -6.9925e-01,
           1.9890e-01,  4.8145e-01,  5.3715e-01],
         [ 1.3986e+00, -1.1884e+00,  4.2444e-01,  1.3998e+00, -7.3532e-01,
          -9.4360e-01, -1.4932e-01, -1.2199e+00]],

        [[ 5.1900e-01, -1.3026e+00, -1.7043e+00, -2.8520e-01, -1.4247e+00,
           5.4460e-02, -6.3961e-01,  1.3025e-01],
         [ 5.2184e-01,  5.8037e-01,  8.7149e-01, -2.0938e+00, -4.8395e-01,
          -2.0589e+00, -2.3386e+00,  7.5332e-02],
         [ 2.2620e-01, -3.7655e-01, -3.5929e-01, -1.8760e+00, -2.0527e+00,
           1.6695e-01,  2.2314e-01,  2.8254e-01]],

        [[ 1.2510e+00,  2.2883e-01,  3.4238e-01, -7.5308e-01,  2.5237e-03,
          -3.4200e-01, -2.2455e-01, -1.0249e+00],
         [ 3.7932e-02,  1.3092e+00, -5.4694e-01,  1.6446e-01, -1.7779e+00,
          -6.6966e-01, -8.4106e-01,  9.3973e-02],
         [ 1.6245e+00,  7.8641e-01,  1.4010e+00, -7.4482e-02, -1.4197e+00,
          -1.3425e+00, -8.8786e-01,  9.7691e-01]],

        [[ 4.8784e-01, -1.5034e+00, -1.6049e-01,  5.3555e-01,  2.2210e-03,
          -5.5209e-01,  2.6669e-01, -1.2266e-01],
         [ 6.5991e-01,  6.3159e-01,  1.3538e+00, -2.7384e-01,  8.5952e-01,
          -3.7105e-01, -1.3350e-01, -2.1770e+00],
         [ 3.7015e-01,  1.5858e+00, -4.2269e-01,  2.0528e+00,  3.1429e-01,
          -8.5222e-01, -5.5130e-01, -9.6802e-02]]])
In [ ]:
 

Comments