gru_seq2seq.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:tensorflow 作者: ananthpn 项目源码 文件源码
def __init__(self, clf_pic_name, input_dims, hidden_size, num_decoder_symbols, learning_rate=0.01, maxlen_to_decode=16):
        self.clf_pic_name = clf_pic_name # we will save the model here
        # set up some common variables
        self.start_of_sequence_id = 0 # this will help us to terminate the seq
        self.end_of_sequence_id = 0
        self.encoder_hidden_size = hidden_size
        self.decoder_hidden_size = self.encoder_hidden_size
        self.learning_rate = learning_rate
        self.decoder_sequence_length = maxlen_to_decode #7 #max length that decoder will predict before terminating

        # placeholders and variables
        self.encoder_length = tf.placeholder(tf.int32, [None]) # seq length for dynamic time unrolling
        self.decoder_length = tf.placeholder(tf.int32, [None]) # seq length for dynamic time unrolling
        self.encoder_embedding_size = input_dims #
        self.decoder_embedding_size = self.encoder_embedding_size
        self.decoder_embeddings = tf.get_variable('decoder_embeddings',
                [self.decoder_embedding_size, self.decoder_embedding_size],) # 
        self.num_decoder_symbols = num_decoder_symbols #self.decoder_embedding_size # number of output classes of decoder
        with tf.variable_scope("rnn") as scope:
            # setting up weights for computing the final output
            self.output_fn = lambda x: layers.linear(x, self.num_decoder_symbols,
                                          scope=scope)
        self.inputs = tf.placeholder("float", [None, None, self.encoder_embedding_size])
        self.decoder_inputs = tf.placeholder("float", [None, None, self.decoder_embedding_size])
        self.encoder_targets = tf.placeholder("float", [None, None, self.num_decoder_symbols])
        self.decoder_targets = tf.placeholder("float", [None, None, self.num_decoder_symbols])

        # build model - compute graph
        self.encoder()
        self.decoder_train()
        self.decoder_inference()
        self.compute_cost()
        self.optimize()
        self.get_sm_outputs()
        return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号