pytorch中使用cuda扩展的实现示例


Posted in Python onFebruary 12, 2020

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的内存泄漏及gc模块的使用分析
Jul 16 Python
Python isinstance函数介绍
Apr 14 Python
详解使用 pyenv 管理多个版本 python 环境
Oct 19 Python
python中的迭代和可迭代对象代码示例
Dec 27 Python
python爬虫爬取某站上海租房图片
Feb 04 Python
python进行TCP端口扫描的实现
Dec 21 Python
Python设置matplotlib.plot的坐标轴刻度间隔以及刻度范围
Jun 25 Python
pyQt5实时刷新界面的示例
Jun 25 Python
使用python实现男神女神颜值打分系统(推荐)
Oct 31 Python
python不使用for计算两组、多个矩形两两间的iou方式
Jan 18 Python
Python爬虫工具requests-html使用解析
Apr 29 Python
python批量提取图片信息并保存的实现
Feb 05 Python
pycharm内无法import已安装的模块问题解决
Feb 12 #Python
PyTorch笔记之scatter()函数的使用
Feb 12 #Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 #Python
pycharm无法导入本地模块的解决方式
Feb 12 #Python
解决pycharm中导入自己写的.py函数出错问题
Feb 12 #Python
解决pycharm同一目录下无法import其他文件
Feb 12 #Python
适合Python初学者的一些编程技巧
Feb 12 #Python
You might like
php实现根据字符串生成对应数组的方法
2014/09/22 PHP
如何让CI框架支持service层
2014/10/29 PHP
一个完整的PHP类包含的七种语法说明
2015/06/04 PHP
PHP生成随机字符串(3种方法)
2015/09/25 PHP
ThinkPHP模板循环输出Volist标签用法实例详解
2016/03/23 PHP
如何直接访问php实例对象中的private属性详解
2017/10/12 PHP
jQuery EasyUI API 中文文档 - Tree树使用介绍
2011/11/19 Javascript
JS实现清除指定cookies的方法
2014/09/20 Javascript
jquery+css实现的红色线条横向二级菜单效果
2015/08/22 Javascript
如何利用JQuery实现从底部回到顶部的功能
2016/12/27 Javascript
jQuery.Validate表单验证插件的使用示例详解
2017/01/04 Javascript
Vue.js自定义指令的用法与实例解析
2017/01/18 Javascript
jquery表单验证实例仿Toast提示效果
2017/03/03 Javascript
js判断PC端与移动端跳转
2020/12/24 Javascript
微信小程序 实现点击添加移除class
2017/06/12 Javascript
ES6扩展运算符用法实例分析
2017/10/31 Javascript
全站最详细的Vuex教程
2018/04/13 Javascript
ng-repeat指令在迭代对象时的去重方法
2018/10/02 Javascript
基于Fixed定位的框选功能的实现代码
2019/05/13 Javascript
小程序封装路由文件和路由方法(5种全解析)
2019/05/26 Javascript
vue 动态创建组件的两种方法
2020/12/31 Vue.js
python动态加载包的方法小结
2016/04/18 Python
Python标准库06之子进程 (subprocess包) 详解
2016/12/07 Python
Python中input与raw_input 之间的比较
2017/08/20 Python
python多进程实现进程间通信实例
2017/11/24 Python
IntelliJ IDEA安装运行python插件方法
2018/12/10 Python
python实现两个字典合并,两个list合并
2019/12/02 Python
法国春天百货官网:Printemps.com
2020/06/29 全球购物
美国轻奢时尚购物网站:REVOLVE(支持中文)
2020/07/18 全球购物
Java如何支持I18N?
2016/10/31 面试题
幼儿园社区活动总结
2014/07/07 职场文书
重阳节活动总结
2014/08/27 职场文书
大学生党校培训心得体会
2014/09/11 职场文书
庆祝国庆节演讲稿2014
2014/09/19 职场文书
2015年财务经理工作总结
2015/05/13 职场文书
检举信的写法
2019/04/10 职场文书