python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现对excel文件列表值进行统计的方法
Jul 25 Python
Python网络爬虫出现乱码问题的解决方法
Jan 05 Python
Python实现PS滤镜碎片特效功能示例
Jan 24 Python
Python读取Word(.docx)正文信息的方法
Mar 15 Python
Django项目中用JS实现加载子页面并传值的方法
May 28 Python
浅谈Python中的全局锁(GIL)问题
Jan 11 Python
Python、 Pycharm、Django安装详细教程(图文)
Apr 12 Python
django 使用 PIL 压缩图片的例子
Aug 16 Python
python打包成so文件过程解析
Sep 28 Python
Django使用消息提示简单的弹出个对话框实例
Nov 15 Python
python标准库os库的函数介绍
Feb 12 Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
php通过array_push()函数添加多个变量到数组末尾的方法
2015/03/18 PHP
php实现中文字符截取防乱码方法汇总
2015/04/29 PHP
php开发时容易忘记的一些技术细节
2016/02/03 PHP
PHP 数组基本操作小结(推荐)
2016/06/13 PHP
php基于mcrypt_encrypt和mcrypt_decrypt实现字符串加密解密的方法
2016/07/12 PHP
解释&&和||在javascript中的另类用法
2014/07/28 Javascript
js+html5实现canvas绘制简单矩形的方法
2015/06/05 Javascript
javascript实现动态标签云
2015/10/16 Javascript
jquery如何获取元素的滚动条高度等实现代码
2015/10/19 Javascript
javascript实现全角转半角的方法
2016/01/23 Javascript
JavaScript中输出信息的方法(信息确认框-提示输入框-文档流输出)
2016/06/12 Javascript
jQuery实现带右侧索引功能的通讯录示例【附源码下载】
2018/04/17 jQuery
vue实现自定义多选与单选的答题功能
2018/07/05 Javascript
JS实现的视频弹幕效果示例
2018/08/17 Javascript
Node.js系列之安装配置与基本使用(1)
2019/08/30 Javascript
Vue.js实现可编辑的表格
2019/12/11 Javascript
js实现简单放大镜效果
2020/03/07 Javascript
[02:21]DOTA2英雄基础教程 蝙蝠骑士
2013/12/16 DOTA
python 测试实现方法
2008/12/24 Python
python实现通过代理服务器访问远程url的方法
2015/04/29 Python
Python生成8位随机字符串的方法分析
2017/12/05 Python
python flask解析json数据不完整的解决方法
2019/05/26 Python
django多对多表的创建,级联删除及手动创建第三张表
2019/07/25 Python
利用matplotlib为图片上添加触发事件进行交互
2020/04/23 Python
Python 字符串池化的前提
2020/07/03 Python
CSS3 优势以及网页设计师如何使用CSS3技术
2009/07/29 HTML / CSS
HTML5中的Scoped属性使用实例
2014/04/23 HTML / CSS
澳大利亚的奢侈品牌:Oroton
2016/08/26 全球购物
介绍JAVA 中的Collection FrameWork(及如何写自己的数据结构)
2014/10/31 面试题
行政助理的职责
2013/11/14 职场文书
工地安全检查制度
2014/02/04 职场文书
《孔子拜师》教学反思
2014/02/24 职场文书
北体毕业生求职信
2014/02/28 职场文书
DSP接收机前端设想
2022/04/05 无线电
Axios代理配置及封装响应拦截处理方式
2022/04/07 Vue.js
在SQL Server中使用 Try Catch 处理异常的示例详解
2022/07/15 SQL Server