def build_skip_conn_attn(cnn_channels, h_cnn_time, x_time, timespan):
"""Build skip connection for attention based model."""
skip = [None]
skip_ch = [0]
nlayers = len(h_cnn_time[0])
timespan = len(h_cnn_time)
for jj in range(nlayers):
lidx = nlayers - jj - 2
if lidx >= 0:
ll = [h_cnn_time[tt][lidx] for tt in range(timespan)]
else:
ll = x_time
layer = tf.concat(1, [tf.expand_dims(l, 1) for l in ll])
ss = tf.shape(layer)
layer = tf.reshape(layer, tf.pack([-1, ss[2], ss[3], ss[4]]))
skip.append(layer)
ch_idx = lidx + 1
skip_ch.append(cnn_channels[ch_idx])
return skip, skip_ch
评论列表
文章目录