python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python登录QQ邮箱发信的实现代码
Feb 10 Python
Windows和Linux下使用Python访问SqlServer的方法介绍
Mar 10 Python
Python编程实现从字典中提取子集的方法分析
Feb 09 Python
Python基础教程之利用期物处理并发
Mar 29 Python
python 2.7 检测一个网页是否能正常访问的方法
Dec 26 Python
python实现从本地摄像头和网络摄像头截取图片功能
Jul 11 Python
Python Django Cookie 简单用法解析
Aug 13 Python
keras获得model中某一层的某一个Tensor的输出维度教程
Jan 24 Python
使用pyecharts1.7进行简单的可视化大全
May 17 Python
Django实现后台上传并显示图片功能
May 29 Python
python字典key不能是可以是啥类型
Aug 04 Python
利用Python实现最小二乘法与梯度下降算法
Feb 21 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
PHP采集相关教程之一 CURL函数库
2010/02/15 PHP
PHP的cURL库功能简介 抓取网页、POST数据及其他
2011/04/07 PHP
PHP+FastCGI+Nginx配置PHP运行环境
2014/08/07 PHP
php获取指定范围内最接近数的方法
2015/06/02 PHP
Zend Framework教程之Zend_Config_Xml用法分析
2016/03/23 PHP
浅谈php fopen下载远程文件的函数
2016/11/18 PHP
PHP面向对象五大原则之依赖倒置原则(DIP)详解
2018/04/08 PHP
yii框架结合charjs实现统计30天数据的方法
2020/04/04 PHP
ModelDialog JavaScript模态对话框类代码
2011/04/17 Javascript
AngularJS语法详解(续)
2015/01/23 Javascript
JQUERY简单按钮轮换选中效果实现方法
2015/05/07 Javascript
JavaScript学习笔记之数组的增、删、改、查
2016/03/23 Javascript
原生JS实现图片轮播与淡入效果的简单实例
2016/08/21 Javascript
用iframe实现不刷新整个页面上传图片的实例
2016/11/18 Javascript
ES6概念 Symbol.keyFor()方法
2016/12/25 Javascript
JS实现选定指定HTML元素对象中指定文本内容功能示例
2017/02/13 Javascript
HTML的select控件美化
2017/03/27 Javascript
Vue分页组件实例代码
2017/04/17 Javascript
javascript中mouseenter与mouseover的异同
2017/06/06 Javascript
在 Typescript 中使用可被复用的 Vue Mixin功能
2018/04/17 Javascript
bootstrap模态框关闭后清除模态框的数据方法
2018/08/10 Javascript
详解如何在Node.js的httpServer中接收前端发送的arraybuffer数据
2018/11/11 Javascript
JS/HTML5游戏常用算法之碰撞检测 包围盒检测算法详解【凹多边形的分离轴检测算法】
2018/12/13 Javascript
关于Python中异常(Exception)的汇总
2017/01/18 Python
对Python信号处理模块signal详解
2019/01/09 Python
Python爬虫解析网页的4种方式实例及原理解析
2019/12/30 Python
pytorch:model.train和model.eval用法及区别详解
2020/02/20 Python
Python模拟登录和登录跳转的参考示例
2020/10/30 Python
CSS3简单实现照片墙
2014/12/12 HTML / CSS
KLOOK客路:发现更好玩的世界,预订独一无二的旅行体验
2016/12/16 全球购物
高中考试作弊检讨书
2014/01/14 职场文书
大型演出策划方案
2014/05/28 职场文书
原来闭幕词是这样写的呀!
2019/07/01 职场文书
Pytorch使用shuffle打乱数据的操作
2021/05/20 Python
星际争霸:毕姥爷vs解冻03
2022/04/01 星际争霸
讲解Python实例练习逆序输出字符串
2022/05/06 Python