花式索引

组合索引

与numpy索引方式几乎一致,参考numpy学习

In [4]:
import torch
In [8]:
x = torch.randn(1, 3, 4, 5)
print(x.shape)
x
torch.Size([1, 3, 4, 5])
Out[8]:
tensor([[[[ 2.2271,  2.0278, -0.5272,  1.4764, -0.6787],
          [-0.7569, -0.8737,  0.0031,  0.7208,  1.9315],
          [ 0.6217,  0.3060, -0.3342,  0.1870,  1.5624],
          [-1.2602, -1.7282, -0.1555,  2.0486,  0.2562]],

         [[-0.3865, -0.4514,  0.6373, -0.2836, -0.0389],
          [ 1.1489, -1.4416,  0.2477,  1.7385,  0.7891],
          [ 0.7877,  0.6736, -1.0933,  0.3111,  0.2772],
          [ 0.4253,  1.3894, -1.1572,  1.6345,  0.1543]],

         [[-1.6844,  0.5027, -0.4037,  1.5505, -0.6599],
          [ 0.8898, -1.4016, -0.2114,  1.2523,  0.3860],
          [-0.0034,  0.9842, -0.2550, -0.6099, -0.5720],
          [-1.4554,  0.3559,  1.4612, -0.4953, -1.1379]]]])
In [6]:
dim0 = torch.tensor([0, 0])
dim2 = torch.tensor([0, 1])
In [10]:
x_ = x[dim0, :, dim2]
x_
Out[10]:
tensor([[[ 2.2271,  2.0278, -0.5272,  1.4764, -0.6787],
         [-0.3865, -0.4514,  0.6373, -0.2836, -0.0389],
         [-1.6844,  0.5027, -0.4037,  1.5505, -0.6599]],

        [[-0.7569, -0.8737,  0.0031,  0.7208,  1.9315],
         [ 1.1489, -1.4416,  0.2477,  1.7385,  0.7891],
         [ 0.8898, -1.4016, -0.2114,  1.2523,  0.3860]]])
In [ ]:
 

Comments