Move TextPreprocessing to "processors" folder.

PiperOrigin-RevId: 490532670
This commit is contained in:
MediaPipe Team 2022-11-23 10:17:46 -08:00 committed by Copybara-Service
parent 54d1744c8f
commit bfa57310c4
11 changed files with 80 additions and 93 deletions

View File

@ -1,43 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(
name = "text_preprocessing_graph",
srcs = ["text_preprocessing_graph.cc"],
hdrs = ["text_preprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -133,3 +133,29 @@ cc_library(
)
# TODO: Enable this test
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(
name = "text_preprocessing_graph",
srcs = ["text_preprocessing_graph.cc"],
hdrs = ["text_preprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -59,3 +59,12 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.components.proto;
package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto";

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include <string>
@ -25,13 +25,14 @@ limitations under the License.
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/subgraph.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
@ -41,7 +42,8 @@ using ::mediapipe::api2::SideInput;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::SideSource;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions;
using ::mediapipe::tasks::components::processors::proto::
TextPreprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
@ -169,7 +171,7 @@ absl::StatusOr<int> GetMaxSeqLen(const tflite::SubGraph& model_graph) {
}
} // namespace
absl::Status ConfigureTextPreprocessingSubgraph(
absl::Status ConfigureTextPreprocessingGraph(
const ModelResources& model_resources,
TextPreprocessingGraphOptions& options) {
if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) {
@ -200,8 +202,7 @@ absl::Status ConfigureTextPreprocessingSubgraph(
return absl::OkStatus();
}
// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text
// preprocessing.
// A TextPreprocessingGraph performs text preprocessing.
// - Accepts a std::string input and outputs CPU tensors.
//
// Inputs:
@ -216,9 +217,9 @@ absl::Status ConfigureTextPreprocessingSubgraph(
// Vector containing the preprocessed input tensors for the TFLite model.
//
// The recommended way of using this subgraph is through the GraphBuilder API
// using the 'ConfigureTextPreprocessing()' function. See header file for more
// details.
class TextPreprocessingSubgraph : public mediapipe::Subgraph {
// using the 'ConfigureTextPreprocessingGraph()' function. See header file for
// more details.
class TextPreprocessingGraph : public mediapipe::Subgraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
@ -267,8 +268,9 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::TextPreprocessingSubgraph);
::mediapipe::tasks::components::processors::TextPreprocessingGraph);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,26 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_
#include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
// Configures a TextPreprocessing subgraph using the provided `model_resources`
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Configures a TextPreprocessingGraph using the provided `model_resources`
// and TextPreprocessingGraphOptions.
// - Accepts a std::string input and outputs CPU tensors.
//
// Example usage:
//
// auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
// graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph");
// MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph(
// model_resources,
// &preprocessing.GetOptions<TextPreprocessingGraphOptions>()));
//
// The resulting TextPreprocessing subgraph has the following I/O:
// The resulting TextPreprocessingGraph has the following I/O:
// Inputs:
// TEXT - std::string
// The text to preprocess.
@ -43,16 +48,13 @@ limitations under the License.
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the preprocessed input tensors for the TFLite model.
namespace mediapipe {
namespace tasks {
namespace components {
absl::Status ConfigureTextPreprocessingSubgraph(
const tasks::core::ModelResources& model_resources,
tasks::components::proto::TextPreprocessingGraphOptions& options);
absl::Status ConfigureTextPreprocessingGraph(
const core::ModelResources& model_resources,
proto::TextPreprocessingGraphOptions& options);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_

View File

@ -22,12 +22,3 @@ mediapipe_proto_library(
name = "segmenter_options_proto",
srcs = ["segmenter_options.proto"],
)
mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)

View File

@ -52,11 +52,11 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_calculator",
"//mediapipe/tasks/cc/core:model_task_graph",

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
@ -115,12 +115,12 @@ class TextClassifierGraph : public core::ModelTaskGraph {
Graph& graph) {
// Adds preprocessing calculators and connects them to the text input
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.TextPreprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph(
model_resources,
preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>()));
components::processors::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator.

View File

@ -54,11 +54,11 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
Graph& graph) {
// Adds preprocessing calculators and connects them to the text input
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.TextPreprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph(
model_resources,
preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>()));
components::processors::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator.