Merge branch 'google:master' into gesture-recognizer-python

This commit is contained in:
Kinar R 2022-10-28 14:08:52 +05:30 committed by GitHub
commit 0f7c5d5e90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 901 additions and 81 deletions

View File

@ -172,6 +172,10 @@ http_archive(
urls = [ urls = [
"https://github.com/google/sentencepiece/archive/1.0.0.zip", "https://github.com/google/sentencepiece/archive/1.0.0.zip",
], ],
patches = [
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
],
patch_args = ["-p1"],
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"}, repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
) )

14
docs/BUILD Normal file
View File

@ -0,0 +1,14 @@
# Placeholder for internal Python strict binary compatibility macro.
py_binary(
name = "build_py_api_docs",
srcs = ["build_py_api_docs.py"],
deps = [
"//mediapipe",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/tensorflow_docs",
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
"//third_party/py/tensorflow_docs/api_generator:public_api",
],
)

85
docs/build_py_api_docs.py Normal file
View File

@ -0,0 +1,85 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""MediaPipe reference docs generation script.
This script generates API reference docs for the `mediapipe` PIP package.
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
$> python build_py_api_docs.py
"""
import os
from absl import app
from absl import flags
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
try:
# mediapipe has not been set up to work with bazel yet, so catch & report.
import mediapipe # pytype: disable=import-error
except ImportError as e:
raise ImportError('Please `pip install mediapipe`.') from e
PROJECT_SHORT_NAME = 'mp'
PROJECT_FULL_NAME = 'MediaPipe'
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
default='/tmp/generated_docs',
help='Where to write the resulting docs.')
_URL_PREFIX = flags.DEFINE_string(
'code_url_prefix',
'https://github.com/google/mediapipe/tree/master/mediapipe',
'The url prefix for links to code.')
_SEARCH_HINTS = flags.DEFINE_bool(
'search_hints', True,
'Include metadata search hints in the generated files')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
'Path prefix in the _toc.yaml')
def gen_api_docs():
"""Generates API docs for the mediapipe package."""
doc_generator = generate_lib.DocGenerator(
root_title=PROJECT_FULL_NAME,
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
base_dir=os.path.dirname(mediapipe.__file__),
code_url_prefix=_URL_PREFIX.value,
search_hints=_SEARCH_HINTS.value,
site_path=_SITE_PATH.value,
# This callback ensures that docs are only generated for objects that
# are explicitly imported in your __init__.py files. There are other
# options but this is a good starting point.
callbacks=[public_api.explicit_package_contents_filter],
)
doc_generator.build(_OUTPUT_DIR.value)
print('Docs output to:', _OUTPUT_DIR.value)
def main(_):
gen_api_docs()
if __name__ == '__main__':
app.run(main)

View File

@ -289,8 +289,15 @@ class NodeBase {
template <typename T> template <typename T>
T& GetOptions() { T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true; options_used_ = true;
return *options_.MutableExtension(T::ext); return *options_.MutableExtension(extension);
} }
protected: protected:
@ -386,8 +393,15 @@ class PacketGenerator {
template <typename T> template <typename T>
T& GetOptions() { T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true; options_used_ = true;
return *options_.MutableExtension(T::ext); return *options_.MutableExtension(extension);
} }
template <typename B, typename T, bool kIsOptional, bool kIsMultiple> template <typename B, typename T, bool kIsOptional, bool kIsMultiple>

View File

@ -161,7 +161,7 @@ class Texture {
~Texture() { ~Texture() {
if (is_owned_) { if (is_owned_) {
glDeleteProgram(handle_); glDeleteTextures(1, &handle_);
} }
} }

View File

@ -49,6 +49,7 @@ cc_library(
"//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -129,7 +130,7 @@ absl::Status ConfigureImageToTensorCalculator(
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
std); std);
} }
// TODO: need to.support different GPU origin on differnt // TODO: need to support different GPU origin on differnt
// platforms or applications. // platforms or applications.
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
return absl::OkStatus(); return absl::OkStatus();
@ -137,7 +138,13 @@ absl::Status ConfigureImageToTensorCalculator(
} // namespace } // namespace
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration) {
return acceleration.has_gpu();
}
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
bool use_gpu,
ImagePreprocessingOptions* options) { ImagePreprocessingOptions* options) {
ASSIGN_OR_RETURN(auto image_tensor_specs, ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources)); BuildImageTensorSpecs(model_resources));
@ -145,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
image_tensor_specs, options->mutable_image_to_tensor_options())); image_tensor_specs, options->mutable_image_to_tensor_options()));
// The GPU backend isn't able to process int data. If the input tensor is // The GPU backend isn't able to process int data. If the input tensor is
// quantized, forces the image preprocessing graph to use CPU backend. // quantized, forces the image preprocessing graph to use CPU backend.
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
} else {
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -19,20 +19,26 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
// Configures an ImagePreprocessing subgraph using the provided model resources. // Configures an ImagePreprocessing subgraph using the provided model resources
// When use_gpu is true, use GPU as backend to convert image to tensor.
// - Accepts CPU input images and outputs CPU tensors. // - Accepts CPU input images and outputs CPU tensors.
// //
// Example usage: // Example usage:
// //
// auto& preprocessing = // auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
// core::proto::Acceleration acceleration;
// acceleration.mutable_xnnpack();
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
// model_resources, // model_resources,
// use_gpu,
// &preprocessing.GetOptions<ImagePreprocessingOptions>())); // &preprocessing.GetOptions<ImagePreprocessingOptions>()));
// //
// The resulting ImagePreprocessing subgraph has the following I/O: // The resulting ImagePreprocessing subgraph has the following I/O:
@ -56,9 +62,14 @@ namespace components {
// The image that has the pixel data stored on the target storage (CPU vs // The image that has the pixel data stored on the target storage (CPU vs
// GPU). // GPU).
absl::Status ConfigureImagePreprocessing( absl::Status ConfigureImagePreprocessing(
const core::ModelResources& model_resources, const core::ModelResources& model_resources, bool use_gpu,
ImagePreprocessingOptions* options); ImagePreprocessingOptions* options);
// Determine if the image preprocessing subgraph should use GPU as the backend
// according to the given acceleration setting.
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration);
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -93,3 +93,46 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
mediapipe_proto_library(
name = "combined_prediction_calculator_proto",
srcs = ["combined_prediction_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "combined_prediction_calculator",
srcs = ["combined_prediction_calculator.cc"],
deps = [
":combined_prediction_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)
cc_test(
name = "combined_prediction_calculator_test",
srcs = ["combined_prediction_calculator_test.cc"],
deps = [
":combined_prediction_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -0,0 +1,187 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/btree_map.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
constexpr char kPredictionTag[] = "PREDICTION";
Classification GetMaxScoringClassification(
const ClassificationList& classifications) {
Classification max_classification;
max_classification.set_score(0);
for (const auto& input : classifications.classification()) {
if (max_classification.score() < input.score()) {
max_classification = input;
}
}
return max_classification;
}
float GetScoreThreshold(
const std::string& input_label,
const absl::btree_map<std::string, float>& classwise_thresholds,
const std::string& background_label, const float default_threshold) {
float threshold = default_threshold;
auto it = classwise_thresholds.find(input_label);
if (it != classwise_thresholds.end()) {
threshold = it->second;
}
return threshold;
}
std::unique_ptr<ClassificationList> GetWinningPrediction(
const ClassificationList& classification_list,
const absl::btree_map<std::string, float>& classwise_thresholds,
const std::string& background_label, const float default_threshold) {
auto prediction_list = std::make_unique<ClassificationList>();
if (classification_list.classification().empty()) {
return prediction_list;
}
Classification& prediction = *prediction_list->add_classification();
auto argmax_prediction = GetMaxScoringClassification(classification_list);
float argmax_prediction_thresh =
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
background_label, default_threshold);
if (argmax_prediction.score() >= argmax_prediction_thresh) {
prediction.set_label(argmax_prediction.label());
prediction.set_score(argmax_prediction.score());
} else {
for (const auto& input : classification_list.classification()) {
if (input.label() == background_label) {
prediction.set_label(input.label());
prediction.set_score(input.score());
break;
}
}
}
return prediction_list;
}
} // namespace
// This calculator accepts multiple ClassificationList input streams. Each
// ClassificationList should contain classifications with labels and
// corresponding softmax scores. The calculator computes the best prediction for
// each ClassificationList input stream via argmax and thresholding. Thresholds
// for all classes can be specified in the
// `CombinedPredictionCalculatorOptions`, along with a default global
// threshold.
// Please note that for this calculator to work as designed, the class names
// other than the background class in the ClassificationList objects must be
// different, but the background class name has to be the same. This background
// label name can be set via `background_label` in
// `CombinedPredictionCalculatorOptions`.
// The ClassificationList in the PREDICTION output stream contains the label of
// the winning class and corresponding softmax score. If none of the
// ClassificationList objects has a non-background winning class, the output
// contains the background class and score of the background class in the first
// ClassificationList. If multiple ClassificationList objects have a
// non-background winning class, the output contains the winning prediction from
// the ClassificationList with the highest priority. Priority is in decreasing
// order of input streams to the graph node using this calculator.
// Input:
// At least one stream with ClassificationList.
// Output:
// PREDICTION - A ClassificationList with the winning label as the only item.
//
// Usage example:
// node {
// calculator: "CombinedPredictionCalculator"
// input_stream: "classification_list_0"
// input_stream: "classification_list_1"
// output_stream: "PREDICTION:prediction"
// options {
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
// class {
// label: "A"
// score_threshold: 0.7
// }
// default_global_threshold: 0.1
// background_label: "B"
// }
// }
// }
class CombinedPredictionCalculator : public Node {
public:
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
""};
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
for (const auto& input : options_.class_()) {
classwise_thresholds_[input.label()] = input.score_threshold();
}
classwise_thresholds_[options_.background_label()] = 0;
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
// After loop, if have winning prediction return. Otherwise empty packet.
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
auto collection = kClassificationListIn(cc);
for (int idx = 0; idx < collection.Count(); ++idx) {
const auto& packet = collection[idx];
if (packet.IsEmpty()) {
continue;
}
auto prediction = GetWinningPrediction(
packet.Get(), classwise_thresholds_, options_.background_label(),
options_.default_global_threshold());
if (prediction->classification(0).label() !=
options_.background_label()) {
kPredictionOut(cc).Send(std::move(prediction));
return absl::OkStatus();
}
if (first_winning_prediction == nullptr) {
first_winning_prediction = std::move(prediction);
}
}
if (first_winning_prediction != nullptr) {
kPredictionOut(cc).Send(std::move(first_winning_prediction));
}
return absl::OkStatus();
}
private:
CombinedPredictionCalculatorOptions options_;
absl::btree_map<std::string, float> classwise_thresholds_;
};
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,41 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message CombinedPredictionCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional CombinedPredictionCalculatorOptions ext = 483738635;
}
message Class {
optional string label = 1;
optional float score_threshold = 2;
}
// List of classes with score thresholds.
repeated Class class = 1;
// Default score threshold applied to a label.
optional float default_global_threshold = 2 [default = 0];
// Name of the background class whose input scores will be ignored while
// thresholding.
optional string background_label = 3;
}

View File

@ -0,0 +1,315 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kPredictionTag[] = "PREDICTION";
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
float drama_thresh, float llama_thresh, float bazinga_thresh,
float joy_thresh, float peace_thresh) {
constexpr absl::string_view kCalculatorProto = R"pb(
calculator: "CombinedPredictionCalculator"
input_stream: "custom_softmax_scores"
input_stream: "canned_softmax_scores"
output_stream: "PREDICTION:prediction"
options {
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
class { label: "CustomDrama" score_threshold: $0 }
class { label: "CustomLlama" score_threshold: $1 }
class { label: "CannedBazinga" score_threshold: $2 }
class { label: "CannedJoy" score_threshold: $3 }
class { label: "CannedPeace" score_threshold: $4 }
background_label: "Negative"
}
}
)pb";
auto runner = std::make_unique<CalculatorRunner>(
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
bazinga_thresh, joy_thresh, peace_thresh));
return runner;
}
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
const float negative_score, const float drama_score,
const float llama_score) {
auto custom_scores = std::make_unique<ClassificationList>();
auto custom_negative = custom_scores->add_classification();
custom_negative->set_label("Negative");
custom_negative->set_score(negative_score);
auto drama = custom_scores->add_classification();
drama->set_label("CustomDrama");
drama->set_score(drama_score);
auto llama = custom_scores->add_classification();
llama->set_label("CustomLlama");
llama->set_score(llama_score);
return custom_scores;
}
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
const float negative_score, const float bazinga_score,
const float joy_score, const float peace_score) {
auto canned_scores = std::make_unique<ClassificationList>();
auto canned_negative = canned_scores->add_classification();
canned_negative->set_label("Negative");
canned_negative->set_score(negative_score);
auto bazinga = canned_scores->add_classification();
bazinga->set_label("CannedBazinga");
bazinga->set_score(bazinga_score);
auto joy = canned_scores->add_classification();
joy->set_label("CannedJoy");
joy->set_score(joy_score);
auto peace = canned_scores->add_classification();
peace->set_label("CannedPeace");
peace->set_score(peace_score);
return canned_scores;
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomEmpty_CannedEmpty_ResultIsEmpty) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
auto canned_scores = BuildCannedScoreInput(
/*negative_score=*/0.1,
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
runner->MutableInputs()->Index(1).packets.push_back(
Adopt(canned_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), "CannedJoy");
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
auto custom_scores =
BuildCustomScoreInput(/*negative_score=*/0.1,
/*drama_score=*/0.2, /*llama_score=*/0.7);
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(custom_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), "CustomLlama");
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
}
struct CombinedPredictionCalculatorTestCase {
std::string test_name;
float custom_negative_score;
float drama_score;
float llama_score;
float drama_thresh;
float llama_thresh;
float canned_negative_score;
float bazinga_score;
float joy_score;
float peace_score;
float bazinga_thresh;
float joy_thresh;
float peace_thresh;
std::string max_scoring_label;
float max_score;
};
using CombinedPredictionCalculatorTest =
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
auto runner = BuildNodeRunnerWithOptions(
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
test_case.joy_thresh, test_case.peace_thresh);
auto custom_scores =
BuildCustomScoreInput(test_case.custom_negative_score,
test_case.drama_score, test_case.llama_score);
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(custom_scores.release()).At(Timestamp(1)));
auto canned_scores = BuildCannedScoreInput(
test_case.canned_negative_score, test_case.bazinga_score,
test_case.joy_score, test_case.peace_score);
runner->MutableInputs()->Index(1).packets.push_back(
Adopt(canned_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
}
INSTANTIATE_TEST_CASE_P(
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
{
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
.custom_negative_score = 0.1,
.drama_score = 0.5,
.llama_score = 0.3,
.drama_thresh = 0.25,
.llama_thresh = 0.7,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "CustomDrama",
.max_score = 0.5,
},
{
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
.custom_negative_score = 0.1,
.drama_score = 0.3,
.llama_score = 0.6,
.drama_thresh = 0.4,
.llama_thresh = 0.8,
.canned_negative_score = 0.1,
.bazinga_score = 0.4,
.joy_score = 0.3,
.peace_score = 0.2,
.bazinga_thresh = 0.0,
.joy_thresh = 0.0,
.peace_thresh = 0.0,
.max_scoring_label = "CannedBazinga",
.max_score = 0.4,
},
{
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
.custom_negative_score = 0.5,
.drama_score = 0.1,
.llama_score = 0.4,
.drama_thresh = 0.1,
.llama_thresh = 0.05,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.5,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
.custom_negative_score = 0.8,
.drama_score = 0.1,
.llama_score = 0.1,
.drama_thresh = 0.25,
.llama_thresh = 0.7,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.8,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
.custom_negative_score = 0.1,
.drama_score = 0.2,
.llama_score = 0.7,
.drama_thresh = 1.1,
.llama_thresh = 1.1,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.1,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
.custom_negative_score = 0.1,
.drama_score = 0.3,
.llama_score = 0.6,
.drama_thresh = 0.4,
.llama_thresh = 0.8,
.canned_negative_score = 0.3,
.bazinga_score = 0.2,
.joy_score = 0.3,
.peace_score = 0.2,
.bazinga_thresh = 0.5,
.joy_thresh = 0.5,
.peace_thresh = 0.5,
.max_scoring_label = "Negative",
.max_score = 0.1,
},
}),
[](const testing::TestParamInfo<
CombinedPredictionCalculatorTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace mediapipe

View File

@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph {
image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode( image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");

View File

@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");

View File

@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -243,8 +243,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -40,6 +40,10 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
] ]
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
]
def mediapipe_tasks_core_aar(name, srcs, manifest): def mediapipe_tasks_core_aar(name, srcs, manifest):
"""Builds medaipipe tasks core AAR. """Builds medaipipe tasks core AAR.
@ -60,6 +64,11 @@ def mediapipe_tasks_core_aar(name, srcs, manifest):
_mediapipe_tasks_java_proto_src_extractor(target = target), _mediapipe_tasks_java_proto_src_extractor(target = target),
) )
for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS:
mediapipe_tasks_java_proto_srcs.append(
_mediapipe_tasks_java_proto_src_extractor(target = target),
)
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java", src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java",
@ -81,32 +90,35 @@ def mediapipe_tasks_core_aar(name, srcs, manifest):
], ],
manifest = manifest, manifest = manifest,
deps = [ deps = [
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite", "//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework:calculator_profile_java_proto_lite", "//mediapipe/framework:calculator_profile_java_proto_lite",
"//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework:mediapipe_options_java_proto_lite", "//mediapipe/framework:mediapipe_options_java_proto_lite",
"//mediapipe/framework:packet_factory_java_proto_lite", "//mediapipe/framework:packet_factory_java_proto_lite",
"//mediapipe/framework:packet_generator_java_proto_lite", "//mediapipe/framework:packet_generator_java_proto_lite",
"//mediapipe/framework:status_handler_java_proto_lite", "//mediapipe/framework:status_handler_java_proto_lite",
"//mediapipe/framework:stream_handler_java_proto_lite", "//mediapipe/framework:stream_handler_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite", "//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite", "//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite", "//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
"//third_party:androidx_annotation", "//third_party:androidx_annotation",
"//third_party:autovalue", "//third_party:autovalue",
"@com_google_protobuf//:protobuf_javalite", "@com_google_protobuf//:protobuf_javalite",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
"@maven//:com_google_flogger_flogger", "@maven//:com_google_flogger_flogger",
"@maven//:com_google_flogger_flogger_system_backend", "@maven//:com_google_flogger_flogger_system_backend",
"@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_code_findbugs_jsr305",
] + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, ] +
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS +
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS +
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
) )
def mediapipe_tasks_vision_aar(name, srcs, native_library): def mediapipe_tasks_vision_aar(name, srcs, native_library):
@ -142,6 +154,39 @@ EOF
native_library = native_library, native_library = native_library,
) )
def mediapipe_tasks_text_aar(name, srcs, native_library):
"""Builds medaipipe tasks text AAR.
Args:
name: The bazel target name.
srcs: MediaPipe Text Tasks' source files.
native_library: The native library that contains text tasks' graph and calculators.
"""
native.genrule(
name = name + "tasks_manifest_generator",
outs = ["AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.text">
<uses-sdk
android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>
EOF
""",
)
_mediapipe_tasks_aar(
name = name,
srcs = srcs,
manifest = "AndroidManifest.xml",
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
native_library = native_library,
)
def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library): def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library):
"""Builds medaipipe tasks AAR.""" """Builds medaipipe tasks AAR."""

