Merge branch 'master' into gesture-recognizer-python

This commit is contained in:
Kinar R 2022-10-31 16:47:43 +05:30 committed by GitHub
commit 5ec87c8bd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1370 additions and 121 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
from typing import List, Tuple
import tensorflow as tf
@ -21,19 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
"""Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
super().__init__(dataset, size)
self._index_by_label = index_by_label
self._label_names = label_names
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self._index_by_label)
return len(self._label_names)
@property
def index_by_label(self: ds._DatasetT) -> Any:
return self._index_by_label
def label_names(self: ds._DatasetT) -> List[str]:
return self._label_names
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -48,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self._index_by_label)
return self._split(fraction, self._label_names)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple, TypeVar
from typing import Any, List, Tuple, TypeVar
# Dependency imports
@ -37,26 +37,22 @@ class ClassificationDatasetTest(tf.test.TestCase):
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any):
super().__init__(
dataset=dataset, size=size, index_by_label=index_by_label)
label_names: List[str], value: Any):
super().__init__(dataset=dataset, size=size, label_names=label_names)
self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_by_label, self.value)
return self._split(fraction, self.label_names, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_by_label = (False, True)
label_names = ['foo', 'bar']
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset(
dataset=ds,
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
# Train/Test data split.
fraction = .25
@ -73,7 +69,7 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_by_label, index_by_label)
self.assertEqual(test_data.label_names, label_names)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
full_train: bool):
"""Initilizes a classifier with its specifications.
Args:
model_spec: Specification for the model.
index_by_label: A list that map from index to label class name.
label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
"""
super(Classifier, self).__init__(model_spec, shuffle)
self._index_by_label = index_by_label
self._label_names = label_names
self._full_train = full_train
self._num_classes = len(index_by_label)
self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_by_label))
f.write('\n'.join(self._label_names))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self):
super(ClassifierTest, self).setUp()
index_by_label = ['cat', 'dog']
label_names = ['cat', 'dog']
self.model = MockClassifier(
model_spec=None,
index_by_label=index_by_label,
label_names=label_names,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -106,4 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
dataset=image_label_ds, size=all_image_size, label_names=label_names)

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
self.assertEqual(test_data.label_names, ['pos', 'neg'])
def test_from_folder(self):
data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
self.assertEqual(data.label_names, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_by_label,
self.assertEqual(train_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_by_label,
self.assertEqual(validation_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_by_label,
self.assertEqual(test_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""APIs to train image classifier model."""
from typing import Any, List, Optional
from typing import List, Optional
import tensorflow as tf
import tensorflow_hub as hub
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams):
"""Initializes ImageClassifier class.
Args:
model_spec: Specification for the model.
index_by_label: A list that maps from index to label class name.
label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier.
"""
super().__init__(
model_spec=model_spec,
index_by_label=index_by_label,
label_names=label_names,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec)
image_classifier = cls(
model_spec=spec,
index_by_label=train_data.index_by_label,
hparams=hparams)
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
image_classifier._create_model()

View File

