From 70df9e24194ba534c2301664f181935f4f4f59e9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 19 Oct 2022 09:01:56 -0700 Subject: [PATCH] Open-source the TextClassifier C++ API. PiperOrigin-RevId: 482218721 --- mediapipe/tasks/cc/text/text_classifier/BUILD | 84 +++++++ .../tasks/cc/text/text_classifier/proto/BUILD | 30 +++ .../proto/text_classifier_graph_options.proto | 35 +++ .../text/text_classifier/text_classifier.cc | 104 ++++++++ .../cc/text/text_classifier/text_classifier.h | 96 +++++++ .../text_classifier/text_classifier_graph.cc | 162 ++++++++++++ .../text_classifier/text_classifier_test.cc | 238 ++++++++++++++++++ .../text_classifier_test_utils.cc | 131 ++++++++++ .../text_classifier_test_utils.h | 35 +++ 9 files changed, 915 insertions(+) create mode 100644 mediapipe/tasks/cc/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/cc/text/text_classifier/proto/BUILD create mode 100644 mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto create mode 100644 mediapipe/tasks/cc/text/text_classifier/text_classifier.cc create mode 100644 mediapipe/tasks/cc/text/text_classifier/text_classifier.h create mode 100644 mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc create mode 100644 mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc create mode 100644 mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc create mode 100644 mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD new file mode 100644 index 000000000..a85538631 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -0,0 +1,84 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_classifier_graph", + srcs = ["text_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_calculator", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + deps = [ + ":text_classifier_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + +cc_library( + name = "text_classifier_test_utils", + srcs = ["text_classifier_test_utils.cc"], + hdrs = ["text_classifier_test_utils.h"], + visibility = ["//visibility:private"], + deps = [ + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite:mutable_op_resolver", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +) diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/BUILD b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD new file mode 100644 index 000000000..f2b544d87 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "text_classifier_graph_options_proto", + srcs = ["text_classifier_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto new file mode 100644 index 000000000..58d5fa9f8 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.text.text_classifier.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +message TextClassifierGraphOptions { + extend mediapipe.CalculatorOptions { + optional TextClassifierGraphOptions ext = 462704549; + } + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + optional components.processors.proto.ClassifierOptions classifier_options = 2; +} diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc new file mode 100644 index 000000000..699f15bc0 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc @@ -0,0 +1,104 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +namespace { + +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + +constexpr char kTextStreamName[] = "text_in"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kClassificationResultStreamName[] = "classification_result_out"; +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// type "TextClassifierGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kSubgraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); + subgraph.Out(kClassificationResultTag) + .SetName(kClassificationResultStreamName) >> + graph.Out(kClassificationResultTag); + return graph.GetConfig(); +} + +// Converts the user-facing TextClassifierOptions struct to the internal +// TextClassifierGraphOptions proto. +std::unique_ptr +ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + auto classifier_options_proto = + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( + &(options->classifier_options))); + options_proto->mutable_classifier_options()->Swap( + classifier_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> TextClassifier::Create( + std::unique_ptr options) { + auto options_proto = ConvertTextClassifierOptionsToProto(options.get()); + return core::TaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver)); +} + +absl::StatusOr TextClassifier::Classify( + absl::string_view text) { + ASSIGN_OR_RETURN( + auto output_packets, + runner_->Process( + {{kTextStreamName, MakePacket(std::string(text))}})); + return output_packets[kClassificationResultStreamName] + .Get(); +} + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h new file mode 100644 index 000000000..b027a9787 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h @@ -0,0 +1,96 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +// The options for configuring a MediaPipe text classifier task. +struct TextClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + components::processors::ClassifierOptions classifier_options; +}; + +// Performs classification on text. +// +// This API expects a TFLite model with (optional) TFLite Model Metadata that +// contains the mandatory (described below) input tensors, output tensor, +// and the optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for +// models with int32 input tensors because it contains the input process unit +// for the model's Tokenizer. No metadata is required for models with string +// input tensors. +// +// Input tensors: +// (kTfLiteInt32) +// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing +// the input ids, segment ids, and mask ids +// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the +// input ids +// or (kTfLiteString) +// - 1 input tensor that is shapeless or has shape [1] containing the input +// string +// At least one output tensor with: +// (kTfLiteFloat32/kBool) +// - `[1 x N]` array with `N` represents the number of categories. +// - optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `category_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `TextClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +class TextClassifier : core::BaseTaskApi { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a TextClassifier from the provided `options`. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs classification on the input `text`. + absl::StatusOr Classify( + absl::string_view text); + + // Shuts down the TextClassifier when all the work is done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc new file mode 100644 index 000000000..9706db4d8 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -0,0 +1,162 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::ModelResources; + +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; +constexpr char kTensorsTag[] = "TENSORS"; + +} // namespace + +// A "TextClassifierGraph" performs Natural Language classification (including +// BERT-based text classification). +// - Accepts input text and outputs classification results on CPU. +// +// Inputs: +// TEXT - std::string +// Input text to perform classification on. +// +// Outputs: +// CLASSIFICATION_RESULT - ClassificationResult +// The aggregated classification result object that has 3 dimensions: +// (classification head, classification timestamp, classification category). +// +// Example: +// node { +// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph" +// input_stream: "TEXT:text_in" +// output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// options { +// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// } +// } +// } +class TextClassifierGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN( + const ModelResources* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + Source classification_result_out, + BuildTextClassifierTask( + sc->Options(), *model_resources, + graph[Input(kTextTag)], graph)); + classification_result_out >> + graph[Output(kClassificationResultTag)]; + return graph.GetConfig(); + } + + private: + // Adds a mediapipe TextClassifier task graph into the provided + // builder::Graph instance. The TextClassifier task takes an input + // text (std::string) and returns one classification result per output head + // specified by the model. + // + // task_options: the mediapipe tasks TextClassifierGraphOptions proto. + // model_resources: the ModelResources object initialized from a + // TextClassifier model file with model metadata. + // text_in: (std::string) stream to run text classification on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr> BuildTextClassifierTask( + const proto::TextClassifierGraphOptions& task_options, + const ModelResources& model_resources, Source text_in, + Graph& graph) { + // Adds preprocessing calculators and connects them to the text input + // stream. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); + MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + model_resources, + preprocessing.GetOptions< + tasks::components::proto::TextPreprocessingGraphOptions>())); + text_in >> preprocessing.In(kTextTag); + + // Adds both InferenceCalculator and ModelResourcesCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + // The metadata extractor side-output comes from the + // ModelResourcesCalculator. + inference.SideOut(kMetadataExtractorTag) >> + preprocessing.SideIn(kMetadataExtractorTag); + preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + + // Adds postprocessing calculators and connects them to the graph output. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + + // Outputs the aggregated classification result as the subgraph output + // stream. + return postprocessing[Output( + kClassificationResultTag)]; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::text::text_classifier::TextClassifierGraph); + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc new file mode 100644 index 000000000..5b33f6606 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { +namespace { + +using ::mediapipe::EqualsProto; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::kMediaPipeTasksPayload; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::proto::Approximately; +using ::testing::proto::IgnoringRepeatedFieldOrdering; +using ::testing::proto::Partially; + +constexpr float kEpsilon = 0.001; +constexpr int kMaxSeqLen = 128; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; +constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; +constexpr char kTestRegexModelPath[] = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +constexpr char kStringToBoolModelPath[] = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +class TextClassifierTest : public tflite_shims::testing::Test {}; + +TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK(TextClassifier::Create(std::move(options))); +} + +TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) { + auto options = std::make_unique(); + StatusOr> classifier = + TextClassifier::Create(std::move(options)); + + EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + classifier.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(TextClassifierTest, CreateFailsWithMissingModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kInvalidModelPath); + StatusOr> classifier = + TextClassifier::Create(std::move(options)); + + EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound); + EXPECT_THAT(classifier.status().message(), + HasSubstr("Unable to open file at")); + EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); + MP_ASSERT_OK(TextClassifier::Create(std::move(options))); +} + +TEST_F(TextClassifierTest, TextClassifierWithBert) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + ClassificationResult negative_result, + classifier->Classify("unflinchingly bleak and desperate")); + ASSERT_THAT(negative_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "negative" score: 0.956 } + categories { category_name: "positive" score: 0.044 } + } + } + )pb"), + kEpsilon)))); + + MP_ASSERT_OK_AND_ASSIGN( + ClassificationResult positive_result, + classifier->Classify("it's a charming and often affecting journey")); + ASSERT_THAT(positive_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "negative" score: 0.0 } + categories { category_name: "positive" score: 1.0 } + } + } + )pb"), + kEpsilon)))); + MP_ASSERT_OK(classifier->Close()); +} + +TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result, + classifier->Classify("What a waste of my time.")); + ASSERT_THAT(negative_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "Negative" score: 0.813 } + categories { category_name: "Positive" score: 0.187 } + } + } + )pb"), + kEpsilon)))); + + MP_ASSERT_OK_AND_ASSIGN( + ClassificationResult positive_result, + classifier->Classify("This is the best movie I’ve seen in recent years. " + "Strongly recommend it!")); + ASSERT_THAT(positive_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "Negative" score: 0.487 } + categories { category_name: "Positive" score: 0.513 } + } + } + )pb"), + kEpsilon)))); + MP_ASSERT_OK(classifier->Close()); +} + +TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); + options->base_options.op_resolver = CreateCustomResolver(); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, + classifier->Classify("hello")); + ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( + classifications { + entries { + categories { index: 1 score: 1 } + categories { index: 0 score: 1 } + categories { index: 2 score: 0 } + } + } + )pb")))); +} + +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, + classifier->Classify(ss_for_positive_review.str())); + ASSERT_THAT(result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "negative" score: 0.014 } + categories { category_name: "positive" score: 0.986 } + } + } + )pb"), + kEpsilon)))); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc new file mode 100644 index 000000000..d12370372 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc @@ -0,0 +1,131 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" +#include "tensorflow/lite/string_util.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace { + +using ::mediapipe::tasks::CreateStatusWithPayload; +using ::tflite::GetInput; +using ::tflite::GetOutput; +using ::tflite::GetString; +using ::tflite::StringRef; + +constexpr absl::string_view kInputStr = "hello"; +constexpr bool kBooleanData[] = {true, true, false}; +constexpr size_t kBooleanDataSize = std::size(kBooleanData); + +// Checks and returns type of a tensor, fails if tensor type is not T. +template +absl::StatusOr AssertAndReturnTypedTensor(const TfLiteTensor* tensor) { + if (!tensor->data.raw) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Tensor (%s) has no raw data.", tensor->name)); + } + + // Checks if data type of tensor is T and returns the pointer casted to T if + // applicable, returns nullptr if tensor type is not T. + // See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. + if (tensor->type == tflite::typeToTfLiteType()) { + return reinterpret_cast(tensor->data.raw); + } + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.", + tensor->name, tflite::typeToTfLiteType(), + tensor->bytes)); +} + +// Populates tensor with array of data, fails if data type doesn't match tensor +// type or they don't have the same number of elements. +template >>> +absl::Status PopulateTensor(const T* data, int num_elements, + TfLiteTensor* tensor) { + ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor(tensor)); + size_t bytes = num_elements * sizeof(T); + if (tensor->bytes != bytes) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes, + bytes)); + } + std::memcpy(v, data, bytes); + return absl::OkStatus(); +} + +TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteIntArray* dims = TfLiteIntArrayCreate(1); + dims->data[0] = kBooleanDataSize; + return context->ResizeTensor(context, output, dims); +} + +TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a bool tensor with +// value{true, true, false}, it's used to mimic a real text classification model +// which classifies a string into scores of different categories. +TfLiteRegistration* RegisterStringToBool() { + // Dummy implementation of custom OP + // This op takes string as input and outputs bool[] + static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr, + /* prepare= */ PrepareStringToBool, + /* invoke= */ InvokeStringToBool}; + return &r; +} +} // namespace + +std::unique_ptr CreateCustomResolver() { + tflite::MutableOpResolver resolver; + resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool()); + return std::make_unique(resolver); +} + +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h new file mode 100644 index 000000000..a427b561c --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ + +#include + +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace text { + +// Create a custom MutableOpResolver to provide custom OP implementations to +// mimic classification behavior. +std::unique_ptr CreateCustomResolver(); + +} // namespace text +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_