mnist_net2net.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def make_teacher_model(train_data, validation_data, epochs=3):
    '''Train a simple CNN as teacher model.
    '''
    model = Sequential()
    model.add(Conv2D(64, 3, input_shape=input_shape,
                     padding='same', name='conv1'))
    model.add(MaxPooling2D(2, name='pool1'))
    model.add(Conv2D(64, 3, padding='same', name='conv2'))
    model.add(MaxPooling2D(2, name='pool2'))
    model.add(Flatten(name='flatten'))
    model.add(Dense(64, activation='relu', name='fc1'))
    model.add(Dense(num_class, activation='softmax', name='fc2'))
    model.compile(loss='categorical_crossentropy',
                  optimizer=SGD(lr=0.01, momentum=0.9),
                  metrics=['accuracy'])

    train_x, train_y = train_data
    history = model.fit(train_x, train_y,
                        epochs=epochs,
                        validation_data=validation_data)
    return model, history
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号