@ -88,6 +88,7 @@ cc_library(
name = "builtin_task_graphs",
deps = [
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
],

View File

@ -62,14 +62,18 @@ cc_library(
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:combined_prediction_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:combined_prediction_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
@ -77,8 +81,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -153,13 +153,12 @@ class CombinedPredictionCalculator : public Node {
// After loop, if have winning prediction return. Otherwise empty packet.
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
auto collection = kClassificationListIn(cc);
for (int idx = 0; idx < collection.Count(); ++idx) {
const auto& packet = collection[idx];
if (packet.IsEmpty()) {
for (const auto& input : collection) {
if (input.IsEmpty() || input.Get().classification_size() == 0) {
continue;
}
auto prediction = GetWinningPrediction(
packet.Get(), classwise_thresholds_, options_.background_label(),
input.Get(), classwise_thresholds_, options_.background_label(),
options_.default_global_threshold());
if (prediction->classification(0).label() !=
options_.background_label()) {

View File

@ -146,6 +146,10 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
->mutable_canned_gesture_classifier_graph_options()
->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence);
hand_gesture_recognizer_graph_options
->mutable_custom_gesture_classifier_graph_options()
->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence);
}
return options_proto;
}

View File

@ -30,14 +30,17 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h"
@ -58,6 +61,7 @@ using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::processors::
ConfigureTensorsToClassificationCalculator;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::core::proto::BaseOptions;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions;
@ -78,13 +82,20 @@ constexpr char kVectorTag[] = "VECTOR";
constexpr char kIndexTag[] = "INDEX";
constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kPredictionTag[] = "PREDICTION";
constexpr char kBackgroundLabel[] = "None";
constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite";
constexpr char kCannedGestureClassifierTFLiteName[] =
"canned_gesture_classifier.tflite";
constexpr char kCustomGestureClassifierTFLiteName[] =
"custom_gesture_classifier.tflite";
struct SubTaskModelResources {
const core::ModelResources* gesture_embedder_model_resource;
const core::ModelResources* canned_gesture_classifier_model_resource;
const core::ModelResources* gesture_embedder_model_resource = nullptr;
const core::ModelResources* canned_gesture_classifier_model_resource =
nullptr;
const core::ModelResources* custom_gesture_classifier_model_resource =
nullptr;
};
Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
@ -94,41 +105,21 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
return node[Output<std::vector<Tensor>>{"TENSORS"}];
}
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandGestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file,
gesture_embedder_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
gesture_embedder_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile(
canned_gesture_classifier_file,
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
canned_gesture_classifier_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
absl::Status ConfigureCombinedPredictionCalculator(
CombinedPredictionCalculatorOptions* options) {
options->set_background_label(kBackgroundLabel);
return absl::OkStatus();
}
void PopulateAccelerationAndUseStreamMode(
const BaseOptions& parent_base_options,
BaseOptions* sub_task_base_options) {
sub_task_base_options->mutable_acceleration()->CopyFrom(
parent_base_options.acceleration());
sub_task_base_options->set_use_stream_mode(
parent_base_options.use_stream_mode());
}
} // namespace
// A
@ -212,6 +203,56 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
}
private:
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandGestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file,
gesture_embedder_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
PopulateAccelerationAndUseStreamMode(
options->base_options(),
gesture_embedder_graph_options->mutable_base_options());
ASSIGN_OR_RETURN(
const auto canned_gesture_classifier_file,
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile(
canned_gesture_classifier_file,
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
PopulateAccelerationAndUseStreamMode(
options->base_options(),
canned_gesture_classifier_graph_options->mutable_base_options());
const auto custom_gesture_classifier_file =
resources.GetModelFile(kCustomGestureClassifierTFLiteName);
if (custom_gesture_classifier_file.ok()) {
has_custom_gesture_classifier = true;
auto* custom_gesture_classifier_graph_options =
options->mutable_custom_gesture_classifier_graph_options();
SetExternalFile(
custom_gesture_classifier_file.value(),
custom_gesture_classifier_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
PopulateAccelerationAndUseStreamMode(
options->base_options(),
custom_gesture_classifier_graph_options->mutable_base_options());
} else {
LOG(INFO) << "Custom gesture classifier is not defined.";
}
return absl::OkStatus();
}
absl::StatusOr<SubTaskModelResources> CreateSubTaskModelResources(
SubgraphContext* sc) {
auto* options = sc->MutableOptions<HandGestureRecognizerGraphOptions>();
@ -237,6 +278,19 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
std::make_unique<core::proto::ExternalFile>(
std::move(canned_gesture_classifier_model_asset)),
"_canned_gesture_classifier"));
if (has_custom_gesture_classifier) {
auto& custom_gesture_classifier_model_asset =
*options->mutable_custom_gesture_classifier_graph_options()
->mutable_base_options()
->mutable_model_asset();
ASSIGN_OR_RETURN(
sub_task_model_resources.custom_gesture_classifier_model_resource,
CreateModelResources(
sc,
std::make_unique<core::proto::ExternalFile>(
std::move(custom_gesture_classifier_model_asset)),
"_custom_gesture_classifier"));
}
return sub_task_model_resources;
}
@ -302,7 +356,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
hand_world_landmarks_tensor >> concatenate_tensor_vector.In(2);
auto concatenated_tensors = concatenate_tensor_vector.Out("");
// Inference for static hand gesture recognition.
// Inference for gesture embedder.
auto& gesture_embedder_inference =
AddInference(*sub_task_model_resources.gesture_embedder_model_resource,
graph_options.gesture_embedder_graph_options()
@ -310,34 +364,64 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
.acceleration(),
graph);
concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag);
auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag);
auto embedding_tensors =
gesture_embedder_inference.Out(kTensorsTag).Cast<Tensor>();
auto& canned_gesture_classifier_inference = AddInference(
*sub_task_model_resources.canned_gesture_classifier_model_resource,
graph_options.canned_gesture_classifier_graph_options()
.base_options()
.acceleration(),
graph);
embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag);
auto inference_output_tensors =
canned_gesture_classifier_inference.Out(kTensorsTag);
auto& combine_predictions = graph.AddNode("CombinedPredictionCalculator");
MP_RETURN_IF_ERROR(ConfigureCombinedPredictionCalculator(
&combine_predictions
.GetOptions<CombinedPredictionCalculatorOptions>()));
int classifier_nums = 0;
// Inference for custom gesture classifier if it exists.
if (has_custom_gesture_classifier) {
ASSIGN_OR_RETURN(
auto gesture_clasification_list,
GetGestureClassificationList(
sub_task_model_resources.custom_gesture_classifier_model_resource,
graph_options.custom_gesture_classifier_graph_options(),
embedding_tensors, graph));
gesture_clasification_list >> combine_predictions.In(classifier_nums++);
}
// Inference for canned gesture classifier.
ASSIGN_OR_RETURN(
auto gesture_clasification_list,
GetGestureClassificationList(
sub_task_model_resources.canned_gesture_classifier_model_resource,
graph_options.canned_gesture_classifier_graph_options(),
embedding_tensors, graph));
gesture_clasification_list >> combine_predictions.In(classifier_nums++);
auto combined_classification_list =
combine_predictions.Out(kPredictionTag).Cast<ClassificationList>();
return combined_classification_list;
}
absl::StatusOr<Source<ClassificationList>> GetGestureClassificationList(
const core::ModelResources* model_resources,
const proto::GestureClassifierGraphOptions& options,
Source<Tensor>& embedding_tensors, Graph& graph) {
auto& custom_gesture_classifier_inference = AddInference(
*model_resources, options.base_options().acceleration(), graph);
embedding_tensors >> custom_gesture_classifier_inference.In(kTensorsTag);
auto custom_gesture_inference_out_tensors =
custom_gesture_classifier_inference.Out(kTensorsTag);
auto& tensors_to_classification =
graph.AddNode("TensorsToClassificationCalculator");
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
graph_options.canned_gesture_classifier_graph_options()
.classifier_options(),
*sub_task_model_resources.canned_gesture_classifier_model_resource
->GetMetadataExtractor(),
options.classifier_options(), *model_resources->GetMetadataExtractor(),
0,
&tensors_to_classification.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>()));
inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
auto classification_list =
tensors_to_classification[Output<ClassificationList>(
"CLASSIFICATIONS")];
return classification_list;
custom_gesture_inference_out_tensors >>
tensors_to_classification.In(kTensorsTag);
return tensors_to_classification.Out("CLASSIFICATIONS")
.Cast<ClassificationList>();
}
bool has_custom_gesture_classifier = false;
};
// clang-format off

