I developed a Keypoint RCNN Detection model, but when deploying it on Android, I encountered this error:
Process: com.example.java_pytorch, PID: 5028
java.lang.RuntimeException: Unable to start activity ComponentInfo{com.example.java_pytorch/com.example.java_pytorch.MainActivity}: com.facebook.jni.CppException:
Unknown builtin op: torchvision::nms.
Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript.
:
File “code/torch/torchvision/models/detection/keypoint_rcnn/___torch_mangle_501.py”, line 783
_360 = torch.index(boxes4, _359)
_361 = annotate(List[Optional[Tensor]], [curr_indices])
curr_keep_indices = ops.torchvision.nms(_360, torch.index(scores1, _361), 0.69999999999999996)
~~~~~~~~~~~~~~~~~~~ <— HERE
_362 = annotate(List[Optional[Tensor]], [curr_keep_indices])
_363 = torch.index(curr_indices, _362)
I’ve already included these dependencies:
implementation ‘org.pytorch:pytorch_android:1.13.0 implementation ‘org.pytorch:pytorch_android_torchvision:1.13.0 implementation ‘org.pytorch:torchvision_ops:0.13.0 implementation ‘com.facebook.soloader:nativeloader:0.8.0
Can anyone help me understand and resolve this error?
Android source code and i have developed this model from this source https://medium.com/@alexppppp/how-to-train-a-custom-keypoint-detection-model-with-pytorch-d9af90e111da
package com.example.java_pytorch;
// MainActivity.java
import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.Manifest;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.util.Log;
import android.view.SurfaceView;
import android.view.View;
import android.view.WindowManager;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.Toast;
import androidx.annotation.NonNull;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import org.opencv.android.BaseLoaderCallback;
import org.opencv.android.CameraBridgeViewBase;
import org.opencv.android.LoaderCallbackInterface;
import org.opencv.android.OpenCVLoader;
import org.opencv.core.Mat;
import org.opencv.android.Utils;
import org.pytorch.IValue;
//import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import android.content.Context;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public class MainActivity extends AppCompatActivity implements CameraBridgeViewBase.CvCameraViewListener2 {
private static final int CAMERA_PERMISSION_REQUEST_CODE = 100;
private static final String TAG = "MainActivity";
private CameraBridgeViewBase mOpenCvCameraView;
private boolean mIsCameraPermissionGranted = false;
private Button mStartCameraButton;
private ImageView mImageView;
private Module mPyTorchModel;
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni");
}
static {
System.loadLibrary("opencv_java3"); // Adjust the name based on your OpenCV version
}
private BaseLoaderCallback mLoaderCallback = new BaseLoaderCallback(this) {
@Override
public void onManagerConnected(int status) {
switch (status) {
case LoaderCallbackInterface.SUCCESS:
mOpenCvCameraView.enableView();
break;
default:
super.onManagerConnected(status);
break;
}
}
};
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
mOpenCvCameraView = findViewById(R.id.camera_view);
mOpenCvCameraView.setVisibility(SurfaceView.VISIBLE);
mOpenCvCameraView.setCvCameraViewListener(this);
mStartCameraButton = findViewById(R.id.start_camera_button);
mStartCameraButton.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
startCamera();
}
});
mImageView = findViewById(R.id.image_view);
if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, CAMERA_PERMISSION_REQUEST_CODE);
} else {
mIsCameraPermissionGranted = true;
}
// Load PyTorch model
mPyTorchModel = Module.load(assetFilePath(getApplicationContext(), "mobilenetv2_quantized.pt"));
}
private void startCamera() {
if (!OpenCVLoader.initDebug()) {
Log.d(TAG, "OpenCV library not found!");
OpenCVLoader.initAsync(OpenCVLoader.OPENCV_VERSION, this, mLoaderCallback);
} else {
Log.d(TAG, "OpenCV library found inside package. Using it!");
mLoaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
}
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if (requestCode == CAMERA_PERMISSION_REQUEST_CODE) {
if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
mIsCameraPermissionGranted = true;
startCamera(); // Start camera after getting permission
} else {
Toast.makeText(this, "Camera permission is required for this app", Toast.LENGTH_SHORT).show();
finish();
}
}
}
@Override
protected void onResume() {
super.onResume();
if (mIsCameraPermissionGranted) {
startCamera(); // Start camera when the activity resumes
}
}
@Override
protected void onPause() {
super.onPause();
if (mOpenCvCameraView != null) {
mOpenCvCameraView.disableView();
}
}
@Override
protected void onDestroy() {
super.onDestroy();
if (mOpenCvCameraView != null) {
mOpenCvCameraView.disableView();
}
}
@Override
public void onCameraViewStarted(int width, int height) {}
@Override
public void onCameraViewStopped() {}
@Override
public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame inputFrame) {
Mat rgba = inputFrame.rgba();
// Convert Mat to Bitmap
Bitmap bitmap = Bitmap.createBitmap(rgba.cols(), rgba.rows(), Bitmap.Config.ARGB_8888);
Utils.matToBitmap(rgba, bitmap);
// Process bitmap with PyTorch model
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
try {
IValue output = mPyTorchModel.forward(IValue.from(inputTensor));
// Process output and draw keypoints on rgba Mat
} catch (Exception e) {
Log.e(TAG, "Error processing image with PyTorch model: " + e.getMessage());
}
return rgba;
}
private String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
try (InputStream is = context.getAssets().open(assetName)) {
try (FileOutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}