def preprocess_input(self, inputs, training=None):
if self.implementation == 0:
cell_mask = inputs[:, :, -self.units:]
inputs = inputs[:, :, :-self.units]
inputs_prep = super(CellMaskedLSTM, self).preprocess_input(
inputs,
training
)
return K.concatenate([inputs_prep, cell_mask], axis=2)
else:
return inputs
评论列表
文章目录