def get_next_count(classify, test_set_x, data, valid, global_count, curr_residue, start_frame):
(valid_st2,valid_st5,valid_st8) = valid
(ns_test_set_x_st2,ns_test_set_x_st5,ns_test_set_x_st8) = data
(curr_residue_st2, curr_residue_st5, curr_residue_st8) = curr_residue
# classify st_2 it is always valid
(st2_count, st2_res, st2_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st2, curr_residue_st2, (start_frame/2-19), (start_frame/2-19)+20)
# check if st5 is valid. if not return st2 count
if (valid_st5 == 1):
(st5_count, st5_res, st5_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st5, curr_residue_st5, (start_frame/5-19), (start_frame/5-19)+8)
else:
st5_entropy = numpy.inf
if (valid_st8 == 1):
(st8_count, st8_res, st8_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st8, curr_residue_st8, (start_frame/8-19), (start_frame/8-19)+5)
else:
st8_entropy = numpy.inf
winner = numpy.nanargmin(numpy.array([st2_entropy, st5_entropy, st8_entropy]))
if (winner == 0):
# winner is stride 2
return (global_count + st2_count, (st2_res*2/2,st2_res*2/5, st2_res*2/8))
if (winner == 1):
# winner is stride 5
return (global_count + st5_count, (st5_res*5/2,st5_res*5/5, st5_res*5/8))
if (winner == 2):
# winner is stride 8
return (global_count + st8_count, (st8_res*8/2,st8_res*8/5, st8_res*8/8))
评论列表
文章目录