cwgan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Dialog-System-with-GAN-model 作者: drcut 项目源码 文件源码
def generator(encoder_inputs,decoder_inputs,target_weights,bucket_id,seq_len):
    def seq2seq_f(encoder,decoder):
        cell = tf.contrib.rnn.BasicLSTMCell(embedding_size)
        if num_layers > 1:
            cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
        w = tf.get_variable("proj_w", [embedding_size, num_symbols])
        b = tf.get_variable("proj_b", [num_symbols])
        output_projection = (w, b)
        outputs, state = tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(encoder,
                decoder,cell,num_symbols,num_symbols,embedding_size,output_projection=output_projection,
                feed_previous = True)
        trans_output = []
        for output in outputs:
            trans_output.append(tf.matmul(output,w) + b)
        return trans_output, state

    targets = decoder_inputs
    outputs, losses = tf.contrib.legacy_seq2seq.model_with_buckets(
            encoder_inputs, decoder_inputs, targets, 
            target_weights, buckets, seq2seq_f, 
            softmax_loss_function=None, 
            per_example_loss=False, name='model_with_buckets')
    patch = tf.convert_to_tensor([[0.0]*num_symbols] * batch_size)
    def f0(): 
        for _ in range(0,max_len-buckets[0][1]):
            outputs[0].append(patch)
        return tf.convert_to_tensor(outputs[0],dtype = tf.float32)
    def f1(): 
        for _ in range(0,max_len-buckets[1][1]):
            outputs[1].append(patch)
        return tf.convert_to_tensor(outputs[1],dtype = tf.float32)
    def f2(): 
        for _ in range(0,max_len-buckets[2][1]):
            outputs[2].append(patch)
        return tf.convert_to_tensor(outputs[2],dtype = tf.float32)
    r = tf.case({tf.equal(bucket_id, 0): f0,
                 tf.equal(bucket_id, 1): f1},
                default=f2, exclusive=True)
    return tf.nn.softmax(tf.reshape(r,[max_len,batch_size,num_symbols]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号