def sample_embeddings(self, embeddings, filenames, class_id, sample_num):
if len(embeddings.shape) == 2 or embeddings.shape[1] == 1:
return np.squeeze(embeddings)
else:
batch_size, embedding_num, _ = embeddings.shape
# Take every sample_num captions to compute the mean vector
sampled_embeddings = []
sampled_captions = []
for i in range(batch_size):
randix = np.random.choice(embedding_num,
sample_num, replace=False)
if sample_num == 1:
randix = int(randix)
captions = self.readCaptions(filenames[i],
class_id[i])
sampled_captions.append(captions[randix])
sampled_embeddings.append(embeddings[i, randix, :])
else:
e_sample = embeddings[i, randix, :]
e_mean = np.mean(e_sample, axis=0)
sampled_embeddings.append(e_mean)
sampled_embeddings_array = np.array(sampled_embeddings)
return np.squeeze(sampled_embeddings_array), sampled_captions
datasets.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录