Spaces:
Runtime error
Runtime error
# TensorFlow Lite Android image classification example | |
This document walks through the code of a simple Android mobile application that | |
demonstrates | |
[image classification](https://www.tensorflow.org/lite/models/image_classification/overview) | |
using the device camera. | |
## Explore the code | |
We're now going to walk through the most important parts of the sample code. | |
### Get camera input | |
This mobile application gets the camera input using the functions defined in the | |
file | |
[`CameraActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java). | |
This file depends on | |
[`AndroidManifest.xml`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/AndroidManifest.xml) | |
to set the camera orientation. | |
`CameraActivity` also contains code to capture user preferences from the UI and | |
make them available to other classes via convenience methods. | |
```java | |
model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); | |
device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); | |
numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); | |
``` | |
### Classifier | |
This Image Classification Android reference app demonstrates two implementation | |
solutions, | |
[`lib_task_api`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api) | |
that leverages the out-of-box API from the | |
[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier), | |
and | |
[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support) | |
that creates the custom inference pipleline using the | |
[TensorFlow Lite Support Library](https://www.tensorflow.org/lite/inference_with_metadata/lite_support). | |
Both solutions implement the file `Classifier.java` (see | |
[the one in lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java) | |
and | |
[the one in lib_support](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java)) | |
that contains most of the complex logic for processing the camera input and | |
running inference. | |
Two subclasses of the `Classifier` exist, as in `ClassifierFloatMobileNet.java` | |
and `ClassifierQuantizedMobileNet.java`, which contain settings for both | |
floating point and | |
[quantized](https://www.tensorflow.org/lite/performance/post_training_quantization) | |
models. | |
The `Classifier` class implements a static method, `create`, which is used to | |
instantiate the appropriate subclass based on the supplied model type (quantized | |
vs floating point). | |
#### Using the TensorFlow Lite Task Library | |
Inference can be done using just a few lines of code with the | |
[`ImageClassifier`](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier) | |
in the TensorFlow Lite Task Library. | |
##### Load model and create ImageClassifier | |
`ImageClassifier` expects a model populated with the | |
[model metadata](https://www.tensorflow.org/lite/convert/metadata) and the label | |
file. See the | |
[model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements) | |
for more details. | |
`ImageClassifierOptions` allows manipulation on various inference options, such | |
as setting the maximum number of top scored results to return using | |
`setMaxResults(MAX_RESULTS)`, and setting the score threshold using | |
`setScoreThreshold(scoreThreshold)`. | |
```java | |
// Create the ImageClassifier instance. | |
ImageClassifierOptions options = | |
ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); | |
imageClassifier = ImageClassifier.createFromFileAndOptions(activity, | |
getModelPath(), options); | |
``` | |
`ImageClassifier` currently does not support configuring delegates and | |
multithread, but those are on our roadmap. Please stay tuned! | |
##### Run inference | |
`ImageClassifier` contains builtin logic to preprocess the input image, such as | |
rotating and resizing an image. Processing options can be configured through | |
`ImageProcessingOptions`. In the following example, input images are rotated to | |
the up-right angle and cropped to the center as the model expects a square input | |
(`224x224`). See the | |
[Java doc of `ImageClassifier`](https://github.com/tensorflow/tflite-support/blob/195b574f0aa9856c618b3f1ad87bd185cddeb657/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java#L22) | |
for more details about how the underlying image processing is performed. | |
```java | |
TensorImage inputImage = TensorImage.fromBitmap(bitmap); | |
int width = bitmap.getWidth(); | |
int height = bitmap.getHeight(); | |
int cropSize = min(width, height); | |
ImageProcessingOptions imageOptions = | |
ImageProcessingOptions.builder() | |
.setOrientation(getOrientation(sensorOrientation)) | |
// Set the ROI to the center of the image. | |
.setRoi( | |
new Rect( | |
/*left=*/ (width - cropSize) / 2, | |
/*top=*/ (height - cropSize) / 2, | |
/*right=*/ (width + cropSize) / 2, | |
/*bottom=*/ (height + cropSize) / 2)) | |
.build(); | |
List<Classifications> results = imageClassifier.classify(inputImage, | |
imageOptions); | |
``` | |
The output of `ImageClassifier` is a list of `Classifications` instance, where | |
each `Classifications` element is a single head classification result. All the | |
demo models are single head models, therefore, `results` only contains one | |
`Classifications` object. Use `Classifications.getCategories()` to get a list of | |
top-k categories as specified with `MAX_RESULTS`. Each `Category` object | |
contains the srting label and the score of that category. | |
To match the implementation of | |
[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support), | |
`results` is converted into `List<Recognition>` in the method, | |
`getRecognitions`. | |
#### Using the TensorFlow Lite Support Library | |
##### Load model and create interpreter | |
To perform inference, we need to load a model file and instantiate an | |
`Interpreter`. This happens in the constructor of the `Classifier` class, along | |
with loading the list of class labels. Information about the device type and | |
number of threads is used to configure the `Interpreter` via the | |
`Interpreter.Options` instance passed into its constructor. Note that if a GPU, | |
DSP (Digital Signal Processor) or NPU (Neural Processing Unit) is available, a | |
[`Delegate`](https://www.tensorflow.org/lite/performance/delegates) can be used | |
to take full advantage of these hardware. | |
Please note that there are performance edge cases and developers are adviced to | |
test with a representative set of devices prior to production. | |
```java | |
protected Classifier(Activity activity, Device device, int numThreads) throws | |
IOException { | |
tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); | |
switch (device) { | |
case NNAPI: | |
nnApiDelegate = new NnApiDelegate(); | |
tfliteOptions.addDelegate(nnApiDelegate); | |
break; | |
case GPU: | |
gpuDelegate = new GpuDelegate(); | |
tfliteOptions.addDelegate(gpuDelegate); | |
break; | |
case CPU: | |
break; | |
} | |
tfliteOptions.setNumThreads(numThreads); | |
tflite = new Interpreter(tfliteModel, tfliteOptions); | |
labels = FileUtil.loadLabels(activity, getLabelPath()); | |
... | |
``` | |
For Android devices, we recommend pre-loading and memory mapping the model file | |
to offer faster load times and reduce the dirty pages in memory. The method | |
`FileUtil.loadMappedFile` does this, returning a `MappedByteBuffer` containing | |
the model. | |
The `MappedByteBuffer` is passed into the `Interpreter` constructor, along with | |
an `Interpreter.Options` object. This object can be used to configure the | |
interpreter, for example by setting the number of threads (`.setNumThreads(1)`) | |
or enabling [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks) | |
(`.addDelegate(nnApiDelegate)`). | |
##### Pre-process bitmap image | |
Next in the `Classifier` constructor, we take the input camera bitmap image, | |
convert it to a `TensorImage` format for efficient processing and pre-process | |
it. The steps are shown in the private 'loadImage' method: | |
```java | |
/** Loads input image, and applys preprocessing. */ | |
private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { | |
// Loads bitmap into a TensorImage. | |
image.load(bitmap); | |
// Creates processor for the TensorImage. | |
int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight()); | |
int numRoration = sensorOrientation / 90; | |
ImageProcessor imageProcessor = | |
new ImageProcessor.Builder() | |
.add(new ResizeWithCropOrPadOp(cropSize, cropSize)) | |
.add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.BILINEAR)) | |
.add(new Rot90Op(numRoration)) | |
.add(getPreprocessNormalizeOp()) | |
.build(); | |
return imageProcessor.process(inputImageBuffer); | |
} | |
``` | |
The pre-processing is largely the same for quantized and float models with one | |
exception: Normalization. | |
In `ClassifierFloatMobileNet`, the normalization parameters are defined as: | |
```java | |
private static final float IMAGE_MEAN = 127.5f; | |
private static final float IMAGE_STD = 127.5f; | |
``` | |
In `ClassifierQuantizedMobileNet`, normalization is not required. Thus the | |
nomalization parameters are defined as: | |
```java | |
private static final float IMAGE_MEAN = 0.0f; | |
private static final float IMAGE_STD = 1.0f; | |
``` | |
##### Allocate output object | |
Initiate the output `TensorBuffer` for the output of the model. | |
```java | |
/** Output probability TensorBuffer. */ | |
private final TensorBuffer outputProbabilityBuffer; | |
//... | |
// Get the array size for the output buffer from the TensorFlow Lite model file | |
int probabilityTensorIndex = 0; | |
int[] probabilityShape = | |
tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, 1001} | |
DataType probabilityDataType = | |
tflite.getOutputTensor(probabilityTensorIndex).dataType(); | |
// Creates the output tensor and its processor. | |
outputProbabilityBuffer = | |
TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); | |
// Creates the post processor for the output probability. | |
probabilityProcessor = | |
new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); | |
``` | |
For quantized models, we need to de-quantize the prediction with the NormalizeOp | |
(as they are all essentially linear transformation). For float model, | |
de-quantize is not required. But to uniform the API, de-quantize is added to | |
float model too. Mean and std are set to 0.0f and 1.0f, respectively. To be more | |
specific, | |
In `ClassifierQuantizedMobileNet`, the normalized parameters are defined as: | |
```java | |
private static final float PROBABILITY_MEAN = 0.0f; | |
private static final float PROBABILITY_STD = 255.0f; | |
``` | |
In `ClassifierFloatMobileNet`, the normalized parameters are defined as: | |
```java | |
private static final float PROBABILITY_MEAN = 0.0f; | |
private static final float PROBABILITY_STD = 1.0f; | |
``` | |
##### Run inference | |
Inference is performed using the following in `Classifier` class: | |
```java | |
tflite.run(inputImageBuffer.getBuffer(), | |
outputProbabilityBuffer.getBuffer().rewind()); | |
``` | |
##### Recognize image | |
Rather than call `run` directly, the method `recognizeImage` is used. It accepts | |
a bitmap and sensor orientation, runs inference, and returns a sorted `List` of | |
`Recognition` instances, each corresponding to a label. The method will return a | |
number of results bounded by `MAX_RESULTS`, which is 3 by default. | |
`Recognition` is a simple class that contains information about a specific | |
recognition result, including its `title` and `confidence`. Using the | |
post-processing normalization method specified, the confidence is converted to | |
between 0 and 1 of a given class being represented by the image. | |
```java | |
/** Gets the label to probability map. */ | |
Map<String, Float> labeledProbability = | |
new TensorLabel(labels, | |
probabilityProcessor.process(outputProbabilityBuffer)) | |
.getMapWithFloatValue(); | |
``` | |
A `PriorityQueue` is used for sorting. | |
```java | |
/** Gets the top-k results. */ | |
private static List<Recognition> getTopKProbability( | |
Map<String, Float> labelProb) { | |
// Find the best classifications. | |
PriorityQueue<Recognition> pq = | |
new PriorityQueue<>( | |
MAX_RESULTS, | |
new Comparator<Recognition>() { | |
@Override | |
public int compare(Recognition lhs, Recognition rhs) { | |
// Intentionally reversed to put high confidence at the head of | |
// the queue. | |
return Float.compare(rhs.getConfidence(), lhs.getConfidence()); | |
} | |
}); | |
for (Map.Entry<String, Float> entry : labelProb.entrySet()) { | |
pq.add(new Recognition("" + entry.getKey(), entry.getKey(), | |
entry.getValue(), null)); | |
} | |
final ArrayList<Recognition> recognitions = new ArrayList<>(); | |
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); | |
for (int i = 0; i < recognitionsSize; ++i) { | |
recognitions.add(pq.poll()); | |
} | |
return recognitions; | |
} | |
``` | |
### Display results | |
The classifier is invoked and inference results are displayed by the | |
`processImage()` function in | |
[`ClassifierActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java). | |
`ClassifierActivity` is a subclass of `CameraActivity` that contains method | |
implementations that render the camera image, run classification, and display | |
the results. The method `processImage()` runs classification on a background | |
thread as fast as possible, rendering information on the UI thread to avoid | |
blocking inference and creating latency. | |
```java | |
@Override | |
protected void processImage() { | |
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, | |
previewHeight); | |
final int imageSizeX = classifier.getImageSizeX(); | |
final int imageSizeY = classifier.getImageSizeY(); | |
runInBackground( | |
new Runnable() { | |
@Override | |
public void run() { | |
if (classifier != null) { | |
final long startTime = SystemClock.uptimeMillis(); | |
final List<Classifier.Recognition> results = | |
classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); | |
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; | |
LOGGER.v("Detect: %s", results); | |
runOnUiThread( | |
new Runnable() { | |
@Override | |
public void run() { | |
showResultsInBottomSheet(results); | |
showFrameInfo(previewWidth + "x" + previewHeight); | |
showCropInfo(imageSizeX + "x" + imageSizeY); | |
showCameraResolution(imageSizeX + "x" + imageSizeY); | |
showRotationInfo(String.valueOf(sensorOrientation)); | |
showInference(lastProcessingTimeMs + "ms"); | |
} | |
}); | |
} | |
readyForNextImage(); | |
} | |
}); | |
} | |
``` | |
Another important role of `ClassifierActivity` is to determine user preferences | |
(by interrogating `CameraActivity`), and instantiate the appropriately | |
configured `Classifier` subclass. This happens when the video feed begins (via | |
`onPreviewSizeChosen()`) and when options are changed in the UI (via | |
`onInferenceConfigurationChanged()`). | |
```java | |
private void recreateClassifier(Model model, Device device, int numThreads) { | |
if (classifier != null) { | |
LOGGER.d("Closing classifier."); | |
classifier.close(); | |
classifier = null; | |
} | |
if (device == Device.GPU && model == Model.QUANTIZED) { | |
LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); | |
runOnUiThread( | |
() -> { | |
Toast.makeText(this, "GPU does not yet supported quantized models.", | |
Toast.LENGTH_LONG) | |
.show(); | |
}); | |
return; | |
} | |
try { | |
LOGGER.d( | |
"Creating classifier (model=%s, device=%s, numThreads=%d)", model, | |
device, numThreads); | |
classifier = Classifier.create(this, model, device, numThreads); | |
} catch (IOException e) { | |
LOGGER.e(e, "Failed to create classifier."); | |
} | |
} | |
``` | |