def build_LPIRNN_model(self, train_phase):
config = self.config
self.lpi_ = self.build_sharedTask_part(train_phase)
loss_, loss_p_ = self.build_individualTask_part(train_phase, self.lpi_)
if config.trace_hid_layer:
self.trace_dict["lpi_"+str(config.trace_input_id)] = self.lpi_ # here you can collect the lpi w.r.t. a given state id
self.loss_dict["loss"] = loss_
self.loss_dict["loss_p"] = loss_p_
# compute grads and update params
self.build_trainer(self.loss_dict["loss"], tf.trainable_variables())
if config.use_v2_saver:
self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=config.max_ckpt_to_keep,
write_version=saver_pb2.SaverDef.V2)
else:
self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=config.max_ckpt_to_keep,
write_version=saver_pb2.SaverDef.V1)
评论列表
文章目录