def test_custom_repr_graph(self):
# Define a custom representation function graph
def build_tanh_representation_graph(tf_features, n_components, n_features, node_name_ending):
tf_tanh_weights = tf.Variable(tf.random_normal([n_features, n_components],
stddev=.5),
name='tanh_weights_%s' % node_name_ending)
tf_repr = tf.nn.tanh(tf.sparse_tensor_dense_matmul(tf_features, tf_tanh_weights))
# Return repr layer and variables
return tf_repr, [tf_tanh_weights]
# Build a model with the custom representation function
model = TensorRec(user_repr_graph=build_tanh_representation_graph,
item_repr_graph=build_tanh_representation_graph)
self.assertIsNotNone(model)
评论列表
文章目录