batch_ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:t3f 作者: Bihaqo 项目源码 文件源码
def multiply_along_batch_dim(batch_tt, weights):
  """Multiply each TensorTrain in a batch by a number.

  Args:
    batch_tt: TensorTrainBatch object, TT-matrices or TT-tensors.
    weights: 1-D tf.Tensor (or something convertible to it like np.array) of size
     tt.batch_sie with weights. 

  Returns:
    TensorTrainBatch
  """
  weights = tf.convert_to_tensor(weights)
  tt_cores = list(batch_tt.tt_cores)
  if batch_tt.is_tt_matrix():
    weights = weights[:, tf.newaxis, tf.newaxis, tf.newaxis, tf.newaxis]
  else:
    weights = weights[:, tf.newaxis, tf.newaxis, tf.newaxis]
  tt_cores[0] = weights * tt_cores[0]
  out_shape = batch_tt.get_raw_shape()
  out_ranks = batch_tt.get_tt_ranks()
  out_batch_size = batch_tt.batch_size
  return TensorTrainBatch(tt_cores, out_shape, out_ranks, out_batch_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号