Date Tags Paddle
In [1]:
import torch
In [92]:
x = torch.randn(2, 3, 4, 5)
print(x.shape)
x
torch.Size([2, 3, 4, 5])
Out[92]:
tensor([[[[-0.0760, -0.9579, -1.4367, -1.9330,  1.9332],
          [-0.5192, -1.2808, -0.9056, -1.8171,  2.4717],
          [-0.3558,  0.1629, -2.8641,  1.3472, -0.9211],
          [ 0.0932,  0.6473,  1.0439,  1.2783,  0.1088]],

         [[-0.1191, -0.8518,  0.9753, -0.7256, -0.7781],
          [-0.5159, -0.0078,  0.3808,  2.2084,  0.5428],
          [ 1.6506,  0.6946,  0.1436,  0.5960,  0.4892],
          [-2.6200,  0.0104,  0.1307, -1.1882,  1.1010]],

         [[ 0.0119,  0.0548, -0.1476, -0.3266, -0.3430],
          [ 0.0218, -1.3204, -1.4213,  0.5170,  0.2243],
          [ 1.6161,  1.3656, -1.6352,  0.9008, -0.4133],
          [ 0.1467,  0.6184,  2.1848, -0.2661, -0.7767]]],


        [[[-0.0625,  0.0730,  0.5231,  0.2452, -0.3780],
          [-0.3511,  0.7749, -2.4241,  0.4164,  0.7710],
          [-0.5274, -2.0765, -0.3656, -0.8910,  0.0848],
          [-1.6999, -0.0182, -0.1018,  1.0403, -1.1671]],

         [[ 1.5376,  1.5196, -0.0153, -0.7229, -1.3471],
          [-0.1645,  0.8966, -0.5429, -0.2707,  0.7849],
          [-0.0886, -1.0376,  1.2382,  0.2717, -0.3969],
          [-0.3015, -0.6165,  0.2013, -0.0116,  0.8203]],

         [[-0.7642, -1.6964,  1.4862,  0.8707, -0.3923],
          [ 1.4053, -0.2089,  1.5149,  1.3437, -0.7319],
          [ 0.5642,  0.5984,  0.3218, -1.3115,  0.3141],
          [-1.3803,  0.0154,  0.4372,  0.4740, -1.0228]]]])
In [93]:
dim0 = torch.tensor([0, 1])
dim2 = torch.tensor([0, 2])
In [101]:
x_ = x[dim0, :, dim2]
print(x_.shape)
x_
torch.Size([2, 3, 5])
Out[101]:
tensor([[[-0.0760, -0.9579, -1.4367, -1.9330,  1.9332],
         [-0.1191, -0.8518,  0.9753, -0.7256, -0.7781],
         [ 0.0119,  0.0548, -0.1476, -0.3266, -0.3430]],

        [[-0.5274, -2.0765, -0.3656, -0.8910,  0.0848],
         [-0.0886, -1.0376,  1.2382,  0.2717, -0.3969],
         [ 0.5642,  0.5984,  0.3218, -1.3115,  0.3141]]])

paddle实现

In [95]:
import paddle
In [105]:
xx = paddle.to_tensor(x.numpy())
# xx
In [117]:
axis0 = paddle.to_tensor([0, 1])
axis1 = paddle.to_tensor([0, 2])
In [121]:
out1 = paddle.gather(xx, index=axis0, axis=0)
In [139]:
index =paddle.stack([axis0, axis1], axis=-1)
index 
Out[139]:
Tensor(shape=[2, 2], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[0, 0],
        [1, 2]])
In [140]:
paddle.gather_nd(out1, index=index)#.transpose([0, 1, 3, 2])
Out[140]:
Tensor(shape=[2, 4, 5], dtype=float32, place=CPUPlace, stop_gradient=True,
       [[[-0.07601893, -0.95786285, -1.43674135, -1.93300319,  1.93319392],
         [-0.51922202, -1.28075957, -0.90560406, -1.81705618,  2.47172594],
         [-0.35581169,  0.16289686, -2.86410379,  1.34724343, -0.92106235],
         [ 0.09319570,  0.64727783,  1.04392993,  1.27829981,  0.10877635]],

        [[-0.76419204, -1.69635403,  1.48616660,  0.87069529, -0.39233911],
         [ 1.40529430, -0.20886442,  1.51489639,  1.34371901, -0.73188514],
         [ 0.56418574,  0.59838581,  0.32179439, -1.31146288,  0.31410775],
         [-1.38033020,  0.01542827,  0.43722150,  0.47403052, -1.02283168]]])
In [ ]:
 

Comments