def iaf(self, z, h, lin1, lin2):
ms = F.crelu(lin1(F.concat((z, h), axis=1)))
ms = lin2(ms)
m, s = F.split_axis(ms, 2, axis=1)
s = F.sigmoid(s)
z = s*z + (1-s)*m
# pdb.set_trace()
return z, -F.sum(F.log(s), axis=1)