Merge branch 'google:master' into image-classification-python-impl

This commit is contained in:
Kinar R 2022-10-14 15:31:01 +05:30 committed by GitHub
commit f160f28039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 4607 additions and 234 deletions

View File

@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow
```bash
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt
```
This will open up your webcam as long as it is connected and on. Any errors

View File

@ -209,11 +209,18 @@ cc_library(
alwayslink = 1,
)
mediapipe_proto_library(
name = "rotation_mode_proto",
srcs = ["rotation_mode.proto"],
visibility = ["//visibility:public"],
)
mediapipe_proto_library(
name = "image_transformation_calculator_proto",
srcs = ["image_transformation_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
":rotation_mode_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
@ -238,6 +245,7 @@ cc_library(
}),
visibility = ["//visibility:public"],
deps = [
":rotation_mode_cc_proto",
":image_transformation_calculator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/calculators/image/rotation_mode.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"

View File

@ -16,20 +16,10 @@ syntax = "proto2";
package mediapipe;
import "mediapipe/calculators/image/rotation_mode.proto";
import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}
message ImageTransformationCalculatorOptions {
extend CalculatorOptions {
optional ImageTransformationCalculatorOptions ext = 251952830;

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}

View File

@ -14,15 +14,10 @@ cc_library(
name = "builder",
hdrs = ["builder.h"],
deps = [
":const_str",
":contract",
":node",
":packet",
":port",
"//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)

View File

@ -5,12 +5,7 @@
#include <type_traits>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_contract.h"
@ -120,6 +115,9 @@ using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
} // namespace internal_builder
template <bool IsSide, typename T = internal::Generic>
class SourceImpl;
// These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API.
template <bool IsSide, typename T = internal::Generic>
@ -137,10 +135,14 @@ class DestinationImpl {
return DestinationImpl<IsSide, U>(&base_);
}
private:
DestinationBase& base_;
template <bool Source_IsSide, typename Source_T>
friend class SourceImpl;
};
template <bool IsSide, typename T = internal::Generic>
template <bool IsSide, typename T>
class SourceImpl {
public:
using Base = SourceBase;
@ -438,8 +440,9 @@ class Graph {
// Creates a node of a specific type. Should be used for pure interfaces,
// which do not have a built-in type string.
template <class Calc>
Node<Calc>& AddNode(const std::string& type) {
auto node = std::make_unique<Node<Calc>>(type);
Node<Calc>& AddNode(absl::string_view type) {
auto node =
std::make_unique<Node<Calc>>(std::string(type.data(), type.size()));
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
@ -447,16 +450,18 @@ class Graph {
// Creates a generic node, with no compile-time checking of inputs and
// outputs. This can be used for calculators whose contract is not visible.
GenericNode& AddNode(const std::string& type) {
auto node = std::make_unique<GenericNode>(type);
GenericNode& AddNode(absl::string_view type) {
auto node =
std::make_unique<GenericNode>(std::string(type.data(), type.size()));
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
}
// For legacy PacketGenerators.
PacketGenerator& AddPacketGenerator(const std::string& type) {
auto node = std::make_unique<PacketGenerator>(type);
PacketGenerator& AddPacketGenerator(absl::string_view type) {
auto node = std::make_unique<PacketGenerator>(
std::string(type.data(), type.size()));
auto node_p = node.get();
packet_gens_.emplace_back(std::move(node));
return *node_p;

View File

@ -30,3 +30,10 @@ android_library(
"@maven//:com_google_guava_guava",
],
)
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -164,7 +164,10 @@ EOF
assets_dir = assets_dir,
)
_aar_with_jni(name, name + "_android_lib")
mediapipe_build_aar_with_jni(
name = name,
android_library = name + "_android_lib",
)
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
"""Generates MediaPipe jni library.
@ -203,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
alwayslink = 1,
)
def _aar_with_jni(name, android_library):
def mediapipe_build_aar_with_jni(name, android_library):
"""Builds MediaPipe AAR with jni.
Args:
name: The bazel target name.
android_library: the android library that contains jni.
"""
# Generates dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app target below)
native.genrule(
@ -214,7 +224,7 @@ cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
package="dummy.package.for.so">
<uses-sdk android:minSdkVersion="21"/>
<uses-sdk android:minSdkVersion="24"/>
</manifest>
EOF
""",
@ -241,6 +251,7 @@ chmod +w $(location :{}.aar)
origdir=$$PWD
cd $$(mktemp -d)
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
find lib -name *_dummy_app.so -delete
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
@ -287,6 +298,36 @@ def mediapipe_java_proto_srcs(name = ""):
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:stream_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StreamHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_factory_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketFactoryProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_generator_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:status_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StatusHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:mediapipe_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",

View File

@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training models. Shared across tasks."""
import dataclasses
import tempfile
from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass
class BaseHParams:
"""Hyperparameters used for training models.
A common set of hyperparameters shared by the training jobs of all model
maker tasks.
Attributes:
learning_rate: The learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to
use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy
documentation for more details:
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
tpu: The Cloud TPU to use for training. This should be either the name used
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
"""
# Parameters for train configuration
learning_rate: float
batch_size: int
epochs: int
steps_per_epoch: Optional[int] = None
# Dataset-related parameters
shuffle: bool = False
# Parameters for model / checkpoint files
export_dir: str = tempfile.mkdtemp()
# Parameters for hardware acceleration
distribution_strategy: str = 'off'
num_gpus: int = -1 # default value of -1 means use all available GPUs
tpu: str = ''

View File

@ -26,6 +26,7 @@ from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
import numpy as np
import tensorflow as tf
# resources dependency
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import quantization
@ -33,6 +34,31 @@ DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000
def load_keras_model(model_path: str,
compile_on_load: bool = False) -> tf.keras.Model:
"""Loads a tensorflow Keras model from file and returns the Keras model.
Args:
model_path: Relative path to a directory containing model data, such as
<parent_path>/saved_model/.
compile_on_load: Whether the model should be compiled while loading. If
False, the model returned has to be compiled with the appropriate loss
function and custom metrics before running for inference on a test
dataset.
Returns:
A tensorflow Keras model.
"""
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
# with the `model_path` which defines the relative path under mediapipe/, it
# yields to the aboslution path of the model files directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, model_path)
return tf.keras.models.load_model(
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int:
@ -68,7 +94,8 @@ def export_tflite(
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)):
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None):
"""Converts the model to tflite format and saves it.
Args:
@ -76,6 +103,9 @@ def export_tflite(
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training.
"""
if tflite_filepath is None:
raise ValueError(
@ -87,7 +117,8 @@ def export_tflite(
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
if quantization_config:
converter = quantization_config.set_converter_with_quantization(converter)
converter = quantization_config.set_converter_with_quantization(
converter, preprocess=preprocess)
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()

View File

@ -15,7 +15,6 @@
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
@ -25,6 +24,18 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
def test_load_model(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
model.save(saved_model_path)
loaded_model = model_util.load_keras_model(saved_model_path)
input_tensors = test_util.create_random_sample(size=[1, input_dim])
model_output = model.predict_on_batch(input_tensors)
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all())
@parameterized.named_parameters(
dict(
testcase_name='input_only_steps_per_epoch',
@ -124,9 +135,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
np.random.seed(0)
random_input = np.random.uniform(
low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32)
random_input = test_util.create_random_sample(
size=[1, input_dim], high=max_input_value)
random_input = tf.convert_to_tensor(random_input)
self.assertTrue(
test_util.is_same_output(

View File

@ -46,6 +46,24 @@ def create_dataset(data_size: int,
return dataset
def create_random_sample(size: Union[int, List[int]],
low: float = 0,
high: float = 1) -> np.ndarray:
"""Creates and returns a random sample with floating point values.
Args:
size: Size of the output multi-dimensional array.
low: Lower boundary of the output values.
high: Higher boundary of the output values.
Returns:
1D array if the size is scalar. Otherwise, N-D array whose dimension equals
input size.
"""
np.random.seed(0)
return np.random.uniform(low=low, high=high, size=size).astype(np.float32)
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
"""Builds a simple Keras model for test."""
inputs = tf.keras.layers.Input(shape=input_shape)

View File

@ -21,6 +21,9 @@ import "mediapipe/framework/formats/classification.proto";
import "mediapipe/framework/formats/landmark.proto";
import "mediapipe/framework/formats/rect.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "LandmarksDetectionResultProto";
message LandmarksDetectionResult {
optional mediapipe.NormalizedLandmarkList landmarks = 1;
optional mediapipe.ClassificationList classifications = 2;

View File

@ -180,6 +180,15 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
if (status_or_packets.value()[kHandGesturesStreamName].IsEmpty()) {
Packet empty_packet =
status_or_packets.value()[kHandGesturesStreamName];
result_callback(
{{{}, {}, {}, {}}}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return;
}
Packet gesture_packet =
status_or_packets.value()[kHandGesturesStreamName];
Packet handedness_packet =
@ -188,7 +197,6 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
status_or_packets.value()[kHandLandmarksStreamName];
Packet hand_world_landmarks_packet =
status_or_packets.value()[kHandWorldLandmarksStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(
{{gesture_packet.Get<std::vector<ClassificationList>>(),
handedness_packet.Get<std::vector<ClassificationList>>(),
@ -218,6 +226,9 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
ASSIGN_OR_RETURN(auto output_packets,
ProcessImageData({{kImageInStreamName,
MakePacket<Image>(std::move(image))}}));
if (output_packets[kHandGesturesStreamName].IsEmpty()) {
return {{{}, {}, {}, {}}};
}
return {
{/* gestures= */ {output_packets[kHandGesturesStreamName]
.Get<std::vector<ClassificationList>>()},
@ -247,6 +258,9 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kHandGesturesStreamName].IsEmpty()) {
return {{{}, {}, {}, {}}};
}
return {
{/* gestures= */ {output_packets[kHandGesturesStreamName]
.Get<std::vector<ClassificationList>>()},

View File

@ -87,9 +87,8 @@ struct GestureRecognizerOptions {
// Performs hand gesture recognition on the given image.
//
// TODO add the link to DevSite.
// This API expects expects a pre-trained hand gesture model asset bundle, or a
// custom one created using Model Maker. See <link to the DevSite documentation
// page>.
// This API expects a pre-trained hand gesture model asset bundle, or a custom
// one created using Model Maker. See <link to the DevSite documentation page>.
//
// Inputs:
// Image

View File

@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer";
option java_outer_classname = "GestureClassifierGraphOptionsProto";
message GestureClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureClassifierGraphOptions ext = 478825465;

View File

@ -20,6 +20,9 @@ package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer";
option java_outer_classname = "GestureEmbedderGraphOptionsProto";
message GestureEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureEmbedderGraphOptions ext = 478825422;

View File

@ -23,6 +23,9 @@ import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer";
option java_outer_classname = "HandGestureRecognizerGraphOptionsProto";
message HandGestureRecognizerGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandGestureRecognizerGraphOptions ext = 463370452;

View File

@ -22,6 +22,9 @@ import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.handlandmarker";
option java_outer_classname = "HandLandmarkerGraphOptionsProto";
message HandLandmarkerGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandLandmarkerGraphOptions ext = 462713202;

View File

@ -20,6 +20,9 @@ package mediapipe.tasks.vision.hand_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.handlandmarker";
option java_outer_classname = "HandLandmarksDetectorGraphOptionsProto";
message HandLandmarksDetectorGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandLandmarksDetectorGraphOptions ext = 474472470;

View File

@ -59,14 +59,24 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap;
// Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() {
NormalizedRect norm_rect;
norm_rect.set_x_center(0.5);
norm_rect.set_y_center(0.5);
norm_rect.set_width(1);
norm_rect.set_height(1);
return norm_rect;
// Returns a NormalizedRect covering the full image if input is not present.
// Otherwise, makes sure the x_center, y_center, width and height are set in
// case only a rotation was provided in the input.
NormalizedRect FillNormalizedRect(
std::optional<NormalizedRect> normalized_rect) {
NormalizedRect result;
if (normalized_rect.has_value()) {
result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (!has_coordinates) {
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
}
return result;
}
// Creates a MediaPipe graph config that contains a subgraph node of
@ -154,15 +164,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
}
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
Image image, std::optional<NormalizedRect> roi) {
Image image, std::optional<NormalizedRect> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
@ -173,15 +182,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
}
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
Image image, int64 timestamp_ms,
std::optional<NormalizedRect> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
@ -195,16 +204,16 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
.Get<ClassificationResult>();
}
absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
std::optional<NormalizedRect> roi) {
absl::Status ImageClassifier::ClassifyAsync(
Image image, int64 timestamp_ms,
std::optional<NormalizedRect> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))

View File

@ -105,9 +105,18 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
static absl::StatusOr<std::unique_ptr<ImageClassifier>> Create(
std::unique_ptr<ImageClassifierOptions> options);
// Performs image classification on the provided single image. Classification
// is performed on the region of interest specified by the `roi` argument if
// provided, or on the entire image otherwise.
// Performs image classification on the provided single image.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation).
// and/or
// - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageClassifier is created with the image
// running mode.
@ -117,11 +126,21 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// YUVToImageCalculator is integrated.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
std::optional<mediapipe::NormalizedRect> image_processing_options =
std::nullopt);
// Performs image classification on the provided video frame. Classification
// is performed on the region of interested specified by the `roi` argument if
// provided, or on the entire image otherwise.
// Performs image classification on the provided video frame.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation).
// and/or
// - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageClassifier is created with the video
// running mode.
@ -131,12 +150,22 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing.
absl::StatusOr<components::containers::proto::ClassificationResult>
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
std::optional<mediapipe::NormalizedRect>
image_processing_options = std::nullopt);
// Sends live image data to image classification, and the results will be
// available via the "result_callback" provided in the ImageClassifierOptions.
// Classification is performed on the region of interested specified by the
// `roi` argument if provided, or on the entire image otherwise.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation).
// and/or
// - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageClassifier is created with the live
// stream running mode.
@ -153,9 +182,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status ClassifyAsync(
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect>
image_processing_options = std::nullopt);
// TODO: add Classify() variants taking a region of interest as
// additional argument.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h"
#include <cmath>
#include <functional>
#include <memory>
#include <string>
@ -546,18 +547,102 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// NormalizedRect around the soccer ball.
NormalizedRect roi;
roi.set_x_center(0.532);
roi.set_y_center(0.521);
roi.set_width(0.164);
roi.set_height(0.427);
// Crop around the soccer ball.
NormalizedRect image_processing_options;
image_processing_options.set_x_center(0.532);
image_processing_options.set_y_center(0.521);
image_processing_options.set_width(0.164);
image_processing_options.set_height(0.427);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi));
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
auto options = std::make_unique<ImageClassifierOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->classifier_options.max_results = 3;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Specify a 90° anti-clockwise rotation.
NormalizedRect image_processing_options;
image_processing_options.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
// Results differ slightly from the non-rotated image, but that's expected
// as models are very sensitive to the slightest numerical differences
// introduced by the rotation and JPG encoding.
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.6371766
category_name: "cheeseburger"
}
categories {
index: 963
score: 0.049443405
category_name: "meat loaf"
}
categories {
index: 925
score: 0.047918003
category_name: "guacamole"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"multi_objects_rotated.jpg")));
auto options = std::make_unique<ImageClassifierOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Crop around the chair, with 90° anti-clockwise rotation.
NormalizedRect image_processing_options;
image_processing_options.set_x_center(0.2821);
image_processing_options.set_y_center(0.2406);
image_processing_options.set_width(0.5642);
image_processing_options.set_height(0.1286);
image_processing_options.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
ExpectApproximatelyEqual(results,
ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 560
score: 0.6800408
category_name: "folding chair"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
@ -646,16 +731,17 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// NormalizedRect around the soccer ball.
NormalizedRect roi;
roi.set_x_center(0.532);
roi.set_y_center(0.521);
roi.set_width(0.164);
roi.set_height(0.427);
// Crop around the soccer ball.
NormalizedRect image_processing_options;
image_processing_options.set_x_center(0.532);
image_processing_options.set_y_center(0.521);
image_processing_options.set_width(0.164);
image_processing_options.set_height(0.427);
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results,
image_classifier->ClassifyForVideo(image, i, roi));
MP_ASSERT_OK_AND_ASSIGN(
auto results,
image_classifier->ClassifyForVideo(image, i, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
}
MP_ASSERT_OK(image_classifier->Close());
@ -790,15 +876,16 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// NormalizedRect around the soccer ball.
NormalizedRect roi;
roi.set_x_center(0.532);
roi.set_y_center(0.521);
roi.set_width(0.164);
roi.set_height(0.427);
// Crop around the soccer ball.
NormalizedRect image_processing_options;
image_processing_options.set_x_center(0.532);
image_processing_options.set_y_center(0.521);
image_processing_options.set_width(0.164);
image_processing_options.set_height(0.427);
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi));
MP_ASSERT_OK(
image_classifier->ClassifyAsync(image, i, image_processing_options));
}
MP_ASSERT_OK(image_classifier->Close());

View File

@ -34,8 +34,8 @@ android_binary(
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:objectdetector",
"//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout",
"//third_party:opencv",

View File

@ -63,3 +63,10 @@ android_library(
"@maven//:com_google_guava_guava",
],
)
# Expose the java source files for building mediapipe tasks core AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
)

View File

@ -28,3 +28,10 @@ android_library(
"@maven//:com_google_guava_guava",
],
)
# Expose the java source files for building mediapipe tasks core AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
)

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.core">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -20,6 +20,7 @@ android_library(
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "AndroidManifest.xml",
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",

View File

@ -0,0 +1,142 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["core/*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"@maven//:com_google_guava_guava",
],
)
# The native library of all MediaPipe vision tasks.
cc_binary(
name = "libmediapipe_tasks_vision_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_vision_jni_lib",
srcs = [":libmediapipe_tasks_vision_jni.so"],
alwayslink = 1,
)
android_library(
name = "objectdetector",
srcs = [
"objectdetector/ObjectDetectionResult.java",
"objectdetector/ObjectDetector.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "objectdetector/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "imageclassifier",
srcs = [
"imageclassifier/ImageClassificationResult.java",
"imageclassifier/ImageClassifier.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "imageclassifier/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "gesturerecognizer",
srcs = [
"gesturerecognizer/GestureRecognitionResult.java",
"gesturerecognizer/GestureRecognizer.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "gesturerecognizer/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -1,53 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"@maven//:com_google_guava_guava",
],
)
# The native library of all MediaPipe vision tasks.
cc_binary(
name = "libmediapipe_tasks_vision_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_vision_jni_lib",
srcs = [":libmediapipe_tasks_vision_jni.so"],
alwayslink = 1,
)

View File

@ -1,40 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "gesturerecognizer",
srcs = [
"GestureRecognitionResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = ":AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,466 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.gesturerecognizer;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.handdetector.HandDetectorGraphOptionsProto;
import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarkerGraphOptionsProto;
import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarksDetectorGraphOptionsProto;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs gesture recognition on images.
*
* <p>This API expects a pre-trained hand gesture model asset bundle, or a custom one created using
* Model Maker. See <TODO link to the DevSite documentation page>.
*
* <ul>
* <li>Input image {@link Image}
* <ul>
* <li>The image that gesture recognition runs on.
* </ul>
* <li>Output GestureRecognitionResult {@link GestureRecognitionResult}
* <ul>
* <li>A GestureRecognitionResult containing hand landmarks and recognized hand gestures.
* </ul>
* </ul>
*/
public final class GestureRecognizer extends BaseVisionTaskApi {
private static final String TAG = GestureRecognizer.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"LANDMARKS:hand_landmarks",
"WORLD_LANDMARKS:world_hand_landmarks",
"HANDEDNESS:handedness",
"HAND_GESTURES:hand_gestures",
"IMAGE:image_out"));
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
private static final int HANDEDNESS_OUT_STREAM_INDEX = 2;
private static final int HAND_GESTURES_OUT_STREAM_INDEX = 3;
private static final int IMAGE_OUT_STREAM_INDEX = 4;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
/**
* Creates a {@link GestureRecognizer} instance from a model file and the default {@link
* GestureRecognizerOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the gesture recognition model with metadata in the assets.
* @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation.
*/
public static GestureRecognizer createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link GestureRecognizer} instance from a model file and the default {@link
* GestureRecognizerOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the gesture recognition model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation.
*/
public static GestureRecognizer createFromFile(Context context, File modelFile)
throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates a {@link GestureRecognizer} instance from a model buffer and the default {@link
* GestureRecognizerOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
* model.
* @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation.
*/
public static GestureRecognizer createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link GestureRecognizer} instance from a {@link GestureRecognizerOptions}.
*
* @param context an Android {@link Context}.
* @param recognizerOptions a {@link GestureRecognizerOptions} instance.
* @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation.
*/
public static GestureRecognizer createFromOptions(
Context context, GestureRecognizerOptions recognizerOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<GestureRecognitionResult, Image> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<GestureRecognitionResult, Image>() {
@Override
public GestureRecognitionResult convertToTaskResult(List<Packet> packets) {
// If there is no hands detected in the image, just returns empty lists.
if (packets.get(HAND_GESTURES_OUT_STREAM_INDEX).isEmpty()) {
return GestureRecognitionResult.create(
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp());
}
return GestureRecognitionResult.create(
PacketGetter.getProtoVector(
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
PacketGetter.getProtoVector(
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
PacketGetter.getProtoVector(
packets.get(HANDEDNESS_OUT_STREAM_INDEX), ClassificationList.parser()),
PacketGetter.getProtoVector(
packets.get(HAND_GESTURES_OUT_STREAM_INDEX), ClassificationList.parser()),
packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp());
}
@Override
public Image convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
recognizerOptions.resultListener().ifPresent(handler::setResultListener);
recognizerOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<GestureRecognizerOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(recognizerOptions)
.setEnableFlowLimiting(recognizerOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
return new GestureRecognizer(runner, recognizerOptions.runningMode());
}
/**
* Constructor to initialize an {@link GestureRecognizer} from a {@link TaskRunner} and a {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
}
/**
* Performs gesture recognition on the provided single image. Only use this method when the {@link
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
* for input image format.
*
* <p>{@link GestureRecognizer} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public GestureRecognitionResult recognize(Image inputImage) {
return (GestureRecognitionResult) processImageData(inputImage);
}
/**
* Performs gesture recognition on the provided video frame. Only use this method when the {@link
* GestureRecognizer} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link GestureRecognizer} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) {
return (GestureRecognitionResult) processVideoData(inputImage, inputTimestampMs);
}
/**
* Sends live image data to perform gesture recognition, and the results will be available via the
* {@link ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method
* when the {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the gesture recognizer. The input timestamps must be monotonically increasing.
*
* <p>{@link GestureRecognizer} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void recognizeAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(inputImage, inputTimestampMs);
}
/** Options for setting up an {@link GestureRecognizer}. */
@AutoValue
public abstract static class GestureRecognizerOptions extends TaskOptions {
/** Builder for {@link GestureRecognizerOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the gesture recognizer task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the gesture recognizer task. Default to the image mode. Gesture
* recognizer has three modes:
*
* <ul>
* <li>IMAGE: The mode for recognizing gestures on single image inputs.
* <li>VIDEO: The mode for recognizing gestures on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for recognizing gestures on a live stream of input data,
* such as from camera. In this mode, {@code setResultListener} must be called to set up
* a listener to receive the recognition results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode value);
// TODO: remove these. Temporary solutions before bundle asset is ready.
public abstract Builder setBaseOptionsHandDetector(BaseOptions value);
public abstract Builder setBaseOptionsHandLandmarker(BaseOptions value);
public abstract Builder setBaseOptionsGestureRecognizer(BaseOptions value);
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
public abstract Builder setNumHands(Integer value);
/** Sets minimum confidence score for the hand detection to be considered successfully */
public abstract Builder setMinHandDetectionConfidence(Float value);
/** Sets minimum confidence score of hand presence score in the hand landmark detection. */
public abstract Builder setMinHandPresenceConfidence(Float value);
/** Sets the minimum confidence score for the hand tracking to be considered successfully. */
public abstract Builder setMinTrackingConfidence(Float value);
/**
* Sets the minimum confidence score for the gestures to be considered successfully. If < 0,
* the gesture confidence threshold=0.5 for the model is used.
*
* <p>TODO Note this option is subject to change, after scoring merging
* calculator is implemented.
*/
public abstract Builder setMinGestureConfidence(Float value);
/**
* Sets the result listener to receive the detection results asynchronously when the gesture
* recognizer is in the live stream mode.
*/
public abstract Builder setResultListener(
ResultListener<GestureRecognitionResult, Image> value);
/** Sets an optional error listener. */
public abstract Builder setErrorListener(ErrorListener value);
abstract GestureRecognizerOptions autoBuild();
/**
* Validates and builds the {@link GestureRecognizerOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the object detector is
* in the live stream mode.
*/
public final GestureRecognizerOptions build() {
GestureRecognizerOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The gesture recognizer is in the live stream mode, a user-defined result listener"
+ " must be provided in GestureRecognizerOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The gesture recognizer is in the image or the video mode, a user-defined result"
+ " listener shouldn't be provided in GestureRecognizerOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
// TODO: remove these. Temporary solutions before bundle asset is ready.
abstract BaseOptions baseOptionsHandDetector();
abstract BaseOptions baseOptionsHandLandmarker();
abstract BaseOptions baseOptionsGestureRecognizer();
abstract RunningMode runningMode();
abstract Optional<Integer> numHands();
abstract Optional<Float> minHandDetectionConfidence();
abstract Optional<Float> minHandPresenceConfidence();
abstract Optional<Float> minTrackingConfidence();
// TODO update gesture confidence options after score merging calculator is ready.
abstract Optional<Float> minGestureConfidence();
abstract Optional<ResultListener<GestureRecognitionResult, Image>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_GestureRecognizer_GestureRecognizerOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setNumHands(1)
.setMinHandDetectionConfidence(0.5f)
.setMinHandPresenceConfidence(0.5f)
.setMinTrackingConfidence(0.5f)
.setMinGestureConfidence(-1f);
}
/**
* Converts a {@link GestureRecognizerOptions} to a {@link CalculatorOptions} protobuf message.
*/
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptions()));
GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder =
GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
// Setup HandDetectorGraphOptions.
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder
handDetectorGraphOptionsBuilder =
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandDetector())));
numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands);
minHandDetectionConfidence()
.ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence);
// Setup HandLandmarkerGraphOptions.
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder
handLandmarksDetectorGraphOptionsBuilder =
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
minHandPresenceConfidence()
.ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence);
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder
handLandmarkerGraphOptionsBuilder =
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
minTrackingConfidence()
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
handLandmarkerGraphOptionsBuilder
.setHandDetectorGraphOptions(handDetectorGraphOptionsBuilder.build())
.setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptionsBuilder.build());
// Setup HandGestureRecognizerGraphOptions.
HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder
handGestureRecognizerGraphOptionsBuilder =
HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsGestureRecognizer())));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold);
handGestureRecognizerGraphOptionsBuilder.setClassifierOptions(
classifierOptionsBuilder.build());
taskOptionsBuilder
.setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build())
.setHandGestureRecognizerGraphOptions(handGestureRecognizerGraphOptionsBuilder.build());
return CalculatorOptions.newBuilder()
.setExtension(
GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.imageclassifier">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,102 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the classification results generated by {@link ImageClassifier}. */
@AutoValue
public abstract class ImageClassificationResult implements TaskResult {
/**
* Creates an {@link ImageClassificationResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
* message.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static ImageClassificationResult create(
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
classificationResult.getClassificationsList()) {
classifications.add(classificationsFromProto(classificationsProto));
}
return new AutoValue_ImageClassificationResult(
timestampMs, Collections.unmodifiableList(classifications));
}
@Override
public abstract long timestampMs();
/** Contains one set of results per classifier head. */
public abstract List<Classifications> classifications();
/**
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
*
* @param category the {@link CategoryProto.Category} protobuf message to convert.
*/
static Category categoryFromProto(CategoryProto.Category category) {
return Category.create(
category.getScore(),
category.getIndex(),
category.getCategoryName(),
category.getDisplayName());
}
/**
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
* ClassificationEntry} object.
*
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
*/
static ClassificationEntry classificationEntryFromProto(
ClassificationsProto.ClassificationEntry entry) {
List<Category> categories = new ArrayList<>();
for (CategoryProto.Category category : entry.getCategoriesList()) {
categories.add(categoryFromProto(category));
}
return ClassificationEntry.create(categories, entry.getTimestampMs());
}
/**
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
* Classifications} object.
*
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
* convert.
*/
static Classifications classificationsFromProto(
ClassificationsProto.Classifications classifications) {
List<ClassificationEntry> entries = new ArrayList<>();
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
entries.add(classificationEntryFromProto(entry));
}
return Classifications.create(
entries, classifications.getHeadIndex(), classifications.getHeadName());
}
}

View File

@ -0,0 +1,456 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import android.content.Context;
import android.graphics.RectF;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs classification on images.
*
* <p>The API expects a TFLite model with optional, but strongly recommended, <a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
*
* <p>The API supports models with one image input tensor and one or more output tensors. To be more
* specific, here are the requirements.
*
* <ul>
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
* <ul>
* <li>image input of size {@code [batch x height x width x channels]}.
* <li>batch inference is not supported ({@code batch} is required to be 1).
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
* <li>if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the
* metadata for input normalization.
* </ul>
* <li>At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with:
* <ul>
* <li>{@code N} classes and either 2 or 4 dimensions, i.e. {@code [1 x N]} or {@code [1 x 1
* x 1 x N]}
* <li>optional (but recommended) label map(s) as AssociatedFile-s with type
* TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
* any) is used to fill the {@code class_name} field of the results. The {@code
* display_name} field is filled from the AssociatedFile (if any) whose locale matches
* the {@code display_names_locale} field of the {@code ImageClassifierOptions} used at
* creation time ("en" by default, i.e. English). If none of these are available, only
* the {@code index} field of the results will be filled.
* <li>optional score calibration can be attached using ScoreCalibrationOptions and an
* AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See <a
* href="https://github.com/google/mediapipe/blob/master/mediapipe/tasks/metadata/metadata_schema.fbs">
* metadata_schema.fbs</a> for more details.
* </ul>
* </ul>
*
* <p>An example of such model can be found <a
* href="https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1">
* TensorFlow Hub</a>.
*/
public final class ImageClassifier extends BaseVisionTaskApi {
private static final String TAG = ImageClassifier.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out"));
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
static {
ProtoUtil.registerTypeName(
ClassificationsProto.ClassificationResult.class,
"mediapipe.tasks.components.containers.proto.ClassificationResult");
}
/**
* Creates an {@link ImageClassifier} instance from a model file and default {@link
* ImageClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the classification model in the assets.
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
*/
public static ImageClassifier createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link ImageClassifier} instance from a model file and default {@link
* ImageClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the classification model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
*/
public static ImageClassifier createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates an {@link ImageClassifier} instance from a model buffer and default {@link
* ImageClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
* classification model.
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
*/
public static ImageClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link ImageClassifier} instance from an {@link ImageClassifierOptions} instance.
*
* @param context an Android {@link Context}.
* @param options an {@link ImageClassifierOptions} instance.
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
*/
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
OutputHandler<ImageClassificationResult, Image> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ImageClassificationResult, Image>() {
@Override
public ImageClassificationResult convertToTaskResult(List<Packet> packets) {
try {
return ImageClassificationResult.create(
PacketGetter.getProto(
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
} catch (InvalidProtocolBufferException e) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
}
}
@Override
public Image convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
options.resultListener().ifPresent(handler::setResultListener);
options.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<ImageClassifierOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(options)
.setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
return new ImageClassifier(runner, options.runningMode());
}
/**
* Constructor to initialize an {@link ImageClassifier} from a {@link TaskRunner} and {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ImageClassifier(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
}
/**
* Performs classification on the provided single image. Only use this method when the {@link
* ImageClassifier} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classify(Image inputImage) {
return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF());
}
/**
* Performs classification on the provided single image and region-of-interest. Only use this
* method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} specifying the region of interest on which to perform
* classification. Coordinates are expected to be specified as normalized values in [0,1].
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classify(Image inputImage, RectF roi) {
return (ImageClassificationResult) processImageData(inputImage, roi);
}
/**
* Performs classification on the provided video frame. Only use this method when the {@link
* ImageClassifier} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) {
return (ImageClassificationResult)
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
}
/**
* Performs classification on the provided video frame with additional region-of-interest. Only
* use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} specifying the region of interest on which to perform
* classification. Coordinates are expected to be specified as normalized values in [0,1].
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classifyForVideo(
Image inputImage, RectF roi, long inputTimestampMs) {
return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs);
}
/**
* Sends live image data to perform classification, and the results will be available via the
* {@link ResultListener} provided in the {@link ImageClassifierOptions}. Only use this method
* when the {@link ImageClassifier} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void classifyAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
}
/**
* Sends live image data and additional region-of-interest to perform classification, and the
* results will be available via the {@link ResultListener} provided in the {@link
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
* {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} specifying the region of interest on which to perform
* classification. Coordinates are expected to be specified as normalized values in [0,1].
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) {
sendLiveStreamData(inputImage, roi, inputTimestampMs);
}
/** Options for setting up and {@link ImageClassifier}. */
@AutoValue
public abstract static class ImageClassifierOptions extends TaskOptions {
/** Builder for {@link ImageClassifierOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the {@link BaseOptions} for the image classifier task. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/**
* Sets the {@link RunningMode} for the image classifier task. Default to the image mode.
* Image classifier has three modes:
*
* <ul>
* <li>IMAGE: The mode for performing classification on single image inputs.
* <li>VIDEO: The mode for performing classification on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for performing classification on a live stream of input
* data, such as from camera. In this mode, {@code setResultListener} must be called to
* set up a listener to receive the classification results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode runningMode);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
/**
* Sets the {@link ResultListener} to receive the classification results asynchronously when
* the image classifier is in the live stream mode.
*/
public abstract Builder setResultListener(
ResultListener<ImageClassificationResult, Image> resultListener);
/** Sets an optional {@link ErrorListener}. */
public abstract Builder setErrorListener(ErrorListener errorListener);
abstract ImageClassifierOptions autoBuild();
/**
* Validates and builds the {@link ImageClassifierOptions} instance. *
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the image classifier
* is in the live stream mode.
*/
public final ImageClassifierOptions build() {
ImageClassifierOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The image classifier is in the live stream mode, a user-defined result listener"
+ " must be provided in the ImageClassifierOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The image classifier is in the image or video mode, a user-defined result listener"
+ " shouldn't be provided in ImageClassifierOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<ResultListener<ImageClassificationResult, Image>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
.setRunningMode(RunningMode.IMAGE);
}
/**
* Converts a {@link ImageClassifierOptions} to a {@link CalculatorOptions} protobuf message.
*/
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder()
.setExtension(
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
/** Creates a RectF covering the full image. */
private static RectF buildFullImageRectF() {
return new RectF(0, 0, 1, 1);
}
}

View File

@ -1,44 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "objectdetector",
srcs = [
"ObjectDetectionResult.java",
"ObjectDetector.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = ":AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.gesturerecognizertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="gesturerecognizertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.gesturerecognizertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,495 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.gesturerecognizer;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.common.truth.Correspondence;
import com.google.mediapipe.formats.proto.ClassificationProto;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Landmark;
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions;
import java.io.InputStream;
import java.util.Arrays;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link GestureRecognizer}. */
@RunWith(Suite.class)
@SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class})
public class GestureRecognizerTest {
private static final String HAND_DETECTOR_MODEL_FILE = "palm_detection_full.tflite";
private static final String HAND_LANDMARKER_MODEL_FILE = "hand_landmark_full.tflite";
private static final String GESTURE_RECOGNIZER_MODEL_FILE =
"cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite";
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
private static final String TAG = "Gesture Recognizer Test";
private static final String THUMB_UP_LABEL = "Thumb_Up";
private static final int THUMB_UP_INDEX = 5;
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
private static final int IMAGE_WIDTH = 382;
private static final int IMAGE_HEIGHT = 406;
@RunWith(AndroidJUnit4.class)
public static final class General extends GestureRecognizerTest {
@Test
public void recognize_successWithValidModels() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_successWithEmptyResult() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(NO_HANDS_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.gestures()).isEmpty();
}
@Test
public void recognize_successWithMinGestureConfidence() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
// TODO update the confidence to be in range [0,1] after embedding model
// and scoring calculator is integrated.
.setMinGestureConfidence(3.0f)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
// Only contains one top scoring gesture.
assertThat(actualResult.gestures().get(0)).hasSize(1);
assertActualGestureEqualExpectedGesture(
actualResult.gestures().get(0).get(0), expectedResult.gestures().get(0).get(0));
}
@Test
public void recognize_successWithNumHands() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setNumHands(2)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE));
assertThat(actualResult.handednesses()).hasSize(2);
}
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends GestureRecognizerTest {
@Test
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE)
.build())
.setBaseOptionsHandDetector(
BaseOptions.builder()
.setModelAssetPath(HAND_DETECTOR_MODEL_FILE)
.build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_MODEL_FILE)
.build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE)
.build())
.setRunningMode(mode)
.setResultListener((gestureRecognitionResult, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener shouldn't be provided");
}
}
}
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE)
.build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE)
.build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void recognize_failsWithCallingWrongApiInImageMode() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((gestureRecognitionResult, inputImage) -> {})
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void recognize_successWithImageMode() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_successWithVideoMode() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
for (int i = 0; i < 3; i++) {
GestureRecognitionResult actualResult =
gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
}
@Test
public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception {
Image image = getImageFromAsset(THUMB_UP_IMAGE);
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(actualResult, inputImage) -> {
assertActualResultApproximatelyEqualsToExpectedResult(
actualResult, expectedResult);
assertImageSizeIsExpected(inputImage);
})
.build();
try (GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
gestureRecognizer.recognizeAsync(image, 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
@Test
public void recognize_successWithLiveSteamMode() throws Exception {
Image image = getImageFromAsset(THUMB_UP_IMAGE);
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(actualResult, inputImage) -> {
assertActualResultApproximatelyEqualsToExpectedResult(
actualResult, expectedResult);
assertImageSizeIsExpected(inputImage);
})
.build();
try (GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
gestureRecognizer.recognizeAsync(image, i);
}
}
}
private static Image getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
private static GestureRecognitionResult getExpectedGestureRecognitionResult(
String filePath, String gestureLabel, int gestureIndex) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
LandmarksDetectionResult landmarksDetectionResultProto =
LandmarksDetectionResult.parser().parseFrom(istr);
ClassificationProto.ClassificationList gesturesProto =
ClassificationProto.ClassificationList.newBuilder()
.addClassification(
ClassificationProto.Classification.newBuilder()
.setLabel(gestureLabel)
.setIndex(gestureIndex))
.build();
return GestureRecognitionResult.create(
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
Arrays.asList(landmarksDetectionResultProto.getClassifications()),
Arrays.asList(gesturesProto),
/*timestampMs=*/ 0);
}
private static void assertActualResultApproximatelyEqualsToExpectedResult(
GestureRecognitionResult actualResult, GestureRecognitionResult expectedResult) {
// Expects to have the same number of hands detected.
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size());
assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size());
assertThat(actualResult.gestures()).hasSize(expectedResult.gestures().size());
// Actual landmarks match expected landmarks.
assertThat(actualResult.landmarks().get(0))
.comparingElementsUsing(
Correspondence.from(
(Correspondence.BinaryPredicate<Landmark, Landmark>)
(actual, expected) -> {
return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
.compare(actual.x(), expected.x())
&& Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
.compare(actual.y(), expected.y());
},
"landmarks approximately equal to"))
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
// Actual handedness matches expected handedness.
Category actualTopHandedness = actualResult.handednesses().get(0).get(0);
Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0);
assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index());
assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName());
// Actual gesture with top score matches expected gesture.
Category actualTopGesture = actualResult.gestures().get(0).get(0);
Category expectedTopGesture = expectedResult.gestures().get(0).get(0);
assertActualGestureEqualExpectedGesture(actualTopGesture, expectedTopGesture);
}
private static void assertActualGestureEqualExpectedGesture(
Category actualGesture, Category expectedGesture) {
assertThat(actualGesture.index()).isEqualTo(actualGesture.index());
assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName());
}
private static void assertImageSizeIsExpected(Image inputImage) {
assertThat(inputImage).isNotNull();
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
}
}

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.imageclassifiertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="imageclassifiertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.imageclassifiertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,445 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import android.graphics.RectF;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link ImageClassifier}/ */
@RunWith(Suite.class)
@SuiteClasses({ImageClassifierTest.General.class, ImageClassifierTest.RunningModeTest.class})
public class ImageClassifierTest {
private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
private static final String BURGER_IMAGE = "burger.jpg";
private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
@RunWith(AndroidJUnit4.class)
public static final class General extends ImageClassifierTest {
@Test
public void create_failsWithMissingModel() throws Exception {
String nonExistentFile = "/path/to/non/existent/file";
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
ImageClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), nonExistentFile));
assertThat(exception).hasMessageThat().contains(nonExistentFile);
}
@Test
public void create_failsWithInvalidModelBuffer() throws Exception {
// Create a non-direct model ByteBuffer.
ByteBuffer modelBuffer =
TestUtils.loadToNonDirectByteBuffer(
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifier.createFromBuffer(
ApplicationProvider.getApplicationContext(), modelBuffer));
assertThat(exception)
.hasMessageThat()
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
@Test
public void classify_succeedsWithNoOptions() throws Exception {
ImageClassifier imageClassifier =
ImageClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001);
assertThat(results.classifications().get(0).entries().get(0).categories().get(0))
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
}
@Test
public void classify_succeedsWithFloatModel() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results,
Arrays.asList(
Category.create(0.7952058f, 934, "cheeseburger", ""),
Category.create(0.027329788f, 932, "bagel", ""),
Category.create(0.019334773f, 925, "guacamole", "")));
}
@Test
public void classify_succeedsWithQuantizedModel() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
}
@Test
public void classify_succeedsWithScoreThreshold() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results,
Arrays.asList(
Category.create(0.7952058f, 934, "cheeseburger", ""),
Category.create(0.027329788f, 932, "bagel", "")));
}
@Test
public void classify_succeedsWithAllowlist() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(
ClassifierOptions.builder()
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
.build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results,
Arrays.asList(
Category.create(0.7952058f, 934, "cheeseburger", ""),
Category.create(0.019334773f, 925, "guacamole", ""),
Category.create(0.006279315f, 963, "meat loaf", "")));
}
@Test
public void classify_succeedsWithDenylist() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(
ClassifierOptions.builder()
.setMaxResults(3)
.setCategoryDenylist(Arrays.asList("bagel"))
.build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results,
Arrays.asList(
Category.create(0.7952058f, 934, "cheeseburger", ""),
Category.create(0.019334773f, 925, "guacamole", ""),
Category.create(0.006279315f, 963, "meat loaf", "")));
}
@Test
public void classify_succeedsWithRegionOfInterest() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// RectF around the soccer ball.
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
ImageClassificationResult results =
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi);
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
}
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends ImageClassifierTest {
@Test
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setRunningMode(mode)
.setResultListener((imageClassificationResult, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener shouldn't be provided");
}
}
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void classify_failsWithCallingWrongApiInImageMode() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void classify_failsWithCallingWrongApiInVideoMode() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void classify_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((imageClassificationResult, inputImage) -> {})
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void classify_succeedsWithImageMode() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
}
@Test
public void classify_succeedsWithVideoMode() throws Exception {
Image image = getImageFromAsset(BURGER_IMAGE);
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setRunningMode(RunningMode.VIDEO)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ImageClassificationResult results = imageClassifier.classifyForVideo(image, i);
assertHasOneHeadAndOneTimestamp(results, i);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
}
}
@Test
public void classify_failsWithOutOfOrderInputTimestamps() throws Exception {
Image image = getImageFromAsset(BURGER_IMAGE);
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(imageClassificationResult, inputImage) -> {
assertCategoriesAre(
imageClassificationResult,
Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
assertImageSizeIsExpected(inputImage);
})
.build();
try (ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
@Test
public void classify_succeedsWithLiveStreamMode() throws Exception {
Image image = getImageFromAsset(BURGER_IMAGE);
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(imageClassificationResult, inputImage) -> {
assertCategoriesAre(
imageClassificationResult,
Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
assertImageSizeIsExpected(inputImage);
})
.build();
try (ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; ++i) {
imageClassifier.classifyAsync(image, i);
}
}
}
}
private static Image getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
private static void assertHasOneHeadAndOneTimestamp(
ImageClassificationResult results, long timestampMs) {
assertThat(results.classifications()).hasSize(1);
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
assertThat(results.classifications().get(0).entries()).hasSize(1);
assertThat(results.classifications().get(0).entries().get(0).timestampMs())
.isEqualTo(timestampMs);
}
private static void assertCategoriesAre(
ImageClassificationResult results, List<Category> categories) {
assertThat(results.classifications().get(0).entries().get(0).categories())
.hasSize(categories.size());
for (int i = 0; i < categories.size(); i++) {
assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
.isEqualTo(categories.get(i));
}
}
private static void assertImageSizeIsExpected(Image inputImage) {
assertThat(inputImage).isNotNull();
assertThat(inputImage.getWidth()).isEqualTo(480);
assertThat(inputImage.getHeight()).isEqualTo(325);
}
}

View File

@ -0,0 +1,21 @@
# Placeholder for internal Python strict library compatibility macro.
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"],
)
py_library(
name = "metadata_info",
srcs = [
"metadata_info.py",
],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,446 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper classes for common model metadata information."""
import csv
import os
from typing import List, Optional, Type
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
# Min and max values for UINT8 tensors.
_MIN_UINT8 = 0
_MAX_UINT8 = 255
# Default description for vocabulary files.
_VOCAB_FILE_DESCRIPTION = ("Vocabulary file to convert natural language "
"words to embedding vectors.")
class GeneralMd:
"""A container for common metadata information of a model.
Attributes:
name: name of the model.
version: version of the model.
description: description of what the model does.
author: author of the model.
licenses: licenses of the model.
"""
def __init__(self,
name: Optional[str] = None,
version: Optional[str] = None,
description: Optional[str] = None,
author: Optional[str] = None,
licenses: Optional[str] = None) -> None:
self.name = name
self.version = version
self.description = description
self.author = author
self.licenses = licenses
def create_metadata(self) -> _metadata_fb.ModelMetadataT:
"""Creates the model metadata based on the general model information.
Returns:
A Flatbuffers Python object of the model metadata.
"""
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.name = self.name
model_metadata.version = self.version
model_metadata.description = self.description
model_metadata.author = self.author
model_metadata.license = self.licenses
return model_metadata
class AssociatedFileMd:
"""A container for common associated file metadata information.
Attributes:
file_path: path to the associated file.
description: description of the associated file.
file_type: file type of the associated file [1].
locale: locale of the associated file [2].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L77
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L176
"""
def __init__(
self,
file_path: str,
description: Optional[str] = None,
file_type: Optional[int] = _metadata_fb.AssociatedFileType.UNKNOWN,
locale: Optional[str] = None) -> None:
self.file_path = file_path
self.description = description
self.file_type = file_type
self.locale = locale
def create_metadata(self) -> _metadata_fb.AssociatedFileT:
"""Creates the associated file metadata.
Returns:
A Flatbuffers Python object of the associated file metadata.
"""
file_metadata = _metadata_fb.AssociatedFileT()
file_metadata.name = os.path.basename(self.file_path)
file_metadata.description = self.description
file_metadata.type = self.file_type
file_metadata.locale = self.locale
return file_metadata
class LabelFileMd(AssociatedFileMd):
"""A container for label file metadata information."""
_LABEL_FILE_DESCRIPTION = ("Labels for categories that the model can "
"recognize.")
_FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
def __init__(self, file_path: str, locale: Optional[str] = None) -> None:
"""Creates a LabelFileMd object.
Args:
file_path: file_path of the label file.
locale: locale of the label file [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L176
"""
super().__init__(file_path, self._LABEL_FILE_DESCRIPTION, self._FILE_TYPE,
locale)
class ScoreCalibrationMd:
"""A container for score calibration [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
"""
_SCORE_CALIBRATION_FILE_DESCRIPTION = (
"Contains sigmoid-based score calibration parameters. The main purposes "
"of score calibration is to make scores across classes comparable, so "
"that a common threshold can be used for all output classes.")
_FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_SCORE_CALIBRATION
def __init__(self,
score_transformation_type: _metadata_fb.ScoreTransformationType,
default_score: float, file_path: str) -> None:
"""Creates a ScoreCalibrationMd object.
Args:
score_transformation_type: type of the function used for transforming the
uncalibrated score before applying score calibration.
default_score: the default calibrated score to apply if the uncalibrated
score is below min_score or if no parameters were specified for a given
index.
file_path: file_path of the score calibration file [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L133
Raises:
ValueError: if the score_calibration file is malformed.
"""
self._score_transformation_type = score_transformation_type
self._default_score = default_score
self._file_path = file_path
# Sanity check the score calibration file.
with open(self._file_path) as calibration_file:
csv_reader = csv.reader(calibration_file, delimiter=",")
for row in csv_reader:
if row and len(row) != 3 and len(row) != 4:
raise ValueError(
f"Expected empty lines or 3 or 4 parameters per line in score"
f" calibration file, but got {len(row)}.")
if row and float(row[0]) < 0:
raise ValueError(
f"Expected scale to be a non-negative value, but got "
f"{float(row[0])}.")
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the score calibration metadata based on the information.
Returns:
A Flatbuffers Python object of the score calibration metadata.
"""
score_calibration = _metadata_fb.ProcessUnitT()
score_calibration.optionsType = (
_metadata_fb.ProcessUnitOptions.ScoreCalibrationOptions)
options = _metadata_fb.ScoreCalibrationOptionsT()
options.scoreTransformation = self._score_transformation_type
options.defaultScore = self._default_score
score_calibration.options = options
return score_calibration
def create_score_calibration_file_md(self) -> AssociatedFileMd:
return AssociatedFileMd(self._file_path,
self._SCORE_CALIBRATION_FILE_DESCRIPTION,
self._FILE_TYPE)
class TensorMd:
"""A container for common tensor metadata information.
Attributes:
name: name of the tensor.
description: description of what the tensor is.
min_values: per-channel minimum value of the tensor.
max_values: per-channel maximum value of the tensor.
content_type: content_type of the tensor.
associated_files: information of the associated files in the tensor.
tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
used to locate the corresponding tensor and decide the order of the tensor
metadata [2] when populating model metadata.
[1]:
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
"""
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
min_values: Optional[List[float]] = None,
max_values: Optional[List[float]] = None,
content_type: int = _metadata_fb.ContentProperties.FeatureProperties,
associated_files: Optional[List[Type[AssociatedFileMd]]] = None,
tensor_name: Optional[str] = None) -> None:
self.name = name
self.description = description
self.min_values = min_values
self.max_values = max_values
self.content_type = content_type
self.associated_files = associated_files
self.tensor_name = tensor_name
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input tensor metadata based on the information.
Returns:
A Flatbuffers Python object of the input metadata.
"""
tensor_metadata = _metadata_fb.TensorMetadataT()
tensor_metadata.name = self.name
tensor_metadata.description = self.description
# Create min and max values
stats = _metadata_fb.StatsT()
stats.max = self.max_values
stats.min = self.min_values
tensor_metadata.stats = stats
# Create content properties
content = _metadata_fb.ContentT()
if self.content_type is _metadata_fb.ContentProperties.FeatureProperties:
content.contentProperties = _metadata_fb.FeaturePropertiesT()
elif self.content_type is _metadata_fb.ContentProperties.ImageProperties:
content.contentProperties = _metadata_fb.ImagePropertiesT()
elif self.content_type is (
_metadata_fb.ContentProperties.BoundingBoxProperties):
content.contentProperties = _metadata_fb.BoundingBoxPropertiesT()
elif self.content_type is _metadata_fb.ContentProperties.AudioProperties:
content.contentProperties = _metadata_fb.AudioPropertiesT()
content.contentPropertiesType = self.content_type
tensor_metadata.content = content
# TODO: check if multiple label files have populated locale.
# Create associated files
if self.associated_files:
tensor_metadata.associatedFiles = [
file.create_metadata() for file in self.associated_files
]
return tensor_metadata
class InputImageTensorMd(TensorMd):
"""A container for input image tensor metadata information.
Attributes:
norm_mean: the mean value used in tensor normalization [1].
norm_std: the std value used in the tensor normalization [1]. norm_mean and
norm_std must have the same dimension.
color_space_type: the color space type of the input image [2].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L198
"""
# Min and max float values for image pixels.
_MIN_PIXEL = 0.0
_MAX_PIXEL = 255.0
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
norm_mean: Optional[List[float]] = None,
norm_std: Optional[List[float]] = None,
color_space_type: Optional[int] = _metadata_fb.ColorSpaceType.UNKNOWN,
tensor_type: Optional["_schema_fb.TensorType"] = None) -> None:
"""Initializes the instance of InputImageTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
norm_mean: the mean value used in tensor normalization [1].
norm_std: the std value used in the tensor normalization [1]. norm_mean
and norm_std must have the same dimension.
color_space_type: the color space type of the input image [2].
tensor_type: data type of the tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L198
Raises:
ValueError: if norm_mean and norm_std have different dimensions.
"""
if norm_std and norm_mean and len(norm_std) != len(norm_mean):
raise ValueError(
f"norm_mean and norm_std are expected to be the same dim. But got "
f"{len(norm_mean)} and {len(norm_std)}")
if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8]
max_values = [_MAX_UINT8]
elif tensor_type is _schema_fb.TensorType.FLOAT32 and norm_std and norm_mean:
min_values = [
float(self._MIN_PIXEL - mean) / std
for mean, std in zip(norm_mean, norm_std)
]
max_values = [
float(self._MAX_PIXEL - mean) / std
for mean, std in zip(norm_mean, norm_std)
]
else:
# Uint8 and Float32 are the two major types currently. And Task library
# doesn't support other types so far.
min_values = None
max_values = None
super().__init__(name, description, min_values, max_values,
_metadata_fb.ContentProperties.ImageProperties)
self.norm_mean = norm_mean
self.norm_std = norm_std
self.color_space_type = color_space_type
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input image metadata based on the information.
Returns:
A Flatbuffers Python object of the input image metadata.
"""
tensor_metadata = super().create_metadata()
tensor_metadata.content.contentProperties.colorSpace = self.color_space_type
# Create normalization parameters
if self.norm_mean and self.norm_std:
normalization = _metadata_fb.ProcessUnitT()
normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
normalization.options = _metadata_fb.NormalizationOptionsT()
normalization.options.mean = self.norm_mean
normalization.options.std = self.norm_std
tensor_metadata.processUnits = [normalization]
return tensor_metadata
class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information.
Attributes:
label_files: information of the label files [1] in the classification
tensor.
score_calibration_md: information of the score calibration operation [2] in
the classification tensor.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
"""
# Min and max float values for classification results.
_MIN_FLOAT = 0.0
_MAX_FLOAT = 1.0
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[int] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None) -> None:
"""Initializes the instance of ClassificationTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
label_files: information of the label files [1] in the classification
tensor.
tensor_type: data type of the tensor.
score_calibration_md: information of the score calibration files operation
[2] in the classification tensor.
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
is used to locate the corresponding classification tensor and decide the
order of the tensor metadata [4] when populating model metadata.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
[3]:
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L623-L640
"""
self.score_calibration_md = score_calibration_md
if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8]
max_values = [_MAX_UINT8]
elif tensor_type is _schema_fb.TensorType.FLOAT32:
min_values = [self._MIN_FLOAT]
max_values = [self._MAX_FLOAT]
else:
# Uint8 and Float32 are the two major types currently. And Task library
# doesn't support other types so far.
min_values = None
max_values = None
associated_files = label_files or []
if self.score_calibration_md:
associated_files.append(
score_calibration_md.create_score_calibration_file_md())
super().__init__(name, description, min_values, max_values,
_metadata_fb.ContentProperties.FeatureProperties,
associated_files, tensor_name)
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the classification tensor metadata based on the information."""
tensor_metadata = super().create_metadata()
if self.score_calibration_md:
tensor_metadata.processUnits = [
self.score_calibration_md.create_metadata()
]
return tensor_metadata

View File

@ -0,0 +1,26 @@
# Placeholder for internal Python strict test compatibility macro.
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
py_test(
name = "metadata_info_test",
srcs = ["metadata_info_test.py"],
data = [
"//mediapipe/tasks/testdata/metadata:data_files",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py",
"//mediapipe/tasks/python/metadata",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
"//mediapipe/tasks/python/test:test_utils",
"@flatbuffers//:runtime_py",
],
)

View File

@ -0,0 +1,343 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata info classes."""
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
import flatbuffers
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata import metadata as _metadata
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.test import test_utils
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
class GeneralMdTest(absltest.TestCase):
_EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path(
"general_meta.json")
def test_create_metadata_should_succeed(self):
general_md = metadata_info.GeneralMd(
name="model",
version="v1",
description="A ML model.",
author="MediaPipe",
licenses="Apache")
general_metadata = general_md.create_metadata()
# Create the Flatbuffers object and convert it to the json format.
builder = flatbuffers.Builder(0)
builder.Finish(
general_metadata.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_json = _metadata.convert_to_json(bytes(builder.Output()))
with open(self._EXPECTED_GENERAL_META_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class AssociatedFileMdTest(absltest.TestCase):
_EXPECTED_META_JSON = test_utils.get_test_data_path(
"associated_file_meta.json")
def test_create_metadata_should_succeed(self):
file_md = metadata_info.AssociatedFileMd(
file_path="label.txt",
description="The label file.",
file_type=_metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS,
locale="en")
file_metadata = file_md.create_metadata()
# Create the Flatbuffers object and convert it to the json format.
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.associatedFiles = [file_metadata]
builder = flatbuffers.Builder(0)
builder.Finish(
model_metadata.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_json = _metadata.convert_to_json(bytes(builder.Output()))
with open(self._EXPECTED_META_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class TensorMdTest(parameterized.TestCase):
_TENSOR_NAME = "input"
_TENSOR_DESCRIPTION = "The input tensor."
_TENSOR_MIN = 0
_TENSOR_MAX = 1
_LABEL_FILE_EN = "labels.txt"
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
_EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path(
"feature_tensor_meta.json")
_EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path(
"image_tensor_meta.json")
_EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path(
"bounding_box_tensor_meta.json")
@parameterized.named_parameters(
{
"testcase_name": "feature_tensor",
"content_type": _metadata_fb.ContentProperties.FeatureProperties,
"golden_json": _EXPECTED_FEATURE_TENSOR_JSON
}, {
"testcase_name": "image_tensor",
"content_type": _metadata_fb.ContentProperties.ImageProperties,
"golden_json": _EXPECTED_IMAGE_TENSOR_JSON
}, {
"testcase_name": "bounding_box_tensor",
"content_type": _metadata_fb.ContentProperties.BoundingBoxProperties,
"golden_json": _EXPECTED_BOUNDING_BOX_TENSOR_JSON
})
def test_create_metadata_should_succeed(self, content_type, golden_json):
associated_file1 = metadata_info.AssociatedFileMd(
file_path=self._LABEL_FILE_EN, locale="en")
associated_file2 = metadata_info.AssociatedFileMd(
file_path=self._LABEL_FILE_CN, locale="cn")
tensor_md = metadata_info.TensorMd(
name=self._TENSOR_NAME,
description=self._TENSOR_DESCRIPTION,
min_values=[self._TENSOR_MIN],
max_values=[self._TENSOR_MAX],
content_type=content_type,
associated_files=[associated_file1, associated_file2])
tensor_metadata = tensor_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(tensor_metadata))
with open(golden_json, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class InputImageTensorMdTest(parameterized.TestCase):
_NAME = "image"
_DESCRIPTION = "The input image."
_NORM_MEAN = (0, 127.5, 255)
_NORM_STD = (127.5, 127.5, 127.5)
_COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
"input_image_tensor_float_meta.json")
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
"input_image_tensor_uint8_meta.json")
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
"input_image_tensor_unsupported_meta.json")
@parameterized.named_parameters(
{
"testcase_name": "float",
"tensor_type": _schema_fb.TensorType.FLOAT32,
"golden_json": _EXPECTED_FLOAT_TENSOR_JSON
}, {
"testcase_name": "uint8",
"tensor_type": _schema_fb.TensorType.UINT8,
"golden_json": _EXPECTED_UINT8_TENSOR_JSON
}, {
"testcase_name": "unsupported_tensor_type",
"tensor_type": _schema_fb.TensorType.INT16,
"golden_json": _EXPECTED_UNSUPPORTED_TENSOR_JSON
})
def test_create_metadata_should_succeed(self, tensor_type, golden_json):
tesnor_md = metadata_info.InputImageTensorMd(
name=self._NAME,
description=self._DESCRIPTION,
norm_mean=list(self._NORM_MEAN),
norm_std=list(self._NORM_STD),
color_space_type=self._COLOR_SPACE_TYPE,
tensor_type=tensor_type)
tensor_metadata = tesnor_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(tensor_metadata))
with open(golden_json, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_init_should_throw_exception_with_incompatible_mean_and_std(self):
norm_mean = [0]
norm_std = [1, 2]
with self.assertRaises(ValueError) as error:
metadata_info.InputImageTensorMd(norm_mean=norm_mean, norm_std=norm_std)
self.assertEqual(
f"norm_mean and norm_std are expected to be the same dim. But got "
f"{len(norm_mean)} and {len(norm_std)}", str(error.exception))
class ClassificationTensorMdTest(parameterized.TestCase):
_NAME = "probability"
_DESCRIPTION = "The classification result tensor."
_LABEL_FILE_EN = "labels.txt"
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
_CALIBRATION_DEFAULT_SCORE = 0.2
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
"classification_tensor_float_meta.json")
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
"classification_tensor_uint8_meta.json")
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
"classification_tensor_unsupported_meta.json")
@parameterized.named_parameters(
{
"testcase_name": "float",
"tensor_type": _schema_fb.TensorType.FLOAT32,
"golden_json": _EXPECTED_FLOAT_TENSOR_JSON
}, {
"testcase_name": "uint8",
"tensor_type": _schema_fb.TensorType.UINT8,
"golden_json": _EXPECTED_UINT8_TENSOR_JSON
}, {
"testcase_name": "unsupported_tensor_type",
"tensor_type": _schema_fb.TensorType.INT16,
"golden_json": _EXPECTED_UNSUPPORTED_TENSOR_JSON
})
def test_create_metadata_should_succeed(self, tensor_type, golden_json):
label_file_en = metadata_info.LabelFileMd(
file_path=self._LABEL_FILE_EN, locale="en")
label_file_cn = metadata_info.LabelFileMd(
file_path=self._LABEL_FILE_CN, locale="cn")
score_calibration_md = metadata_info.ScoreCalibrationMd(
_metadata_fb.ScoreTransformationType.IDENTITY,
self._CALIBRATION_DEFAULT_SCORE, _SCORE_CALIBRATION_FILE)
tesnor_md = metadata_info.ClassificationTensorMd(
name=self._NAME,
description=self._DESCRIPTION,
label_files=[label_file_en, label_file_cn],
tensor_type=tensor_type,
score_calibration_md=score_calibration_md)
tensor_metadata = tesnor_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(tensor_metadata))
with open(golden_json, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class ScoreCalibrationMdTest(absltest.TestCase):
_DEFAULT_VALUE = 0.2
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
"score_calibration_tensor_meta.json")
_EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path(
"score_calibration_file_meta.json")
def test_create_metadata_should_succeed(self):
score_calibration_md = metadata_info.ScoreCalibrationMd(
_metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE,
_SCORE_CALIBRATION_FILE)
score_calibration_metadata = score_calibration_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(
score_calibration_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_create_score_calibration_file_md_should_succeed(self):
score_calibration_md = metadata_info.ScoreCalibrationMd(
_metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE,
_SCORE_CALIBRATION_FILE)
score_calibration_file_md = (
score_calibration_md.create_score_calibration_file_md())
file_metadata = score_calibration_file_md.create_metadata()
# Create the Flatbuffers object and convert it to the json format.
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.associatedFiles = [file_metadata]
builder = flatbuffers.Builder(0)
builder.Finish(
model_metadata.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_json = _metadata.convert_to_json(bytes(builder.Output()))
with open(self._EXPECTED_MODEL_META_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_create_score_calibration_file_fails_with_less_colunms(self):
with tempfile.TemporaryDirectory() as temp_dir:
malformed_calibration_file = test_utils.create_calibration_file(
temp_dir, content="1.0,0.2")
with self.assertRaisesRegex(
ValueError,
"Expected empty lines or 3 or 4 parameters per line in score" +
" calibration file, but got 2."):
metadata_info.ScoreCalibrationMd(
_metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE,
malformed_calibration_file)
def test_create_score_calibration_file_fails_with_negative_scale(self):
with tempfile.TemporaryDirectory() as temp_dir:
malformed_calibration_file = test_utils.create_calibration_file(
temp_dir, content="-1.0,0.2,0.1")
with self.assertRaisesRegex(
ValueError,
"Expected scale to be a non-negative value, but got -1.0."):
metadata_info.ScoreCalibrationMd(
_metadata_fb.ScoreTransformationType.LOG, self._DEFAULT_VALUE,
malformed_calibration_file)
def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata.
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputTensorMetadata = [tensor_metadata]
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.subgraphMetadata = [subgraph_metadata]
# Create the Flatbuffers object and convert it to the json format.
builder = flatbuffers.Builder(0)
builder.Finish(
model_metadata.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
return bytes(builder.Output())
def _create_dummy_model_metadata_with_process_uint(
process_unit_metadata: _metadata_fb.ProcessUnitT) -> bytes:
# Create a dummy model using the tensor metadata.
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputProcessUnits = [process_unit_metadata]
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.subgraphMetadata = [subgraph_metadata]
# Create the Flatbuffers object and convert it to the json format.
builder = flatbuffers.Builder(0)
builder.Finish(
model_metadata.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
return bytes(builder.Output())
if __name__ == "__main__":
absltest.main()

View File

@ -43,3 +43,13 @@ def get_test_data_path(file_or_dirname: str) -> str:
if f.endswith(file_or_dirname):
return os.path.join(directory, f)
raise ValueError("No %s in test directory" % file_or_dirname)
def create_calibration_file(file_dir: str,
file_name: str = "score_calibration.txt",
content: str = "1.0,2.0,3.0,4.0") -> str:
"""Creates the calibration file."""
calibration_file = os.path.join(file_dir, file_name)
with open(calibration_file, mode="w") as file:
file.write(content)
return calibration_file

View File

@ -33,7 +33,21 @@ mediapipe_files(srcs = [
exports_files([
"external_file",
"general_meta.json",
"golden_json.json",
"associated_file_meta.json",
"bounding_box_tensor_meta.json",
"classification_tensor_float_meta.json",
"classification_tensor_uint8_meta.json",
"classification_tensor_unsupported_meta.json",
"feature_tensor_meta.json",
"image_tensor_meta.json",
"input_image_tensor_float_meta.json",
"input_image_tensor_uint8_meta.json",
"input_image_tensor_unsupported_meta.json",
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
])
filegroup(
@ -51,7 +65,21 @@ filegroup(
filegroup(
name = "data_files",
srcs = [
"associated_file_meta.json",
"bounding_box_tensor_meta.json",
"classification_tensor_float_meta.json",
"classification_tensor_uint8_meta.json",
"classification_tensor_unsupported_meta.json",
"external_file",
"feature_tensor_meta.json",
"general_meta.json",
"golden_json.json",
"image_tensor_meta.json",
"input_image_tensor_float_meta.json",
"input_image_tensor_uint8_meta.json",
"input_image_tensor_unsupported_meta.json",
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",
],
)

View File

@ -0,0 +1,10 @@
{
"associated_files": [
{
"name": "label.txt",
"description": "The label file.",
"type": "TENSOR_AXIS_LABELS",
"locale": "en"
}
]
}

View File

@ -0,0 +1,35 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input",
"description": "The input tensor.",
"content": {
"content_properties_type": "BoundingBoxProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"locale": "en"
},
{
"name": "labels_cn.txt",
"locale": "cn"
}
]
}
]
}
]
}

View File

@ -0,0 +1,52 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "probability",
"description": "The classification result tensor.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreCalibrationOptions",
"options": {
"default_score": 0.2
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS",
"locale": "en"
},
{
"name": "labels_cn.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS",
"locale": "cn"
},
{
"name": "score_calibration.txt",
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
}
]
}
]
}
]
}

