def create_test_set(x_lst):
n = len(x_lst)
x_lens = np.array(map(len, x_lst))
max_len = max(map(len, x_lst)) - 1
u_out = np.zeros((n, max_len, OUTDIM), dtype='float32')*np.nan
x_out = np.zeros((n, max_len, OUTDIM), dtype='float32')*np.nan
for row, vec in enumerate(x_lst):
l = len(vec) - 1
u = vec[:-1] # all but last element
x = vec[1:] # all but first element
x_out[row, :l] = x
u_out[row, :l] = u
mask = np.invert(np.isnan(x_out))
x_out[np.isnan(x_out)] = 0
u_out[np.isnan(u_out)] = 0
mask = mask[:, :, 0]
assert np.all((mask.sum(axis=1)+1) == x_lens)
return u_out, x_out, mask.astype('float32')
评论列表
文章目录