adverse_models.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:nli_generation 作者: jstarc 项目源码 文件源码
def adverse_model(discriminator):

    train_input = Input(shape=(None,), dtype='int32')
    hypo_input = Input(shape=(None,), dtype='int32')

    def margin_opt(inputs):
        assert len(inputs) == 2, ('Margin Output needs '
                              '2 inputs, %d given' % len(inputs))
        return K.log(inputs[0]) + K.log(1-inputs[1])

    margin = Lambda(margin_opt, output_shape=(lambda s : (None, 1)))\
               ([discriminator(train_input), discriminator(hypo_input)])
    adverserial = Model([train_input, hypo_input], margin)

    adverserial.compile(loss=minimize, optimizer='adam')
    return adverserial
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号