nn1_stress_test.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def gen_minibatch(tokens, features, labels, mini_batch_size, shuffle= True):
    tokens = np.asarray(tokens)[np.where(labels!=0.5)[0]]
    if type(features) is np.ndarray:
      features = np.asarray(features)[np.where(labels!=0.5)[0]]
    else:
      features = np.asarray(features.todense())[np.where(labels!=0.5)[0]]
    labels = np.asarray(labels)[np.where(labels!=0.5)[0]]
#     print tokens.shape
#     print tokens[0]
    for token, feature, label in iterate_minibatches(tokens, features, labels, mini_batch_size, shuffle = shuffle):
#         print 'token', type(token)
#         print token
        token = [_ for _ in pad_batch(token)]
#         print len(token), token[0].size(), token[1].size()
        yield token, Variable(torch.from_numpy(feature)) , Variable(torch.FloatTensor(label), requires_grad= False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号