dvd.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:dvd 作者: ajayrfhp 项目源码 文件源码
def get_embedding_X(img):
    '''
            Args    : Numpy Images vector
            Returns : Embedded Matrix of length Samples, 4096
    '''
    img = img.reshape((img.shape[0], img.shape[1], img.shape[2], 1))
    sess = tf.Session()
    imgs = tf.placeholder(tf.float32, [None, None, None, None])
    vgg = vgg16(imgs, '/tmp/vgg16_weights.npz', sess)
    embs = []
    cnt = 0
    for img_batch in np.array_split(img, img.shape[0] / 1000):
        emb = sess.run(vgg.emb, feed_dict={vgg.imgs: img_batch})
        embs.extend(emb)
        cnt += 1
        progress = round(100 * (cnt * 1000 / img.shape[0]),2)
        if(progress%10 == 0):
          print progress
    embs = np.array(embs)
    print embs.shape
    embs = np.reshape(embs,(embs.shape[0],embs.shape[1] * embs.shape[2] * embs.shape[3]))
    return embs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号