关于将Pytorch模型部署到安卓移动端方法总结

07-19 1427阅读

一、Android Studio环境配置

1.安装包下载问题解决

在Android Studio官网下载编译工具时,会出现无法下载的问题,可右键复制下载链接IDMan中进行下载。

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

2.安装

安装过程中,需要将Android Virtual Device勾选,否则无法使用虚拟机。

关于将Pytorch模型部署到安卓移动端方法总结

安装启动后,会提示没有SDK,设置代码,直接选择cancel键。

关于将Pytorch模型部署到安卓移动端方法总结

完后,会有专门的SKD组件的安装,但是会有unavailable不可安装的情况出现,可通过创建项目后配置gradle后便可以安装了。

关于将Pytorch模型部署到安卓移动端方法总结

二、项目创建

软件安装后可能出现打不开的情况,可选择以管理员身份启动即可解决问题。

选择New Project

关于将Pytorch模型部署到安卓移动端方法总结

选择喜欢的界面样式即可。

关于将Pytorch模型部署到安卓移动端方法总结使用语言、SDK根据自行需求进行选择就行。

Build configuration language建议选择Kotlin DSL(build.gradle.kts)[Recommended],否则会出现缺少gradle文件的情况。

关于将Pytorch模型部署到安卓移动端方法总结

创建完后会出现如下项目目录,并不会直接出现app的文件夹,需要手动配置gradle。

关于将Pytorch模型部署到安卓移动端方法总结

按照如下目录gradle/wrapper/gradle-wrapper.properties修改distributionUrl为本地地址。(根据原先的地址下载对应的压缩包)

#Wed May 01 21:02:04 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
更变为
#Wed May 01 21:02:04 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
# 对应的gradle-8.4-bin.zip本地地址即可
distributionUrl=file:///D://Android//gradle-8.4-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

在settings.gradle.kts更换阿里源(直接复制粘贴即可)

pluginManagement {
    repositories {
        maven { url=uri ("https://www.jitpack.io")}
        maven { url=uri ("https://maven.aliyun.com/repository/releases")}
        maven { url=uri ("https://maven.aliyun.com/repository/google")}
        maven { url=uri ("https://maven.aliyun.com/repository/central")}
        maven { url=uri ("https://maven.aliyun.com/repository/gradle-plugin")}
        maven { url=uri ("https://maven.aliyun.com/repository/public")}
        google()
        mavenCentral()
        gradlePluginPortal()
    }
}
dependencyResolutionManagement {
    repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
    repositories {
        maven { url=uri ("https://www.jitpack.io")}
        maven { url=uri ("https://maven.aliyun.com/repository/releases")}
        maven { url=uri ("https://maven.aliyun.com/repository/google")}
        maven { url=uri ("https://maven.aliyun.com/repository/central")}
        maven { url=uri ("https://maven.aliyun.com/repository/gradle-plugin")}
        maven { url=uri ("https://maven.aliyun.com/repository/public")}
        google()
        mavenCentral()
    }
}
rootProject.name = "Helloword"
include(":app")

在build.gradle.kts中点击sync now即可自动配置,稍等即可便可变成app文件夹的形式。

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

选择Project,变成全部文件的形式。

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

初始新建项目即刻完成。

三、训练模型权重转化

需将训练好的.pth文件转化为.pt文件

