def buildEmbedding(self):
weights = self.embedding_params.get('weights')
#assert weights
trainable = self.params.get('embedding_trainable', False)
if trainable:
logging.info('Embedding Weights is Trainable!')
else:
logging.info('Embedding Weights is Not Trainable!')
with tf.name_scope('embedding'):
W = tf.Variable(
weights,
name = 'embedding',
trainable = trainable,
dtype = tf.float32
)
self.tensors['q_embedding'] = tf.nn.embedding_lookup(W, self.tensors['q_input'])
self.tensors['a_embedding'] = tf.nn.embedding_lookup(W, self.tensors['a_input'])
评论列表
文章目录