def get_transitions(states):
"""
Computes transitions given a state array
Args:
states : numpy array
States array of the form
...,4,1,1,...,1,2,2,...,2,3,3,....,3,4,...,4,1,...
Returns:
transitions : numpy array
Contains indices of all the transitions in the states array
"""
states = np.squeeze(states)
# Edge cases when starts in 1 and/or ends in 4
if states[0] == 1:
states = np.concatenate(([4], states))
if states[-1] == 4:
states = np.concatenate((states, [1]))
transitions = np.where(np.diff(states) != 0)[0] + 1
first = np.where(states == 1)[0][0]
last = np.where(states == 4)[0][-1] + 1
transitions = transitions[np.logical_and(transitions >= first, transitions <= last)]
return transitions
评论列表
文章目录