gan.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:SequentialData-GAN 作者: jaesik817 项目源码 文件源码
def build_discriminator(x_data, x_generated, keep_prob):
    x_data=tf.unstack(x_data,seq_size,1);
    x_generated=list(x_generated);
    x_in = tf.concat([x_data, x_generated],1);
    x_in=tf.unstack(x_in,seq_size,0);
    lstm_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(n_hidden), output_keep_prob=keep_prob) for _ in range(d_num_layers)]);
    with tf.variable_scope("dis") as dis:
      weights=tf.Variable(tf.random_normal([n_hidden, 1]));
      biases=tf.Variable(tf.random_normal([1]));
      outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x_in, dtype=tf.float32);
      res=tf.matmul(outputs[-1], weights) + biases;
      y_data = tf.nn.sigmoid(tf.slice(res, [0, 0], [batch_size, -1], name=None));
      y_generated = tf.nn.sigmoid(tf.slice(res, [batch_size, 0], [-1, -1], name=None));
      d_params=[v for v in tf.global_variables() if v.name.startswith(dis.name)];
    with tf.name_scope("desc_params"):
      for param in d_params:
        variable_summaries(param);
    return y_data, y_generated, d_params;
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号