ConditionalRandomField

Reason is the light and the light of life.

Jerry Su Feb 04, 2021 5 mins

A comparison of Hidden Markov Models and Conditional Random Fields, two kinds of probabilistic graphical models.

import torch
from torch import nn
# BMES四位序列标注法
PADDING = 0
B = 2
E = 3
S = 4
M = 5
START = 6
END = 7

LABEL_VOCAB = {0: '<pad>', 1: '<unk>', 2: 'B', 3: 'E', 4: 'S', 5: 'M'}

NUM_TAGS = 8

1.1 Initialize Input

# batch_size x max_len x num_tags : (2, 9, 6)

logits = torch.randn(3, 9, NUM_TAGS)
print(f"logits:\n {logits}\n")
tags = torch.tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
                     [6, 4, 2, 3, 4, 4, 7, 0, 0],
                     [6, 2, 5, 3, 4, 7, 0, 0, 0]])
print(f"tags:\n {tags}\n")
mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1, 1, 1, 0, 0],
                     [1, 1, 1, 1, 1, 1, 0, 0, 0]])
print(f"mask:\n {mask}")
logits:
 tensor([[[-1.4452e-01, -1.9019e-01,  9.7182e-01, -1.5851e+00, -1.3361e+00,
          -1.4078e+00,  5.9501e-01,  1.1116e+00],
         [-1.2016e+00,  5.7204e-01, -1.7459e-01, -1.2101e+00, -1.5633e+00,
          -1.5958e+00,  7.4246e-01, -2.2454e-01],
         [-1.5298e-01,  7.3708e-01,  3.7166e-01,  2.4409e-01, -7.6002e-01,
          -1.1487e+00, -3.5016e-01,  1.7178e-01],
         [ 5.3908e-01,  1.7665e-01,  8.4331e-02,  1.2331e+00, -6.3707e-01,
           3.2750e-01, -9.5666e-01, -1.0764e+00],
         [-2.0275e-01, -4.7478e-01, -2.4096e-01,  3.4847e-01,  1.4107e+00,
          -6.7662e-01,  1.1356e+00, -8.8798e-01],
         [-5.2993e-01, -8.7384e-01, -7.1909e-01, -9.0088e-01, -1.0477e+00,
           5.7400e-01, -8.9259e-02, -9.4986e-01],
         [ 5.1900e-01, -1.3026e+00, -1.7043e+00, -2.8520e-01, -1.4247e+00,
           5.4460e-02, -6.3961e-01,  1.3025e-01],
         [ 1.2510e+00,  2.2883e-01,  3.4238e-01, -7.5308e-01,  2.5237e-03,
          -3.4200e-01, -2.2455e-01, -1.0249e+00],
         [ 4.8784e-01, -1.5034e+00, -1.6049e-01,  5.3555e-01,  2.2210e-03,
          -5.5209e-01,  2.6669e-01, -1.2266e-01]],

        [[ 4.8289e-01,  8.0970e-01, -2.6692e-01, -1.0974e+00,  2.5424e-01,
           1.4737e-01,  7.8376e-01, -1.4346e+00],
         [-2.6347e-01, -1.5178e+00,  1.5867e+00,  7.5319e-01, -2.9615e-01,
           6.0084e-01, -5.5508e-01, -1.8708e-01],
         [ 2.3153e+00, -1.7178e-01, -6.4674e-01,  1.1492e+00,  7.9638e-01,
           5.4677e-01,  1.6118e-01,  2.1595e-01],
         [ 1.6032e+00,  1.0249e+00, -1.1967e+00,  7.7639e-01,  1.0185e+00,
          -9.3879e-01, -1.4424e+00,  1.3888e+00],
         [-1.0615e-01,  4.5697e-01,  1.2877e-01,  1.3390e+00, -1.9707e+00,
          -5.8355e-01, -2.2570e+00, -1.1907e+00],
         [-6.4982e-01,  1.6911e-01, -3.0681e-01, -5.5137e-01, -6.9925e-01,
           1.9890e-01,  4.8145e-01,  5.3715e-01],
         [ 5.2184e-01,  5.8037e-01,  8.7149e-01, -2.0938e+00, -4.8395e-01,
          -2.0589e+00, -2.3386e+00,  7.5332e-02],
         [ 3.7932e-02,  1.3092e+00, -5.4694e-01,  1.6446e-01, -1.7779e+00,
          -6.6966e-01, -8.4106e-01,  9.3973e-02],
         [ 6.5991e-01,  6.3159e-01,  1.3538e+00, -2.7384e-01,  8.5952e-01,
          -3.7105e-01, -1.3350e-01, -2.1770e+00]],

        [[-5.1107e-01, -7.3291e-01,  8.8863e-01,  8.5765e-02,  2.0189e-01,
          -6.3774e-01,  1.1234e-01,  1.0721e+00],
         [-1.7006e-01, -2.5931e-01,  1.1782e+00, -2.6335e-01,  1.4209e+00,
           4.8150e-01,  1.1860e+00,  7.5185e-01],
         [-1.4374e+00,  1.0826e+00,  2.8144e-02,  6.7766e-01, -1.9165e-01,
           1.2558e-01,  1.2113e+00,  5.0190e-02],
         [ 3.9188e-01,  1.5474e+00, -1.5797e+00,  6.2201e-01, -3.8201e-01,
          -2.2004e+00, -1.2397e+00,  1.3466e+00],
         [-1.6803e+00,  1.5669e+00,  5.5389e-01, -8.0821e-01,  1.8370e+00,
          -2.8107e-02,  8.2856e-01,  7.0409e-01],
         [ 1.3986e+00, -1.1884e+00,  4.2444e-01,  1.3998e+00, -7.3532e-01,
          -9.4360e-01, -1.4932e-01, -1.2199e+00],
         [ 2.2620e-01, -3.7655e-01, -3.5929e-01, -1.8760e+00, -2.0527e+00,
           1.6695e-01,  2.2314e-01,  2.8254e-01],
         [ 1.6245e+00,  7.8641e-01,  1.4010e+00, -7.4482e-02, -1.4197e+00,
          -1.3425e+00, -8.8786e-01,  9.7691e-01],
         [ 3.7015e-01,  1.5858e+00, -4.2269e-01,  2.0528e+00,  3.1429e-01,
          -8.5222e-01, -5.5130e-01, -9.6802e-02]]])

