def sample_lstm_state(args): if 'layer_sizes' in vars(args): hx = V(th.zeros(1, args.layer_sizes)) cx = V(th.zeros(1, args.layer_sizes)) return hx, cx else: return None