conv_encoder_utils.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def linear_mapping_stupid(inputs, out_dim, in_dim=None, dropout=1.0, var_scope_name="linear_mapping"):
  with tf.variable_scope(var_scope_name):
    print('name', tf.get_variable_scope().name) 
    input_shape_tensor = tf.shape(inputs)   # dynamic shape, no None
    input_shape = inputs.get_shape().as_list()    # static shape. may has None
    print('input_shape', input_shape)
    assert len(input_shape) == 3
    inputs = tf.reshape(inputs, [-1, input_shape_tensor[-1]])

    linear_mapping_w = tf.get_variable("linear_mapping_w", [input_shape[-1], out_dim], initializer=tf.random_normal_initializer(mean=0, stddev=tf.sqrt(dropout*1.0/input_shape[-1])))
    linear_mapping_b = tf.get_variable("linear_mapping_b", [out_dim], initializer=tf.zeros_initializer())


    output = tf.matmul(inputs, linear_mapping_w) + linear_mapping_b
    print('xxxxx_params', input_shape, out_dim)
    #output = tf.reshape(output, [input_shape[0], -1, out_dim])
    output = tf.reshape(output, [input_shape_tensor[0], -1, out_dim])

  return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号