def backward(self):
dx_flat = self.probs
coords = tf.transpose(tf.pack([tf.range(self.N * self.T), self.y_flat]))
binary_mask = tf.sparse_to_dense(coords, dx_flat.get_shape(), 1)
# convert 1/0 to True/False
binary_mask = tf.cast(binary_mask, tf.bool)
decremented = dx_flat - 1
# make new x out of old values or decresed, depending on mask
dx_flat = tf.select(binary_mask, decremented, dx_flat)
dx_flat /= self.N
dx_flat *= self.mask_flat[:, None]
dx = tf.reshape(dx_flat, [self.N, self.T, self.V])
return dx
评论列表
文章目录