def combined_discriminate2(data,sae,discriminator,**kwargs):
_data = Input(shape=data.shape[1:])
_data2 = Reshape((*data.shape[1:],1))(_data)
_categorical = wrap(_data,K.concatenate([_data2, 1-_data2],-1),name="categorical")
_images = sae.decoder(_categorical)
_features = sae.features(_images)
_results = discriminator.net(_features)
m = Model(_data, _results)
return m.predict(data,**kwargs)
评论列表
文章目录