关于将Pytorch模型部署到安卓移动端方法总结
一、Android Studio环境配置
1.安装包下载问题解决
在Android Studio官网下载编译工具时,会出现无法下载的问题,可右键复制下载链接IDMan中进行下载。
2.安装
安装过程中,需要将Android Virtual Device勾选,否则无法使用虚拟机。
安装启动后,会提示没有SDK,设置代码,直接选择cancel键。
完后,会有专门的SKD组件的安装,但是会有unavailable不可安装的情况出现,可通过创建项目后配置gradle后便可以安装了。
二、项目创建
软件安装后可能出现打不开的情况,可选择以管理员身份启动即可解决问题。
选择New Project
选择喜欢的界面样式即可。
Build configuration language建议选择Kotlin DSL(build.gradle.kts)[Recommended],否则会出现缺少gradle文件的情况。
创建完后会出现如下项目目录,并不会直接出现app的文件夹,需要手动配置gradle。
按照如下目录gradle/wrapper/gradle-wrapper.properties修改distributionUrl为本地地址。(根据原先的地址下载对应的压缩包)
#Wed May 01 21:02:04 CST 2024 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists 更变为 #Wed May 01 21:02:04 CST 2024 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists # 对应的gradle-8.4-bin.zip本地地址即可 distributionUrl=file:///D://Android//gradle-8.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists
在settings.gradle.kts更换阿里源(直接复制粘贴即可)
pluginManagement { repositories { maven { url=uri ("https://www.jitpack.io")} maven { url=uri ("https://maven.aliyun.com/repository/releases")} maven { url=uri ("https://maven.aliyun.com/repository/google")} maven { url=uri ("https://maven.aliyun.com/repository/central")} maven { url=uri ("https://maven.aliyun.com/repository/gradle-plugin")} maven { url=uri ("https://maven.aliyun.com/repository/public")} google() mavenCentral() gradlePluginPortal() } } dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { maven { url=uri ("https://www.jitpack.io")} maven { url=uri ("https://maven.aliyun.com/repository/releases")} maven { url=uri ("https://maven.aliyun.com/repository/google")} maven { url=uri ("https://maven.aliyun.com/repository/central")} maven { url=uri ("https://maven.aliyun.com/repository/gradle-plugin")} maven { url=uri ("https://maven.aliyun.com/repository/public")} google() mavenCentral() } } rootProject.name = "Helloword" include(":app")
在build.gradle.kts中点击sync now即可自动配置,稍等即可便可变成app文件夹的形式。
选择Project,变成全部文件的形式。
初始新建项目即刻完成。
三、训练模型权重转化
需将训练好的.pth文件转化为.pt文件
""" 该程序使用的是resnet32网络,用到其他网络可自行更改 保存的权重字典目录如下所示。 ckpt = { 'weight': model.state_dict(), 'epoch': epoch, 'cfg': opt.model, 'index': name } """ from models.resnet_cifar import resnet32 # 确保引用你的正确模型架构 import torch import torch.nn as nn # 假设你的ResNet定义在resnet.py文件中 model = resnet32() num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 100) # 修改这里的100为你的类别数 # 加载权重 checkpoint = torch.load('modelleader_best.pth', map_location=torch.device('cpu')) model.load_state_dict(checkpoint['weight'], strict=False) # 使用strict=False可以忽略不匹配的键 model.eval() # 将模型转换为TorchScript example_input = torch.rand(1, 3, 32, 32) # 修改这里以匹配你的模型输入尺寸 traced_script_module = torch.jit.trace(model, example_input) traced_script_module.save("model.pt")
四、Pytorch项目搭建工作
在如下目录下创建assets文件,将转化好的模型放在里面即可,切记不可直接创建文件夹,会出现找不到模型问题。
在com/example/myapplication下创建了两个类cifarClassed,MainActivity。
MainActivity类
package com.example.myapplication; import android.content.Context; import android.content.Intent; import android.content.pm.PackageManager; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.provider.MediaStore; import android.util.Log; import android.view.View; import android.widget.Button; import android.widget.ImageView; import android.widget.TextView; import androidx.annotation.NonNull; import androidx.appcompat.app.AppCompatActivity; import androidx.core.app.ActivityCompat; import androidx.core.content.ContextCompat; import androidx.core.content.FileProvider; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; public class MainActivity extends AppCompatActivity { private static final int PERMISSION_REQUEST_CODE = 101; private static final int REQUEST_IMAGE_CAPTURE = 1; private static final int REQUEST_IMAGE_SELECT = 2; private ImageView imageView; private TextView textView; private Module module; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); // 检查相机权限 if (ContextCompat.checkSelfPermission(this, android.Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { ActivityCompat.requestPermissions(this, new String[]{android.Manifest.permission.CAMERA}, PERMISSION_REQUEST_CODE); } imageView = findViewById(R.id.image); textView = findViewById(R.id.text); ImageView logoImageView = findViewById(R.id.logo); logoImageView.setImageResource(R.drawable.logo); Button takePhotoButton = findViewById(R.id.button_take_photo); Button selectImageButton = findViewById(R.id.button_select_image); takePhotoButton.setOnClickListener(v -> dispatchTakePictureIntent()); selectImageButton.setOnClickListener(v -> dispatchGalleryIntent()); try { module = Module.load(assetFilePath(this, "model.pt")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } } @Override public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { super.onRequestPermissionsResult(requestCode, permissions, grantResults); if (requestCode == PERMISSION_REQUEST_CODE) { if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) { // 权限被授予 Log.d("Permissions", "Camera permission granted"); } else { // 权限被拒绝 Log.d("Permissions", "Camera permission denied"); } } } private void dispatchTakePictureIntent() { Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE); if (takePictureIntent.resolveActivity(getPackageManager()) != null) { startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE); } } private void dispatchGalleryIntent() { Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI); startActivityForResult(intent, REQUEST_IMAGE_SELECT); } @Override protected void onActivityResult(int requestCode, int resultCode, Intent data) { super.onActivityResult(requestCode, resultCode, data); if (resultCode == RESULT_OK && (requestCode == REQUEST_IMAGE_CAPTURE || requestCode == REQUEST_IMAGE_SELECT)) { Bitmap imageBitmap = null; if (requestCode == REQUEST_IMAGE_CAPTURE) { Bundle extras = data.getExtras(); imageBitmap = (Bitmap) extras.get("data"); } else if (requestCode == REQUEST_IMAGE_SELECT) { try { imageBitmap = MediaStore.Images.Media.getBitmap(this.getContentResolver(), data.getData()); } catch (IOException e) { e.printStackTrace(); } } imageView.setImageBitmap(imageBitmap); classifyImage(imageBitmap); } } // private void classifyImage(Bitmap bitmap) { // Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, // TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); // Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // float[] scores = outputTensor.getDataAsFloatArray(); // float maxScore = -Float.MAX_VALUE; // int maxScoreIdx = -1; // for (int i = 0; i maxScore) { // maxScore = scores[i]; // maxScoreIdx = i; // } // } // textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxScoreIdx]); // textView.setVisibility(View.VISIBLE); // 设置 TextView 可见 // } // private void classifyImage(Bitmap bitmap) { // // 调整图像大小为 32x32 像素 // Bitmap resizedBitmap = resizeBitmap(bitmap, 32, 32); // // // 将调整大小后的图像转换为 PyTorch Tensor // Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, // new float[]{0.485f, 0.456f, 0.406f}, // 均值 Mean // new float[]{0.229f, 0.224f, 0.225f}); // 标准差 Std // // // 推理 // Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // float[] scores = outputTensor.getDataAsFloatArray(); // float maxScore = -Float.MAX_VALUE; // int maxScoreIdx = -1; // for (int i = 0; i maxScore) { // maxScore = scores[i]; // maxScoreIdx = i; // } // } // textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxScoreIdx]); // textView.setVisibility(View.VISIBLE); // 设置 TextView 可见 // } // private float[] softmax(float[] scores) { float max = Float.NEGATIVE_INFINITY; for (float score : scores) { if (score > max) max = score; } float sum = 0.0f; float[] exps = new float[scores.length]; for (int i = 0; i maxScore) { maxScore = probabilities[i]; maxScoreIdx = i; } } // 更新 UI 必须在主线程中完成 final int maxIndex = maxScoreIdx; final float finalMaxScore = maxScore; runOnUiThread(new Runnable() { @Override public void run() { textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxIndex] + " (" + String.format("%.2f%%", finalMaxScore * 100) + ")"); textView.setVisibility(View.VISIBLE); // 设置 TextView 可见 } }); } /// // // 方法来调整 Bitmap 的大小 private Bitmap resizeBitmap(Bitmap originalBitmap, int targetWidth, int targetHeight) { return Bitmap.createScaledBitmap(originalBitmap, targetWidth, targetHeight, false); } public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }
CifarClassed类
package com.example.myapplication; public class CifarClassed { public static String[] IMAGENET_CLASSES = new String[]{ "apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "crab", "crocodile", "cup", "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck", "pine_tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower", "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm" }; }
页面布局存放在MyApplication\app\src\main\res\layout\activity_main.xml文件中。
在MyApplication\app\src\main\res\drawable\circle_shape.xml(自行创建)
在MyApplication\app\src\main\res\drawable\rounded_background(自行创建)
在MyApplication\app\src\main\AndroidManifest.xml添加相机与读取照片的权限。
app级别build.gradle.kts(MyApplication\app\build.gradle.kts)配置如下。
plugins { alias(libs.plugins.androidApplication) } android { namespace = "com.example.myapplication" compileSdk = 34 sourceSets { getByName("main") { jniLibs.srcDir("libs") } } packaging { resources.excludes.add("META-INF/*") } defaultConfig { applicationId = "com.example.myapplication" minSdk = 24 targetSdk = 34 versionCode = 1 versionName = "1.0" testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" } buildTypes { release { isMinifyEnabled = false proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") } } compileOptions { sourceCompatibility = JavaVersion.VERSION_1_8 targetCompatibility = JavaVersion.VERSION_1_8 } } dependencies { // 使用 alias 来指定库,确保 libs.aliases.gradle 中已经定义了这些别名 implementation(libs.appcompat) implementation(libs.material) implementation(libs.activity) implementation(libs.constraintlayout) testImplementation(libs.junit) androidTestImplementation(libs.ext.junit) androidTestImplementation(libs.espresso.core) implementation("org.pytorch:pytorch_android:1.12.1") implementation("org.pytorch:pytorch_android_torchvision:1.12.1") implementation("com.google.android.exoplayer:exoplayer:2.14.1") implementation("androidx.localbroadcastmanager:localbroadcastmanager:1.0.0") implementation("androidx.activity:activity:1.2.0") implementation("androidx.fragment:fragment:1.3.0") implementation("de.hdodenhof:circleimageview:3.1.0") }
这段可解决如下bug。
packaging { resources.excludes.add("META-INF/*") }
Caused by: com.android.builder.merge.DuplicateRelativeFileException: 2 files found with path ‘META-INF/androidx.core_core.version’.
手动添加非常麻烦,因为不止一个文件冲突!!!
完成以上步骤再按下Sync Now完成依赖的配置工作,需在编译器中自行选择虚拟设备。
完成后即可在MainActivity.java文件启动项目。
五、APK安装包导出
点击create创建即可,便可得到apk文件。