model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:SmileGO 作者: kongjiellx 项目源码 文件源码
def train(self):
        checkpointer = ModelCheckpoint(filepath='./best.model',
                                       verbose=1,
                                       monitor='val_loss',
                                       save_best_only=True)
        earlystop = EarlyStopping(monitor='val_loss',
                                  patience=10,
                                  verbose=1,
                                  mode='auto')

        self.model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])
        json_model = self.model.to_json()
        fjson = open('model.json', 'w')
        fjson.write(json_model)
        fjson.close()
        print 'model_json_saved!'
        train_x, train_y = pre_x_y(path='kgs-19-2017-01-new/')
        valid_x, valid_y = pre_x_y(path='kgs-19-2017-01-new/')
        print 'train_data_len: ', len(train_x)
        print 'valid_data_len: ', len(valid_x)
        self.model.fit({'x': train_x}, {'out': train_y},
                  batch_size=32,
                  nb_epoch=1,
                  shuffle=True,
                  verbose=1,
                  callbacks=[checkpointer, earlystop],
                  validation_data=({'x': valid_x}, {'out': valid_y}))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号