trajmodel.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:RNN-TrajModel 作者: wuhao5688 项目源码 文件源码
def build_LPIRNN_model(self, train_phase):
    config = self.config
    self.lpi_ = self.build_sharedTask_part(train_phase)
    loss_, loss_p_ = self.build_individualTask_part(train_phase, self.lpi_)
    if config.trace_hid_layer:
      self.trace_dict["lpi_"+str(config.trace_input_id)] = self.lpi_ # here you can collect the lpi w.r.t. a given state id
    self.loss_dict["loss"] = loss_
    self.loss_dict["loss_p"] = loss_p_
    # compute grads and update params
    self.build_trainer(self.loss_dict["loss"], tf.trainable_variables())
    if config.use_v2_saver:
      self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=config.max_ckpt_to_keep,
                                  write_version=saver_pb2.SaverDef.V2)
    else:
      self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=config.max_ckpt_to_keep,
                                  write_version=saver_pb2.SaverDef.V1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号