volleyball_train_stage_b.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:social-scene-understanding 作者: cvlab-epfl 项目源码 文件源码
def _construct_sequence(batch):
        hidden, boxes = batch
        # initializing the state with features
        states = [hidden[0]]
        # TODO: make this dependent on the data
        # TODO: make it with scan ?
        for t in range(1, T):
          # find the matching boxes. TODO: try with the soft matching function
          if c.match_kind == 'boxes':
            dists = nnutil.cdist(boxes[t-1], boxes[t])
            idxs = tf.argmin(dists, 1, 'idxs')
            state_prev = tf.gather(states[t-1], idxs)
          elif c.match_kind == 'hidden':
            # TODO: actually it makes more sense to compare on states
            dists = nnutil.cdist(hidden[t-1], hidden[t])
            idxs = tf.argmin(dists, 1, 'idxs')
            state_prev = tf.gather(states[t-1], idxs)
          elif c.match_kind == 'hidden-soft':
            dists = nnutil.cdist(hidden[t-1], hidden[t])
            weights = slim.softmax(dists)
            state_prev = tf.matmul(weights, states[t-1])
          else:
            raise RuntimeError('Unknown match_kind: %s' % c.match_kind)

          def _construct_update(reuse):
            state = tf.concat(1, [state_prev, hidden[t]])
            # TODO: initialize jointly
            reset = slim.fully_connected(state, NFH, tf.nn.sigmoid,
                                         reuse=reuse,
                                         scope='reset')
            step = slim.fully_connected(state, NFH, tf.nn.sigmoid,
                                        reuse=reuse,
                                        scope='step')
            state_r = tf.concat(1, [reset * state_prev, hidden[t]])
            state_up = slim.fully_connected(state_r, NFH, tf.nn.tanh,
                                            reuse=reuse,
                                            scope='state_up')
            return state_up, step
          try:
            state_up, step = _construct_update(reuse=True)
          except ValueError:
            state_up, step = _construct_update(reuse=False)

          state = step * state_up + (1.0 - step) * state_prev
          states.append(state)
        return tf.pack(states)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号