train2.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainer-stack-gan 作者: dsanno 项目源码 文件源码
def update(gen1, gen2, dis, optimizer_gen, optimizer_dis, x_batch, margin):
    xp = gen1.xp
    batch_size = len(x_batch)

    # from generated image
    z = xp.random.normal(0, 1, (batch_size, latent_size)).astype(np.float32)
    z = z / (xp.linalg.norm(z, axis=1, keepdims=True) + 1e-12)
    x_stack1 = gen1(Variable(z, volatile=True), train=False)
    x_gen = gen2(x_stack1.data)
    total_size = np.prod(x_gen.shape)
    del z
    del x_stack1
    y_gen, h_gen = dis(x_gen)
    h_gen = F.normalize(F.reshape(h_gen, (batch_size, -1)))
    similarity = F.sum(F.matmul(h_gen, h_gen, transb=True)) / (batch_size * batch_size)
    del h_gen
    loss_gen = F.mean_squared_error(x_gen, y_gen) + 0.1 * similarity
    loss_dis = F.sum(F.relu(margin * margin - F.batch_l2_norm_squared(x_gen - y_gen))) / total_size
    del x_gen
    del y_gen
    del similarity
    # from real image
    x = xp.asarray(x_batch)
    y, h = dis(x)
    loss_dis += F.mean_squared_error(x, y)

    gen2.cleargrads()
    loss_gen.backward()
    optimizer_gen.update()
    loss_gen_data = float(loss_gen.data)
    del loss_gen

    dis.cleargrads()
    loss_dis.backward()
    optimizer_dis.update()

    return loss_gen_data, float(loss_dis.data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号