tags:
 tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
        [6, 4, 2, 3, 4, 4, 7, 0, 0],
        [6, 2, 5, 3, 4, 7, 0, 0, 0]])

mask:
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0]])
logits = logits.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
trans_matrix = nn.Parameter(torch.randn(NUM_TAGS, NUM_TAGS))
trans_matrix
Parameter containing:
tensor([[ 0.3130,  0.6805, -1.6291,  1.1787,  0.2590,  0.0393,  0.3733, -1.0978],
        [ 0.0257, -0.7165,  0.1774, -0.4137, -0.9312,  1.1543, -1.7091,  0.1694],
        [-0.5890, -1.3841, -0.7339,  1.5916, -0.6873, -1.2056,  1.3192,  0.1930],
        [-1.0862, -0.3838,  0.3116, -1.0895, -0.5821, -1.2720,  0.2369,  0.5035],
        [-1.1792,  0.2191, -0.5459,  1.8000, -0.0737,  1.6784, -0.8590, -0.3808],
        [-0.1573, -1.7135, -0.2278,  1.8250, -0.4302,  2.0009, -1.1343,  0.4233],
        [-0.3599, -0.0824, -0.4446,  0.4188, -0.7154, -0.1829,  0.0595, -1.1767],
        [-1.2007,  2.5477,  0.0693, -0.9544, -0.8122,  0.1949, -0.0823,  0.3554]],
       requires_grad=True)

2. Compute the score for the gold path.

seq_len, batch_size, _ = logits.size()
print(f"seq_len: {seq_len}\nbatch_size: {batch_size}")
seq_len: 9
batch_size: 3
batch_idx = torch.arange(batch_size, dtype=torch.long)
batch_idx
tensor([0, 1, 2])
seq_idx = torch.arange(seq_len, dtype=torch.long)
seq_idx
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

2.1 transition probability score

mask
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 0.],
        [1., 0., 0.],
        [0., 0., 0.]])
mask = mask.eq(True)
mask
tensor([[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True, False],
        [ True, False, False],
        [False, False, False]])
