modellib.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:rec-attend-public 作者: renmengye 项目源码 文件源码
def build_skip_conn_attn(cnn_channels, h_cnn_time, x_time, timespan):
  """Build skip connection for attention based model."""
  skip = [None]
  skip_ch = [0]
  nlayers = len(h_cnn_time[0])
  timespan = len(h_cnn_time)
  for jj in range(nlayers):
    lidx = nlayers - jj - 2
    if lidx >= 0:
      ll = [h_cnn_time[tt][lidx] for tt in range(timespan)]
    else:
      ll = x_time
    layer = tf.concat(1, [tf.expand_dims(l, 1) for l in ll])
    ss = tf.shape(layer)
    layer = tf.reshape(layer, tf.pack([-1, ss[2], ss[3], ss[4]]))
    skip.append(layer)
    ch_idx = lidx + 1
    skip_ch.append(cnn_channels[ch_idx])
  return skip, skip_ch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号