def get_video_weights(video_id_batch):
video_id_to_index = tf.contrib.lookup.string_to_index_table_from_file(
vocabulary_file=FLAGS.sample_vocab_file, default_value=0)
indexes = video_id_to_index.lookup(video_id_batch)
weights, length = get_video_weights_array()
weights_input = tf.placeholder(tf.float32, shape=[length], name="sample_weights_input")
weights_tensor = tf.get_variable("sample_weights",
shape=[length],
trainable=False,
dtype=tf.float32,
initializer=tf.constant_initializer(weights))
weights_assignment = tf.assign(weights_tensor, weights_input)
tf.add_to_collection("weights_input", weights_input)
tf.add_to_collection("weights_assignment", weights_assignment)
video_weight_batch = tf.nn.embedding_lookup(weights_tensor, indexes)
return video_weight_batch
评论列表
文章目录