bpnn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:lotto 作者: hhh5460 项目源码 文件源码
def calculate_loss(self, X, y, model):
        num_examples = len(X)
        lamda = 0.01 # regularization strength

        Wi, bh, Wh, bo = model['Wi'], model['bh'], model['Wh'], model['bo']
        # Forward propagation to calculate our predictions
        neth = np.dot(X, Wi) + bh
        lh = np.tanh(neth)
        neto = np.dot(lh, Wh) + bo
        lo = np.exp(neto)
        probs = lo / np.sum(lo, axis=1, keepdims=True)
        # Calculating the loss
        corect_logprobs = -np.log(probs[range(num_examples), y])
        data_loss = np.sum(corect_logprobs)
        # Add regulatization term to loss (optional)
        data_loss += lamda/2 * (np.sum(np.square(Wi)) + np.sum(np.square(Wh)))
        return 1./num_examples * data_loss

    # ??
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号