def cross_entropy(self, raw_network_output, target_signal_data):
if isinstance(target_signal_data, Variable):
raise Exception("target_signal_data cannot be Variable")
raw_network_output = self.to_variable(raw_network_output)
target_width = target_signal_data.shape[1]
batchsize = raw_network_output.data.shape[0]
if raw_network_output.data.shape[3] != target_width:
raise Exception("raw_network_output.width != target.width")
# (batchsize * time_step,) <- (batchsize, time_step)
target_signal_data = target_signal_data.reshape((-1,))
target_signal = self.to_variable(target_signal_data)
# (batchsize * time_step, channels) <- (batchsize, channels, 1, time_step)
raw_network_output = F.transpose(raw_network_output, (0, 3, 2, 1))
raw_network_output = F.reshape(raw_network_output, (batchsize * target_width, -1))
loss = F.softmax_cross_entropy(raw_network_output, target_signal)
return loss
评论列表
文章目录