linalg.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Parser-v1 作者: tdozat 项目源码 文件源码
def diagonal_bilinear(inputs1, inputs2, output_size, add_bias2=True, add_bias1=True, add_bias=False, initializer=None, scope=None, moving_params=None):
  """"""

  with tf.variable_scope(scope or 'Bilinear'):
    # Reformat the inputs
    ndims = len(inputs1.get_shape().as_list())
    inputs1_shape = tf.shape(inputs1)
    inputs2_shape = tf.shape(inputs2)
    inputs1_bucket_size = inputs1_shape[ndims-2]
    inputs2_bucket_size = inputs2_shape[ndims-2]

    inputs1_size = inputs1.get_shape().as_list()[-1]
    inputs2_size = inputs2.get_shape().as_list()[-1]
    assert inputs1_size == inputs2_size

    output_shape = []
    batch_size = 1
    for i in xrange(ndims-2):
      batch_size *= inputs1_shape[i]
      output_shape.append(inputs1_shape[i])
    output_shape.append(inputs1_bucket_size)
    output_shape.append(output_size)
    output_shape.append(inputs2_bucket_size)
    output_shape = tf.pack(output_shape)
    inputs1 = tf.reshape(inputs1, tf.pack([batch_size, inputs1_bucket_size, inputs1_size]))
    inputs2 = tf.reshape(inputs2, tf.pack([batch_size, inputs2_bucket_size, inputs2_size]))
    inputs1.set_shape([tf.Dimension(None)]*2 + [tf.Dimension(inputs1_size)])
    inputs2.set_shape([tf.Dimension(None)]*2 + [tf.Dimension(inputs2_size)])

    inputs = broadcast_mult(inputs1, inputs2)
    with tf.variable_scope('Bilinear'):
      bilin = linear(inputs, output_size, add_bias=add_bias, initializer=initializer, scope=scope, moving_params=moving_params)
    with tf.variable_scope('Linear1'):
      lin1 = linear(inputs1, output_size, add_bias=False, initializer=initializer, scope=scope, moving_params=moving_params)
      lin1 = tf.expand_dims(lin1, 2)
    with tf.variable_scope('Linear2'):
      lin2 = linear(inputs2, output_size, add_bias=False, initializer=initializer, scope=scope, moving_params=moving_params)
      lin2 = tf.expand_dims(lin2, 1)

    bilin = tf.transpose(bilin+lin1+lin2, [0,1,3,2])

    return bilin

#===============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号