def __call__(self, x, context):
x = F.broadcast_to(x[:, None], (context.shape[0], context.shape[1]))
x = F.reshape(x, (context.shape[0] * context.shape[1],))
if args.subword == 'rnn':
context = context.reshape((context.shape[0] * context.shape[1]))
e = self.rnn.charRNN(context)
if args.subword == 'none':
e = self.embed(context)
e = F.reshape(e, (e.shape[0] * e.shape[1], e.shape[2]))
loss = self.loss_func(e, x)
reporter.report({'loss': loss}, self)
return loss
train_word2vec_subword_chainer_input.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录