eew_rnn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:EarlyWarning 作者: wjlei1990 项目源码 文件源码
def main():
    waveforms, magnitudes = load_data()
    loader = make_dataloader(waveforms, magnitudes)

    rnn = RNN(input_size, hidden_size, num_layers)
    print(rnn)

    optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
    loss_func = nn.MSELoss()

    for epoch in range(3):
        loss_epoch = []
        for step, (batch_x, batch_y) in enumerate(loader):
            x = torch.unsqueeze(batch_x[0, :, :].t(), dim=1)
            print('Epoch: ', epoch, '| Step: ', step, '| x: ',
                  x.size(), '| y: ', batch_y.numpy())
            x = Variable(x)
            y = Variable(torch.Tensor([batch_y.numpy(), ]))
            prediction = rnn(x)
            loss = loss_func(prediction, y)
            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()
            loss_epoch.append(loss.data[0])
            print("Current loss: %e --- loss mean: %f"
                  % (loss.data[0], np.mean(loss_epoch)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号