def _labeling_batch_to_class_batch(y, y_labeling, num_classes,
y_hat_mask=None):
# FIXME: y_hat_mask is currently not used
batch_size = y.shape[1]
N = y_labeling.shape[0]
n_labels = y.shape[0]
# sum over all repeated labels
# from (T, B, L) to (T, C, B)
out = T.zeros((num_classes, batch_size, N))
y_labeling = y_labeling.dimshuffle((2, 1, 0)) # L, B, T
y_ = y
def scan_step(index, prev_res, y_labeling, y_):
res_t = T.inc_subtensor(prev_res[y_[index, T.arange(batch_size)],
T.arange(batch_size)],
y_labeling[index, T.arange(batch_size)])
return res_t
result, updates = theano.scan(scan_step,
sequences=[T.arange(n_labels)],
non_sequences=[y_labeling, y_],
outputs_info=[out])
# result will be (C, B, T) so we make it (T, B, C)
return result[-1].dimshuffle(2, 1, 0)
评论列表
文章目录