lexicons.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:nmtrain 作者: philip30 项目源码 文件源码
def __call__(self, y, a, ht, y_lex):
    y_dict = F.squeeze(F.batch_matmul(y_lex, a, transa=True), axis=2)
    return (y + F.log(y_dict + self.alpha))

#class LinearInterpolationLexicon(chainer.Chain):
#  def __init__(self, hidden_size):
#    super(LinearInterpolationLexicon, self).__init__(
#      perceptron = chainer.links.Linear(hidden_size, 1)
#    )
#
#  def __call__(self, y, a, ht, y_lex):
#    y      = F.softmax(y)
#    y_dict = F.squeeze(F.batch_matmul(y_lex, a, transa=True), axis=2)
#    gamma  = F.broadcast_to(F.sigmoid(self.perceptron(ht)), y_dict.data.shape)
#    return (gamma * y_dict + (1-gamma) * y)
#
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号