def augmented_loss(self, y_true, y_pred):
_y_pred = Activation("softmax")(y_pred)
loss = K.categorical_crossentropy(_y_pred, y_true)
# y is (batch x seq x vocab)
y_indexes = K.argmax(y_true, axis=2) # turn one hot to index. (batch x seq)
y_vectors = self.embedding(y_indexes) # lookup the vector (batch x seq x vector_length)
#v_length = self.setting.vector_length
#y_vectors = K.reshape(y_vectors, (-1, v_length))
#y_t = K.map_fn(lambda v: K.dot(self.embedding.embeddings, K.reshape(v, (-1, 1))), y_vectors)
#y_t = K.squeeze(y_t, axis=2) # unknown but necessary operation
#y_t = K.reshape(y_t, (-1, self.sequence_size, self.vocab_size))
# vector x embedding dot products (batch x seq x vocab)
y_t = tf.tensordot(y_vectors, K.transpose(self.embedding.embeddings), 1)
y_t = K.reshape(y_t, (-1, self.sequence_size, self.vocab_size)) # explicitly set shape
y_t = K.softmax(y_t / self.temperature)
_y_pred_t = Activation("softmax")(y_pred / self.temperature)
aug_loss = kullback_leibler_divergence(y_t, _y_pred_t)
loss += (self.gamma * self.temperature) * aug_loss
return loss
评论列表
文章目录