def get_data(item='train',id=1,is_shuffle=False,is_subtrain=1):
file_path=os.path.join(metadata_root,item+'_list0'+id+'.txt')
files=[]
labels=[]
with open(file_path,'r')as fp:
lines=fp.readlines()
if is_shuffle==True:
np.random.shuffle(lines)
if not is_subtrain==1:
lines=random.sample(lines,int(len(lines)*is_subtrain))
for line in lines:
tmp_prefix=line.strip().split('.')[0].split('/')[1]
label_tmp=line.strip().split(' ')[1]
files.append(os.path.join(feature_root,tmp_prefix+'.npy'))
labels.append(int(label_tmp)-1)
return files,np.array(labels,dtype=np.float64)
svm_classification.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录