def main():
parser = argparse.ArgumentParser()
parser.add_argument('tag')
args = parser.parse_args()
tag = args.tag
hps = hypers.hps_for_tag(tag)
hps.is_training = 0
hps.batch_size = 1
# (dummy dataset, just so we have some placeholder values for the rnnmodel's input vars)
dat = dataset.BasketDataset(hps, 'unit_tests.tfrecords')
model = rnnmodel.RNNModel(hps, dat)
sess = tf.InteractiveSession()
utils.load_checkpoint_for_tag(tag, sess)
def lookup(varname):
with tf.variable_scope('instarnn', reuse=True):
var = tf.get_variable(varname)
val = sess.run(var)
return val
emb = lookup('product_embeddings')
outpath = path_for_cached_embeddings(tag)
np.save(outpath, emb)
print 'Saved embeddings with shape {} to {}'.format(emb.shape, outpath)
cache_embeddings.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录