"""
该程序使用的是resnet32网络,用到其他网络可自行更改
保存的权重字典目录如下所示。
      ckpt = {
            'weight': model.state_dict(),
            'epoch': epoch,
            'cfg': opt.model,
            'index': name
        }
"""
from models.resnet_cifar import resnet32  # 确保引用你的正确模型架构
import torch
import torch.nn as nn
# 假设你的ResNet定义在resnet.py文件中
model = resnet32()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100)  # 修改这里的100为你的类别数
# 加载权重
checkpoint = torch.load('modelleader_best.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['weight'], strict=False)  # 使用strict=False可以忽略不匹配的键
model.eval()
# 将模型转换为TorchScript
example_input = torch.rand(1, 3, 32, 32)  # 修改这里以匹配你的模型输入尺寸
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("model.pt")

四、Pytorch项目搭建工作

在如下目录下创建assets文件,将转化好的模型放在里面即可,切记不可直接创建文件夹,会出现找不到模型问题。

关于将Pytorch模型部署到安卓移动端方法总结

在com/example/myapplication下创建了两个类cifarClassed,MainActivity。

关于将Pytorch模型部署到安卓移动端方法总结

MainActivity类
package com.example.myapplication;
import android.content.Context;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import androidx.core.content.FileProvider;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
public class MainActivity extends AppCompatActivity {
    private static final int PERMISSION_REQUEST_CODE = 101;
    private static final int REQUEST_IMAGE_CAPTURE = 1;
    private static final int REQUEST_IMAGE_SELECT = 2;
    private ImageView imageView;
    private TextView textView;
    private Module module;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        // 检查相机权限
        if (ContextCompat.checkSelfPermission(this, android.Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{android.Manifest.permission.CAMERA}, PERMISSION_REQUEST_CODE);
        }
        imageView = findViewById(R.id.image);
        textView = findViewById(R.id.text);
        ImageView logoImageView = findViewById(R.id.logo);
        logoImageView.setImageResource(R.drawable.logo);
        Button takePhotoButton = findViewById(R.id.button_take_photo);
        Button selectImageButton = findViewById(R.id.button_select_image);
        takePhotoButton.setOnClickListener(v -> dispatchTakePictureIntent());
        selectImageButton.setOnClickListener(v -> dispatchGalleryIntent());
        try {
            module = Module.load(assetFilePath(this, "model.pt"));
        } catch (IOException e) {
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }
    }
    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        if (requestCode == PERMISSION_REQUEST_CODE) {
            if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
                // 权限被授予
                Log.d("Permissions", "Camera permission granted");
            } else {
                // 权限被拒绝
                Log.d("Permissions", "Camera permission denied");
            }
        }
    }
    private void dispatchTakePictureIntent() {
        Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
        if (takePictureIntent.resolveActivity(getPackageManager()) != null) {
            startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE);
        }
    }
    private void dispatchGalleryIntent() {
        Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
        startActivityForResult(intent, REQUEST_IMAGE_SELECT);
    }
    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode == RESULT_OK && (requestCode == REQUEST_IMAGE_CAPTURE || requestCode == REQUEST_IMAGE_SELECT)) {
            Bitmap imageBitmap = null;
            if (requestCode == REQUEST_IMAGE_CAPTURE) {
                Bundle extras = data.getExtras();
                imageBitmap = (Bitmap) extras.get("data");
            } else if (requestCode == REQUEST_IMAGE_SELECT) {
                try {
                    imageBitmap = MediaStore.Images.Media.getBitmap(this.getContentResolver(), data.getData());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            imageView.setImageBitmap(imageBitmap);
            classifyImage(imageBitmap);
        }
    }
//    private void classifyImage(Bitmap bitmap) {
//        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
//                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
//        Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
//        float[] scores = outputTensor.getDataAsFloatArray();
//        float maxScore = -Float.MAX_VALUE;
//        int maxScoreIdx = -1;
//        for (int i = 0; i  maxScore) {
//                maxScore = scores[i];
//                maxScoreIdx = i;
//            }
//        }
//        textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxScoreIdx]);
//        textView.setVisibility(View.VISIBLE); // 设置 TextView 可见
//    }
//    private void classifyImage(Bitmap bitmap) {
//        // 调整图像大小为 32x32 像素
//        Bitmap resizedBitmap = resizeBitmap(bitmap, 32, 32);
//
//        // 将调整大小后的图像转换为 PyTorch Tensor
//        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap,
//                new float[]{0.485f, 0.456f, 0.406f}, // 均值 Mean
//                new float[]{0.229f, 0.224f, 0.225f}); // 标准差 Std
//
//        // 推理
//        Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
//        float[] scores = outputTensor.getDataAsFloatArray();
//        float maxScore = -Float.MAX_VALUE;
//        int maxScoreIdx = -1;
//        for (int i = 0; i  maxScore) {
//                maxScore = scores[i];
//                maxScoreIdx = i;
//            }
//        }
//        textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxScoreIdx]);
//        textView.setVisibility(View.VISIBLE); // 设置 TextView 可见
//    }
//
    private float[] softmax(float[] scores) {
        float max = Float.NEGATIVE_INFINITY;
        for (float score : scores) {
            if (score > max) max = score;
        }
        float sum = 0.0f;
        float[] exps = new float[scores.length];
        for (int i = 0; i  maxScore) {
                maxScore = probabilities[i];
                maxScoreIdx = i;
            }
        }
        // 更新 UI 必须在主线程中完成
        final int maxIndex = maxScoreIdx;
        final float finalMaxScore = maxScore;
        runOnUiThread(new Runnable() {
            @Override
            public void run() {
                textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxIndex] + " (" + String.format("%.2f%%", finalMaxScore * 100) + ")");
                textView.setVisibility(View.VISIBLE); // 设置 TextView 可见
            }
        });
    }
