python多线程方法详解


Posted in Python onJanuary 18, 2022

处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明。

1、模块安装

pip install multiprocessing
pip install joblib

2、以分块计算NDVI为例

首先导入需要的包

import numpy as np
from osgeo import gdal
import time
from multiprocessing import cpu_count
from multiprocessing import Pool
from joblib import Parallel, delayed

定义GdalUtil类,以读取遥感数据

class GdalUtil:
    def __init__(self):
        pass
    @staticmethod
    def read_file(raster_file, read_band=None):
        """读取栅格数据"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
        # 打开输入图像
        dataset = gdal.Open(raster_file, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_file)
        # 列
        raster_width = dataset.RasterXSize
        # 行
        raster_height = dataset.RasterYSize
        # 读取数据
        if read_band == None:
            data_array = dataset.ReadAsArray(0, 0, raster_width, raster_height)
        else:
            band = dataset.GetRasterBand(read_band)
            data_array = band.ReadAsArray(0, 0, raster_width, raster_height)
        return data_array
 
    @staticmethod
    def read_block_data(dataset, band_num, cols_read, rows_read, start_col=0, start_row=0):
        band = dataset.GetRasterBand(band_num)
        res_data = band.ReadAsArray(start_col, start_row, cols_read, rows_read)
        return res_data
 
    @staticmethod
    def get_raster_band(raster_path):
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
        raster_band = dataset.RasterCount
        return raster_band
 
    @staticmethod
    def get_file_size(raster_path):
        """获取栅格仿射变换参数"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
 
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
        # 列
        raster_width = dataset.RasterXSize
        # 行
        raster_height = dataset.RasterYSize
        return raster_width, raster_height
 
    @staticmethod
    def get_file_geotransform(raster_path):
        """获取栅格仿射变换参数"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
 
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
 
        # 获取输入图像仿射变换参数
        input_geotransform = dataset.GetGeoTransform()
        return input_geotransform
 
    @staticmethod
    def get_file_proj(raster_path):
        """获取栅格图像空间参考"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
 
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
 
        # 获取输入图像空间参考
        input_project = dataset.GetProjection()
        return input_project
 
    @staticmethod
    def write_file(dataset, geotransform, project, output_path, out_format='GTiff', eType=gdal.GDT_Float32):
        """写入栅格"""
        if np.ndim(dataset) == 3:
            out_band, out_rows, out_cols = dataset.shape
        else:
            out_band = 1
            out_rows, out_cols = dataset.shape
 
        # 创建指定输出格式的驱动
        out_driver = gdal.GetDriverByName(out_format)
        if out_driver == None:
            print('格式%s 不支持Creat()方法.\n', out_format)
            return
 
        out_dataset = out_driver.Create(output_path, xsize=out_cols,
                                        ysize=out_rows, bands=out_band,
                                        eType=eType)
        # 设置输出图像的仿射参数
        out_dataset.SetGeoTransform(geotransform)
 
        # 设置输出图像的投影参数
        out_dataset.SetProjection(project)
 
        # 写出数据
        if out_band == 1:
            out_dataset.GetRasterBand(1).WriteArray(dataset)
        else:
            for i in range(out_band):
                out_dataset.GetRasterBand(i + 1).WriteArray(dataset[i])
        del out_dataset

定义计算NDVI的函数

def cal_ndvi(multi):
    '''
    计算高分NDVI
    :param multi:格式为列表,依次包含[遥感文件路径,开始行号,开始列号,待读的行数,待读的列数]
    :return: NDVI数组
    '''
    input_file, start_col, start_row, cols_step, rows_step = multi
    dataset = gdal.Open(input_file, gdal.GA_ReadOnly)
    nir_data = GdalUtil.read_block_data(dataset, 4, cols_step, rows_step, start_col=start_col, start_row=start_row)
    red_data = GdalUtil.read_block_data(dataset, 3, cols_step, rows_step, start_col=start_col, start_row=start_row)
    ndvi = (nir_data - red_data) / (nir_data + red_data)
    ndvi[(ndvi > 1.5) | (ndvi < -1)] = 0
    return ndvi
定义主函数
if __name__ == "__main__":
    input_file = r'D:\originalData\GF1\namucuo2021.tif'
    output_file = r'D:\originalData\GF1\namucuo2021_ndvi.tif'
    method = 'joblib'
    # method = 'multiprocessing'
    # 获取文件主要信息
    raster_cols, raster_rows = GdalUtil.get_file_size(input_file)
    geotransform = GdalUtil.get_file_geotransform(input_file)
    project = GdalUtil.get_file_proj(input_file)
    # 定义分块大小
    rows_block_size = 50
    cols_block_size = 50
    multi = []
    for j in range(0, raster_rows, rows_block_size):
        for i in range(0, raster_cols, cols_block_size):
            if j + rows_block_size < raster_rows:
                rows_step = rows_block_size
            else:
                rows_step = raster_rows - j
            # 数据横向步长
            if i + cols_block_size < raster_cols:
                cols_step = cols_block_size
            else:
                cols_step = raster_cols - i
            temp_multi = [input_file, i, j, cols_step, rows_step]
            multi.append(temp_multi)
 
    t1 = time.time()
    if method == 'multiprocessing':
        # multiprocessing方法
        pool = Pool(processes=cpu_count()-1)
        # 注意map函数中传入的参数应该是可迭代对象,如list;返回值为list
        res = pool.map(cal_ndvi, multi)
        pool.close()
        pool.join()
    else:
        # joblib方法
        res = Parallel(n_jobs=-1)(delayed(cal_ndvi)(input_list) for input_list in multi)
 
    t2 = time.time()
    print("Total time:" + (t2 - t1).__str__())
 
    # 将multiprocessing中的结果提取出来,放回对应的矩阵位置中
    out_data = np.zeros([raster_rows, raster_cols], dtype='float')
    for result, input_multi in zip(res, multi):
        start_col = input_multi[1]
        start_row = input_multi[2]
        cols_step = input_multi[3]
        rows_step = input_multi[4]
        out_data[start_row:start_row + rows_step, start_col:start_col + cols_step] = result
 
    GdalUtil.write_file(out_data, geotransform, project, output_file)

