tf.concat/split/stack

Reason is the light and the light of life.

Jerry Su Nov 23, 2020 1 mins

Merge and Split

  • tf.concat

  • tf.split

  • tf.stack

  • tf.unstack

1. tf.concat

原有维度concat,如果新维度concat,则stack

  • concat:拼接的维度可以不同,其他维度必须相同。 concat([3, 35, 8], [2, 35, 8]) -> [5, 35, 8]

  • stack:所有维度必须相同。 stack([3, 35, 8], [3, 35, 8]) -> [2, 3, 35, 8]

import tensorflow as tf
# [classes, students, scores]
a = tf.ones([4, 35, 8])
b = tf.ones([2, 35, 8])
# class维度concat
c = tf.concat([a, b], axis=0)
c.shape
TensorShape([6, 35, 8])
# [classes, students, scores]
a = tf.ones([4, 30, 8])
b = tf.ones([4, 5, 8])
# students维度concat
tf.concat([a, b], axis=1).shape
TensorShape([4, 35, 8])

2. tf.stack

stack: create new dim

会创建一个新的维度,设计新的schools概念,在新的维度上表达。

[classes, students, scores] -> [schools, classes, students, scores] 新建schools维度并concat

# [classes, students, scores]
a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])
tf.stack([a, b], axis=0).shape
TensorShape([2, 4, 35, 8])
tf.stack([a, b], axis=3).shape
TensorShape([4, 35, 8, 2])

3. tf.split vs unstack

e = tf.ones([2, 4, 35, 8])
res = tf.unstack(e, axis=3)
len(res), res[0].shape
(8, TensorShape([2, 4, 35]))
res = tf.split(e, axis=3, num_or_size_splits=2)
len(res), res[0].shape
(2, TensorShape([2, 4, 35, 4]))
res = tf.split(e, axis=3, num_or_size_splits=[2, 2, 4])
len(res), res[0].shape, res[2].shape
(3, TensorShape([2, 4, 35, 2]), TensorShape([2, 4, 35, 4]))


Read more:

Related posts: