def multiply_along_batch_dim(batch_tt, weights):
"""Multiply each TensorTrain in a batch by a number.
Args:
batch_tt: TensorTrainBatch object, TT-matrices or TT-tensors.
weights: 1-D tf.Tensor (or something convertible to it like np.array) of size
tt.batch_sie with weights.
Returns:
TensorTrainBatch
"""
weights = tf.convert_to_tensor(weights)
tt_cores = list(batch_tt.tt_cores)
if batch_tt.is_tt_matrix():
weights = weights[:, tf.newaxis, tf.newaxis, tf.newaxis, tf.newaxis]
else:
weights = weights[:, tf.newaxis, tf.newaxis, tf.newaxis]
tt_cores[0] = weights * tt_cores[0]
out_shape = batch_tt.get_raw_shape()
out_ranks = batch_tt.get_tt_ranks()
out_batch_size = batch_tt.batch_size
return TensorTrainBatch(tt_cores, out_shape, out_ranks, out_batch_size)
评论列表
文章目录