trajmodel.py 文件源码

python
阅读 56 收藏 0 点赞 0 评论 0

项目:RNN-TrajModel 作者: wuhao5688 项目源码 文件源码
def build_rnn_layer(self, inputs_, train_phase):
    """
    Build the computation graph from inputs to outputs of the RNN layer.
    :param inputs_: [batch, t, emb], float
    :param train_phase: bool
    :return: rnn_outputs_: [batch, t, hid_dim], float
    """
    config = self.config

    def unrolled_rnn(cell, emb_inputs_, initial_state_, seq_len_):
      if not config.fix_seq_len:
        raise Exception("`config.fix_seq_len` should be set to `True` if using unrolled_rnn()")
      outputs = []
      state = initial_state_
      with tf.variable_scope("unrolled_rnn"):
        for t in range(config.max_seq_len):
          if t > 0:
            tf.get_variable_scope().reuse_variables()
          output, state = cell(emb_inputs_[:, t], state)  # [batch, hid_dim]
          outputs.append(output)
        rnn_outputs_ = tf.pack(outputs, axis=1)  # [batch, t, hid_dim]
      return rnn_outputs_
    def dynamic_rnn(cell, emb_inputs_, initial_state_, seq_len_):
      rnn_outputs_, last_states_ = tf.nn.dynamic_rnn(cell, emb_inputs_, initial_state=initial_state_,
                                                     sequence_length=seq_len_,
                                                     dtype=config.float_type)  # you should define dtype if initial_state is not provided
      return rnn_outputs_
    def bidirectional_rnn(cell, emb_inputs_, initial_state_, seq_len_):
      rnn_outputs_, output_states = tf.nn.bidirectional_dynamic_rnn(cell, cell, emb_inputs_, seq_len_,
                                                                    initial_state_, initial_state_, config.float_type)
      return tf.concat(2, rnn_outputs_)
    def rnn(cell, emb_inputs_, initial_state_, seq_len_):
      if not config.fix_seq_len:
        raise Exception("`config.fix_seq_len` should be set to `True` if using rnn()")
      inputs_ = tf.unpack(emb_inputs_, axis=1)
      outputs_, states_ = tf.nn.rnn(cell, inputs_, initial_state_, dtype=config.float_type, sequence_length=seq_len_)
      return outputs_

    if config.rnn == 'rnn':
      cell = tf.nn.rnn_cell.BasicRNNCell(config.hidden_dim)
    elif config.rnn == 'lstm':
      cell = tf.nn.rnn_cell.BasicLSTMCell(config.hidden_dim)
    elif config.rnn == 'gru':
      cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
    else:
      raise Exception("`config.rnn` should be correctly defined.")

    if train_phase and config.keep_prob < 1:
      cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=config.keep_prob)

    if config.num_layers is not None and config.num_layers > 1:
      cell = tf.nn.rnn_cell.MultiRNNCell([cell] * config.num_layers)

    initial_state_ = cell.zero_state(config.batch_size, dtype=config.float_type)
    if config.use_seq_len_in_rnn:
      seq_len_ = self.seq_len_
    else:
      seq_len_ = None
    rnn_outputs_ = dynamic_rnn(cell, inputs_, initial_state_, seq_len_)  # [batch, time, hid_dim]
    return rnn_outputs_
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号