linalg.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Parser-v1 作者: tdozat 项目源码 文件源码
def broadcast_add(inputs1, inputs2):
  """"""

  inputs1_shape = tf.shape(inputs1)
  inputs_size = inputs1.get_shape().as_list()[-1]
  inputs2_shape = tf.shape(inputs2)
  inputs1 = tf.transpose(inputs1, [0,2,1])
  inputs2 = tf.transpose(inputs2, [0,2,1])
  inputs1 = tf.reshape(inputs1, tf.pack([-1,inputs1_shape[1],1]))
  inputs2 = tf.reshape(inputs2, tf.pack([-1,1,inputs2_shape[1]]))
  inputs = inputs1 + inputs2
  inputs = tf.reshape(inputs, [inputs1_shape[0], inputs1_shape[2],  inputs1_shape[1], inputs2_shape[1]])
  inputs = tf.transpose(inputs, [0,2,3,1])
  inputs.set_shape([tf.Dimension(None)]*3 + [tf.Dimension(inputs_size)])
  return inputs

#===============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号