您的位置 首页 PyTorch 教程

如何有效地阅读 PyTorch 的源代码?

PyTorch入门实战教程

本文作者罗秀哲,来源知乎

PyTorch的模块化感觉很好,读某一个部分不需要很清楚其它部分的具体实现。不过复数是大部分文件里都需要检查一下要不要改的…

读底层代码之前

你需要简单了解以下接口(知道在哪里查文档就好)

  1. BLAS/LAPACK的接口,这个各个BLAS库都或多或少有文档来描述,但是个人认为其中intel mkl的最好。有时候还是要参考一下最早的这个BLAS (Basic Linear Algebra Subprograms) 但是注意PyTorch不是调用的CBLAS的接口,而是直接调用Fortran的函数所以所有的变量都是通过指针传递,并且相应的routine后面有一个下划线_
  2. CUBLAS/Thrust这个是CUDA的BLAS部分会用到。这个文档直接读NVIDIA官方的就好了。
  3. 相关的数学知识

CPU张量运算库TH

这个部分是PyTorch里最底层的部分,但是也是我觉得写的最好的部分。感觉读懂这一部分其他底层的代码就都很容易读了。它用宏实现了C里的泛型支持,然后用宏+命名实现了简单的OOP。甚至在管理文件的部分还有virtual table。

这个最早是从Lua版本的Torch迁移过来的,那会儿似乎还没有完整的SIMD。现在提供了完整的SIMD支持。数据结构部分主要实现了THStorageTHTensor两个。这两个是干嘛用的在PyTorch的Python文档里都有写。其实比起我们物理这边用的Tensor Network的实现要简单,因为不用连腿….

THTensor不负责存储,是一种查看存储对象(THStorage)的方法。感觉实现思路上和MXNet的backend,mshadow里的Tensor有一些类似的地方。不过因为C不支持重载,所以这里完全没有懒惰求值,当然用的时候也许会有些麻烦。

这里比较有特点的地方是用宏实现的泛型。这个方法在底层的库里有广泛的使用。等到lib目录上层的代码能够使用THPP这个支持C++模板的封装以后就不会见到了。(THPP和facebook/THPP那个不一样,后者是一个C++的TH实现)。

TH主要通过宏替换,比如real(这个名字在我改成复数的时候可别扭了…然后还得小心别和cpp里的std::real一个不小心重名,然后被替换成std::double之类的东西…)真正用来实现的需要泛型的源代码写在generic目录下,然后在外面通过几个头文件里定义的宏进行不同变量名称的替换。

然后还有一个比较难读的宏应该就是TH_TENSOR_APPLY这个开头的几个宏函数了。这几个比较像是在现在更上层语言里用到的map,或者broadcast。也就是把一段操作应用到每个Tensor的元素上去。为了支持不同个数的Tensor之间的运算,有好几个不同个数Tensor的。比如TH_TENSOR_APPLY_3就是说三个Tensor里要怎么做element-wise运算。然后需要注意的是,在宏函数里输入的操作部分也就是CODE部分,一定不要出现逗号,和内嵌的宏。

然后TensorMath和Conv等部分的函数都和Python函数功能类似,不懂直接看Python文档就行了。

CPU神经网络库THNN

这个部分就很简单了,就是一些Python层的Function对象在底层的定义和实现。不懂的就直接查Python文档,都有对应的。

自动微分部分

这个是在C++里写的,然后接了Python。C++里就是实现了三个list用来存Variable啊什么的…大致思路应该是把使用过forward的函数记下来,然后后面调用backward来算梯度。所以叫Tape-based

剩下的大部分就是Python和一些glue代码了。Python的文档很全,更好读,直接看就行。

不太有空更新,如果有空我会更的…长期待续…不过其实你要是读了TH部分,基本感觉别的也没什么好说的了。

最近在给PyTorch增加对复数的支持(其实我也不想写代码,然而没人给造轮子啊…之前马普的一个哥们儿已经改了一些CPU和CUDA的部分,但是问题和没改的还很多)所以大概在两三个月前开始读了从上层的Python到底层C和CUDA的实现。感觉只有一开始在C的部分稍微需要动一下脑子,在向量化那一块可能还需要学习一下SIMD/SSE的指令,其它地方就都很容易读。

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部