def main():
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=1000, type=int)
parser.add_argument('-b', '--batch-size', default=N, type=int)
parser.add_argument('--cuda', action='store_true')
args = parser.parse_args()
data = build_linear_dataset(N, p)
if args.cuda:
# make tensors and modules CUDA
data = data.cuda()
softplus.cuda()
regression_model.cuda()
for j in range(args.num_epochs):
if args.batch_size == N:
# use the entire data set
epoch_loss = svi.step(data)
else:
# mini batch
epoch_loss = 0.0
perm = torch.randperm(N) if not args.cuda else torch.randperm(N).cuda()
# shuffle data
data = data[perm]
# get indices of each batch
all_batches = get_batch_indices(N, args.batch_size)
for ix, batch_start in enumerate(all_batches[:-1]):
batch_end = all_batches[ix + 1]
batch_data = data[batch_start: batch_end]
epoch_loss += svi.step(batch_data)
if j % 100 == 0:
print("epoch avg loss {}".format(epoch_loss/float(N)))
评论列表
文章目录