mixture.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:nn-patterns 作者: pikinder 项目源码 文件源码
def _update_statistics(self, new_stats, stats):
        new_stats = create_dict(new_stats)
        if stats is None:
            stats = new_stats
            return stats

        # update the stats layerwise
        for l_i in range(len(stats)):

            for subtype,_ in subtypes:
                # TODO: Have to check the type to see if this is needed
                cnt_old = 1.0 * stats[l_i][subtype]['cnt']
                stats[l_i][subtype]['cnt'] = (stats[l_i][subtype]['cnt']
                                              + new_stats[l_i][subtype]['cnt'])
                norm = np.maximum(stats[l_i][subtype]['cnt'], 1.0)

                for key in subtype_keys:
                    if key not in subtype_keys_no_aggregation:
                        tmp_old = cnt_old / norm * stats[l_i][subtype][key]
                        tmp_new = (new_stats[l_i][subtype]['cnt']
                                   / norm * new_stats[l_i][subtype][key])
                        stats[l_i][subtype][key] = tmp_old + tmp_new
        return stats
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号