def inputs_train():
"""
Args:
nothing
Rtns:
img3_batch -> 5D float32 or float16 tensor of [batch_size,h,w,d,c]
label_batch -> 1D float32 or float16 tensor of [batch_size]
Raises:
ValueError -> If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir)
img3_batch, label_batch = in_data.inputs_train(data_dir=data_dir,
batch_size=FLAGS.batch_size)
if FLAGS.use_fp16:
img3_batch = tf.cast(img3_batch, tf.float16)
label_batch = tf.cast(label_batch, tf.float16)
return img3_batch, label_batch
评论列表
文章目录