View File

@ -31,6 +31,8 @@ import java.util.List;
@AutoValue
public abstract class GestureRecognitionResult implements TaskResult {
private static final int kGestureDefaultIndex = -1;
/**
* Creates a {@link GestureRecognitionResult} instance from the lists of landmarks, handedness,
* and gestures protobuf messages.
@ -97,7 +99,9 @@ public abstract class GestureRecognitionResult implements TaskResult {
gestures.add(
Category.create(
classification.getScore(),
classification.getIndex(),
// Gesture index is not used, because the final gesture result comes from multiple
// classifiers.
kGestureDefaultIndex,
classification.getLabel(),
classification.getDisplayName()));
}
@ -123,6 +127,10 @@ public abstract class GestureRecognitionResult implements TaskResult {
/** Handedness of detected hands. */
public abstract List<List<Category>> handednesses();
/** Recognized hand gestures of detected hands */
/**
* Recognized hand gestures of detected hands. Note that the index of the gesture is always -1,
* because the raw indices from multiple gesture classifiers cannot consolidate to a meaningful
* index.
*/
public abstract List<List<Category>> gestures();
}

View File

@ -46,19 +46,24 @@ import org.junit.runners.Suite.SuiteClasses;
@SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class})
public class GestureRecognizerTest {
private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task";
private static final String GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE =
"gesture_recognizer_with_custom_classifier.task";
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
private static final String FIST_IMAGE = "fist.jpg";
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
private static final String FIST_LANDMARKS = "fist_landmarks.pb";
private static final String TAG = "Gesture Recognizer Test";
private static final String THUMB_UP_LABEL = "Thumb_Up";
private static final int THUMB_UP_INDEX = 5;
private static final String POINTING_UP_LABEL = "Pointing_Up";
private static final int POINTING_UP_INDEX = 3;
private static final String FIST_LABEL = "Closed_Fist";
private static final String ROCK_LABEL = "Rock";
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
private static final int IMAGE_WIDTH = 382;
private static final int IMAGE_HEIGHT = 406;
private static final int GESTURE_EXPECTED_INDEX = -1;
@RunWith(AndroidJUnit4.class)
public static final class General extends GestureRecognizerTest {
@ -77,7 +82,7 @@ public class GestureRecognizerTest {
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@ -108,16 +113,14 @@ public class GestureRecognizerTest {
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
// TODO update the confidence to be in range [0,1] after embedding model
// and scoring calculator is integrated.
.setMinGestureConfidence(2.0f)
.setMinGestureConfidence(0.5f)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL);
// Only contains one top scoring gesture.
assertThat(actualResult.gestures().get(0)).hasSize(1);
assertActualGestureEqualExpectedGesture(
@ -159,10 +162,48 @@ public class GestureRecognizerTest {
gestureRecognizer.recognize(
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
assertThat(actualResult.gestures()).hasSize(1);
assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX);
assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL);
}
@Test
public void recognize_successWithCannedGestureFist() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(FIST_LANDMARKS, FIST_LABEL);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_successWithCustomGestureRock() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(
GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(FIST_LANDMARKS, ROCK_LABEL);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_failsWithRegionOfInterest() throws Exception {
GestureRecognizerOptions options =
@ -331,7 +372,7 @@ public class GestureRecognizerTest {
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@ -348,7 +389,7 @@ public class GestureRecognizerTest {
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL);
for (int i = 0; i < 3; i++) {
GestureRecognitionResult actualResult =
gestureRecognizer.recognizeForVideo(
@ -361,7 +402,7 @@ public class GestureRecognizerTest {
public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception {
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL);
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
@ -393,7 +434,7 @@ public class GestureRecognizerTest {
public void recognize_successWithLiveSteamMode() throws Exception {
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL);
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
@ -423,7 +464,7 @@ public class GestureRecognizerTest {
}
private static GestureRecognitionResult getExpectedGestureRecognitionResult(
String filePath, String gestureLabel, int gestureIndex) throws Exception {
String filePath, String gestureLabel) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
LandmarksDetectionResult landmarksDetectionResultProto =
@ -431,9 +472,7 @@ public class GestureRecognizerTest {
ClassificationProto.ClassificationList gesturesProto =
ClassificationProto.ClassificationList.newBuilder()
.addClassification(
ClassificationProto.Classification.newBuilder()
.setLabel(gestureLabel)
.setIndex(gestureIndex))
ClassificationProto.Classification.newBuilder().setLabel(gestureLabel))
.build();
return GestureRecognitionResult.create(
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
@ -479,8 +518,8 @@ public class GestureRecognizerTest {
private static void assertActualGestureEqualExpectedGesture(
Category actualGesture, Category expectedGesture) {
assertThat(actualGesture.index()).isEqualTo(actualGesture.index());
assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName());
assertThat(actualGesture.categoryName()).isEqualTo(expectedGesture.categoryName());
assertThat(actualGesture.index()).isEqualTo(GESTURE_EXPECTED_INDEX);
}
private static void assertImageSizeIsExpected(MPImage inputImage) {

View File

@ -57,6 +57,22 @@ py_test(
],
)
py_test(
name = "image_segmenter_test",
srcs = ["image_segmenter_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_segmenter",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_test(
name = "gesture_recognizer_test",
srcs = ["gesture_recognizer_test.py"],

View File

@ -0,0 +1,353 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for image segmenter."""
import enum
from typing import List
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import cv2
import numpy as np
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_segmenter
from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image
_ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.OutputType
_Activation = image_segmenter.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98
def _similar_to_uint8_mask(actual_mask, expected_mask):
actual_mask_pixels = actual_mask.numpy_view().flatten()
expected_mask_pixels = expected_mask.numpy_view().flatten()
consistent_pixels = 0
num_pixels = len(expected_mask_pixels)
for index in range(num_pixels):
consistent_pixels += (
actual_mask_pixels[index] *
_MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index])
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class ImageSegmenterTest(parameterized.TestCase):
def setUp(self):
super().setUp()
# Load the test input image.
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE))
# Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(_SEGMENTATION_FILE), cv2.IMREAD_GRAYSCALE)
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter:
self.assertIsInstance(segmenter, _ImageSegmenter)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions(base_options=base_options)
with _ImageSegmenter.create_from_options(options) as segmenter:
self.assertIsInstance(segmenter, _ImageSegmenter)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _ImageSegmenterOptions(base_options=base_options)
segmenter = _ImageSegmenter.create_from_options(options)
self.assertIsInstance(segmenter, _ImageSegmenter)
@parameterized.parameters((ModelFileType.FILE_NAME,),
(ModelFileType.FILE_CONTENT,))
def test_segment_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1)
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct.
self.assertEqual(result_pixels.dtype, np.uint8)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
def test_segment_succeeds_with_confidence_mask(self):
# Creates segmenter.
base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode.
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
segmenter = _ImageSegmenter.create_from_options(options)
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _ImageSegmenterOptions(
base_options=base_options,
output_type=_OutputType.CONFIDENCE_MASK,
activation=_Activation.SOFTMAX)
segmenter = _ImageSegmenter.create_from_options(options)
confidence_masks = segmenter.segment(self.test_image)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
@parameterized.parameters((ModelFileType.FILE_NAME),
(ModelFileType.FILE_CONTENT))
def test_segment_in_context(self, model_file_type):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_contents = f.read()
base_options = _BaseOptions(model_asset_buffer=model_contents)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
def test_missing_result_callback(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass
def test_calling_segment_for_video_in_image_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
segmenter.segment_for_video(self.test_image, 0)
def test_calling_segment_async_in_image_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
segmenter.segment_async(self.test_image, 0)
def test_calling_segment_in_video_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
segmenter.segment(self.test_image)
def test_calling_segment_async_in_video_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
segmenter.segment_async(self.test_image, 0)
def test_segment_for_video_with_out_of_order_timestamp(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
unused_result = segmenter.segment_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK,
running_mode=_RUNNING_MODE.VIDEO)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
category_masks = segmenter.segment_for_video(self.test_image, timestamp)
self.assertLen(category_masks, 1)
self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
def test_calling_segment_in_live_stream_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
segmenter.segment(self.test_image)
def test_calling_segment_for_video_in_live_stream_mode(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
segmenter.segment_for_video(self.test_image, 0)
def test_segment_async_calls_with_illegal_timestamp(self):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageSegmenter.create_from_options(options) as segmenter:
segmenter.segment_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
segmenter.segment_async(self.test_image, 0)
def test_segment_async_calls(self):
observed_timestamp_ms = -1
def check_result(result: List[image_module.Image], output_image: _Image,
timestamp_ms: int):
# Get the output category mask.
category_mask = result[0]
self.assertEqual(output_image.width, self.test_image.width)
self.assertEqual(output_image.height, self.test_image.height)
self.assertEqual(output_image.width, self.test_seg_image.width)
self.assertEqual(output_image.height, self.test_seg_image.height)
self.assertTrue(
_similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the '
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK,
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp)
if __name__ == '__main__':
absltest.main()

View File

@ -59,6 +59,25 @@ py_library(
],
)
py_library(
name = "image_segmenter",
srcs = [
"image_segmenter.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_library(
name = "gesture_recognizer",
srcs = [

View File

@ -0,0 +1,253 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe image segmenter task."""
import dataclasses
import enum
from typing import Callable, List, Mapping, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.python._framework_bindings import task_runner
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
_TaskRunner = task_runner.TaskRunner
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
@dataclasses.dataclass
class ImageSegmenterOptions:
"""Options for the image segmenter task.
Attributes:
base_options: Base options for the image segmenter task.
running_mode: The running mode of the task. Default to the image mode. Image
segmenter task has three running modes: 1) The image mode for segmenting
objects on single image inputs. 2) The video mode for segmenting objects
on the decoded frames of a video. 3) The live stream mode for segmenting
objects on a live stream of input data, such as from camera.
output_type: The output mask type allows specifying the type of
post-processing to perform on the raw model results.
activation: Activation function to apply to input tensor.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
activation: Optional[Activation] = Activation.NONE
result_callback: Optional[Callable[
[List[image_module.Image], image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterOptionsProto:
"""Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value)
return _ImageSegmenterOptionsProto(
base_options=base_options_proto,
segmenter_options=segmenter_options_proto)
class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs image segmentation on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
"""Creates an `ImageSegmenter` object from a TensorFlow Lite model and the default `ImageSegmenterOptions`.
Note that the created `ImageSegmenter` instance is in image mode, for
performing image segmentation on single image inputs.
Args:
model_path: Path to the model.
Returns:
`ImageSegmenter` object that's created from the model file and the default
`ImageSegmenterOptions`.
Raises:
ValueError: If failed to create `ImageSegmenter` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = ImageSegmenterOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: ImageSegmenterOptions) -> 'ImageSegmenter':
"""Creates the `ImageSegmenter` object from image segmenter options.
Args:
options: Options for the image segmenter task.
Returns:
`ImageSegmenter` object that's created from `options`.
Raises:
ValueError: If failed to create `ImageSegmenter` object from
`ImageSegmenterOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
options.result_callback(segmentation_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])],
output_streams=[
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
def segment(self, image: image_module.Image) -> List[image_module.Image]:
"""Performs the actual segmentation task on the provided MediaPipe Image.
Args:
image: MediaPipe Image.
Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is
per-category segmented image mask.
If the output_type is CONFIDENCE_MASK, the returned vector of images
contains only one confidence image mask. A segmentation result object that
contains a list of segmentation masks as images.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run.
"""
output_packets = self._process_image_data(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result
def segment_for_video(self, image: image_module.Image,
timestamp_ms: int) -> List[image_module.Image]:
"""Performs segmentation on the provided video frames.
Only use this method when the ImageSegmenter is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is
per-category segmented image mask.
If the output_type is CONFIDENCE_MASK, the returned vector of images
contains only one confidence image mask. A segmentation result object that
contains a list of segmentation masks as images.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run.
"""
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
return segmentation_result
def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image segmentation.
Only use this method when the ImageSegmenter is created with the live stream
running mode. The input timestamps should be monotonically increasing for
adjacent calls of this method. This method will return immediately after the
input image is accepted. The results will be available via the
`result_callback` provided in the `ImageSegmenterOptions`. The
`segment_async` method is designed to process live stream data such as
camera input. To lower the overall latency, image segmenter may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` prvoides:
- A segmentation result object that contains a list of segmentation masks
as images.
- The input image that the image segmenter runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
Raises:
ValueError: If the current input timestamp is smaller than what the image
segmenter has already processed.
"""
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})

View File

@ -37,6 +37,7 @@ mediapipe_files(srcs = [
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite",
"fist.jpg",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"left_hands.jpg",
@ -64,6 +65,7 @@ mediapipe_files(srcs = [
"selfie_segm_144_256_3.tflite",
"selfie_segm_144_256_3_expected_mask.jpg",
"thumb_up.jpg",
"victory.jpg",
])
exports_files(
@ -91,6 +93,7 @@ filegroup(
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",
"fist.jpg",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"left_hands.jpg",
@ -107,6 +110,7 @@ filegroup(
"selfie_segm_128_128_3_expected_mask.jpg",
"selfie_segm_144_256_3_expected_mask.jpg",
"thumb_up.jpg",
"victory.jpg",
],
visibility = [
"//mediapipe/python:__subpackages__",
@ -149,6 +153,7 @@ filegroup(
"expected_left_up_hand_rotated_landmarks.prototxt",
"expected_right_down_hand_landmarks.prototxt",
"expected_right_up_hand_landmarks.prototxt",
"fist_landmarks.pbtxt",
"hand_detector_result_one_hand.pbtxt",
"hand_detector_result_one_hand_rotated.pbtxt",
"hand_detector_result_two_hands.pbtxt",
@ -156,5 +161,6 @@ filegroup(
"pointing_up_rotated_landmarks.pbtxt",
"thumb_up_landmarks.pbtxt",
"thumb_up_rotated_landmarks.pbtxt",
"victory_landmarks.pbtxt",
],
)

View File

@ -0,0 +1,223 @@
classifications {
classification {
score: 1.0
label: "Left"
display_name: "Left"
}
}
landmarks {
landmark {
x: 0.47709703
y: 0.66129065
z: -3.3540672e-07
}
landmark {
x: 0.6125982
y: 0.5578249
z: -0.041392017
}
landmark {
x: 0.71123487
y: 0.4316616
z: -0.064544134
}
landmark {
x: 0.6836403
y: 0.3199585
z: -0.08752567
}
landmark {
x: 0.5593274
y: 0.3206453
z: -0.09880819
}
landmark {
x: 0.60828537
y: 0.3068749
z: -0.014799656
}
landmark {
x: 0.62940764
y: 0.21414441
z: -0.06007311
}
landmark {
x: 0.6244353
y: 0.32872596
z: -0.08326768
}
landmark {
x: 0.60784453
y: 0.3684796
z: -0.09658983
}
landmark {
x: 0.5156504
y: 0.32194698
z: -0.021699267
}
landmark {
x: 0.52931
y: 0.24767634
z: -0.062571
}
landmark {
x: 0.5484773
y: 0.3805329
z: -0.07028895
}
landmark {
x: 0.54428184
y: 0.3881125
z: -0.07458326
}
landmark {
x: 0.43159598
y: 0.34918433
z: -0.037482508
}
landmark {
x: 0.4486106
y: 0.27649382
z: -0.08174769
}
landmark {
x: 0.47723144
y: 0.3964985
z: -0.06496752
}
landmark {
x: 0.46794242
y: 0.4082967
z: -0.04897496
}
landmark {
x: 0.34826216
y: 0.37813392
z: -0.057438444
}
landmark {
x: 0.3861837
y: 0.32820183
z: -0.07282783
}
landmark {
x: 0.41143674
y: 0.39734486
z: -0.047633167
}
landmark {
x: 0.39401984
y: 0.41149133
z: -0.029640475
}
}
world_landmarks {
landmark {
x: -0.008604452
y: 0.08165767
z: 0.0061365655
}
landmark {
x: 0.027301773
y: 0.061905317
z: -0.00872007
}
landmark {
x: 0.049898714
y: 0.035359327
z: -0.016682662
}
landmark {
x: 0.050297678
y: 0.005200807
z: -0.028928496
}
landmark {
x: 0.015639625
y: -0.0063155442
z: -0.03174634
}
landmark {
x: 0.029161729
y: -0.0024596984
z: 0.0011553494
}
landmark {
x: 0.034491
y: -0.017581237
z: -0.020781275
}
landmark {
x: 0.034020264
y: -0.0059247985
z: -0.02573838
}
landmark {
x: 0.02867364
y: 0.011137734
z: -0.009430941
}
landmark {
x: 0.0015385814
y: -0.004778851
z: 0.0056454404
}
landmark {
x: 0.010490709
y: -0.019680617
z: -0.027034117
}
landmark {
x: 0.0132071925
y: 0.0071370844
z: -0.034802448
}
landmark {
x: 0.0139978565
y: 0.011672501
z: -0.0040006908
}
landmark {
x: -0.019919239
y: -0.0006897822
z: -0.0003317799
}
landmark {
x: -0.01088193
y: -0.008502296
z: -0.02873486
}
landmark {
x: -0.005327127
y: 0.012745364
z: -0.034153957
}
landmark {
x: -0.0027040644
y: 0.02167169
z: -0.011669062
}
landmark {
x: -0.038813893
y: 0.011925209
z: -0.0076287366
}
landmark {
x: -0.030842202
y: 0.0010964936
z: -0.022697516
}
landmark {
x: -0.01829514
y: 0.013929318
z: -0.032819964
}
landmark {
x: -0.024175374
y: 0.022456694
z: -0.02357186
}
}

View File

@ -0,0 +1,223 @@
classifications {
classification {
score: 1.0
label: "Left"
display_name: "Left"
}
}
landmarks {
landmark {
x: 0.5164316
y: 0.804093
z: 8.7653416e-07
}
landmark {
x: 0.6063608
y: 0.7111354
z: -0.044089418
}
landmark {
x: 0.6280186
y: 0.588498
z: -0.062358405
}
landmark {
x: 0.5265348
y: 0.52083343
z: -0.08526791
}
landmark {
x: 0.4243384
y: 0.4993468
z: -0.1077741
}
landmark {
x: 0.5605667
y: 0.4489705
z: -0.016151091
}
landmark {
x: 0.5766643
y: 0.32260323
z: -0.049342215
}
landmark {
x: 0.5795845
y: 0.24180722
z: -0.07323826
}
landmark {
x: 0.5827511
y: 0.16940045
z: -0.09069163
}
landmark {
x: 0.4696163
y: 0.4599558
z: -0.032168437
}
landmark {
x: 0.44361597
y: 0.31689578
z: -0.075698614
}
landmark {
x: 0.42695498
y: 0.22273324
z: -0.10819675
}
landmark {
x: 0.40697217
y: 0.14279765
z: -0.12666894
}
landmark {
x: 0.39543492
y: 0.50612336
z: -0.055138163
}
landmark {
x: 0.3618012
y: 0.4388296
z: -0.1298119
}
landmark {
x: 0.4154368
y: 0.52674913
z: -0.1463017
}
landmark {
x: 0.44916254
y: 0.59442246
z: -0.13470782
}
landmark {
x: 0.33178204
y: 0.5731769
z: -0.08103096
}
landmark {
x: 0.3092102
y: 0.5040002
z: -0.13258384
}
landmark {
x: 0.35576707
y: 0.5576498
z: -0.12714732
}
landmark {
x: 0.393444
y: 0.6118667
z: -0.11102459
}
}
world_landmarks {
landmark {
x: 0.01299962
y: 0.09162361
z: 0.011185312
}
landmark {
x: 0.03726317
y: 0.0638103
z: -0.010005756
}
landmark {
x: 0.03975261
y: 0.03712649
z: -0.02906275
}
landmark {
x: 0.018798776
y: 0.012429599
z: -0.048737116
}
landmark {
x: -0.0128555335
y: 0.001022811
z: -0.044505004
}
landmark {
x: 0.025658218
y: -0.008031519
z: -0.0058278795
}
landmark {
x: 0.028017294
y: -0.038120236
z: -0.010376478
}
landmark {
x: 0.030067094
y: -0.059907563
z: -0.014568218
}
landmark {
x: 0.027284538
y: -0.07803874
z: -0.032692235
}
landmark {
x: 0.0013260426
y: -0.005039873
z: 0.005567288
}
landmark {
x: -0.002380834
y: -0.044605374
z: -0.0038231965
}
landmark {
x: -0.009240147
y: -0.066279344
z: -0.02161214
}
landmark {
x: -0.0092535615
y: -0.08933755
z: -0.037401434
}
landmark {
x: -0.01751284
y: 0.0037118336
z: 0.0047480655
}
landmark {
x: -0.02195602
y: -0.010006189
z: -0.02371484
}
landmark {
x: -0.012851426
y: 0.008346066
z: -0.037721373
}
landmark {
x: -0.00018795021
y: 0.026816685
z: -0.03732748
}
landmark {
x: -0.034864448
y: 0.022316
z: -0.0002774651
}
landmark {
x: -0.035896845
y: 0.01066218
z: -0.017325373
}
landmark {
x: -0.02358637
y: 0.018667895
z: -0.028403495
}
landmark {
x: -0.013704676
y: 0.033456434
z: -0.02595728
}
}

View File

@ -244,6 +244,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/feature_tensor_meta.json?generation=1665422818797346"],
)
http_file(
name = "com_google_mediapipe_fist_jpg",
sha256 = "43fa1cabf3f90d574accc9a56986e2ee48638ce59fc65af1846487f73bb2ef24",
urls = ["https://storage.googleapis.com/mediapipe-assets/fist.jpg?generation=1666999359066679"],
)
http_file(
name = "com_google_mediapipe_fist_landmarks_pbtxt",
sha256 = "76d6489e6163211ce5e9080e51983165bb9b24ff50146cc7487bd629f011c598",
urls = ["https://storage.googleapis.com/mediapipe-assets/fist_landmarks.pbtxt?generation=1666999360561864"],
)
http_file(
name = "com_google_mediapipe_general_meta_json",
sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f",
@ -838,6 +850,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/universal_sentence_encoder_qa_with_metadata.tflite?generation=1665445919252005"],
)
http_file(
name = "com_google_mediapipe_victory_jpg",
sha256 = "84cb8853e3df614e0cb5c93a25e3e2f38ea5e4f92fd428ee7d867ed3479d5764",
urls = ["https://storage.googleapis.com/mediapipe-assets/victory.jpg?generation=1666999364225126"],
)
http_file(
name = "com_google_mediapipe_victory_landmarks_pbtxt",
sha256 = "b25ab4f222674489f543afb6454396ecbc1437a7ae6213dbf0553029ae939ab0",
urls = ["https://storage.googleapis.com/mediapipe-assets/victory_landmarks.pbtxt?generation=1666999366036622"],
)
http_file(
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",