tflittle_gru.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:identifiera-sarkasm 作者: risnejunior 项目源码 文件源码
def feed_network(self,data,keep_prob,chunk_size,n_chunks,dynamic):
        # This code is copied from tflearn
        sequence_lengths = None
        if dynamic:
            sequence_lengths = net.calc_seqlenth(data if isinstance(data, tf.Tensor) else tf.stack(data))
        batch_size = tf.shape(data)[0]
        weight_dropout = tf.nn.dropout(self._layer_weights, keep_prob)
        rnn_dropout = rnn.core_rnn_cell.DropoutWrapper(self._gru_cell,output_keep_prob=keep_prob)

        # Calculation Begin
        input_shape = data.get_shape().as_list()
        ndim = len(input_shape)
        axis = [1, 0] + list(range(2,ndim))
        data = tf.transpose(data,(axis))
        sequence = tf.unstack(data)
        outputs, states = rnn.static_rnn(rnn_dropout, sequence, dtype=tf.float32, sequence_length = sequence_lengths)
        if dynamic:
            outputs = tf.transpose(tf.stack(outputs), [1, 0, 2])
            output = net.advanced_indexing_op(outputs, sequence_lengths)
        else:
            output = outputs[-1]
        output = tf.add(tf.matmul(output,weight_dropout), self._layer_biases)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号