maxout.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:c2w2c 作者: milankinen 项目源码 文件源码
def build(self, input_shape):
    self.input_spec = [InputSpec(dtype=K.floatx(),
                                 shape=(None, input_shape[0][1], input_shape[0][2])),
                       InputSpec(dtype=K.floatx(),
                                 shape=(None, input_shape[1][1], input_shape[1][2])),
                       InputSpec(dtype=K.floatx(),
                                 shape=(None, input_shape[2][1], input_shape[2][2]))]

    self.W_h = self.init((self.nb_feature, input_shape[0][2], self.output_dim),
                         name='{}_W_h'.format(self.name))

    self.W_y = self.init((self.nb_feature, input_shape[1][2], self.output_dim),
                         name='{}_W_y'.format(self.name))

    self.W_c = self.init((self.nb_feature, input_shape[2][2], self.output_dim),
                         name='{}_W_c'.format(self.name))

    trainable = [self.W_h, self.W_y, self.W_c]

    if self.bias:
      self.b = K.zeros((self.nb_feature, self.output_dim),
                       name='{}_b'.format(self.name))
      self.trainable_weights = trainable + [self.b]
    else:
      self.trainable_weights = trainable
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号