def __init__(self, args):
self.args = args
if args.disc_model == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif args.disc_model == 'gru':
cell_fn = rnn_cell.GRUCell
elif args.disc_model == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))
self.embedding = tf.Variable(tf.random_uniform([self.args.vocab_size, self.args.rnn_size], minval=-.05, maxval=.05, dtype=tf.float32), name='embedding')
with tf.variable_scope('DISC') as scope:
cell = cell_fn(args.rnn_size)
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
# If the input data is given as word tokens, feed this value
self.input_data_text = tf.placeholder(tf.int32, [args.batch_size, args.seq_length], name='input_data_text')
#self.input_data_text = tf.Variable(tf.zeros((args.batch_size, args.seq_length), dtype=tf.int32), name='input_data_text')
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
# Fully connected layer is applied to the final state to determine the output class
self.fc_layer = tf.Variable(tf.random_normal([args.rnn_size, 1], stddev=0.35, dtype=tf.float32), name='disc_fc_layer')
self.lr = tf.Variable(0.0, trainable=False, name='learning_rate')
self.has_init_seq2seq = False
评论列表
文章目录