def get_shapes_and_dtypes(data):
shapes = {}
dtypes = {}
for k in data.keys():
if isinstance(data[k][0], str):
shapes[k] = []
dtypes[k] = tf.string
elif isinstance(data[k][0], np.ndarray):
shapes[k] = data[k][0].shape
dtypes[k] = tf.uint8
elif isinstance(data[k][0], np.bool_):
shapes[k] = []
dtypes[k] = tf.string
else:
raise TypeError('Unknown data type', type(data[k][0]))
return shapes, dtypes
评论列表
文章目录