flip_mask = mask.eq(False)
flip_mask
tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])
tags
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
tags[: seq_len -1]
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0]])
tags[1:]
tensor([[4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
trans_matrix
Parameter containing:
tensor([[ 0.3130,  0.6805, -1.6291,  1.1787,  0.2590,  0.0393,  0.3733, -1.0978],
        [ 0.0257, -0.7165,  0.1774, -0.4137, -0.9312,  1.1543, -1.7091,  0.1694],
        [-0.5890, -1.3841, -0.7339,  1.5916, -0.6873, -1.2056,  1.3192,  0.1930],
        [-1.0862, -0.3838,  0.3116, -1.0895, -0.5821, -1.2720,  0.2369,  0.5035],
        [-1.1792,  0.2191, -0.5459,  1.8000, -0.0737,  1.6784, -0.8590, -0.3808],
        [-0.1573, -1.7135, -0.2278,  1.8250, -0.4302,  2.0009, -1.1343,  0.4233],
        [-0.3599, -0.0824, -0.4446,  0.4188, -0.7154, -0.1829,  0.0595, -1.1767],
        [-1.2007,  2.5477,  0.0693, -0.9544, -0.8122,  0.1949, -0.0823,  0.3554]],
       requires_grad=True)
tags
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
trans_score = trans_matrix[tags[:seq_len - 1], tags[1:]]    # tensor的二维索引
trans_score
tensor([[-0.7154, -0.7154, -0.4446],
        [-0.5459, -0.5459, -1.2056],
        [-1.2056,  1.5916,  1.8250],
        [ 2.0009, -0.5821, -0.5821],
        [ 1.8250, -0.0737, -0.3808],
        [-0.5821, -0.3808, -1.2007],
        [-0.3808, -1.2007,  0.3130],
        [-1.2007,  0.3130,  0.3130]], grad_fn=<IndexBackward>)
flip_mask
tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])
flip_mask[1:, :]
tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])
# 最终计算的转移矩阵!!!
trans_score = trans_score.masked_fill(flip_mask[1:, :], 0)
trans_score
tensor([[-0.7154, -0.7154, -0.4446],
        [-0.5459, -0.5459, -1.2056],
        [-1.2056,  1.5916,  1.8250],
        [ 2.0009, -0.5821, -0.5821],
        [ 1.8250, -0.0737, -0.3808],
        [-0.5821, -0.3808,  0.0000],
        [-0.3808,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]], grad_fn=<MaskedFillBackward0>)

2.2 emission probability score

# emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
# emit_score
logits.size()
torch.Size([9, 3, 8])
seq_idx
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
seq_idx.size()
torch.Size([9])
seq_idx.view(-1, 1)
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8]])
batch_idx
tensor([0, 1, 2])
batch_idx.size()
torch.Size([3])
batch_idx.view(1, -1)
tensor([[0, 1, 2]])
tags
tensor([[6, 6, 6],
        [4, 4, 2],
        [2, 2, 5],
        [5, 3, 3],
        [5, 4, 4],
        [3, 4, 7],
        [4, 7, 0],
        [7, 0, 0],
        [0, 0, 0]])
logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags]
tensor([[ 0.5950,  0.7838,  0.1123],
        [-1.5633, -0.2962,  1.1782],
        [ 0.3717, -0.6467,  0.1256],
        [ 0.3275,  0.7764,  0.6220],
        [-0.6766, -1.9707,  1.8370],
        [-0.9009, -0.6992, -1.2199],
        [-1.4247,  0.0753,  0.2262],
        [-1.0249,  0.0379,  1.6245],
        [ 0.4878,  0.6599,  0.3702]])
