def call(self,inputs):
"""
inputs in as array which contains the support set the embeddings, the target embedding as the second last value in the array, and true class of target embedding as the last value in the array
"""
similarities = []
targetembedding = inputs[-2]
numsupportset = len(inputs)-2
for ii in range(numsupportset):
supportembedding = inputs[ii]
dd = tf.negative(tf.sqrt(tf.reduce_sum(tf.square(supportembedding-targetembedding),1,keep_dims=True)))
similarities.append(dd)
similarities = tf.concat(axis=1,values=similarities)
softmax_similarities = tf.nn.softmax(similarities)
preds = tf.squeeze(tf.matmul(tf.expand_dims(softmax_similarities,1),inputs[-1]))
preds.set_shape((inputs[0].shape[0],self.nway))
return preds
评论列表
文章目录