cache_embeddings.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('tag')
  args = parser.parse_args()
  tag = args.tag

  hps = hypers.hps_for_tag(tag)
  hps.is_training = 0
  hps.batch_size = 1
  # (dummy dataset, just so we have some placeholder values for the rnnmodel's input vars)
  dat = dataset.BasketDataset(hps, 'unit_tests.tfrecords')
  model = rnnmodel.RNNModel(hps, dat)
  sess = tf.InteractiveSession()
  utils.load_checkpoint_for_tag(tag, sess)

  def lookup(varname):
    with tf.variable_scope('instarnn', reuse=True):
        var = tf.get_variable(varname)
    val = sess.run(var)
    return val

  emb = lookup('product_embeddings')
  outpath = path_for_cached_embeddings(tag)
  np.save(outpath, emb)
  print 'Saved embeddings with shape {} to {}'.format(emb.shape, outpath)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号