def __call__(self, x):
h = F.reshape(self.l0(x), ((x.shape[0],) + self.embed_shape))
for i in range(self.n_blocks):
for j in range(self.block_size):
h = F.elu(getattr(self, 'c{}'.format(i*j+j))(h))
if i < self.n_blocks - 1:
h = F.unpooling_2d(h, ksize=2, stride=2, cover_all=False)
return self.ln(h)
评论列表
文章目录