srez_train.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorflow-srgan 作者: olgaliak 项目源码 文件源码
def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, max_samples=8):
    td = train_data

    size = [label.shape[1], label.shape[2]]

    nearest = tf.image.resize_nearest_neighbor(feature, size)
    nearest = tf.maximum(tf.minimum(nearest, 1.0), 0.0)

    bicubic = tf.image.resize_bicubic(feature, size)
    bicubic = tf.maximum(tf.minimum(bicubic, 1.0), 0.0)

    clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)

#    image   = tf.concat([nearest, bicubic, clipped, label], 2)
    image   = clipped

    printCnt = 5
    image = image[0:printCnt]
    image = tf.concat([image[i,:,:,:] for i in range(printCnt)], 0)
    image = td.sess.run(image)

    filename = 'batch%06d_%s.png' % (batch, suffix)
    filename = os.path.join(FLAGS.train_dir, filename)
    scipy.misc.toimage(image, cmin=0., cmax=1.).save(filename)
    print("    Saved %s" % (filename,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号