eval_dtnn_gdb9.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:dtnn 作者: atomistic-machine-learning 项目源码 文件源码
def predict(dbpath, features, sess, y):
    U0 = []
    U0_pred = []
    count = 0
    with connect(dbpath) as conn:
        n_structures = conn.count()
        for row in conn.select():
            U0.append(row['U0'])

            at = row.toatoms()
            feed_dict = {
                features['numbers']:
                    np.array(at.numbers).astype(np.int64),
                features['positions']:
                    np.array(at.positions).astype(np.float32)
            }
            U0_p = sess.run(y, feed_dict=feed_dict)
            U0_pred.append(U0_p)
            if count % 1000 == 0:
                print(str(count) + ' / ' + str(n_structures))
            count += 1
    return U0, U0_pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号