network.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:cocktail-party 作者: avivga 项目源码 文件源码
def build(video_shape, audio_spectrogram_size):
        model = Sequential()

        model.add(ZeroPadding3D(padding=(1, 2, 2), name='zero1', input_shape=video_shape))
        model.add(Convolution3D(32, (3, 5, 5), strides=(1, 2, 2), kernel_initializer='he_normal', name='conv1'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max1'))
        model.add(Dropout(0.25))

        model.add(ZeroPadding3D(padding=(1, 2, 2), name='zero2'))
        model.add(Convolution3D(64, (3, 5, 5), strides=(1, 1, 1), kernel_initializer='he_normal', name='conv2'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max2'))
        model.add(Dropout(0.25))

        model.add(ZeroPadding3D(padding=(1, 1, 1), name='zero3'))
        model.add(Convolution3D(128, (3, 3, 3), strides=(1, 1, 1), kernel_initializer='he_normal', name='conv3'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), name='max3'))
        model.add(Dropout(0.25))

        model.add(TimeDistributed(Flatten(), name='time'))

        model.add(Dense(1024, kernel_initializer='he_normal', name='dense1'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dropout(0.25))

        model.add(Dense(1024, kernel_initializer='he_normal', name='dense2'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dropout(0.25))

        model.add(Flatten())

        model.add(Dense(2048, kernel_initializer='he_normal', name='dense3'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dropout(0.25))

        model.add(Dense(2048, kernel_initializer='he_normal', name='dense4'))
        model.add(BatchNormalization())
        model.add(LeakyReLU())
        model.add(Dropout(0.25))

        model.add(Dense(audio_spectrogram_size, name='output'))

        model.summary()

        return VideoToSpeechNet(model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号