def __init__(self, input_dim, z_dim, class_num, batch_size):
self.input_dim = input_dim
self.z_dim = z_dim
self.class_num = class_num
self.batch_size = batch_size
self.lr = 0.0001
# -- encoder -------
self.encoder = Encoder([input_dim, 1200, 600, 100], z_dim)
# -- decoder -------
self.decoder = Decoder([z_dim, 100, 600, 1200, input_dim])
# -- discriminator --
self.discriminator = Discriminator([z_dim + (class_num + 1), 50, 20, 10, 1])
# -- sampler ----
self.sampler = Sampler(class_num)
评论列表
文章目录