Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python 遍历字符串(含汉字)实例详解
Apr 04 Python
python获取多线程及子线程的返回值
Nov 15 Python
Sublime开发python程序的示例代码
Jan 24 Python
pandas值替换方法
Jul 10 Python
Python计算开方、立方、圆周率,精确到小数点后任意位的方法
Jul 17 Python
详解【python】str与json类型转换
Apr 29 Python
在Pandas中处理NaN值的方法
Jun 25 Python
找Python安装目录,设置环境路径以及在命令行运行python脚本实例
Mar 09 Python
Python连接Mysql进行增删改查的示例代码
Aug 03 Python
python excel和yaml文件的读取封装
Jan 12 Python
python实现简单区块链结构
Apr 25 Python
Anaconda安装pytorch和paddle的方法步骤
Apr 03 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
PHP extract 将数组拆分成多个变量的函数
2010/06/30 PHP
利用Ffmpeg获得flv视频缩略图和视频时间的代码
2011/09/15 PHP
基于PHP导出Excel的小经验 完美解决乱码问题
2013/06/10 PHP
PHP生成压缩文件实例
2015/02/07 PHP
PHP实现将textarea的值根据回车换行拆分至数组
2015/06/10 PHP
游戏人文件夹程序 ver 4.03
2006/07/14 Javascript
jquery 上下滚动广告
2009/06/17 Javascript
nodejs的require模块(文件模块/核心模块)及路径介绍
2013/01/14 NodeJs
jquery $("#variable") 循环改变variable的值示例
2014/02/23 Javascript
jQuery调用ajax请求的常见方法汇总
2015/03/24 Javascript
JavaScript实现为指定对象添加多个事件处理程序的方法
2015/04/17 Javascript
jQuery实用技巧必备(下)
2015/11/03 Javascript
jquery关于事件冒泡和事件委托的技巧及阻止与允许事件冒泡的三种实现方法
2015/11/27 Javascript
Vue实现百度下拉提示搜索功能
2017/06/21 Javascript
Node 自动化部署的方法
2017/10/17 Javascript
vue 使用 vue-pdf 实现pdf在线预览的示例代码
2020/04/26 Javascript
原生js 实现表单验证功能
2021/02/08 Javascript
python使用PyGame播放Midi和Mp3文件的方法
2015/04/24 Python
Python中返回字典键的值的values()方法使用
2015/05/22 Python
Python基于win32ui模块创建弹出式菜单示例
2018/05/09 Python
Python之dict(或对象)与json之间的互相转化实例
2018/06/05 Python
scrapy-redis源码分析之发送POST请求详解
2019/05/15 Python
Python流程控制 while循环实现解析
2019/09/02 Python
Python脚本操作Excel实现批量替换功能
2019/11/20 Python
英国最大的网上药品商店:Chemist Direct
2017/12/16 全球购物
anello泰国官方网站:日本流行包包品牌
2019/08/08 全球购物
无故旷工检讨书
2014/01/26 职场文书
初中美术教学反思
2014/01/29 职场文书
公司投资建议书
2014/05/16 职场文书
消防安全承诺书
2014/05/22 职场文书
交通局领导班子群众路线教育实践活动对照检查材料思想汇报
2014/10/09 职场文书
焦裕禄纪念馆观后感
2015/06/09 职场文书
2016年六一文艺汇演开幕词
2016/03/04 职场文书
导游词之永济鹳雀楼
2020/01/16 职场文书
MySQL连表查询分组去重的实现示例
2021/07/01 MySQL
html中两种获取标签内的值的方法
2022/06/10 HTML / CSS