I’m working on an image segmentation project in Flutter using TensorFlow Lite. My model is supposed to segment an image into 4 classes, but the output I’m getting only shows 2 classes. Here’s the relevant part of my code:
Future<Uint8List> predict(img.Image image) async {
if (interpreter == null) {
throw StateError('Cannot run inference, Interpreter is null');
}
print('Loading image into TensorImage...');
_inputImage = TensorImage(TfLiteType.float32);
_inputImage.loadImage(image);
print('Processing image...');
final ImageProcessor imageProcessor = ImageProcessorBuilder()
.add(ResizeOp(512, 512, ResizeMethod.NEAREST_NEIGHBOUR))
.build();
_inputImage = imageProcessor.process(_inputImage);
final input = [_inputImage.tensorBuffer.buffer.asFloat32List()];
// Prepare output tensor
final output = {
0: [List<List<List<double>>>.filled(512, List<List<double>>.filled(512, List<double>.filled(4, 0)))],
};
print('Running interpreter...');
interpreter.runForMultipleInputs([input], output);
print('Post-processing output...');
final processedOutput = postProcess2(output[0][0]);
print('Converting processed output to image...');
img.Image _image = convertProcessedOutputToImage3(processedOutput);
if (_image == null) {
throw StateError('Image conversion failed');
}
print('Encoding image to PNG...');
final Uint8List pngBytes = img.encodePng(_image);
await saveBytesToGallery(pngBytes);
print('Result image saved to gallery');
return pngBytes;
}
List<List<int>> postProcess2(List<List<List<double>>> rawOutput) {
final List<List<int>> processed = List.generate(512, (_) => List.filled(512, 0));
List<int> classCount = [0, 0, 0, 0];
for (int y = 0; y < 512; y++) {
for (int x = 0; x < 512; x++) {
List<double> pixelProbabilities = rawOutput[y][x];
int classIndex = argmax(pixelProbabilities);
processed[y][x] = classIndex;
classCount[classIndex]++;
if (y < 5 && x < 5) {
print("Pixel ($x, $y) probabilities: $pixelProbabilities, classified as: $classIndex");
}
}
}
print("Detailed class distribution:");
for (int i = 0; i < classCount.length; i++) {
print("Class $i: ${classCount[i]} pixels (${(classCount[i] / (512 * 512) * 100).toStringAsFixed(2)}%)");
}
return processed;
}
int argmax(List<double> list) {
return list.indexOf(list.reduce((a, b) => a > b ? a : b));
}
img.Image convertProcessedOutputToImage3(List<List<int>> processedOutput) {
final img.Image image = img.Image(512, 512);
for (int y = 0; y < 512; y++) {
for (int x = 0; x < 512; x++) {
switch (processedOutput[y][x]) {
case 0:
image.setPixel(x, y, img.getColor(0, 0, 0)); // Black for background
break;
case 1:
image.setPixel(x, y, img.getColor(128, 128, 128)); // Medium gray for main wound area
break;
case 2:
image.setPixel(x, y, img.getColor(192, 192, 192)); // Light gray for wound edges
break;
case 3:
image.setPixel(x, y, img.getColor(255, 255, 255)); // White for deeper/necrotic areas
break;
}
}
}
return image;
}
The output I’m getting shows this distribution:
Detailed class distribution:
Class 0: 0 pixels (0.00%)
Class 1: 0 pixels (0.00%)
Class 2: 3621 pixels (1.38%)
Class 3: 258523 pixels (98.62%)
As you can see, only classes 2 and 3 are present in the output, while classes 0 and 1 are completely absent.
I’ve checked the model’s output, and it does produce probabilities for all 4 classes. Here’s a sample of the raw output for a few pixels:
Pixel (0, 0) probabilities: [0.000976, 0.000002, 0.825438, 0.173583], classified as: 2
Pixel (1, 0) probabilities: [0.002672, 0.000000, 0.076512, 0.920817], classified as: 3
Pixel (2, 0) probabilities: [0.000085, 0.000000, 0.003305, 0.996610], classified as: 3
Is there anything in the predict function or interpreter initialization that could be affecting the output?