def __build_model(self):
encoder_cell, decoder_cell = self.__build_rnn_cell()
with tf.variable_scope('encoder_layer'):
encoder_output, encoder_state = tf.nn.dynamic_rnn(
cell=encoder_cell,
inputs=self.encoder_input_embedding,
dtype=tf.float32
)
tf.summary.histogram('encoder_output', encoder_output)
del encoder_output
with tf.variable_scope('decoder_layer'):
output, decoder_state = tf.nn.dynamic_rnn(
cell=decoder_cell,
inputs=self.decoder_input_embedding,
initial_state=encoder_state,
dtype=tf.float32
)
tf.summary.histogram('decoder_layer', output)
del decoder_state
self.logit, self.cost, self.train_op = self.__build_ops(output)
self.output = tf.arg_max(self.logit, 2)
self.merged = tf.summary.merge_all()
deepAPI_model.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录