View File

@ -0,0 +1,52 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "probability",
"description": "The classification result tensor.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreCalibrationOptions",
"options": {
"default_score": 0.2
}
}
],
"stats": {
"max": [
255.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS",
"locale": "en"
},
{
"name": "labels_cn.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS",
"locale": "cn"
},
{
"name": "score_calibration.txt",
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
}
]
}
]
}
]
}

View File

@ -0,0 +1,46 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "probability",
"description": "The classification result tensor.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreCalibrationOptions",
"options": {
"default_score": 0.2
}
}
],
"stats": {
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS",
"locale": "en"
},
{
"name": "labels_cn.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS",
"locale": "cn"
},
{
"name": "score_calibration.txt",
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
}
]
}
]
}
]
}

View File

@ -0,0 +1,35 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input",
"description": "The input tensor.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"locale": "en"
},
{
"name": "labels_cn.txt",
"locale": "cn"
}
]
}
]
}
]
}

View File

@ -0,0 +1,7 @@
{
"name": "model",
"description": "A ML model.",
"version": "v1",
"author": "MediaPipe",
"license": "Apache"
}

View File

@ -0,0 +1,35 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input",
"description": "The input tensor.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"locale": "en"
},
{
"name": "labels_cn.txt",
"locale": "cn"
}
]
}
]
}
]
}

