def FrameProcessor(frames):
"""
frames.shape: (batch size, n frames, FRAME_SIZE)
output.shape: (batch size, n frames, DIM)
"""
embedded = lib.ops.Embedding('FrameEmbedding', Q_LEVELS, Q_LEVELS, frames)
embedded = embedded.reshape((frames.shape[0], frames.shape[1], Q_LEVELS * FRAME_SIZE))
output = MLP('FrameProcessor', FRAME_SIZE*Q_LEVELS, DIM, embedded)
return output
# frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1)
# frames *= lib.floatX(2)
# output = MLP('FrameProcessor', FRAME_SIZE, DIM, frames)
# return output
评论列表
文章目录