def encode_tags(tags, output):
result = []
for tag in tags:
t_set = set(tag.split(' '))
assert issubclass(output, enum.Enum)
res_index = None
for idx, member in enumerate(output.__members__.values()):
if member.value == '' or member.value in t_set:
res_index = idx
assert res_index is not None
v = np.zeros(len(output), dtype=np.float32)
v[res_index] = 1.0
result.append(v)
return result
评论列表
文章目录