fast_weights.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:meta-learning 作者: ioanachelu 项目源码 文件源码
def _fwlinear(self, args, output_size, scope=None):
    if args is None or (nest.is_sequence(args) and not args):
      raise ValueError("`args` must be specified")
    if not nest.is_sequence(args):
      args = [args]
    assert len(args) == 2
    assert args[0].get_shape().as_list()[1] == output_size

    dtype = [a.dtype for a in args][0]

    with vs.variable_scope(scope or "Linear"):
      # matrixW = vs.get_variable(
      #   "MatrixW", dtype=dtype, initializer=tf.convert_to_tensor(np.eye(output_size, dtype=np.float32) * .05))
      matrixW = vs.get_variable("MatrixW", [output_size, output_size], dtype=dtype)
      matrixC = vs.get_variable(
        "MatrixC", [args[1].get_shape().as_list()[1], output_size], dtype=dtype)

      res = tf.matmul(args[0], matrixW) + tf.matmul(args[1], matrixC)
      return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号