def get_cell_with_dropout(self, rng: RandomStreams, dropout_rate: float):
with_dropout = SimpleRecurrentCell.__new__(self.__class__)
with_dropout.__prev_hidden_to_next, with_dropout.__prediction_to_hidden = dropout_multiple(
dropout_rate, rng, True, self.__prev_hidden_to_next, self.__prediction_to_hidden)
with_dropout.__bias = self.__bias
with_dropout.get_cell_with_dropout = None
with_dropout.__name = self.__name + ":with_dropout"
return with_dropout
评论列表
文章目录