data.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:dtnn 作者: atomistic-machine-learning 项目源码 文件源码
def __init__(self, asedb, kvp={}, data={}, batch_size=1,
                 selection=None, shuffle=True, prefetch=False,
                 block_size=150000,
                 capacity=5000, num_epochs=np.Inf, floatX=np.float32):
        super(ASEDataProvider, self).__init__(batch_size)

        self.asedb = asedb
        self.prefetch = prefetch
        self.selection = selection
        self.block_size = block_size
        self.shuffle = shuffle
        self.kvp = kvp
        self.data = data
        self.floatX = floatX
        self.feat_names = ['numbers', 'positions', 'cell',
                           'pbc'] + list(kvp.keys()) + list(data.keys())
        self.shapes = [(None,), (None, 3), (3, 3),
                       (3,)] + list(kvp.values()) + list(data.values())

        self.epoch = 0
        self.num_epochs = num_epochs
        self.n_rows = 0

        # initialize queue
        with connect(self.asedb) as con:
            row = list(con.select(self.selection, limit=1))[0]

        feats = self.convert_atoms(row)
        dtypes = [np.array(feat).dtype for feat in feats]
        self.queue = tf.FIFOQueue(capacity, dtypes)

        self.placeholders = [
            tf.placeholder(dt, name=name)
            for dt, name in zip(dtypes, self.feat_names)
            ]
        self.enqueue_op = self.queue.enqueue(self.placeholders)
        self.dequeue_op = self.queue.dequeue()

        self.preprocs = []
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号