pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用PyHook监听鼠标和键盘事件实例
Jul 18 Python
python+Django+apache的配置方法详解
Jun 01 Python
详解flask入门模板引擎
Jul 18 Python
python通过zabbix api获取主机
Sep 17 Python
python读取文本中的坐标方法
Oct 14 Python
解决pycharm运行时interpreter为空的问题
Oct 29 Python
使用python对多个txt文件中的数据进行筛选的方法
Jul 10 Python
Python线上环境使用日志的及配置文件
Jul 28 Python
windows10在visual studio2019下配置使用openCV4.3.0
Jul 14 Python
2021年的Python 时间轴和即将推出的功能详解
Jul 27 Python
python 决策树算法的实现
Oct 09 Python
opencv-python图像配准(匹配和叠加)的实现
Jun 23 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
dede全站URL静态化改造[070414更正]
2007/04/17 PHP
用穿越火线快速入门php面向对象
2012/02/22 PHP
PHP读取数据库并按照中文名称进行排序实现代码
2013/01/29 PHP
PHP利用MySQL保存session的实现思路及示例代码
2014/09/09 PHP
Laravel框架运行出错提示RuntimeException No application encryption key has been specified.解决方法
2019/04/02 PHP
微信JSSDK分享功能图文实例详解
2019/04/08 PHP
JavaScript语句可以不以;结尾的烦恼
2007/03/08 Javascript
setTimeout和setInterval的区别你真的了解吗?
2011/03/31 Javascript
ExtJS判断IE浏览器类型的方法
2014/02/10 Javascript
ECMAScript5中的对象存取器属性:getter和setter介绍
2014/12/08 Javascript
JavaScript弹出新窗口后向父窗口输出内容的方法
2015/04/06 Javascript
基于HTML5上使用iScroll实现下拉刷新,上拉加载更多
2016/05/21 Javascript
Bootstrap modal 多弹窗之叠加显示不出弹窗问题的解决方案
2017/02/23 Javascript
js按条件生成随机json:randomjson实现方法
2017/04/07 Javascript
nuxt.js中间件实现拦截权限判断的方法
2018/11/21 Javascript
JavaScript学习笔记之DOM操作实例分析
2019/01/08 Javascript
JS/jQuery实现超简单的Table表格添加,删除行功能示例
2019/07/31 jQuery
vue输入节流,避免实时请求接口的实例代码
2019/10/30 Javascript
微信小程序淘宝首页双排图片布局排版代码(推荐)
2020/10/29 Javascript
[01:00:26]Ti4主赛事胜者组第一天 EG vs NEWBEE 1
2014/07/19 DOTA
python实现将内容分行输出
2015/11/05 Python
Python中矩阵库Numpy基本操作详解
2017/11/21 Python
解决nohup重定向python输出到文件不成功的问题
2018/05/11 Python
python实现随机漫步算法
2018/08/27 Python
Python GUI布局尺寸适配方法
2018/10/11 Python
python 梯度法求解函数极值的实例
2019/07/10 Python
Python3创建Django项目的几种方法(3种)
2020/06/03 Python
HTML5 虚拟键盘出现挡住输入框的解决办法
2017/02/14 HTML / CSS
巴西最大的玩具连锁店:Ri Happy
2020/06/17 全球购物
27个经典Linux面试题及答案,你知道几个?
2013/01/10 面试题
采购主管的岗位职责
2013/12/17 职场文书
化学专业自荐信
2014/05/28 职场文书
幼儿园秋季开学寄语
2014/08/02 职场文书
教师学习群众路线心得体会
2014/11/04 职场文书
2014年卫生保健工作总结
2014/12/08 职场文书
Python中Cookies导出某站用户数据的方法
2021/05/17 Python