def crossEntropy(x):
"""Cross entropy loss function. Only works for networks with one output."""
if x.ndim == 1:
pass
elif x.ndim == 2:
x = x[:, 0]
else:
raise ValueError('x must be either a vector or a matrix.')
y = tt.vector('y')
L = -tt.mean(y * tt.log(x) + (1-y) * tt.log(1-x))
L.name = 'loss'
return y, L
评论列表
文章目录