View File

@ -0,0 +1,47 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "The input image.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0,
127.5,
255.0
],
"std": [
127.5,
127.5,
127.5
]
}
}
],
"stats": {
"max": [
2.0,
1.0,
0.0
],
"min": [
0.0,
-1.0,
-2.0
]
}
}
]
}
]
}

View File

@ -0,0 +1,43 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "The input image.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0,
127.5,
255.0
],
"std": [
127.5,
127.5,
127.5
]
}
}
],
"stats": {
"max": [
255.0
],
"min": [
0.0
]
}
}
]
}
]
}

View File

@ -0,0 +1,37 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "The input image.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0,
127.5,
255.0
],
"std": [
127.5,
127.5,
127.5
]
}
}
],
"stats": {
}
}
]
}
]
}

View File

@ -0,0 +1,511 @@
0.9876328110694885,0.36622241139411926,0.5352765321731567,0.71484375
0.9584911465644836,1.0602262020111084,0.2777034342288971,0.019999999552965164
0.9698624014854431,0.8795201778411865,0.539591908454895,0.00390625
0.7486230731010437,1.1876736879348755,2.552982807159424,0.019999999552965164
0.9745277166366577,0.3739396333694458,0.4621727764606476,0.19921875
0.9683839678764343,0.6996201276779175,0.7690851092338562,0.019999999552965164
0.6875,0.31044548749923706,1.0056899785995483,0.019999999552965164
0.9849396347999573,0.8532888889312744,-0.2361421436071396,0.03125
0.9878578186035156,1.0118975639343262,0.13313621282577515,0.359375
0.9915205836296082,0.4434199929237366,1.0268371105194092,0.05078125
0.9370332360267639,0.4586562216281891,-0.08101099729537964,0.019999999552965164
0.9905818104743958,0.8670706152915955,0.012704282067716122,0.019999999552965164
0.9080020189285278,0.8507471680641174,0.5081117749214172,0.019999999552965164
0.985953152179718,0.9933826923370361,-0.8114940524101257,0.109375
0.9819648861885071,1.12098228931427,-0.6330763697624207,0.01171875
0.9025918245315552,0.7803755402565002,0.03275677561759949,0.08984375
0.9863958954811096,0.11243592947721481,0.935604453086853,0.61328125
0.9905291795730591,0.3710605800151825,0.708966851234436,0.359375
0.9917052984237671,0.9596433043479919,0.19800108671188354,0.09765625
0.8762937188148499,0.3449830114841461,0.5352474451065063,0.0078125
0.9902125000953674,0.8918796181678772,-0.1306992471218109,0.26171875
0.9902340173721313,0.9177873134613037,-0.4322589933872223,0.019999999552965164
0.9707600474357605,0.7028177976608276,0.9813734889030457,0.019999999552965164
0.9823090434074402,1.0499590635299683,0.12045472860336304,0.0078125
0.990516185760498,0.9449402093887329,1.3773189783096313,0.019999999552965164
0.9875434041023254,0.577914297580719,1.282518982887268,0.0390625
0.9821421504020691,0.0967339277267456,0.8279788494110107,0.47265625
0.9875047206878662,0.9038218259811401,2.1208062171936035,0.38671875
0.9857864379882812,0.8627446889877319,0.18189261853694916,0.019999999552965164
0.9647751450538635,1.0752476453781128,-0.018294010311365128,0.0234375
0.9830358624458313,0.5638481378555298,0.8346489667892456,0.019999999552965164
0.9904966354370117,1.0160938501358032,-0.0573287308216095,0.00390625
0.8458405137062073,0.4868394434452057,0.6617084741592407,0.019999999552965164
0.9847381711006165,0.5939620137214661,0.008616370148956776,0.00390625
0.9375938773155212,0.723095178604126,0.6635608077049255,0.019999999552965164
0.9334303140640259,0.5689108967781067,0.37019580602645874,0.019999999552965164
0.9716793894767761,1.0037211179733276,0.5898993611335754,0.02734375
0.9197732210159302,0.46794334053993225,0.7365336418151855,0.640625
0.9857497811317444,0.7299028635025024,0.9195274114608765,0.0390625
0.8758038282394409,1.200216293334961,0.02580185979604721,0.019999999552965164
0.9841026067733765,0.8050475716590881,0.9698556661605835,0.0078125
0.9908539652824402,0.7911490201950073,0.19351358711719513,0.12109375
0.9179956316947937,0.023991893976926804,0.35193610191345215,0.04296875
0.9903728365898132,0.7744967341423035,0.2686336636543274,0.359375
0.906022846698761,0.5766159892082214,1.0600007772445679,0.04296875
0.9885554909706116,0.99117511510849,0.5611960291862488,0.4140625
0.9906331896781921,1.1376535892486572,1.45369291305542,0.019999999552965164
0.9640991687774658,0.5387894511222839,1.1824018955230713,0.019999999552965164
0.9932155609130859,0.4347895085811615,1.3938102722167969,0.0078125
0.9884702563285828,0.885567843914032,0.1556047648191452,0.1484375
0.9891508221626282,0.04143073782324791,0.6111864447593689,0.0078125
0.8935436010360718,0.2937895655632019,0.3215920031070709,0.00390625
0.8327123522758484,0.8381986021995544,-0.026293788105249405,0.019999999552965164
0.9839455485343933,0.9581400156021118,1.495324969291687,0.640625
0.9904995560646057,0.9168422818183899,0.33293962478637695,0.015625
0.9856975674629211,1.0433714389801025,0.5954801440238953,0.019999999552965164
0.9942344427108765,0.7206616997718811,1.666426181793213,0.9609375
0.8182767033576965,0.9546273946762085,0.5500107407569885,0.019999999552965164
0.9631295800209045,0.6277880668640137,0.05952891707420349,0.05859375
0.9819005727767944,1.0826934576034546,0.7444049715995789,0.30859375
0.9884315133094788,1.0500890016555786,1.1161768436431885,0.019999999552965164
0.9175815582275391,0.09232989698648453,1.596696138381958,0.47265625
0.9868760108947754,0.903079628944397,-0.15774966776371002,0.8515625
0.9866015911102295,0.7533788084983826,0.7489103078842163,0.03125
0.8074312806129456,0.8615151643753052,0.40621864795684814,0.00390625
0.9829285144805908,0.8954831957817078,0.4462486207485199,0.02734375
0.9681841135025024,0.6257772445678711,0.43809664249420166,0.38671875
0.9872947931289673,0.9947993159294128,0.9271130561828613,0.26171875
0.7997345328330994,0.3995186686515808,-0.3755347430706024,0.019999999552965164
0.9922754168510437,1.1357101202011108,-0.10267537832260132,0.5
0.9861471652984619,0.8725204467773438,1.1657888889312744,0.019999999552965164
0.9888646006584167,1.2098380327224731,-0.27832522988319397,0.05078125
0.5641342997550964,1.0501892566680908,1.9519661664962769,0.019999999552965164
0.9548168778419495,0.8971696496009827,1.378737449645996,0.00390625
0.9875019788742065,0.8718118071556091,0.5476236939430237,0.0078125
0.9725168347358704,0.6989551782608032,-1.3157455921173096,0.61328125
0.9864014983177185,0.7576251029968262,-0.41650667786598206,0.00390625
0.960071861743927,0.13068856298923492,0.4819187819957733,0.019999999552965164
0.9849705100059509,0.7724528908729553,0.3877875804901123,0.03125
0.9703006744384766,0.8848260641098022,-1.1767181158065796,0.80078125
0.9837008714675903,0.7015050053596497,0.18209102749824524,0.00390625
0.9579976797103882,0.053806986659765244,2.7309608459472656,0.4000000059604645
0.9896979928016663,0.41135814785957336,0.5738034844398499,0.019999999552965164
0.9853873252868652,0.5438565611839294,0.20562179386615753,0.02734375
0.9784129858016968,0.6330984830856323,-0.1789831817150116,0.015625
0.9375,0.855596125125885,-0.1933964192867279,0.019999999552965164
0.9524176716804504,0.08709807693958282,0.6299692988395691,0.33203125
0.9808038473129272,1.2909820079803467,0.3397117257118225,0.00390625
0.8008236885070801,0.7974631786346436,1.0567312240600586,0.019999999552965164
0.9421642422676086,0.6754576563835144,0.32419073581695557,0.23828125
0.9072281718254089,1.1716840267181396,-0.10382208228111267,0.00390625
0.9497162103652954,1.1582106351852417,-0.11845408380031586,0.00390625
0.9773319959640503,0.5042116641998291,1.2815768718719482,0.23828125
0.9743752479553223,1.1731196641921997,0.48585158586502075,0.1640625
0.9601503610610962,1.0114264488220215,-0.9113408327102661,0.38671875
0.97279292345047,0.32572469115257263,0.548393964767456,0.01171875
0.9845231175422668,0.9852075576782227,1.0973742008209229,0.69140625
0.9764596223831177,0.2248251885175705,0.8963963985443115,0.33203125
0.8746626377105713,0.016590777784585953,1.4492003917694092,0.359375
0.9726155996322632,0.8712832927703857,-0.6451321840286255,0.52734375
0.980800211429596,0.8469374775886536,0.0718703418970108,0.04296875
0.7734344005584717,0.8508065342903137,0.4233662784099579,0.019999999552965164
0.969182550907135,0.8082079887390137,-0.4314402937889099,0.0234375
0.9037994742393494,0.1387290209531784,1.8660004138946533,0.5
0.9869260191917419,0.6927974820137024,0.4927133619785309,0.019999999552965164
0.8794143795967102,0.8060213327407837,-0.6247795820236206,0.09765625
0.9895913600921631,0.8851431012153625,0.9641156196594238,0.28515625
0.9833245873451233,0.9379183053970337,1.5143399238586426,0.0078125
0.26580730080604553,1.488408088684082,2.5120370388031006,0.019999999552965164
0.9859549403190613,1.5805137157440186,0.7283271551132202,0.01171875
0.9376091361045837,0.6854841709136963,0.20175717771053314,0.00390625
0.965065598487854,0.7363166213035583,-0.3636060357093811,0.1484375
0.9904685020446777,0.9182849526405334,0.30159056186676025,0.05859375
0.5014551877975464,0.7409977912902832,0.2045259326696396,0.019999999552965164
0.9434370398521423,0.3679845631122589,0.6447131633758545,0.38671875
0.9806621670722961,0.9568924307823181,1.2417932748794556,0.019999999552965164
0.9825865626335144,1.2273900508880615,-0.0674915760755539,0.0390625
0.9859767556190491,0.7635276317596436,-0.8502742648124695,0.109375
0.9701240658760071,0.46266916394233704,0.38697123527526855,0.0703125
0.9651575088500977,0.5057743191719055,0.6578569412231445,0.0078125
0.9685596227645874,0.6961715817451477,0.20829983055591583,0.015625
0.9772806167602539,0.8312440514564514,-0.09966880083084106,0.019999999552965164
0.9718109369277954,0.8248763680458069,1.2387524843215942,0.08984375
0.9890084266662598,2.0058324337005615,1.7648913860321045,0.019999999552965164
0.9813475608825684,1.02803373336792,1.4689184427261353,0.019999999552965164
0.9925220608711243,0.8020634055137634,0.7509317994117737,0.015625
0.9754987955093384,0.5145153999328613,0.4638928472995758,0.00390625
0.9735408425331116,0.7434492111206055,0.06251777708530426,0.01171875
0.8753963112831116,1.6830265522003174,4.509310722351074,0.019999999552965164
0.9385876655578613,0.46194836497306824,0.13496099412441254,0.13671875
0.9676342010498047,0.5462782979011536,0.9306238889694214,0.1796875
0.9829097986221313,0.8054409623146057,0.11194216459989548,0.08984375
0.9503080248832703,0.44028621912002563,0.4689175486564636,0.00390625
0.9808863997459412,0.8023126721382141,-0.022534284740686417,0.015625
0.9079821109771729,0.33415740728378296,0.544142484664917,0.019999999552965164
0.9839802980422974,0.9184480905532837,0.2658761739730835,0.1484375
0.75,0.8216301798820496,0.3300539255142212,0.019999999552965164
0.9590148329734802,0.722118616104126,0.255025178194046,0.015625
0.9616804122924805,0.8398274779319763,0.33006206154823303,0.019999999552965164
0.7859238386154175,0.5596626400947571,0.5452361702919006,0.019999999552965164
0.9842674732208252,0.07029404491186142,1.189304232597351,0.30859375
0.7237641215324402,0.2756437361240387,-0.10612351447343826,0.019999999552965164
0.9793540239334106,0.5117573738098145,0.8033715486526489,0.01953125
0.9825188517570496,0.3965616822242737,0.17742407321929932,0.019999999552965164
0.9859991073608398,1.32109534740448,0.5763598084449768,0.019999999552965164
0.9551243782043457,0.3639756143093109,0.19449777901172638,0.00390625
0.9606218338012695,0.8222983479499817,0.43461644649505615,0.00390625
0.9785885810852051,0.9104304909706116,0.2279568761587143,0.01171875
0.9705367684364319,0.0769517719745636,0.7330215573310852,0.04296875
0.9736841320991516,0.9110560417175293,0.10864781588315964,0.05859375
0.9880238771438599,1.1702078580856323,0.05487633869051933,0.00390625
0.9913991093635559,0.7445327043533325,1.2198610305786133,0.01171875
0.8302573561668396,0.33997753262519836,1.0731935501098633,0.019999999552965164
0.9880614280700684,0.9227356910705566,2.1198885440826416,0.61328125
0.9173498153686523,0.2221490740776062,0.11565151065587997,0.0078125
0.962620735168457,1.011454701423645,-1.5519139766693115,0.8203125
0.9828791618347168,0.7543124556541443,0.29118794202804565,0.00390625
0.9908701181411743,0.8183356523513794,0.48734790086746216,0.019999999552965164
0.5002585649490356,0.12179236859083176,0.20199841260910034,0.019999999552965164
0.9631574153900146,0.41631683707237244,1.1000276803970337,0.44140625
0.9875426888465881,0.8117235898971558,0.8689690232276917,0.08203125
0.9410585761070251,0.3703889548778534,0.7951740026473999,0.0078125
0.9877454042434692,0.2155231237411499,1.635109305381775,0.94921875
0.9860436320304871,1.0054532289505005,-0.9608616232872009,0.03125
0.9721421003341675,0.5174740552902222,0.43327680230140686,0.0078125
0.9908374547958374,0.8122930526733398,0.21533408761024475,0.0078125
0.9896888136863708,0.7030488848686218,-0.062063876539468765,0.01953125
0.9861313700675964,0.49431633949279785,0.981758177280426,0.12109375
0.9792494177818298,1.0670701265335083,0.7028639316558838,0.019999999552965164
0.9871346950531006,1.3606067895889282,-3.00394868850708,0.61328125
0.9583333134651184,0.9180184602737427,-0.05760742351412773,0.019999999552965164
0.9764145612716675,0.5258041024208069,1.1425464153289795,0.019999999552965164
0.9076833128929138,1.081973910331726,0.6340405344963074,0.019999999552965164
0.9895729422569275,0.27958083152770996,1.2441545724868774,0.08203125
0.916824221611023,0.7878308892250061,-1.3060243129730225,0.359375
0.9883677363395691,0.6098470687866211,0.7665972709655762,0.52734375
0.949999988079071,0.818132758140564,1.5476282835006714,0.019999999552965164
0.9666821360588074,0.707548201084137,0.7326748967170715,0.00390625
0.9861665368080139,0.7194502353668213,2.1585183143615723,0.38671875
0.9811879992485046,0.32190269231796265,0.31508582830429077,0.05078125
0.9625869989395142,0.11173010617494583,0.9030138850212097,0.019999999552965164
0.9675677418708801,0.49738144874572754,0.5481624007225037,0.019999999552965164
0.9764066934585571,1.0306450128555298,0.2257029116153717,0.00390625
0.9857029318809509,0.8312124013900757,-0.12777498364448547,0.00390625
0.9781621098518372,0.621485710144043,0.3126043975353241,0.21875
0.9705549478530884,0.15182119607925415,1.7296228408813477,0.13671875
0.9801698923110962,0.8953424692153931,0.6697174310684204,0.019999999552965164
0.9842199087142944,0.7984838485717773,0.7436375617980957,0.0078125
0.9159231185913086,0.05519663542509079,0.011483916081488132,0.47265625
0.9742691516876221,0.9268448352813721,1.1530364751815796,0.019999999552965164
0.9579406380653381,0.7879363894462585,1.1582229137420654,0.00390625
0.8999202251434326,0.8120636343955994,0.37021151185035706,0.019999999552965164
0.9870507121086121,1.1666820049285889,1.387096881866455,0.019999999552965164
0.9769532680511475,0.6519474983215332,0.3170791268348694,0.109375
0.9546447396278381,0.7559569478034973,0.9533731937408447,0.0078125
0.9773718118667603,1.3183629512786865,1.0090563297271729,0.019999999552965164
0.9049819707870483,1.0706751346588135,1.7704588174819946,0.019999999552965164
0.9003662467002869,0.7251236438751221,-1.4905513525009155,0.4140625
0.9834321141242981,0.5246152877807617,1.2191725969314575,0.47265625
0.9748008847236633,0.8448761105537415,-0.01744924671947956,0.00390625
0.9904628396034241,0.8762193918228149,0.22459718585014343,0.01171875
0.6833457946777344,0.8996955752372742,1.2423095703125,0.019999999552965164
0.9909645318984985,0.8978683948516846,0.7022045254707336,0.019999999552965164
0.9843918681144714,0.12815311551094055,1.5720607042312622,0.78125
0.9382115602493286,0.4989806115627289,1.1206520795822144,0.03515625
0.9832627177238464,0.6727185845375061,0.2797912657260895,0.08984375
0.8830162286758423,1.1294968128204346,1.1474463939666748,0.019999999552965164
0.9554208517074585,0.9476046562194824,0.8490120768547058,0.019999999552965164
0.98823082447052,0.7835749983787537,0.5608289837837219,0.03515625
0.9790570139884949,0.9982950091362,0.3763321042060852,0.00390625
0.5039305686950684,0.9079190492630005,1.265581488609314,0.019999999552965164
0.9871423840522766,0.6633929014205933,0.09028752893209457,0.019999999552965164
0.8614975214004517,0.9595098495483398,-0.5349600315093994,0.00390625
0.9873358011245728,0.698331892490387,0.7571848630905151,0.1484375
0.7227392196655273,1.1300171613693237,1.1754553318023682,0.019999999552965164
0.9814568758010864,0.46864795684814453,0.6286783218383789,0.19921875
0.9876973032951355,0.29863566160202026,0.7726709842681885,0.61328125
0.9887779951095581,1.1818888187408447,-1.0321481227874756,0.38671875
0.9684743285179138,0.7226923108100891,0.0908145159482956,0.0390625
0.9854185581207275,1.0576037168502808,0.35190048813819885,0.0078125
0.9463624954223633,0.781932532787323,0.7598024606704712,0.01171875
0.9837555885314941,0.8735848665237427,0.5948384404182434,0.019999999552965164
0.9700835347175598,0.45710718631744385,2.141801357269287,0.8359375
0.9896127581596375,1.018708348274231,0.23626597225666046,0.01953125
0.7728451490402222,8.084141001063472e-08,0.7415778636932373,0.4000000059604645
0.9838477969169617,0.8994008302688599,0.15494465827941895,0.00390625
0.9421281218528748,0.4648025333881378,0.12706322968006134,0.00390625
0.9843724370002747,1.0055731534957886,-0.911835253238678,0.23828125
0.958256185054779,1.1208757162094116,-0.31016042828559875,0.0078125
0.9832971692085266,0.056124646216630936,1.7148709297180176,0.23828125
0.9804430603981018,0.4016909897327423,0.6085042357444763,0.0703125
0.9825966358184814,0.9228396415710449,0.912163257598877,0.019999999552965164
0.9441317915916443,0.048142336308956146,0.6141980290412903,0.109375
0.9856440424919128,0.8616625666618347,0.28943121433258057,0.015625
0.9913654923439026,1.0482347011566162,0.6889304518699646,0.015625
0.97914719581604,0.8870795369148254,-0.700239360332489,0.015625
0.9836585521697998,0.5450212955474854,0.009687358513474464,0.01953125
0.990472137928009,0.8221097588539124,2.5926225185394287,0.97265625
0.6274135708808899,0.6787079572677612,0.12988793849945068,0.015625
0.982601523399353,0.7495649456977844,1.2217103242874146,0.019999999552965164
0.9841020703315735,0.9071263670921326,1.3682825565338135,0.09765625
0.9872562885284424,0.818276584148407,-0.14663955569267273,0.05859375
0.5041943192481995,0.35444244742393494,0.46112486720085144,0.00390625
0.7517910599708557,0.91172856092453,1.3611085414886475,0.019999999552965164
0.9861181378364563,1.0613479614257812,-0.46272075176239014,0.015625
0.9914185404777527,0.9464229941368103,1.2103853225708008,0.0234375
0.984909176826477,0.5985794067382812,0.7704220414161682,0.08203125
0.9575125575065613,0.7695640325546265,0.6132461428642273,0.00390625
0.9845197200775146,0.7421835064888,1.332088589668274,0.019999999552965164
0.9470700621604919,0.357934832572937,1.0986406803131104,0.359375
0.9287161231040955,0.6833012104034424,0.373298704624176,0.00390625
0.9531774520874023,0.3247152864933014,0.6011538505554199,0.66796875
0.9779354929924011,0.828241229057312,0.3349589705467224,0.03125
0.9863978028297424,0.932086169719696,0.04865559563040733,0.02734375
0.9826814532279968,0.06353739649057388,1.879408359527588,0.61328125
0.974474310874939,0.8063777685165405,0.8257133364677429,0.019999999552965164
0.9670184254646301,0.09195757657289505,1.7024414539337158,0.5
0.9885809421539307,0.7981435656547546,-0.11792337149381638,0.0703125
0.9829109907150269,0.9578585028648376,-1.9371291399002075,0.13671875
0.9754639863967896,1.137816071510315,0.5887423157691956,0.00390625
0.9755549430847168,0.677255392074585,0.20494212210178375,0.00390625
0.9903355836868286,1.0475162267684937,2.1768462657928467,0.52734375
0.9855127930641174,0.9580414891242981,0.35021960735321045,0.76171875
0.9450457692146301,0.4737727642059326,-0.3041325807571411,0.01171875
0.9360163807868958,0.9219141006469727,1.2481396198272705,0.019999999552965164
0.9696909189224243,0.06589268147945404,1.456658124923706,0.30000001192092896
0.6495901942253113,0.8538134098052979,0.3043774366378784,0.019999999552965164
0.9901140928268433,0.8112474679946899,0.7102972269058228,0.019999999552965164
0.9925929307937622,0.49307680130004883,0.6297348737716675,0.019999999552965164
0.9840761423110962,0.5691578388214111,0.9437046647071838,0.00390625
0.9625457525253296,0.9322702288627625,1.3358750343322754,0.0234375
0.9820173978805542,0.6805416345596313,1.0065922737121582,0.05859375
0.9883391261100769,0.742003321647644,0.6168643236160278,0.0078125
0.9119130969047546,0.8404607176780701,0.8882355690002441,0.01171875
0.9854885935783386,1.295777440071106,0.5272557735443115,0.00390625
0.9911734461784363,1.152715802192688,-0.05230601131916046,0.019999999552965164
0.8071879744529724,0.4576769471168518,1.391660451889038,0.00390625
0.9919166564941406,1.1775370836257935,0.5039792060852051,0.019999999552965164
0.9831258654594421,0.9164834022521973,0.3790256977081299,0.01171875
0.990642249584198,0.9242916107177734,1.477474570274353,0.38671875
0.7415178418159485,0.2909083068370819,0.19971248507499695,0.019999999552965164
0.9146556854248047,0.06850286573171616,1.3211928606033325,0.61328125
0.976986825466156,0.6469135284423828,-0.7279839515686035,0.02734375
0.968462347984314,0.4640704393386841,1.4650955200195312,0.1484375
0.937825083732605,0.9767780303955078,-0.7378027439117432,0.0390625
0.9878604412078857,1.1423084735870361,1.7311146259307861,0.1484375
0.9904257655143738,0.9551829099655151,1.564165472984314,0.00390625
0.9830996990203857,0.92529296875,-0.1086890697479248,0.02734375
0.9820512533187866,0.7556048631668091,0.6512532830238342,0.109375
0.9740781188011169,0.8380919098854065,0.19731587171554565,0.019999999552965164
0.9830799698829651,1.183397650718689,-0.801214873790741,0.019999999552965164
0.9898439049720764,1.168870210647583,1.2985308170318604,0.00390625
0.97286057472229,0.8012385964393616,-1.657444953918457,0.09765625
0.9182834625244141,0.5254654884338379,-0.027080848813056946,0.04296875
0.9729798436164856,0.4111078381538391,1.077646255493164,0.019999999552965164
0.6875,1.756393551826477,0.34522199630737305,0.019999999552965164
0.9920725226402283,1.0676580667495728,1.1592471599578857,0.019999999552965164
0.37564563751220703,0.07466565072536469,0.3562135696411133,0.019999999552965164
0.9894161224365234,0.8109862804412842,1.3056280612945557,0.0390625
0.9386259317398071,0.5322021842002869,-0.03461914509534836,0.08984375
0.9866133332252502,0.8940346240997314,1.0361984968185425,0.00390625
0.9822850823402405,0.6215930581092834,-0.6859042048454285,0.00390625
0.9752063155174255,1.0129338502883911,0.3866007626056671,0.019999999552965164
0.9825329184532166,0.567034125328064,0.5370683670043945,0.5
0.9422088861465454,0.9411858320236206,0.5332568883895874,0.38671875
0.9506444931030273,0.7494101524353027,0.9869776368141174,0.00390625
0.9923189282417297,1.1255286931991577,0.8734608292579651,0.019999999552965164
0.9807777404785156,0.9558923244476318,1.5415621995925903,0.09765625
0.961335301399231,0.7840818762779236,0.06915930658578873,0.00390625
0.9867202639579773,1.0596263408660889,0.21268242597579956,0.0078125
0.9926426410675049,0.8886650204658508,0.6200761198997498,0.019999999552965164
0.9791930913925171,0.4474319517612457,0.5827012062072754,0.019999999552965164
0.986801028251648,1.1846712827682495,1.4253416061401367,0.00390625
0.9549052119255066,0.6142332553863525,0.4867286682128906,0.00390625
0.983259916305542,0.42561075091362,0.9666317105293274,0.08203125
0.98175048828125,0.7744573354721069,0.4953071177005768,0.019999999552965164
0.987273097038269,0.8209654092788696,0.5267868041992188,0.019999999552965164
0.9916341304779053,0.6881924271583557,0.9522916078567505,0.019999999552965164
0.9819192886352539,0.8128346800804138,0.6556753516197205,0.05859375
0.9854727387428284,0.6597779393196106,0.9645410180091858,0.8359375
0.9891805648803711,0.7752296924591064,1.34084153175354,0.52734375
0.9489904046058655,0.6988677978515625,0.5052891969680786,0.019999999552965164
0.9741962552070618,0.43797168135643005,0.7825477123260498,0.01171875
0.9907783269882202,0.8732656240463257,1.1458243131637573,0.19921875
0.9760454297065735,0.7810378670692444,-0.29553040862083435,0.015625
0.9885720014572144,0.8427382707595825,0.2628841996192932,0.019999999552965164
0.8171960115432739,0.3271152079105377,1.30915105342865,0.26171875
0.9881270527839661,0.13021250069141388,1.6307408809661865,0.55859375
0.9751906991004944,0.8255484104156494,0.21788427233695984,0.019999999552965164
0.9630831480026245,2.1396600701476974e-15,2.883542776107788,0.5
0.8849332332611084,0.888649582862854,1.0651483535766602,0.01171875
0.9897550344467163,0.08640030771493912,2.661073923110962,0.69140625
0.9030827879905701,0.7017505168914795,0.07822071760892868,0.00390625
0.9650112986564636,0.36098214983940125,0.7112777829170227,0.0078125
0.9872719049453735,0.7115703821182251,0.6924230456352234,0.019999999552965164
0.5884749889373779,0.0942283645272255,0.24825790524482727,0.019999999552965164
0.9642857313156128,0.5304845571517944,0.6281308531761169,0.019999999552965164
0.9651434421539307,0.07168509066104889,1.4704163074493408,0.61328125
0.9779187440872192,1.0171563625335693,-2.8089962005615234,0.1484375
0.9375227689743042,0.9291267991065979,0.6853470802307129,0.019999999552965164
0.9820515513420105,0.7226945757865906,-0.19336646795272827,0.61328125
0.984882652759552,0.8176864385604858,1.161419153213501,0.0078125
0.9573767185211182,0.9027169346809387,0.15423306822776794,0.26171875
0.9059234261512756,0.872424840927124,0.7419941425323486,0.019999999552965164
0.9914654493331909,1.0662620067596436,2.7141172885894775,0.55859375
0.9839044809341431,0.9037585854530334,0.7042809724807739,0.01953125
0.986689567565918,0.6848335266113281,0.9014078974723816,0.00390625
0.9837497472763062,0.7507086396217346,0.7179840207099915,0.0078125
0.9895229339599609,1.1564929485321045,0.5822750926017761,0.019999999552965164
0.9845471978187561,0.8716567158699036,0.19987598061561584,0.01953125
0.971385657787323,0.49073365330696106,1.2333439588546753,0.73828125
0.9841684699058533,0.6468350887298584,1.0000839233398438,0.0703125
0.9882851839065552,0.26080548763275146,0.8985073566436768,0.01171875
0.9851044416427612,0.8687262535095215,0.07842865586280823,0.1796875
0.9799972772598267,0.25032666325569153,1.2494641542434692,0.10000000149011612
0.9896620512008667,0.7762697339057922,0.20227234065532684,0.019999999552965164
0.990495502948761,0.15801414847373962,1.006077766418457,0.01171875
0.9806667566299438,0.7082678079605103,0.35462483763694763,0.02734375
0.9715457558631897,0.0615643672645092,0.9478678703308105,0.4000000059604645
0.9168440103530884,0.5679594874382019,-0.6143214106559753,0.1484375
0.9824567437171936,0.45072048902511597,1.0683321952819824,0.1484375
0.9840478301048279,0.08733312040567398,1.3535010814666748,0.47265625
0.9896746873855591,1.1761761903762817,0.7102295756340027,0.94140625
0.9827673435211182,0.8215981125831604,0.6729252338409424,0.019999999552965164
0.9906817674636841,0.16318124532699585,1.133107304573059,0.30000001192092896
0.9701097011566162,1.0519390106201172,-0.16105352342128754,0.00390625
0.9417809844017029,0.7868722081184387,1.1539735794067383,0.019999999552965164
0.9615354537963867,0.8469739556312561,0.6801642179489136,0.0390625
0.988472580909729,0.81600022315979,0.6296193599700928,0.019999999552965164
0.9841001629829407,0.8400164246559143,-0.06806250661611557,0.00390625
0.9276565313339233,0.32582467794418335,-0.14148345589637756,0.019999999552965164
0.7008209228515625,0.545078694820404,1.1250351667404175,0.019999999552965164
0.9907881021499634,0.9919379353523254,-0.12143492698669434,0.019999999552965164
0.9702130556106567,0.7762024402618408,0.24524429440498352,0.0078125
0.9876235723495483,0.7181832790374756,0.41931474208831787,0.019999999552965164
0.9841905236244202,0.8836563229560852,0.28947240114212036,0.00390625
0.990247905254364,0.9825950860977173,0.6003378033638,0.00390625
0.9635987281799316,0.3707619905471802,-0.03457726538181305,0.0390625
0.9924789071083069,1.485293984413147,0.5796234607696533,0.00390625
0.9839015603065491,0.06343062222003937,1.9442640542984009,0.5
0.9927193522453308,0.7006005048751831,0.3714500069618225,0.019999999552965164
0.9870567321777344,0.869498610496521,1.5008329153060913,0.00390625
0.9002388119697571,0.4945279657840729,-0.27996397018432617,0.0078125
0.98891282081604,0.8541091680526733,0.5112633109092712,0.66796875
0.9001862406730652,0.43330734968185425,0.3592444360256195,0.00390625
0.958705723285675,0.7425220012664795,0.15833647549152374,0.00390625
0.9910086989402771,0.9245886206626892,0.8454338908195496,0.01953125
0.9912900328636169,1.3806378841400146,1.0953043699264526,0.99609375
0.9887956976890564,1.0331758260726929,0.6490115523338318,0.640625
0.8638584017753601,0.902369499206543,-0.2767508327960968,0.0078125
0.7059138417243958,1.0,1.032223091723683e-11,0.019999999552965164
0.9889519810676575,0.8361310362815857,0.811896800994873,0.03515625
0.970467209815979,0.07315781712532043,0.20799599587917328,0.00390625
0.9828550219535828,0.8393198251724243,0.6089786291122437,0.28515625
0.9553551077842712,0.7775288820266724,-0.4464336037635803,0.046875
0.9782186150550842,0.4313304126262665,0.4458310604095459,0.019999999552965164
0.9371097087860107,0.9338632225990295,1.3358187675476074,0.019999999552965164
0.9861361384391785,0.24091234803199768,1.4301774501800537,0.80078125
0.9890525341033936,1.1365840435028076,0.3055979013442993,0.00390625
0.957517683506012,0.058012738823890686,0.15909947454929352,0.046875
0.9762251377105713,0.72292160987854,0.49151331186294556,0.019999999552965164
0.9875496625900269,0.9114606976509094,-0.5052767992019653,0.05859375
0.9715835452079773,0.8113637566566467,-2.0302956104278564,0.019999999552965164
0.9846333265304565,0.49688151478767395,0.7285738587379456,0.019999999552965164
0.98553466796875,0.1484774351119995,1.3616747856140137,0.5859375
0.9866309762001038,1.0217945575714111,-0.8717418313026428,0.02734375
0.9891880750656128,0.42588523030281067,0.7833192944526672,0.109375
0.9870361685752869,0.8525673151016235,1.2773776054382324,0.019999999552965164
0.9897037744522095,0.8012522459030151,0.3973642885684967,0.109375
0.9828903079032898,1.1558295488357544,-0.6781614422798157,0.5859375
0.9924454689025879,1.1040401458740234,1.3243318796157837,0.019999999552965164
0.9826735258102417,1.0064337253570557,-0.5324167013168335,0.38671875
0.949999988079071,0.8152432441711426,0.6293236613273621,0.00390625
0.9905489087104797,0.9191447496414185,0.5621309876441956,0.019999999552965164
0.9664857387542725,0.5995981693267822,-0.7409313321113586,0.01171875
0.9847198724746704,0.8284208178520203,0.2851041555404663,0.9296875
0.9342833757400513,0.5566492676734924,0.6875373721122742,0.019999999552965164
0.8894915580749512,0.4102778434753418,0.37977635860443115,0.01953125
0.9870865941047668,0.44245558977127075,0.16041725873947144,0.10000000149011612
0.9890456795692444,1.1491310596466064,1.0844204425811768,0.01953125
0.7304704785346985,0.12790271639823914,-0.1085965558886528,0.019999999552965164
0.9830618500709534,0.8738722205162048,-0.11583804339170456,0.0234375
0.9885876178741455,0.744857668876648,0.11028216779232025,0.01953125
0.9575535655021667,0.3011772632598877,0.5136104226112366,0.00390625
0.9298899173736572,1.1736249923706055,4.0247297286987305,0.09765625
0.9907795190811157,1.0897759199142456,0.6261603236198425,0.019999999552965164
0.9855174422264099,0.6543705463409424,0.08955699950456619,0.08984375
0.976660430431366,0.5610390901565552,0.6389923095703125,0.0390625
0.9870068430900574,0.80875563621521,-0.6651867032051086,0.08984375
0.9652793407440186,0.5887689590454102,0.5353426933288574,0.0703125
0.9875175952911377,0.7699108123779297,0.876632034778595,0.019999999552965164
0.9016479849815369,0.9994669556617737,0.30356451869010925,0.015625
0.989987850189209,0.7350922226905823,0.8748764991760254,0.0078125
0.983323335647583,0.8931586146354675,1.0226351022720337,0.01171875
0.9914804100990295,0.9369975328445435,0.8283791542053223,0.019999999552965164
0.9704275727272034,1.124052882194519,0.9457330107688904,0.019999999552965164
0.9867291450500488,0.9667392373085022,-0.6122757196426392,0.44140625
0.9887421131134033,0.7823470234870911,0.343982458114624,0.00390625
0.9861542582511902,0.9171664118766785,0.35665032267570496,0.019999999552965164
0.9772396683692932,0.08705096691846848,1.7621256113052368,0.66796875
0.9819098114967346,0.8605496883392334,0.5151250958442688,0.01171875
0.982971727848053,0.5631197690963745,1.608361005783081,0.019999999552965164
0.9914254546165466,0.3850722908973694,1.4068152904510498,0.98828125
0.9880355596542358,1.1387118101119995,1.4653834104537964,0.05859375
0.9586950540542603,1.7633997201919556,1.0344760417938232,0.019999999552965164
0.9828103184700012,0.8817474842071533,0.7680216431617737,0.890625
0.9880233407020569,0.899823784828186,0.44692227244377136,0.19921875
0.9862816333770752,0.8610615134239197,0.4195229709148407,0.03125
0.9813369512557983,0.8014124631881714,1.1136316061019897,0.0078125
0.9148907661437988,0.5909111499786377,1.2860896587371826,0.015625
0.9865161776542664,0.8720636963844299,0.6233670115470886,0.015625
0.9786784648895264,0.48225611448287964,-0.005022380966693163,0.12109375
0.9843324422836304,1.0519789457321167,-2.2056643962860107,0.03125
0.9688847064971924,0.8007095456123352,0.14495795965194702,0.1640625
0.9724696278572083,0.9987169504165649,0.32869264483451843,0.019999999552965164
0.9875112175941467,1.0948023796081543,2.15657114982605,0.03125
0.9923174381256104,0.10759950429201126,0.6762840747833252,0.019999999552965164
0.9666666388511658,0.6234443783760071,1.4971232414245605,0.0390625
0.989655613899231,0.8248854279518127,0.4701078534126282,0.019999999552965164
0.9753870368003845,0.6746605634689331,-0.23550045490264893,0.1640625
0.9170913100242615,1.0504746437072754,2.7344093322753906,0.019999999552965164
0.9821392297744751,1.4154850244522095,1.2012253999710083,0.019999999552965164
0.9886221885681152,1.22860586643219,1.160277009010315,0.890625
0.9877735376358032,0.6805673837661743,1.5975077152252197,0.359375
0.9831939339637756,0.6648986339569092,1.1059051752090454,0.28515625
0.950076162815094,0.724887490272522,0.316800057888031,0.019999999552965164
0.9817547798156738,0.8619367480278015,-0.24251239001750946,0.109375
0.9849069714546204,0.8399055004119873,1.7567216157913208,0.4000000059604645
0.9821556806564331,0.8135135769844055,0.33616918325424194,0.0078125
0.8329862356185913,0.7938078045845032,1.0597797632217407,0.019999999552965164
0.9856904149055481,0.05120579153299332,0.8267747759819031,0.5
0.9766159057617188,0.7623113989830017,0.7656452059745789,0.09765625
0.9885436296463013,0.9814053177833557,0.05546858534216881,0.00390625
0.9900276064872742,0.9320858716964722,-0.36458709836006165,0.03125
0.9058290123939514,0.7260504364967346,1.1726433038711548,0.019999999552965164
0.9503811597824097,0.6632846593856812,0.7332696914672852,0.019999999552965164
0.9846004247665405,0.6996731758117676,-0.8613988757133484,0.019999999552965164
0.9897956252098083,0.8407823443412781,1.2952353954315186,0.76171875
0.9898385405540466,0.7309674024581909,0.7317643761634827,0.019999999552965164
0.9850022196769714,0.7537633180618286,0.3925366699695587,0.03125
0.9858620762825012,0.9250133633613586,2.0220303535461426,0.9296875
0.8120821714401245,0.3994182348251343,-0.4576922655105591,0.019999999552965164
0.9496838450431824,0.8251343965530396,0.15125347673892975,0.019999999552965164
0.9420520067214966,0.6087028384208679,1.0767998695373535,0.019999999552965164
0.9899152517318726,0.8887513279914856,0.9602599143981934,0.019999999552965164
0.9461711049079895,1.1373282670974731,0.6371906995773315,0.00390625
0.9834751486778259,0.7226889729499817,0.8995278477668762,0.109375
0.9850850105285645,1.2857465744018555,-2.2220215797424316,0.38671875
0.9789451956748962,0.9153420925140381,0.12551555037498474,0.01171875
0.8774109482765198,0.9271970987319946,0.5529487729072571,0.019999999552965164
0.9074040651321411,0.920030951499939,0.40618932247161865,0.00390625
0.9878932237625122,0.5347745418548584,0.8865230679512024,0.046875
0.937852144241333,1.1346293687820435,-0.3324768841266632,0.019999999552965164
0.7542195916175842,0.44728168845176697,0.45312440395355225,0.019999999552965164
0.9915731549263,1.3838905096054077,-0.043990228325128555,0.01171875
0.9284758567810059,0.4973248541355133,0.9887621998786926,0.019999999552965164
0.9700435400009155,0.8664135336875916,1.0059133768081665,0.046875
0.9667003750801086,0.7796391844749451,-0.10554620623588562,0.00390625
0.9698932766914368,0.7340040802955627,0.4837290942668915,0.00390625
0.973517894744873,0.9678344130516052,0.36683231592178345,0.00390625
0.9770389795303345,0.8958415389060974,1.2423408031463623,0.015625
0.9902989864349365,0.7568255066871643,0.9843511581420898,0.019999999552965164
0.9908176064491272,0.8731094002723694,0.6906698346138,0.00390625
0.9901729226112366,0.8561913371086121,0.8783953189849854,0.5859375

