def update_memories_with_extra_features_(self, memory_lengths, memories):
memory_lengths = memory_lengths.data
memories = memories.data
if self.extra_features_slots > 0:
num_nonempty_memories = memory_lengths.ne(0).sum()
updated_memories = memories.new(memories.numel() + num_nonempty_memories * self.extra_features_slots)
src_offset = 0
dst_offset = 0
for i in range(memory_lengths.size(0)):
for j in range(self.opt['mem_size']):
length = memory_lengths[i, j]
if length > 0:
if self.opt['time_features']:
updated_memories[dst_offset] = self.time_feature(j)
dst_offset += 1
updated_memories[dst_offset:dst_offset + length] = memories[src_offset:src_offset + length]
src_offset += length
dst_offset += length
memory_lengths += memory_lengths.ne(0).long() * self.extra_features_slots
memories.set_(updated_memories)
评论列表
文章目录