def save(artist, model_path, num_save):
sample_save_dir = c.get_dir('../save/samples/')
sess = tf.Session()
print artist
data_reader = DataReader(artist)
vocab = data_reader.get_vocab()
print 'Init model...'
model = LSTMModel(sess,
vocab,
c.BATCH_SIZE,
c.SEQ_LEN,
c.CELL_SIZE,
c.NUM_LAYERS,
test=True)
saver = tf.train.Saver()
sess.run(tf.initialize_all_variables())
saver.restore(sess, model_path)
print 'Model restored from ' + model_path
artist_save_dir = c.get_dir(join(sample_save_dir, artist))
for i in xrange(num_save):
print i
path = join(artist_save_dir, str(i) + '.txt')
sample = model.generate()
processed_sample = process_sample(sample)
with open(path, 'w') as f:
f.write(processed_sample)
评论列表
文章目录