def initialize(self, custom_getter):
super(DistributionModel, self).initialize(custom_getter)
# Network
self.network = Network.from_spec(
spec=self.network_spec,
kwargs=dict(summary_labels=self.summary_labels)
)
# Distributions
self.distributions = self.create_distributions()
# Network internals
self.internals_input.extend(self.network.internals_input())
self.internals_init.extend(self.network.internals_init())
# KL divergence function
self.fn_kl_divergence = tf.make_template(
name_=(self.scope + '/kl-divergence'),
func_=self.tf_kl_divergence,
custom_getter_=custom_getter
)
评论列表
文章目录