diff --git a/WORKSPACE b/WORKSPACE index 146916c5c..5a47cf6b7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -172,6 +172,10 @@ http_archive( urls = [ "https://github.com/google/sentencepiece/archive/1.0.0.zip", ], + patches = [ + "//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff", + ], + patch_args = ["-p1"], repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"}, ) diff --git a/docs/BUILD b/docs/BUILD new file mode 100644 index 000000000..cb8794dab --- /dev/null +++ b/docs/BUILD @@ -0,0 +1,14 @@ +# Placeholder for internal Python strict binary compatibility macro. + +py_binary( + name = "build_py_api_docs", + srcs = ["build_py_api_docs.py"], + deps = [ + "//mediapipe", + "//third_party/py/absl:app", + "//third_party/py/absl/flags", + "//third_party/py/tensorflow_docs", + "//third_party/py/tensorflow_docs/api_generator:generate_lib", + "//third_party/py/tensorflow_docs/api_generator:public_api", + ], +) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py new file mode 100644 index 000000000..9911d0736 --- /dev/null +++ b/docs/build_py_api_docs.py @@ -0,0 +1,85 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe +$> python build_py_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib +from tensorflow_docs.api_generator import public_api + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe`.') from e + + +PROJECT_SHORT_NAME = 'mp' +PROJECT_FULL_NAME = 'MediaPipe' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe)], + base_dir=os.path.dirname(mediapipe.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + # This callback ensures that docs are only generated for objects that + # are explicitly imported in your __init__.py files. There are other + # options but this is a good starting point. + callbacks=[public_api.explicit_package_contents_filter], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 82905d2f5..7dce211c8 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -289,8 +289,15 @@ class NodeBase { template T& GetOptions() { + return GetOptions(T::ext); + } + + // Use this API when the proto extension does not follow the "ext" naming + // convention. + template + auto& GetOptions(const E& extension) { options_used_ = true; - return *options_.MutableExtension(T::ext); + return *options_.MutableExtension(extension); } protected: @@ -386,8 +393,15 @@ class PacketGenerator { template T& GetOptions() { + return GetOptions(T::ext); + } + + // Use this API when the proto extension does not follow the "ext" naming + // convention. + template + auto& GetOptions(const E& extension) { options_used_ = true; - return *options_.MutableExtension(T::ext); + return *options_.MutableExtension(extension); } template diff --git a/mediapipe/modules/face_geometry/libs/effect_renderer.cc b/mediapipe/modules/face_geometry/libs/effect_renderer.cc index 27a54e011..73f473084 100644 --- a/mediapipe/modules/face_geometry/libs/effect_renderer.cc +++ b/mediapipe/modules/face_geometry/libs/effect_renderer.cc @@ -161,7 +161,7 @@ class Texture { ~Texture() { if (is_owned_) { - glDeleteProgram(handle_); + glDeleteTextures(1, &handle_); } } diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index f563fbf64..344fafb4e 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -49,6 +49,7 @@ cc_library( "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index f3f3b6863..7940080e1 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -129,7 +130,7 @@ absl::Status ConfigureImageToTensorCalculator( options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / std); } - // TODO: need to.support different GPU origin on differnt + // TODO: need to support different GPU origin on differnt // platforms or applications. options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); return absl::OkStatus(); @@ -137,7 +138,13 @@ absl::Status ConfigureImageToTensorCalculator( } // namespace +bool DetermineImagePreprocessingGpuBackend( + const core::proto::Acceleration& acceleration) { + return acceleration.has_gpu(); +} + absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, + bool use_gpu, ImagePreprocessingOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); @@ -145,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, image_tensor_specs, options->mutable_image_to_tensor_options())); // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. - if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { + if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { + options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + } else { options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); } return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/image_preprocessing.h index a5b767f3a..6963b6556 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/image_preprocessing.h @@ -19,20 +19,26 @@ limitations under the License. #include "absl/status/status.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { -// Configures an ImagePreprocessing subgraph using the provided model resources. +// Configures an ImagePreprocessing subgraph using the provided model resources +// When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = // graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// core::proto::Acceleration acceleration; +// acceleration.mutable_xnnpack(); +// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // model_resources, +// use_gpu, // &preprocessing.GetOptions())); // // The resulting ImagePreprocessing subgraph has the following I/O: @@ -56,9 +62,14 @@ namespace components { // The image that has the pixel data stored on the target storage (CPU vs // GPU). absl::Status ConfigureImagePreprocessing( - const core::ModelResources& model_resources, + const core::ModelResources& model_resources, bool use_gpu, ImagePreprocessingOptions* options); +// Determine if the image preprocessing subgraph should use GPU as the backend +// according to the given acceleration setting. +bool DetermineImagePreprocessingGpuBackend( + const core::proto::Acceleration& acceleration); + } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 08f7f45d0..8c2c2e593 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -93,3 +93,46 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +mediapipe_proto_library( + name = "combined_prediction_calculator_proto", + srcs = ["combined_prediction_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "combined_prediction_calculator", + srcs = ["combined_prediction_calculator.cc"], + deps = [ + ":combined_prediction_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "combined_prediction_calculator_test", + srcs = ["combined_prediction_calculator_test.cc"], + deps = [ + ":combined_prediction_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc new file mode 100644 index 000000000..c7147ea6e --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc @@ -0,0 +1,187 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +Classification GetMaxScoringClassification( + const ClassificationList& classifications) { + Classification max_classification; + max_classification.set_score(0); + for (const auto& input : classifications.classification()) { + if (max_classification.score() < input.score()) { + max_classification = input; + } + } + return max_classification; +} + +float GetScoreThreshold( + const std::string& input_label, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + float threshold = default_threshold; + auto it = classwise_thresholds.find(input_label); + if (it != classwise_thresholds.end()) { + threshold = it->second; + } + return threshold; +} + +std::unique_ptr GetWinningPrediction( + const ClassificationList& classification_list, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + auto prediction_list = std::make_unique(); + if (classification_list.classification().empty()) { + return prediction_list; + } + Classification& prediction = *prediction_list->add_classification(); + auto argmax_prediction = GetMaxScoringClassification(classification_list); + float argmax_prediction_thresh = + GetScoreThreshold(argmax_prediction.label(), classwise_thresholds, + background_label, default_threshold); + if (argmax_prediction.score() >= argmax_prediction_thresh) { + prediction.set_label(argmax_prediction.label()); + prediction.set_score(argmax_prediction.score()); + } else { + for (const auto& input : classification_list.classification()) { + if (input.label() == background_label) { + prediction.set_label(input.label()); + prediction.set_score(input.score()); + break; + } + } + } + return prediction_list; +} + +} // namespace + +// This calculator accepts multiple ClassificationList input streams. Each +// ClassificationList should contain classifications with labels and +// corresponding softmax scores. The calculator computes the best prediction for +// each ClassificationList input stream via argmax and thresholding. Thresholds +// for all classes can be specified in the +// `CombinedPredictionCalculatorOptions`, along with a default global +// threshold. +// Please note that for this calculator to work as designed, the class names +// other than the background class in the ClassificationList objects must be +// different, but the background class name has to be the same. This background +// label name can be set via `background_label` in +// `CombinedPredictionCalculatorOptions`. +// The ClassificationList in the PREDICTION output stream contains the label of +// the winning class and corresponding softmax score. If none of the +// ClassificationList objects has a non-background winning class, the output +// contains the background class and score of the background class in the first +// ClassificationList. If multiple ClassificationList objects have a +// non-background winning class, the output contains the winning prediction from +// the ClassificationList with the highest priority. Priority is in decreasing +// order of input streams to the graph node using this calculator. +// Input: +// At least one stream with ClassificationList. +// Output: +// PREDICTION - A ClassificationList with the winning label as the only item. +// +// Usage example: +// node { +// calculator: "CombinedPredictionCalculator" +// input_stream: "classification_list_0" +// input_stream: "classification_list_1" +// output_stream: "PREDICTION:prediction" +// options { +// [mediapipe.CombinedPredictionCalculatorOptions.ext] { +// class { +// label: "A" +// score_threshold: 0.7 +// } +// default_global_threshold: 0.1 +// background_label: "B" +// } +// } +// } + +class CombinedPredictionCalculator : public Node { + public: + static constexpr Input::Multiple kClassificationListIn{ + ""}; + static constexpr Output kPredictionOut{"PREDICTION"}; + MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut); + + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + for (const auto& input : options_.class_()) { + classwise_thresholds_[input.label()] = input.score_threshold(); + } + classwise_thresholds_[options_.background_label()] = 0; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // After loop, if have winning prediction return. Otherwise empty packet. + std::unique_ptr first_winning_prediction = nullptr; + auto collection = kClassificationListIn(cc); + for (int idx = 0; idx < collection.Count(); ++idx) { + const auto& packet = collection[idx]; + if (packet.IsEmpty()) { + continue; + } + auto prediction = GetWinningPrediction( + packet.Get(), classwise_thresholds_, options_.background_label(), + options_.default_global_threshold()); + if (prediction->classification(0).label() != + options_.background_label()) { + kPredictionOut(cc).Send(std::move(prediction)); + return absl::OkStatus(); + } + if (first_winning_prediction == nullptr) { + first_winning_prediction = std::move(prediction); + } + } + if (first_winning_prediction != nullptr) { + kPredictionOut(cc).Send(std::move(first_winning_prediction)); + } + return absl::OkStatus(); + } + + private: + CombinedPredictionCalculatorOptions options_; + absl::btree_map classwise_thresholds_; +}; + +MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto new file mode 100644 index 000000000..730e7dd78 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message CombinedPredictionCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional CombinedPredictionCalculatorOptions ext = 483738635; + } + + message Class { + optional string label = 1; + optional float score_threshold = 2; + } + + // List of classes with score thresholds. + repeated Class class = 1; + + // Default score threshold applied to a label. + optional float default_global_threshold = 2 [default = 0]; + + // Name of the background class whose input scores will be ignored while + // thresholding. + optional string background_label = 3; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc new file mode 100644 index 000000000..ecf49795b --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc @@ -0,0 +1,315 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +std::unique_ptr BuildNodeRunnerWithOptions( + float drama_thresh, float llama_thresh, float bazinga_thresh, + float joy_thresh, float peace_thresh) { + constexpr absl::string_view kCalculatorProto = R"pb( + calculator: "CombinedPredictionCalculator" + input_stream: "custom_softmax_scores" + input_stream: "canned_softmax_scores" + output_stream: "PREDICTION:prediction" + options { + [mediapipe.CombinedPredictionCalculatorOptions.ext] { + class { label: "CustomDrama" score_threshold: $0 } + class { label: "CustomLlama" score_threshold: $1 } + class { label: "CannedBazinga" score_threshold: $2 } + class { label: "CannedJoy" score_threshold: $3 } + class { label: "CannedPeace" score_threshold: $4 } + background_label: "Negative" + } + } + )pb"; + auto runner = std::make_unique( + absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh, + bazinga_thresh, joy_thresh, peace_thresh)); + return runner; +} + +std::unique_ptr BuildCustomScoreInput( + const float negative_score, const float drama_score, + const float llama_score) { + auto custom_scores = std::make_unique(); + auto custom_negative = custom_scores->add_classification(); + custom_negative->set_label("Negative"); + custom_negative->set_score(negative_score); + auto drama = custom_scores->add_classification(); + drama->set_label("CustomDrama"); + drama->set_score(drama_score); + auto llama = custom_scores->add_classification(); + llama->set_label("CustomLlama"); + llama->set_score(llama_score); + return custom_scores; +} + +std::unique_ptr BuildCannedScoreInput( + const float negative_score, const float bazinga_score, + const float joy_score, const float peace_score) { + auto canned_scores = std::make_unique(); + auto canned_negative = canned_scores->add_classification(); + canned_negative->set_label("Negative"); + canned_negative->set_score(negative_score); + auto bazinga = canned_scores->add_classification(); + bazinga->set_label("CannedBazinga"); + bazinga->set_score(bazinga_score); + auto joy = canned_scores->add_classification(); + joy->set_label("CannedJoy"); + joy->set_score(joy_score); + auto peace = canned_scores->add_classification(); + peace->set_label("CannedPeace"); + peace->set_score(peace_score); + return canned_scores; +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedEmpty_ResultIsEmpty) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty()); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedNotEmpty_ResultIsCanned) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9, + /*joy_thresh=*/0.5, /*peace_thresh=*/0.8); + auto canned_scores = BuildCannedScoreInput( + /*negative_score=*/0.1, + /*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CannedJoy"); + EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomNotEmpty_CannedEmpty_ResultIsCustom) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + auto custom_scores = + BuildCustomScoreInput(/*negative_score=*/0.1, + /*drama_score=*/0.2, /*llama_score=*/0.7); + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CustomLlama"); + EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4); +} + +struct CombinedPredictionCalculatorTestCase { + std::string test_name; + float custom_negative_score; + float drama_score; + float llama_score; + float drama_thresh; + float llama_thresh; + float canned_negative_score; + float bazinga_score; + float joy_score; + float peace_score; + float bazinga_thresh; + float joy_thresh; + float peace_thresh; + std::string max_scoring_label; + float max_score; +}; + +using CombinedPredictionCalculatorTest = + testing::TestWithParam; + +TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) { + const CombinedPredictionCalculatorTestCase& test_case = GetParam(); + + auto runner = BuildNodeRunnerWithOptions( + test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh, + test_case.joy_thresh, test_case.peace_thresh); + + auto custom_scores = + BuildCustomScoreInput(test_case.custom_negative_score, + test_case.drama_score, test_case.llama_score); + + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + + auto canned_scores = BuildCannedScoreInput( + test_case.canned_negative_score, test_case.bazinga_score, + test_case.joy_score, test_case.peace_score); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label); + EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4); +} + +INSTANTIATE_TEST_CASE_P( + CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest, + testing::ValuesIn({ + { + .test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.5, + .llama_score = 0.3, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "CustomDrama", + .max_score = 0.5, + }, + { + .test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.1, + .bazinga_score = 0.4, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.0, + .joy_thresh = 0.0, + .peace_thresh = 0.0, + .max_scoring_label = "CannedBazinga", + .max_score = 0.4, + }, + { + .test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh", + .custom_negative_score = 0.5, + .drama_score = 0.1, + .llama_score = 0.4, + .drama_thresh = 0.1, + .llama_thresh = 0.05, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.5, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh", + .custom_negative_score = 0.8, + .drama_score = 0.1, + .llama_score = 0.1, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.8, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2", + .custom_negative_score = 0.1, + .drama_score = 0.2, + .llama_score = 0.7, + .drama_thresh = 1.1, + .llama_thresh = 1.1, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.3, + .bazinga_score = 0.2, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.5, + .joy_thresh = 0.5, + .peace_thresh = 0.5, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + }), + [](const testing::TestParamInfo< + CombinedPredictionCalculatorTest::ParamType>& info) { + return info.param.test_name; + }); + +} // namespace + +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index e876d7d09..06bb2e549 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph { image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 23521790d..1f127deb8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 9a0078c5c..8a1b17ce9 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index fff0f4366..f0f440986 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 629b940aa..d3e522d92 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -243,8 +243,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index 07e912cfc..b149cea0f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index e0b9c79ed..0260e3fab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -40,6 +40,10 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", ] +_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", +] + def mediapipe_tasks_core_aar(name, srcs, manifest): """Builds medaipipe tasks core AAR. @@ -60,6 +64,11 @@ def mediapipe_tasks_core_aar(name, srcs, manifest): _mediapipe_tasks_java_proto_src_extractor(target = target), ) + for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java", @@ -81,32 +90,35 @@ def mediapipe_tasks_core_aar(name, srcs, manifest): ], manifest = manifest, deps = [ - "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", - "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", - "//mediapipe/framework:calculator_java_proto_lite", - "//mediapipe/framework:calculator_profile_java_proto_lite", - "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/framework:mediapipe_options_java_proto_lite", - "//mediapipe/framework:packet_factory_java_proto_lite", - "//mediapipe/framework:packet_generator_java_proto_lite", - "//mediapipe/framework:status_handler_java_proto_lite", - "//mediapipe/framework:stream_handler_java_proto_lite", - "//mediapipe/framework/formats:classification_java_proto_lite", - "//mediapipe/framework/formats:detection_java_proto_lite", - "//mediapipe/framework/formats:landmark_java_proto_lite", - "//mediapipe/framework/formats:location_data_java_proto_lite", - "//mediapipe/framework/formats:rect_java_proto_lite", - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", - "//third_party:androidx_annotation", - "//third_party:autovalue", - "@com_google_protobuf//:protobuf_javalite", - "@maven//:com_google_guava_guava", - "@maven//:com_google_flogger_flogger", - "@maven//:com_google_flogger_flogger_system_backend", - "@maven//:com_google_code_findbugs_jsr305", - ] + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, + "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:mediapipe_options_java_proto_lite", + "//mediapipe/framework:packet_factory_java_proto_lite", + "//mediapipe/framework:packet_generator_java_proto_lite", + "//mediapipe/framework:status_handler_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:androidx_annotation", + "//third_party:autovalue", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:com_google_guava_guava", + "@maven//:com_google_flogger_flogger", + "@maven//:com_google_flogger_flogger_system_backend", + "@maven//:com_google_code_findbugs_jsr305", + ] + + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS + + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS, ) def mediapipe_tasks_vision_aar(name, srcs, native_library): @@ -142,6 +154,39 @@ EOF native_library = native_library, ) +def mediapipe_tasks_text_aar(name, srcs, native_library): + """Builds medaipipe tasks text AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Text Tasks' source files. + native_library: The native library that contains text tasks' graph and calculators. + """ + + native.genrule( + name = name + "tasks_manifest_generator", + outs = ["AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + + +EOF +""", + ) + + _mediapipe_tasks_aar( + name = name, + srcs = srcs, + manifest = "AndroidManifest.xml", + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS, + native_library = native_library, + ) + def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library): """Builds medaipipe tasks AAR.""" diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 1719707d8..fa2a547c2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -61,3 +61,11 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar") + +mediapipe_tasks_text_aar( + name = "tasks_text", + srcs = glob(["**/*.java"]), + native_library = ":libmediapipe_tasks_text_jni_lib", +) diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 9d37b60a0..e9b8bfa03 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -84,33 +84,28 @@ cc_library( "//conditions:default": ["tflite_gpu_runner.h"], }), deps = select({ - "//mediapipe:ios": [], - "//mediapipe:macos": [], - "//conditions:default": [ - "@com_google_absl//absl/strings", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/delegates/gpu:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", - ], - "//mediapipe:android": [ - "@com_google_absl//absl/strings", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/delegates/gpu:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", - ], - }) + [ + "//mediapipe:ios": [], + "//mediapipe:macos": [], + "//conditions:default": [ + "@com_google_absl//absl/strings", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/delegates/gpu:api", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", + ], + }) + + select({ + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", + ], + "//conditions:default": [], + }) + [ "@com_google_absl//absl/status", + "//mediapipe/framework:port", "@org_tensorflow//tensorflow/lite/core/api", ], ) diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 4c422835a..4e40975cb 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -34,7 +34,7 @@ // This code should be enabled as soon as TensorFlow version, which mediapipe // uses, will include this module. -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) #include "tensorflow/lite/delegates/gpu/cl/api.h" #endif @@ -82,7 +82,7 @@ ObjectDef GetSSBOObjectDef(int channels) { return gpu_object_def; } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) { cl::InferenceOptions result{}; @@ -106,7 +106,7 @@ absl::Status VerifyShapes(const std::vector& actual, return absl::OkStatus(); } -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } // namespace @@ -225,7 +225,7 @@ absl::Status TFLiteGPURunner::InitializeOpenGL( absl::Status TFLiteGPURunner::InitializeOpenCL( std::unique_ptr* builder) { -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) cl::InferenceEnvironmentOptions env_options; if (!serialized_binary_cache_.empty()) { env_options.serialized_binary_cache = serialized_binary_cache_; @@ -254,11 +254,12 @@ absl::Status TFLiteGPURunner::InitializeOpenCL( return absl::OkStatus(); #else - return mediapipe::UnimplementedError("Currently only Android is supported"); -#endif // __ANDROID__ + return mediapipe::UnimplementedError( + "Currently only Android & ChromeOS are supported"); +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel( std::unique_ptr* builder) { @@ -283,7 +284,7 @@ absl::StatusOr> TFLiteGPURunner::GetSerializedModel() { return serialized_model; } -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } // namespace gpu } // namespace tflite diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index 88d3914f7..dfbc8d659 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -20,6 +20,7 @@ #include #include "absl/status/status.h" +#include "mediapipe/framework/port.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -28,9 +29,9 @@ #include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/model.h" -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) #include "tensorflow/lite/delegates/gpu/cl/api.h" -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) namespace tflite { namespace gpu { @@ -83,7 +84,7 @@ class TFLiteGPURunner { return output_shape_from_model_; } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) void SetSerializedBinaryCache(std::vector&& cache) { serialized_binary_cache_ = std::move(cache); } @@ -98,26 +99,26 @@ class TFLiteGPURunner { } absl::StatusOr> GetSerializedModel(); -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) private: absl::Status InitializeOpenGL(std::unique_ptr* builder); absl::Status InitializeOpenCL(std::unique_ptr* builder); -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status InitializeOpenCLFromSerializedModel( std::unique_ptr* builder); -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) InferenceOptions options_; std::unique_ptr gl_environment_; -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) std::unique_ptr cl_environment_; std::vector serialized_binary_cache_; std::vector serialized_model_; bool serialized_model_used_ = false; -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) // graph_gl_ is maintained temporarily and becomes invalid after runner_ is // ready diff --git a/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff b/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff new file mode 100644 index 000000000..a084d9262 --- /dev/null +++ b/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff @@ -0,0 +1,34 @@ +diff --git a/src/BUILD b/src/BUILD +index b4298d2..f3877a3 100644 +--- a/src/BUILD ++++ b/src/BUILD +@@ -71,9 +71,7 @@ cc_library( + ":common", + ":sentencepiece_cc_proto", + ":sentencepiece_model_cc_proto", +- "@com_github_gflags_gflags//:gflags", + "@com_google_glog//:glog", +- "@com_google_googletest//:gtest", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", +diff --git a/src/normalizer.h b/src/normalizer.h +index c16ac16..2af58be 100644 +--- a/src/normalizer.h ++++ b/src/normalizer.h +@@ -21,7 +21,6 @@ + #include + #include + +-#include "gtest/gtest_prod.h" + #include "absl/strings/string_view.h" + #include "third_party/darts_clone/include/darts.h" + #include "src/common.h" +@@ -97,7 +96,6 @@ class Normalizer { + friend class Builder; + + private: +- FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest); + + void Init(); +