genderclassifier.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:namegenderclassifier 作者: joaoalvarenga 项目源码 文件源码
def train(self, dataset, train_split=0.8, dense_size=32, learning_rate=0.001, batch_size=32, epochs=50, activation='relu'):
        self.__load_dataset(dataset, train_split)

        train_x = np.array(self.__train_data[:, 0].tolist())
        train_y = to_categorical(self.__train_data[:, 1], 2)

        test_x = np.array(self.__test_data[:, 0].tolist())
        test_y = to_categorical(self.__test_data[:, 1], 2)

        print(train_x.shape)
        self.__model = Sequential()
        self.__model.add(Dense(dense_size, input_dim=train_x.shape[1], activation=activation, init='glorot_uniform'))
        self.__model.add(Dense(train_y.shape[1], activation='softmax', init='glorot_uniform'))
        self.__model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['categorical_accuracy'])

        self.__model.fit(train_x, train_y, batch_size=batch_size, nb_epoch=epochs, validation_data=(test_x, test_y), verbose=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号