def connect(self, inputs, is_train):
""" Trick to speed up model compiling at decoding time.
(Avoids building a complicated CG.)
"""
if not self.fix_mask:
self.generate_mask(inputs.shape, is_train)
if self.fast_predict:
return inputs * (1 - self.dropout_prob)
return ifelse(is_train,
inputs * self.dropout_mask,
inputs * (1 - self.dropout_prob))
评论列表
文章目录