View File

@ -0,0 +1,9 @@
{
"associated_files": [
{
"name": "score_calibration.txt",
"description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.",
"type": "TENSOR_AXIS_SCORE_CALIBRATION"
}
]
}

View File

@ -0,0 +1,15 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "ScoreCalibrationOptions",
"options": {
"score_transformation": "LOG",
"default_score": 0.2
}
}
]
}
]
}

View File

@ -25,6 +25,7 @@ package(
mediapipe_files(srcs = [
"burger.jpg",
"burger_crop.jpg",
"burger_rotated.jpg",
"cat.jpg",
"cat_mask.jpg",
"cats_and_dogs.jpg",
@ -46,6 +47,7 @@ mediapipe_files(srcs = [
"mobilenet_v3_small_100_224_embedder.tflite",
"mozart_square.jpg",
"multi_objects.jpg",
"multi_objects_rotated.jpg",
"palm_detection_full.tflite",
"pointing_up.jpg",
"right_hands.jpg",
@ -72,6 +74,7 @@ filegroup(
srcs = [
"burger.jpg",
"burger_crop.jpg",
"burger_rotated.jpg",
"cat.jpg",
"cat_mask.jpg",
"cats_and_dogs.jpg",
@ -81,6 +84,7 @@ filegroup(
"left_hands.jpg",
"mozart_square.jpg",
"multi_objects.jpg",
"multi_objects_rotated.jpg",
"pointing_up.jpg",
"right_hands.jpg",
"segmentation_golden_rotation0.png",

View File

@ -22,12 +22,24 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/albert_with_metadata.tflite?generation=1661875651648830"],
)
http_file(
name = "com_google_mediapipe_associated_file_meta_json",
sha256 = "5b2cba11ae893e1226af6570813955889e9f171d6d2c67b3e96ecb6b96d8c681",
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_tflite",
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1663009542017720"],
)
http_file(
name = "com_google_mediapipe_bounding_box_tensor_meta_json",
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
urls = ["https://storage.googleapis.com/mediapipe-assets/bounding_box_tensor_meta.json?generation=1665422797529909"],
)
http_file(
name = "com_google_mediapipe_BUILD",
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
@ -46,6 +58,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/burger.jpg?generation=1661875667922678"],
)
http_file(
name = "com_google_mediapipe_burger_rotated_jpg",
sha256 = "b7bb5e59ef778f3ce6b3e616c511908a53d513b83a56aae58b7453e14b0a4b2a",
urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"],
)
http_file(
name = "com_google_mediapipe_cat_jpg",
sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be",
@ -70,6 +88,24 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_no_resizing.jpg?generation=1661875687251296"],
)
http_file(
name = "com_google_mediapipe_classification_tensor_float_meta_json",
sha256 = "1d10b1c9c87eabac330651136804074ddc134779e94a73cf783207c3aa2a5619",
urls = ["https://storage.googleapis.com/mediapipe-assets/classification_tensor_float_meta.json?generation=1665422803073223"],
)
http_file(
name = "com_google_mediapipe_classification_tensor_uint8_meta_json",
sha256 = "74f4d64ee0017d11e0fdc975a88d974d73b72b889fd4d67992356052edde0f1e",
urls = ["https://storage.googleapis.com/mediapipe-assets/classification_tensor_uint8_meta.json?generation=1665422808178685"],
)
http_file(
name = "com_google_mediapipe_classification_tensor_unsupported_meta_json",
sha256 = "4810ad8a00f0078c6a693114d00f692aa70ff2d61030a6e516db1e654707e208",
urls = ["https://storage.googleapis.com/mediapipe-assets/classification_tensor_unsupported_meta.json?generation=1665422813312699"],
)
http_file(
name = "com_google_mediapipe_coco_efficientdet_lite0_v1_1_0_quant_2021_09_06_tflite",
sha256 = "dee1b4af055a644804d5594442300ecc9e4f7080c25b7c044c98f527eeabb6cf",
@ -166,6 +202,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"],
)
http_file(
name = "com_google_mediapipe_feature_tensor_meta_json",
sha256 = "b2c30ddfd495956ce81085f8a143422f4310b002cfbf1c594ff2ee0576e29d6f",
urls = ["https://storage.googleapis.com/mediapipe-assets/feature_tensor_meta.json?generation=1665422818797346"],
)
http_file(
name = "com_google_mediapipe_general_meta_json",
sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f",
urls = ["https://storage.googleapis.com/mediapipe-assets/general_meta.json?generation=1665422822603848"],
)
http_file(
name = "com_google_mediapipe_golden_json_json",
sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c",
@ -208,6 +256,30 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_recrop.tflite?generation=1661875770633070"],
)
http_file(
name = "com_google_mediapipe_image_tensor_meta_json",
sha256 = "aad86fde3defb379c82ff7ee48e50493a58529cdc0623cf0d7bf135c3577060e",
urls = ["https://storage.googleapis.com/mediapipe-assets/image_tensor_meta.json?generation=1665422826106636"],
)
http_file(
name = "com_google_mediapipe_input_image_tensor_float_meta_json",
sha256 = "426ecf5c3ace61db3936b950c3709daece15827ea21905ddbcdc81b1c6e70232",
urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_float_meta.json?generation=1665422829230563"],
)
http_file(
name = "com_google_mediapipe_input_image_tensor_uint8_meta_json",
sha256 = "dc7ff86b606641e480c7d154b5f467e1f8c895f85733c73ba47a259a66ed187b",
urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_uint8_meta.json?generation=1665422832572887"],
)
http_file(
name = "com_google_mediapipe_input_image_tensor_unsupported_meta_json",
sha256 = "443d436c2068df8201b9822c35e724acfd8004a788d388e7d74c38a2425c55df",
urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_unsupported_meta.json?generation=1665422835757143"],
)
http_file(
name = "com_google_mediapipe_iris_and_gaze_tflite",
sha256 = "b6dcb860a92a3c7264a8e50786f46cecb529672cdafc17d39c78931257da661d",
@ -370,6 +442,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"],
)
http_file(
name = "com_google_mediapipe_multi_objects_rotated_jpg",
sha256 = "175f6c572ffbab6554e382fd5056d09720eef931ccc4ed79481bdc47a8443911",
urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects_rotated.jpg?generation=1665065847969523"],
)
http_file(
name = "com_google_mediapipe_object_detection_3d_camera_tflite",
sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d",
@ -472,6 +550,24 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands.jpg?generation=1661875908672404"],
)
http_file(
name = "com_google_mediapipe_score_calibration_file_meta_json",
sha256 = "6a3c305620371f662419a496f75be5a10caebca7803b1e99d8d5d22ba51cda94",
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration_file_meta.json?generation=1665422841236117"],
)
http_file(
name = "com_google_mediapipe_score_calibration_tensor_meta_json",
sha256 = "24cbde7f76dd6a09a55d07f30493c2f254d61154eb2e8d18ed947ff56781186d",
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration_tensor_meta.json?generation=1665422844327992"],
)
http_file(
name = "com_google_mediapipe_score_calibration_txt",
sha256 = "34b0c51a8c79b4515bdd24e440c4b76a9f0fd01ef6385b36af983036e7be6271",
urls = ["https://storage.googleapis.com/mediapipe-assets/score_calibration.txt?generation=1665422847392804"],
)
http_file(
name = "com_google_mediapipe_segmentation_golden_rotation0_png",
sha256 = "9ee993919b753118928ba2d14f7c5c83a6cfc23355e6943dac4ad81eedd73069",