View File

@ -61,3 +61,11 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
mediapipe_tasks_text_aar(
name = "tasks_text",
srcs = glob(["**/*.java"]),
native_library = ":libmediapipe_tasks_text_jni_lib",
)

View File

@ -84,33 +84,28 @@ cc_library(
"//conditions:default": ["tflite_gpu_runner.h"], "//conditions:default": ["tflite_gpu_runner.h"],
}), }),
deps = select({ deps = select({
"//mediapipe:ios": [], "//mediapipe:ios": [],
"//mediapipe:macos": [], "//mediapipe:macos": [],
"//conditions:default": [ "//conditions:default": [
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/gpu:api", "@org_tensorflow//tensorflow/lite/delegates/gpu:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
], ],
"//mediapipe:android": [ }) +
"@com_google_absl//absl/strings", select({
"//mediapipe/framework/port:ret_check", "//mediapipe:android": [
"//mediapipe/framework/port:status", "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
"//mediapipe/framework/port:statusor", ],
"@org_tensorflow//tensorflow/lite:framework", "//conditions:default": [],
"@org_tensorflow//tensorflow/lite/delegates/gpu:api", }) + [
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
],
}) + [
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"//mediapipe/framework:port",
"@org_tensorflow//tensorflow/lite/core/api", "@org_tensorflow//tensorflow/lite/core/api",
], ],
) )

