Open-source the TextClassifier C++ API.

PiperOrigin-RevId: 482218721
This commit is contained in:
MediaPipe Team 2022-10-19 09:01:56 -07:00 committed by Copybara-Service
parent 7a6ae97a0e
commit 70df9e2419
9 changed files with 915 additions and 0 deletions

View File

@ -0,0 +1,84 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "text_classifier_graph",
srcs = ["text_classifier_graph.cc"],
deps = [
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_calculator",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(
name = "text_classifier",
srcs = ["text_classifier.cc"],
hdrs = ["text_classifier.h"],
deps = [
":text_classifier_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
cc_library(
name = "text_classifier_test_utils",
srcs = ["text_classifier_test_utils.cc"],
hdrs = ["text_classifier_test_utils.h"],
visibility = ["//visibility:private"],
deps = [
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite:mutable_op_resolver",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
],
)

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "text_classifier_graph_options_proto",
srcs = ["text_classifier_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,35 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.text.text_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message TextClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional TextClassifierGraphOptions ext = 462704549;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
optional components.processors.proto.ClassifierOptions classifier_options = 2;
}

View File

@ -0,0 +1,104 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kTextStreamName[] = "text_in";
constexpr char kTextTag[] = "TEXT";
constexpr char kClassificationResultStreamName[] = "classification_result_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
// Creates a MediaPipe graph config that only contains a single subgraph node of
// type "TextClassifierGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<proto::TextClassifierGraphOptions> options) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName);
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag);
return graph.GetConfig();
}
// Converts the user-facing TextClassifierOptions struct to the internal
// TextClassifierGraphOptions proto.
std::unique_ptr<proto::TextClassifierGraphOptions>
ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) {
auto options_proto = std::make_unique<proto::TextClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
auto classifier_options_proto =
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get());
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
std::unique_ptr<TextClassifierOptions> options) {
auto options_proto = ConvertTextClassifierOptionsToProto(options.get());
return core::TaskApiFactory::Create<TextClassifier,
proto::TextClassifierGraphOptions>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver));
}
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
absl::string_view text) {
ASSIGN_OR_RETURN(
auto output_packets,
runner_->Process(
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
return output_packets[kClassificationResultStreamName]
.Get<ClassificationResult>();
}
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,96 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
// The options for configuring a MediaPipe text classifier task.
struct TextClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
components::processors::ClassifierOptions classifier_options;
};
// Performs classification on text.
//
// This API expects a TFLite model with (optional) TFLite Model Metadata that
// contains the mandatory (described below) input tensors, output tensor,
// and the optional (but recommended) label items as AssociatedFiles with type
// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for
// models with int32 input tensors because it contains the input process unit
// for the model's Tokenizer. No metadata is required for models with string
// input tensors.
//
// Input tensors:
// (kTfLiteInt32)
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing
// the input ids, segment ids, and mask ids
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
// input ids
// or (kTfLiteString)
// - 1 input tensor that is shapeless or has shape [1] containing the input
// string
// At least one output tensor with:
// (kTfLiteFloat32/kBool)
// - `[1 x N]` array with `N` represents the number of categories.
// - optional (but recommended) label items as AssociatedFiles with type
// TENSOR_AXIS_LABELS, containing one label per line. The first such
// AssociatedFile (if any) is used to fill the `category_name` field of the
// results. The `display_name` field is filled from the AssociatedFile (if
// any) whose locale matches the `display_names_locale` field of the
// `TextClassifierOptions` used at creation time ("en" by default, i.e.
// English). If none of these are available, only the `index` field of the
// results will be filled.
class TextClassifier : core::BaseTaskApi {
public:
using BaseTaskApi::BaseTaskApi;
// Creates a TextClassifier from the provided `options`.
static absl::StatusOr<std::unique_ptr<TextClassifier>> Create(
std::unique_ptr<TextClassifierOptions> options);
// Performs classification on the input `text`.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
absl::string_view text);
// Shuts down the TextClassifier when all the work is done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_

View File

@ -0,0 +1,162 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <string>
#include <type_traits>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
} // namespace
// A "TextClassifierGraph" performs Natural Language classification (including
// BERT-based text classification).
// - Accepts input text and outputs classification results on CPU.
//
// Inputs:
// TEXT - std::string
// Input text to perform classification on.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The aggregated classification result object that has 3 dimensions:
// (classification head, classification timestamp, classification category).
//
// Example:
// node {
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
// input_stream: "TEXT:text_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// options {
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// }
// }
// }
class TextClassifierGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(
const ModelResources* model_resources,
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
Source<ClassificationResult> classification_result_out,
BuildTextClassifierTask(
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
graph[Input<std::string>(kTextTag)], graph));
classification_result_out >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
return graph.GetConfig();
}
private:
// Adds a mediapipe TextClassifier task graph into the provided
// builder::Graph instance. The TextClassifier task takes an input
// text (std::string) and returns one classification result per output head
// specified by the model.
//
// task_options: the mediapipe tasks TextClassifierGraphOptions proto.
// model_resources: the ModelResources object initialized from a
// TextClassifier model file with model metadata.
// text_in: (std::string) stream to run text classification on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
const proto::TextClassifierGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) {
// Adds preprocessing calculators and connects them to the text input
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
model_resources,
preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
// The metadata extractor side-output comes from the
// ModelResourcesCalculator.
inference.SideOut(kMetadataExtractorTag) >>
preprocessing.SideIn(kMetadataExtractorTag);
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, task_options.classifier_options(),
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the aggregated classification result as the subgraph output
// stream.
return postprocessing[Output<ClassificationResult>(
kClassificationResultTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::text::text_classifier::TextClassifierGraph);
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,238 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
#include <cmath>
#include <cstdlib>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::EqualsProto;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::proto::Approximately;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
using ::testing::proto::Partially;
constexpr float kEpsilon = 0.001;
constexpr int kMaxSeqLen = 128;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
constexpr char kTestRegexModelPath[] =
"test_model_text_classifier_with_regex_tokenizer.tflite";
constexpr char kStringToBoolModelPath[] =
"test_model_text_classifier_bool_output.tflite";
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
class TextClassifierTest : public tflite_shims::testing::Test {};
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) {
auto options = std::make_unique<TextClassifierOptions>();
StatusOr<std::unique_ptr<TextClassifier>> classifier =
TextClassifier::Create(std::move(options));
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
classifier.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(TextClassifierTest, CreateFailsWithMissingModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
StatusOr<std::unique_ptr<TextClassifier>> classifier =
TextClassifier::Create(std::move(options));
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
EXPECT_THAT(classifier.status().message(),
HasSubstr("Unable to open file at"));
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.956 }
categories { category_name: "positive" score: 0.044 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.0 }
categories { category_name: "positive" score: 1.0 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
classifier->Classify("What a waste of my time."));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.813 }
categories { category_name: "Positive" score: 0.187 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years. "
"Strongly recommend it!"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.487 }
categories { category_name: "Positive" score: 0.513 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify("hello"));
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
classifications {
entries {
categories { index: 1 score: 1 }
categories { index: 0 score: 1 }
categories { index: 2 score: 0 }
}
}
)pb"))));
}
TEST_F(TextClassifierTest, BertLongPositive) {
std::stringstream ss_for_positive_review;
ss_for_positive_review
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kMaxSeqLen; ++i) {
ss_for_positive_review << " long";
}
ss_for_positive_review << " movie review";
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify(ss_for_positive_review.str()));
ASSERT_THAT(result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.014 }
categories { category_name: "positive" score: 0.986 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
} // namespace
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,131 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include <cstring>
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace {
using ::mediapipe::tasks::CreateStatusWithPayload;
using ::tflite::GetInput;
using ::tflite::GetOutput;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr absl::string_view kInputStr = "hello";
constexpr bool kBooleanData[] = {true, true, false};
constexpr size_t kBooleanDataSize = std::size(kBooleanData);
// Checks and returns type of a tensor, fails if tensor type is not T.
template <typename T>
absl::StatusOr<T*> AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
if (!tensor->data.raw) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("Tensor (%s) has no raw data.", tensor->name));
}
// Checks if data type of tensor is T and returns the pointer casted to T if
// applicable, returns nullptr if tensor type is not T.
// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
if (tensor->type == tflite::typeToTfLiteType<T>()) {
return reinterpret_cast<T*>(tensor->data.raw);
}
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.",
tensor->name, tflite::typeToTfLiteType<T>(),
tensor->bytes));
}
// Populates tensor with array of data, fails if data type doesn't match tensor
// type or they don't have the same number of elements.
template <typename T, typename = std::enable_if_t<
std::negation_v<std::is_same<T, std::string>>>>
absl::Status PopulateTensor(const T* data, int num_elements,
TfLiteTensor* tensor) {
ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor<T>(tensor));
size_t bytes = num_elements * sizeof(T);
if (tensor->bytes != bytes) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes,
bytes));
}
std::memcpy(v, data, bytes);
return absl::OkStatus();
}
TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
dims->data[0] = kBooleanDataSize;
return context->ResizeTensor(context, output, dims);
}
TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input_tensor != nullptr);
StringRef input_str_ref = GetString(input_tensor, 0);
std::string input_str(input_str_ref.str, input_str_ref.len);
if (input_str != kInputStr) {
return kTfLiteError;
}
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok());
return kTfLiteOk;
}
// This custom op takes a string tensor in and outputs a bool tensor with
// value{true, true, false}, it's used to mimic a real text classification model
// which classifies a string into scores of different categories.
TfLiteRegistration* RegisterStringToBool() {
// Dummy implementation of custom OP
// This op takes string as input and outputs bool[]
static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr,
/* prepare= */ PrepareStringToBool,
/* invoke= */ InvokeStringToBool};
return &r;
}
} // namespace
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver() {
tflite::MutableOpResolver resolver;
resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool());
return std::make_unique<tflite::MutableOpResolver>(resolver);
}
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,35 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
#include <memory>
#include "tensorflow/lite/mutable_op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace text {
// Create a custom MutableOpResolver to provide custom OP implementations to
// mimic classification behavior.
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver();
} // namespace text
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_