My .tflite model works on python but it dosen’t work well on android project.
It seems no difference between python and android.
What kinds of layer or function changes demention of input?
Standalone code to reproduce the issue
model
python
model = Sequential()
model.add(Input([256], dtype="int32"))
model.add(Embedding(35000, 10))
model.add(GRU(10))
# model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax', ))
model.compile(loss='binary_crossentropy', optimizer="adam", metrics=['accuracy'])
print(model.summary())
Model: “sequential_43”
┏━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━┩
│ embedding_35 (Embedding)│ (None, 256, 10) │ 350,000 │
├──────────────────┼─────────────┼─────────┤
│ gru_33 (GRU) │ (None, 10) │ 660 │
├──────────────────┼─────────────┼─────────┤
│ dense_27 (Dense) │ (None, 2) │ 22 │
└──────────────────┴─────────────┴─────────┘
Total params: 350,682 (1.34 MB)
Trainable params: 350,682 (1.34 MB)
Non-trainable params: 0 (0.00 B)
error occur
private AnalysisResult analyzeTextTFLite(String text) {
// convert text to custom ids
Feature feature = featureConverter.convert(text, ADD_SPECIAL_TOKENS);
int curSeqLen = feature.inputIds.length;
int[] tfInputs = new int[256];
for (int j = 0; j < curSeqLen; j++) {
tfInputs[j] = feature.inputIds[j];
}
Map<String, Object> inputsMap = new HashMap<>();
Map<String, Object> outputMap = new HashMap<>();
inputsMap.put("keras_tensor_198", tfInputs);
float[][] logits = new float[1][2];
outputMap.put("output_0", logits);
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
tflite.runSignature(inputsMap, outputMap);
Log.d(TAG, "Model inference score : " + logits[0][0] + "," + logits[0][1]);
float[] scores = new float[2];
scores[0] = argmax(logits[0]);
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
return new AnalysisResult(scores, moduleForwardDuration);
}
works well
input_details = interpreter.get_input_details()
print(input_details)
"""
[{'name': 'serving_default_keras_tensor_198:0',
'index': 0,
'shape': array([ 1, 256]),
'shape_signature': array([ -1, 256]),
'dtype': numpy.int32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}}]
output_details = interpreter.get_output_details()
print(output_details)
"""
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
"""
[{'name': 'StatefulPartitionedCall_1:0',
'index': 36,
'shape': array([1, 2]),
'shape_signature': array([-1, 2]),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}}]
"""
interpreter.set_tensor(input_details[0]['index'], encode_plus_inputs["input_ids"])
interpreter.invoke()
Relevant log output
FATAL EXCEPTION: ModuleActivity
Process: org.pytorch.demo, PID: 32712
java.lang.IllegalArgumentException: Internal error: Failed to run on the given Interpreter: tensorflow/lite/kernels/transpose.cc:63 op_context->perm->dims->data[0] != dims (3 != 2)
Node number 10 (TRANSPOSE) failed to prepare.
at org.tensorflow.lite.NativeInterpreterWrapper.run(Native Method)
at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:264)
at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:194)
at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:271)
at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:284)
at org.pytorch.demo.nlp.NSMCPytorchActivity.analyzeTextKoElectraTFLite(NSMCPytorchActivity.java:281)
at org.pytorch.demo.nlp.NSMCPytorchActivity.analyzeText(NSMCPytorchActivity.java:201)
at org.pytorch.demo.nlp.NSMCPytorchActivity.lambda$new$2$org-pytorch-demo-nlp-NSMCPytorchActivity(NSMCPytorchActivity.java:154)
at org.pytorch.demo.nlp.NSMCPytorchActivity$$ExternalSyntheticLambda2.run(Unknown Source:4)
at android.os.Handler.handleCallback(Handler.java:958)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loopOnce(Looper.java:230)
at android.os.Looper.loop(Looper.java:319)
at android.os.HandlerThread.run(HandlerThread.java:67)
I think my model works well in android code..
심인용 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.