def initial_count(classify, test_set_x, data, valid):
(valid_st2,valid_st5,valid_st8) = valid
(ns_test_set_x_st2,ns_test_set_x_st5,ns_test_set_x_st8) = data
# classify st_2 it is always valid
(st2_count, st2_res, st2_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st2, 0, 0, 81) #100 - 19 etc.
# check if st5 is valid. if not return st2 count
if (valid_st5 == 1):
(st5_count, st5_res, st5_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st5, 0, 0, 21)
else:
st8_entropy = numpy.inf
if (valid_st8 == 1):
(st8_count, st8_res, st8_entropy) = count_in_interval(classify, test_set_x, ns_test_set_x_st8, 0, 0, 6)
else:
st8_entropy = numpy.inf
winner = numpy.nanargmin(numpy.array([st2_entropy, st5_entropy, st8_entropy]))
if (winner == 0):
# winner is stride 2
return (st2_count, (st2_res*2/2,st2_res*2/5, st2_res*2/8))
if (winner == 1):
# winner is stride 5
return (st5_count, (st5_res*5/2,st5_res*5/5, st5_res*5/8))
if (winner == 2):
# winner is stride 8
return (st8_count, (st8_res*8/2,st8_res*8/5, st8_res*8/8))
评论列表
文章目录