def get_remain_count(classify, test_set_x, data, valid, global_count, curr_residue, start_frame):
(valid_st2,valid_st5,valid_st8) = valid
(ns_test_set_x_st2,ns_test_set_x_st5,ns_test_set_x_st8) = data
(curr_residue_st2, curr_residue_st5, curr_residue_st8) = curr_residue
# classify st_2 it is always valid
(st2_count, st2_res, st2_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st2, curr_residue_st2, (start_frame/2-19), ns_test_set_x_st2.shape[0])
# check if st5 is valid. if not return st2 count
if (valid_st5 == 1):
(st5_count, st5_res, st5_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st5, curr_residue_st5, (start_frame/5-19), ns_test_set_x_st5.shape[0])
else:
st5_entropy = numpy.inf
if (valid_st8 == 1):
(st8_count, st8_res, st8_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st8, curr_residue_st8, (start_frame/8-19), ns_test_set_x_st8.shape[0])
else:
st8_entropy = numpy.inf
winner = numpy.nanargmin(numpy.array([st2_entropy, st5_entropy, st8_entropy]))
if (winner == 0):
# winner is stride 2
return (global_count + st2_count)
if (winner == 1):
# winner is stride 5
return (global_count + st5_count)
if (winner == 2):
# winner is stride 8
return (global_count + st8_count)
评论列表
文章目录