logits
tensor([[[-1.4452e-01, -1.9019e-01,  9.7182e-01, -1.5851e+00, -1.3361e+00,
          -1.4078e+00,  5.9501e-01,  1.1116e+00],
         [ 4.8289e-01,  8.0970e-01, -2.6692e-01, -1.0974e+00,  2.5424e-01,
           1.4737e-01,  7.8376e-01, -1.4346e+00],
         [-5.1107e-01, -7.3291e-01,  8.8863e-01,  8.5765e-02,  2.0189e-01,
          -6.3774e-01,  1.1234e-01,  1.0721e+00]],

        [[-1.2016e+00,  5.7204e-01, -1.7459e-01, -1.2101e+00, -1.5633e+00,
          -1.5958e+00,  7.4246e-01, -2.2454e-01],
         [-2.6347e-01, -1.5178e+00,  1.5867e+00,  7.5319e-01, -2.9615e-01,
           6.0084e-01, -5.5508e-01, -1.8708e-01],
         [-1.7006e-01, -2.5931e-01,  1.1782e+00, -2.6335e-01,  1.4209e+00,
           4.8150e-01,  1.1860e+00,  7.5185e-01]],

        [[-1.5298e-01,  7.3708e-01,  3.7166e-01,  2.4409e-01, -7.6002e-01,
          -1.1487e+00, -3.5016e-01,  1.7178e-01],
         [ 2.3153e+00, -1.7178e-01, -6.4674e-01,  1.1492e+00,  7.9638e-01,
           5.4677e-01,  1.6118e-01,  2.1595e-01],
         [-1.4374e+00,  1.0826e+00,  2.8144e-02,  6.7766e-01, -1.9165e-01,
           1.2558e-01,  1.2113e+00,  5.0190e-02]],

        [[ 5.3908e-01,  1.7665e-01,  8.4331e-02,  1.2331e+00, -6.3707e-01,
           3.2750e-01, -9.5666e-01, -1.0764e+00],
         [ 1.6032e+00,  1.0249e+00, -1.1967e+00,  7.7639e-01,  1.0185e+00,
          -9.3879e-01, -1.4424e+00,  1.3888e+00],
         [ 3.9188e-01,  1.5474e+00, -1.5797e+00,  6.2201e-01, -3.8201e-01,
          -2.2004e+00, -1.2397e+00,  1.3466e+00]],

        [[-2.0275e-01, -4.7478e-01, -2.4096e-01,  3.4847e-01,  1.4107e+00,
          -6.7662e-01,  1.1356e+00, -8.8798e-01],
         [-1.0615e-01,  4.5697e-01,  1.2877e-01,  1.3390e+00, -1.9707e+00,
          -5.8355e-01, -2.2570e+00, -1.1907e+00],
         [-1.6803e+00,  1.5669e+00,  5.5389e-01, -8.0821e-01,  1.8370e+00,
          -2.8107e-02,  8.2856e-01,  7.0409e-01]],

        [[-5.2993e-01, -8.7384e-01, -7.1909e-01, -9.0088e-01, -1.0477e+00,
           5.7400e-01, -8.9259e-02, -9.4986e-01],
         [-6.4982e-01,  1.6911e-01, -3.0681e-01, -5.5137e-01, -6.9925e-01,
           1.9890e-01,  4.8145e-01,  5.3715e-01],
         [ 1.3986e+00, -1.1884e+00,  4.2444e-01,  1.3998e+00, -7.3532e-01,
          -9.4360e-01, -1.4932e-01, -1.2199e+00]],

        [[ 5.1900e-01, -1.3026e+00, -1.7043e+00, -2.8520e-01, -1.4247e+00,
           5.4460e-02, -6.3961e-01,  1.3025e-01],
         [ 5.2184e-01,  5.8037e-01,  8.7149e-01, -2.0938e+00, -4.8395e-01,
          -2.0589e+00, -2.3386e+00,  7.5332e-02],
         [ 2.2620e-01, -3.7655e-01, -3.5929e-01, -1.8760e+00, -2.0527e+00,
           1.6695e-01,  2.2314e-01,  2.8254e-01]],

        [[ 1.2510e+00,  2.2883e-01,  3.4238e-01, -7.5308e-01,  2.5237e-03,
          -3.4200e-01, -2.2455e-01, -1.0249e+00],
         [ 3.7932e-02,  1.3092e+00, -5.4694e-01,  1.6446e-01, -1.7779e+00,
          -6.6966e-01, -8.4106e-01,  9.3973e-02],
         [ 1.6245e+00,  7.8641e-01,  1.4010e+00, -7.4482e-02, -1.4197e+00,
          -1.3425e+00, -8.8786e-01,  9.7691e-01]],

        [[ 4.8784e-01, -1.5034e+00, -1.6049e-01,  5.3555e-01,  2.2210e-03,
          -5.5209e-01,  2.6669e-01, -1.2266e-01],
         [ 6.5991e-01,  6.3159e-01,  1.3538e+00, -2.7384e-01,  8.5952e-01,
          -3.7105e-01, -1.3350e-01, -2.1770e+00],
         [ 3.7015e-01,  1.5858e+00, -4.2269e-01,  2.0528e+00,  3.1429e-01,
          -8.5222e-01, -5.5130e-01, -9.6802e-02]]])


Read more:

Related posts: