def init_state(self, batch_size, K, dtype):
# inner RNN hidden state init
with tf.name_scope('inner_RNN_init'):
h = self.cell.zero_state(batch_size * K, dtype)
# initial prediction (B, K, W, H, C)
with tf.name_scope('pred_init'):
pred_shape = tf.stack([batch_size, K] + self.input_shape.as_list())
pred = tf.ones(shape=pred_shape, dtype=dtype) * self.pred_init
# initial gamma (B, K, W, H, 1)
with tf.name_scope('gamma_init'):
gamma_shape = self.gamma_shape.as_list()
shape = tf.stack([batch_size, K] + gamma_shape)
# init with Gaussian distribution
gamma = tf.abs(tf.random_normal(shape, dtype=dtype))
gamma /= tf.reduce_sum(gamma, 1, keep_dims=True)
# init with all 1 if K = 1
if K == 1:
gamma = tf.ones_like(gamma)
return h, pred, gamma
评论列表
文章目录