def seq_linear(linear, x): batch, units, length, _ = x.shape h = linear(F.transpose(x, (0, 2, 1, 3)).reshape(batch * length, units)) return F.transpose(h.reshape((batch, length, units, 1)), (0, 2, 1, 3))