def main(args):
from tfacvp.model import ActionConditionalVideoPredictionModel
from tfacvp.util import post_process_rgb
with tf.Graph().as_default() as graph:
logging.info('Create model [num_act = %d] for testing' % (args.num_act))
model = ActionConditionalVideoPredictionModel(num_act=args.num_act, is_train=False)
config = get_config(args)
s = np.load(args.data)
mean = np.load(args.mean)
scale = 255.0
with tf.Session(config=config) as sess:
logging.info('Loading weights from %s' % (args.load))
model.restore(sess, args.load)
for i in range(args.num_act):
logging.info('Predict next frame condition on action %d' % (i))
a = np.identity(args.num_act)[i]
x_t_1_pred_batch = model.predict(sess, s[np.newaxis, :], a[np.newaxis, :])[0]
img = x_t_1_pred_batch[0]
img = post_process(img, mean, scale)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite('pred-%02d.png' % i, img)
example.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录