vocab.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Parser-v1 作者: tdozat 项目源码 文件源码
def weighted_average(self, inputs, moving_params=None):
    """"""

    input_shape = tf.shape(inputs)
    batch_size = input_shape[0]
    bucket_size = input_shape[1]
    input_size = len(self)

    if moving_params is not None:
      trainable_embeddings = moving_params.average(self.trainable_embeddings)
    else:
      trainable_embeddings = self.trainable_embeddings

    embed_input = tf.matmul(tf.reshape(inputs, [-1, input_size]),
                            trainable_embeddings)
    embed_input = tf.reshape(embed_input, tf.pack([batch_size, bucket_size, self.embed_size]))
    embed_input.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(self.embed_size)]) 
    if moving_params is None:
      tf.add_to_collection('Weights', embed_input)
    return embed_input

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号