base_trainer.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:ssgan 作者: samrussell 项目源码 文件源码
def save_results(self, filename):
    # save some samples
    fake_categories = np.random.choice(self.num_classes,16)
    fake_vectors = to_categorical(fake_categories, self.num_classes+1)
    random_value_part = np.random.uniform(0,1,size=[16,100-(self.num_classes+1)])
    fake_values = np.concatenate((fake_vectors, random_value_part), axis=1)
    #fake_values = np.random.uniform(0,1,size=[16,100])
    images = self.generator.predict(fake_values)
    plt.figure(figsize=(10,10))

    for i in range(images.shape[0]):
      plt.subplot(4, 4, i+1)
      image = images[i, :, :, :]
      if self.img_channels == 1:
        image = np.reshape(image, [self.img_rows, self.img_cols])
      elif K.image_data_format() == 'channels_first':
        image = image.transpose(1,2,0)
      # implicit no need to transpose if channels are last
      plt.imshow(image, cmap='gray')
      plt.axis('off')
    plt.tight_layout()

    plt.savefig(filename)
    plt.close('all')

  #def test_results(self, testing_values, testing_labels):
    #predictions = self.model.predict(testing_values)
    #df = pandas.DataFrame(data=np.argmax(predictions, axis=1), columns=['Label'])
    #df.insert(0, 'ImageId', range(1, 1 + len(df)))

    # save results
    #df.to_csv(self.commandline_args.output, index=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号