def build_inference(self):
embed = self.embedding
context = self.context
hidden = self.hidden
features = self._cnn_encoding(embedding=embed, context=context, hidden=hidden)
c = tf.zeros([tf.shape(self.context)[0], self.H])
h = tf.zeros([tf.shape(self.context)[0], self.H])
(self.init_c, self.init_h) = self._lstm(h, c, features, reuse=False)
_ = self._decode_lstm(self.init_h)
_ = self._word_embedding(inputs=tf.fill([tf.shape(features)[0]], self._start))
self.in_word = tf.placeholder(tf.int32, [None])
x = self._word_embedding(inputs=self.in_word, reuse=True)
self.c_feed = tf.placeholder(tf.float32, [None, self.H])
self.h_feed = tf.placeholder(tf.float32, [None, self.H])
(self.c, self.h) = self._lstm(self.h_feed, self.c_feed, x, reuse=True)
self.log_softmax = self._decode_lstm(self.h, reuse=True)
评论列表
文章目录