task.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:cloud-ml-sdk 作者: XiaoMi 项目源码 文件源码
def main():
  # Define train function
  def linear_train(train_data, train_target, n_epochs=200):
    for _ in range(n_epochs):
      output = linear_function(train_data)
      loss = F.mean_squared_error(train_target, output)
      linear_function.zerograds()
      loss.backward()
      optimizer.update()

  # Construct train data
  x = 30 * np.random.rand(1000).astype(np.float32)
  y = 7 * x + 10
  y += 10 * np.random.randn(1000).astype(np.float32)

  linear_function = L.Linear(1, 1)

  x_var = Variable(x.reshape(1000, -1))
  y_var = Variable(y.reshape(1000, -1))

  optimizer = optimizers.MomentumSGD(lr=0.001)
  optimizer.setup(linear_function)

  for i in range(150):
    linear_train(x_var, y_var, n_epochs=20)
    y_pred = linear_function(x_var).data

  slope = linear_function.W.data[0, 0]
  intercept = linear_function.b.data[0]

  print("Final Line: {0:.3}x + {1:.3}".format(slope, intercept))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号