Move TextPreprocessing to "processors" folder.

PiperOrigin-RevId: 490532670
This commit is contained in:
MediaPipe Team 2022-11-23 10:17:46 -08:00 committed by Copybara-Service
parent 54d1744c8f
commit bfa57310c4
11 changed files with 80 additions and 93 deletions

View File

@ -1,43 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(
name = "text_preprocessing_graph",
srcs = ["text_preprocessing_graph.cc"],
hdrs = ["text_preprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -133,3 +133,29 @@ cc_library(
) )
# TODO: Enable this test # TODO: Enable this test
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(
name = "text_preprocessing_graph",
srcs = ["text_preprocessing_graph.cc"],
hdrs = ["text_preprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -59,3 +59,12 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
], ],
) )
mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include <string> #include <string>
@ -25,13 +25,14 @@ limitations under the License.
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/subgraph.h" #include "mediapipe/framework/subgraph.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
@ -41,7 +42,8 @@ using ::mediapipe::api2::SideInput;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::SideSource;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; using ::mediapipe::tasks::components::processors::proto::
TextPreprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
@ -169,7 +171,7 @@ absl::StatusOr<int> GetMaxSeqLen(const tflite::SubGraph& model_graph) {
} }
} // namespace } // namespace
absl::Status ConfigureTextPreprocessingSubgraph( absl::Status ConfigureTextPreprocessingGraph(
const ModelResources& model_resources, const ModelResources& model_resources,
TextPreprocessingGraphOptions& options) { TextPreprocessingGraphOptions& options) {
if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) {
@ -200,8 +202,7 @@ absl::Status ConfigureTextPreprocessingSubgraph(
return absl::OkStatus(); return absl::OkStatus();
} }
// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text // A TextPreprocessingGraph performs text preprocessing.
// preprocessing.
// - Accepts a std::string input and outputs CPU tensors. // - Accepts a std::string input and outputs CPU tensors.
// //
// Inputs: // Inputs:
@ -216,9 +217,9 @@ absl::Status ConfigureTextPreprocessingSubgraph(
// Vector containing the preprocessed input tensors for the TFLite model. // Vector containing the preprocessed input tensors for the TFLite model.
// //
// The recommended way of using this subgraph is through the GraphBuilder API // The recommended way of using this subgraph is through the GraphBuilder API
// using the 'ConfigureTextPreprocessing()' function. See header file for more // using the 'ConfigureTextPreprocessingGraph()' function. See header file for
// details. // more details.
class TextPreprocessingSubgraph : public mediapipe::Subgraph { class TextPreprocessingGraph : public mediapipe::Subgraph {
public: public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
@ -267,8 +268,9 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::TextPreprocessingSubgraph); ::mediapipe::tasks::components::processors::TextPreprocessingGraph);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,26 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
// Configures a TextPreprocessing subgraph using the provided `model_resources` namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Configures a TextPreprocessingGraph using the provided `model_resources`
// and TextPreprocessingGraphOptions. // and TextPreprocessingGraphOptions.
// - Accepts a std::string input and outputs CPU tensors. // - Accepts a std::string input and outputs CPU tensors.
// //
// Example usage: // Example usage:
// //
// auto& preprocessing = // auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph");
// MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( // MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph(
// model_resources, // model_resources,
// &preprocessing.GetOptions<TextPreprocessingGraphOptions>())); // &preprocessing.GetOptions<TextPreprocessingGraphOptions>()));
// //
// The resulting TextPreprocessing subgraph has the following I/O: // The resulting TextPreprocessingGraph has the following I/O:
// Inputs: // Inputs:
// TEXT - std::string // TEXT - std::string
// The text to preprocess. // The text to preprocess.
@ -43,16 +48,13 @@ limitations under the License.
// Outputs: // Outputs:
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// Vector containing the preprocessed input tensors for the TFLite model. // Vector containing the preprocessed input tensors for the TFLite model.
namespace mediapipe { absl::Status ConfigureTextPreprocessingGraph(
namespace tasks { const core::ModelResources& model_resources,
namespace components { proto::TextPreprocessingGraphOptions& options);
absl::Status ConfigureTextPreprocessingSubgraph(
const tasks::core::ModelResources& model_resources,
tasks::components::proto::TextPreprocessingGraphOptions& options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_

View File

@ -22,12 +22,3 @@ mediapipe_proto_library(
name = "segmenter_options_proto", name = "segmenter_options_proto",
srcs = ["segmenter_options.proto"], srcs = ["segmenter_options.proto"],
) )
mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)

View File

@ -52,11 +52,11 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_calculator", "//mediapipe/tasks/cc/core:model_resources_calculator",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
@ -115,12 +115,12 @@ class TextClassifierGraph : public core::ModelTaskGraph {
Graph& graph) { Graph& graph) {
// Adds preprocessing calculators and connects them to the text input // Adds preprocessing calculators and connects them to the text input
// stream. // stream.
auto& preprocessing = auto& preprocessing = graph.AddNode(
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); "mediapipe.tasks.components.processors.TextPreprocessingGraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph(
model_resources, model_resources,
preprocessing.GetOptions< preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>())); components::processors::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag); text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator. // Adds both InferenceCalculator and ModelResourcesCalculator.

View File

@ -54,11 +54,11 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
Graph& graph) { Graph& graph) {
// Adds preprocessing calculators and connects them to the text input // Adds preprocessing calculators and connects them to the text input
// stream. // stream.
auto& preprocessing = auto& preprocessing = graph.AddNode(
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); "mediapipe.tasks.components.processors.TextPreprocessingGraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph(
model_resources, model_resources,
preprocessing.GetOptions< preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>())); components::processors::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag); text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator. // Adds both InferenceCalculator and ModelResourcesCalculator.