View File

@ -34,7 +34,7 @@
// This code should be enabled as soon as TensorFlow version, which mediapipe // This code should be enabled as soon as TensorFlow version, which mediapipe
// uses, will include this module. // uses, will include this module.
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
#include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif #endif
@ -82,7 +82,7 @@ ObjectDef GetSSBOObjectDef(int channels) {
return gpu_object_def; return gpu_object_def;
} }
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) { cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
cl::InferenceOptions result{}; cl::InferenceOptions result{};
@ -106,7 +106,7 @@ absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
return absl::OkStatus(); return absl::OkStatus();
} }
#endif // __ANDROID__ #endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
} // namespace } // namespace
@ -225,7 +225,7 @@ absl::Status TFLiteGPURunner::InitializeOpenGL(
absl::Status TFLiteGPURunner::InitializeOpenCL( absl::Status TFLiteGPURunner::InitializeOpenCL(
std::unique_ptr<InferenceBuilder>* builder) { std::unique_ptr<InferenceBuilder>* builder) {
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
cl::InferenceEnvironmentOptions env_options; cl::InferenceEnvironmentOptions env_options;
if (!serialized_binary_cache_.empty()) { if (!serialized_binary_cache_.empty()) {
env_options.serialized_binary_cache = serialized_binary_cache_; env_options.serialized_binary_cache = serialized_binary_cache_;
@ -254,11 +254,12 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
return absl::OkStatus(); return absl::OkStatus();
#else #else
return mediapipe::UnimplementedError("Currently only Android is supported"); return mediapipe::UnimplementedError(
#endif // __ANDROID__ "Currently only Android & ChromeOS are supported");
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
} }
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel( absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
std::unique_ptr<InferenceBuilder>* builder) { std::unique_ptr<InferenceBuilder>* builder) {
@ -283,7 +284,7 @@ absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
return serialized_model; return serialized_model;
} }
#endif // __ANDROID__ #endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -28,9 +29,9 @@
#include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/delegates/gpu/gl/api2.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
#include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif // __ANDROID__ #endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
@ -83,7 +84,7 @@ class TFLiteGPURunner {
return output_shape_from_model_; return output_shape_from_model_;
} }
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
void SetSerializedBinaryCache(std::vector<uint8_t>&& cache) { void SetSerializedBinaryCache(std::vector<uint8_t>&& cache) {
serialized_binary_cache_ = std::move(cache); serialized_binary_cache_ = std::move(cache);
} }
@ -98,26 +99,26 @@ class TFLiteGPURunner {
} }
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel(); absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
#endif // __ANDROID__ #endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
private: private:
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder); absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder); absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
absl::Status InitializeOpenCLFromSerializedModel( absl::Status InitializeOpenCLFromSerializedModel(
std::unique_ptr<InferenceBuilder>* builder); std::unique_ptr<InferenceBuilder>* builder);
#endif // __ANDROID__ #endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
InferenceOptions options_; InferenceOptions options_;
std::unique_ptr<gl::InferenceEnvironment> gl_environment_; std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
#ifdef __ANDROID__ #if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
std::unique_ptr<cl::InferenceEnvironment> cl_environment_; std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
std::vector<uint8_t> serialized_binary_cache_; std::vector<uint8_t> serialized_binary_cache_;
std::vector<uint8_t> serialized_model_; std::vector<uint8_t> serialized_model_;
bool serialized_model_used_ = false; bool serialized_model_used_ = false;
#endif // __ANDROID__ #endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is // graph_gl_ is maintained temporarily and becomes invalid after runner_ is
// ready // ready

View File

@ -0,0 +1,34 @@
diff --git a/src/BUILD b/src/BUILD
index b4298d2..f3877a3 100644
--- a/src/BUILD
+++ b/src/BUILD
@@ -71,9 +71,7 @@ cc_library(
":common",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
- "@com_github_gflags_gflags//:gflags",
"@com_google_glog//:glog",
- "@com_google_googletest//:gtest",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/container:flat_hash_map",
diff --git a/src/normalizer.h b/src/normalizer.h
index c16ac16..2af58be 100644
--- a/src/normalizer.h
+++ b/src/normalizer.h
@@ -21,7 +21,6 @@
#include <utility>
#include <vector>
-#include "gtest/gtest_prod.h"
#include "absl/strings/string_view.h"
#include "third_party/darts_clone/include/darts.h"
#include "src/common.h"
@@ -97,7 +96,6 @@ class Normalizer {
friend class Builder;
private:
- FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest);
void Init();