disc_model.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:TextGAN 作者: AustinStoneProjects 项目源码 文件源码
def __init__(self, args):
        self.args = args

        if args.disc_model == 'rnn':
            cell_fn = rnn_cell.BasicRNNCell
        elif args.disc_model == 'gru':
            cell_fn = rnn_cell.GRUCell
        elif args.disc_model == 'lstm':
            cell_fn = rnn_cell.BasicLSTMCell
        else:
            raise Exception("model type not supported: {}".format(args.model))

        self.embedding = tf.Variable(tf.random_uniform([self.args.vocab_size, self.args.rnn_size], minval=-.05, maxval=.05, dtype=tf.float32), name='embedding')
        with tf.variable_scope('DISC') as scope:
            cell = cell_fn(args.rnn_size)

            self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
            # If the input data is given as word tokens, feed this value
            self.input_data_text = tf.placeholder(tf.int32, [args.batch_size, args.seq_length], name='input_data_text')
            #self.input_data_text = tf.Variable(tf.zeros((args.batch_size, args.seq_length), dtype=tf.int32), name='input_data_text')

            self.initial_state = cell.zero_state(args.batch_size, tf.float32)
            # Fully connected layer is applied to the final state to determine the output class
            self.fc_layer = tf.Variable(tf.random_normal([args.rnn_size, 1], stddev=0.35, dtype=tf.float32), name='disc_fc_layer')
            self.lr = tf.Variable(0.0, trainable=False, name='learning_rate')
            self.has_init_seq2seq = False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号