def fprop(self, inputs):
with tf.variable_scope('model', values=[inputs]):
one_hot_inputs = tf.one_hot(inputs, self.n_tokens, axis=-1)
with tf.variable_scope('rnn', values=[inputs]):
states, _ = dynamic_rnn(cell=IsanCell(self.hidden_dim), inputs=one_hot_inputs, dtype=tf.float32)
Wo = tf.get_variable('Wo', shape=[self.hidden_dim, self.target_dim],
initializer=tf.random_normal_initializer(
stddev=1.0 / (self.hidden_dim + self.target_dim) ** 2))
bo = tf.get_variable('bo', shape=[1, self.target_dim],
initializer=tf.zeros_initializer())
bs, t = inputs.get_shape().as_list()
logits = tf.matmul(tf.reshape(states, [t * bs, self.hidden_dim]), Wo) + bo
logits = tf.reshape(logits, [bs, t, self.target_dim])
return logits
评论列表
文章目录