recurrentNetwork.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TikZ 作者: ellisk42 项目源码 文件源码
def decodesIntoAccuracy(self, labels, perSymbol = True):
        # as the dimensions None x L
        accuracyMatrix = tf.equal(self.hardOutputs, labels)

        # zero out anything past the labeled length
        accuracyMatrix = tf.logical_and(accuracyMatrix,
                                        tf.sequence_mask(self.lengthPlaceholder, maxlen = self.maximumLength))

        # Some across all of the time steps to get the total number of predictions correct in each batch entry
        accuracyVector = tf.reduce_sum(tf.cast(accuracyMatrix,tf.int32),axis = 1)
        if perSymbol:
            # Now normalize it by the sequence length and take the average
            accuracyVector = tf.divide(tf.cast(accuracyVector,tf.float32),
                                       tf.cast(self.lengthPlaceholder,tf.float32))
        if not perSymbol:
            # accuracy is measured per sequence
            accuracyVector = tf.cast(tf.equal(accuracyVector,self.lengthPlaceholder),tf.float32)
        return tf.reduce_mean(accuracyVector)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号