PyTorch 学习笔记(一):什么是 PyTorch

PyTorch入门实战教程

PyTorch是一个动态的建图的工具。不像Tensorflow那样,先建图,然后通过feed和run重复执行建好的图。相对来说,PyTorch具有更好的灵活性。

编写一个深度网络需要关注的地方是:

  1. 网络的参数应该由什么对象保存
  2. 如何构建网络
  3. 如何计算梯度和更新参数

如何保存参数

pytorch中有两种变量类型,一个是Tensor,一个是Variable。

Tensor: 就像ndarray一样,一维Tensor叫Vector,二维Tensor叫Matrix,三维及以上称为Tensor
Variable:是Tensor的一个wrapper,不仅保存了值,而且保存了这个值的creator,需要BP的网络都是Variable参与运算

自动求导

pytorch的自动求导工具包在torch.autograd中

neural networks

使用torch.nn包中的工具来构建神经网络 需要以下几步:

  • 定义神经网络的权重,搭建网络结构
  • 遍历整个数据集进行训练
  • 将数据输入神经网络
  • 计算loss
  • 计算网络权重的梯度
  • 更新网络权重
    • weight = weight + learning_rate * gradient

上述代码输出:

使用loss criterion 和 optimizer训练网络

torch.nn包下有很多loss标准。同时torch.optimizer帮助完成更新权重的工作。这样就不需要手动更新参数了

整体NN结构

其它

关于求梯度,只有我们定义的Variable才会被求梯度,由creator创造的不会去求梯度

自己定义Variable的时候,记得Variable(Tensor, requires_grad = True),这样才会被求梯度,不然的话,是不会求梯度的

文章来源:Keith

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(1)

返回顶部