train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TF_MemN2N-tableQA 作者: vendi12 项目源码 文件源码
def categorical_accuracy(y_true, y_pred, mask=True):
    '''
    categorical_accuracy adjusted for padding mask
    '''
    # if mask is not None:
    print y_true
    print y_pred
    eval_shape = (reduce(mul, y_true.shape[:-1]), y_true.shape[-1])
    print eval_shape
    y_true_ = np.reshape(y_true, eval_shape)
    y_pred_ = np.reshape(y_pred, eval_shape)
    flat_mask = np.flatten(mask)
    comped = np.equal(np.argmax(y_true_, axis=-1),
                      np.argmax(y_pred_, axis=-1))
    ## not sure how to do this in tensor flow
    good_entries = flat_mask.nonzero()[0]
    return np.mean(np.gather(comped, good_entries))

    # else:
    #     return K.mean(K.equal(K.argmax(y_true, axis=-1),
    #                           K.argmax(y_pred, axis=-1)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号