def create(self):
language_model = Sequential()
self.textual_embedding(language_model, mask_zero=True)
self.stacked_RNN(language_model)
language_model.add(self._config.recurrent_encoder(
self._config.hidden_state_dim,
return_sequences=False,
go_backwards=self._config.go_backwards))
self.language_model = language_model
visual_model_factory = \
select_sequential_visual_model[self._config.trainable_perception_name](
self._config.visual_dim)
visual_model = visual_model_factory.create()
visual_dimensionality = visual_model_factory.get_dimensionality()
self.visual_embedding(visual_model, visual_dimensionality)
#visual_model = Sequential()
#self.visual_embedding(visual_model)
self.visual_model = visual_model
if self._config.multimodal_merge_mode == 'dot':
self.add(Merge([language_model, visual_model], mode='dot', dot_axes=[(1,),(1,)]))
else:
self.add(Merge([language_model, visual_model], mode=self._config.multimodal_merge_mode))
self.add(Dropout(0.5))
self.add(Dense(self._config.output_dim))
self.add(RepeatVector(self._config.max_output_time_steps))
self.add(self._config.recurrent_decoder(
self._config.hidden_state_dim, return_sequences=True))
self.add(Dropout(0.5))
self.add(TimeDistributedDense(self._config.output_dim))
self.add(Activation('softmax'))
###
# Graph-based models
###
model_zoo.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录