def __init__(self, img_feat, tags_idx, a_tags_idx, test_tags_idx, z_dim, vocab_processor):
self.z_sampler = stats.truncnorm((-1 - 0.) / 1., (1 - 0.) / 1., loc=0., scale=1)
self.length = len(tags_idx)
self.current = 0
self.img_feat = img_feat
self.tags_idx = tags_idx
self.a_tags_idx = a_tags_idx
self.w_idx = np.arange(self.length)
self.w_idx2 = np.arange(self.length)
self.tmp = 0
self.epoch = 0
self.vocab_processor = vocab_processor
self.vocab_size = len(vocab_processor._reverse_mapping)
self.unk_id = vocab_processor._mapping['<UNK>']
self.eos_id = vocab_processor._mapping['<EOS>']
self.hair_id = vocab_processor._mapping['hair']
self.eyes_id = vocab_processor._mapping['eyes']
self.gen_info()
self.test_tags_idx = self.gen_test_hot(test_tags_idx)
self.fixed_z = self.next_noise_batch(len(self.test_tags_idx), z_dim)
idx = np.random.permutation(np.arange(self.length))
self.w_idx2 = self.w_idx2[idx]
评论列表
文章目录