这是一份面向 Numpy 用户的 PyTorch 入坑指南,如果你之前对 Numpy 使用得心应手,那么有了下面这份指南,你一定可以快速了解 PyTorch 里对应的数值类型以及运算等知识。
文章目录
类型(Types)
| Numpy | PyTorch |
|---|---|
| np.ndarray | torch.Tensor |
| np.float32 | torch.float32; torch.float |
| np.float64 | torch.float64; torch.double |
| np.float16 | torch.float16; torch.half |
| np.int8 | torch.int8 |
| np.uint8 | torch.uint8 |
| np.int16 | torch.int16; torch.short |
| np.int32 | torch.int32; torch.int |
| np.int64 | torch.int64; torch.long |
构造器(Constructor)
零和一(Ones and zeros)
| Numpy | PyTorch |
|---|---|
| np.empty((2, 3)) | torch.empty(2, 3) |
| np.empty_like(x) | torch.empty_like(x) |
| np.eye | torch.eye |
| np.identity | torch.eye |
| np.ones | torch.ones |
| np.ones_like | torch.ones_like |
| np.zeros | torch.zeros |
| np.zeros_like | torch.zeros_like |
从已知数据构造
| Numpy | PyTorch |
|---|---|
| np.array([[1, 2], [3, 4]]) | torch.tensor([[1, 2], [3, 4]]) |
| np.array([3.2, 4.3], dtype=np.float16) np.float16([3.2, 4.3]) | torch.tensor([3.2, 4.3], dtype=torch.float16) |
| x.copy() | x.clone() |
| np.fromfile(file) | torch.tensor(torch.Storage(file)) |
| np.frombuffer | |
| np.fromfunction | |
| np.fromiter | |
| np.fromstring | |
| np.load | torch.load |
| np.loadtxt | |
| np.concatenate | torch.cat |
数值范围
| Numpy | PyTorch |
|---|---|
| np.arange(10) | torch.arange(10) |
| np.arange(2, 3, 0.1) | torch.arange(2, 3, 0.1) |
| np.linspace | torch.linspace |
| np.logspace | torch.logspace |
构造矩阵
| Numpy | PyTorch |
|---|---|
| np.diag | torch.diag |
| np.tril | torch.tril |
| np.triu | torch.triu |
参数
| Numpy | PyTorch |
|---|---|
| x.shape | x.shape |
| x.strides | x.stride() |
| x.ndim | x.dim() |
| x.data | x.data |
| x.size | x.nelement() |
| x.dtype | x.dtype |
索引
| Numpy | PyTorch |
|---|---|
| x[0] | x[0] |
| x[:, 0] | x[:, 0] |
| x[indices] | x[indices] |
| np.take(x, indices) | torch.take(x, torch.LongTensor(indices)) |
| x[x != 0] | x[x != 0] |
形状(Shape)变换
| Numpy | PyTorch |
|---|---|
| x.reshape | x.reshape; x.view |
| x.resize() | x.resize_ |
| x.resize_as_ | |
| x.transpose | x.transpose or x.permute |
| x.flatten | x.view(-1) |
| x.squeeze() | x.squeeze() |
| x[:, np.newaxis]; np.expand_dims(x, 1) | x.unsqueeze(1) |
数据选择
| Numpy | PyTorch |
|---|---|
| np.put | |
| x.put | x.put_ |
| x = np.array([1, 2, 3]) x.repeat(2) # [1, 1, 2, 2, 3, 3] | x = torch.tensor([1, 2, 3]) x.repeat(2) # [1, 2, 3, 1, 2, 3] x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3] |
| np.tile(x, (3, 2)) | x.repeat(3, 2) |
| np.choose | |
| np.sort | sorted, indices = torch.sort(x, [dim]) |
| np.argsort | sorted, indices = torch.sort(x, [dim]) |
| np.nonzero | torch.nonzero |
| np.where | torch.where |
| x[::-1] |
数值计算
| Numpy | PyTorch |
|---|---|
| x.min | x.min |
| x.argmin | x.argmin |
| x.max | x.max |
| x.argmax | x.argmax |
| x.clip | x.clamp |
| x.round | x.round |
| np.floor(x) | torch.floor(x); x.floor() |
| np.ceil(x) | torch.ceil(x); x.ceil() |
| x.trace | x.trace |
| x.sum | x.sum |
| x.cumsum | x.cumsum |
| x.mean | x.mean |
| x.std | x.std |
| x.prod | x.prod |
| x.cumprod | x.cumprod |
| x.all | (x == 1).sum() == x.nelement() |
| x.any | (x == 1).sum() > 0 |
数值比较
| Numpy | PyTorch |
|---|---|
| np.less | x.lt |
| np.less_equal | x.le |
| np.greater | x.gt |
| np.greater_equal | x.ge |
| np.equal | x.eq |
| np.not_equal | x.ne |
希望这份指南能帮你快速了解 Numpy 和 PyTorch 之间的联系和区别。
本站微信群、QQ群(三群号 726282629):







