python多线程方法详解


Posted in Python onJanuary 18, 2022

处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明。

1、模块安装

pip install multiprocessing
pip install joblib

2、以分块计算NDVI为例

首先导入需要的包

import numpy as np
from osgeo import gdal
import time
from multiprocessing import cpu_count
from multiprocessing import Pool
from joblib import Parallel, delayed

定义GdalUtil类,以读取遥感数据

class GdalUtil:
    def __init__(self):
        pass
    @staticmethod
    def read_file(raster_file, read_band=None):
        """读取栅格数据"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
        # 打开输入图像
        dataset = gdal.Open(raster_file, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_file)
        # 列
        raster_width = dataset.RasterXSize
        # 行
        raster_height = dataset.RasterYSize
        # 读取数据
        if read_band == None:
            data_array = dataset.ReadAsArray(0, 0, raster_width, raster_height)
        else:
            band = dataset.GetRasterBand(read_band)
            data_array = band.ReadAsArray(0, 0, raster_width, raster_height)
        return data_array
 
    @staticmethod
    def read_block_data(dataset, band_num, cols_read, rows_read, start_col=0, start_row=0):
        band = dataset.GetRasterBand(band_num)
        res_data = band.ReadAsArray(start_col, start_row, cols_read, rows_read)
        return res_data
 
    @staticmethod
    def get_raster_band(raster_path):
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
        raster_band = dataset.RasterCount
        return raster_band
 
    @staticmethod
    def get_file_size(raster_path):
        """获取栅格仿射变换参数"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
 
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
        # 列
        raster_width = dataset.RasterXSize
        # 行
        raster_height = dataset.RasterYSize
        return raster_width, raster_height
 
    @staticmethod
    def get_file_geotransform(raster_path):
        """获取栅格仿射变换参数"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
 
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
 
        # 获取输入图像仿射变换参数
        input_geotransform = dataset.GetGeoTransform()
        return input_geotransform
 
    @staticmethod
    def get_file_proj(raster_path):
        """获取栅格图像空间参考"""
        # 注册栅格驱动
        gdal.AllRegister()
        gdal.SetConfigOption('gdal_FILENAME_IS_UTF8', 'YES')
 
        # 打开输入图像
        dataset = gdal.Open(raster_path, gdal.GA_ReadOnly)
        if dataset == None:
            print('打开图像{0} 失败.\n', raster_path)
 
        # 获取输入图像空间参考
        input_project = dataset.GetProjection()
        return input_project
 
    @staticmethod
    def write_file(dataset, geotransform, project, output_path, out_format='GTiff', eType=gdal.GDT_Float32):
        """写入栅格"""
        if np.ndim(dataset) == 3:
            out_band, out_rows, out_cols = dataset.shape
        else:
            out_band = 1
            out_rows, out_cols = dataset.shape
 
        # 创建指定输出格式的驱动
        out_driver = gdal.GetDriverByName(out_format)
        if out_driver == None:
            print('格式%s 不支持Creat()方法.\n', out_format)
            return
 
        out_dataset = out_driver.Create(output_path, xsize=out_cols,
                                        ysize=out_rows, bands=out_band,
                                        eType=eType)
        # 设置输出图像的仿射参数
        out_dataset.SetGeoTransform(geotransform)
 
        # 设置输出图像的投影参数
        out_dataset.SetProjection(project)
 
        # 写出数据
        if out_band == 1:
            out_dataset.GetRasterBand(1).WriteArray(dataset)
        else:
            for i in range(out_band):
                out_dataset.GetRasterBand(i + 1).WriteArray(dataset[i])
        del out_dataset

定义计算NDVI的函数

def cal_ndvi(multi):
    '''
    计算高分NDVI
    :param multi:格式为列表,依次包含[遥感文件路径,开始行号,开始列号,待读的行数,待读的列数]
    :return: NDVI数组
    '''
    input_file, start_col, start_row, cols_step, rows_step = multi
    dataset = gdal.Open(input_file, gdal.GA_ReadOnly)
    nir_data = GdalUtil.read_block_data(dataset, 4, cols_step, rows_step, start_col=start_col, start_row=start_row)
    red_data = GdalUtil.read_block_data(dataset, 3, cols_step, rows_step, start_col=start_col, start_row=start_row)
    ndvi = (nir_data - red_data) / (nir_data + red_data)
    ndvi[(ndvi > 1.5) | (ndvi < -1)] = 0
    return ndvi
定义主函数
if __name__ == "__main__":
    input_file = r'D:\originalData\GF1\namucuo2021.tif'
    output_file = r'D:\originalData\GF1\namucuo2021_ndvi.tif'
    method = 'joblib'
    # method = 'multiprocessing'
    # 获取文件主要信息
    raster_cols, raster_rows = GdalUtil.get_file_size(input_file)
    geotransform = GdalUtil.get_file_geotransform(input_file)
    project = GdalUtil.get_file_proj(input_file)
    # 定义分块大小
    rows_block_size = 50
    cols_block_size = 50
    multi = []
    for j in range(0, raster_rows, rows_block_size):
        for i in range(0, raster_cols, cols_block_size):
            if j + rows_block_size < raster_rows:
                rows_step = rows_block_size
            else:
                rows_step = raster_rows - j
            # 数据横向步长
            if i + cols_block_size < raster_cols:
                cols_step = cols_block_size
            else:
                cols_step = raster_cols - i
            temp_multi = [input_file, i, j, cols_step, rows_step]
            multi.append(temp_multi)
 
    t1 = time.time()
    if method == 'multiprocessing':
        # multiprocessing方法
        pool = Pool(processes=cpu_count()-1)
        # 注意map函数中传入的参数应该是可迭代对象,如list;返回值为list
        res = pool.map(cal_ndvi, multi)
        pool.close()
        pool.join()
    else:
        # joblib方法
        res = Parallel(n_jobs=-1)(delayed(cal_ndvi)(input_list) for input_list in multi)
 
    t2 = time.time()
    print("Total time:" + (t2 - t1).__str__())
 
    # 将multiprocessing中的结果提取出来,放回对应的矩阵位置中
    out_data = np.zeros([raster_rows, raster_cols], dtype='float')
    for result, input_multi in zip(res, multi):
        start_col = input_multi[1]
        start_row = input_multi[2]
        cols_step = input_multi[3]
        rows_step = input_multi[4]
        out_data[start_row:start_row + rows_step, start_col:start_col + cols_step] = result
 
    GdalUtil.write_file(out_data, geotransform, project, output_file)

双重for循环时,两层for循环都使用multiprocessing时会报错,这时可以外层for循环使用joblib方法,内层for循环改为multiprocessing方法,不会报错

到此这篇关于python多线程方法详解的文章就介绍到这了,更多相关python多线程内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现定时播放mp3
Mar 29 Python
使用Python的Tornado框架实现一个简单的WebQQ机器人
Apr 24 Python
Python实现把数字转换成中文
Jun 29 Python
讲解Python的Scrapy爬虫框架使用代理进行采集的方法
Feb 18 Python
详解Python中的文件操作
Aug 28 Python
python开发环境PyScripter中文乱码问题解决方案
Sep 11 Python
使用python 爬虫抓站的一些技巧总结
Jan 10 Python
TensorFlow saver指定变量的存取
Mar 10 Python
将pandas.dataframe的数据写入到文件中的方法
Dec 07 Python
详解python函数的闭包问题(内部函数与外部函数详述)
May 17 Python
opencv python 图片读取与显示图片窗口未响应问题的解决
Apr 24 Python
如何使用Python提取Chrome浏览器保存的密码
Jun 09 Python
用Python生成会跳舞的美女
基于Pygame实现简单的贪吃蛇游戏
Dec 06 #Python
Python可变集合和不可变集合的构造方法大全
Dec 06 #Python
Python实现视频中添加音频工具详解
Dec 06 #Python
Python实现GIF动图以及视频卡通化详解
Python实现照片卡通化
用Python爬取英雄联盟的皮肤详细示例
You might like
ThinkPHP自动验证失败的解决方法
2011/06/09 PHP
PHP操作数组的一些函数整理介绍
2011/07/17 PHP
高性能PHP框架Symfony2经典入门教程
2014/07/08 PHP
PHP中把数据库查询结果输出为json格式简单实例
2015/04/09 PHP
CodeMirror2 IE7/IE8 下面未知运行时错误的解决方法
2012/03/29 Javascript
gridview生成时如何去掉style属性中的border-collapse
2014/09/30 Javascript
js中取得变量绝对值的方法
2015/01/03 Javascript
javascript中的正则表达式使用指南
2015/03/01 Javascript
JS实现漂亮的窗口拖拽效果(可改变大小、最大化、最小化、关闭)
2015/10/10 Javascript
js验证身份证号有效性并提示对应信息
2015/10/19 Javascript
常用原生JS兼容性写法汇总
2016/04/27 Javascript
jquery判断对象是否为空并遍历对象的简单实例
2016/07/26 Javascript
jQuery实现的导航下拉菜单效果示例
2016/09/05 Javascript
浅谈JS使用[ ]来访问对象属性
2016/09/21 Javascript
js中toString()和String()区别详解
2017/03/23 Javascript
JavaScript正则表达式和级联效果
2017/09/14 Javascript
Vue2.0 实现歌手列表滚动及右侧快速入口功能
2018/08/08 Javascript
微信小程序五子棋游戏的棋盘,重置,对弈实现方法【附demo源码下载】
2019/02/20 Javascript
使用RxJS更优雅地进行定时请求详析
2019/06/02 Javascript
vue el-table实现自定义表头
2019/12/11 Javascript
Vue强制组件重新渲染的方法讨论
2020/02/03 Javascript
如何使用vue slot创建一个模态框的实例代码
2020/05/24 Javascript
详解JavaScript执行模型
2020/11/16 Javascript
[01:09]DOTA2次级职业联赛 - 99战队宣传片
2014/12/01 DOTA
[02:12]2019完美世界全国高校联赛(春季赛)报名开启
2019/03/01 DOTA
Python urllib模块urlopen()与urlretrieve()详解
2013/11/01 Python
深入理解Django自定义信号(signals)
2018/10/15 Python
python @propert装饰器使用方法原理解析
2019/12/25 Python
Window版下在Jupyter中编写TensorFlow的环境搭建
2020/04/10 Python
洗车工岗位职责
2014/03/15 职场文书
三月雷锋月活动总结
2014/07/03 职场文书
ktv周年庆活动方案
2014/08/18 职场文书
公司的门卫岗位职责
2014/09/09 职场文书
2014第二批党的群众路线教育实践活动对照检查材料思想汇报
2014/09/18 职场文书
2015年学校保卫部工作总结
2015/05/11 职场文书
详解缓存穿透击穿雪崩解决方案
2021/05/28 Redis