eew_rnn_cuda.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:EarlyWarning 作者: wjlei1990 项目源码 文件源码
def predict_on_test(rnn, test_x):
    print("Predict...")
    pred_y = []
    for idx in range(len(test_x)):
        x = test_x[idx, :, :]
        x = Variable(torch.unsqueeze(torch.Tensor(x).t(), dim=1)).cuda()
        y_p = rnn(x)
        _y = float(y_p.cpu().data.numpy()[0])
        # print("pred %d: %f | true y: %f" % (idx, _y, test_y[idx]))
        pred_y.append(_y)

    return pred_y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号