gen_model.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:TextGAN 作者: AustinStoneProjects 项目源码 文件源码
def __init__(self, args):
        self.args = args

        if args.gen_model == 'rnn':
            cell_fn = rnn_cell.BasicRNNCell
        elif args.gen_model == 'gru':
            cell_fn = rnn_cell.GRUCell
        elif args.gen_model == 'lstm':
            cell_fn = rnn_cell.BasicLSTMCell
        else:
            raise Exception("model type not supported: {}".format(args.model))

        with tf.variable_scope('GEN') as scope:
            cell = cell_fn(args.rnn_size)
            self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
            # sequence of word tokens taken as input
            self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length], name='input_data')

            self.latent_state = tf.placeholder(tf.float32, [args.batch_size, args.latent_size])

            # weights to map the latent state into the (usually) bigger initial state
            # right now this only works for rnn (other more complex models have more than
            # one initial state which needs to be given a value)
            # Right now we support up to two layers (state1 and state2)
            self.latent_to_initial_state1 = tf.Variable(tf.random_normal([args.latent_size, args.rnn_size], stddev=0.35, dtype=tf.float32), name='latent_to_intial_state1')
            self.latent_to_initial_state2 = tf.Variable(tf.random_normal([args.latent_size, args.rnn_size], stddev=0.35, dtype=tf.float32), name='latent_to_intial_state2')
            self.initial_state1 = tf.matmul(self.latent_state, self.latent_to_initial_state1)
            self.initial_state2 = tf.matmul(self.latent_state, self.latent_to_initial_state2)
            # these are the actual approximate word vectors generated by the model
            self.outputs = tf.placeholder(tf.float32, [args.seq_length, args.batch_size, args.rnn_size])
            self.lr = tf.Variable(0.0, trainable=False, name='learning_rate')
            self.has_init_seq2seq = False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号