def QRNcell():
xq = Input(batch_shape=(batch_size, embedding_dim * 2))
# Split into context and query
xt = Lambda(lambda x, dim: x[:, :dim], arguments={'dim': embedding_dim},
output_shape=lambda s: (s[0], s[1] / 2))(xq)
qt = Lambda(lambda x, dim: x[:, dim:], arguments={'dim': embedding_dim},
output_shape=lambda s: (s[0], s[1] / 2))(xq)
h_tm1 = Input(batch_shape=(batch_size, embedding_dim))
zt = Dense(1, activation='sigmoid', bias_initializer=Constant(2.5))(multiply([xt, qt]))
zt = Lambda(lambda x, dim: K.repeat_elements(x, dim, axis=1), arguments={'dim': embedding_dim})(zt)
ch = Dense(embedding_dim, activation='tanh')(concatenate([xt, qt], axis=-1))
rt = Dense(1, activation='sigmoid')(multiply([xt, qt]))
rt = Lambda(lambda x, dim: K.repeat_elements(x, dim, axis=1), arguments={'dim': embedding_dim})(rt)
ht = add([multiply([zt, ch, rt]), multiply([Lambda(lambda x: 1 - x, output_shape=lambda s: s)(zt), h_tm1])])
return RecurrentModel(input=xq, output=ht, initial_states=[h_tm1], final_states=[ht], return_sequences=True)
#
# Load data
#
query_reduction_network.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录