nn_test.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:imperative 作者: yaroslavvb 项目源码 文件源码
def _upsample_filters(self, filters, rate):
    """Upsamples the filters by a factor of rate along the spatial dimensions.

    Args:
      filters: [h, w, in_depth, out_depth]. Original filters.
      rate: An int, specifying the upsampling rate.

    Returns:
      filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
        h_up = h + (h - 1) * (rate - 1)
        w_up = w + (w - 1) * (rate - 1)
        containing (rate - 1) zeros between consecutive filter values along
        the filters' spatial dimensions.
    """
    if rate == 1:
      return filters
    # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
    filters_up = np.transpose(filters, [2, 3, 0, 1])
    ker = np.zeros([rate, rate])
    ker[0, 0] = 1
    filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
    # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
    filters_up = np.transpose(filters_up, [2, 3, 0, 1])
    self.assertEqual(np.sum(filters), np.sum(filters_up))
    return filters_up
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号