train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainer-image-caption 作者: dsanno 项目源码 文件源码
def forward(net, image_batch, sentence_batch, train=True):
    images = xp.asarray(image_batch)
    n, sentence_length = sentence_batch.shape
    net.initialize(images)
    loss = 0
    acc = 0
    size = 0
    for i in range(sentence_length - 1):
        target = xp.where(xp.asarray(sentence_batch[:, i]) != eos, 1, 0).astype(np.float32)
        if (target == 0).all():
            break
        with chainer.using_config('train', train):
            with chainer.using_config('enable_backprop', train):
                x = xp.asarray(sentence_batch[:, i])
                t = xp.asarray(sentence_batch[:, i + 1])
                y = net(x)
                y_max_index = xp.argmax(y.data, axis=1)
                mask = target.reshape((len(target), 1)).repeat(y.data.shape[1], axis=1)
                y = y * mask
                loss += F.softmax_cross_entropy(y, t)
                acc += xp.sum((y_max_index == t) * target)
                size += xp.sum(target)
    return loss / size, float(acc) / size, float(size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号