mediapipe/docs/solutions/object_detection_saved_model.md
Copybara-Service d6fb7c365e Merge pull request #4145 from lucifertrj:poseDocs
PiperOrigin-RevId: 516231033
2023-03-13 09:22:34 -07:00

3.6 KiB

TensorFlow/TFLite Object Detection Model

TensorFlow model

The model is trained on MSCOCO 2014 dataset using TensorFlow Object Detection API. It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided pipeline.config. The model is a relatively compact model which has 0.171 mAP to achieve real-time performance on mobile devices. You can compare it with other models from the TensorFlow detection model zoo.

TFLite model

The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to this tutorial with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:

  • model.ckpt.index
  • model.ckpt.meta
  • model.ckpt.data-00000-of-00001
  • pipeline.config

Make sure you have installed these python libraries. Then to get the frozen graph, run the export_tflite_ssd_graph.py script from the models/research directory with this command:

$ PATH_TO_MODEL=path/to/the/model
$ bazel run object_detection:export_tflite_ssd_graph -- \
    --pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
    --trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
    --output_directory ${PATH_TO_MODEL} \
    --add_postprocessing_op=False

The exported model contains two files:

  • tflite_graph.pb
  • tflite_graph.pbtxt

The difference between this step and the one in the tutorial is that we set add_postprocessing_op to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.

Optional: You can install and use the graph tool to inspect the input/output of the exported model:

$ bazel run graph_transforms:summarize_graph -- \
    --in_graph=${PATH_TO_MODEL}/tflite_graph.pb

You should be able to see the input image size of the model is 320x320 and the outputs of the model are:

  • raw_outputs/box_encodings
  • raw_outputs/class_predictions

The last step is to convert the model to TFLite. You can look at this guide for more detail. For this example, you just need to run:

$ tflite_convert --  \
  --graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
  --output_file=${PATH_TO_MODEL}/model.tflite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --inference_type=FLOAT \
  --input_shapes=1,320,320,3 \
  --input_arrays=normalized_input_image_tensor \
  --output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions

Now you have the TFLite model model.tflite ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.