Slicing Tensor With List - Tensorflow
Is there a way to accomplish this method of slicing in Tensorflow (example shown using numpy)? z = np.random.random((3,7,7,12)) x = z[...,[0,5]] such that x_hat = np.concatenate([
Solution 1:
How about:
x = tf.stack([tfz[..., i] for i in [0,5]], axis=-1)
This works for me:
z = np.random.random((3,7,7,12))
tfz = tf.constant(z)
x = tf.stack([tfz[..., i] for i in [0,5]], axis=-1)
x_hat = np.concatenate([z[...,[0]], z[...,[5]]], 3)
with tf.Session() as sess:
x_run = sess.run(x)
assert np.all(x_run == x_hat)
Solution 2:
You need a reshape to make the result of concatenation consistent with the original shape (the first 3 dimensions).
z = np.arange(36)
tfz = tf.reshape(tf.constant(z), [2, 3, 2, 3])
slice1 = tf.reshape(tfz[:,:,:,1], [2, 3, -1, 1])
slice2 = tf.reshape(tfz[:,:,:,2], [2, 3, -1, 1])
slice = tf.concat([slice1, slice2], axis=3)
with tf.Session() as sess:
print sess.run([tfz, slice])
> [[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]],
[[[18, 19, 20],
[21, 22, 23]],
[[24, 25, 26],
[27, 28, 29]],
[[30, 31, 32],
[33, 34, 35]]]]
# Get the last two columns
> [[[[ 1, 2],
[ 4, 5]],
[[ 7, 8],
[10, 11]],
[[13, 14],
[16, 17]]],
[[[19, 20],
[22, 23]],
[[25, 26],
[28, 29]],
[[31, 32],
[34, 35]]]]
Solution 3:
It is a shape error like greeness said. Unfortunately, there doesn't seem to be a simple way of doing it like I hoped, but this is the generalized solution I came up with:
deflist_slice(tensor, indices, axis):
"""
Args
----
tensor (Tensor) : input tensor to slice
indices ( [int] ) : list of indices of where to perform slices
axis (int) : the axis to perform the slice on
"""
slices = []
## Set the shape of the output tensor. # Set any unknown dimensions to -1, so that reshape can infer it correctly. # Set the dimension in the slice direction to be 1, so that overall dimensions are preserved during the operation
shape = tensor.get_shape().as_list()
shape[shape==None] = -1
shape[axis] = 1
nd = len(shape)
for i in indices:
_slice = [slice(None)]*nd
_slice[axis] = slice(i,i+1)
slices.append(tf.reshape(tensor[_slice], shape))
return tf.concat(slices, axis=axis)
z = np.random.random(size=(3, 7, 7, 12))
x = z[...,[0,5]]
tfz = tf.constant(z)
tfx_hat = list_slice(tfz, [0, 5], axis=3)
x_hat = tfx_hat.eval()
assert np.all(x == x_hat)
Post a Comment for "Slicing Tensor With List - Tensorflow"