双重for循环时,两层for循环都使用multiprocessing时会报错,这时可以外层for循环使用joblib方法,内层for循环改为multiprocessing方法,不会报错

到此这篇关于python多线程方法详解的文章就介绍到这了,更多相关python多线程内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python和pyqt实现360的CLable控件
Feb 21 Python
浅析Python中将单词首字母大写的capitalize()方法
May 18 Python
pandas apply 函数 实现多进程的示例讲解
Apr 20 Python
python实现输入数字的连续加减方法
Jun 22 Python
Python之inspect模块实现获取加载模块路径的方法
Oct 16 Python
利用python和百度地图API实现数据地图标注的方法
May 13 Python
python导包的几种方法(自定义包的生成以及导入详解)
Jul 15 Python
8段用于数据清洗Python代码(小结)
Oct 31 Python
python GUI库图形界面开发之PyQt5美化窗体与控件(异形窗体)实例
Feb 25 Python
Python文本文件的合并操作方法代码实例
Mar 31 Python
Python基于gevent实现高并发代码实例
May 15 Python
python如何支持并发方法详解
Jul 25 Python
用Python生成会跳舞的美女
基于Pygame实现简单的贪吃蛇游戏
Dec 06 #Python
Python可变集合和不可变集合的构造方法大全
Dec 06 #Python
Python实现视频中添加音频工具详解
Dec 06 #Python
Python实现GIF动图以及视频卡通化详解
Python实现照片卡通化
用Python爬取英雄联盟的皮肤详细示例
You might like
php支持断点续传、分块下载的类
2016/05/02 PHP
PHP微信公众号自动发送红包API
2016/06/01 PHP
PHP面向对象程序设计OOP继承用法入门示例
2016/12/27 PHP
javascript 一个自定义长度的文本自动换行的函数
2007/08/19 Javascript
扩展jQuery 键盘事件的几个基本方法
2009/10/30 Javascript
jQuery学习笔记之jQuery的DOM操作
2010/12/22 Javascript
让你的博文自动带上缩址的实现代码,方便发到微博客上
2010/12/28 Javascript
ToolTips JQEURY插件之简洁小提示框效果
2011/11/19 Javascript
javascript在当前窗口关闭前检测窗口是否关闭
2014/09/29 Javascript
Bootstrap和Java分页实例第一篇
2016/12/23 Javascript
jQuery EasyUI Panel面板组件使用详解
2017/02/28 Javascript
jQuery插件FusionCharts绘制的3D饼状图效果实例【附demo源码下载】
2017/03/03 Javascript
JS简单获取当前日期和农历日期的方法
2017/04/17 Javascript
js实现音乐播放控制条
2017/09/09 Javascript
jQuery+CSS实现的table表格行列转置功能示例
2018/01/08 jQuery
解决bootstrap-select 动态加载数据不显示的问题
2018/08/10 Javascript
微信小程序实现图片上传
2019/05/23 Javascript
通过javascript实现段落的收缩与展开
2019/06/26 Javascript
[46:25]DOTA2上海特级锦标赛主赛事日 - 4 败者组第五轮 MVP.Phx VS EG第二局
2016/03/05 DOTA
python实现逻辑回归的方法示例
2017/05/02 Python
深入理解Django自定义信号(signals)
2018/10/15 Python
pyqt5之将textBrowser的内容写入txt文档的方法
2019/06/21 Python
python3 自动打印出最新版本执行的mysql2redis实例
2020/04/09 Python
jupyter notebook 多环境conda kernel配置方式
2020/04/10 Python
设置jupyter中DataFrame的显示限制方式
2020/04/12 Python
HTML5制作酷炫音频播放器插件图文教程
2014/12/30 HTML / CSS
HTML5 DeviceOrientation实现手机网站摇一摇功能代码实例
2015/04/24 HTML / CSS
Ryderwear澳洲官网:澳大利亚高端健身训练装备品牌
2018/09/18 全球购物
AVI-8手表美国官方商店:AVI-8 USA
2019/04/10 全球购物
2014年党风建设工作总结
2014/11/19 职场文书
python 自动刷新网页的两种方法
2021/04/20 Python
Django显示可视化图表的实践
2021/05/10 Python
基于Redis延迟队列的实现代码
2021/05/13 Redis
MySQL8.0的WITH查询详情
2021/08/30 MySQL
Python 全局空间和局部空间
2022/04/06 Python
Mysql排查分析慢sql之explain实战案例
2022/04/19 MySQL