Scatter_nd

Reason is the light and the light of life.

Jerry Su Mar 04, 2021 3 mins
import torch
import tensorflow as tf
tf.scatter_nd(
       indices,
       updates,
       shape,
       name=None
)

根据indices将updates散布到新的(初始为零)张量.

根据索引对给定shape的零张量中的单个值或切片应用稀疏updates来创建新的张量.此运算符是tf.gather_nd运算符的反函数,它从给定的张量中提取值或切片.

警告:更新应用的顺序是非确定性的,所以如果indices包含重复项的话,则输出将是不确定的.

indices是一个整数张量,其中含有索引形成一个新的形状shape张量.indices的最后的维度可以是shape的最多的秩:






scatter_nd

torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

indices = tf.constant([[1], [3], [1], [7]])
print(f"indices: {indices}")
updates = tf.constant([9, 10, 11, 12])
print(f"update: {updates}")
shape = tf.constant([8])
print(f"shape: {shape}")
indices: [[1]
 [3]
 [1]
 [7]]
update: [ 9 10 11 12]
shape: [8]
scatter = tf.scatter_nd(indices, updates, shape)
scatter
<tf.Tensor: shape=(8,), dtype=int32, numpy=array([ 0, 20,  0, 10,  0,  0,  0, 12], dtype=int32)>
updates = torch.Tensor([[9,10,11,12]])
indices = torch.Tensor([[4,3,4,7]]).long()
out = torch.zeros(1, 8).scatter_(1, indices, updates, reduce='add')
out
tensor([[ 0.,  0.,  0., 10., 20.,  0.,  0., 12.]])


indices.dtype
torch.int64

scatter_nd

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
scatter
<tf.Tensor: shape=(4, 4, 4), dtype=int32, numpy=
array([[[5, 5, 5, 5],
        [6, 6, 6, 6],
        [7, 7, 7, 7],
        [8, 8, 8, 8]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[5, 5, 5, 5],
        [6, 6, 6, 6],
        [7, 7, 7, 7],
        [8, 8, 8, 8]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]]], dtype=int32)>

batch_size = update.shape[0]
batch_nums = tf.range(batch_size)
batch_nums = tf.expand_dims(batch_nums, axis=1)
batch_nums
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
array([[0],
       [1],
       [2],
       [3]], dtype=int32)>
batch_nums = tf.tile(batch_nums, [1, update.shape[-1]])
batch_nums
<tf.Tensor: shape=(4, 5), dtype=int32, numpy=
array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3]], dtype=int32)>
encode_inputs = tf.constant([[1,2,3,2,2], [1,1,1,1,1]])
encode_inputs = tf.expand_dims(encode_inputs, axis=1)
encode_inputs = tf.tile(encode_inputs, [1, 2, 1])
encode_inputs = tf.reshape(encode_inputs, shape=(-1, update.shape[-1]))
encode_inputs
<tf.Tensor: shape=(4, 5), dtype=int32, numpy=
array([[1, 2, 3, 2, 2],
       [1, 2, 3, 2, 2],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1]], dtype=int32)>
update = tf.constant([[11,12,13,14,15], [21,22,23,24,25], [31,32,33,34,35], [41,42,43,44,45]])
update
<tf.Tensor: shape=(4, 5), dtype=int32, numpy=
array([[11, 12, 13, 14, 15],
       [21, 22, 23, 24, 25],
       [31, 32, 33, 34, 35],
       [41, 42, 43, 44, 45]], dtype=int32)>
indices = tf.stack([batch_nums, encode_inputs], axis=2)
indices
<tf.Tensor: shape=(4, 5, 2), dtype=int32, numpy=
array([[[0, 1],
        [0, 2],
        [0, 3],
        [0, 2],
        [0, 2]],

       [[1, 1],
        [1, 2],
        [1, 3],
        [1, 2],
        [1, 2]],

       [[2, 1],
        [2, 1],
        [2, 1],
        [2, 1],
        [2, 1]],

       [[3, 1],
        [3, 1],
        [3, 1],
        [3, 1],
        [3, 1]]], dtype=int32)>
tf.scatter_nd(indices, update, (batch_size,10))
<tf.Tensor: shape=(4, 10), dtype=int32, numpy=
array([[  0,  11,  41,  13,   0,   0,   0,   0,   0,   0],
       [  0,  21,  71,  23,   0,   0,   0,   0,   0,   0],
       [  0, 165,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0, 215,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)>

update_torch = torch.tensor([[11,12,13,14,15], [21,22,23,24,25], [31,32,33,34,35], [41,42,43,44,45]])
update_torch
tensor([[11, 12, 13, 14, 15],
        [21, 22, 23, 24, 25],
        [31, 32, 33, 34, 35],
        [41, 42, 43, 44, 45]])
torch.zeros((batch_size_torch,10))
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
batch_size_torch = update_torch.shape[0]
batch_nums_torch = torch.arange(batch_size_torch)
batch_nums_torch = batch_nums_torch.unsqueeze(dim=1)
batch_nums_torch = batch_nums_torch.repeat([1, update_torch.shape[-1]])
batch_nums_torch
tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]])
encode_inputs_torch = torch.tensor([[1,2,3,2,2], [1,1,1,1,1]])
encode_inputs_torch = encode_inputs_torch.unsqueeze(dim=1)
encode_inputs_torch = encode_inputs_torch.repeat([1, 2, 1])
encode_inputs_torch = encode_inputs_torch.view(-1, update_torch.shape[-1])
encode_inputs_torch
tensor([[1, 2, 3, 2, 2],
        [1, 2, 3, 2, 2],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])
indices = encode_inputs_torch
indices
tensor([[1, 2, 3, 2, 2],
        [1, 2, 3, 2, 2],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])
torch.zeros((batch_size_torch,10), dtype=update_torch.dtype).scatter_add_(1, indices, update_torch)
tensor([[  0,  11,  41,  13,   0,   0,   0,   0,   0,   0],
        [  0,  21,  71,  23,   0,   0,   0,   0,   0,   0],
        [  0, 165,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0, 215,   0,   0,   0,   0,   0,   0,   0,   0]])

src = torch.ones((2, 5))
src
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
index = torch.tensor([[0, 1, 2, 0, 0],
                      [0, 1, 2, 2, 2]])
index
tensor([[0, 1, 2, 0, 0],
        [0, 1, 2, 2, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])
5e-5 
5e-05


Read more:

Related posts: