您的位置 首页 PyTorch 教程

PyTorch 入门之五分钟实现简单二分类器

PyTorch入门实战教程

很多时候我们想找一个简单的分类器示例,却找来找去都是图像分类,而且看起来云里雾里,很难入门。今天我们用 PyTorch 教大家实现一个很简单的二分类器,所用的数据来自 Scikit learn。

我们首先来生成数据,一共 200 个样本:

绘制以下样本的散点图:

可以看到我们生成了两类数据,分别用 0 和 1 来表示。我们接下来将要在这个样本数据上构造一个分类器,采用的是一个很简单的全连接网络,网络结构如下:

这个网络包含一个输入层,一个中间层,一个输出层。中间层包含 3 个神经元,使用的激活函数是 tanh。当然,中间层的神经元越多,分类效果一般越好,但这个 3 层的网络对于我们的样本数据已经足够用了。我们来算一下参数数量:上图中一共有 6 6 = 12 条线,就是 12 个权重,加上 3 2 = 5 个 bias,一共 17 个参数需要训练。

首先我们将样本数据从 numpy 转成 tensor:

然后构建我们的神经网络:

我们的损失函数用 CrossEntropyLoss,梯度优化器使用 Adam:

开始训练:

我们来看一下 training error:

为了更直观地展示分类结果,我们将结果可视化:

下面的函数帮助我们在两个分类之间画一条分界线:

输出图像:

可以看出,分类效果还是很不错的。

完整代码:Github

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(1)

返回顶部