def _initialize_weights(self):
all_weights = dict()
if self.pretrain_flag > 0:
weight_saver = tf.train.import_meta_graph(self.save_file + '.meta')
pretrain_graph = tf.get_default_graph()
feature_embeddings = pretrain_graph.get_tensor_by_name('feature_embeddings:0')
feature_bias = pretrain_graph.get_tensor_by_name('feature_bias:0')
bias = pretrain_graph.get_tensor_by_name('bias:0')
with tf.Session() as sess:
weight_saver.restore(sess, self.save_file)
fe, fb, b = sess.run([feature_embeddings, feature_bias, bias])
all_weights['feature_embeddings'] = tf.Variable(fe, dtype=tf.float32)
all_weights['feature_bias'] = tf.Variable(fb, dtype=tf.float32)
all_weights['bias'] = tf.Variable(b, dtype=tf.float32)
else:
all_weights['feature_embeddings'] = tf.Variable(
tf.random_normal([self.features_M, self.hidden_factor], 0.0, 0.01),
name='feature_embeddings') # features_M * K
all_weights['feature_bias'] = tf.Variable(
tf.random_uniform([self.features_M, 1], 0.0, 0.0), name='feature_bias') # features_M * 1
all_weights['bias'] = tf.Variable(tf.constant(0.0), name='bias') # 1 * 1
return all_weights
评论列表
文章目录