deepAPI_model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Simple-DeepAPISearcher 作者: dlcjfgmlnasa 项目源码 文件源码
def __init__(self,
                 encoder_size,
                 decoder_size,
                 encoder_vocab_size,
                 decoder_vocab_size,
                 encoder_layer_size,
                 decoder_layer_size,
                 RNN_type='LSTM',
                 encoder_input_keep_prob=1.0,
                 encoder_output_keep_prob=1.0,
                 decoder_input_keep_prob=1.0,
                 decoder_output_keep_prob=1.0,
                 learning_rate=0.01,
                 hidden_size=128):

        self.encoder_size = encoder_size
        self.decoder_size = decoder_size
        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        self.encoder_layer_size = encoder_layer_size
        self.decoder_layer_size = decoder_layer_size
        self.encoder_input_keep_prob = encoder_input_keep_prob
        self.encoder_output_keep_prob = encoder_output_keep_prob
        self.decoder_input_keep_prob = decoder_input_keep_prob
        self.decoder_output_keep_prob = decoder_output_keep_prob
        self.learning_rate = learning_rate
        self.hidden_size = hidden_size

        self.encoder_input = tf.placeholder(tf.float32, shape=(None, self.encoder_size, self.encoder_vocab_size))
        self.decoder_input = tf.placeholder(tf.float32, shape=(None, self.decoder_size, self.decoder_vocab_size))
        self.target_input = tf.placeholder(tf.int32, shape=(None, self.decoder_size))

        self.weight = tf.get_variable(shape=[self.hidden_size, self.decoder_vocab_size],
                                      initializer=tf.contrib.layers.xavier_initializer(),
                                      dtype=tf.float32,
                                      name='weight')
        self.bias = tf.get_variable(shape=[self.decoder_vocab_size],
                                    initializer=tf.contrib.layers.xavier_initializer(),
                                    dtype=tf.float32,
                                    name='bias')

        self.logits = None
        self.cost = None
        self.train_op = None
        self.RNNCell = None
        self.outputs = None
        self.merged = None

        if RNN_type == 'LSTM':
            self.RNNCell = rnn.LSTMCell
        elif RNN_type == 'GRU':
            self.RNNCell = rnn.GRUCell
        else:
            raise Exception('not support {} RNN type'.format(RNN_type))

        self.build_model()
        self.saver = tf.train.Saver(tf.global_variables())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号