Open-source the TextClassifier C++ API.
PiperOrigin-RevId: 482218721
This commit is contained in:
parent
7a6ae97a0e
commit
70df9e2419
84
mediapipe/tasks/cc/text/text_classifier/BUILD
Normal file
84
mediapipe/tasks/cc/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier_graph",
|
||||
srcs = ["text_classifier_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_calculator",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier",
|
||||
srcs = ["text_classifier.cc"],
|
||||
hdrs = ["text_classifier.h"],
|
||||
deps = [
|
||||
":text_classifier_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier_test_utils",
|
||||
srcs = ["text_classifier_test_utils.cc"],
|
||||
hdrs = ["text_classifier_test_utils.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite:mutable_op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||
],
|
||||
)
|
30
mediapipe/tasks/cc/text/text_classifier/proto/BUILD
Normal file
30
mediapipe/tasks/cc/text/text_classifier/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_classifier_graph_options_proto",
|
||||
srcs = ["text_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.text.text_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message TextClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TextClassifierGraphOptions ext = 462704549;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
}
|
104
mediapipe/tasks/cc/text/text_classifier/text_classifier.cc
Normal file
104
mediapipe/tasks/cc/text/text_classifier/text_classifier.cc
Normal file
|
@ -0,0 +1,104 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kTextStreamName[] = "text_in";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// type "TextClassifierGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<proto::TextClassifierGraphOptions> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
|
||||
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||
subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing TextClassifierOptions struct to the internal
|
||||
// TextClassifierGraphOptions proto.
|
||||
std::unique_ptr<proto::TextClassifierGraphOptions>
|
||||
ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) {
|
||||
auto options_proto = std::make_unique<proto::TextClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
|
||||
std::unique_ptr<TextClassifierOptions> options) {
|
||||
auto options_proto = ConvertTextClassifierOptionsToProto(options.get());
|
||||
return core::TaskApiFactory::Create<TextClassifier,
|
||||
proto::TextClassifierGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
|
||||
absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
return output_packets[kClassificationResultStreamName]
|
||||
.Get<ClassificationResult>();
|
||||
}
|
||||
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
96
mediapipe/tasks/cc/text/text_classifier/text_classifier.h
Normal file
96
mediapipe/tasks/cc/text/text_classifier/text_classifier.h
Normal file
|
@ -0,0 +1,96 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
// The options for configuring a MediaPipe text classifier task.
|
||||
struct TextClassifierOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
};
|
||||
|
||||
// Performs classification on text.
|
||||
//
|
||||
// This API expects a TFLite model with (optional) TFLite Model Metadata that
|
||||
// contains the mandatory (described below) input tensors, output tensor,
|
||||
// and the optional (but recommended) label items as AssociatedFiles with type
|
||||
// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for
|
||||
// models with int32 input tensors because it contains the input process unit
|
||||
// for the model's Tokenizer. No metadata is required for models with string
|
||||
// input tensors.
|
||||
//
|
||||
// Input tensors:
|
||||
// (kTfLiteInt32)
|
||||
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing
|
||||
// the input ids, segment ids, and mask ids
|
||||
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
|
||||
// input ids
|
||||
// or (kTfLiteString)
|
||||
// - 1 input tensor that is shapeless or has shape [1] containing the input
|
||||
// string
|
||||
// At least one output tensor with:
|
||||
// (kTfLiteFloat32/kBool)
|
||||
// - `[1 x N]` array with `N` represents the number of categories.
|
||||
// - optional (but recommended) label items as AssociatedFiles with type
|
||||
// TENSOR_AXIS_LABELS, containing one label per line. The first such
|
||||
// AssociatedFile (if any) is used to fill the `category_name` field of the
|
||||
// results. The `display_name` field is filled from the AssociatedFile (if
|
||||
// any) whose locale matches the `display_names_locale` field of the
|
||||
// `TextClassifierOptions` used at creation time ("en" by default, i.e.
|
||||
// English). If none of these are available, only the `index` field of the
|
||||
// results will be filled.
|
||||
class TextClassifier : core::BaseTaskApi {
|
||||
public:
|
||||
using BaseTaskApi::BaseTaskApi;
|
||||
|
||||
// Creates a TextClassifier from the provided `options`.
|
||||
static absl::StatusOr<std::unique_ptr<TextClassifier>> Create(
|
||||
std::unique_ptr<TextClassifierOptions> options);
|
||||
|
||||
// Performs classification on the input `text`.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
absl::string_view text);
|
||||
|
||||
// Shuts down the TextClassifier when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
162
mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc
Normal file
162
mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc
Normal file
|
@ -0,0 +1,162 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||
// BERT-based text classification).
|
||||
// - Accepts input text and outputs classification results on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// Input text to perform classification on.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The aggregated classification result object that has 3 dimensions:
|
||||
// (classification head, classification timestamp, classification category).
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
|
||||
// input_stream: "TEXT:text_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TextClassifierGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(
|
||||
const ModelResources* model_resources,
|
||||
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
Source<ClassificationResult> classification_result_out,
|
||||
BuildTextClassifierTask(
|
||||
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<std::string>(kTextTag)], graph));
|
||||
classification_result_out >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds a mediapipe TextClassifier task graph into the provided
|
||||
// builder::Graph instance. The TextClassifier task takes an input
|
||||
// text (std::string) and returns one classification result per output head
|
||||
// specified by the model.
|
||||
//
|
||||
// task_options: the mediapipe tasks TextClassifierGraphOptions proto.
|
||||
// model_resources: the ModelResources object initialized from a
|
||||
// TextClassifier model file with model metadata.
|
||||
// text_in: (std::string) stream to run text classification on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
||||
const proto::TextClassifierGraphOptions& task_options,
|
||||
const ModelResources& model_resources, Source<std::string> text_in,
|
||||
Graph& graph) {
|
||||
// Adds preprocessing calculators and connects them to the text input
|
||||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
|
||||
model_resources,
|
||||
preprocessing.GetOptions<
|
||||
tasks::components::proto::TextPreprocessingGraphOptions>()));
|
||||
text_in >> preprocessing.In(kTextTag);
|
||||
|
||||
// Adds both InferenceCalculator and ModelResourcesCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
// The metadata extractor side-output comes from the
|
||||
// ModelResourcesCalculator.
|
||||
inference.SideOut(kMetadataExtractorTag) >>
|
||||
preprocessing.SideIn(kMetadataExtractorTag);
|
||||
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||
|
||||
// Adds postprocessing calculators and connects them to the graph output.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
// stream.
|
||||
return postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)];
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::text::text_classifier::TextClassifierGraph);
|
||||
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
238
mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
Normal file
238
mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
Normal file
|
@ -0,0 +1,238 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::EqualsProto;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::IgnoringRepeatedFieldOrdering;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kEpsilon = 0.001;
|
||||
constexpr int kMaxSeqLen = 128;
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||
constexpr char kTestRegexModelPath[] =
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
constexpr char kStringToBoolModelPath[] =
|
||||
"test_model_text_classifier_bool_output.tflite";
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
StatusOr<std::unique_ptr<TextClassifier>> classifier =
|
||||
TextClassifier::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
classifier.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, CreateFailsWithMissingModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||
StatusOr<std::unique_ptr<TextClassifier>> classifier =
|
||||
TextClassifier::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
|
||||
EXPECT_THAT(classifier.status().message(),
|
||||
HasSubstr("Unable to open file at"));
|
||||
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult negative_result,
|
||||
classifier->Classify("unflinchingly bleak and desperate"));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.956 }
|
||||
categories { category_name: "positive" score: 0.044 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("it's a charming and often affecting journey"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.0 }
|
||||
categories { category_name: "positive" score: 1.0 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
|
||||
classifier->Classify("What a waste of my time."));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.813 }
|
||||
categories { category_name: "Positive" score: 0.187 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("This is the best movie I’ve seen in recent years. "
|
||||
"Strongly recommend it!"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.487 }
|
||||
categories { category_name: "Positive" score: 0.513 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||
options->base_options.op_resolver = CreateCustomResolver();
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify("hello"));
|
||||
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { index: 1 score: 1 }
|
||||
categories { index: 0 score: 1 }
|
||||
categories { index: 2 score: 0 }
|
||||
}
|
||||
}
|
||||
)pb"))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, BertLongPositive) {
|
||||
std::stringstream ss_for_positive_review;
|
||||
ss_for_positive_review
|
||||
<< "it's a charming and often affecting journey and this is a long";
|
||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||
ss_for_positive_review << " long";
|
||||
}
|
||||
ss_for_positive_review << " movie review";
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify(ss_for_positive_review.str()));
|
||||
ASSERT_THAT(result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.014 }
|
||||
categories { category_name: "positive" score: 0.986 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,131 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/portable_type_to_tflitetype.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::CreateStatusWithPayload;
|
||||
using ::tflite::GetInput;
|
||||
using ::tflite::GetOutput;
|
||||
using ::tflite::GetString;
|
||||
using ::tflite::StringRef;
|
||||
|
||||
constexpr absl::string_view kInputStr = "hello";
|
||||
constexpr bool kBooleanData[] = {true, true, false};
|
||||
constexpr size_t kBooleanDataSize = std::size(kBooleanData);
|
||||
|
||||
// Checks and returns type of a tensor, fails if tensor type is not T.
|
||||
template <typename T>
|
||||
absl::StatusOr<T*> AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
|
||||
if (!tensor->data.raw) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInternal,
|
||||
absl::StrFormat("Tensor (%s) has no raw data.", tensor->name));
|
||||
}
|
||||
|
||||
// Checks if data type of tensor is T and returns the pointer casted to T if
|
||||
// applicable, returns nullptr if tensor type is not T.
|
||||
// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
|
||||
if (tensor->type == tflite::typeToTfLiteType<T>()) {
|
||||
return reinterpret_cast<T*>(tensor->data.raw);
|
||||
}
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInternal,
|
||||
absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.",
|
||||
tensor->name, tflite::typeToTfLiteType<T>(),
|
||||
tensor->bytes));
|
||||
}
|
||||
|
||||
// Populates tensor with array of data, fails if data type doesn't match tensor
|
||||
// type or they don't have the same number of elements.
|
||||
template <typename T, typename = std::enable_if_t<
|
||||
std::negation_v<std::is_same<T, std::string>>>>
|
||||
absl::Status PopulateTensor(const T* data, int num_elements,
|
||||
TfLiteTensor* tensor) {
|
||||
ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor<T>(tensor));
|
||||
size_t bytes = num_elements * sizeof(T);
|
||||
if (tensor->bytes != bytes) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInternal,
|
||||
absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes,
|
||||
bytes));
|
||||
}
|
||||
std::memcpy(v, data, bytes);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||
dims->data[0] = kBooleanDataSize;
|
||||
return context->ResizeTensor(context, output, dims);
|
||||
}
|
||||
|
||||
TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input_tensor != nullptr);
|
||||
StringRef input_str_ref = GetString(input_tensor, 0);
|
||||
std::string input_str(input_str_ref.str, input_str_ref.len);
|
||||
if (input_str != kInputStr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// This custom op takes a string tensor in and outputs a bool tensor with
|
||||
// value{true, true, false}, it's used to mimic a real text classification model
|
||||
// which classifies a string into scores of different categories.
|
||||
TfLiteRegistration* RegisterStringToBool() {
|
||||
// Dummy implementation of custom OP
|
||||
// This op takes string as input and outputs bool[]
|
||||
static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr,
|
||||
/* prepare= */ PrepareStringToBool,
|
||||
/* invoke= */ InvokeStringToBool};
|
||||
return &r;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver() {
|
||||
tflite::MutableOpResolver resolver;
|
||||
resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool());
|
||||
return std::make_unique<tflite::MutableOpResolver>(resolver);
|
||||
}
|
||||
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
|
||||
// Create a custom MutableOpResolver to provide custom OP implementations to
|
||||
// mimic classification behavior.
|
||||
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver();
|
||||
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user