Open-source the TextClassifier C++ API.
PiperOrigin-RevId: 482218721
This commit is contained in:
parent
7a6ae97a0e
commit
70df9e2419
84
mediapipe/tasks/cc/text/text_classifier/BUILD
Normal file
84
mediapipe/tasks/cc/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "text_classifier_graph",
|
||||||
|
srcs = ["text_classifier_graph.cc"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources_calculator",
|
||||||
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "text_classifier",
|
||||||
|
srcs = ["text_classifier.cc"],
|
||||||
|
hdrs = ["text_classifier.h"],
|
||||||
|
deps = [
|
||||||
|
":text_classifier_graph",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "text_classifier_test_utils",
|
||||||
|
srcs = ["text_classifier_test_utils.cc"],
|
||||||
|
hdrs = ["text_classifier_test_utils.h"],
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@org_tensorflow//tensorflow/lite:mutable_op_resolver",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
],
|
||||||
|
)
|
30
mediapipe/tasks/cc/text/text_classifier/proto/BUILD
Normal file
30
mediapipe/tasks/cc/text/text_classifier/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "text_classifier_graph_options_proto",
|
||||||
|
srcs = ["text_classifier_graph_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe.tasks.text.text_classifier.proto;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
message TextClassifierGraphOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional TextClassifierGraphOptions ext = 462704549;
|
||||||
|
}
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
|
// model file with metadata, accelerator options, etc.
|
||||||
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
|
// number of results, etc.
|
||||||
|
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||||
|
}
|
104
mediapipe/tasks/cc/text/text_classifier/text_classifier.cc
Normal file
104
mediapipe/tasks/cc/text/text_classifier/text_classifier.cc
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace text {
|
||||||
|
namespace text_classifier {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
|
constexpr char kTextStreamName[] = "text_in";
|
||||||
|
constexpr char kTextTag[] = "TEXT";
|
||||||
|
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||||
|
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
|
||||||
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
|
// type "TextClassifierGraph".
|
||||||
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
|
std::unique_ptr<proto::TextClassifierGraphOptions> options) {
|
||||||
|
api2::builder::Graph graph;
|
||||||
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
|
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
|
||||||
|
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||||
|
subgraph.Out(kClassificationResultTag)
|
||||||
|
.SetName(kClassificationResultStreamName) >>
|
||||||
|
graph.Out(kClassificationResultTag);
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing TextClassifierOptions struct to the internal
|
||||||
|
// TextClassifierGraphOptions proto.
|
||||||
|
std::unique_ptr<proto::TextClassifierGraphOptions>
|
||||||
|
ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) {
|
||||||
|
auto options_proto = std::make_unique<proto::TextClassifierGraphOptions>();
|
||||||
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
auto classifier_options_proto =
|
||||||
|
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||||
|
components::processors::ConvertClassifierOptionsToProto(
|
||||||
|
&(options->classifier_options)));
|
||||||
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
|
classifier_options_proto.get());
|
||||||
|
return options_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
|
||||||
|
std::unique_ptr<TextClassifierOptions> options) {
|
||||||
|
auto options_proto = ConvertTextClassifierOptionsToProto(options.get());
|
||||||
|
return core::TaskApiFactory::Create<TextClassifier,
|
||||||
|
proto::TextClassifierGraphOptions>(
|
||||||
|
CreateGraphConfig(std::move(options_proto)),
|
||||||
|
std::move(options->base_options.op_resolver));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
|
||||||
|
absl::string_view text) {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_packets,
|
||||||
|
runner_->Process(
|
||||||
|
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||||
|
return output_packets[kClassificationResultStreamName]
|
||||||
|
.Get<ClassificationResult>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace text_classifier
|
||||||
|
} // namespace text
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
96
mediapipe/tasks/cc/text/text_classifier/text_classifier.h
Normal file
96
mediapipe/tasks/cc/text/text_classifier/text_classifier.h
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace text {
|
||||||
|
namespace text_classifier {
|
||||||
|
|
||||||
|
// The options for configuring a MediaPipe text classifier task.
|
||||||
|
struct TextClassifierOptions {
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
|
// number of results, etc.
|
||||||
|
components::processors::ClassifierOptions classifier_options;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Performs classification on text.
|
||||||
|
//
|
||||||
|
// This API expects a TFLite model with (optional) TFLite Model Metadata that
|
||||||
|
// contains the mandatory (described below) input tensors, output tensor,
|
||||||
|
// and the optional (but recommended) label items as AssociatedFiles with type
|
||||||
|
// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for
|
||||||
|
// models with int32 input tensors because it contains the input process unit
|
||||||
|
// for the model's Tokenizer. No metadata is required for models with string
|
||||||
|
// input tensors.
|
||||||
|
//
|
||||||
|
// Input tensors:
|
||||||
|
// (kTfLiteInt32)
|
||||||
|
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing
|
||||||
|
// the input ids, segment ids, and mask ids
|
||||||
|
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
|
||||||
|
// input ids
|
||||||
|
// or (kTfLiteString)
|
||||||
|
// - 1 input tensor that is shapeless or has shape [1] containing the input
|
||||||
|
// string
|
||||||
|
// At least one output tensor with:
|
||||||
|
// (kTfLiteFloat32/kBool)
|
||||||
|
// - `[1 x N]` array with `N` represents the number of categories.
|
||||||
|
// - optional (but recommended) label items as AssociatedFiles with type
|
||||||
|
// TENSOR_AXIS_LABELS, containing one label per line. The first such
|
||||||
|
// AssociatedFile (if any) is used to fill the `category_name` field of the
|
||||||
|
// results. The `display_name` field is filled from the AssociatedFile (if
|
||||||
|
// any) whose locale matches the `display_names_locale` field of the
|
||||||
|
// `TextClassifierOptions` used at creation time ("en" by default, i.e.
|
||||||
|
// English). If none of these are available, only the `index` field of the
|
||||||
|
// results will be filled.
|
||||||
|
class TextClassifier : core::BaseTaskApi {
|
||||||
|
public:
|
||||||
|
using BaseTaskApi::BaseTaskApi;
|
||||||
|
|
||||||
|
// Creates a TextClassifier from the provided `options`.
|
||||||
|
static absl::StatusOr<std::unique_ptr<TextClassifier>> Create(
|
||||||
|
std::unique_ptr<TextClassifierOptions> options);
|
||||||
|
|
||||||
|
// Performs classification on the input `text`.
|
||||||
|
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||||
|
absl::string_view text);
|
||||||
|
|
||||||
|
// Shuts down the TextClassifier when all the work is done.
|
||||||
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace text_classifier
|
||||||
|
} // namespace text
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
162
mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc
Normal file
162
mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace text {
|
||||||
|
namespace text_classifier {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
|
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
|
constexpr char kTextTag[] = "TEXT";
|
||||||
|
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||||
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||||
|
// BERT-based text classification).
|
||||||
|
// - Accepts input text and outputs classification results on CPU.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// TEXT - std::string
|
||||||
|
// Input text to perform classification on.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// CLASSIFICATION_RESULT - ClassificationResult
|
||||||
|
// The aggregated classification result object that has 3 dimensions:
|
||||||
|
// (classification head, classification timestamp, classification category).
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
|
||||||
|
// input_stream: "TEXT:text_in"
|
||||||
|
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
|
||||||
|
// {
|
||||||
|
// base_options {
|
||||||
|
// model_asset {
|
||||||
|
// file_name: "/path/to/model.tflite"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
|
SubgraphContext* sc) override {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
const ModelResources* model_resources,
|
||||||
|
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||||
|
Graph graph;
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
Source<ClassificationResult> classification_result_out,
|
||||||
|
BuildTextClassifierTask(
|
||||||
|
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||||
|
graph[Input<std::string>(kTextTag)], graph));
|
||||||
|
classification_result_out >>
|
||||||
|
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Adds a mediapipe TextClassifier task graph into the provided
|
||||||
|
// builder::Graph instance. The TextClassifier task takes an input
|
||||||
|
// text (std::string) and returns one classification result per output head
|
||||||
|
// specified by the model.
|
||||||
|
//
|
||||||
|
// task_options: the mediapipe tasks TextClassifierGraphOptions proto.
|
||||||
|
// model_resources: the ModelResources object initialized from a
|
||||||
|
// TextClassifier model file with model metadata.
|
||||||
|
// text_in: (std::string) stream to run text classification on.
|
||||||
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
|
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
||||||
|
const proto::TextClassifierGraphOptions& task_options,
|
||||||
|
const ModelResources& model_resources, Source<std::string> text_in,
|
||||||
|
Graph& graph) {
|
||||||
|
// Adds preprocessing calculators and connects them to the text input
|
||||||
|
// stream.
|
||||||
|
auto& preprocessing =
|
||||||
|
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
|
||||||
|
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
|
||||||
|
model_resources,
|
||||||
|
preprocessing.GetOptions<
|
||||||
|
tasks::components::proto::TextPreprocessingGraphOptions>()));
|
||||||
|
text_in >> preprocessing.In(kTextTag);
|
||||||
|
|
||||||
|
// Adds both InferenceCalculator and ModelResourcesCalculator.
|
||||||
|
auto& inference = AddInference(
|
||||||
|
model_resources, task_options.base_options().acceleration(), graph);
|
||||||
|
// The metadata extractor side-output comes from the
|
||||||
|
// ModelResourcesCalculator.
|
||||||
|
inference.SideOut(kMetadataExtractorTag) >>
|
||||||
|
preprocessing.SideIn(kMetadataExtractorTag);
|
||||||
|
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||||
|
|
||||||
|
// Adds postprocessing calculators and connects them to the graph output.
|
||||||
|
auto& postprocessing = graph.AddNode(
|
||||||
|
"mediapipe.tasks.components.processors."
|
||||||
|
"ClassificationPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||||
|
model_resources, task_options.classifier_options(),
|
||||||
|
&postprocessing
|
||||||
|
.GetOptions<components::processors::proto::
|
||||||
|
ClassificationPostprocessingGraphOptions>()));
|
||||||
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
|
// Outputs the aggregated classification result as the subgraph output
|
||||||
|
// stream.
|
||||||
|
return postprocessing[Output<ClassificationResult>(
|
||||||
|
kClassificationResultTag)];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
::mediapipe::tasks::text::text_classifier::TextClassifierGraph);
|
||||||
|
|
||||||
|
} // namespace text_classifier
|
||||||
|
} // namespace text
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
238
mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
Normal file
238
mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/cord.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace text {
|
||||||
|
namespace text_classifier {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::EqualsProto;
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::Optional;
|
||||||
|
using ::testing::proto::Approximately;
|
||||||
|
using ::testing::proto::IgnoringRepeatedFieldOrdering;
|
||||||
|
using ::testing::proto::Partially;
|
||||||
|
|
||||||
|
constexpr float kEpsilon = 0.001;
|
||||||
|
constexpr int kMaxSeqLen = 128;
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||||
|
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||||
|
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||||
|
constexpr char kTestRegexModelPath[] =
|
||||||
|
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||||
|
constexpr char kStringToBoolModelPath[] =
|
||||||
|
"test_model_text_classifier_bool_output.tflite";
|
||||||
|
|
||||||
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||||
|
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
StatusOr<std::unique_ptr<TextClassifier>> classifier =
|
||||||
|
TextClassifier::Create(std::move(options));
|
||||||
|
|
||||||
|
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
classifier.status().message(),
|
||||||
|
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||||
|
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||||
|
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, CreateFailsWithMissingModel) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||||
|
StatusOr<std::unique_ptr<TextClassifier>> classifier =
|
||||||
|
TextClassifier::Create(std::move(options));
|
||||||
|
|
||||||
|
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_THAT(classifier.status().message(),
|
||||||
|
HasSubstr("Unable to open file at"));
|
||||||
|
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||||
|
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
ClassificationResult negative_result,
|
||||||
|
classifier->Classify("unflinchingly bleak and desperate"));
|
||||||
|
ASSERT_THAT(negative_result,
|
||||||
|
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
classifications {
|
||||||
|
entries {
|
||||||
|
categories { category_name: "negative" score: 0.956 }
|
||||||
|
categories { category_name: "positive" score: 0.044 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"),
|
||||||
|
kEpsilon))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
ClassificationResult positive_result,
|
||||||
|
classifier->Classify("it's a charming and often affecting journey"));
|
||||||
|
ASSERT_THAT(positive_result,
|
||||||
|
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
classifications {
|
||||||
|
entries {
|
||||||
|
categories { category_name: "negative" score: 0.0 }
|
||||||
|
categories { category_name: "positive" score: 1.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"),
|
||||||
|
kEpsilon))));
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
|
||||||
|
classifier->Classify("What a waste of my time."));
|
||||||
|
ASSERT_THAT(negative_result,
|
||||||
|
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
classifications {
|
||||||
|
entries {
|
||||||
|
categories { category_name: "Negative" score: 0.813 }
|
||||||
|
categories { category_name: "Positive" score: 0.187 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"),
|
||||||
|
kEpsilon))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
ClassificationResult positive_result,
|
||||||
|
classifier->Classify("This is the best movie I’ve seen in recent years. "
|
||||||
|
"Strongly recommend it!"));
|
||||||
|
ASSERT_THAT(positive_result,
|
||||||
|
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
classifications {
|
||||||
|
entries {
|
||||||
|
categories { category_name: "Negative" score: 0.487 }
|
||||||
|
categories { category_name: "Positive" score: 0.513 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"),
|
||||||
|
kEpsilon))));
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||||
|
options->base_options.op_resolver = CreateCustomResolver();
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||||
|
classifier->Classify("hello"));
|
||||||
|
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
|
||||||
|
classifications {
|
||||||
|
entries {
|
||||||
|
categories { index: 1 score: 1 }
|
||||||
|
categories { index: 0 score: 1 }
|
||||||
|
categories { index: 2 score: 0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, BertLongPositive) {
|
||||||
|
std::stringstream ss_for_positive_review;
|
||||||
|
ss_for_positive_review
|
||||||
|
<< "it's a charming and often affecting journey and this is a long";
|
||||||
|
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||||
|
ss_for_positive_review << " long";
|
||||||
|
}
|
||||||
|
ss_for_positive_review << " movie review";
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||||
|
classifier->Classify(ss_for_positive_review.str()));
|
||||||
|
ASSERT_THAT(result,
|
||||||
|
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
classifications {
|
||||||
|
entries {
|
||||||
|
categories { category_name: "negative" score: 0.014 }
|
||||||
|
categories { category_name: "positive" score: 0.986 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"),
|
||||||
|
kEpsilon))));
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace text_classifier
|
||||||
|
} // namespace text
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,131 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/portable_type_to_tflitetype.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace text {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::CreateStatusWithPayload;
|
||||||
|
using ::tflite::GetInput;
|
||||||
|
using ::tflite::GetOutput;
|
||||||
|
using ::tflite::GetString;
|
||||||
|
using ::tflite::StringRef;
|
||||||
|
|
||||||
|
constexpr absl::string_view kInputStr = "hello";
|
||||||
|
constexpr bool kBooleanData[] = {true, true, false};
|
||||||
|
constexpr size_t kBooleanDataSize = std::size(kBooleanData);
|
||||||
|
|
||||||
|
// Checks and returns type of a tensor, fails if tensor type is not T.
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T*> AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
|
||||||
|
if (!tensor->data.raw) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInternal,
|
||||||
|
absl::StrFormat("Tensor (%s) has no raw data.", tensor->name));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if data type of tensor is T and returns the pointer casted to T if
|
||||||
|
// applicable, returns nullptr if tensor type is not T.
|
||||||
|
// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
|
||||||
|
if (tensor->type == tflite::typeToTfLiteType<T>()) {
|
||||||
|
return reinterpret_cast<T*>(tensor->data.raw);
|
||||||
|
}
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInternal,
|
||||||
|
absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.",
|
||||||
|
tensor->name, tflite::typeToTfLiteType<T>(),
|
||||||
|
tensor->bytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populates tensor with array of data, fails if data type doesn't match tensor
|
||||||
|
// type or they don't have the same number of elements.
|
||||||
|
template <typename T, typename = std::enable_if_t<
|
||||||
|
std::negation_v<std::is_same<T, std::string>>>>
|
||||||
|
absl::Status PopulateTensor(const T* data, int num_elements,
|
||||||
|
TfLiteTensor* tensor) {
|
||||||
|
ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor<T>(tensor));
|
||||||
|
size_t bytes = num_elements * sizeof(T);
|
||||||
|
if (tensor->bytes != bytes) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInternal,
|
||||||
|
absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes,
|
||||||
|
bytes));
|
||||||
|
}
|
||||||
|
std::memcpy(v, data, bytes);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||||
|
dims->data[0] = kBooleanDataSize;
|
||||||
|
return context->ResizeTensor(context, output, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||||
|
TF_LITE_ENSURE(context, input_tensor != nullptr);
|
||||||
|
StringRef input_str_ref = GetString(input_tensor, 0);
|
||||||
|
std::string input_str(input_str_ref.str, input_str_ref.len);
|
||||||
|
if (input_str != kInputStr) {
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok());
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This custom op takes a string tensor in and outputs a bool tensor with
|
||||||
|
// value{true, true, false}, it's used to mimic a real text classification model
|
||||||
|
// which classifies a string into scores of different categories.
|
||||||
|
TfLiteRegistration* RegisterStringToBool() {
|
||||||
|
// Dummy implementation of custom OP
|
||||||
|
// This op takes string as input and outputs bool[]
|
||||||
|
static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr,
|
||||||
|
/* prepare= */ PrepareStringToBool,
|
||||||
|
/* invoke= */ InvokeStringToBool};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver() {
|
||||||
|
tflite::MutableOpResolver resolver;
|
||||||
|
resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool());
|
||||||
|
return std::make_unique<tflite::MutableOpResolver>(resolver);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace text
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace text {
|
||||||
|
|
||||||
|
// Create a custom MutableOpResolver to provide custom OP implementations to
|
||||||
|
// mimic classification behavior.
|
||||||
|
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver();
|
||||||
|
|
||||||
|
} // namespace text
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user