bench_classify_online.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepRepICCV2015 作者: tomrunia 项目源码 文件源码
def get_next_count(classify, test_set_x, data, valid, global_count, curr_residue, start_frame):

    (valid_st2,valid_st5,valid_st8) = valid
    (ns_test_set_x_st2,ns_test_set_x_st5,ns_test_set_x_st8) = data
    (curr_residue_st2, curr_residue_st5, curr_residue_st8) = curr_residue

    # classify st_2 it is always valid
    (st2_count, st2_res, st2_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st2, curr_residue_st2, (start_frame/2-19), (start_frame/2-19)+20)
    # check if st5 is valid. if not return st2 count
    if (valid_st5 == 1):
        (st5_count, st5_res, st5_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st5, curr_residue_st5, (start_frame/5-19), (start_frame/5-19)+8)
    else:
        st5_entropy = numpy.inf

    if (valid_st8 == 1):
        (st8_count, st8_res, st8_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st8, curr_residue_st8, (start_frame/8-19), (start_frame/8-19)+5)
    else:
        st8_entropy = numpy.inf

    winner = numpy.nanargmin(numpy.array([st2_entropy, st5_entropy, st8_entropy]))

    if (winner == 0):
        # winner is stride 2
        return (global_count + st2_count, (st2_res*2/2,st2_res*2/5, st2_res*2/8))
    if (winner == 1):
        # winner is stride 5
        return (global_count + st5_count, (st5_res*5/2,st5_res*5/5, st5_res*5/8))
    if (winner == 2):
        # winner is stride 8
        return (global_count + st8_count, (st8_res*8/2,st8_res*8/5, st8_res*8/8))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号