Merge branch 'google:master' into image-classification-python-impl
This commit is contained in:
commit
aac7ff946f
|
@ -143,9 +143,7 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "packet_frequency_calculator",
|
||||
srcs = ["packet_frequency_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:packet_frequency_cc_proto",
|
||||
|
@ -190,9 +188,7 @@ cc_test(
|
|||
cc_library(
|
||||
name = "packet_latency_calculator",
|
||||
srcs = ["packet_latency_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:latency_cc_proto",
|
||||
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto",
|
||||
|
|
|
@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
text->set_left(label_left_px_);
|
||||
text->set_baseline(label_baseline_px + i * label_height_px_);
|
||||
text->set_font_face(options_.font_face());
|
||||
if (options_.outline_thickness() > 0) {
|
||||
text->set_outline_thickness(options_.outline_thickness());
|
||||
if (options_.outline_color_size() > 0) {
|
||||
*(text->mutable_outline_color()) =
|
||||
options_.outline_color(i % options_.outline_color_size());
|
||||
} else {
|
||||
text->mutable_outline_color()->set_r(0);
|
||||
text->mutable_outline_color()->set_g(0);
|
||||
text->mutable_outline_color()->set_b(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kRenderDataTag)
|
||||
|
|
|
@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
|
|||
// Thickness for drawing the label(s).
|
||||
optional double thickness = 2 [default = 2];
|
||||
|
||||
// Color of outline around each character, if any. One per label, as with
|
||||
// color attribute.
|
||||
repeated Color outline_color = 12;
|
||||
|
||||
// Thickness of outline around each character.
|
||||
optional double outline_thickness = 11;
|
||||
|
||||
// The font height in absolute pixels.
|
||||
optional int32 font_height_px = 3 [default = 50];
|
||||
|
||||
|
|
|
@ -185,7 +185,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
|||
<< "Updated existing texture which had not been marked for reuse!";
|
||||
CHECK(prod_token);
|
||||
producer_sync_ = std::move(prod_token);
|
||||
producer_context_ = producer_sync_->GetContext();
|
||||
const auto& synced_context = producer_sync_->GetContext();
|
||||
if (synced_context) {
|
||||
producer_context_ = synced_context;
|
||||
}
|
||||
}
|
||||
|
||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
||||
|
|
|
@ -34,6 +34,7 @@ android_library(
|
|||
android_library(
|
||||
name = "android_framework_no_mff",
|
||||
proguard_specs = [":proguard.pgcfg"],
|
||||
visibility = ["//visibility:public"],
|
||||
exports = [
|
||||
":android_framework_no_proguard",
|
||||
],
|
||||
|
|
|
@ -48,6 +48,8 @@ pybind_extension(
|
|||
"//mediapipe/python/pybind:timestamp",
|
||||
"//mediapipe/python/pybind:validated_graph_config",
|
||||
"//mediapipe/tasks/python/core/pybind:task_runner",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@stblib//:stb_image",
|
||||
# Type registration.
|
||||
"//mediapipe/framework:basic_types_registration",
|
||||
"//mediapipe/framework/formats:classification_registration",
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Tests for mediapipe.python._framework_bindings.image."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
@ -23,6 +24,7 @@ import cv2
|
|||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
# resources dependency
|
||||
from mediapipe.python._framework_bindings import image
|
||||
from mediapipe.python._framework_bindings import image_frame
|
||||
|
||||
|
@ -185,6 +187,5 @@ class ImageTest(absltest.TestCase):
|
|||
gc.collect()
|
||||
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -45,6 +45,8 @@ pybind_library(
|
|||
":util",
|
||||
"//mediapipe/framework:type_map",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@stblib//:stb_image",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/python/pybind/image_frame_util.h"
|
||||
#include "mediapipe/python/pybind/util.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "stb_image.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace python {
|
||||
|
@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
image.is_aligned(16)
|
||||
)doc");
|
||||
|
||||
image.def_static(
|
||||
"create_from_file",
|
||||
[](const std::string& file_name) {
|
||||
int width;
|
||||
int height;
|
||||
int channels;
|
||||
auto* image_data =
|
||||
stbi_load(file_name.c_str(), &width, &height, &channels,
|
||||
/*desired_channels=*/0);
|
||||
if (image_data == nullptr) {
|
||||
throw RaisePyError(PyExc_RuntimeError,
|
||||
absl::StrFormat("Image decoding failed (%s): %s",
|
||||
stbi_failure_reason(), file_name)
|
||||
.c_str());
|
||||
}
|
||||
ImageFrameSharedPtr image_frame;
|
||||
switch (channels) {
|
||||
case 1:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::GRAY8, width, height, width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
case 3:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGB, width, height, 3 * width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
case 4:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGBA, width, height, 4 * width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
default:
|
||||
throw RaisePyError(
|
||||
PyExc_RuntimeError,
|
||||
absl::StrFormat(
|
||||
"Expected image with 1 (grayscale), 3 (RGB) or 4 "
|
||||
"(RGBA) channels, found %d channels.",
|
||||
channels)
|
||||
.c_str());
|
||||
}
|
||||
return Image(std::move(image_frame));
|
||||
},
|
||||
R"doc(Creates `Image` object from the image file.
|
||||
|
||||
Args:
|
||||
file_name: Image file name.
|
||||
|
||||
Returns:
|
||||
`Image` object.
|
||||
|
||||
Raises:
|
||||
RuntimeError if the image file can't be decoded.
|
||||
)doc",
|
||||
py::arg("file_name"));
|
||||
|
||||
image.def_property_readonly("width", &Image::width)
|
||||
.def_property_readonly("height", &Image::height)
|
||||
.def_property_readonly("channels", &Image::channels)
|
||||
|
|
|
@ -33,11 +33,12 @@ cc_library(
|
|||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
|
@ -60,12 +61,13 @@ cc_library(
|
|||
":audio_classifier_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||
"//mediapipe/tasks/cc/components:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
|
|
|
@ -22,10 +22,11 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -33,8 +34,12 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kAudioStreamName[] = "audio_in";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
|
@ -42,16 +47,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
|||
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.audio.AudioClassifierGraph";
|
||||
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using AudioClassifierOptionsProto =
|
||||
audio_classifier::proto::AudioClassifierOptions;
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// "mediapipe.tasks.audio.AudioClassifierGraph".
|
||||
// type "AudioClassifierGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<AudioClassifierOptionsProto> options_proto) {
|
||||
std::unique_ptr<proto::AudioClassifierGraphOptions> options_proto) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||
|
@ -59,7 +61,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||
subgraph.In(kSampleRateTag);
|
||||
}
|
||||
subgraph.GetOptions<AudioClassifierOptionsProto>().Swap(options_proto.get());
|
||||
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
|
@ -67,18 +70,18 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
}
|
||||
|
||||
// Converts the user-facing AudioClassifierOptions struct to the internal
|
||||
// AudioClassifierOptions proto.
|
||||
std::unique_ptr<AudioClassifierOptionsProto>
|
||||
// AudioClassifierGraphOptions proto.
|
||||
std::unique_ptr<proto::AudioClassifierGraphOptions>
|
||||
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||
auto options_proto = std::make_unique<AudioClassifierOptionsProto>();
|
||||
auto options_proto = std::make_unique<proto::AudioClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
||||
components::ConvertClassifierOptionsToProto(
|
||||
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
|
@ -119,7 +122,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
|||
};
|
||||
}
|
||||
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
||||
AudioClassifierOptionsProto>(
|
||||
proto::AudioClassifierGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
|
@ -140,6 +143,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
|
|||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -23,13 +23,14 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
// The options for configuring a mediapipe audio classifier task.
|
||||
struct AudioClassifierOptions {
|
||||
|
@ -39,7 +40,7 @@ struct AudioClassifierOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::ClassifierOptions classifier_options;
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
|
||||
// The running mode of the audio classifier. Default to the audio clips mode.
|
||||
// Audio classifier has two running modes:
|
||||
|
@ -58,8 +59,9 @@ struct AudioClassifierOptions {
|
|||
// The user-defined result callback for processing audio stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::AUDIO_STREAM.
|
||||
std::function<void(absl::StatusOr<ClassificationResult>)> result_callback =
|
||||
nullptr;
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
// Performs audio classification on audio clips or audio stream.
|
||||
|
@ -131,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
// framed audio clip.
|
||||
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
||||
// and makes `audio_sample_rate` optional.
|
||||
absl::StatusOr<ClassificationResult> Classify(mediapipe::Matrix audio_clip,
|
||||
double audio_sample_rate);
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
||||
|
||||
// Sends audio data (a block in a continuous audio stream) to perform audio
|
||||
// classification. Only use this method when the AudioClassifier is created
|
||||
|
@ -162,6 +164,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -28,12 +28,12 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -52,6 +53,7 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
|
@ -60,10 +62,9 @@ constexpr char kPacketTag[] = "PACKET";
|
|||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
using AudioClassifierOptionsProto =
|
||||
audio_classifier::proto::AudioClassifierOptions;
|
||||
|
||||
absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) {
|
||||
absl::Status SanityCheckOptions(
|
||||
const proto::AudioClassifierGraphOptions& options) {
|
||||
if (options.base_options().use_stream_mode() &&
|
||||
!options.has_default_input_audio_sample_rate()) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
|
@ -111,7 +112,7 @@ void ConfigureAudioToTensorCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification.
|
||||
// An "AudioClassifierGraph" performs audio classification.
|
||||
// - Accepts CPU audio buffer and outputs classification results on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -129,12 +130,12 @@ void ConfigureAudioToTensorCalculator(
|
|||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.audio.AudioClassifierGraph"
|
||||
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
|
||||
// input_stream: "AUDIO:audio_in"
|
||||
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
|
||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -152,16 +153,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<AudioClassifierOptionsProto>(sc));
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
const bool use_stream_mode = sc->Options<AudioClassifierOptionsProto>()
|
||||
.base_options()
|
||||
.use_stream_mode();
|
||||
const bool use_stream_mode =
|
||||
sc->Options<proto::AudioClassifierGraphOptions>()
|
||||
.base_options()
|
||||
.use_stream_mode();
|
||||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
BuildAudioClassificationTask(
|
||||
sc->Options<AudioClassifierOptionsProto>(), *model_resources,
|
||||
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Matrix>(kAudioTag)],
|
||||
use_stream_mode
|
||||
? absl::nullopt
|
||||
|
@ -178,14 +181,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
|
||||
// the inputs and returns one classification result per input audio buffer.
|
||||
//
|
||||
// task_options: the mediapipe tasks AudioClassifierOptions proto.
|
||||
// task_options: the mediapipe tasks AudioClassifierGraphOptions proto.
|
||||
// model_resources: the ModelSources object initialized from an audio
|
||||
// classifier model file with model metadata.
|
||||
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
||||
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
|
||||
const AudioClassifierOptionsProto& task_options,
|
||||
const proto::AudioClassifierGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
|
@ -236,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
|
||||
// Adds postprocessing calculators and connects them to the graph output.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing.GetOptions<
|
||||
tasks::components::ClassificationPostprocessingOptions>()));
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Time aggregation is only needed for performing audio classification on
|
||||
|
@ -257,8 +263,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph);
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph);
|
||||
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -37,17 +37,19 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
namespace {
|
||||
|
||||
using ::absl::StatusOr;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -557,6 +559,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,12 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "audio_classifier_options_proto",
|
||||
srcs = ["audio_classifier_options.proto"],
|
||||
name = "audio_classifier_graph_options_proto",
|
||||
srcs = ["audio_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,12 +18,12 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.audio.audio_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message AudioClassifierOptions {
|
||||
message AudioClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional AudioClassifierOptions ext = 451755788;
|
||||
optional AudioClassifierGraphOptions ext = 451755788;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
|
@ -31,7 +31,7 @@ message AudioClassifierOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
|
||||
// The default sample rate of the input audio. Must be set when the
|
||||
// AudioClassifier is configured to process audio stream data.
|
|
@ -58,65 +58,6 @@ cc_library(
|
|||
|
||||
# TODO: Enable this test
|
||||
|
||||
cc_library(
|
||||
name = "classifier_options",
|
||||
srcs = ["classifier_options.cc"],
|
||||
hdrs = ["classifier_options.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classification_postprocessing_options_proto",
|
||||
srcs = ["classification_postprocessing_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_postprocessing",
|
||||
srcs = ["classification_postprocessing.cc"],
|
||||
hdrs = ["classification_postprocessing.h"],
|
||||
deps = [
|
||||
":classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.cc"],
|
||||
|
|
|
@ -37,8 +37,8 @@ cc_library(
|
|||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:category_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -128,7 +128,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -25,15 +25,15 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
||||
using ::mediapipe::tasks::ClassificationResult;
|
||||
using ::mediapipe::tasks::Classifications;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
|
||||
// Aggregates ClassificationLists into a single ClassificationResult that has
|
||||
// 3 dimensions: (classification head, classification timestamp, classification
|
||||
|
|
|
@ -17,12 +17,13 @@ limitations under the License.
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
// Specialized EndLoopCalculator for Tasks specific types.
|
||||
namespace mediapipe::tasks {
|
||||
|
||||
typedef EndLoopCalculator<std::vector<ClassificationResult>>
|
||||
typedef EndLoopCalculator<
|
||||
std::vector<components::containers::proto::ClassificationResult>>
|
||||
EndLoopClassificationResultCalculator;
|
||||
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);
|
||||
|
||||
|
|
|
@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "category_proto",
|
||||
srcs = ["category.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifications_proto",
|
||||
srcs = ["classifications.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embeddings_proto",
|
||||
srcs = ["embeddings.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "landmarks_detection_result_proto",
|
||||
srcs = [
|
||||
|
@ -29,8 +47,3 @@ mediapipe_proto_library(
|
|||
"//mediapipe/framework/formats:rect_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embeddings_proto",
|
||||
srcs = ["embeddings.proto"],
|
||||
)
|
||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks;
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
// A single classification result.
|
||||
message Category {
|
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks;
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/components/containers/category.proto";
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
64
mediapipe/tasks/cc/components/processors/BUILD
Normal file
64
mediapipe/tasks/cc/components/processors/BUILD
Normal file
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "classifier_options",
|
||||
srcs = ["classifier_options.cc"],
|
||||
hdrs = ["classifier_options.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_postprocessing_graph",
|
||||
srcs = ["classification_postprocessing_graph.cc"],
|
||||
hdrs = ["classification_postprocessing_graph.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::tflite::ProcessUnit;
|
||||
|
@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
|
||||
// Performs sanity checks on provided ClassifierOptions.
|
||||
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) {
|
||||
absl::Status SanityCheckClassifierOptions(
|
||||
const proto::ClassifierOptions& options) {
|
||||
if (options.max_results() == 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
|
@ -203,7 +205,7 @@ absl::StatusOr<float> GetScoreThreshold(
|
|||
|
||||
// Gets the category allowlist or denylist (if any) as a set of indices.
|
||||
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||
const ClassifierOptions& options, const LabelItems& label_items) {
|
||||
const proto::ClassifierOptions& options, const LabelItems& label_items) {
|
||||
absl::flat_hash_set<int> category_indices;
|
||||
// Exit early if no denylist/allowlist.
|
||||
if (options.category_denylist_size() == 0 &&
|
||||
|
@ -239,7 +241,7 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
|||
|
||||
absl::Status ConfigureScoreCalibrationIfAny(
|
||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||
ClassificationPostprocessingOptions* options) {
|
||||
proto::ClassificationPostprocessingGraphOptions* options) {
|
||||
const auto* tensor_metadata =
|
||||
metadata_extractor.GetOutputTensorMetadata(tensor_index);
|
||||
if (tensor_metadata == nullptr) {
|
||||
|
@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
|||
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||
// classifier options and the (optional) output tensor metadata.
|
||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||
const ClassifierOptions& options,
|
||||
const proto::ClassifierOptions& options,
|
||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||
TensorsToClassificationCalculatorOptions* calculator_options) {
|
||||
const auto* tensor_metadata =
|
||||
|
@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureClassificationPostprocessing(
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const ModelResources& model_resources,
|
||||
const ClassifierOptions& classifier_options,
|
||||
ClassificationPostprocessingOptions* options) {
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
proto::ClassificationPostprocessingGraphOptions* options) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
|
||||
ASSIGN_OR_RETURN(const auto heads_properties,
|
||||
GetClassificationHeadsProperties(model_resources));
|
||||
|
@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
|
||||
// raw tensors into ClassificationResult objects.
|
||||
// A "ClassificationPostprocessingGraph" converts raw tensors into
|
||||
// ClassificationResult objects.
|
||||
// - Accepts CPU input tensors.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing(
|
|||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
//
|
||||
// The recommended way of using this subgraph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessing()' function. See header file
|
||||
// for more details.
|
||||
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||
// The recommended way of using this graph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||
// file for more details.
|
||||
class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
|
@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
|||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
BuildClassificationPostprocessing(
|
||||
sc->Options<ClassificationPostprocessingOptions>(),
|
||||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
classification_result_out >>
|
||||
|
@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
|||
}
|
||||
|
||||
private:
|
||||
// Adds an on-device classification postprocessing subgraph into the provided
|
||||
// builder::Graph instance. The classification postprocessing subgraph takes
|
||||
// Adds an on-device classification postprocessing graph into the provided
|
||||
// builder::Graph instance. The classification postprocessing graph takes
|
||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
||||
// stream containing the output classification results (ClassificationResult).
|
||||
//
|
||||
// options: the on-device ClassificationPostprocessingOptions.
|
||||
// options: the on-device ClassificationPostprocessingGraphOptions.
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||
// timestamps that a single ClassificationResult should aggregate.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>>
|
||||
BuildClassificationPostprocessing(
|
||||
const ClassificationPostprocessingOptions& options,
|
||||
const proto::ClassificationPostprocessingGraphOptions& options,
|
||||
Source<std::vector<Tensor>> tensors_in,
|
||||
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||
const int num_heads = options.tensors_to_classifications_options_size();
|
||||
|
@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
|||
kClassificationResultTag)];
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors::
|
||||
ClassificationPostprocessingGraph); // NOLINT
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,32 +13,33 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Configures a ClassificationPostprocessing subgraph using the provided model
|
||||
// Configures a ClassificationPostprocessingGraph using the provided model
|
||||
// resources and ClassifierOptions.
|
||||
// - Accepts CPU input tensors.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// auto& postprocessing =
|
||||
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph");
|
||||
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||
// model_resources,
|
||||
// classifier_options,
|
||||
// &preprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
||||
// &preprocessing.GetOptions<ClassificationPostprocessingGraphOptions>()));
|
||||
//
|
||||
// The resulting ClassificationPostprocessing subgraph has the following I/O:
|
||||
// The resulting ClassificationPostprocessingGraph has the following I/O:
|
||||
// Inputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator.
|
||||
|
@ -49,13 +50,14 @@ namespace components {
|
|||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
absl::Status ConfigureClassificationPostprocessing(
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const tasks::components::proto::ClassifierOptions& classifier_options,
|
||||
ClassificationPostprocessingOptions* options);
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
proto::ClassificationPostprocessingGraphOptions* options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -42,9 +42,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
|
@ -60,7 +61,7 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::proto::Approximately;
|
||||
|
@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.set_max_results(0);
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out);
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
|
||||
|
@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_allowlist("foo");
|
||||
options_in.add_category_denylist("bar");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out);
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
|
||||
|
@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_allowlist("foo");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out);
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
|
@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||
R"pb(score_calibration_options: []
|
||||
|
@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.set_max_results(3);
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||
R"pb(score_calibration_options: []
|
||||
|
@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.set_score_threshold(0.5);
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||
R"pb(score_calibration_options: []
|
||||
|
@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Check label map size and two first elements.
|
||||
EXPECT_EQ(
|
||||
|
@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_allowlist("tench");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Clear label map and compare the rest of the options.
|
||||
options_out.mutable_tensors_to_classifications_options(0)
|
||||
|
@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_denylist("background");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Clear label map and compare the rest of the options.
|
||||
options_out.mutable_tensors_to_classifications_options(0)
|
||||
|
@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
|||
auto model_resources,
|
||||
CreateModelResourcesForModel(
|
||||
kQuantizedImageClassifierWithDummyScoreCalibration));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Check label map size and two first elements.
|
||||
EXPECT_EQ(
|
||||
|
@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
// Check label maps sizes and first two elements.
|
||||
EXPECT_EQ(
|
||||
options_out.tensors_to_classifications_options(0).label_items_size(),
|
||||
|
@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
|||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const ClassifierOptions& options,
|
||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
||||
&postprocessing
|
||||
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
|
@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
options.set_score_threshold(0.5);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||
|
@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
|
@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
|
@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||
|
@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
ClassifierOptions* options) {
|
||||
tasks::components::proto::ClassifierOptions options_proto;
|
||||
proto::ClassifierOptions options_proto;
|
||||
options_proto.set_display_names_locale(options->display_names_locale);
|
||||
options_proto.set_max_results(options->max_results);
|
||||
options_proto.set_score_threshold(options->score_threshold);
|
||||
|
@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
|||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
||||
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Classifier options for MediaPipe C++ classification Tasks.
|
||||
struct ClassifierOptions {
|
||||
|
@ -49,11 +50,12 @@ struct ClassifierOptions {
|
|||
};
|
||||
|
||||
// Converts a ClassifierOptions to a ClassifierOptionsProto.
|
||||
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
ClassifierOptions* classifier_options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
|
@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "category_proto",
|
||||
srcs = ["category.proto"],
|
||||
name = "classifier_options_proto",
|
||||
srcs = ["classifier_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifications_proto",
|
||||
srcs = ["classifications.proto"],
|
||||
name = "classification_postprocessing_graph_options_proto",
|
||||
srcs = ["classification_postprocessing_graph_options.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||
],
|
||||
)
|
|
@ -15,16 +15,16 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components;
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
|
||||
|
||||
message ClassificationPostprocessingOptions {
|
||||
message ClassificationPostprocessingGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ClassificationPostprocessingOptions ext = 460416950;
|
||||
optional ClassificationPostprocessingGraphOptions ext = 460416950;
|
||||
}
|
||||
|
||||
// Optional mapping between output tensor index and corresponding score
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.proto;
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
// Shared options used by all classification tasks.
|
||||
message ClassifierOptions {
|
|
@ -23,11 +23,6 @@ mediapipe_proto_library(
|
|||
srcs = ["segmenter_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifier_options_proto",
|
||||
srcs = ["classifier_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedder_options_proto",
|
||||
srcs = ["embedder_options.proto"],
|
||||
|
|
|
@ -42,3 +42,16 @@ cc_test(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gate",
|
||||
hdrs = ["gate.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Enable this test
|
||||
|
|
160
mediapipe/tasks/cc/components/utils/gate.h
Normal file
160
mediapipe/tasks/cc/components/utils/gate.h
Normal file
|
@ -0,0 +1,160 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/calculators/core/gate_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace utils {
|
||||
|
||||
// Utility class that simplifies allowing (gating) multiple streams.
|
||||
class AllowGate {
|
||||
public:
|
||||
AllowGate(api2::builder::Source<bool> allow, api2::builder::Graph& graph)
|
||||
: node_(AddSourceGate(allow, graph)) {}
|
||||
AllowGate(api2::builder::SideSource<bool> allow, api2::builder::Graph& graph)
|
||||
: node_(AddSideSourceGate(allow, graph)) {}
|
||||
|
||||
// Move-only
|
||||
AllowGate(AllowGate&& allow_gate) = default;
|
||||
AllowGate& operator=(AllowGate&& allow_gate) = default;
|
||||
|
||||
template <typename T>
|
||||
api2::builder::Source<T> Allow(api2::builder::Source<T> source) {
|
||||
source >> node_.In(index_);
|
||||
return node_.Out(index_++).Cast<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static api2::builder::GenericNode& AddSourceGate(
|
||||
T allow, api2::builder::Graph& graph) {
|
||||
auto& gate_node = graph.AddNode("GateCalculator");
|
||||
allow >> gate_node.In("ALLOW");
|
||||
return gate_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static api2::builder::GenericNode& AddSideSourceGate(
|
||||
T allow, api2::builder::Graph& graph) {
|
||||
auto& gate_node = graph.AddNode("GateCalculator");
|
||||
allow >> gate_node.SideIn("ALLOW");
|
||||
return gate_node;
|
||||
}
|
||||
|
||||
api2::builder::GenericNode& node_;
|
||||
int index_ = 0;
|
||||
};
|
||||
|
||||
// Utility class that simplifies disallowing (gating) multiple streams.
|
||||
class DisallowGate {
|
||||
public:
|
||||
DisallowGate(api2::builder::Source<bool> disallow,
|
||||
api2::builder::Graph& graph)
|
||||
: node_(AddSourceGate(disallow, graph)) {}
|
||||
DisallowGate(api2::builder::SideSource<bool> disallow,
|
||||
api2::builder::Graph& graph)
|
||||
: node_(AddSideSourceGate(disallow, graph)) {}
|
||||
|
||||
// Move-only
|
||||
DisallowGate(DisallowGate&& disallow_gate) = default;
|
||||
DisallowGate& operator=(DisallowGate&& disallow_gate) = default;
|
||||
|
||||
template <typename T>
|
||||
api2::builder::Source<T> Disallow(api2::builder::Source<T> source) {
|
||||
source >> node_.In(index_);
|
||||
return node_.Out(index_++).Cast<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static api2::builder::GenericNode& AddSourceGate(
|
||||
T disallow, api2::builder::Graph& graph) {
|
||||
auto& gate_node = graph.AddNode("GateCalculator");
|
||||
auto& gate_node_opts =
|
||||
gate_node.GetOptions<mediapipe::GateCalculatorOptions>();
|
||||
// Supposedly, the most popular configuration for MediaPipe Tasks team
|
||||
// graphs. Hence, intentionally hard coded to catch and verify any other use
|
||||
// case (should help to workout a common approach and have a recommended way
|
||||
// of blocking streams).
|
||||
gate_node_opts.set_empty_packets_as_allow(true);
|
||||
disallow >> gate_node.In("DISALLOW");
|
||||
return gate_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static api2::builder::GenericNode& AddSideSourceGate(
|
||||
T disallow, api2::builder::Graph& graph) {
|
||||
auto& gate_node = graph.AddNode("GateCalculator");
|
||||
auto& gate_node_opts =
|
||||
gate_node.GetOptions<mediapipe::GateCalculatorOptions>();
|
||||
gate_node_opts.set_empty_packets_as_allow(true);
|
||||
disallow >> gate_node.SideIn("DISALLOW");
|
||||
return gate_node;
|
||||
}
|
||||
|
||||
api2::builder::GenericNode& node_;
|
||||
int index_ = 0;
|
||||
};
|
||||
|
||||
// Updates graph to drop @value stream packet if corresponding @condition stream
|
||||
// packet holds true.
|
||||
template <class T>
|
||||
api2::builder::Source<T> DisallowIf(api2::builder::Source<T> value,
|
||||
api2::builder::Source<bool> condition,
|
||||
api2::builder::Graph& graph) {
|
||||
return DisallowGate(condition, graph).Disallow(value);
|
||||
}
|
||||
|
||||
// Updates graph to drop @value stream packet if corresponding @condition stream
|
||||
// packet holds true.
|
||||
template <class T>
|
||||
api2::builder::Source<T> DisallowIf(api2::builder::Source<T> value,
|
||||
api2::builder::SideSource<bool> condition,
|
||||
api2::builder::Graph& graph) {
|
||||
return DisallowGate(condition, graph).Disallow(value);
|
||||
}
|
||||
|
||||
// Updates graph to pass through @value stream packet if corresponding
|
||||
// @allow stream packet holds true.
|
||||
template <class T>
|
||||
api2::builder::Source<T> AllowIf(api2::builder::Source<T> value,
|
||||
api2::builder::Source<bool> allow,
|
||||
api2::builder::Graph& graph) {
|
||||
return AllowGate(allow, graph).Allow(value);
|
||||
}
|
||||
|
||||
// Updates graph to pass through @value stream packet if corresponding
|
||||
// @allow side stream packet holds true.
|
||||
template <class T>
|
||||
api2::builder::Source<T> AllowIf(api2::builder::Source<T> value,
|
||||
api2::builder::SideSource<bool> allow,
|
||||
api2::builder::Graph& graph) {
|
||||
return AllowGate(allow, graph).Allow(value);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_
|
229
mediapipe/tasks/cc/components/utils/gate_test.cc
Normal file
229
mediapipe/tasks/cc/components/utils/gate_test.cc
Normal file
|
@ -0,0 +1,229 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_graph.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace utils {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::builder::SideSource;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
|
||||
TEST(DisallowGate, VerifyConfig) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
|
||||
Source<int> value1 = graph.In("VALUE_1").Cast<int>();
|
||||
Source<int> value2 = graph.In("VALUE_2").Cast<int>();
|
||||
Source<int> value3 = graph.In("VALUE_3").Cast<int>();
|
||||
|
||||
DisallowGate gate(condition, graph);
|
||||
gate.Disallow(value1).SetName("gated_stream1");
|
||||
gate.Disallow(value2).SetName("gated_stream2");
|
||||
gate.Disallow(value3).SetName("gated_stream3");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "__stream_1"
|
||||
input_stream: "__stream_2"
|
||||
input_stream: "__stream_3"
|
||||
input_stream: "DISALLOW:__stream_0"
|
||||
output_stream: "gated_stream1"
|
||||
output_stream: "gated_stream2"
|
||||
output_stream: "gated_stream3"
|
||||
options {
|
||||
[mediapipe.GateCalculatorOptions.ext] {
|
||||
empty_packets_as_allow: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "CONDITION:__stream_0"
|
||||
input_stream: "VALUE_1:__stream_1"
|
||||
input_stream: "VALUE_2:__stream_2"
|
||||
input_stream: "VALUE_3:__stream_3"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(DisallowIf, VerifyConfig) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Source<int> value = graph.In("VALUE").Cast<int>();
|
||||
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
|
||||
|
||||
auto gated_stream = DisallowIf(value, condition, graph);
|
||||
gated_stream.SetName("gated_stream");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "__stream_1"
|
||||
input_stream: "DISALLOW:__stream_0"
|
||||
output_stream: "gated_stream"
|
||||
options {
|
||||
[mediapipe.GateCalculatorOptions.ext] {
|
||||
empty_packets_as_allow: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "CONDITION:__stream_0"
|
||||
input_stream: "VALUE:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(DisallowIf, VerifyConfigWithSideCondition) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Source<int> value = graph.In("VALUE").Cast<int>();
|
||||
SideSource<bool> condition = graph.SideIn("CONDITION").Cast<bool>();
|
||||
|
||||
auto gated_stream = DisallowIf(value, condition, graph);
|
||||
gated_stream.SetName("gated_stream");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "gated_stream"
|
||||
input_side_packet: "DISALLOW:__side_packet_1"
|
||||
options {
|
||||
[mediapipe.GateCalculatorOptions.ext] {
|
||||
empty_packets_as_allow: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "VALUE:__stream_0"
|
||||
input_side_packet: "CONDITION:__side_packet_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(AllowGate, VerifyConfig) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
|
||||
Source<int> value1 = graph.In("VALUE_1").Cast<int>();
|
||||
Source<int> value2 = graph.In("VALUE_2").Cast<int>();
|
||||
Source<int> value3 = graph.In("VALUE_3").Cast<int>();
|
||||
|
||||
AllowGate gate(condition, graph);
|
||||
gate.Allow(value1).SetName("gated_stream1");
|
||||
gate.Allow(value2).SetName("gated_stream2");
|
||||
gate.Allow(value3).SetName("gated_stream3");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "__stream_1"
|
||||
input_stream: "__stream_2"
|
||||
input_stream: "__stream_3"
|
||||
input_stream: "ALLOW:__stream_0"
|
||||
output_stream: "gated_stream1"
|
||||
output_stream: "gated_stream2"
|
||||
output_stream: "gated_stream3"
|
||||
}
|
||||
input_stream: "CONDITION:__stream_0"
|
||||
input_stream: "VALUE_1:__stream_1"
|
||||
input_stream: "VALUE_2:__stream_2"
|
||||
input_stream: "VALUE_3:__stream_3"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(AllowIf, VerifyConfig) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Source<int> value = graph.In("VALUE").Cast<int>();
|
||||
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
|
||||
|
||||
auto gated_stream = AllowIf(value, condition, graph);
|
||||
gated_stream.SetName("gated_stream");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "__stream_1"
|
||||
input_stream: "ALLOW:__stream_0"
|
||||
output_stream: "gated_stream"
|
||||
}
|
||||
input_stream: "CONDITION:__stream_0"
|
||||
input_stream: "VALUE:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(AllowIf, VerifyConfigWithSideConition) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Source<int> value = graph.In("VALUE").Cast<int>();
|
||||
SideSource<bool> condition = graph.SideIn("CONDITION").Cast<bool>();
|
||||
|
||||
auto gated_stream = AllowIf(value, condition, graph);
|
||||
gated_stream.SetName("gated_stream");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "__stream_0"
|
||||
output_stream: "gated_stream"
|
||||
input_side_packet: "ALLOW:__side_packet_1"
|
||||
}
|
||||
input_stream: "VALUE:__stream_0"
|
||||
input_side_packet: "CONDITION:__side_packet_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace utils
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -23,6 +23,7 @@ cc_library(
|
|||
srcs = ["base_options.cc"],
|
||||
hdrs = ["base_options.h"],
|
||||
deps = [
|
||||
":mediapipe_builtin_op_resolver",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
|
@ -50,6 +51,21 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mediapipe_builtin_op_resolver",
|
||||
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
"//mediapipe/util/tflite/operations:transform_landmarks",
|
||||
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
|
||||
"//mediapipe/util/tflite/operations:transpose_conv_bias",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
|
||||
# supports TFLite-in-GMSCore.
|
||||
cc_library(
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
@ -63,7 +64,7 @@ struct BaseOptions {
|
|||
// A non-default OpResolver to support custom Ops or specify a subset of
|
||||
// built-in Ops.
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver =
|
||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
absl::make_unique<MediaPipeBuiltinOpResolver>();
|
||||
};
|
||||
|
||||
// Converts a BaseOptions to a BaseOptionsProto.
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||
|
@ -21,14 +21,11 @@ limitations under the License.
|
|||
#include "mediapipe/util/tflite/operations/transform_landmarks.h"
|
||||
#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h"
|
||||
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
|
||||
: BuiltinOpResolver() {
|
||||
namespace core {
|
||||
MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
||||
AddCustom("MaxPoolingWithArgmax2D",
|
||||
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
|
||||
AddCustom("MaxUnpooling2D",
|
||||
|
@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
|
|||
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
||||
/*version=*/2);
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,25 +13,23 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
class SelfieSegmentationModelOpResolver
|
||||
namespace core {
|
||||
class MediaPipeBuiltinOpResolver
|
||||
: public tflite::ops::builtin::BuiltinOpResolver {
|
||||
public:
|
||||
SelfieSegmentationModelOpResolver();
|
||||
SelfieSegmentationModelOpResolver(
|
||||
const SelfieSegmentationModelOpResolver& r) = delete;
|
||||
MediaPipeBuiltinOpResolver();
|
||||
MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
|
|
@ -18,18 +18,6 @@ package(default_visibility = [
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "hand_detector_op_resolver",
|
||||
srcs = ["hand_detector_op_resolver.cc"],
|
||||
hdrs = ["hand_detector_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
"//mediapipe/util/tflite/operations:transpose_conv_bias",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_detector_graph",
|
||||
srcs = ["hand_detector_graph.cc"],
|
||||
|
|
|
@ -35,11 +35,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
@ -121,8 +121,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
|||
hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)];
|
||||
|
||||
return TaskRunner::Create(graph.GetConfig(),
|
||||
absl::make_unique<HandDetectorOpResolver>());
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) {
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
|
||||
|
||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
||||
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
HandDetectorOpResolver::HandDetectorOpResolver() {
|
||||
AddCustom("MaxPoolingWithArgmax2D",
|
||||
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
|
||||
AddCustom("MaxUnpooling2D",
|
||||
mediapipe::tflite_operations::RegisterMaxUnpooling2D());
|
||||
AddCustom("Convolution2DTransposeBias",
|
||||
mediapipe::tflite_operations::RegisterConvolution2DTransposeBias());
|
||||
}
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -54,10 +54,10 @@ cc_library(
|
|||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
|
|
|
@ -27,9 +27,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -49,6 +49,7 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
|
||||
HandGestureRecognizerSubgraphOptions;
|
||||
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
|
||||
|
@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
model_resources, graph_options.classifier_options(),
|
||||
&postprocessing.GetOptions<
|
||||
tasks::components::ClassificationPostprocessingOptions>()));
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, graph_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
||||
auto classification_result =
|
||||
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
||||
|
|
|
@ -26,7 +26,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
@ -37,7 +37,5 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message HandGestureRecognizerSubgraphOptions {
|
||||
|
@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions {
|
|||
|
||||
// Options for configuring the gesture classifier behavior, such as score
|
||||
// threshold, number of results, etc.
|
||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
|
||||
// considered tracked successfully
|
||||
|
|
49
mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD
Normal file
49
mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD
Normal file
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/app/xeno:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_association_calculator_proto",
|
||||
srcs = ["hand_association_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_association_calculator",
|
||||
srcs = ["hand_association_calculator.cc"],
|
||||
deps = [
|
||||
":hand_association_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:association_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:rectangle",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:rectangle_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO: Enable this test
|
|
@ -0,0 +1,125 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/rectangle.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
|
||||
#include "mediapipe/util/rectangle_util.h"
|
||||
|
||||
namespace mediapipe::api2 {
|
||||
|
||||
// HandAssociationCalculator accepts multiple inputs of vectors of
|
||||
// NormalizedRect. The output is a vector of NormalizedRect that contains
|
||||
// rects from the input vectors that don't overlap with each other. When two
|
||||
// rects overlap, the rect that comes in from an earlier input stream is
|
||||
// kept in the output. If a rect has no ID (i.e. from detection stream),
|
||||
// then a unique rect ID is assigned for it.
|
||||
|
||||
// The rects in multiple input streams are effectively flattened to a single
|
||||
// list. For example:
|
||||
// Stream1 : rect 1, rect 2
|
||||
// Stream2: rect 3, rect 4
|
||||
// Stream3: rect 5, rect 6
|
||||
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
|
||||
// In the flattened list, if a rect with a higher index overlaps with a rect a
|
||||
// lower index, beyond a specified IOU threshold, the rect with the lower
|
||||
// index will be in the output, and the rect with higher index will be
|
||||
// discarded.
|
||||
// TODO: Upgrade this to latest API for calculators
|
||||
class HandAssociationCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
// Initialize input and output streams.
|
||||
for (auto& input_stream : cc->Inputs()) {
|
||||
input_stream.Set<std::vector<NormalizedRect>>();
|
||||
}
|
||||
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<HandAssociationCalculatorOptions>();
|
||||
CHECK_GT(options_.min_similarity_threshold(), 0.0);
|
||||
CHECK_LE(options_.min_similarity_threshold(), 1.0);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
ASSIGN_OR_RETURN(auto result, GetNonOverlappingElements(cc));
|
||||
|
||||
auto output =
|
||||
std::make_unique<std::vector<NormalizedRect>>(std::move(result));
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
HandAssociationCalculatorOptions options_;
|
||||
|
||||
// Return a list of non-overlapping elements from all input streams, with
|
||||
// decreasing order of priority based on input stream index and indices
|
||||
// within an input stream.
|
||||
absl::StatusOr<std::vector<NormalizedRect>> GetNonOverlappingElements(
|
||||
CalculatorContext* cc) {
|
||||
std::vector<NormalizedRect> result;
|
||||
|
||||
for (const auto& input_stream : cc->Inputs()) {
|
||||
if (input_stream.IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
|
||||
ASSIGN_OR_RETURN(
|
||||
bool is_overlapping,
|
||||
mediapipe::DoesRectOverlap(rect, result,
|
||||
options_.min_similarity_threshold()));
|
||||
if (!is_overlapping) {
|
||||
if (!rect.has_rect_id()) {
|
||||
rect.set_rect_id(GetNextRectId());
|
||||
}
|
||||
result.push_back(rect);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
// Each NormalizedRect processed by the calculator will be assigned
|
||||
// an unique id, if it does not already have an ID. The starting ID will be 1.
|
||||
// Note: This rect_id_ is local to an instance of this calculator. And it is
|
||||
// expected that the hand tracking graph to have only one instance of
|
||||
// this association calculator.
|
||||
int64 rect_id_ = 1;
|
||||
|
||||
inline int GetNextRectId() { return rect_id_++; }
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(HandAssociationCalculator);
|
||||
|
||||
} // namespace mediapipe::api2
|
|
@ -13,22 +13,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
syntax = "proto2";
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
package mediapipe;
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver {
|
||||
public:
|
||||
HandDetectorOpResolver();
|
||||
HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete;
|
||||
};
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
message HandAssociationCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional HandAssociationCalculatorOptions ext = 408244367;
|
||||
}
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
optional float min_similarity_threshold = 1 [default = 1.0];
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
class HandAssociationCalculatorTest : public testing::Test {
|
||||
protected:
|
||||
HandAssociationCalculatorTest() {
|
||||
// 0.4 ================
|
||||
// | | | |
|
||||
// 0.3 ===================== | NR2 | |
|
||||
// | | | NR1 | | | NR4 |
|
||||
// 0.2 | NR0 | =========== ================
|
||||
// | | | | | |
|
||||
// 0.1 =====|=============== |
|
||||
// | NR3 | | |
|
||||
// 0.0 ================ |
|
||||
// | NR5 |
|
||||
// -0.1 ===========
|
||||
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
|
||||
|
||||
// NormalizedRect nr_0.
|
||||
nr_0_.set_x_center(0.2);
|
||||
nr_0_.set_y_center(0.2);
|
||||
nr_0_.set_width(0.2);
|
||||
nr_0_.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_1.
|
||||
nr_1_.set_x_center(0.4);
|
||||
nr_1_.set_y_center(0.2);
|
||||
nr_1_.set_width(0.2);
|
||||
nr_1_.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_2.
|
||||
nr_2_.set_x_center(1.0);
|
||||
nr_2_.set_y_center(0.3);
|
||||
nr_2_.set_width(0.2);
|
||||
nr_2_.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_3.
|
||||
nr_3_.set_x_center(0.35);
|
||||
nr_3_.set_y_center(0.15);
|
||||
nr_3_.set_width(0.3);
|
||||
nr_3_.set_height(0.3);
|
||||
|
||||
// NormalizedRect nr_4.
|
||||
nr_4_.set_x_center(1.1);
|
||||
nr_4_.set_y_center(0.3);
|
||||
nr_4_.set_width(0.2);
|
||||
nr_4_.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_5.
|
||||
nr_5_.set_x_center(0.5);
|
||||
nr_5_.set_y_center(0.05);
|
||||
nr_5_.set_width(0.3);
|
||||
nr_5_.set_height(0.3);
|
||||
}
|
||||
|
||||
NormalizedRect nr_0_, nr_1_, nr_2_, nr_3_, nr_4_, nr_5_;
|
||||
};
|
||||
|
||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "HandAssociationCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
// Input Stream 0: nr_0, nr_1, nr_2.
|
||||
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_0->push_back(nr_0_);
|
||||
input_vec_0->push_back(nr_1_);
|
||||
input_vec_0->push_back(nr_2_);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: nr_3, nr_4.
|
||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_1->push_back(nr_3_);
|
||||
input_vec_1->push_back(nr_4_);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 2: nr_5.
|
||||
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_2->push_back(nr_5_);
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||
|
||||
// Rectangles are added in the following sequence:
|
||||
// nr_0 is added 1st.
|
||||
// nr_1 is added because it does not overlap with nr_0.
|
||||
// nr_2 is added because it does not overlap with nr_0 or nr_1.
|
||||
// nr_3 is NOT added because it overlaps with nr_0.
|
||||
// nr_4 is NOT added because it overlaps with nr_2.
|
||||
// nr_5 is NOT added because it overlaps with nr_1.
|
||||
EXPECT_EQ(3, assoc_rects.size());
|
||||
|
||||
// Check that IDs are filled in and contents match.
|
||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
||||
assoc_rects[0].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
|
||||
|
||||
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
|
||||
assoc_rects[1].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
|
||||
|
||||
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
|
||||
assoc_rects[2].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
|
||||
}
|
||||
|
||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "HandAssociationCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
// Input Stream 0: nr_0, nr_1. Tracked hands.
|
||||
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
// Setting ID to a negative number for test only, since newly generated
|
||||
// ID by HandAssociationCalculator are positive numbers.
|
||||
nr_0_.set_rect_id(-2);
|
||||
input_vec_0->push_back(nr_0_);
|
||||
nr_1_.set_rect_id(-1);
|
||||
input_vec_0->push_back(nr_1_);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: nr_2, nr_3. Newly detected palms.
|
||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_1->push_back(nr_2_);
|
||||
input_vec_1->push_back(nr_3_);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||
|
||||
// Rectangles are added in the following sequence:
|
||||
// nr_0 is added 1st.
|
||||
// nr_1 is added because it does not overlap with nr_0.
|
||||
// nr_2 is added because it does not overlap with nr_0 or nr_1.
|
||||
// nr_3 is NOT added because it overlaps with nr_0.
|
||||
EXPECT_EQ(3, assoc_rects.size());
|
||||
|
||||
// Check that IDs are filled in and contents match.
|
||||
EXPECT_EQ(assoc_rects[0].rect_id(), -2);
|
||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
|
||||
|
||||
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
|
||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
|
||||
|
||||
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
|
||||
assoc_rects[2].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
|
||||
}
|
||||
|
||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "HandAssociationCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
// Input Stream 0: nr_5.
|
||||
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_0->push_back(nr_5_);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: nr_4, nr_3
|
||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_1->push_back(nr_4_);
|
||||
input_vec_1->push_back(nr_3_);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 2: nr_2, nr_1, nr_0.
|
||||
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec_2->push_back(nr_2_);
|
||||
input_vec_2->push_back(nr_1_);
|
||||
input_vec_2->push_back(nr_0_);
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||
|
||||
// Rectangles are added in the following sequence:
|
||||
// nr_5 is added 1st.
|
||||
// nr_4 is added because it does not overlap with nr_5.
|
||||
// nr_3 is NOT added because it overlaps with nr_5.
|
||||
// nr_2 is NOT added because it overlaps with nr_4.
|
||||
// nr_1 is NOT added because it overlaps with nr_5.
|
||||
// nr_0 is added because it does not overlap with nr_5 or nr_4.
|
||||
EXPECT_EQ(3, assoc_rects.size());
|
||||
|
||||
// Outputs are in same order as inputs, and IDs are filled in.
|
||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
||||
assoc_rects[0].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_));
|
||||
|
||||
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
|
||||
assoc_rects[1].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_));
|
||||
|
||||
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
|
||||
assoc_rects[2].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_));
|
||||
}
|
||||
|
||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "HandAssociationCalculator"
|
||||
input_stream: "input_vec"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
// Just one input stream : nr_3, nr_5.
|
||||
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
|
||||
input_vec->push_back(nr_3_);
|
||||
input_vec->push_back(nr_5_);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||
|
||||
// Rectangles are added in the following sequence:
|
||||
// nr_3 is added 1st.
|
||||
// nr_5 is NOT added because it overlaps with nr_3.
|
||||
EXPECT_EQ(1, assoc_rects.size());
|
||||
|
||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
||||
assoc_rects[0].clear_rect_id();
|
||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -26,14 +26,14 @@ cc_library(
|
|||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -50,9 +50,9 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
|
@ -61,7 +61,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
|
|
|
@ -26,9 +26,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -36,11 +36,12 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -52,12 +53,11 @@ constexpr char kImageTag[] = "IMAGE";
|
|||
constexpr char kNormRectName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.ImageClassifierGraph";
|
||||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using ImageClassifierOptionsProto =
|
||||
image_classifier::proto::ImageClassifierOptions;
|
||||
|
||||
// Builds a NormalizedRect covering the entire image.
|
||||
NormalizedRect BuildFullImageNormRect() {
|
||||
|
@ -70,17 +70,17 @@ NormalizedRect BuildFullImageNormRect() {
|
|||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the
|
||||
// live stream mode, a "FlowLimiterCalculator" will be added to limit the number
|
||||
// of frames in flight.
|
||||
// type "ImageClassifierGraph". If the task is running in the live stream mode,
|
||||
// a "FlowLimiterCalculator" will be added to limit the number of frames in
|
||||
// flight.
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<ImageClassifierOptionsProto> options_proto,
|
||||
std::unique_ptr<proto::ImageClassifierGraphOptions> options_proto,
|
||||
bool enable_flow_limiting) {
|
||||
api2::builder::Graph graph;
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectName);
|
||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageClassifierOptionsProto>().Swap(
|
||||
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
task_subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
|
@ -98,18 +98,18 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
}
|
||||
|
||||
// Converts the user-facing ImageClassifierOptions struct to the internal
|
||||
// ImageClassifierOptions proto.
|
||||
std::unique_ptr<ImageClassifierOptionsProto>
|
||||
// ImageClassifierGraphOptions proto.
|
||||
std::unique_ptr<proto::ImageClassifierGraphOptions>
|
||||
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
|
||||
auto options_proto = std::make_unique<ImageClassifierOptionsProto>();
|
||||
auto options_proto = std::make_unique<proto::ImageClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode != core::RunningMode::IMAGE);
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
||||
components::ConvertClassifierOptionsToProto(
|
||||
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
|
@ -145,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
|||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<ImageClassifier,
|
||||
ImageClassifierOptionsProto>(
|
||||
proto::ImageClassifierGraphOptions>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
|
@ -214,6 +214,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
|
|||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -23,8 +23,8 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
// The options for configuring a Mediapipe image classifier task.
|
||||
struct ImageClassifierOptions {
|
||||
|
@ -50,12 +51,14 @@ struct ImageClassifierOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::ClassifierOptions classifier_options;
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(absl::StatusOr<ClassificationResult>, const Image&, int64)>
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>,
|
||||
const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -112,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: describe exact preprocessing steps once
|
||||
// YUVToImageCalculator is integrated.
|
||||
absl::StatusOr<ClassificationResult> Classify(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
|
||||
|
@ -126,9 +129,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<ClassificationResult> ClassifyForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>
|
||||
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
|
||||
// Sends live image data to image classification, and the results will be
|
||||
// available via the "result_callback" provided in the ImageClassifierOptions.
|
||||
|
@ -161,6 +164,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -22,18 +22,19 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -42,8 +43,7 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ImageClassifierOptionsProto =
|
||||
image_classifier::proto::ImageClassifierOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
|
@ -61,8 +61,7 @@ struct ImageClassifierOutputStreams {
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image
|
||||
// classification.
|
||||
// An "ImageClassifierGraph" performs image classification.
|
||||
// - Accepts CPU input images and outputs classifications on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -80,12 +79,12 @@ struct ImageClassifierOutputStreams {
|
|||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.ImageClassifierGraph"
|
||||
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext]
|
||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -104,13 +103,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<ImageClassifierOptionsProto>(sc));
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<proto::ImageClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildImageClassificationTask(
|
||||
sc->Options<ImageClassifierOptionsProto>(), *model_resources,
|
||||
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
output_streams.classification_result >>
|
||||
|
@ -125,13 +125,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
// (mediapipe::Image) as input and returns one classification result per input
|
||||
// image.
|
||||
//
|
||||
// task_options: the mediapipe tasks ImageClassifierOptions.
|
||||
// task_options: the mediapipe tasks ImageClassifierGraphOptions.
|
||||
// model_resources: the ModelSources object initialized from an image
|
||||
// classification model file with model metadata.
|
||||
// image_in: (mediapipe::Image) stream to run classification on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
|
||||
const ImageClassifierOptionsProto& task_options,
|
||||
const proto::ImageClassifierGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
// Adds preprocessing calculators and connects them to the graph input image
|
||||
|
@ -153,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
|
||||
// Adds postprocessing calculators and connects them to the graph output.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing.GetOptions<
|
||||
tasks::components::ClassificationPostprocessingOptions>()));
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
|
@ -168,8 +171,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph);
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph);
|
||||
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -32,8 +32,8 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -44,9 +44,13 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -814,6 +818,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,12 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "image_classifier_options_proto",
|
||||
srcs = ["image_classifier_options.proto"],
|
||||
name = "image_classifier_graph_options_proto",
|
||||
srcs = ["image_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,12 +18,12 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.vision.image_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message ImageClassifierOptions {
|
||||
message ImageClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ImageClassifierOptions ext = 456383383;
|
||||
optional ImageClassifierGraphOptions ext = 456383383;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
|
@ -31,5 +31,5 @@ message ImageClassifierOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
}
|
|
@ -33,7 +33,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
@ -73,19 +72,4 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_segmenter_op_resolvers",
|
||||
srcs = ["image_segmenter_op_resolvers.cc"],
|
||||
hdrs = ["image_segmenter_op_resolvers.h"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
"//mediapipe/util/tflite/operations:transform_landmarks",
|
||||
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
|
||||
"//mediapipe/util/tflite/operations:transpose_conv_bias",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: This test fails in OSS
|
||||
|
|
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
|
|
@ -31,7 +31,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
|||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<SelfieSegmentationModelOpResolver>();
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
|
@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
|||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<SelfieSegmentationModelOpResolver>();
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::NONE;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
|
|
|
@ -11,3 +11,26 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "category",
|
||||
srcs = ["Category.java"],
|
||||
deps = [
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "detection",
|
||||
srcs = ["Detection.java"],
|
||||
deps = [
|
||||
":category",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Category is a util class, contains a category name, its display name, a float value as score, and
|
||||
* the index of the label in the corresponding label file. Typically it's used as result of
|
||||
* classification or detection tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class Category {
|
||||
private static final float TOLERANCE = 1e-6f;
|
||||
|
||||
/**
|
||||
* Creates a {@link Category} instance.
|
||||
*
|
||||
* @param score the probability score of this label category.
|
||||
* @param index the index of the label in the corresponding label file.
|
||||
* @param categoryName the label of this category object.
|
||||
* @param displayName the display name of the label.
|
||||
*/
|
||||
public static Category create(float score, int index, String categoryName, String displayName) {
|
||||
return new AutoValue_Category(score, index, categoryName, displayName);
|
||||
}
|
||||
|
||||
/** The probability score of this label category. */
|
||||
public abstract float score();
|
||||
|
||||
/** The index of the label in the corresponding label file. Returns -1 if the index is not set. */
|
||||
public abstract int index();
|
||||
|
||||
/** The label of this category object. */
|
||||
public abstract String categoryName();
|
||||
|
||||
/**
|
||||
* The display name of the label, which may be translated for different locales. For example, a
|
||||
* label, "apple", may be translated into Spanish for display purpose, so that the display name is
|
||||
* "manzana".
|
||||
*/
|
||||
public abstract String displayName();
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object o) {
|
||||
if (!(o instanceof Category)) {
|
||||
return false;
|
||||
}
|
||||
Category other = (Category) o;
|
||||
return Math.abs(other.score() - this.score()) < TOLERANCE
|
||||
&& other.index() == this.index()
|
||||
&& other.categoryName().equals(this.categoryName())
|
||||
&& other.displayName().equals(this.displayName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(categoryName(), displayName(), score(), index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString() {
|
||||
return "<Category \""
|
||||
+ categoryName()
|
||||
+ "\" (displayName="
|
||||
+ displayName()
|
||||
+ " score="
|
||||
+ score()
|
||||
+ " index="
|
||||
+ index()
|
||||
+ ")>";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents one detected object in the results of {@link
|
||||
* com.google.mediapipe.tasks.version.objectdetector.ObjectDetector}.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class Detection {
|
||||
|
||||
/**
|
||||
* Creates a {@link Detection} instance from a list of {@link Category} and a bounding box.
|
||||
*
|
||||
* @param categories a list of {@link Category} objects that contain category name, display name,
|
||||
* score, and the label index.
|
||||
* @param boundingBox a {@link RectF} object to represent the bounding box.
|
||||
*/
|
||||
public static Detection create(List<Category> categories, RectF boundingBox) {
|
||||
|
||||
// As an open source project, we've been trying avoiding depending on common java libraries,
|
||||
// such as Guava, because it may introduce conflicts with clients who also happen to use those
|
||||
// libraries. Therefore, instead of using ImmutableList here, we convert the List into
|
||||
// unmodifiableList
|
||||
return new AutoValue_Detection(Collections.unmodifiableList(categories), boundingBox);
|
||||
}
|
||||
|
||||
/** A list of {@link Category} objects. */
|
||||
public abstract List<Category> categories();
|
||||
|
||||
/** A {@link RectF} object to represent the bounding box of the detected object. */
|
||||
public abstract RectF boundingBox();
|
||||
}
|
|
@ -11,3 +11,27 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
android_library(
|
||||
name = "core",
|
||||
srcs = glob(["*.java"]),
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
"//third_party:autovalue",
|
||||
"@com_google_protobuf//:protobuf_javalite",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Options to configure MediaPipe Tasks in general. */
|
||||
@AutoValue
|
||||
public abstract class BaseOptions {
|
||||
/** Builder for {@link BaseOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets the model path to a tflite model with metadata in the assets.
|
||||
*
|
||||
* <p>Note: when model path is set, both model file descriptor and model buffer should be empty.
|
||||
*/
|
||||
public abstract Builder setModelAssetPath(String value);
|
||||
|
||||
/**
|
||||
* Sets the native fd int of a tflite model with metadata.
|
||||
*
|
||||
* <p>Note: when model file descriptor is set, both model path and model buffer should be empty.
|
||||
*/
|
||||
public abstract Builder setModelAssetFileDescriptor(Integer value);
|
||||
|
||||
/**
|
||||
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model
|
||||
* with metadata.
|
||||
*
|
||||
* <p>Note: when model buffer is set, both model file and model file descriptor should be empty.
|
||||
*/
|
||||
public abstract Builder setModelAssetBuffer(ByteBuffer value);
|
||||
|
||||
/**
|
||||
* Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default
|
||||
* delegate CPU is used.
|
||||
*/
|
||||
public abstract Builder setDelegate(Delegate delegate);
|
||||
|
||||
abstract BaseOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link BaseOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if {@link BaseOptions} is invalid, or the provided model
|
||||
* buffer is not a direct {@link ByteBuffer} or a {@link MappedByteBuffer}.
|
||||
*/
|
||||
public final BaseOptions build() {
|
||||
BaseOptions options = autoBuild();
|
||||
int modelAssetPathPresent = options.modelAssetPath().isPresent() ? 1 : 0;
|
||||
int modelAssetFileDescriptorPresent = options.modelAssetFileDescriptor().isPresent() ? 1 : 0;
|
||||
int modelAssetBufferPresent = options.modelAssetBuffer().isPresent() ? 1 : 0;
|
||||
|
||||
if (modelAssetPathPresent + modelAssetFileDescriptorPresent + modelAssetBufferPresent != 1) {
|
||||
throw new IllegalArgumentException(
|
||||
"Please specify only one of the model asset path, the model asset file descriptor, and"
|
||||
+ " the model asset buffer.");
|
||||
}
|
||||
if (options.modelAssetBuffer().isPresent()
|
||||
&& !(options.modelAssetBuffer().get().isDirect()
|
||||
|| options.modelAssetBuffer().get() instanceof MappedByteBuffer)) {
|
||||
throw new IllegalArgumentException(
|
||||
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract Optional<String> modelAssetPath();
|
||||
|
||||
abstract Optional<Integer> modelAssetFileDescriptor();
|
||||
|
||||
abstract Optional<ByteBuffer> modelAssetBuffer();
|
||||
|
||||
abstract Delegate delegate();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
/** MediaPipe Tasks delegate. */
|
||||
// TODO implement advanced delegate setting.
|
||||
public enum Delegate {
|
||||
CPU,
|
||||
GPU,
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
/** Interface for the customizable MediaPipe task error listener. */
|
||||
public interface ErrorListener {
|
||||
void onError(RuntimeException e);
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/** Facilitates creation and destruction of the native ModelResourcesCache. */
|
||||
class ModelResourcesCache {
|
||||
private final long nativeHandle;
|
||||
private final AtomicBoolean isHandleValid;
|
||||
|
||||
public ModelResourcesCache() {
|
||||
nativeHandle = nativeCreateModelResourcesCache();
|
||||
isHandleValid = new AtomicBoolean(true);
|
||||
}
|
||||
|
||||
public boolean isHandleValid() {
|
||||
return isHandleValid.get();
|
||||
}
|
||||
|
||||
public long getNativeHandle() {
|
||||
if (isHandleValid.get()) {
|
||||
return nativeHandle;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public void release() {
|
||||
if (isHandleValid.compareAndSet(true, false)) {
|
||||
nativeReleaseModelResourcesCache(nativeHandle);
|
||||
}
|
||||
}
|
||||
|
||||
private native long nativeCreateModelResourcesCache();
|
||||
|
||||
private native void nativeReleaseModelResourcesCache(long nativeHandle);
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import com.google.mediapipe.framework.GraphService;
|
||||
|
||||
/** Java wrapper for graph service of ModelResourcesCacheService. */
|
||||
class ModelResourcesCacheService implements GraphService<ModelResourcesCache> {
|
||||
public ModelResourcesCacheService() {}
|
||||
|
||||
@Override
|
||||
public void installServiceObject(long context, ModelResourcesCache object) {
|
||||
nativeInstallServiceObject(context, object.getNativeHandle());
|
||||
}
|
||||
|
||||
public native void nativeInstallServiceObject(long context, long object);
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import android.util.Log;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import java.util.List;
|
||||
|
||||
/** Base class for handling MediaPipe task graph outputs. */
|
||||
public class OutputHandler<OutputT extends TaskResult, InputT> {
|
||||
/**
|
||||
* Interface for converting MediaPipe graph output {@link Packet}s to task result object and task
|
||||
* input object.
|
||||
*/
|
||||
public interface OutputPacketConverter<OutputT extends TaskResult, InputT> {
|
||||
OutputT convertToTaskResult(List<Packet> packets);
|
||||
|
||||
InputT convertToTaskInput(List<Packet> packets);
|
||||
}
|
||||
|
||||
/** Interface for the customizable MediaPipe task result listener. */
|
||||
public interface ResultListener<OutputT extends TaskResult, InputT> {
|
||||
void run(OutputT result, InputT input);
|
||||
}
|
||||
|
||||
private static final String TAG = "OutputHandler";
|
||||
// A task-specific graph output packet converter that should be implemented per task.
|
||||
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
|
||||
// The user-defined task result listener.
|
||||
private ResultListener<OutputT, InputT> resultListener;
|
||||
// The user-defined error listener.
|
||||
protected ErrorListener errorListener;
|
||||
// The cached task result for non latency sensitive use cases.
|
||||
protected OutputT cachedTaskResult;
|
||||
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
|
||||
private boolean handleTimestampBoundChanges = false;
|
||||
|
||||
/**
|
||||
* Sets a callback to be invoked to convert a {@link Packet} list to a task result object and a
|
||||
* task input object.
|
||||
*
|
||||
* @param converter the task-specific {@link OutputPacketConverter} callback.
|
||||
*/
|
||||
public void setOutputPacketConverter(OutputPacketConverter<OutputT, InputT> converter) {
|
||||
this.outputPacketConverter = converter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a callback to be invoked when task result objects become available.
|
||||
*
|
||||
* @param listener the user-defined {@link ResultListener} callback.
|
||||
*/
|
||||
public void setResultListener(ResultListener<OutputT, InputT> listener) {
|
||||
this.resultListener = listener;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a callback to be invoked when exceptions are thrown from the task graph.
|
||||
*
|
||||
* @param listener The user-defined {@link ErrorListener} callback.
|
||||
*/
|
||||
public void setErrorListener(ErrorListener listener) {
|
||||
this.errorListener = listener;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets whether the output handler should react to the timestamp bound changes that are reprsented
|
||||
* as empty output {@link Packet}s.
|
||||
*
|
||||
* @param handleTimestampBoundChanges A boolean value.
|
||||
*/
|
||||
public void setHandleTimestampBoundChanges(boolean handleTimestampBoundChanges) {
|
||||
this.handleTimestampBoundChanges = handleTimestampBoundChanges;
|
||||
}
|
||||
|
||||
/** Returns true if the task graph is set to handle timestamp bound changes. */
|
||||
boolean handleTimestampBoundChanges() {
|
||||
return handleTimestampBoundChanges;
|
||||
}
|
||||
|
||||
/* Returns the cached task result object. */
|
||||
public OutputT retrieveCachedTaskResult() {
|
||||
OutputT taskResult = cachedTaskResult;
|
||||
cachedTaskResult = null;
|
||||
return taskResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
|
||||
*
|
||||
* @param packets A list of output {@link Packet}s.
|
||||
*/
|
||||
void run(List<Packet> packets) {
|
||||
OutputT taskResult = null;
|
||||
try {
|
||||
taskResult = outputPacketConverter.convertToTaskResult(packets);
|
||||
if (resultListener == null) {
|
||||
cachedTaskResult = taskResult;
|
||||
} else {
|
||||
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
|
||||
resultListener.run(taskResult, taskInput);
|
||||
}
|
||||
} catch (MediaPipeException e) {
|
||||
if (errorListener != null) {
|
||||
errorListener.onError(e);
|
||||
} else {
|
||||
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e);
|
||||
}
|
||||
} finally {
|
||||
for (Packet packet : packets) {
|
||||
if (packet != null) {
|
||||
packet.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node;
|
||||
import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo;
|
||||
import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link
|
||||
* com.google.mediapipe.framework.Graph}.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class TaskInfo<T extends TaskOptions> {
|
||||
/** Builder for {@link TaskInfo}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder<T extends TaskOptions> {
|
||||
/** Sets the MediaPipe task graph name. */
|
||||
public abstract Builder<T> setTaskGraphName(String value);
|
||||
|
||||
/** Sets a list of task graph input stream info {@link String}s in the form TAG:name. */
|
||||
public abstract Builder<T> setInputStreams(List<String> value);
|
||||
|
||||
/** Sets a list of task graph output stream info {@link String}s in the form TAG:name. */
|
||||
public abstract Builder<T> setOutputStreams(List<String> value);
|
||||
|
||||
/** Sets to true if the task requires a flow limiter. */
|
||||
public abstract Builder<T> setEnableFlowLimiting(Boolean value);
|
||||
|
||||
/**
|
||||
* Sets a task-specific options instance.
|
||||
*
|
||||
* @param value a task-specific options that is derived from {@link TaskOptions}.
|
||||
*/
|
||||
public abstract Builder<T> setTaskOptions(T value);
|
||||
|
||||
public abstract TaskInfo<T> autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link TaskInfo} instance. *
|
||||
*
|
||||
* @throws IllegalArgumentException if the required information such as task graph name, graph
|
||||
* input streams, and the graph output streams are empty.
|
||||
*/
|
||||
public final TaskInfo<T> build() {
|
||||
TaskInfo<T> taskInfo = autoBuild();
|
||||
if (taskInfo.taskGraphName().isEmpty()
|
||||
|| taskInfo.inputStreams().isEmpty()
|
||||
|| taskInfo.outputStreams().isEmpty()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Task graph's name, input streams, and output streams should be non-empty.");
|
||||
}
|
||||
return taskInfo;
|
||||
}
|
||||
}
|
||||
|
||||
abstract String taskGraphName();
|
||||
|
||||
abstract T taskOptions();
|
||||
|
||||
abstract List<String> inputStreams();
|
||||
|
||||
abstract List<String> outputStreams();
|
||||
|
||||
abstract Boolean enableFlowLimiting();
|
||||
|
||||
public static <T extends TaskOptions> Builder<T> builder() {
|
||||
return new AutoValue_TaskInfo.Builder<T>();
|
||||
}
|
||||
|
||||
/* Returns a list of the output stream names without the stream tags. */
|
||||
List<String> outputStreamNames() {
|
||||
List<String> streamNames = new ArrayList<>(outputStreams().size());
|
||||
for (String stream : outputStreams()) {
|
||||
streamNames.add(stream.substring(stream.lastIndexOf(':') + 1));
|
||||
}
|
||||
return streamNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a MediaPipe Task {@link CalculatorGraphConfig} protobuf message from the {@link
|
||||
* TaskInfo} instance.
|
||||
*/
|
||||
CalculatorGraphConfig generateGraphConfig() {
|
||||
CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder();
|
||||
Node.Builder taskSubgraphBuilder =
|
||||
Node.newBuilder()
|
||||
.setCalculator(taskGraphName())
|
||||
.setOptions(taskOptions().convertToCalculatorOptionsProto());
|
||||
for (String outputStream : outputStreams()) {
|
||||
taskSubgraphBuilder.addOutputStream(outputStream);
|
||||
graphBuilder.addOutputStream(outputStream);
|
||||
}
|
||||
if (!enableFlowLimiting()) {
|
||||
for (String inputStream : inputStreams()) {
|
||||
taskSubgraphBuilder.addInputStream(inputStream);
|
||||
graphBuilder.addInputStream(inputStream);
|
||||
}
|
||||
graphBuilder.addNode(taskSubgraphBuilder.build());
|
||||
return graphBuilder.build();
|
||||
}
|
||||
Node.Builder flowLimiterCalculatorBuilder =
|
||||
Node.newBuilder()
|
||||
.setCalculator("FlowLimiterCalculator")
|
||||
.addInputStreamInfo(
|
||||
InputStreamInfo.newBuilder().setTagIndex("FINISHED").setBackEdge(true).build())
|
||||
.setOptions(
|
||||
CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
FlowLimiterCalculatorOptions.ext,
|
||||
FlowLimiterCalculatorOptions.newBuilder()
|
||||
.setMaxInFlight(1)
|
||||
.setMaxInQueue(1)
|
||||
.build())
|
||||
.build());
|
||||
for (String inputStream : inputStreams()) {
|
||||
graphBuilder.addInputStream(inputStream);
|
||||
flowLimiterCalculatorBuilder.addInputStream(stripTagIndex(inputStream));
|
||||
String taskInputStream = addStreamNamePrefix(inputStream);
|
||||
flowLimiterCalculatorBuilder.addOutputStream(stripTagIndex(taskInputStream));
|
||||
taskSubgraphBuilder.addInputStream(taskInputStream);
|
||||
}
|
||||
flowLimiterCalculatorBuilder.addInputStream(
|
||||
"FINISHED:" + stripTagIndex(outputStreams().get(0)));
|
||||
graphBuilder.addNode(flowLimiterCalculatorBuilder.build());
|
||||
graphBuilder.addNode(taskSubgraphBuilder.build());
|
||||
return graphBuilder.build();
|
||||
}
|
||||
|
||||
private String stripTagIndex(String tagIndexName) {
|
||||
return tagIndexName.substring(tagIndexName.lastIndexOf(':') + 1);
|
||||
}
|
||||
|
||||
private String addStreamNamePrefix(String tagIndexName) {
|
||||
return tagIndexName.substring(0, tagIndexName.lastIndexOf(':') + 1)
|
||||
+ "throttled_"
|
||||
+ stripTagIndex(tagIndexName);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import com.google.mediapipe.calculator.proto.InferenceCalculatorProto;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.tasks.core.proto.AccelerationProto;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
|
||||
import com.google.protobuf.ByteString;
|
||||
|
||||
/**
|
||||
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
|
||||
* {@link TaskOptions}.
|
||||
*/
|
||||
public abstract class TaskOptions {
|
||||
/**
|
||||
* Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf
|
||||
* message.
|
||||
*/
|
||||
public abstract CalculatorOptions convertToCalculatorOptionsProto();
|
||||
|
||||
/**
|
||||
* Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf
|
||||
* message.
|
||||
*/
|
||||
protected BaseOptionsProto.BaseOptions convertBaseOptionsToProto(BaseOptions options) {
|
||||
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
||||
ExternalFileProto.ExternalFile.newBuilder();
|
||||
options.modelAssetPath().ifPresent(externalFileBuilder::setFileName);
|
||||
options
|
||||
.modelAssetFileDescriptor()
|
||||
.ifPresent(
|
||||
fd ->
|
||||
externalFileBuilder.setFileDescriptorMeta(
|
||||
ExternalFileProto.FileDescriptorMeta.newBuilder().setFd(fd).build()));
|
||||
options
|
||||
.modelAssetBuffer()
|
||||
.ifPresent(
|
||||
modelBuffer -> {
|
||||
modelBuffer.rewind();
|
||||
externalFileBuilder.setFileContent(ByteString.copyFrom(modelBuffer));
|
||||
});
|
||||
AccelerationProto.Acceleration.Builder accelerationBuilder =
|
||||
AccelerationProto.Acceleration.newBuilder();
|
||||
switch (options.delegate()) {
|
||||
case CPU:
|
||||
accelerationBuilder.setXnnpack(
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
|
||||
.getDefaultInstance());
|
||||
break;
|
||||
case GPU:
|
||||
accelerationBuilder.setGpu(
|
||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.getDefaultInstance());
|
||||
break;
|
||||
}
|
||||
return BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setModelAsset(externalFileBuilder.build())
|
||||
.setAcceleration(accelerationBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
/**
|
||||
* Interface for the MediaPipe Task result. Any MediaPipe task-specific result class should
|
||||
* implement {@link TaskResult}.
|
||||
*/
|
||||
public interface TaskResult {
|
||||
/** Returns the timestamp that is associated with the task result object. */
|
||||
long timestampMs();
|
||||
}
|
|
@ -0,0 +1,265 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||
import com.google.mediapipe.framework.Graph;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/** The runner of MediaPipe task graphs. */
|
||||
public class TaskRunner implements AutoCloseable {
|
||||
private static final String TAG = TaskRunner.class.getSimpleName();
|
||||
private static final long TIMESATMP_UNITS_PER_SECOND = 1000000;
|
||||
|
||||
private final OutputHandler<? extends TaskResult, ?> outputHandler;
|
||||
private final AtomicBoolean graphStarted = new AtomicBoolean(false);
|
||||
private final Graph graph;
|
||||
private final ModelResourcesCache modelResourcesCache;
|
||||
private final AndroidPacketCreator packetCreator;
|
||||
private long lastSeenTimestamp = Long.MIN_VALUE;
|
||||
private ErrorListener errorListener;
|
||||
|
||||
/**
|
||||
* Create a {@link TaskRunner} instance.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param taskInfo a {@link TaskInfo} instance contains task graph name, task options, and graph
|
||||
* input and output stream names.
|
||||
* @param outputHandler a {@link OutputHandler} instance handles task result object and runtime
|
||||
* exception.
|
||||
* @throws MediaPipeException for any error during {@link TaskRunner} creation.
|
||||
*/
|
||||
public static TaskRunner create(
|
||||
Context context,
|
||||
TaskInfo<? extends TaskOptions> taskInfo,
|
||||
OutputHandler<? extends TaskResult, ?> outputHandler) {
|
||||
AndroidAssetUtil.initializeNativeAssetManager(context);
|
||||
Graph mediapipeGraph = new Graph();
|
||||
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
|
||||
ModelResourcesCache graphModelResourcesCache = new ModelResourcesCache();
|
||||
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
|
||||
mediapipeGraph.addMultiStreamCallback(
|
||||
taskInfo.outputStreamNames(),
|
||||
outputHandler::run,
|
||||
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
|
||||
mediapipeGraph.startRunningGraph();
|
||||
// Waits until all calculators are opened and the graph is fully started.
|
||||
mediapipeGraph.waitUntilGraphIdle();
|
||||
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a callback to be invoked when exceptions are thrown by the {@link TaskRunner} instance.
|
||||
*
|
||||
* @param listener an {@link ErrorListener} callback.
|
||||
*/
|
||||
public void setErrorListener(ErrorListener listener) {
|
||||
this.errorListener = listener;
|
||||
}
|
||||
|
||||
/** Returns the {@link AndroidPacketCreator} associated to the {@link TaskRunner} instance. */
|
||||
public AndroidPacketCreator getPacketCreator() {
|
||||
return packetCreator;
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method for processing batch data.
|
||||
*
|
||||
* <p>Note: This method is designed for processing batch data such as unrelated images and texts.
|
||||
* The call blocks the current thread until a failure status or a successful result is returned.
|
||||
* An internal timestamp will be assigend per invocation. This method is thread-safe and allows
|
||||
* clients to call it from different threads.
|
||||
*
|
||||
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||
*/
|
||||
public synchronized TaskResult process(Map<String, Packet> inputs) {
|
||||
addPackets(inputs, generateSyntheticTimestamp());
|
||||
graph.waitUntilGraphIdle();
|
||||
return outputHandler.retrieveCachedTaskResult();
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method for processing offline streaming data.
|
||||
*
|
||||
* <p>Note: This method is designed for processing offline streaming data such as the decoded
|
||||
* frames from a video file and an audio file. The call blocks the current thread until a failure
|
||||
* status or a successful result is returned. The caller must ensure that the input timestamp is
|
||||
* greater than the timestamps of previous invocations. This method is thread-unsafe and it is the
|
||||
* caller's responsibility to synchronize access to this method across multiple threads and to
|
||||
* ensure that the input packet timestamps are in order.
|
||||
*
|
||||
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||
* @param inputTimestamp the timestamp of the input packets.
|
||||
*/
|
||||
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
validateInputTimstamp(inputTimestamp);
|
||||
addPackets(inputs, inputTimestamp);
|
||||
graph.waitUntilGraphIdle();
|
||||
return outputHandler.retrieveCachedTaskResult();
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method for handling live streaming data.
|
||||
*
|
||||
* <p>Note: This method that is designed for handling live streaming data such as live camera and
|
||||
* microphone data. A user-defined packets callback function must be provided in the constructor
|
||||
* to receive the output packets. The caller must ensure that the input packet timestamps are
|
||||
* monotonically increasing. This method is thread-unsafe and it is the caller's responsibility to
|
||||
* synchronize access to this method across multiple threads and to ensure that the input packet
|
||||
* timestamps are in order.
|
||||
*
|
||||
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||
* @param inputTimestamp the timestamp of the input packets.
|
||||
*/
|
||||
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
validateInputTimstamp(inputTimestamp);
|
||||
addPackets(inputs, inputTimestamp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets and restarts the {@link TaskRunner} instance. This can be useful for resetting a
|
||||
* stateful task graph to process new data.
|
||||
*/
|
||||
public void restart() {
|
||||
if (graphStarted.get()) {
|
||||
try {
|
||||
graphStarted.set(false);
|
||||
graph.closeAllPacketSources();
|
||||
graph.waitUntilGraphDone();
|
||||
} catch (MediaPipeException e) {
|
||||
reportError(e);
|
||||
}
|
||||
}
|
||||
try {
|
||||
graph.startRunningGraph();
|
||||
// Waits until all calculators are opened and the graph is fully restarted.
|
||||
graph.waitUntilGraphIdle();
|
||||
graphStarted.set(true);
|
||||
} catch (MediaPipeException e) {
|
||||
reportError(e);
|
||||
}
|
||||
}
|
||||
|
||||
/** Closes and cleans up the {@link TaskRunner} instance. */
|
||||
@Override
|
||||
public void close() {
|
||||
if (!graphStarted.get()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
graphStarted.set(false);
|
||||
graph.closeAllPacketSources();
|
||||
graph.waitUntilGraphDone();
|
||||
if (modelResourcesCache != null) {
|
||||
modelResourcesCache.release();
|
||||
}
|
||||
} catch (MediaPipeException e) {
|
||||
// Note: errors during Process are reported at the earliest opportunity,
|
||||
// which may be addPacket or waitUntilDone, depending on timing. For consistency,
|
||||
// we want to always report them using the same async handler if installed.
|
||||
reportError(e);
|
||||
}
|
||||
try {
|
||||
graph.tearDown();
|
||||
} catch (MediaPipeException e) {
|
||||
reportError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
if (!graphStarted.get()) {
|
||||
reportError(
|
||||
new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"The task graph hasn't been successfully started or error occurs during graph"
|
||||
+ " initializaton."));
|
||||
}
|
||||
try {
|
||||
for (Map.Entry<String, Packet> entry : inputs.entrySet()) {
|
||||
// addConsumablePacketToInputStream allows the graph to take exclusive ownership of the
|
||||
// packet, which may allow for more memory optimizations.
|
||||
graph.addConsumablePacketToInputStream(entry.getKey(), entry.getValue(), inputTimestamp);
|
||||
// If addConsumablePacket succeeded, we don't need to release the packet ourselves.
|
||||
entry.setValue(null);
|
||||
}
|
||||
} catch (MediaPipeException e) {
|
||||
// TODO: do not suppress exceptions here!
|
||||
if (errorListener == null) {
|
||||
Log.e(TAG, "Mediapipe error: ", e);
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
} finally {
|
||||
for (Packet packet : inputs.values()) {
|
||||
// In case of error, addConsumablePacketToInputStream will not release the packet, so we
|
||||
// have to release it ourselves.
|
||||
if (packet != null) {
|
||||
packet.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the input timestamp is strictly greater than the last timestamp that has been
|
||||
* processed.
|
||||
*
|
||||
* @param inputTimestamp the input timestamp.
|
||||
*/
|
||||
private void validateInputTimstamp(long inputTimestamp) {
|
||||
if (lastSeenTimestamp >= inputTimestamp) {
|
||||
reportError(
|
||||
new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"The received packets having a smaller timestamp than the processed timestamp."));
|
||||
}
|
||||
lastSeenTimestamp = inputTimestamp;
|
||||
}
|
||||
|
||||
/** Generates a synthetic input timestamp in the batch processing mode. */
|
||||
private long generateSyntheticTimestamp() {
|
||||
long timestamp =
|
||||
lastSeenTimestamp == Long.MIN_VALUE ? 0 : lastSeenTimestamp + TIMESATMP_UNITS_PER_SECOND;
|
||||
lastSeenTimestamp = timestamp;
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
/** Private constructor. */
|
||||
private TaskRunner(
|
||||
Graph graph,
|
||||
ModelResourcesCache modelResourcesCache,
|
||||
OutputHandler<? extends TaskResult, ?> outputHandler) {
|
||||
this.outputHandler = outputHandler;
|
||||
this.graph = graph;
|
||||
this.modelResourcesCache = modelResourcesCache;
|
||||
this.packetCreator = new AndroidPacketCreator(graph);
|
||||
graphStarted.set(true);
|
||||
}
|
||||
|
||||
/** Reports error. */
|
||||
private void reportError(MediaPipeException e) {
|
||||
if (errorListener != null) {
|
||||
errorListener.onError(e);
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "model_resources_cache_jni",
|
||||
srcs = [
|
||||
"model_resources_cache_jni.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"model_resources_cache_jni.h",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||
] + select({
|
||||
"//conditions:default": ["//third_party/java/jdk:jni"],
|
||||
"//mediapipe:android": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "model_resources_cache_jni",
|
||||
srcs = [
|
||||
"model_resources_cache_jni.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"model_resources_cache_jni.h",
|
||||
] + select({
|
||||
# The Android toolchain makes "jni.h" available in the include path.
|
||||
# For non-Android toolchains, generate jni.h and jni_md.h.
|
||||
"//mediapipe:android": [],
|
||||
"//conditions:default": [
|
||||
":jni.h",
|
||||
":jni_md.h",
|
||||
],
|
||||
}),
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Silly rules to make
|
||||
# #include <jni.h>
|
||||
# in the source headers work
|
||||
# (in combination with the "includes" attribute of the tf_cuda_library rule
|
||||
# above. Not needed when using the Android toolchain).
|
||||
#
|
||||
# Inspired from:
|
||||
# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
|
||||
# but hopefully there is a simpler alternative to this.
|
||||
genrule(
|
||||
name = "copy_jni_h",
|
||||
srcs = ["@bazel_tools//tools/jdk:jni_header"],
|
||||
outs = ["jni.h"],
|
||||
cmd = "cp -f $< $@",
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "copy_jni_md_h",
|
||||
srcs = select({
|
||||
"//mediapipe:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
|
||||
"//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
|
||||
}),
|
||||
outs = ["jni_md.h"],
|
||||
cmd = "cp -f $< $@",
|
||||
)
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::tasks::core::kModelResourcesCacheService;
|
||||
using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver;
|
||||
using ::mediapipe::tasks::core::ModelResourcesCache;
|
||||
using HandleType = std::shared_ptr<ModelResourcesCache>*;
|
||||
} // namespace
|
||||
|
||||
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) {
|
||||
auto ptr = std::make_shared<ModelResourcesCache>(
|
||||
absl::make_unique<MediaPipeBuiltinOpResolver>());
|
||||
HandleType handle = new std::shared_ptr<ModelResourcesCache>(std::move(ptr));
|
||||
return reinterpret_cast<jlong>(handle);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
|
||||
jlong nativeHandle) {
|
||||
delete reinterpret_cast<HandleType>(nativeHandle);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
|
||||
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
|
||||
jlong objectHandle) {
|
||||
mediapipe::android::GraphServiceHelper::SetServiceObject(
|
||||
contextHandle, kModelResourcesCacheService,
|
||||
*reinterpret_cast<HandleType>(objectHandle));
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
|
||||
#define JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
#define MODEL_RESOURCES_CACHE_METHOD(METHOD_NAME) \
|
||||
Java_com_google_mediapipe_tasks_core_ModelResourcesCache_##METHOD_NAME
|
||||
|
||||
#define MODEL_RESOURCES_CACHE_SERVICE_METHOD(METHOD_NAME) \
|
||||
Java_com_google_mediapipe_tasks_core_ModelResourcesCacheService_##METHOD_NAME
|
||||
|
||||
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz);
|
||||
|
||||
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
|
||||
jlong nativeHandle);
|
||||
|
||||
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
|
||||
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
|
||||
jlong objectHandle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
|
|
@ -11,3 +11,38 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
android_library(
|
||||
name = "core",
|
||||
srcs = glob(["*.java"]),
|
||||
deps = [
|
||||
":libmediapipe_tasks_vision_jni_lib",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# The native library of all MediaPipe vision tasks.
|
||||
cc_binary(
|
||||
name = "libmediapipe_tasks_vision_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libmediapipe_tasks_vision_jni_lib",
|
||||
srcs = [":libmediapipe_tasks_vision_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.core;
|
||||
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** The base class of MediaPipe vision tasks. */
|
||||
public class BaseVisionTaskApi implements AutoCloseable {
|
||||
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||
private final TaskRunner runner;
|
||||
private final RunningMode runningMode;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision
|
||||
* task {@link RunningMode}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param imageStreamName the image input stream name.
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @throws MediaPipeException if the task is not in the image mode.
|
||||
*/
|
||||
protected TaskResult processImageData(String imageStreamName, Image image) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param imageStreamName the image input stream name.
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode.
|
||||
*/
|
||||
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param imageStreamName the image input stream name.
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode.
|
||||
*/
|
||||
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the MediaPipe vision task. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.core;
|
||||
|
||||
/**
|
||||
* MediaPipe vision task running mode. A MediaPipe vision task can be run with three different
|
||||
* modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for running a mediapipe vision task on single image inputs.
|
||||
* <li>VIDEO: The mode for running a mediapipe vision task on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for running a mediapipe vision task on a live stream of input data,
|
||||
* such as from camera.
|
||||
* </ul>
|
||||
*/
|
||||
public enum RunningMode {
|
||||
IMAGE,
|
||||
VIDEO,
|
||||
LIVE_STREAM
|
||||
}
|
|
@ -11,3 +11,34 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "objectdetector",
|
||||
srcs = [
|
||||
"ObjectDetectionResult.java",
|
||||
"ObjectDetector.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = ":AndroidManifest.xml",
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework/formats:detection_java_proto_lite",
|
||||
"//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the detection results generated by {@link ObjectDetector}. */
|
||||
@AutoValue
|
||||
public abstract class ObjectDetectionResult implements TaskResult {
|
||||
private static final int DEFAULT_CATEGORY_INDEX = -1;
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections();
|
||||
|
||||
/**
|
||||
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
|
||||
* messages.
|
||||
*
|
||||
* @param detectionList a list of {@link Detection} protobuf messages.
|
||||
*/
|
||||
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
||||
for (Detection detectionProto : detectionList) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
|
||||
categories.add(
|
||||
Category.create(
|
||||
detectionProto.getScore(idx),
|
||||
detectionProto.getLabelIdCount() > idx
|
||||
? detectionProto.getLabelId(idx)
|
||||
: DEFAULT_CATEGORY_INDEX,
|
||||
detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
|
||||
detectionProto.getDisplayNameCount() > idx
|
||||
? detectionProto.getDisplayName(idx)
|
||||
: ""));
|
||||
}
|
||||
RectF boundingBox = new RectF();
|
||||
if (detectionProto.getLocationData().hasBoundingBox()) {
|
||||
BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
|
||||
boundingBox.set(
|
||||
/*left=*/ boundingBoxProto.getXmin(),
|
||||
/*top=*/ boundingBoxProto.getYmin(),
|
||||
/*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
|
||||
/*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
|
||||
}
|
||||
detections.add(
|
||||
com.google.mediapipe.tasks.components.containers.Detection.create(
|
||||
categories, boundingBox));
|
||||
}
|
||||
return new AutoValue_ObjectDetectionResult(
|
||||
timestampMs, Collections.unmodifiableList(detections));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,418 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs object detection on images.
|
||||
*
|
||||
* <p>The API expects a TFLite model with <a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||
*
|
||||
* <p>The API supports models with one image input tensor and four output tensors. To be more
|
||||
* specific, here are the requirements.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
|
||||
* <ul>
|
||||
* <li>image input of size {@code [batch x height x width x channels]}.
|
||||
* <li>batch inference is not supported ({@code batch} is required to be 1).
|
||||
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
|
||||
* <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
|
||||
* to the metadata for input normalization.
|
||||
* </ul>
|
||||
* <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e:
|
||||
* <ul>
|
||||
* <li>Location tensor ({@code kTfLiteFloat32}):
|
||||
* <ul>
|
||||
* <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing
|
||||
* bounding boxes in the form [top, left, right, bottom].
|
||||
* <li>{@code BoundingBoxProperties} are required to be attached to the metadata and
|
||||
* must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}.
|
||||
* </ul>
|
||||
* <li>Classes tensor ({@code kTfLiteFloat32}):
|
||||
* <ul>
|
||||
* <li>tensor of size {@code [1 x num_results]}, each value representing the integer
|
||||
* index of a class.
|
||||
* <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS}
|
||||
* associated files, they are used to convert the tensor values into labels.
|
||||
* </ul>
|
||||
* <li>scores tensor ({@code kTfLiteFloat32}):
|
||||
* <ul>
|
||||
* <li>tensor of size {@code [1 x num_results]}, each value representing the score of
|
||||
* the detected object.
|
||||
* </ul>
|
||||
* <li>Number of detection tensor ({@code kTfLiteFloat32}):
|
||||
* <ul>
|
||||
* <li>integer num_results as a tensor of size {@code [1]}.
|
||||
* </ul>
|
||||
* </ul>
|
||||
* </ul>
|
||||
*
|
||||
* <p>An example of such model can be found on <a
|
||||
* href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow
|
||||
* Hub.</a>.
|
||||
*/
|
||||
public final class ObjectDetector extends BaseVisionTaskApi {
|
||||
private static final String TAG = ObjectDetector.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
||||
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||
|
||||
/**
|
||||
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
|
||||
* ObjectDetectorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the detection model with metadata in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||
*/
|
||||
public static ObjectDetector createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
|
||||
* ObjectDetectorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the detection model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||
*/
|
||||
public static ObjectDetector createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ObjectDetector} instance from a model buffer and the default {@link
|
||||
* ObjectDetectorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
|
||||
* model.
|
||||
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||
*/
|
||||
public static ObjectDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||
return createFromOptions(
|
||||
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param detectorOptions a {@link ObjectDetectorOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||
*/
|
||||
public static ObjectDetector createFromOptions(
|
||||
Context context, ObjectDetectorOptions detectorOptions) {
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ObjectDetectionResult, Image> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, Image>() {
|
||||
@Override
|
||||
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
|
||||
return ObjectDetectionResult.create(
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
||||
packets.get(DETECTIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
detectorOptions.resultListener().ifPresent(handler::setResultListener);
|
||||
detectorOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ObjectDetectorOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(detectorOptions)
|
||||
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
|
||||
return new ObjectDetector(runner, detectorOptions.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link ObjectDetector} from a {@link TaskRunner} and a {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs object detection on the provided single image. Only use this method when the {@link
|
||||
* ObjectDetector} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link ObjectDetector} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detect(Image inputImage) {
|
||||
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs object detection on the provided video frame. Only use this method when the {@link
|
||||
* ObjectDetector} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ObjectDetector} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
|
||||
return (ObjectDetectionResult)
|
||||
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform object detection, and the results will be available via the
|
||||
* {@link ResultListener} provided in the {@link ObjectDetectorOptions}. Only use this method when
|
||||
* the {@link ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ObjectDetector} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(Image inputImage, long inputTimestampMs) {
|
||||
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ObjectDetector}. */
|
||||
@AutoValue
|
||||
public abstract static class ObjectDetectorOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link ObjectDetectorOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the object detector task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the object detector task. Default to the image mode. Object
|
||||
* detector has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for detecting objects on single image inputs.
|
||||
* <li>VIDEO: The mode for detecting objects on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for detecting objects on a live stream of input data, such
|
||||
* as from camera. In this mode, {@code setResultListener} must be called to set up a
|
||||
* listener to receive the detection results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/**
|
||||
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
|
||||
* any. Defaults to English.
|
||||
*/
|
||||
public abstract Builder setDisplayNamesLocale(String value);
|
||||
|
||||
/**
|
||||
* Sets the optional maximum number of top-scored detection results to return.
|
||||
*
|
||||
* <p>Overrides the ones provided in the model metadata. Results below this value are
|
||||
* rejected.
|
||||
*/
|
||||
public abstract Builder setMaxResults(Integer value);
|
||||
|
||||
/**
|
||||
* Sets the optional score threshold that overrides the one provided in the model metadata (if
|
||||
* any). Results below this value are rejected.
|
||||
*/
|
||||
public abstract Builder setScoreThreshold(Float value);
|
||||
|
||||
/**
|
||||
* Sets the optional allowlist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryDenylist}.
|
||||
*/
|
||||
public abstract Builder setCategoryAllowlist(List<String> value);
|
||||
|
||||
/**
|
||||
* Sets the optional denylist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryAllowlist}.
|
||||
*/
|
||||
public abstract Builder setCategoryDenylist(List<String> value);
|
||||
|
||||
/**
|
||||
* Sets the result listener to receive the detection results asynchronously when the object
|
||||
* detector is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
|
||||
|
||||
/** Sets an optional error listener. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract ObjectDetectorOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ObjectDetectorOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the object detector is
|
||||
* in the live stream mode.
|
||||
*/
|
||||
public final ObjectDetectorOptions build() {
|
||||
ObjectDetectorOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The object detector is in the live stream mode, a user-defined result listener"
|
||||
+ " must be provided in ObjectDetectorOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The object detector is in the image or the video mode, a user-defined result"
|
||||
+ " listener shouldn't be provided in ObjectDetectorOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<String> displayNamesLocale();
|
||||
|
||||
abstract Optional<Integer> maxResults();
|
||||
|
||||
abstract Optional<Float> scoreThreshold();
|
||||
|
||||
abstract List<String> categoryAllowlist();
|
||||
|
||||
abstract List<String> categoryDenylist();
|
||||
|
||||
abstract Optional<ResultListener<ObjectDetectionResult, Image>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ObjectDetector_ObjectDetectorOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setCategoryAllowlist(Collections.emptyList())
|
||||
.setCategoryDenylist(Collections.emptyList());
|
||||
}
|
||||
|
||||
/** Converts a {@link ObjectDetectorOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
ObjectDetectorOptionsProto.ObjectDetectorOptions.Builder taskOptionsBuilder =
|
||||
ObjectDetectorOptionsProto.ObjectDetectorOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
displayNamesLocale().ifPresent(taskOptionsBuilder::setDisplayNamesLocale);
|
||||
maxResults().ifPresent(taskOptionsBuilder::setMaxResults);
|
||||
scoreThreshold().ifPresent(taskOptionsBuilder::setScoreThreshold);
|
||||
if (!categoryAllowlist().isEmpty()) {
|
||||
taskOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||
}
|
||||
if (!categoryDenylist().isEmpty()) {
|
||||
taskOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||
}
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
ObjectDetectorOptionsProto.ObjectDetectorOptions.ext, taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,3 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "test_utils",
|
||||
srcs = ["TestUtils.java"],
|
||||
deps = [
|
||||
"//third_party/java/android_libs/guava_jdk5:io",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core;
|
||||
|
||||
import android.content.Context;
|
||||
import android.content.res.AssetManager;
|
||||
import com.google.common.io.ByteStreams;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
||||
/** Helper class for the Java test in MediaPipe Tasks. */
|
||||
public final class TestUtils {
|
||||
|
||||
/**
|
||||
* Loads the file and create a {@link File} object by reading a file from the asset directory.
|
||||
* Simulates downloading or reading a file that's not precompiled with the app.
|
||||
*
|
||||
* @return a {@link File} object for the model.
|
||||
*/
|
||||
public static File loadFile(Context context, String fileName) {
|
||||
File target = new File(context.getFilesDir(), fileName);
|
||||
try (InputStream is = context.getAssets().open(fileName);
|
||||
FileOutputStream os = new FileOutputStream(target)) {
|
||||
ByteStreams.copy(is, os);
|
||||
} catch (IOException e) {
|
||||
throw new AssertionError("Failed to load model file at " + fileName, e);
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads a file into a direct {@link ByteBuffer} object from the asset directory.
|
||||
*
|
||||
* @return a {@link ByteBuffer} object for the file.
|
||||
*/
|
||||
public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName)
|
||||
throws IOException {
|
||||
AssetManager assetManager = context.getAssets();
|
||||
InputStream inputStream = assetManager.open(fileName);
|
||||
byte[] bytes = ByteStreams.toByteArray(inputStream);
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
|
||||
buffer.put(bytes);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads a file into a non-direct {@link ByteBuffer} object from the asset directory.
|
||||
*
|
||||
* @return a {@link ByteBuffer} object for the file.
|
||||
*/
|
||||
public static ByteBuffer loadToNonDirectByteBuffer(Context context, String fileName)
|
||||
throws IOException {
|
||||
AssetManager assetManager = context.getAssets();
|
||||
InputStream inputStream = assetManager.open(fileName);
|
||||
byte[] bytes = ByteStreams.toByteArray(inputStream);
|
||||
return ByteBuffer.wrap(bytes);
|
||||
}
|
||||
|
||||
public enum ByteBufferType {
|
||||
DIRECT,
|
||||
BACK_UP_ARRAY,
|
||||
OTHER // Non-direct ByteBuffer without a back-up array.
|
||||
}
|
||||
|
||||
private TestUtils() {}
|
||||
}
|
|
@ -12,4 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
||||
|
|
|
@ -0,0 +1,456 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.Detection;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link ObjectDetector}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
|
||||
public class ObjectDetectorTest {
|
||||
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
||||
private static final int IMAGE_WIDTH = 1200;
|
||||
private static final int IMAGE_HEIGHT = 600;
|
||||
private static final float CAT_SCORE = 0.69f;
|
||||
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
|
||||
// TODO: Figure out why android_x86 and android_arm tests have slightly different
|
||||
// scores (0.6875 vs 0.69921875).
|
||||
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
|
||||
private static final float PIXEL_DIFF_TOLERANCE = 5.0f;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ObjectDetectorTest {
|
||||
|
||||
@Test
|
||||
public void detect_successWithValidModels() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithNoOptions() throws Exception {
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithMaxResultsOption() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setMaxResults(8)
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// results should have 8 detected objects because maxResults was set to 8.
|
||||
assertThat(results.detections()).hasSize(8);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithScoreThresholdOption() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setScoreThreshold(0.68f)
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// The score threshold should block all other other objects, except cat.
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithAllowListOption() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setCategoryAllowlist(Arrays.asList("cat"))
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Because of the allowlist, results should only contain cat, and there are 6 detected
|
||||
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
|
||||
assertThat(results.detections()).hasSize(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithDenyListOption() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setCategoryDenylist(Arrays.asList("cat"))
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Because of the denylist, the highest result is not cat anymore.
|
||||
assertThat(results.detections().get(0).categories().get(0).categoryName())
|
||||
.isNotEqualTo("cat");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithModelFileObject() throws Exception {
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithModelBuffer() throws Exception {
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromBuffer(
|
||||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadToDirectByteBuffer(
|
||||
ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithModelBufferAndOptions() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetBuffer(
|
||||
TestUtils.loadToDirectByteBuffer(
|
||||
ApplicationProvider.getApplicationContext(), MODEL_FILE))
|
||||
.build())
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingModel() throws Exception {
|
||||
String nonexistentFile = "/path/to/non/existent/file";
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
ObjectDetector.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), nonexistentFile));
|
||||
assertThat(exception).hasMessageThat().contains(nonexistentFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithInvalidModelBuffer() throws Exception {
|
||||
// Create a non-direct model ByteBuffer.
|
||||
ByteBuffer modelBuffer =
|
||||
TestUtils.loadToNonDirectByteBuffer(
|
||||
ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ObjectDetector.createFromBuffer(
|
||||
ApplicationProvider.getApplicationContext(), modelBuffer));
|
||||
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithBothAllowAndDenyListOption() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setCategoryAllowlist(Arrays.asList("cat"))
|
||||
.setCategoryDenylist(Arrays.asList("dog"))
|
||||
.build();
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
ObjectDetector.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
|
||||
}
|
||||
|
||||
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
|
||||
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
|
||||
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends ObjectDetectorTest {
|
||||
|
||||
@Test
|
||||
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(mode)
|
||||
.setResultListener((objectDetectionResult, inputImage) -> {})
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener shouldn't be provided");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener must be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((objectDetectionResult, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithImageMode() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithVideoMode() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ObjectDetectionResult results =
|
||||
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(objectDetectionResult, inputImage) -> {
|
||||
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
try (ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
objectDetector.detectAsync(image, 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithLiveSteamMode() throws Exception {
|
||||
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(objectDetectionResult, inputImage) -> {
|
||||
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
try (ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
objectDetector.detectAsync(image, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static Image getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
// Checks if results has one and only detection result, which is a cat.
|
||||
private static void assertContainsOnlyCat(
|
||||
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
|
||||
assertThat(result.detections()).hasSize(1);
|
||||
Detection catResult = result.detections().get(0);
|
||||
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
|
||||
// We only support one category for each detected object at this point.
|
||||
assertIsCat(catResult.categories().get(0), expectedScore);
|
||||
}
|
||||
|
||||
private static void assertIsCat(Category category, float expectedScore) {
|
||||
assertThat(category.categoryName()).isEqualTo("cat");
|
||||
// coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite does not support label locale.
|
||||
assertThat(category.displayName()).isEmpty();
|
||||
assertThat((double) category.score()).isWithin(SCORE_DIFF_TOLERANCE).of(expectedScore);
|
||||
assertThat(category.index()).isEqualTo(-1);
|
||||
}
|
||||
|
||||
private static void assertApproximatelyEqualBoundingBoxes(
|
||||
RectF boundingBox1, RectF boundingBox2) {
|
||||
assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left);
|
||||
assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top);
|
||||
assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right);
|
||||
assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom);
|
||||
}
|
||||
|
||||
private static void assertImageSizeIsExpected(Image inputImage) {
|
||||
assertThat(inputImage).isNotNull();
|
||||
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
|
||||
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
|
@ -14,3 +14,13 @@ flatbuffer_cc_library(
|
|||
name = "metadata_schema_cc",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "schema_py",
|
||||
srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "metadata_schema_py",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
|
|
@ -31,7 +31,7 @@ py_library(
|
|||
name = "category",
|
||||
srcs = ["category.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers:category_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import dataclasses
|
||||
from typing import Any
|
||||
|
||||
from mediapipe.tasks.cc.components.containers import category_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_CategoryProto = category_pb2.Category
|
||||
|
|
|
@ -27,6 +27,7 @@ pybind_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/python/pybind:util",
|
||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/python/pybind/util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11_protobuf/native_proto_caster.h"
|
||||
|
@ -75,7 +76,7 @@ mode) or not (synchronous mode).)doc");
|
|||
}
|
||||
auto task_runner = TaskRunner::Create(
|
||||
std::move(graph_config),
|
||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
|
||||
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||
std::move(callback));
|
||||
RaisePyErrorIfNotOk(task_runner.status());
|
||||
return std::move(*task_runner);
|
||||
|
|
38
mediapipe/tasks/python/metadata/BUILD
Normal file
38
mediapipe/tasks/python/metadata/BUILD
Normal file
|
@ -0,0 +1,38 @@
|
|||
load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
stamp_metadata_parser_version(
|
||||
name = "metadata_parser_py",
|
||||
srcs = ["metadata_parser.py.template"],
|
||||
outs = ["metadata_parser.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metadata",
|
||||
srcs = [
|
||||
"metadata.py",
|
||||
":metadata_parser_py",
|
||||
],
|
||||
data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:schema_py",
|
||||
"//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "metadata_displayer_cli",
|
||||
srcs = ["metadata_displayer_cli.py"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [":metadata"],
|
||||
)
|
13
mediapipe/tasks/python/metadata/__init__.py
Normal file
13
mediapipe/tasks/python/metadata/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
20
mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD
Normal file
20
mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD
Normal file
|
@ -0,0 +1,20 @@
|
|||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/tasks:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_flatbuffers",
|
||||
srcs = [
|
||||
"flatbuffers_lib.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_flatbuffers",
|
||||
deps = [
|
||||
"@flatbuffers",
|
||||
"@local_config_python//:python_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "flatbuffers/idl.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_flatbuffers, m) {
|
||||
pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
|
||||
pybind11::class_<flatbuffers::Parser>(m, "Parser")
|
||||
.def(pybind11::init<const flatbuffers::IDLOptions&>())
|
||||
.def("parse",
|
||||
[](flatbuffers::Parser* self, const std::string& source) {
|
||||
return self->Parse(source.c_str());
|
||||
})
|
||||
.def_readonly("builder", &flatbuffers::Parser::builder_)
|
||||
.def_readonly("error", &flatbuffers::Parser::error_);
|
||||
pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
|
||||
.def("clear", &flatbuffers::FlatBufferBuilder::Clear)
|
||||
.def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
|
||||
const std::string& contents) {
|
||||
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
|
||||
contents.length());
|
||||
});
|
||||
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
|
||||
m.def(
|
||||
"generate_text",
|
||||
[](const flatbuffers::Parser& parser,
|
||||
const std::string& buffer) -> std::string {
|
||||
std::string text;
|
||||
if (!flatbuffers::GenerateText(
|
||||
parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
|
||||
return "";
|
||||
}
|
||||
return text;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
865
mediapipe/tasks/python/metadata/metadata.py
Normal file
865
mediapipe/tasks/python/metadata/metadata.py
Normal file
|
@ -0,0 +1,865 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Lite metadata tools."""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers
|
||||
|
||||
try:
|
||||
# If exists, optionally use TensorFlow to open and check files. Used to
|
||||
# support more than local file systems.
|
||||
# In pip requirements, we doesn't necessarily need tensorflow as a dep.
|
||||
import tensorflow as tf
|
||||
_open_file = tf.io.gfile.GFile
|
||||
_exists_file = tf.io.gfile.exists
|
||||
except ImportError as e:
|
||||
# If TensorFlow package doesn't exist, fall back to original open and exists.
|
||||
_open_file = open
|
||||
_exists_file = os.path.exists
|
||||
|
||||
|
||||
def _maybe_open_as_binary(filename, mode):
|
||||
"""Maybe open the binary file, and returns a file-like."""
|
||||
if hasattr(filename, "read"): # A file-like has read().
|
||||
return filename
|
||||
openmode = mode if "b" in mode else mode + "b" # Add binary explicitly.
|
||||
return _open_file(filename, openmode)
|
||||
|
||||
|
||||
def _open_as_zipfile(filename, mode="r"):
|
||||
"""Open file as a zipfile.
|
||||
|
||||
Args:
|
||||
filename: str or file-like or path-like, to the zipfile.
|
||||
mode: str, common file mode for zip.
|
||||
(See: https://docs.python.org/3/library/zipfile.html)
|
||||
|
||||
Returns:
|
||||
A ZipFile object.
|
||||
"""
|
||||
file_like = _maybe_open_as_binary(filename, mode)
|
||||
return zipfile.ZipFile(file_like, mode)
|
||||
|
||||
|
||||
def _is_zipfile(filename):
|
||||
"""Checks whether it is a zipfile."""
|
||||
with _maybe_open_as_binary(filename, "r") as f:
|
||||
return zipfile.is_zipfile(f)
|
||||
|
||||
|
||||
def get_path_to_datafile(path):
|
||||
"""Gets the path to the specified file in the data dependencies.
|
||||
|
||||
The path is relative to the file calling the function.
|
||||
|
||||
It's a simple replacement of
|
||||
"tensorflow.python.platform.resource_loader.get_path_to_datafile".
|
||||
|
||||
Args:
|
||||
path: a string resource path relative to the calling file.
|
||||
|
||||
Returns:
|
||||
The path to the specified file present in the data attribute of py_test
|
||||
or py_binary.
|
||||
"""
|
||||
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access
|
||||
return os.path.join(data_files_path, path)
|
||||
|
||||
|
||||
_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
|
||||
"../../metadata/metadata_schema.fbs")
|
||||
|
||||
|
||||
# TODO: add delete method for associated files.
|
||||
class MetadataPopulator(object):
|
||||
"""Packs metadata and associated files into TensorFlow Lite model file.
|
||||
|
||||
MetadataPopulator can be used to populate metadata and model associated files
|
||||
into a model file or a model buffer (in bytearray). It can also help to
|
||||
inspect list of files that have been packed into the model or are supposed to
|
||||
be packed into the model.
|
||||
|
||||
The metadata file (or buffer) should be generated based on the metadata
|
||||
schema:
|
||||
third_party/tensorflow/lite/schema/metadata_schema.fbs
|
||||
|
||||
Example usage:
|
||||
Populate matadata and label file into an image classifier model.
|
||||
|
||||
First, based on metadata_schema.fbs, generate the metadata for this image
|
||||
classifer model using Flatbuffers API. Attach the label file onto the ouput
|
||||
tensor (the tensor of probabilities) in the metadata.
|
||||
|
||||
Then, pack the metadata and label file into the model as follows.
|
||||
|
||||
```python
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
a model file:
|
||||
populator = MetadataPopulator.with_model_file(model_file)
|
||||
# For metadata buffer (bytearray read from the metadata file), use:
|
||||
# populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
# For associated file buffer (bytearray read from the file), use:
|
||||
# populator.load_associated_file_buffers({"label.txt": b"file content"})
|
||||
populator.populate()
|
||||
|
||||
# Populating a metadata file (or a metadata buffer) and associated files to
|
||||
a model buffer:
|
||||
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
populator.populate()
|
||||
# Writing the updated model buffer into a file.
|
||||
updated_model_buf = populator.get_model_buffer()
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
|
||||
# Transferring metadata and associated files from another TFLite model:
|
||||
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||
populator_dst.load_metadata_and_associated_files(src_model_buf)
|
||||
populator_dst.populate()
|
||||
updated_model_buf = populator.get_model_buffer()
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
```
|
||||
|
||||
Note that existing metadata buffer (if applied) will be overridden by the new
|
||||
metadata buffer.
|
||||
"""
|
||||
# As Zip API is used to concatenate associated files after tflite model file,
|
||||
# the populating operation is developed based on a model file. For in-memory
|
||||
# model buffer, we create a tempfile to serve the populating operation.
|
||||
# Creating the deleting such a tempfile is handled by the class,
|
||||
# _MetadataPopulatorWithBuffer.
|
||||
|
||||
METADATA_FIELD_NAME = "TFLITE_METADATA"
|
||||
TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
METADATA_FILE_IDENTIFIER = b"M001"
|
||||
|
||||
def __init__(self, model_file):
|
||||
"""Constructor for MetadataPopulator.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
_assert_model_file_identifier(model_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_buf = None
|
||||
# _associated_files is a dict of file name and file buffer.
|
||||
self._associated_files = {}
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataPopulator object that populates data to a model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataPopulator object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return cls(model_file)
|
||||
|
||||
# TODO: investigate if type check can be applied to model_buf for
|
||||
# FB.
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buf):
|
||||
"""Creates a MetadataPopulator object that populates data to a model buffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||
|
||||
Raises:
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return _MetadataPopulatorWithBuffer(model_buf)
|
||||
|
||||
def get_model_buffer(self):
|
||||
"""Gets the buffer of the model with packed metadata and associated files.
|
||||
|
||||
Returns:
|
||||
Model buffer (in bytearray).
|
||||
"""
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
if not _is_zipfile(self._model_file):
|
||||
return []
|
||||
|
||||
with _open_as_zipfile(self._model_file, "r") as zf:
|
||||
return zf.namelist()
|
||||
|
||||
def get_recorded_associated_file_list(self):
|
||||
"""Gets a list of associated files recorded in metadata of the model file.
|
||||
|
||||
Associated files may be attached to a model, a subgraph, or an input/output
|
||||
tensor.
|
||||
|
||||
Returns:
|
||||
List of recorded associated files.
|
||||
"""
|
||||
if not self._metadata_buf:
|
||||
return []
|
||||
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||
self._metadata_buf, 0))
|
||||
|
||||
return [
|
||||
file.name.decode("utf-8")
|
||||
for file in self._get_recorded_associated_file_object_list(metadata)
|
||||
]
|
||||
|
||||
def load_associated_file_buffers(self, associated_files):
|
||||
"""Loads the associated file buffers (in bytearray) to be populated.
|
||||
|
||||
Args:
|
||||
associated_files: a dictionary of associated file names and corresponding
|
||||
file buffers, such as {"file.txt": b"file content"}. If pass in file
|
||||
paths for the file name, only the basename will be populated.
|
||||
"""
|
||||
|
||||
self._associated_files.update({
|
||||
os.path.basename(name): buffers
|
||||
for name, buffers in associated_files.items()
|
||||
})
|
||||
|
||||
def load_associated_files(self, associated_files):
|
||||
"""Loads associated files that to be concatenated after the model file.
|
||||
|
||||
Args:
|
||||
associated_files: list of file paths.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
File not found.
|
||||
"""
|
||||
for af_name in associated_files:
|
||||
_assert_file_exist(af_name)
|
||||
with _open_file(af_name, "rb") as af:
|
||||
self.load_associated_file_buffers({af_name: af.read()})
|
||||
|
||||
def load_metadata_buffer(self, metadata_buf):
|
||||
"""Loads the metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Args:
|
||||
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
ValueError: Cannot get minimum metadata parser version.
|
||||
ValueError: The number of SubgraphMetadata is not 1.
|
||||
ValueError: The number of input/output tensors does not match the number
|
||||
of input/output tensor metadata.
|
||||
"""
|
||||
if not metadata_buf:
|
||||
raise ValueError("The metadata to be populated is empty.")
|
||||
|
||||
self._validate_metadata(metadata_buf)
|
||||
|
||||
# Gets the minimum metadata parser version of the metadata_buf.
|
||||
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
|
||||
bytes(metadata_buf))
|
||||
|
||||
# Inserts in the minimum metadata parser version into the metadata_buf.
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
# Remove local file directory in the `name` field of `AssociatedFileT`, and
|
||||
# make it consistent with the name of the actual file packed in the model.
|
||||
self._use_basename_for_associated_files_in_metadata(metadata)
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf_with_version = b.Output()
|
||||
|
||||
self._metadata_buf = metadata_buf_with_version
|
||||
|
||||
def load_metadata_file(self, metadata_file):
|
||||
"""Loads the metadata file to be populated.
|
||||
|
||||
Args:
|
||||
metadata_file: path to the metadata file to be populated.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
ValueError: Cannot get minimum metadata parser version.
|
||||
ValueError: The number of SubgraphMetadata is not 1.
|
||||
ValueError: The number of input/output tensors does not match the number
|
||||
of input/output tensor metadata.
|
||||
"""
|
||||
_assert_file_exist(metadata_file)
|
||||
with _open_file(metadata_file, "rb") as f:
|
||||
metadata_buf = f.read()
|
||||
self.load_metadata_buffer(bytearray(metadata_buf))
|
||||
|
||||
def load_metadata_and_associated_files(self, src_model_buf):
|
||||
"""Loads the metadata and associated files from another model buffer.
|
||||
|
||||
Args:
|
||||
src_model_buf: source model buffer (in bytearray) with metadata and
|
||||
associated files.
|
||||
"""
|
||||
# Load the model metadata from src_model_buf if exist.
|
||||
metadata_buffer = get_metadata_buffer(src_model_buf)
|
||||
if metadata_buffer:
|
||||
self.load_metadata_buffer(metadata_buffer)
|
||||
|
||||
# Load the associated files from src_model_buf if exist.
|
||||
if _is_zipfile(io.BytesIO(src_model_buf)):
|
||||
with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
|
||||
self.load_associated_file_buffers(
|
||||
{f: zf.read(f) for f in zf.namelist()})
|
||||
|
||||
def populate(self):
|
||||
"""Populates loaded metadata and associated files into the model file."""
|
||||
self._assert_validate()
|
||||
self._populate_metadata_buffer()
|
||||
self._populate_associated_files()
|
||||
|
||||
def _assert_validate(self):
|
||||
"""Validates the metadata and associated files to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
File is recorded in the metadata, but is not going to be populated.
|
||||
File has already been packed.
|
||||
"""
|
||||
# Gets files that are recorded in metadata.
|
||||
recorded_files = self.get_recorded_associated_file_list()
|
||||
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
|
||||
# Gets the file name of those associated files to be populated.
|
||||
to_be_populated_files = self._associated_files.keys()
|
||||
|
||||
# Checks all files recorded in the metadata will be populated.
|
||||
for rf in recorded_files:
|
||||
if rf not in to_be_populated_files and rf not in packed_files:
|
||||
raise ValueError("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.".format(rf))
|
||||
|
||||
for f in to_be_populated_files:
|
||||
if f in packed_files:
|
||||
raise ValueError("File, '{0}', has already been packed.".format(f))
|
||||
|
||||
if f not in recorded_files:
|
||||
warnings.warn(
|
||||
"File, '{0}', does not exist in the metadata. But packing it to "
|
||||
"tflite model is still allowed.".format(f))
|
||||
|
||||
def _copy_archived_files(self, src_zip, file_list, dst_zip):
|
||||
"""Copy archieved files in file_list from src_zip ro dst_zip."""
|
||||
|
||||
if not _is_zipfile(src_zip):
|
||||
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
|
||||
|
||||
with _open_as_zipfile(src_zip, "r") as src_zf, \
|
||||
_open_as_zipfile(dst_zip, "a") as dst_zf:
|
||||
src_list = src_zf.namelist()
|
||||
for f in file_list:
|
||||
if f not in src_list:
|
||||
raise ValueError(
|
||||
"File, '{0}', does not exist in the zipfile, {1}.".format(
|
||||
f, src_zip))
|
||||
file_buffer = src_zf.read(f)
|
||||
dst_zf.writestr(f, file_buffer)
|
||||
|
||||
def _get_associated_files_from_process_units(self, table, field_name):
|
||||
"""Gets the files that are attached the process units field of a table.
|
||||
|
||||
Args:
|
||||
table: a Flatbuffers table object that contains fields of an array of
|
||||
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
|
||||
field_name: the name of the field in the table that represents an array of
|
||||
ProcessUnit. If the table is TensorMetadata, field_name can be
|
||||
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
|
||||
either "InputProcessUnits" or "OutputProcessUnits".
|
||||
|
||||
Returns:
|
||||
A list of AssociatedFileT objects.
|
||||
"""
|
||||
|
||||
if table is None:
|
||||
return []
|
||||
|
||||
file_list = []
|
||||
process_units = getattr(table, field_name)
|
||||
# If the process_units field is not populated, it will be None. Use an
|
||||
# empty list to skip the check.
|
||||
for process_unit in process_units or []:
|
||||
options = process_unit.options
|
||||
if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
|
||||
_metadata_fb.RegexTokenizerOptionsT)):
|
||||
file_list += self._get_associated_files_from_table(options, "vocabFile")
|
||||
elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
|
||||
file_list += self._get_associated_files_from_table(
|
||||
options, "sentencePieceModel")
|
||||
file_list += self._get_associated_files_from_table(options, "vocabFile")
|
||||
return file_list
|
||||
|
||||
def _get_associated_files_from_table(self, table, field_name):
|
||||
"""Gets the associated files that are attached a table directly.
|
||||
|
||||
Args:
|
||||
table: a Flatbuffers table object that contains fields of an array of
|
||||
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
|
||||
field_name: the name of the field in the table that represents an array of
|
||||
ProcessUnit. If the table is TensorMetadata, field_name can be
|
||||
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
|
||||
be "VocabFile".
|
||||
|
||||
Returns:
|
||||
A list of AssociatedFileT objects.
|
||||
"""
|
||||
|
||||
if table is None:
|
||||
return []
|
||||
|
||||
# If the associated file field is not populated,
|
||||
# `getattr(table, field_name)` will be None. Return an empty list.
|
||||
return getattr(table, field_name) or []
|
||||
|
||||
def _get_recorded_associated_file_object_list(self, metadata):
|
||||
"""Gets a list of AssociatedFileT objects recorded in the metadata.
|
||||
|
||||
Associated files may be attached to a model, a subgraph, or an input/output
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
metadata: the ModelMetadataT object.
|
||||
|
||||
Returns:
|
||||
List of recorded AssociatedFileT objects.
|
||||
"""
|
||||
recorded_files = []
|
||||
|
||||
# Add associated files attached to ModelMetadata.
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
metadata, "associatedFiles")
|
||||
|
||||
# Add associated files attached to each SubgraphMetadata.
|
||||
for subgraph in metadata.subgraphMetadata or []:
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
subgraph, "associatedFiles")
|
||||
|
||||
# Add associated files attached to each input tensor.
|
||||
for tensor_metadata in subgraph.inputTensorMetadata or []:
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
tensor_metadata, "associatedFiles")
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
tensor_metadata, "processUnits")
|
||||
|
||||
# Add associated files attached to each output tensor.
|
||||
for tensor_metadata in subgraph.outputTensorMetadata or []:
|
||||
recorded_files += self._get_associated_files_from_table(
|
||||
tensor_metadata, "associatedFiles")
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
tensor_metadata, "processUnits")
|
||||
|
||||
# Add associated files attached to the input_process_units.
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
subgraph, "inputProcessUnits")
|
||||
|
||||
# Add associated files attached to the output_process_units.
|
||||
recorded_files += self._get_associated_files_from_process_units(
|
||||
subgraph, "outputProcessUnits")
|
||||
|
||||
return recorded_files
|
||||
|
||||
def _populate_associated_files(self):
|
||||
"""Concatenates associated files after TensorFlow Lite model file.
|
||||
|
||||
If the MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
"""
|
||||
# Opens up the model file in "appending" mode.
|
||||
# If self._model_file already has pack files, zipfile will concatenate
|
||||
# addition files after self._model_file. For example, suppose we have
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt
|
||||
# Then after trigger populate() to add label3.txt, self._model_file becomes
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
|
||||
with tempfile.SpooledTemporaryFile() as temp:
|
||||
# (1) Copy content from model file of to temp file.
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
shutil.copyfileobj(f, temp)
|
||||
|
||||
# (2) Append of to a temp file as a zip.
|
||||
with _open_as_zipfile(temp, "a") as zf:
|
||||
for file_name, file_buffer in self._associated_files.items():
|
||||
zf.writestr(file_name, file_buffer)
|
||||
|
||||
# (3) Copy temp file to model file.
|
||||
temp.seek(0)
|
||||
with _open_file(self._model_file, "wb") as f:
|
||||
shutil.copyfileobj(temp, f)
|
||||
|
||||
def _populate_metadata_buffer(self):
|
||||
"""Populates the metadata buffer (in bytearray) into the model file.
|
||||
|
||||
Inserts metadata_buf into the metadata field of schema.Model. If the
|
||||
MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
|
||||
Existing metadata buffer (if applied) will be overridden by the new metadata
|
||||
buffer.
|
||||
"""
|
||||
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = self._metadata_buf
|
||||
|
||||
is_populated = False
|
||||
if not model.metadata:
|
||||
model.metadata = []
|
||||
else:
|
||||
# Check if metadata has already been populated.
|
||||
for meta in model.metadata:
|
||||
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
|
||||
is_populated = True
|
||||
model.buffers[meta.buffer] = buffer_field
|
||||
|
||||
if not is_populated:
|
||||
if not model.buffers:
|
||||
model.buffers = []
|
||||
model.buffers.append(buffer_field)
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = self.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata.append(metadata_field)
|
||||
|
||||
# Packs model back to a flatbuffer binaray file.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
|
||||
model_buf = b.Output()
|
||||
|
||||
# Saves the updated model buffer to model file.
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
if packed_files:
|
||||
# Writes the updated model buffer and associated files into a new model
|
||||
# file (in memory). Then overwrites the original model file.
|
||||
with tempfile.SpooledTemporaryFile() as temp:
|
||||
temp.write(model_buf)
|
||||
self._copy_archived_files(self._model_file, packed_files, temp)
|
||||
temp.seek(0)
|
||||
with _open_file(self._model_file, "wb") as f:
|
||||
shutil.copyfileobj(temp, f)
|
||||
else:
|
||||
with _open_file(self._model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
def _use_basename_for_associated_files_in_metadata(self, metadata):
|
||||
"""Removes any associated file local directory (if exists)."""
|
||||
for file in self._get_recorded_associated_file_object_list(metadata):
|
||||
file.name = os.path.basename(file.name)
|
||||
|
||||
def _validate_metadata(self, metadata_buf):
|
||||
"""Validates the metadata to be populated."""
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
|
||||
# Verify the number of SubgraphMetadata is exactly one.
|
||||
# TFLite currently only support one subgraph.
|
||||
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||
metadata_buf, 0)
|
||||
if model_meta.SubgraphMetadataLength() != 1:
|
||||
raise ValueError("The number of SubgraphMetadata should be exactly one, "
|
||||
"but got {0}.".format(
|
||||
model_meta.SubgraphMetadataLength()))
|
||||
|
||||
# Verify if the number of tensor metadata matches the number of tensors.
|
||||
with _open_file(self._model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||
|
||||
num_input_tensors = model.Subgraphs(0).InputsLength()
|
||||
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
|
||||
if num_input_tensors != num_input_meta:
|
||||
raise ValueError(
|
||||
"The number of input tensors ({0}) should match the number of "
|
||||
"input tensor metadata ({1})".format(num_input_tensors,
|
||||
num_input_meta))
|
||||
num_output_tensors = model.Subgraphs(0).OutputsLength()
|
||||
num_output_meta = model_meta.SubgraphMetadata(
|
||||
0).OutputTensorMetadataLength()
|
||||
if num_output_tensors != num_output_meta:
|
||||
raise ValueError(
|
||||
"The number of output tensors ({0}) should match the number of "
|
||||
"output tensor metadata ({1})".format(num_output_tensors,
|
||||
num_output_meta))
|
||||
|
||||
|
||||
class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
||||
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
|
||||
|
||||
This class is used to populate metadata into a in-memory model buffer. As we
|
||||
use Zip API to concatenate associated files after tflite model file, the
|
||||
populating operation is developed based on a model file. For in-memory model
|
||||
buffer, we create a tempfile to serve the populating operation. This class is
|
||||
then used to generate this tempfile, and delete the file when the
|
||||
MetadataPopulator object is deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, model_buf):
|
||||
"""Constructor for _MetadataPopulatorWithBuffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Raises:
|
||||
ValueError: model_buf is empty.
|
||||
ValueError: model_buf does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
if not model_buf:
|
||||
raise ValueError("model_buf cannot be empty.")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
model_file = temp.name
|
||||
|
||||
with _open_file(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
super().__init__(model_file)
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor of _MetadataPopulatorWithBuffer.
|
||||
|
||||
Deletes the tempfile.
|
||||
"""
|
||||
if os.path.exists(self._model_file):
|
||||
os.remove(self._model_file)
|
||||
|
||||
|
||||
class MetadataDisplayer(object):
|
||||
"""Displays metadata and associated file info in human-readable format."""
|
||||
|
||||
def __init__(self, model_buffer, metadata_buffer, associated_file_list):
|
||||
"""Constructor for MetadataDisplayer.
|
||||
|
||||
Args:
|
||||
model_buffer: valid buffer of the model file.
|
||||
metadata_buffer: valid buffer of the metadata file.
|
||||
associated_file_list: list of associate files in the model file.
|
||||
"""
|
||||
_assert_model_buffer_identifier(model_buffer)
|
||||
_assert_metadata_buffer_identifier(metadata_buffer)
|
||||
self._model_buffer = model_buffer
|
||||
self._metadata_buffer = metadata_buffer
|
||||
self._associated_file_list = associated_file_list
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataDisplayer object for the model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The model does not have metadata.
|
||||
"""
|
||||
_assert_file_exist(model_file)
|
||||
with _open_file(model_file, "rb") as f:
|
||||
return cls.with_model_buffer(f.read())
|
||||
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buffer):
|
||||
"""Creates a MetadataDisplayer object for a file buffer.
|
||||
|
||||
Args:
|
||||
model_buffer: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
"""
|
||||
if not model_buffer:
|
||||
raise ValueError("model_buffer cannot be empty.")
|
||||
metadata_buffer = get_metadata_buffer(model_buffer)
|
||||
if not metadata_buffer:
|
||||
raise ValueError("The model does not have metadata.")
|
||||
associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
|
||||
return cls(model_buffer, metadata_buffer, associated_file_list)
|
||||
|
||||
def get_associated_file_buffer(self, filename):
|
||||
"""Get the specified associated file content in bytearray.
|
||||
|
||||
Args:
|
||||
filename: name of the file to be extracted.
|
||||
|
||||
Returns:
|
||||
The file content in bytearray.
|
||||
|
||||
Raises:
|
||||
ValueError: if the file does not exist in the model.
|
||||
"""
|
||||
if filename not in self._associated_file_list:
|
||||
raise ValueError(
|
||||
"The file, {}, does not exist in the model.".format(filename))
|
||||
|
||||
with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
|
||||
return zf.read(filename)
|
||||
|
||||
def get_metadata_buffer(self):
|
||||
"""Get the metadata buffer in bytearray out from the model."""
|
||||
return copy.deepcopy(self._metadata_buffer)
|
||||
|
||||
def get_metadata_json(self):
|
||||
"""Converts the metadata into a json string."""
|
||||
return convert_to_json(self._metadata_buffer)
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Returns a list of associated files that are packed in the model.
|
||||
|
||||
Returns:
|
||||
A name list of associated files.
|
||||
"""
|
||||
return copy.deepcopy(self._associated_file_list)
|
||||
|
||||
@staticmethod
|
||||
def _parse_packed_associted_file_list(model_buf):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Args:
|
||||
model_buf: valid file buffer.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
|
||||
try:
|
||||
with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
|
||||
return zf.namelist()
|
||||
except zipfile.BadZipFile:
|
||||
return []
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer):
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
||||
Raises:
|
||||
ValueError: error occured when parsing the metadata schema file.
|
||||
"""
|
||||
|
||||
opt = _pywrap_flatbuffers.IDLOptions()
|
||||
opt.strict_json = True
|
||||
parser = _pywrap_flatbuffers.Parser(opt)
|
||||
with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
|
||||
metadata_schema_content = f.read()
|
||||
if not parser.parse(metadata_schema_content):
|
||||
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
|
||||
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
|
||||
|
||||
|
||||
def _assert_file_exist(filename):
|
||||
"""Checks if a file exists."""
|
||||
if not _exists_file(filename):
|
||||
raise IOError("File, '{0}', does not exist.".format(filename))
|
||||
|
||||
|
||||
def _assert_model_file_identifier(model_file):
|
||||
"""Checks if a model file has the expected TFLite schema identifier."""
|
||||
_assert_file_exist(model_file)
|
||||
with _open_file(model_file, "rb") as f:
|
||||
_assert_model_buffer_identifier(f.read())
|
||||
|
||||
|
||||
def _assert_model_buffer_identifier(model_buf):
|
||||
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
|
||||
raise ValueError(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.")
|
||||
|
||||
|
||||
def _assert_metadata_buffer_identifier(metadata_buf):
|
||||
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
|
||||
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
|
||||
metadata_buf, 0):
|
||||
raise ValueError(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.")
|
||||
|
||||
|
||||
def get_metadata_buffer(model_buf):
|
||||
"""Returns the metadata in the model file as a buffer.
|
||||
|
||||
Args:
|
||||
model_buf: valid buffer of the model file.
|
||||
|
||||
Returns:
|
||||
Metadata buffer. Returns `None` if the model does not have metadata.
|
||||
"""
|
||||
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||
|
||||
# Gets metadata from the model file.
|
||||
for i in range(tflite_model.MetadataLength()):
|
||||
meta = tflite_model.Metadata(i)
|
||||
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||
buffer_index = meta.Buffer()
|
||||
metadata = tflite_model.Buffers(buffer_index)
|
||||
return metadata.DataAsNumpy().tobytes()
|
||||
|
||||
return None
|
34
mediapipe/tasks/python/metadata/metadata_displayer_cli.py
Normal file
34
mediapipe/tasks/python/metadata/metadata_displayer_cli.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""CLI tool for display metadata."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.')
|
||||
flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.')
|
||||
|
||||
|
||||
def main(_):
|
||||
displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path)
|
||||
with open(FLAGS.export_json_path, 'w') as f:
|
||||
f.write(displayer.get_metadata_json())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user