def discriminator_finalstate(self, states): # FIXME
'''Discriminator that operates on the final states of the sentences.'''
with tf.variable_scope("Discriminator"):
# indices = lengths - 2, since the generated output skips <sos>
#final_states = utils.rowwise_lookup(states, self.lengths - 2)
final_states = states[:, -1, :]
combined = tf.concat(1, [self.latent, final_states]) # TODO transform latent
lin1 = tf.nn.elu(utils.linear(combined, cfg.hidden_size, True, 0.0,
scope='discriminator_lin1'))
lin2 = tf.nn.elu(utils.linear(lin1, cfg.hidden_size // 2, True, 0.0,
scope='discriminator_lin2'))
output = utils.linear(lin2, 1, True, 0.0, scope='discriminator_output')
return output
评论列表
文章目录