def continuous_loss(self, y, y_hat):
if isinstance(y_hat, DocLabel):
raise ValueError("continuous loss on discrete input")
if isinstance(y_hat[0], tuple):
y_hat = y_hat[0]
prop_marg, link_marg = y_hat
y_nodes = self.prop_encoder_.transform(y.nodes)
y_links = self.link_encoder_.transform(y.links)
prop_ix = np.indices(y.nodes.shape)
link_ix = np.indices(y.links.shape)
# relies on prop_marg and link_marg summing to 1 row-wise
prop_loss = np.sum(self.prop_cw_[y_nodes] *
(1 - prop_marg[prop_ix, y_nodes]))
link_loss = np.sum(self.link_cw_[y_links] *
(1 - link_marg[link_ix, y_links]))
loss = prop_loss + link_loss
return loss
评论列表
文章目录