model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:tf_han 作者: AlbertXiebnu 项目源码 文件源码
def sentence_embedding(self, inputs, keep_prob, w):
        with tf.device('/cpu:0'):
            embedding_layer = tf.nn.embedding_lookup(w['word_embedding_w'],inputs)
        # batch_size x max_len x word_embedding
        cell_input = tf.transpose(embedding_layer,[1,0,2])
        cell_input = tf.reshape(cell_input,[-1,self.hiddensize])
        cell_input = tf.split(0,self.max_len,cell_input)
        with tf.variable_scope('forward'):
            lstm_fw_cell = rnn_cell.DropoutWrapper(rnn_cell.BasicLSTMCell(self.rnnsize,forget_bias=1.0,state_is_tuple=True),input_keep_prob=keep_prob,output_keep_prob=keep_prob)
        with tf.variable_scope('backward'):
            lstm_bw_cell = rnn_cell.DropoutWrapper(rnn_cell.BasicLSTMCell(self.rnnsize,forget_bias=1.0,state_is_tuple=True),input_keep_prob=keep_prob,output_keep_prob=keep_prob)
        outputs,_,_ = rnn.bidirectional_rnn(lstm_fw_cell,lstm_bw_cell,cell_input,dtype=tf.float32)
        # outputs shape: seq_len x [batch_size x (fw_cell_size + bw_cell_size)]
        att = self.attention_layer(outputs,w)
        return att
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号