def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
input_dim = input_shape[4] - 1 # ignore sense prior parameter
self.input_dim = input_dim
# Saving onto-lstm weights to set them later. This way, LSTM's build method won't
# delete them.
initial_ontolstm_weights = self.initial_weights
self.initial_weights = None
lstm_input_shape = input_shape[:2] + (input_dim,) # removing senses and hyps
# Now calling LSTM's build to initialize the LSTM weights
super(OntoAttentionLSTM, self).build(lstm_input_shape)
# This would have changed the input shape and ndim. Reset it again.
self.input_spec = [InputSpec(shape=input_shape)]
if self.use_attention:
# Following are the attention parameters
self.input_hyp_projector = self.inner_init((input_dim, self.output_dim),
name='{}_input_hyp_projector'.format(self.name)) # Projection operator for synsets
self.context_hyp_projector = self.inner_init((self.output_dim, self.output_dim),
name='{}_context_hyp_projector'.format(self.name)) # Projection operator for hidden state (context)
self.hyp_projector2 = self.inner_init((self.output_dim, self.output_dim),
name='{}_hyp_projector2'.format(self.name)) # Projection operator for hidden state (context)
self.hyp_scorer = self.init((self.output_dim,), name='{}_hyp_scorer'.format(self.name))
# LSTM's build method would have initialized trainable_weights. Add to it.
self.trainable_weights.extend([self.input_hyp_projector, self.context_hyp_projector,
self.hyp_projector2, self.hyp_scorer])
if initial_ontolstm_weights is not None:
self.set_weights(initial_ontolstm_weights)
del initial_ontolstm_weights
评论列表
文章目录