def SimpleSparseTensorFrom(x):
"""Create a very simple SparseTensor with dimensions (batch, time).
Args:
x: a list of lists of type int
Returns:
x_ix and x_val, the indices and values of the SparseTensor<2>.
"""
x_ix = []
x_val = []
for batch_i, batch in enumerate(x):
for time, val in enumerate(batch):
x_ix.append([batch_i, time])
x_val.append(val)
x_shape = [len(x), np.asarray(x_ix).max(0)[1]+1]
x_ix = tf.constant(x_ix, tf.int64)
x_val = tf.constant(x_val, tf.int32)
x_shape = tf.constant(x_shape, tf.int64)
return tf.SparseTensor(x_ix, x_val, x_shape)
评论列表
文章目录