您的位置 首页 PyTorch 教程

PyTorch 源码浅析(三)

PyTorch入门实战教程

THC——基于CUDA的张量代数库

THC里实际上已经开始使用C++了,并不是一个像TH一样的纯C的库。这可能也是个人更偏爱mshadow的原因,有些地方写得像C有些地方像C++就多少有些觉得丑。

不过这里还是要先介绍一些C++里常见的操作,它们在PyTorch里作为基本的工具(utils)出现。

利用模板提取类型信息

在C++里利用模板做泛型编程时,我们常常会遇到这种需要获得一些类型信息的问题。虽然默认的标准库里有type_trait但是有时候还是需要自己实现一些相关的trait。以下用THC/THCTensorTypeUtils.cuh中的SameType进行说明,CUDA C++中大部分语法和C++是一样的,并且目前新版本的nvcc也是支持C++11标准的,所以可以放心使用一些你喜欢的特性。

SameType的作用是判断两个类型是否相同,我们可以借助C++语言中类型的偏特化来实现(注意这个不能用在函数上),C++的偏特化会产生几个不同模板参数,但是成员不同,名称相同的类,我们可以用静态类型来记录类型的信息,将判断类型是否相等的操作交给模板。

首先是一般的情况,两个类型可能不相等,这个时候让其静态成员same为false,然后再用偏特化的版本

C++的编译器会优先匹配偏特化的类型,这样当你填入模板中的两个类型相同时,就会匹配到这个类型然后得到same的值为true

类似地,在THC/THCTensorTypeUtils.cuh中实现了另外一个用来获取Tensor类型信息的类TensorInfo。由于THC并非一个使用模板提供泛型的库,它依然采用了TH中用宏产生泛型的方法,这里需要手动填入类型名称,PyTorch选择了用宏TENSOR_UTILS来填入类型名称。

这种方法在THC里非常常见,这主要是因为不同类型的具体实现可能会有所不同(比如16位浮点类型)而需要根据不同类型去写其函数。实际上在mshadow中,其expression template(表达式模板,一种利用C++模板实现懒惰求值的方法,详见mshadow的README,讲得非常详细)引擎也是通过类似的方式来添加算符的。

但是不同于内建的浮点类型和整数类型,CUDA中复数类型有float2,cuComplex (实际上就是float2),thrust::complex三种选择,同时还需要与TH中使用的C中的复数类型_Complex,以及胶水部分的std::complex进行转换,使得在THNumerics等部分的实现变得比CPU更加麻烦,有时候需要自己去实现复数的基本数学函数等,还需要小心使用reinterpret_cast。此外后面会讲到上一篇最后提到的reduce操作在v0.3中所用到的trick不能在复数上直接应用,这是因为在后来PyTorch在THCTensorMathReduce使用了__shfl_xxx系列的函数来提供更快的reduction等操作(具体的方法见上一篇给出的NVIDIA的那个ppt),但是这个不支持float2,得想办法自己workaround,这也导致PZ的那个complex版本虽然有CUDA支持,但是一直停留在v0.1。

BLAS和Magma

不同于CPU,在GPU上提供BLAS和线性代数算法支持的库是cuBLAS和Magma,PyTorch会在编译时检测到这两个库的时候使用它们但是在没有检测到的时候会调用默认实现,会慢。其支持的功能和CPU上的Lapack是一致的。

Reduce

我们还是通过一些具体的例子来看它是怎么工作的,比如在 THC/THCTensorMathReduce.cu 中的sum函数,它的作用是对某个维度上的元素求和。

实际上可以看到除了处理多GPU的THCAssertSameGPU实现几乎和CPU中的类似,也是用一个reduceDim函数将某个操作遍历这个维度上的所有元素在TH中对应的宏为TH_TENSOR_APPLY_DIM系列,但不同的是由于现在reduce系列的操作由函数实现,不能简单地利用宏的文本替换将需要做的操作简单放入,这里的操作都被定义为了类型的静态成员,包括类型转换。

然后我们看到这里使用的函数为THC_reduceDim 我们尝试去找一下它的定义,在源码目录下用

发现它被定义在 THC/THCReduce.cuh的258行(可能因版本有所不同)。

类似于CPU中的tensor apply,这里也会因内存是否连续而进行优化,思路是类似的,对连续内存尽量合并为一个维度进行遍历,只是采用了GPU加速。

其余的部分和TH很类似,感觉不用介绍了。

源码浅析系列目录

PyTorch 源码浅析(一)

PyTorch 源码浅析(二)

PyTorch 源码浅析(三)

PyTorch 源码浅析(四)

文章来源:罗秀哲知乎专栏

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部