def test(self, model, cases):
sess = tf.get_default_session()
guarantee_initialized_variables(sess)
embeds = model.compute(model.embeds, cases)
primitive_embeddings = RLongPrimitiveEmbeddings(6)
# compute object embedding after applying projection
object_projection_layer = model._object_projection_layer
W, b = object_projection_layer.get_weights() # shapes [10, 6] and [6]
object_embed = np.ones(10).dot(W) + b
assert_array_almost_equal(embeds[0],
np.concatenate((np.zeros(6), primitive_embeddings['r'], primitive_embeddings[-1]))
)
assert_array_almost_equal(embeds[1],
np.concatenate((np.zeros(6), np.zeros(6), primitive_embeddings['X1/1']))
)
assert_array_almost_equal(embeds[2],
np.concatenate((primitive_embeddings['b'], object_embed, object_embed))
)
评论列表
文章目录