def main():
if len(sys.argv) == 3:
config = Config(sys.argv[1], sys.argv[2])
else:
assert False
phase = config.items['phase']
from reader import Reader
train_set = Reader(phase='train', batch_size=config.items['batch_size'], do_shuffle=True)
valid_set = Reader(phase='val', batch_size=10, do_shuffle=False)
test_set = Reader(phase='test', batch_size=10, do_shuffle=False)
glog.info('generating model...')
from model import Model
# with tf.device('/cpu:0'):
# with tf.device('/gpu:%d'%config.items['gpu']):
model = Model(config.items['lr'])
# try:
# config.items['starting'] = int(config.items['model'].split('_')[-1])
# except:
config.items['starting'] = 0
# snapshot path
mkdir_safe(config.items['snap_path'])
sess_config = tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 4})
sess_config.gpu_options.allow_growth = True
with tf.Session(config=sess_config) as sess:
tf.global_variables_initializer().run()
if 'model' in config.items.keys():
model.saver.restore(sess, config.items['model'])
glog.info('loading model: %s...' % config.items['model'])
if phase == 'ctc':
glog.info('ctc training...')
train_valid(sess, model, train_set, valid_set, test_set, config)
# elif phase == 'extract_feature':
# pass
# elif phase == 'get_prediction':
# from reader import Reader
# train_set = Reader(phase='train', batch_size=config.items['batch_size'], do_shuffle=False, resample=False, distortion=False)
# glog.info('feature extracting...')
# get_prediction(model, train_set, config)
# elif phase == 'top_k_prediction':
# from reader import Reader
# train_set = Reader(phase='test', batch_size=config.items['batch_size'], do_shuffle=False, resample=False, distortion=False)
# glog.info('feature extracting...')
# get_top_k_prediction(model, train_set, config)
glog.info('end')
评论列表
文章目录