def squareError(x):
"""Square error loss function."""
if x.ndim == 1:
y = tt.vector('y')
L = tt.mean((x - y) ** 2)
elif x.ndim == 2:
y = tt.matrix('y')
L = tt.mean(tt.sum((x - y) ** 2, axis=1))
else:
raise ValueError('x must be either a vector or a matrix.')
L.name = 'loss'
return y, L
评论列表
文章目录