前几天 Facebook 刚刚发布了 PyTorch Mobile,为了加速手机上的 AI 模型的开发和部署,适用于 Android 和 iOS。在今天的教程里,PyTorch 中文网为大家整理了如何将 ImageNet 预训练模型迁移到手机上,并制作一个 Android 应用来进行图像识别。
本文主要包括以下几步:
- 将深度学习模型转换为 TorchScript 格式(Python)
- 将 PyTorch Mobile 加入 Gradle 依赖(Java)
- 用 PyTorch Mobile 在手机上加载模型进行图像分类(Java)
文章目录
准备工作
首先需要安装最新版的 PyTorch,本文的版本是 1.3.0.
其次需要安装 Android Studio 进行 Android 开发。
模型格式转换
为了能够在 Android 上使用我们的深度学习模型,需要将其转换为 TorchScript 格式。这个过程非常简单。下面的代码将预训练的 MobileNetV2 模型转换为 TorchScript 格式:
1 2 3 4 5 6 7 8 9 10 | import torch from torchvision.models import mobilenet_v2 model = mobilenet_v2(pretrained=True) model.eval() input_tensor = torch.rand(1,3,224,224) script_model = torch.jit.trace(model,input_tensor) script_model.save("mobilenet-v2.pt") |
上述代码会将转换好的模型存为文件 “mobilenet-v2.pt”。
1. 创建 Android 项目,添加 PyTorch Mobile
首先用 Android Studio 创建一个项目名为 PytorchAndroid,然后打开 build.gradle 文件添加 PyTorch Mobile 和 TorchVision Mobile:
1 2 | implementation ‘org.pytorch:pytorch_android:1.3.0’ implementation ‘org.pytorch:pytorch_android_torchvision:1.3.0’ |
文件示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | apply plugin: 'com.android.application' android { compileSdkVersion 28 defaultConfig { applicationId "com.johnolafenwa.pytorchandroid" minSdkVersion 21 targetSdkVersion 28 versionCode 1 versionName "1.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" } buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } } dependencies { implementation fileTree(dir: 'libs', include: ['*.jar']) implementation 'org.pytorch:pytorch_android:1.3.0' implementation 'org.pytorch:pytorch_android_torchvision:1.3.0' implementation 'com.android.support:appcompat-v7:28.0.0' implementation 'com.android.support.constraint:constraint-layout:1.1.3' implementation 'com.android.support:design:28.0.0' testImplementation 'junit:junit:4.12' androidTestImplementation 'com.android.support.test:runner:1.0.2' androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2' } |
然后 Android Studio 会提醒进行同步,点击 “Sync Now” 会自动下载所需要的依赖包。
2. 将模型放到 assets 文件夹
按照下列步骤创建 assets 文件夹:New -> Folder -> Assets Folder。然后将 “mobilenet-v2.pt” 文件放到这个 assets 文件件内。
3. 添加 ImageNet 标签
在 app 包内,创建名为 “Constants.java” 的 Java 文件,将这个文件里的内容复制粘贴进去。
4. 添加分类
在 app 包内,创建名为 “Classifier.java” 的 Java 文件,放入下列代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | package com.johnolafenwa.pytorchandroid; import android.graphics.Bitmap; import org.pytorch.Tensor; import org.pytorch.Module; import org.pytorch.IValue; import org.pytorch.torchvision.TensorImageUtils; public class Classifier { Module model; float[] mean = {0.485f, 0.456f, 0.406f}; float[] std = {0.229f, 0.224f, 0.225f}; public Classifier(String modelPath){ model = Module.load(modelPath); } public void setMeanAndStd(float[] mean, float[] std){ this.mean = mean; this.std = std; } public Tensor preprocess(Bitmap bitmap, int size){ bitmap = Bitmap.createScaledBitmap(bitmap,size,size,false); return TensorImageUtils.bitmapToFloat32Tensor(bitmap,this.mean,this.std); } public int argMax(float[] inputs){ int maxIndex = -1; float maxvalue = 0.0f; for (int i = 0; i < inputs.length; i++){ if(inputs[i] > maxvalue) { maxIndex = i; maxvalue = inputs[i]; } } return maxIndex; } public String predict(Bitmap bitmap){ Tensor tensor = preprocess(bitmap,224); IValue inputs = IValue.from(tensor); Tensor outputs = model.forward(inputs).toTensor(); float[] scores = outputs.getDataAsFloatArray(); int classIndex = argMax(scores); return Constants.IMAGENET_CLASSES[classIndex]; } } |
这个是我们整个项目的核心文件。其中 preprocess
函数接收一张 bitmap 图像,然后调整大小,做标准化处理,再把处理后的文件返回为 Tensor 格式以备模型使用。argmax
函数返回最大值所在的 index。predict
函数接收一张 bitmap 图像,将其处理为 Tensor,放入模型得到预测结果。
5. 添加工具辅助类
创建文件 “Utils.java” 然后放入下列代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | package com.johnolafenwa.pytorchandroid; import android.content.Context; import android.util.Log; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; public class Utils { public static String assetFilePath(Context context, String assetName) { File file = new File(context.getFilesDir(), assetName); try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } catch (IOException e) { Log.e("pytorchandroid", "Error process asset " + assetName + " to file path"); } return null; } } |
6. 添加 Main Activity
创建文件 “MainActivity.java” 然后放入下列代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | package com.johnolafenwa.pytorchandroid; import android.content.Intent; import android.graphics.Bitmap; import android.os.Bundle; import android.provider.MediaStore; import android.support.v7.app.AppCompatActivity; import android.support.v7.widget.Toolbar; import android.util.Log; import android.view.View; import android.widget.Button; import java.io.File; public class MainActivity extends AppCompatActivity { int cameraRequestCode = 001; Classifier classifier; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Toolbar toolbar = findViewById(R.id.toolbar); setSupportActionBar(toolbar); classifier = new Classifier(Utils.assetFilePath(this,"mobilenet-v2.pt")); Button capture = findViewById(R.id.capture); capture.setOnClickListener(new View.OnClickListener(){ @Override public void onClick(View view){ Intent cameraIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE); startActivityForResult(cameraIntent,cameraRequestCode); } }); } @Override protected void onActivityResult(int requestCode, int resultCode, Intent data){ if(requestCode == cameraRequestCode && resultCode == RESULT_OK){ Intent resultView = new Intent(this,Result.class); resultView.putExtra("imagedata",data.getExtras()); Bitmap imageBitmap = (Bitmap) data.getExtras().get("data"); String pred = classifier.predict(imageBitmap); resultView.putExtra("pred",pred); startActivity(resultView); } } } |
文件 “activity_main.xml” 应该长这样:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | <?xml version="1.0" encoding="utf-8"?> <android.support.design.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" tools:context=".MainActivity" > <android.support.design.widget.AppBarLayout android:layout_width="match_parent" android:layout_height="wrap_content" android:theme="@style/AppTheme.AppBarOverlay"> <android.support.v7.widget.Toolbar android:id="@+id/toolbar" android:layout_width="match_parent" android:layout_height="?attr/actionBarSize" android:background="?attr/colorPrimary" app:popupTheme="@style/AppTheme.PopupOverlay" /> </android.support.design.widget.AppBarLayout> <include layout="@layout/content_main" /> </android.support.design.widget.CoordinatorLayout> |
添加文件 “content_main.xml”:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | <?xml version="1.0" encoding="utf-8"?> <android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" app:layout_behavior="@string/appbar_scrolling_view_behavior" tools:context=".MainActivity" tools:showIn="@layout/activity_main" > <Button android:layout_width="wrap_content" android:layout_height="wrap_content" android:id="@+id/capture" android:text="Take A Picture" android:textColor="#ffffff" android:textSize="26dp" android:background="#83D5C4" android:padding="5dp" android:fontFamily="cursive" app:layout_constraintTop_toTopOf="parent" app:layout_constraintBottom_toBottomOf="parent" app:layout_constraintStart_toStartOf="parent" app:layout_constraintEnd_toEndOf="parent" /> </android.support.constraint.ConstraintLayout> |
上面代码主要做的是,点击按钮后,调用外部摄像头拍摄,得到 bitmap 图像后调用分类器得到预测结果。
7. 添加 Result Activity
创建一个 Basic Activity 文件 “Result.java” 然后放入下列代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | package com.johnolafenwa.pytorchandroid; import android.graphics.Bitmap; import android.os.Bundle; import android.support.v7.app.AppCompatActivity; import android.support.v7.widget.Toolbar; import android.widget.ImageView; import android.widget.TextView; public class Result extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_result); Toolbar toolbar = findViewById(R.id.toolbar); setSupportActionBar(toolbar); Bitmap imageBitmap = (Bitmap) getIntent().getBundleExtra("imagedata").get("data"); String pred = getIntent().getStringExtra("pred"); ImageView imageView = findViewById(R.id.image); imageView.setImageBitmap(imageBitmap); TextView textView = findViewById(R.id.label); textView.setText(pred); } } |
文件 “activity_result.xml” 如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | <?xml version="1.0" encoding="utf-8"?> <android.support.design.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" tools:context=".Result"> <android.support.design.widget.AppBarLayout android:layout_width="match_parent" android:layout_height="wrap_content" android:theme="@style/AppTheme.AppBarOverlay"> <android.support.v7.widget.Toolbar android:id="@+id/toolbar" android:layout_width="match_parent" android:layout_height="?attr/actionBarSize" android:background="?attr/colorPrimary" app:popupTheme="@style/AppTheme.PopupOverlay" /> </android.support.design.widget.AppBarLayout> <include layout="@layout/content_result" /> </android.support.design.widget.CoordinatorLayout> |
文件 “content_result.xml” 如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | <?xml version="1.0" encoding="utf-8"?> <android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" app:layout_behavior="@string/appbar_scrolling_view_behavior" tools:context=".Result" tools:showIn="@layout/activity_result"> <ImageView android:layout_width="match_parent" android:layout_height="wrap_content" android:adjustViewBounds="true" android:src="@drawable/ic_launcher_background" app:layout_constraintEnd_toEndOf="parent" app:layout_constraintStart_toStartOf="parent" app:layout_constraintTop_toTopOf="parent" app:layout_constraintBottom_toBottomOf="parent" android:id="@+id/image" /> <TextView android:layout_width="wrap_content" android:layout_height="wrap_content" android:text="Hello World" android:id="@+id/label" android:textSize="16pt" app:layout_constraintStart_toStartOf="@id/image" app:layout_constraintEnd_toEndOf="@id/image" app:layout_constraintTop_toBottomOf="@id/image" /> </android.support.constraint.ConstraintLayout> |
然后差不多就完成了!
编译项目
接下来就是编译和运行自己的 Android 应用了,如下图:
大家可以多试试不同的物品进行识别。
总结
本文我们成功用 PyTorch Mobiel 和 TorchVision Mobile 创建了一个 Android 手机应用,来进行 1000 多个类别的图像识别。本文原作者的 Twitter 在这里,大家可以关注一下,本文的完整代码见 Github。
本站微信群、QQ群(三群号 726282629):
我自己训练的.pt模型导入报错
2020-04-23 17:31:39.748 30471-30471/com.johnolafenwa.pytorchandroid E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.johnolafenwa.pytorchandroid, PID: 30471
java.lang.RuntimeException: Unable to start activity ComponentInfo{com.johnolafenwa.pytorchandroid/com.johnolafenwa.pytorchandroid.MainActivity}: com.facebook.jni.CppException: [enforce fail at inline_container.cc:137] . PytorchStreamReader failed reading zip archive: failed finding central directory
(no backtrace available)
at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2728)
at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:2789)
at android.app.ActivityThread.-wrap12(ActivityThread.java)
at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1513)
at android.os.Handler.dispatchMessage(Handler.java:102)
at android.os.Looper.loop(Looper.java:154)
at android.app.ActivityThread.main(ActivityThread.java:6211)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:903)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:793)
Caused by: com.facebook.jni.CppException: [enforce fail at inline_container.cc:137] . PytorchStreamReader failed reading zip archive: failed finding central directory
(no backtrace available)
at org.pytorch.Module$NativePeer.initHybrid(Native Method)
at org.pytorch.Module$NativePeer.(Module.java:70)
at org.pytorch.Module.(Module.java:25)
at org.pytorch.Module.load(Module.java:21)
at com.johnolafenwa.pytorchandroid.Classifier.(Classifier.java:18)
at com.johnolafenwa.pytorchandroid.MainActivity.onCreate(MainActivity.java:28)
at android.app.Activity.performCreate(Activity.java:6859)
at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1125)
at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2681)
… 9 more
您好请问您解决了么,我也是导入自己的模型报错
您好请问您解决这个问题了么