您的位置 首页 PyTorch 教程

面向 Numpy 用户的 PyTorch 速查表

PyTorch入门实战教程

这是一份面向 Numpy 用户的 PyTorch 入坑指南,如果你之前对 Numpy 使用得心应手,那么有了下面这份指南,你一定可以快速了解 PyTorch 里对应的数值类型以及运算等知识。

类型(Types)

NumpyPyTorch
np.ndarraytorch.Tensor
np.float32torch.float32; torch.float
np.float64torch.float64; torch.double
np.float16torch.float16; torch.half
np.int8torch.int8
np.uint8torch.uint8
np.int16torch.int16; torch.short
np.int32torch.int32; torch.int
np.int64torch.int64; torch.long

构造器(Constructor)

零和一(Ones and zeros)

NumpyPyTorch
np.empty((2, 3))torch.empty(2, 3)
np.empty_like(x)torch.empty_like(x)
np.eyetorch.eye
np.identitytorch.eye
np.onestorch.ones
np.ones_liketorch.ones_like
np.zerostorch.zeros
np.zeros_liketorch.zeros_like

从已知数据构造

NumpyPyTorch
np.array([[1, 2], [3, 4]])torch.tensor([[1, 2], [3, 4]])
np.array([3.2, 4.3], dtype=np.float16)

np.float16([3.2, 4.3])

torch.tensor([3.2, 4.3], dtype=torch.float16)
x.copy()x.clone()
np.fromfile(file)torch.tensor(torch.Storage(file))
np.frombuffer
np.fromfunction
np.fromiter
np.fromstring
np.loadtorch.load
np.loadtxt
np.concatenatetorch.cat

数值范围

NumpyPyTorch
np.arange(10)torch.arange(10)
np.arange(2, 3, 0.1)torch.arange(2, 3, 0.1)
np.linspacetorch.linspace
np.logspacetorch.logspace

构造矩阵

NumpyPyTorch
np.diagtorch.diag
np.triltorch.tril
np.triutorch.triu

参数

NumpyPyTorch
x.shapex.shape
x.stridesx.stride()
x.ndimx.dim()
x.datax.data
x.sizex.nelement()
x.dtypex.dtype

索引

NumpyPyTorch
x[0]x[0]
x[:, 0]x[:, 0]
x[indices]x[indices]
np.take(x, indices)torch.take(x, torch.LongTensor(indices))
x[x != 0]x[x != 0]

形状(Shape)变换

NumpyPyTorch
x.reshapex.reshape; x.view
x.resize()x.resize_
x.resize_as_
x.transposex.transpose or x.permute
x.flattenx.view(-1)
x.squeeze()x.squeeze()
x[:, np.newaxis]; np.expand_dims(x, 1)x.unsqueeze(1)

数据选择

NumpyPyTorch
np.put
x.putx.put_
x = np.array([1, 2, 3])

x.repeat(2) # [1, 1, 2, 2, 3, 3]

x = torch.tensor([1, 2, 3])

x.repeat(2) # [1, 2, 3, 1, 2, 3]

x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3]

np.tile(x, (3, 2))x.repeat(3, 2)
np.choose
np.sortsorted, indices = torch.sort(x, [dim])
np.argsortsorted, indices = torch.sort(x, [dim])
np.nonzerotorch.nonzero
np.wheretorch.where
x[::-1]

数值计算

NumpyPyTorch
x.minx.min
x.argminx.argmin
x.maxx.max
x.argmaxx.argmax
x.clipx.clamp
x.roundx.round
np.floor(x)torch.floor(x); x.floor()
np.ceil(x)torch.ceil(x); x.ceil()
x.tracex.trace
x.sumx.sum
x.cumsumx.cumsum
x.meanx.mean
x.stdx.std
x.prodx.prod
x.cumprodx.cumprod
x.all(x == 1).sum() == x.nelement()
x.any(x == 1).sum() > 0

数值比较

NumpyPyTorch
np.lessx.lt
np.less_equalx.le
np.greaterx.gt
np.greater_equalx.ge
np.equalx.eq
np.not_equalx.ne

 

希望这份指南能帮你快速了解 Numpy 和 PyTorch 之间的联系和区别。

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

返回顶部