train.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ml_gans 作者: imironhead 项目源码 文件源码
def build_dataset_reader():
    """
    """
    paths_png_wildcards = os.path.join(FLAGS.portraits_dir_path, '*.png')

    paths_png = glob.glob(paths_png_wildcards)

    file_name_queue = tf.train.string_input_producer(paths_png)

    reader = tf.WholeFileReader()

    reader_key, reader_val = reader.read(file_name_queue)

    image = tf.image.decode_png(reader_val, channels=3, dtype=tf.uint8)

    # assume the size of input images are either 128x128x3 or 64x64x3.

    if FLAGS.crop_image:
        image = tf.image.crop_to_bounding_box(
            image,
            FLAGS.crop_image_offset_y,
            FLAGS.crop_image_offset_x,
            FLAGS.crop_image_size_m,
            FLAGS.crop_image_size_m)

        image = tf.random_crop(
            image, size=[FLAGS.crop_image_size_n, FLAGS.crop_image_size_n, 3])

    image = tf.image.resize_images(image, [FLAGS.image_size, FLAGS.image_size])

    image = tf.image.random_flip_left_right(image)

    image = tf.cast(image, dtype=tf.float32) / 127.5 - 1.0

    return tf.train.batch(
        tensors=[image],
        batch_size=FLAGS.batch_size,
        capacity=FLAGS.batch_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号