From ace098f370d316ec3d60684bdd9fe5eb85aef922 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 11:55:04 -0800 Subject: [PATCH 01/10] Add proper Cast for MultiPort PiperOrigin-RevId: 487012509 --- mediapipe/framework/api2/builder.h | 7 ++++ mediapipe/framework/api2/builder_test.cc | 51 ++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index bf7f2b399..5af9ee5e0 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -106,6 +106,13 @@ class MultiPort : public Single { return Single{&GetWithAutoGrow(&vec_, index)}; } + template + auto Cast() { + using SingleCastT = + std::invoke_result_t), Single*>; + return MultiPort(&vec_); + } + private: std::vector>& vec_; }; diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 810c52527..3bf3ec198 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +TEST(BuilderTest, MultiPortIsCastToMultiPort) { + builder::Graph graph; + builder::MultiSource any_input = graph.In("ANY_INPUT"); + builder::MultiSource int_input = any_input.Cast(); + builder::MultiDestination any_output = graph.Out("ANY_OUTPUT"); + builder::MultiDestination int_output = any_output.Cast(); + int_input >> int_output; + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "ANY_INPUT:__stream_0" + output_stream: "ANY_OUTPUT:__stream_0" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { + builder::Graph graph; + builder::MultiSource any_multi_input = graph.In("ANY_INPUT"); + builder::Source any_input = any_multi_input; + builder::MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); + builder::Destination any_output = any_multi_output; + any_input >> any_output; + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "ANY_INPUT:__stream_0" + output_stream: "ANY_OUTPUT:__stream_0" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { + builder::Graph graph; + builder::Source int_input = graph.In("INT_INPUT").Cast(); + builder::Source any_input = graph.In("ANY_OUTPUT"); + builder::Destination int_output = graph.Out("INT_OUTPUT").Cast(); + builder::Destination any_output = graph.Out("ANY_OUTPUT"); + int_input >> int_output; + any_input >> any_output; + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "ANY_OUTPUT:__stream_0" + input_stream: "INT_INPUT:__stream_1" + output_stream: "ANY_OUTPUT:__stream_0" + output_stream: "INT_OUTPUT:__stream_1" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + } // namespace test } // namespace api2 } // namespace mediapipe From 37930609ffd27d1aca55dd705290b400cf1b0d59 Mon Sep 17 00:00:00 2001 From: Kinar R <42828719+kinaryml@users.noreply.github.com> Date: Wed, 9 Nov 2022 02:58:38 +0530 Subject: [PATCH 02/10] Update landmark_detection_result.py --- .../python/components/containers/landmark_detection_result.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index 5e68efda1..d115cb52f 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -97,8 +97,7 @@ class LandmarksDetectionResult: landmarks.append(_NormalizedLandmark.create_from_pb2(landmark)) for landmark in pb2_obj.world_landmarks.landmark: - world_landmarks.append(_Landmark.create_from_pb2(landmark) - ) + world_landmarks.append(_Landmark.create_from_pb2(landmark)) return LandmarksDetectionResult( landmarks=landmarks, categories=categories, From 0363d60511676379ac558838592cf373c4416f8e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 13:46:50 -0800 Subject: [PATCH 03/10] Open-sources TextEmbedder. PiperOrigin-RevId: 487041832 --- mediapipe/tasks/cc/text/text_embedder/BUILD | 87 +++++++++++ .../tasks/cc/text/text_embedder/proto/BUILD | 30 ++++ .../proto/text_embedder_graph_options.proto | 36 +++++ .../cc/text/text_embedder/text_embedder.cc | 104 +++++++++++++ .../cc/text/text_embedder/text_embedder.h | 96 ++++++++++++ .../text/text_embedder/text_embedder_graph.cc | 145 ++++++++++++++++++ .../text/text_embedder/text_embedder_test.cc | 143 +++++++++++++++++ 7 files changed, 641 insertions(+) create mode 100644 mediapipe/tasks/cc/text/text_embedder/BUILD create mode 100644 mediapipe/tasks/cc/text/text_embedder/proto/BUILD create mode 100644 mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder.cc create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder.h create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD new file mode 100644 index 000000000..331902362 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -0,0 +1,87 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_embedder", + srcs = ["text_embedder.cc"], + hdrs = ["text_embedder.h"], + deps = [ + ":text_embedder_graph", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "text_embedder_graph", + srcs = ["text_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "text_embedder_test", + srcs = ["text_embedder_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/BUILD b/mediapipe/tasks/cc/text/text_embedder/proto/BUILD new file mode 100644 index 000000000..146483af1 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "text_embedder_graph_options_proto", + srcs = ["text_embedder_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto new file mode 100644 index 000000000..6b8d41a57 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -0,0 +1,36 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.text.text_embedder.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +message TextEmbedderGraphOptions { + extend mediapipe.CalculatorOptions { + optional TextEmbedderGraphOptions ext = 477589892; + } + + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring the embedder behavior, such as normalization or + // quantization. + optional components.processors.proto.EmbedderOptions embedder_options = 2; +} diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc new file mode 100644 index 000000000..375058d57 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc @@ -0,0 +1,104 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h" + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" + +namespace mediapipe::tasks::text::text_embedder { +namespace { + +constexpr char kTextTag[] = "TEXT"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTextInStreamName[] = "text_in"; +constexpr char kEmbeddingsStreamName[] = "embeddings_out"; +constexpr char kGraphTypeName[] = + "mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + +using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; + +// Creates a MediaPipe graph config that contains a single node of type +// "mediapipe.tasks.text.text_embedder.TextEmbedderGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options_proto) { + api2::builder::Graph graph; + auto& task_graph = graph.AddNode(kGraphTypeName); + task_graph.GetOptions().Swap( + options_proto.get()); + graph.In(kTextTag).SetName(kTextInStreamName) >> task_graph.In(kTextTag); + task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >> + graph.Out(kEmbeddingsTag); + return graph.GetConfig(); +} + +// Converts the user-facing TextEmbedderOptions struct to the internal +// TextEmbedderGraphOptions proto. +std::unique_ptr +ConvertTextEmbedderOptionsToProto(TextEmbedderOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + auto embedder_options_proto = + std::make_unique( + components::processors::ConvertEmbedderOptionsToProto( + &(options->embedder_options))); + options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> TextEmbedder::Create( + std::unique_ptr options) { + std::unique_ptr options_proto = + ConvertTextEmbedderOptionsToProto(options.get()); + return core::TaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver)); +} + +absl::StatusOr TextEmbedder::Embed(absl::string_view text) { + ASSIGN_OR_RETURN( + auto output_packets, + runner_->Process( + {{kTextInStreamName, MakePacket(std::string(text))}})); + return ConvertToEmbeddingResult( + output_packets[kEmbeddingsStreamName].Get()); +} + +absl::StatusOr TextEmbedder::CosineSimilarity( + const components::containers::Embedding& u, + const components::containers::Embedding& v) { + return components::utils::CosineSimilarity(u, v); +} + +} // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.h b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h new file mode 100644 index 000000000..81f90fd27 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h @@ -0,0 +1,96 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" + +namespace mediapipe::tasks::text::text_embedder { + +// Alias the shared EmbeddingResult struct as result typo. +using TextEmbedderResult = + ::mediapipe::tasks::components::containers::EmbeddingResult; + +// Options for configuring a MediaPipe text embedder task. +struct TextEmbedderOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the embedder behavior, such as L2-normalization or + // scalar-quantization. + components::processors::EmbedderOptions embedder_options; +}; + +// Performs embedding extraction on text. +// +// This API expects a TFLite model with TFLite Model Metadata that contains the +// mandatory (described below) input tensors and output tensors. Metadata should +// contain the input process unit for the model's Tokenizer as well as input / +// output tensor metadata. +// +// TODO: Support Universal Sentence Encoder. +// Input tensors: +// (kTfLiteInt32) +// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names +// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and +// segment ids respectively +// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the +// input ids +// +// At least one output tensor with: +// (kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +class TextEmbedder : core::BaseTaskApi { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a TextEmbedder from the provided `options`. A non-default + // OpResolver can be specified in the BaseOptions in order to support custom + // Ops or specify a subset of built-in Ops. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs embedding extraction on the input `text`. + absl::StatusOr Embed(absl::string_view text); + + // Shuts down the TextEmbedder when all the work is done. + absl::Status Close() { return runner_->Close(); } + + // Utility function to compute cosine similarity [1] between two embeddings. + // May return an InvalidArgumentError if e.g. the embeddings are of different + // types (quantized vs. float), have different sizes, or have a an L2-norm of + // 0. + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + static absl::StatusOr CosineSimilarity( + const components::containers::Embedding& u, + const components::containers::Embedding& v); +}; + +} // namespace mediapipe::tasks::text::text_embedder + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc new file mode 100644 index 000000000..79eedb6b5 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -0,0 +1,145 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" + +namespace mediapipe::tasks::text::text_embedder { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::core::ModelResources; + +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; +constexpr char kTensorsTag[] = "TENSORS"; + +} // namespace + +// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding +// extraction. +// - Accepts input text and outputs embeddings on CPU. +// +// Inputs: +// TEXT - std::string +// Input text to perform embedding extraction on. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult +// The embedding result. +// +// Example: +// node { +// calculator: "mediapipe.tasks.text.TextEmbedderGraph" +// input_stream: "TEXT:text_in" +// output_stream: "EMBEDDINGS:embedding_result_out" +// options { +// [mediapipe.tasks.text.text_embedder.proto.TextEmbedderGraphOptions.ext] { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// } +// } +// } +class TextEmbedderGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + CHECK(sc != nullptr); + ASSIGN_OR_RETURN(const ModelResources* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + Source embedding_result_out, + BuildTextEmbedderTask(sc->Options(), + *model_resources, + graph[Input(kTextTag)], graph)); + embedding_result_out >> graph[Output(kEmbeddingsTag)]; + return graph.GetConfig(); + } + + private: + // Adds a mediapipe TextEmbedder task graph into the provided + // builder::Graph instance. The TextEmbedder task takes an input + // text (std::string) and returns an embedding result. + // + // task_options: the mediapipe tasks TextEmbedderGraphOptions proto. + // model_resources: the ModelResources object initialized from a + // TextEmbedder model file with model metadata. + // text_in: (std::string) stream to run embedding extraction on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr> BuildTextEmbedderTask( + const proto::TextEmbedderGraphOptions& task_options, + const ModelResources& model_resources, Source text_in, + Graph& graph) { + // Adds preprocessing calculators and connects them to the text input + // stream. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); + MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + model_resources, + preprocessing.GetOptions< + tasks::components::proto::TextPreprocessingGraphOptions>())); + text_in >> preprocessing.In(kTextTag); + + // Adds both InferenceCalculator and ModelResourcesCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + // The metadata extractor side-output comes from the + // ModelResourcesCalculator. + inference.SideOut(kMetadataExtractorTag) >> + preprocessing.SideIn(kMetadataExtractorTag); + preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + + // Adds postprocessing calculators and connects its input stream to the + // inference results. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( + model_resources, task_options.embedder_options(), + &postprocessing.GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + + // Outputs the embedding result. + return postprocessing[Output(kEmbeddingsTag)]; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::text::text_embedder::TextEmbedderGraph); + +} // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc new file mode 100644 index 000000000..fa3d8af91 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::text_embedder { +namespace { + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; + +// Note that these models use dynamic-sized tensors. +// Embedding model with BERT preprocessing. +constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite"; +// Embedding model with regex preprocessing. +constexpr char kRegexOneEmbeddingModel[] = + "regex_one_embedding_with_metadata.tflite"; + +// Tolerance for embedding vector coordinate values. +constexpr float kEpsilon = 1e-4; +// Tolerancy for cosine similarity evaluation. +constexpr double kSimilarityTolerancy = 1e-6; + +using ::mediapipe::file::JoinPath; +using ::testing::HasSubstr; +using ::testing::Optional; + +class EmbedderTest : public tflite_shims::testing::Test {}; + +TEST_F(EmbedderTest, FailsWithMissingModel) { + auto text_embedder = + TextEmbedder::Create(std::make_unique()); + ASSERT_EQ(text_embedder.status().code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT( + text_embedder.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + ASSERT_THAT(text_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(EmbedderTest, SucceedsWithMobileBert) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result0.embeddings.size(), 1); + ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512); + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + ASSERT_EQ(result1.embeddings.size(), 1); + ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512); + ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + +TEST(EmbedTest, SucceedsWithRegexOneEmbeddingModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kRegexOneEmbeddingModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + auto result0, + text_embedder->Embed("it's a charming and often affecting journey")); + EXPECT_EQ(result0.embeddings.size(), 1); + EXPECT_EQ(result0.embeddings[0].float_embedding.size(), 16); + + EXPECT_NEAR(result0.embeddings[0].float_embedding[0], 0.0309356f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + EXPECT_EQ(result1.embeddings.size(), 1); + EXPECT_EQ(result1.embeddings[0].float_embedding.size(), 16); + + EXPECT_NEAR(result1.embeddings[0].float_embedding[0], 0.0312863f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.999937, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + +TEST_F(EmbedderTest, SucceedsWithQuantization) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + options->embedder_options.quantize = true; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result.embeddings.size(), 1); + ASSERT_EQ(result.embeddings[0].quantized_embedding.size(), 512); + + MP_ASSERT_OK(text_embedder->Close()); +} + +} // namespace +} // namespace mediapipe::tasks::text::text_embedder From b3d19fa1af3b23a57993bf3a006e390184459e9f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 13:50:28 -0800 Subject: [PATCH 04/10] Use model bundle writer when exporting models in gesture recognizer PiperOrigin-RevId: 487042776 --- mediapipe/model_maker/python/core/utils/BUILD | 13 +++++++ .../python/core/utils/file_util.py | 36 +++++++++++++++++++ .../python/core/utils/file_util_test.py | 29 +++++++++++++++ .../python/core/utils/model_util.py | 26 +++++++++----- .../python/core/utils/model_util_test.py | 15 +++++++- .../python/core/utils/testdata/BUILD | 23 ++++++++++++ .../python/core/utils/testdata/test.txt | 0 mediapipe/tasks/testdata/vision/BUILD | 5 +++ 8 files changed, 138 insertions(+), 9 deletions(-) create mode 100644 mediapipe/model_maker/python/core/utils/file_util.py create mode 100644 mediapipe/model_maker/python/core/utils/file_util_test.py create mode 100644 mediapipe/model_maker/python/core/utils/testdata/BUILD create mode 100644 mediapipe/model_maker/python/core/utils/testdata/test.txt diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index a2ec52044..12fef631f 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -35,6 +35,7 @@ py_library( name = "model_util", srcs = ["model_util.py"], deps = [ + ":file_util", ":quantization", "//mediapipe/model_maker/python/core/data:dataset", ], @@ -50,6 +51,18 @@ py_test( ], ) +py_library( + name = "file_util", + srcs = ["file_util.py"], +) + +py_test( + name = "file_util_test", + srcs = ["file_util_test.py"], + data = ["//mediapipe/model_maker/python/core/utils/testdata"], + deps = [":file_util"], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py new file mode 100644 index 000000000..bccf928e2 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -0,0 +1,36 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for files.""" + +import os + +# resources dependency + + +def get_absolute_path(file_path: str) -> str: + """Gets the absolute path of a file. + + Args: + file_path: The path to a file relative to the `mediapipe` dir + + Returns: + The full path of the file + """ + # Extract the file path before mediapipe/ as the `base_dir`. By joining it + # with the `path` which defines the relative path under mediapipe/, it + # yields to the absolute path of the model files directory. + cwd = os.path.dirname(__file__) + base_dir = cwd[:cwd.rfind('mediapipe')] + absolute_path = os.path.join(base_dir, file_path) + return absolute_path diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py new file mode 100644 index 000000000..4a2d6dcfb --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from absl.testing import absltest +from mediapipe.model_maker.python.core.utils import file_util + + +class FileUtilTest(absltest.TestCase): + + def test_get_absolute_path(self): + test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt' + absolute_path = file_util.get_absolute_path(test_file) + self.assertTrue(os.path.exists(absolute_path)) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index ada0a61e3..01e301e43 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for keras models.""" +"""Utilities for models.""" from __future__ import absolute_import from __future__ import division @@ -26,8 +26,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import tensorflow as tf -# resources dependency from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import quantization DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 @@ -62,16 +62,26 @@ def load_keras_model(model_path: str, Returns: A tensorflow Keras model. """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `model_path` which defines the relative path under mediapipe/, it - # yields to the aboslution path of the model files directory. - cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, model_path) + absolute_path = file_util.get_absolute_path(model_path) return tf.keras.models.load_model( absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) +def load_tflite_model_buffer(model_path: str) -> bytearray: + """Loads a TFLite model buffer from file. + + Args: + model_path: Relative path to a TFLite file + + Returns: + A TFLite model buffer + """ + absolute_path = file_util.get_absolute_path(model_path) + with tf.io.gfile.GFile(absolute_path, 'rb') as f: + tflite_model_buffer = f.read() + return tflite_model_buffer + + def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, batch_size: Optional[int] = None, train_data: Optional[dataset.Dataset] = None) -> int: diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 1f9e0f1db..bef9c8a97 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_model(self): + def test_load_keras_model(self): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') @@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) + def test_load_tflite_model_buffer(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + tflite_model = model_util.convert_to_tflite(model) + tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') + model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) + + tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) + test_util.test_tflite( + keras_model=model, + tflite_model=tflite_model_buffer, + size=[1, input_dim]) + @parameterized.named_parameters( dict( testcase_name='input_only_steps_per_epoch', diff --git a/mediapipe/model_maker/python/core/utils/testdata/BUILD b/mediapipe/model_maker/python/core/utils/testdata/BUILD new file mode 100644 index 000000000..8eed72f78 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/BUILD @@ -0,0 +1,23 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = ["test.txt"], +) diff --git a/mediapipe/model_maker/python/core/utils/testdata/test.txt b/mediapipe/model_maker/python/core/utils/testdata/test.txt new file mode 100644 index 000000000..e69de29bb diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index e23c4a66c..55d386185 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -144,8 +144,13 @@ filegroup( ) # Gestures related models. Visible to model_maker. +# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval filegroup( name = "test_gesture_models", + srcs = [ + "hand_landmark_full.tflite", + "palm_detection_full.tflite", + ], visibility = [ "//mediapipe/model_maker:__subpackages__", "//mediapipe/tasks:internal", From 0917e8cb8eccb732994c9fdc031128a71d2f21e7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 14:47:19 -0800 Subject: [PATCH 05/10] Support continual training image classifier from saved checkpoint files. PiperOrigin-RevId: 487057612 --- .../image_classifier/image_classifier.py | 5 +- .../image_classifier/image_classifier_test.py | 53 ++++++++++++++++--- .../train_image_classifier_lib.py | 18 ++++++- 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 29b3025d8..f6edbeab4 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -68,7 +68,10 @@ class ImageClassifier(classifier.Classifier): ) -> 'ImageClassifier': """Creates and trains an image classifier. - Loads data and trains the model based on data for image classification. + Loads data and trains the model based on data for image classification. If a + checkpoint file exists in the {options.hparams.export_dir}/checkpoint/ + directory, the training process will load the weight from the checkpoint + file for continual training. Args: train_data: Training data. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 8446df18e..252659edc 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -13,10 +13,13 @@ # limitations under the License. import filecmp +import io import os +import tempfile -from unittest import mock +from unittest import mock as unittest_mock from absl.testing import parameterized +import mock import numpy as np import tensorflow as tf @@ -63,14 +66,20 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): options=image_classifier.ImageClassifierOptions( supported_model=image_classifier.SupportedModels.MOBILENET_V2, hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite0', options=image_classifier.ImageClassifierOptions( supported_model=( image_classifier.SupportedModels.EFFICIENTNET_LITE0), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite0_change_dropout_rate', options=image_classifier.ImageClassifierOptions( @@ -78,21 +87,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): image_classifier.SupportedModels.EFFICIENTNET_LITE0), model_options=image_classifier.ModelOptions(dropout_rate=0.1), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite2', options=image_classifier.ImageClassifierOptions( supported_model=( image_classifier.SupportedModels.EFFICIENTNET_LITE2), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), dict( testcase_name='efficientnet_lite4', options=image_classifier.ImageClassifierOptions( supported_model=( image_classifier.SupportedModels.EFFICIENTNET_LITE4), hparams=image_classifier.HParams( - epochs=1, batch_size=1, shuffle=True))), + epochs=1, + batch_size=1, + shuffle=True, + export_dir=tempfile.mkdtemp()))), ) def test_create_and_train_model( self, options: image_classifier.ImageClassifierOptions): @@ -117,16 +135,35 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + def test_continual_training_by_loading_checkpoint(self): + mock_stdout = io.StringIO() + with mock.patch('sys.stdout', mock_stdout): + options = image_classifier.ImageClassifierOptions( + supported_model=image_classifier.SupportedModels.EFFICIENTNET_LITE0, + hparams=image_classifier.HParams( + epochs=5, batch_size=1, shuffle=True)) + model = image_classifier.ImageClassifier.create( + train_data=self._train_data, + validation_data=self._test_data, + options=options) + model = image_classifier.ImageClassifier.create( + train_data=self._train_data, + validation_data=self._test_data, + options=options) + self._test_accuracy(model) + + self.assertRegex(mock_stdout.getvalue(), 'Resuming from') + def _test_accuracy(self, model, threshold=0.0): _, accuracy = model.evaluate(self._test_data) self.assertGreaterEqual(accuracy, threshold) - @mock.patch.object( + @unittest_mock.patch.object( image_classifier.hyperparameters, 'HParams', autospec=True, return_value=image_classifier.HParams(epochs=1)) - @mock.patch.object( + @unittest_mock.patch.object( image_classifier.model_options, 'ImageClassifierModelOptions', autospec=True, diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py index 4adddefeb..c5b28cff5 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -13,6 +13,7 @@ # limitations under the License. """Library to train model.""" +import os import tensorflow as tf from mediapipe.model_maker.python.core.utils import model_util @@ -78,11 +79,24 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams, loss = tf.keras.losses.CategoricalCrossentropy( label_smoothing=hparams.label_smoothing) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) - callbacks = model_util.get_default_callbacks(export_dir=hparams.export_dir) + + summary_dir = os.path.join(hparams.export_dir, 'summaries') + summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) + # Save checkpoint every 5 epochs. + checkpoint_path = os.path.join(hparams.export_dir, 'checkpoint') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + os.path.join(checkpoint_path, 'model-{epoch:04d}'), + save_weights_only=True, + period=5) + + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) + if latest_checkpoint: + print(f'Resuming from {latest_checkpoint}') + model.load_weights(latest_checkpoint) # Train the model. return model.fit( x=train_ds, epochs=hparams.epochs, validation_data=validation_ds, - callbacks=callbacks) + callbacks=[summary_callback, checkpoint_callback]) From 253ff0f85c9829551db2efeb7a997c71e8935793 Mon Sep 17 00:00:00 2001 From: Lu Wang Date: Tue, 8 Nov 2022 14:54:10 -0800 Subject: [PATCH 06/10] Update the Java doc for model asset for BaseOptions PiperOrigin-RevId: 487059371 --- .../com/google/mediapipe/tasks/core/BaseOptions.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java index 7f2903503..d1db08893 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java @@ -26,22 +26,23 @@ public abstract class BaseOptions { @AutoValue.Builder public abstract static class Builder { /** - * Sets the model path to a tflite model with metadata in the assets. + * Sets the model path to a model asset file (a tflite model or a model asset bundle file) in + * the Android app assets folder. * *

Note: when model path is set, both model file descriptor and model buffer should be empty. */ public abstract Builder setModelAssetPath(String value); /** - * Sets the native fd int of a tflite model with metadata. + * Sets the native fd int of a model asset file (a tflite model or a model asset bundle file). * *

Note: when model file descriptor is set, both model path and model buffer should be empty. */ public abstract Builder setModelAssetFileDescriptor(Integer value); /** - * Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model - * with metadata. + * Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a model asset + * file (a tflite model or a model asset bundle file). * *

Note: when model buffer is set, both model file and model file descriptor should be empty. */ From 669d5395519584bc9d405ab9c4823c8a7a50fe4f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 8 Nov 2022 15:44:35 -0800 Subject: [PATCH 07/10] NPM package definitions for MediaPipe Tasks PiperOrigin-RevId: 487071334 --- mediapipe/tasks/web/BUILD | 106 ++++++++++ mediapipe/tasks/web/audio.ts | 17 ++ mediapipe/tasks/web/audio/BUILD | 2 + mediapipe/tasks/web/package.json | 15 ++ mediapipe/tasks/web/rollup.config.mjs | 9 + mediapipe/tasks/web/text.ts | 17 ++ mediapipe/tasks/web/text/BUILD | 2 + .../tasks/web/text/text_classifier/BUILD | 4 +- ...ptions.d.ts => text_classifier_options.ts} | 0 ..._result.d.ts => text_classifier_result.ts} | 0 mediapipe/tasks/web/vision.ts | 17 ++ mediapipe/tasks/web/vision/BUILD | 2 + .../tasks/web/vision/gesture_recognizer/BUILD | 4 +- ...ons.d.ts => gesture_recognizer_options.ts} | 0 ...sult.d.ts => gesture_recognizer_result.ts} | 0 .../tasks/web/vision/object_detector/BUILD | 4 +- ...ptions.d.ts => object_detector_options.ts} | 0 ..._result.d.ts => object_detector_result.ts} | 0 package.json | 4 + yarn.lock | 182 ++++++++++++++++-- 20 files changed, 366 insertions(+), 19 deletions(-) create mode 100644 mediapipe/tasks/web/BUILD create mode 100644 mediapipe/tasks/web/audio.ts create mode 100644 mediapipe/tasks/web/package.json create mode 100644 mediapipe/tasks/web/rollup.config.mjs create mode 100644 mediapipe/tasks/web/text.ts rename mediapipe/tasks/web/text/text_classifier/{text_classifier_options.d.ts => text_classifier_options.ts} (100%) rename mediapipe/tasks/web/text/text_classifier/{text_classifier_result.d.ts => text_classifier_result.ts} (100%) create mode 100644 mediapipe/tasks/web/vision.ts rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_options.d.ts => gesture_recognizer_options.ts} (100%) rename mediapipe/tasks/web/vision/gesture_recognizer/{gesture_recognizer_result.d.ts => gesture_recognizer_result.ts} (100%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_options.d.ts => object_detector_options.ts} (100%) rename mediapipe/tasks/web/vision/object_detector/{object_detector_result.d.ts => object_detector_result.ts} (100%) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD new file mode 100644 index 000000000..9e4b52417 --- /dev/null +++ b/mediapipe/tasks/web/BUILD @@ -0,0 +1,106 @@ +# This contains the MediaPipe Tasks NPM package definitions. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +# Audio + +mediapipe_ts_library( + name = "audio_lib", + srcs = ["audio.ts"], + deps = ["//mediapipe/tasks/web/audio:audio_lib"], +) + +rollup_bundle( + name = "audio_bundle", + config_file = "rollup.config.mjs", + entry_point = "audio.ts", + output_dir = False, + deps = [ + ":audio_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + ], +) + +pkg_npm( + name = "audio_pkg", + package_name = "__PACKAGE_NAME__", + srcs = ["package.json"], + substitutions = { + "__PACKAGE_NAME__": "@mediapipe/tasks-audio", + "__DESCRIPTION__": "MediaPipe Audio Tasks", + "__BUNDLE__": "audio_bundle.js", + }, + tgz = "audio.tgz", + deps = [":audio_bundle"], +) + +# Text + +mediapipe_ts_library( + name = "text_lib", + srcs = ["text.ts"], + deps = ["//mediapipe/tasks/web/text:text_lib"], +) + +rollup_bundle( + name = "text_bundle", + config_file = "rollup.config.mjs", + entry_point = "text.ts", + output_dir = False, + deps = [ + ":text_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + ], +) + +pkg_npm( + name = "text_pkg", + package_name = "__PACKAGE_NAME__", + srcs = ["package.json"], + substitutions = { + "__PACKAGE_NAME__": "@mediapipe/tasks-text", + "__DESCRIPTION__": "MediaPipe Text Tasks", + "__BUNDLE__": "text_bundle.js", + }, + tgz = "text.tgz", + deps = [":text_bundle"], +) + +# Vision + +mediapipe_ts_library( + name = "vision_lib", + srcs = ["vision.ts"], + deps = ["//mediapipe/tasks/web/vision:vision_lib"], +) + +rollup_bundle( + name = "vision_bundle", + config_file = "rollup.config.mjs", + entry_point = "vision.ts", + output_dir = False, + deps = [ + ":vision_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + ], +) + +pkg_npm( + name = "vision_pkg", + package_name = "__PACKAGE_NAME__", + srcs = ["package.json"], + substitutions = { + "__PACKAGE_NAME__": "@mediapipe/tasks-vision", + "__DESCRIPTION__": "MediaPipe Vision Tasks", + "__BUNDLE__": "vision_bundle.js", + }, + tgz = "vision.tgz", + deps = [":vision_bundle"], +) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts new file mode 100644 index 000000000..4a3b80594 --- /dev/null +++ b/mediapipe/tasks/web/audio.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../tasks/web/audio/index'; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 0263738e6..4f6e48b28 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -2,6 +2,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +package(default_visibility = ["//mediapipe/tasks:internal"]) + mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json new file mode 100644 index 000000000..d3bd8b669 --- /dev/null +++ b/mediapipe/tasks/web/package.json @@ -0,0 +1,15 @@ +{ + "name": "__PACKAGE_NAME__", + "version": "__VERSION__", + "description": "__DESCRIPTION__", + "main": "__BUNDLE__", + "module": "__BUNDLE__", + "author": "mediapipe@google.com", + "license": "Apache-2.0", + "type": "module", + "dependencies": { + "google-protobuf": "^3.21.2" + }, + "homepage": "http://mediapipe.dev", + "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ] +} diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs new file mode 100644 index 000000000..392b235fc --- /dev/null +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -0,0 +1,9 @@ +import resolve from '@rollup/plugin-node-resolve'; +import commonjs from '@rollup/plugin-commonjs'; + +export default { + plugins: [ + resolve(), + commonjs() + ] +} diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts new file mode 100644 index 000000000..f8a0b6457 --- /dev/null +++ b/mediapipe/tasks/web/text.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../tasks/web/text/index'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index d3a797f83..a369d0af0 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -2,6 +2,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +package(default_visibility = ["//mediapipe/tasks:internal"]) + mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index e984a9554..25a8817d4 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -13,8 +13,8 @@ mediapipe_ts_library( name = "text_classifier", srcs = [ "text_classifier.ts", - "text_classifier_options.d.ts", - "text_classifier_result.d.ts", + "text_classifier_options.ts", + "text_classifier_result.ts", ], deps = [ "//mediapipe/framework:calculator_jspb_proto", diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts new file mode 100644 index 000000000..6ff8f725b --- /dev/null +++ b/mediapipe/tasks/web/vision.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../tasks/web/vision/index'; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3684f88ef..abdbc54ea 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -2,6 +2,8 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +package(default_visibility = ["//mediapipe/tasks:internal"]) + mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 8988c4794..6b99f6ce4 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -13,8 +13,8 @@ mediapipe_ts_library( name = "gesture_recognizer", srcs = [ "gesture_recognizer.ts", - "gesture_recognizer_options.d.ts", - "gesture_recognizer_result.d.ts", + "gesture_recognizer_options.ts", + "gesture_recognizer_result.ts", ], deps = [ "//mediapipe/framework:calculator_jspb_proto", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.ts diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts similarity index 100% rename from mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts rename to mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.ts diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 888537bd1..095a84b52 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -13,8 +13,8 @@ mediapipe_ts_library( name = "object_detector", srcs = [ "object_detector.ts", - "object_detector_options.d.ts", - "object_detector_result.d.ts", + "object_detector_options.ts", + "object_detector_result.ts", ], deps = [ "//mediapipe/framework:calculator_jspb_proto", diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.ts similarity index 100% rename from mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_options.ts diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.ts similarity index 100% rename from mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts rename to mediapipe/tasks/web/vision/object_detector/object_detector_result.ts diff --git a/package.json b/package.json index f8478a159..298157cbc 100644 --- a/package.json +++ b/package.json @@ -3,12 +3,16 @@ "version": "0.0.0-alphga", "description": "MediaPipe GitHub repo", "devDependencies": { + "@bazel/rollup": "^5.7.1", "@bazel/typescript": "^5.7.1", + "@rollup/plugin-commonjs": "^23.0.2", + "@rollup/plugin-node-resolve": "^15.0.1", "@types/google-protobuf": "^3.15.6", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", "protobufjs": "^7.1.2", "protobufjs-cli": "^1.0.2", + "rollup": "^2.3.0", "ts-protoc-gen": "^0.15.0", "typescript": "^4.8.4" } diff --git a/yarn.lock b/yarn.lock index e6398fb1f..a5ec6fb13 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3,9 +3,16 @@ "@babel/parser@^7.9.4": - version "7.20.1" - resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.1.tgz#3e045a92f7b4623cafc2425eddcb8cf2e54f9cc5" - integrity sha512-hp0AYxaZJhxULfM1zyp7Wgr+pSUKBcP3M+PHnSzWGdXOzg/kHWIgiUWARvubhUKGOEw3xqY4x+lyZ9ytBVcELw== + version "7.20.3" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.3.tgz#5358cf62e380cf69efcb87a7bb922ff88bfac6e2" + integrity sha512-OP/s5a94frIPXwjzEcv5S/tpQfc6XhxYUnmWpgdqMWGgYCuErA3SzozaRAMQgSZWKeTJxht9aWAkUY+0UzvOFg== + +"@bazel/rollup@^5.7.1": + version "5.7.1" + resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.1.tgz#6f644c2d493a5bd9cd3724a6f239e609585c6e37" + integrity sha512-LLNogoK2Qx9GIJVywQ+V/czjud8236mnaRX//g7qbOyXoWZDQvAEgsxRHq+lS/XX9USbh+zJJlfb+Dfp/PXx4A== + dependencies: + "@bazel/worker" "5.7.1" "@bazel/typescript@^5.7.1": version "5.7.1" @@ -77,6 +84,44 @@ resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570" integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== +"@rollup/plugin-commonjs@^23.0.2": + version "23.0.2" + resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.2.tgz#3a3a5b7b1b1cb29037eb4992edcaae997d7ebd92" + integrity sha512-e9ThuiRf93YlVxc4qNIurvv+Hp9dnD+4PjOqQs5vAYfcZ3+AXSrcdzXnVjWxcGQOa6KGJFcRZyUI3ktWLavFjg== + dependencies: + "@rollup/pluginutils" "^5.0.1" + commondir "^1.0.1" + estree-walker "^2.0.2" + glob "^8.0.3" + is-reference "1.2.1" + magic-string "^0.26.4" + +"@rollup/plugin-node-resolve@^15.0.1": + version "15.0.1" + resolved "https://registry.yarnpkg.com/@rollup/plugin-node-resolve/-/plugin-node-resolve-15.0.1.tgz#72be449b8e06f6367168d5b3cd5e2802e0248971" + integrity sha512-ReY88T7JhJjeRVbfCyNj+NXAG3IIsVMsX9b5/9jC98dRP8/yxlZdz7mHZbHk5zHr24wZZICS5AcXsFZAXYUQEg== + dependencies: + "@rollup/pluginutils" "^5.0.1" + "@types/resolve" "1.20.2" + deepmerge "^4.2.2" + is-builtin-module "^3.2.0" + is-module "^1.0.0" + resolve "^1.22.1" + +"@rollup/pluginutils@^5.0.1": + version "5.0.2" + resolved "https://registry.yarnpkg.com/@rollup/pluginutils/-/pluginutils-5.0.2.tgz#012b8f53c71e4f6f9cb317e311df1404f56e7a33" + integrity sha512-pTd9rIsP92h+B6wWwFbW8RkZv4hiR/xKsqre4SIuAOaOEQRxi0lqLke9k2/7WegC85GgUs9pjmOjCUi3In4vwA== + dependencies: + "@types/estree" "^1.0.0" + estree-walker "^2.0.2" + picomatch "^2.3.1" + +"@types/estree@*", "@types/estree@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/estree/-/estree-1.0.0.tgz#5fb2e536c1ae9bf35366eed879e827fa59ca41c2" + integrity sha512-WulqXMDUTYAXCjZnk6JtIHPigp55cVtDgDrO2gHRwhyJto21+1zbVCtOYB2L1F9w4qCQ0rOGWBnBe0FNTiEJIQ== + "@types/google-protobuf@^3.15.6": version "3.15.6" resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" @@ -110,6 +155,11 @@ resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d" integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg== +"@types/resolve@1.20.2": + version "1.20.2" + resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-1.20.2.tgz#97d26e00cd4a0423b4af620abecf3e6f442b7975" + integrity sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q== + acorn-jsx@^5.3.2: version "5.3.2" resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" @@ -162,6 +212,11 @@ buffer-from@^1.0.0: resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.2.tgz#2b146a6fd72e80b4f55d255f35ed59a3a9a41bd5" integrity sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ== +builtin-modules@^3.3.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-3.3.0.tgz#cae62812b89801e9656336e46223e030386be7b6" + integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw== + catharsis@^0.9.0: version "0.9.0" resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" @@ -189,6 +244,11 @@ color-name@~1.1.4: resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== +commondir@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b" + integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg== + concat-map@0.0.1: version "0.0.1" resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" @@ -199,6 +259,11 @@ deep-is@~0.1.3: resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ== +deepmerge@^4.2.2: + version "4.2.2" + resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955" + integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg== + entities@~2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" @@ -227,9 +292,9 @@ eslint-visitor-keys@^3.3.0: integrity sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA== espree@^9.0.0: - version "9.4.0" - resolved "https://registry.yarnpkg.com/espree/-/espree-9.4.0.tgz#cd4bc3d6e9336c433265fc0aa016fc1aaf182f8a" - integrity sha512-DQmnRpLj7f6TgN/NYb0MTzJXL+vJF9h3pHy4JhCIs3zwcgez8xmGg3sXHcEO97BrmO2OSvCwMdfdlyl+E9KjOw== + version "9.4.1" + resolved "https://registry.yarnpkg.com/espree/-/espree-9.4.1.tgz#51d6092615567a2c2cff7833445e37c28c0065bd" + integrity sha512-XwctdmTO6SIvCzd9810yyNzIrOrqNYV9Koizx4C/mRhf9uq0o4yHoCEU/670pOxOL/MSraektvSAji79kX90Vg== dependencies: acorn "^8.8.0" acorn-jsx "^5.3.2" @@ -250,6 +315,11 @@ estraverse@^5.1.0: resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.3.0.tgz#2eea5290702f26ab8fe5370370ff86c965d21123" integrity sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA== +estree-walker@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/estree-walker/-/estree-walker-2.0.2.tgz#52f010178c2a4c117a7757cfe942adb7d2da4cac" + integrity sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w== + esutils@^2.0.2: version "2.0.3" resolved "https://registry.yarnpkg.com/esutils/-/esutils-2.0.3.tgz#74d2eb4de0b8da1293711910d50775b9b710ef64" @@ -265,6 +335,16 @@ fs.realpath@^1.0.0: resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw== +fsevents@~2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" + integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + glob@^7.1.3: version "7.2.3" resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" @@ -277,7 +357,7 @@ glob@^7.1.3: once "^1.3.0" path-is-absolute "^1.0.0" -glob@^8.0.0: +glob@^8.0.0, glob@^8.0.3: version "8.0.3" resolved "https://registry.yarnpkg.com/glob/-/glob-8.0.3.tgz#415c6eb2deed9e502c68fa44a272e6da6eeca42e" integrity sha512-ull455NHSHI/Y1FqGaaYFaLGkNMMJbavMrEGFXG/PGrg6y7sutWHUHrz6gy6WEBH6akM1M414dWKCNs+IhKdiQ== @@ -303,6 +383,13 @@ has-flag@^4.0.0: resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== +has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + inflight@^1.0.4: version "1.0.6" resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" @@ -316,6 +403,32 @@ inherits@2: resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== +is-builtin-module@^3.2.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/is-builtin-module/-/is-builtin-module-3.2.0.tgz#bb0310dfe881f144ca83f30100ceb10cf58835e0" + integrity sha512-phDA4oSGt7vl1n5tJvTWooWWAsXLY+2xCnxNqvKhGEzujg+A43wPlPOyDg3C8XQHN+6k/JTQWJ/j0dQh/qr+Hw== + dependencies: + builtin-modules "^3.3.0" + +is-core-module@^2.9.0: + version "2.11.0" + resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.11.0.tgz#ad4cb3e3863e814523c96f3f58d26cc570ff0144" + integrity sha512-RRjxlvLDkD1YJwDbroBHMb+cukurkDWNyHx7D3oNB5x9rb5ogcksMC5wHCadcXoo67gVr/+3GFySh3134zi6rw== + dependencies: + has "^1.0.3" + +is-module@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" + integrity sha512-51ypPSPCoTEIN9dy5Oy+h4pShgJmPCygKfyRCISBI+JoWT/2oJvK8QPxmwv7b/p239jXrm9M1mlQbyKJ5A152g== + +is-reference@1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/is-reference/-/is-reference-1.2.1.tgz#8b2dac0b371f4bc994fdeaba9eb542d03002d0b7" + integrity sha512-U82MsXXiFIrjCK4otLT+o2NA2Cd2g5MLoOVXUZjIOhLurrRxpEXzI8O0KZHr3IjLvlAH1kTPYSuqer5T9ZVBKQ== + dependencies: + "@types/estree" "*" + js2xmlparser@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" @@ -372,9 +485,9 @@ lodash@^4.17.14, lodash@^4.17.15: integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== long@^5.0.0: - version "5.2.0" - resolved "https://registry.yarnpkg.com/long/-/long-5.2.0.tgz#2696dadf4b4da2ce3f6f6b89186085d94d52fd61" - integrity sha512-9RTUNjK60eJbx3uz+TEGF7fUr29ZDxR5QzXcyDpeSfeH28S9ycINflOgOlppit5U+4kNTe83KQnMEerw7GmE8w== + version "5.2.1" + resolved "https://registry.yarnpkg.com/long/-/long-5.2.1.tgz#e27595d0083d103d2fa2c20c7699f8e0c92b897f" + integrity sha512-GKSNGeNAtw8IryjjkhZxuKB3JzlcLTwjtiQCHKvqQet81I93kXslhDQruGI/QsddO83mcDToBVy7GqGS/zYf/A== lru-cache@^6.0.0: version "6.0.0" @@ -383,6 +496,13 @@ lru-cache@^6.0.0: dependencies: yallist "^4.0.0" +magic-string@^0.26.4: + version "0.26.7" + resolved "https://registry.yarnpkg.com/magic-string/-/magic-string-0.26.7.tgz#caf7daf61b34e9982f8228c4527474dac8981d6f" + integrity sha512-hX9XH3ziStPoPhJxLq1syWuZMxbDvGNbVchfrdCtanC7D13888bMFow61x8axrx+GfHLtVeAx2kxL7tTGRl+Ow== + dependencies: + sourcemap-codec "^1.4.8" + markdown-it-anchor@^8.4.1: version "8.6.5" resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" @@ -400,9 +520,9 @@ markdown-it@^12.3.2: uc.micro "^1.0.5" marked@^4.0.10: - version "4.2.1" - resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.1.tgz#eaa32594e45b4e58c02e4d118531fd04345de3b4" - integrity sha512-VK1/jNtwqDLvPktNpL0Fdg3qoeUZhmRsuiIjPEy/lHwXW4ouLoZfO4XoWd4ClDt+hupV1VLpkZhEovjU0W/kqA== + version "4.2.2" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.2.tgz#1d2075ad6cdfe42e651ac221c32d949a26c0672a" + integrity sha512-JjBTFTAvuTgANXx82a5vzK9JLSMoV6V3LBVn4Uhdso6t7vXrGx7g1Cd2r6NYSsxrYbQGFCMqBDhFHyK5q2UvcQ== mdurl@^1.0.1: version "1.0.1" @@ -457,6 +577,16 @@ path-is-absolute@^1.0.0: resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== +path-parse@^1.0.7: + version "1.0.7" + resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" + integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== + +picomatch@^2.3.1: + version "2.3.1" + resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" + integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== + prelude-ls@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.1.2.tgz#21932a549f5e52ffd9a827f570e04be62a97da54" @@ -503,6 +633,15 @@ requizzle@^0.2.3: dependencies: lodash "^4.17.14" +resolve@^1.22.1: + version "1.22.1" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" + integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== + dependencies: + is-core-module "^2.9.0" + path-parse "^1.0.7" + supports-preserve-symlinks-flag "^1.0.0" + rimraf@^3.0.0: version "3.0.2" resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-3.0.2.tgz#f1a5402ba6220ad52cc1282bac1ae3aa49fd061a" @@ -510,6 +649,13 @@ rimraf@^3.0.0: dependencies: glob "^7.1.3" +rollup@^2.3.0: + version "2.79.1" + resolved "https://registry.yarnpkg.com/rollup/-/rollup-2.79.1.tgz#bedee8faef7c9f93a2647ac0108748f497f081c7" + integrity sha512-uKxbd0IhMZOhjAiD5oAFp7BqvkA4Dv47qpOCtaNvng4HBwdbWtdOh8f5nZNuk2rp51PMGk3bzfWu5oayNEuYnw== + optionalDependencies: + fsevents "~2.3.2" + semver@5.6.0: version "5.6.0" resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" @@ -535,6 +681,11 @@ source-map@^0.6.0, source-map@~0.6.1: resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +sourcemap-codec@^1.4.8: + version "1.4.8" + resolved "https://registry.yarnpkg.com/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz#ea804bd94857402e6992d05a38ef1ae35a9ab4c4" + integrity sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA== + strip-json-comments@^3.1.0: version "3.1.1" resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" @@ -547,6 +698,11 @@ supports-color@^7.1.0: dependencies: has-flag "^4.0.0" +supports-preserve-symlinks-flag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" + integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== + taffydb@2.6.2: version "2.6.2" resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268" From c31aaa94a6ba862b5148922ebff8ec0f6f127316 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 16:47:48 -0800 Subject: [PATCH 08/10] Adds a `BertClassifier`. PiperOrigin-RevId: 487086744 --- mediapipe/model_maker/python/core/tasks/BUILD | 3 + .../python/core/tasks/classifier.py | 67 +++++++++++++++++-- .../image_classifier/image_classifier.py | 2 +- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD index 124de621a..8c5448556 100644 --- a/mediapipe/model_maker/python/core/tasks/BUILD +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -45,7 +45,10 @@ py_library( srcs = ["classifier.py"], deps = [ ":custom_model", + "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:dataset", + "//mediapipe/model_maker/python/core/utils:model_util", ], ) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 5d0fbd066..200726864 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -13,24 +13,24 @@ # limitations under the License. """Custom classifier.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import os -from typing import Any, List +from typing import Any, Callable, Optional, Sequence, Union import tensorflow as tf +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.tasks import custom_model +from mediapipe.model_maker.python.core.utils import model_util class Classifier(custom_model.CustomModel): """An abstract base class that represents a TensorFlow classifier.""" - def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool): - """Initilizes a classifier with its specifications. + def __init__(self, model_spec: Any, label_names: Sequence[str], + shuffle: bool): + """Initializes a classifier with its specifications. Args: model_spec: Specification for the model. @@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel): super(Classifier, self).__init__(model_spec, shuffle) self._label_names = label_names self._num_classes = len(label_names) + self._model: tf.keras.Model = None + self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None + self._loss_function: Union[str, tf.keras.losses.Loss] = None + self._metric_function: Union[str, tf.keras.metrics.Metric] = None + self._callbacks: Sequence[tf.keras.callbacks.Callback] = None + self._hparams: hp.BaseHParams = None + self._history: tf.keras.callbacks.History = None + + # TODO: Integrate this into all Model Maker tasks. + def _train_model(self, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + preprocessor: Optional[Callable[..., bool]] = None): + """Trains the classifier model. + + Compiles and fits the tf.keras `_model` and records the `_history`. + + Args: + train_data: Training data. + validation_data: Validation data. + preprocessor: An optional data preprocessor that can be used when + generating a tf.data.Dataset. + """ + tf.compat.v1.logging.info('Training the models...') + if len(train_data) < self._hparams.batch_size: + raise ValueError( + f'The size of the train_data {len(train_data)} can\'t be smaller than' + f' batch_size {self._hparams.batch_size}. To solve this problem, set' + ' the batch_size smaller or increase the size of the train_data.') + + train_dataset = train_data.gen_tf_dataset( + batch_size=self._hparams.batch_size, + is_training=True, + shuffle=self._shuffle, + preprocess=preprocessor) + self._hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, + train_data=train_data) + train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch) + validation_dataset = validation_data.gen_tf_dataset( + batch_size=self._hparams.batch_size, + is_training=False, + preprocess=preprocessor) + self._model.compile( + optimizer=self._optimizer, + loss=self._loss_function, + metrics=[self._metric_function]) + self._history = self._model.fit( + x=train_dataset, + epochs=self._hparams.epochs, + validation_data=validation_dataset, + callbacks=self._callbacks) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: """Evaluates the classifier with the provided evaluation dataset. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index f6edbeab4..1ff6132b4 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -193,7 +193,7 @@ class ImageClassifier(classifier.Classifier): tflite_model, self._model_spec.mean_rgb, self._model_spec.stddev_rgb, - labels=metadata_writer.Labels().add(self._label_names)) + labels=metadata_writer.Labels().add(list(self._label_names))) tflite_model_with_metadata, metadata_json = writer.populate() model_util.save_tflite(tflite_model_with_metadata, tflite_file) with open(metadata_file, 'w') as f: From a5bcb97d888df5ee721b80c37dea601291d888c0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 18:27:53 -0800 Subject: [PATCH 09/10] Adds an `AverageWordVecModel`. PiperOrigin-RevId: 487104909 --- mediapipe/model_maker/python/core/utils/model_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 01e301e43..f10d9390c 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -39,11 +39,10 @@ def get_default_callbacks( """Gets default callbacks.""" summary_dir = os.path.join(export_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) - # Save checkpoint every 20 epochs. checkpoint_path = os.path.join(export_dir, 'checkpoint') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( - checkpoint_path, save_weights_only=True, period=20) + checkpoint_path, save_weights_only=True) return [summary_callback, checkpoint_callback] From b4e1833dd06f6fe3a4fe2e396df45e1fa0bf902e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 20:34:23 -0800 Subject: [PATCH 10/10] Internal change PiperOrigin-RevId: 487125366 --- mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 39353b226..3d3ca5ae7 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -198,7 +198,7 @@ export class AudioClassifier extends TaskRunner { classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); classifierNode.addOutputStream( - 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + 'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM); classifierNode.setOptions(calculatorOptions); graphConfig.addNode(classifierNode);