def shift(self,
model_input,
shift_width,
**unused_params):
max_frames = model_input.get_shape().as_list()[1]
num_features = model_input.get_shape().as_list()[2]
shift_inputs = []
for i in xrange(shift_width):
if i == 0:
shift_inputs.append(model_input)
else:
shift_inputs.append(tf.pad(model_input, paddings=[[0,0],[i,0],[0,0]])[:,:max_frames,:])
shift_output = tf.concat(shift_inputs, axis=2)
return shift_output
评论列表
文章目录