def __init__(self, asedb, kvp={}, data={}, batch_size=1,
selection=None, shuffle=True, prefetch=False,
block_size=150000,
capacity=5000, num_epochs=np.Inf, floatX=np.float32):
super(ASEDataProvider, self).__init__(batch_size)
self.asedb = asedb
self.prefetch = prefetch
self.selection = selection
self.block_size = block_size
self.shuffle = shuffle
self.kvp = kvp
self.data = data
self.floatX = floatX
self.feat_names = ['numbers', 'positions', 'cell',
'pbc'] + list(kvp.keys()) + list(data.keys())
self.shapes = [(None,), (None, 3), (3, 3),
(3,)] + list(kvp.values()) + list(data.values())
self.epoch = 0
self.num_epochs = num_epochs
self.n_rows = 0
# initialize queue
with connect(self.asedb) as con:
row = list(con.select(self.selection, limit=1))[0]
feats = self.convert_atoms(row)
dtypes = [np.array(feat).dtype for feat in feats]
self.queue = tf.FIFOQueue(capacity, dtypes)
self.placeholders = [
tf.placeholder(dt, name=name)
for dt, name in zip(dtypes, self.feat_names)
]
self.enqueue_op = self.queue.enqueue(self.placeholders)
self.dequeue_op = self.queue.dequeue()
self.preprocs = []
评论列表
文章目录