您的位置 首页 PyTorch 教程

在 Android 上运行 PyTorch Mobile 进行图像分类

PyTorch入门实战教程

前几天 Facebook 刚刚发布了 PyTorch Mobile,为了加速手机上的 AI 模型的开发和部署,适用于 Android 和 iOS。在今天的教程里,PyTorch 中文网为大家整理了如何将 ImageNet 预训练模型迁移到手机上,并制作一个 Android 应用来进行图像识别。

本文主要包括以下几步:

  1. 将深度学习模型转换为 TorchScript 格式(Python)
  2. 将 PyTorch Mobile 加入 Gradle 依赖(Java)
  3. 用 PyTorch Mobile 在手机上加载模型进行图像分类(Java)

准备工作

首先需要安装最新版的 PyTorch,本文的版本是 1.3.0.

其次需要安装 Android Studio 进行 Android 开发。

模型格式转换

为了能够在 Android 上使用我们的深度学习模型,需要将其转换为 TorchScript 格式。这个过程非常简单。下面的代码将预训练的 MobileNetV2 模型转换为 TorchScript 格式:

上述代码会将转换好的模型存为文件 “mobilenet-v2.pt”。

1. 创建 Android 项目,添加 PyTorch Mobile

首先用 Android Studio 创建一个项目名为 PytorchAndroid,然后打开 build.gradle 文件添加 PyTorch Mobile 和 TorchVision Mobile:

文件示例如下:

然后 Android Studio 会提醒进行同步,点击 “Sync Now” 会自动下载所需要的依赖包。

2. 将模型放到 assets 文件夹

按照下列步骤创建 assets 文件夹:New -> Folder -> Assets Folder。然后将 “mobilenet-v2.pt” 文件放到这个 assets 文件件内。

3. 添加 ImageNet 标签

在 app 包内,创建名为 “Constants.java” 的 Java 文件,将这个文件里的内容复制粘贴进去。

4. 添加分类

在 app 包内,创建名为 “Classifier.java” 的 Java 文件,放入下列代码:

这个是我们整个项目的核心文件。其中 preprocess 函数接收一张 bitmap 图像,然后调整大小,做标准化处理,再把处理后的文件返回为 Tensor 格式以备模型使用。argmax 函数返回最大值所在的 index。predict 函数接收一张 bitmap 图像,将其处理为 Tensor,放入模型得到预测结果。

5. 添加工具辅助类

创建文件 “Utils.java” 然后放入下列代码:

6. 添加 Main Activity

创建文件 “MainActivity.java” 然后放入下列代码:

文件 “activity_main.xml” 应该长这样:

添加文件 “content_main.xml”:

上面代码主要做的是,点击按钮后,调用外部摄像头拍摄,得到 bitmap 图像后调用分类器得到预测结果。

7. 添加 Result Activity

创建一个 Basic Activity 文件 “Result.java” 然后放入下列代码:

文件 “activity_result.xml” 如下:

文件 “content_result.xml” 如下:

然后差不多就完成了!

编译项目

接下来就是编译和运行自己的 Android 应用了,如下图:

大家可以多试试不同的物品进行识别。

总结

本文我们成功用 PyTorch Mobiel 和 TorchVision Mobile 创建了一个 Android 手机应用,来进行 1000 多个类别的图像识别。本文原作者的 Twitter 在这里,大家可以关注一下,本文的完整代码见 Github

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(3)

  1. 我自己训练的.pt模型导入报错

    2020-04-23 17:31:39.748 30471-30471/com.johnolafenwa.pytorchandroid E/AndroidRuntime: FATAL EXCEPTION: main
    Process: com.johnolafenwa.pytorchandroid, PID: 30471
    java.lang.RuntimeException: Unable to start activity ComponentInfo{com.johnolafenwa.pytorchandroid/com.johnolafenwa.pytorchandroid.MainActivity}: com.facebook.jni.CppException: [enforce fail at inline_container.cc:137] . PytorchStreamReader failed reading zip archive: failed finding central directory
    (no backtrace available)
    at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2728)
    at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:2789)
    at android.app.ActivityThread.-wrap12(ActivityThread.java)
    at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1513)
    at android.os.Handler.dispatchMessage(Handler.java:102)
    at android.os.Looper.loop(Looper.java:154)
    at android.app.ActivityThread.main(ActivityThread.java:6211)
    at java.lang.reflect.Method.invoke(Native Method)
    at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:903)
    at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:793)
    Caused by: com.facebook.jni.CppException: [enforce fail at inline_container.cc:137] . PytorchStreamReader failed reading zip archive: failed finding central directory
    (no backtrace available)
    at org.pytorch.Module$NativePeer.initHybrid(Native Method)
    at org.pytorch.Module$NativePeer.(Module.java:70)
    at org.pytorch.Module.(Module.java:25)
    at org.pytorch.Module.load(Module.java:21)
    at com.johnolafenwa.pytorchandroid.Classifier.(Classifier.java:18)
    at com.johnolafenwa.pytorchandroid.MainActivity.onCreate(MainActivity.java:28)
    at android.app.Activity.performCreate(Activity.java:6859)
    at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1125)
    at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2681)
    … 9 more

返回顶部