tf.broadcast_to

Reason is the light and the light of life.

Jerry Su Nov 23, 2020 1 mins

Broadcast

  • expand_dims

  • without copying data vs tf.tile

  • tf.broadcast_to

Key idea

  • 如何a,b张量维度不一致,则插入1dim。小维度对齐

  • expand 1dim到相同dims

e.g. a = [4, 32, 32, 3]

b = [3] -> [1, 1, 1, 3] -> [4, 32, 32, 3]

tf.broadcast_to

import tensorflow as tf
x = tf.random.normal([4, 32, 32, 3])
y = tf.random.normal([4, 1, 1, 1])
x.shape, y.shape
(TensorShape([4, 32, 32, 3]), TensorShape([4, 1, 1, 1]))
# 隐式broadcast
(x + y).shape
TensorShape([4, 32, 32, 3])
# 显式broadcast
(x + tf.broadcast_to(y, [4, 32, 32, 3])).shape
TensorShape([4, 32, 32, 3])

Broadcast VS Tile

broadcast数据并没有真正复制,tile数据真实复制,broadcast更高效且节省内存空间。

a = tf.ones([3, 4])
a.shape
TensorShape([3, 4])
# broadcast
a1 = tf.broadcast_to(a, [2, 3, 4])
a1
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)>
# tile
a2 = tf.expand_dims(a, axis=0)
a2.shape
TensorShape([1, 3, 4])
a2 = tf.tile(a2, [2, 1, 1])   # [2, 1, 1]各个轴复制次数
a2
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)>


Read more:

Related posts: