def _construct_sequence(batch):
hidden, boxes = batch
# initializing the state with features
states = [hidden[0]]
# TODO: make this dependent on the data
# TODO: make it with scan ?
for t in range(1, T):
# find the matching boxes. TODO: try with the soft matching function
if c.match_kind == 'boxes':
dists = nnutil.cdist(boxes[t-1], boxes[t])
idxs = tf.argmin(dists, 1, 'idxs')
state_prev = tf.gather(states[t-1], idxs)
elif c.match_kind == 'hidden':
# TODO: actually it makes more sense to compare on states
dists = nnutil.cdist(hidden[t-1], hidden[t])
idxs = tf.argmin(dists, 1, 'idxs')
state_prev = tf.gather(states[t-1], idxs)
elif c.match_kind == 'hidden-soft':
dists = nnutil.cdist(hidden[t-1], hidden[t])
weights = slim.softmax(dists)
state_prev = tf.matmul(weights, states[t-1])
else:
raise RuntimeError('Unknown match_kind: %s' % c.match_kind)
def _construct_update(reuse):
state = tf.concat(1, [state_prev, hidden[t]])
# TODO: initialize jointly
reset = slim.fully_connected(state, NFH, tf.nn.sigmoid,
reuse=reuse,
scope='reset')
step = slim.fully_connected(state, NFH, tf.nn.sigmoid,
reuse=reuse,
scope='step')
state_r = tf.concat(1, [reset * state_prev, hidden[t]])
state_up = slim.fully_connected(state_r, NFH, tf.nn.tanh,
reuse=reuse,
scope='state_up')
return state_up, step
try:
state_up, step = _construct_update(reuse=True)
except ValueError:
state_up, step = _construct_update(reuse=False)
state = step * state_up + (1.0 - step) * state_prev
states.append(state)
return tf.pack(states)
volleyball_train_stage_b.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录