///
    //
    // 方法来调整 Bitmap 的大小
    private Bitmap resizeBitmap(Bitmap originalBitmap, int targetWidth, int targetHeight) {
        return Bitmap.createScaledBitmap(originalBitmap, targetWidth, targetHeight, false);
    }
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }
        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
}
CifarClassed类
package com.example.myapplication;
public class CifarClassed {
    public static String[] IMAGENET_CLASSES = new String[]{
            "apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle",
            "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel",
            "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock",
            "cloud", "cockroach", "couch", "crab", "crocodile", "cup", "dinosaur",
            "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house",
            "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard",
            "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom",
            "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck",
            "pine_tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon",
            "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk",
            "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower",
            "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor",
            "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf",
            "woman", "worm"
    };
}

页面布局存放在MyApplication\app\src\main\res\layout\activity_main.xml文件中。

    
        
        
         
    
    
    
        
        
    
    
    
    

在MyApplication\app\src\main\res\drawable\circle_shape.xml(自行创建)


      
      

在MyApplication\app\src\main\res\drawable\rounded_background(自行创建)


      
      

在MyApplication\app\src\main\AndroidManifest.xml添加相机与读取照片的权限。


    
    
    

    
        
            
                
                
            
        
    

app级别build.gradle.kts(MyApplication\app\build.gradle.kts)配置如下。

plugins {
    alias(libs.plugins.androidApplication)
}
android {
    namespace = "com.example.myapplication"
    compileSdk = 34
    sourceSets {
        getByName("main") {
            jniLibs.srcDir("libs")
        }
    }
    packaging {
        resources.excludes.add("META-INF/*")
    }
    defaultConfig {
        applicationId = "com.example.myapplication"
        minSdk = 24
        targetSdk = 34
        versionCode = 1
        versionName = "1.0"
        testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
    }
    buildTypes {
        release {
            isMinifyEnabled = false
            proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro")
        }
    }
    compileOptions {
        sourceCompatibility = JavaVersion.VERSION_1_8
        targetCompatibility = JavaVersion.VERSION_1_8
    }
}
dependencies {
    // 使用 alias 来指定库,确保 libs.aliases.gradle 中已经定义了这些别名
    implementation(libs.appcompat)
    implementation(libs.material)
    implementation(libs.activity)
    implementation(libs.constraintlayout)
    testImplementation(libs.junit)
    androidTestImplementation(libs.ext.junit)
    androidTestImplementation(libs.espresso.core)
    implementation("org.pytorch:pytorch_android:1.12.1")
    implementation("org.pytorch:pytorch_android_torchvision:1.12.1")
    implementation("com.google.android.exoplayer:exoplayer:2.14.1")
    implementation("androidx.localbroadcastmanager:localbroadcastmanager:1.0.0")
    implementation("androidx.activity:activity:1.2.0")
    implementation("androidx.fragment:fragment:1.3.0")
    implementation("de.hdodenhof:circleimageview:3.1.0")

}

这段可解决如下bug。

    packaging {
        resources.excludes.add("META-INF/*")
    }
Caused by: com.android.builder.merge.DuplicateRelativeFileException: 2 files found with path ‘META-INF/androidx.core_core.version’.

手动添加非常麻烦,因为不止一个文件冲突!!!

完成以上步骤再按下Sync Now完成依赖的配置工作,需在编译器中自行选择虚拟设备。

关于将Pytorch模型部署到安卓移动端方法总结关于将Pytorch模型部署到安卓移动端方法总结关于将Pytorch模型部署到安卓移动端方法总结

完成后即可在MainActivity.java文件启动项目。

五、APK安装包导出

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

关于将Pytorch模型部署到安卓移动端方法总结

 点击create创建即可,便可得到apk文件。

六、效果图

关于将Pytorch模型部署到安卓移动端方法总结

VPS购买请点击我

文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。

目录[+]