def get_net_cost(model, cost_type, eye=True):
"""Get the train cost of the network."""
cost = None
if eye:
d_eyes = (
(model.trg[:, 37] - model.trg[:, 46])**2 +
(model.trg[:, 37] - model.trg[:, 46])**2).T
if cost_type == CostType.MeanSquared:
cost = T.mean(
T.sqr(model.output_dropout - model.trg), axis=1) / d_eyes
elif cost_type == CostType.CrossEntropy:
cost = T.mean(
T.nnet.binary_crossentropy(
model.output_dropout, model.trg), axis=1)
else:
raise ValueError("cost type unknow.")
else:
if cost_type == CostType.MeanSquared:
cost = T.mean(
T.sqr(model.output_dropout - model.trg), axis=1)
elif cost_type == CostType.CrossEntropy:
cost = T.mean(
T.nnet.binary_crossentropy(
model.output_dropout, model.trg), axis=1)
else:
raise ValueError("cost type unknow.")
if model.l1 != 0.:
cost += model.l1
if model.l2 != 0.:
cost += model.l2
return cost
评论列表
文章目录