Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
人脸识别经典算法一 特征脸方法(Eigenface)
Mar 13 Python
Python使用matplotlib实现基础绘图功能示例
Jul 03 Python
详解Django的model查询操作与查询性能优化
Oct 16 Python
python打包生成的exe文件运行时提示缺少模块的解决方法
Oct 31 Python
解决Python3 被PHP程序调用执行返回乱码的问题
Feb 16 Python
python安装scipy的步骤解析
Sep 28 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 Python
python属于跨平台语言码
Jun 09 Python
使用Python将语音转换为文本的方法
Aug 10 Python
利用python批量爬取百度任意类别的图片的实现方法
Oct 07 Python
Python 线程池模块之多线程操作代码
May 20 Python
pytorch model.cuda()花费时间很长的解决
Jun 01 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
php中将一个对象保存到Session中的方法
2015/03/13 PHP
php图像处理类实例
2015/07/28 PHP
Zend Framework入门教程之Zend_Config组件用法详解
2016/12/09 PHP
onclick与listeners的执行先后问题详细解剖
2013/01/07 Javascript
jquery实现点击弹出层效果的简单实例
2014/03/03 Javascript
jQuery中ajax的get()方法用法实例
2014/12/26 Javascript
Nodejs为什么选择javascript为载体语言
2015/01/13 NodeJs
JS+CSS实现简易实用的滑动门菜单效果
2015/09/18 Javascript
node使用UEditor富文本编辑器的方法实例
2017/07/11 Javascript
通俗解释JavaScript正则表达式快速记忆
2017/08/23 Javascript
详解Vue CLI 3.0脚手架如何mock数据
2018/11/23 Javascript
vue使用Font Awesome的方法步骤
2019/02/26 Javascript
JS实现压缩上传图片base64长度功能
2019/12/03 Javascript
vue.js封装switch开关组件的操作
2020/10/26 Javascript
Python获取服务器信息的最简单实现方法
2015/03/05 Python
Python如何抓取天猫商品详细信息及交易记录
2018/02/23 Python
python 字符串常用函数详解
2019/09/11 Python
什么是Python中的顺序表
2020/06/02 Python
Python类的继承super相关原理解析
2020/10/22 Python
CSS3实例分享--超炫checkbox复选框和radio单选框
2014/09/01 HTML / CSS
HTML5的Video标签有部分MP4无法播放的问题解析(多图)
2017/08/18 HTML / CSS
法国时尚童装网站:Melijoe
2016/08/10 全球购物
阿根廷在线宠物商店:Puppis
2018/03/23 全球购物
介绍一下linux的文件权限
2012/02/15 面试题
高级文秘工作总结的自我评价
2013/09/28 职场文书
毕业生幼师求职自荐信
2013/10/01 职场文书
答谢会策划方案
2014/05/12 职场文书
幼儿园健康教育方案
2014/06/14 职场文书
中国梦读书活动总结
2014/07/10 职场文书
一年级数学上册复习计划
2015/01/17 职场文书
医生辞职信范文
2015/03/02 职场文书
加薪通知
2015/04/25 职场文书
2015年汽车销售员工作总结
2015/07/24 职场文书
经典祝酒词大全
2015/08/12 职场文书
为什么 Nginx 比 Apache 更牛逼
2021/03/31 Servers