def bn_backward(dout, cache):
X, X_norm, mu, var, gamma, beta = cache
N, D = X.shape
X_mu = X - mu
std_inv = 1. / np.sqrt(var + c.eps)
dX_norm = dout * gamma
dvar = np.sum(dX_norm * X_mu, axis=0) * -.5 * std_inv**3
dmu = np.sum(dX_norm * -std_inv, axis=0) + dvar * np.mean(-2. * X_mu, axis=0)
dX = (dX_norm * std_inv) + (dvar * 2 * X_mu / N) + (dmu / N)
dgamma = np.sum(dout * X_norm, axis=0)
dbeta = np.sum(dout, axis=0)
return dX, dgamma, dbeta
评论列表
文章目录