Merge branch 'google:master' into gesture-recognizer-python
This commit is contained in:
commit
0f7c5d5e90
|
@ -172,6 +172,10 @@ http_archive(
|
||||||
urls = [
|
urls = [
|
||||||
"https://github.com/google/sentencepiece/archive/1.0.0.zip",
|
"https://github.com/google/sentencepiece/archive/1.0.0.zip",
|
||||||
],
|
],
|
||||||
|
patches = [
|
||||||
|
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
|
||||||
|
],
|
||||||
|
patch_args = ["-p1"],
|
||||||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
|
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
14
docs/BUILD
Normal file
14
docs/BUILD
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# Placeholder for internal Python strict binary compatibility macro.
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "build_py_api_docs",
|
||||||
|
srcs = ["build_py_api_docs.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe",
|
||||||
|
"//third_party/py/absl:app",
|
||||||
|
"//third_party/py/absl/flags",
|
||||||
|
"//third_party/py/tensorflow_docs",
|
||||||
|
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
|
||||||
|
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||||
|
],
|
||||||
|
)
|
85
docs/build_py_api_docs.py
Normal file
85
docs/build_py_api_docs.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
r"""MediaPipe reference docs generation script.
|
||||||
|
|
||||||
|
This script generates API reference docs for the `mediapipe` PIP package.
|
||||||
|
|
||||||
|
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
|
||||||
|
$> python build_py_api_docs.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
from tensorflow_docs.api_generator import generate_lib
|
||||||
|
from tensorflow_docs.api_generator import public_api
|
||||||
|
|
||||||
|
try:
|
||||||
|
# mediapipe has not been set up to work with bazel yet, so catch & report.
|
||||||
|
import mediapipe # pytype: disable=import-error
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError('Please `pip install mediapipe`.') from e
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_SHORT_NAME = 'mp'
|
||||||
|
PROJECT_FULL_NAME = 'MediaPipe'
|
||||||
|
|
||||||
|
_OUTPUT_DIR = flags.DEFINE_string(
|
||||||
|
'output_dir',
|
||||||
|
default='/tmp/generated_docs',
|
||||||
|
help='Where to write the resulting docs.')
|
||||||
|
|
||||||
|
_URL_PREFIX = flags.DEFINE_string(
|
||||||
|
'code_url_prefix',
|
||||||
|
'https://github.com/google/mediapipe/tree/master/mediapipe',
|
||||||
|
'The url prefix for links to code.')
|
||||||
|
|
||||||
|
_SEARCH_HINTS = flags.DEFINE_bool(
|
||||||
|
'search_hints', True,
|
||||||
|
'Include metadata search hints in the generated files')
|
||||||
|
|
||||||
|
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
|
||||||
|
'Path prefix in the _toc.yaml')
|
||||||
|
|
||||||
|
|
||||||
|
def gen_api_docs():
|
||||||
|
"""Generates API docs for the mediapipe package."""
|
||||||
|
|
||||||
|
doc_generator = generate_lib.DocGenerator(
|
||||||
|
root_title=PROJECT_FULL_NAME,
|
||||||
|
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
|
||||||
|
base_dir=os.path.dirname(mediapipe.__file__),
|
||||||
|
code_url_prefix=_URL_PREFIX.value,
|
||||||
|
search_hints=_SEARCH_HINTS.value,
|
||||||
|
site_path=_SITE_PATH.value,
|
||||||
|
# This callback ensures that docs are only generated for objects that
|
||||||
|
# are explicitly imported in your __init__.py files. There are other
|
||||||
|
# options but this is a good starting point.
|
||||||
|
callbacks=[public_api.explicit_package_contents_filter],
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_generator.build(_OUTPUT_DIR.value)
|
||||||
|
|
||||||
|
print('Docs output to:', _OUTPUT_DIR.value)
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
gen_api_docs()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
|
@ -289,8 +289,15 @@ class NodeBase {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T& GetOptions() {
|
T& GetOptions() {
|
||||||
|
return GetOptions(T::ext);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use this API when the proto extension does not follow the "ext" naming
|
||||||
|
// convention.
|
||||||
|
template <typename E>
|
||||||
|
auto& GetOptions(const E& extension) {
|
||||||
options_used_ = true;
|
options_used_ = true;
|
||||||
return *options_.MutableExtension(T::ext);
|
return *options_.MutableExtension(extension);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -386,8 +393,15 @@ class PacketGenerator {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T& GetOptions() {
|
T& GetOptions() {
|
||||||
|
return GetOptions(T::ext);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use this API when the proto extension does not follow the "ext" naming
|
||||||
|
// convention.
|
||||||
|
template <typename E>
|
||||||
|
auto& GetOptions(const E& extension) {
|
||||||
options_used_ = true;
|
options_used_ = true;
|
||||||
return *options_.MutableExtension(T::ext);
|
return *options_.MutableExtension(extension);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
|
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
|
||||||
|
|
|
@ -161,7 +161,7 @@ class Texture {
|
||||||
|
|
||||||
~Texture() {
|
~Texture() {
|
||||||
if (is_owned_) {
|
if (is_owned_) {
|
||||||
glDeleteProgram(handle_);
|
glDeleteTextures(1, &handle_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ cc_library(
|
||||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
@ -129,7 +130,7 @@ absl::Status ConfigureImageToTensorCalculator(
|
||||||
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
||||||
std);
|
std);
|
||||||
}
|
}
|
||||||
// TODO: need to.support different GPU origin on differnt
|
// TODO: need to support different GPU origin on differnt
|
||||||
// platforms or applications.
|
// platforms or applications.
|
||||||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -137,7 +138,13 @@ absl::Status ConfigureImageToTensorCalculator(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
bool DetermineImagePreprocessingGpuBackend(
|
||||||
|
const core::proto::Acceleration& acceleration) {
|
||||||
|
return acceleration.has_gpu();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
||||||
|
bool use_gpu,
|
||||||
ImagePreprocessingOptions* options) {
|
ImagePreprocessingOptions* options) {
|
||||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||||
BuildImageTensorSpecs(model_resources));
|
BuildImageTensorSpecs(model_resources));
|
||||||
|
@ -145,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
||||||
image_tensor_specs, options->mutable_image_to_tensor_options()));
|
image_tensor_specs, options->mutable_image_to_tensor_options()));
|
||||||
// The GPU backend isn't able to process int data. If the input tensor is
|
// The GPU backend isn't able to process int data. If the input tensor is
|
||||||
// quantized, forces the image preprocessing graph to use CPU backend.
|
// quantized, forces the image preprocessing graph to use CPU backend.
|
||||||
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) {
|
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
|
||||||
|
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
|
||||||
|
} else {
|
||||||
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
|
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -19,20 +19,26 @@ limitations under the License.
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
|
||||||
// Configures an ImagePreprocessing subgraph using the provided model resources.
|
// Configures an ImagePreprocessing subgraph using the provided model resources
|
||||||
|
// When use_gpu is true, use GPU as backend to convert image to tensor.
|
||||||
// - Accepts CPU input images and outputs CPU tensors.
|
// - Accepts CPU input images and outputs CPU tensors.
|
||||||
//
|
//
|
||||||
// Example usage:
|
// Example usage:
|
||||||
//
|
//
|
||||||
// auto& preprocessing =
|
// auto& preprocessing =
|
||||||
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||||
|
// core::proto::Acceleration acceleration;
|
||||||
|
// acceleration.mutable_xnnpack();
|
||||||
|
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
|
||||||
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
// model_resources,
|
// model_resources,
|
||||||
|
// use_gpu,
|
||||||
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
|
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
|
||||||
//
|
//
|
||||||
// The resulting ImagePreprocessing subgraph has the following I/O:
|
// The resulting ImagePreprocessing subgraph has the following I/O:
|
||||||
|
@ -56,9 +62,14 @@ namespace components {
|
||||||
// The image that has the pixel data stored on the target storage (CPU vs
|
// The image that has the pixel data stored on the target storage (CPU vs
|
||||||
// GPU).
|
// GPU).
|
||||||
absl::Status ConfigureImagePreprocessing(
|
absl::Status ConfigureImagePreprocessing(
|
||||||
const core::ModelResources& model_resources,
|
const core::ModelResources& model_resources, bool use_gpu,
|
||||||
ImagePreprocessingOptions* options);
|
ImagePreprocessingOptions* options);
|
||||||
|
|
||||||
|
// Determine if the image preprocessing subgraph should use GPU as the backend
|
||||||
|
// according to the given acceleration setting.
|
||||||
|
bool DetermineImagePreprocessingGpuBackend(
|
||||||
|
const core::proto::Acceleration& acceleration);
|
||||||
|
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -93,3 +93,46 @@ cc_test(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "combined_prediction_calculator_proto",
|
||||||
|
srcs = ["combined_prediction_calculator.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "combined_prediction_calculator",
|
||||||
|
srcs = ["combined_prediction_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
":combined_prediction_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:collection",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"@com_google_absl//absl/container:btree",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "combined_prediction_calculator_test",
|
||||||
|
srcs = ["combined_prediction_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":combined_prediction_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,187 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/btree_map.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/collection.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kPredictionTag[] = "PREDICTION";
|
||||||
|
|
||||||
|
Classification GetMaxScoringClassification(
|
||||||
|
const ClassificationList& classifications) {
|
||||||
|
Classification max_classification;
|
||||||
|
max_classification.set_score(0);
|
||||||
|
for (const auto& input : classifications.classification()) {
|
||||||
|
if (max_classification.score() < input.score()) {
|
||||||
|
max_classification = input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_classification;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GetScoreThreshold(
|
||||||
|
const std::string& input_label,
|
||||||
|
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||||
|
const std::string& background_label, const float default_threshold) {
|
||||||
|
float threshold = default_threshold;
|
||||||
|
auto it = classwise_thresholds.find(input_label);
|
||||||
|
if (it != classwise_thresholds.end()) {
|
||||||
|
threshold = it->second;
|
||||||
|
}
|
||||||
|
return threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ClassificationList> GetWinningPrediction(
|
||||||
|
const ClassificationList& classification_list,
|
||||||
|
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||||
|
const std::string& background_label, const float default_threshold) {
|
||||||
|
auto prediction_list = std::make_unique<ClassificationList>();
|
||||||
|
if (classification_list.classification().empty()) {
|
||||||
|
return prediction_list;
|
||||||
|
}
|
||||||
|
Classification& prediction = *prediction_list->add_classification();
|
||||||
|
auto argmax_prediction = GetMaxScoringClassification(classification_list);
|
||||||
|
float argmax_prediction_thresh =
|
||||||
|
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
|
||||||
|
background_label, default_threshold);
|
||||||
|
if (argmax_prediction.score() >= argmax_prediction_thresh) {
|
||||||
|
prediction.set_label(argmax_prediction.label());
|
||||||
|
prediction.set_score(argmax_prediction.score());
|
||||||
|
} else {
|
||||||
|
for (const auto& input : classification_list.classification()) {
|
||||||
|
if (input.label() == background_label) {
|
||||||
|
prediction.set_label(input.label());
|
||||||
|
prediction.set_score(input.score());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prediction_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// This calculator accepts multiple ClassificationList input streams. Each
|
||||||
|
// ClassificationList should contain classifications with labels and
|
||||||
|
// corresponding softmax scores. The calculator computes the best prediction for
|
||||||
|
// each ClassificationList input stream via argmax and thresholding. Thresholds
|
||||||
|
// for all classes can be specified in the
|
||||||
|
// `CombinedPredictionCalculatorOptions`, along with a default global
|
||||||
|
// threshold.
|
||||||
|
// Please note that for this calculator to work as designed, the class names
|
||||||
|
// other than the background class in the ClassificationList objects must be
|
||||||
|
// different, but the background class name has to be the same. This background
|
||||||
|
// label name can be set via `background_label` in
|
||||||
|
// `CombinedPredictionCalculatorOptions`.
|
||||||
|
// The ClassificationList in the PREDICTION output stream contains the label of
|
||||||
|
// the winning class and corresponding softmax score. If none of the
|
||||||
|
// ClassificationList objects has a non-background winning class, the output
|
||||||
|
// contains the background class and score of the background class in the first
|
||||||
|
// ClassificationList. If multiple ClassificationList objects have a
|
||||||
|
// non-background winning class, the output contains the winning prediction from
|
||||||
|
// the ClassificationList with the highest priority. Priority is in decreasing
|
||||||
|
// order of input streams to the graph node using this calculator.
|
||||||
|
// Input:
|
||||||
|
// At least one stream with ClassificationList.
|
||||||
|
// Output:
|
||||||
|
// PREDICTION - A ClassificationList with the winning label as the only item.
|
||||||
|
//
|
||||||
|
// Usage example:
|
||||||
|
// node {
|
||||||
|
// calculator: "CombinedPredictionCalculator"
|
||||||
|
// input_stream: "classification_list_0"
|
||||||
|
// input_stream: "classification_list_1"
|
||||||
|
// output_stream: "PREDICTION:prediction"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||||
|
// class {
|
||||||
|
// label: "A"
|
||||||
|
// score_threshold: 0.7
|
||||||
|
// }
|
||||||
|
// default_global_threshold: 0.1
|
||||||
|
// background_label: "B"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
class CombinedPredictionCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
|
||||||
|
""};
|
||||||
|
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
|
||||||
|
for (const auto& input : options_.class_()) {
|
||||||
|
classwise_thresholds_[input.label()] = input.score_threshold();
|
||||||
|
}
|
||||||
|
classwise_thresholds_[options_.background_label()] = 0;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
// After loop, if have winning prediction return. Otherwise empty packet.
|
||||||
|
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
|
||||||
|
auto collection = kClassificationListIn(cc);
|
||||||
|
for (int idx = 0; idx < collection.Count(); ++idx) {
|
||||||
|
const auto& packet = collection[idx];
|
||||||
|
if (packet.IsEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto prediction = GetWinningPrediction(
|
||||||
|
packet.Get(), classwise_thresholds_, options_.background_label(),
|
||||||
|
options_.default_global_threshold());
|
||||||
|
if (prediction->classification(0).label() !=
|
||||||
|
options_.background_label()) {
|
||||||
|
kPredictionOut(cc).Send(std::move(prediction));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
if (first_winning_prediction == nullptr) {
|
||||||
|
first_winning_prediction = std::move(prediction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (first_winning_prediction != nullptr) {
|
||||||
|
kPredictionOut(cc).Send(std::move(first_winning_prediction));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CombinedPredictionCalculatorOptions options_;
|
||||||
|
absl::btree_map<std::string, float> classwise_thresholds_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,41 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message CombinedPredictionCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional CombinedPredictionCalculatorOptions ext = 483738635;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Class {
|
||||||
|
optional string label = 1;
|
||||||
|
optional float score_threshold = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// List of classes with score thresholds.
|
||||||
|
repeated Class class = 1;
|
||||||
|
|
||||||
|
// Default score threshold applied to a label.
|
||||||
|
optional float default_global_threshold = 2 [default = 0];
|
||||||
|
|
||||||
|
// Name of the background class whose input scores will be ignored while
|
||||||
|
// thresholding.
|
||||||
|
optional string background_label = 3;
|
||||||
|
}
|
|
@ -0,0 +1,315 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kPredictionTag[] = "PREDICTION";
|
||||||
|
|
||||||
|
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
|
||||||
|
float drama_thresh, float llama_thresh, float bazinga_thresh,
|
||||||
|
float joy_thresh, float peace_thresh) {
|
||||||
|
constexpr absl::string_view kCalculatorProto = R"pb(
|
||||||
|
calculator: "CombinedPredictionCalculator"
|
||||||
|
input_stream: "custom_softmax_scores"
|
||||||
|
input_stream: "canned_softmax_scores"
|
||||||
|
output_stream: "PREDICTION:prediction"
|
||||||
|
options {
|
||||||
|
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||||
|
class { label: "CustomDrama" score_threshold: $0 }
|
||||||
|
class { label: "CustomLlama" score_threshold: $1 }
|
||||||
|
class { label: "CannedBazinga" score_threshold: $2 }
|
||||||
|
class { label: "CannedJoy" score_threshold: $3 }
|
||||||
|
class { label: "CannedPeace" score_threshold: $4 }
|
||||||
|
background_label: "Negative"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb";
|
||||||
|
auto runner = std::make_unique<CalculatorRunner>(
|
||||||
|
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
|
||||||
|
bazinga_thresh, joy_thresh, peace_thresh));
|
||||||
|
return runner;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
|
||||||
|
const float negative_score, const float drama_score,
|
||||||
|
const float llama_score) {
|
||||||
|
auto custom_scores = std::make_unique<ClassificationList>();
|
||||||
|
auto custom_negative = custom_scores->add_classification();
|
||||||
|
custom_negative->set_label("Negative");
|
||||||
|
custom_negative->set_score(negative_score);
|
||||||
|
auto drama = custom_scores->add_classification();
|
||||||
|
drama->set_label("CustomDrama");
|
||||||
|
drama->set_score(drama_score);
|
||||||
|
auto llama = custom_scores->add_classification();
|
||||||
|
llama->set_label("CustomLlama");
|
||||||
|
llama->set_score(llama_score);
|
||||||
|
return custom_scores;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
|
||||||
|
const float negative_score, const float bazinga_score,
|
||||||
|
const float joy_score, const float peace_score) {
|
||||||
|
auto canned_scores = std::make_unique<ClassificationList>();
|
||||||
|
auto canned_negative = canned_scores->add_classification();
|
||||||
|
canned_negative->set_label("Negative");
|
||||||
|
canned_negative->set_score(negative_score);
|
||||||
|
auto bazinga = canned_scores->add_classification();
|
||||||
|
bazinga->set_label("CannedBazinga");
|
||||||
|
bazinga->set_score(bazinga_score);
|
||||||
|
auto joy = canned_scores->add_classification();
|
||||||
|
joy->set_label("CannedJoy");
|
||||||
|
joy->set_score(joy_score);
|
||||||
|
auto peace = canned_scores->add_classification();
|
||||||
|
peace->set_label("CannedPeace");
|
||||||
|
peace->set_score(peace_score);
|
||||||
|
return canned_scores;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CombinedPredictionCalculatorPacketTest,
|
||||||
|
CustomEmpty_CannedEmpty_ResultIsEmpty) {
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
|
||||||
|
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CombinedPredictionCalculatorPacketTest,
|
||||||
|
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
|
||||||
|
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
|
||||||
|
auto canned_scores = BuildCannedScoreInput(
|
||||||
|
/*negative_score=*/0.1,
|
||||||
|
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
|
||||||
|
runner->MutableInputs()->Index(1).packets.push_back(
|
||||||
|
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
|
||||||
|
auto output_prediction_packets =
|
||||||
|
runner->Outputs().Tag(kPredictionTag).packets;
|
||||||
|
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||||
|
Classification output_prediction =
|
||||||
|
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_prediction.label(), "CannedJoy");
|
||||||
|
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CombinedPredictionCalculatorPacketTest,
|
||||||
|
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
|
||||||
|
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||||
|
auto custom_scores =
|
||||||
|
BuildCustomScoreInput(/*negative_score=*/0.1,
|
||||||
|
/*drama_score=*/0.2, /*llama_score=*/0.7);
|
||||||
|
runner->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
|
||||||
|
auto output_prediction_packets =
|
||||||
|
runner->Outputs().Tag(kPredictionTag).packets;
|
||||||
|
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||||
|
Classification output_prediction =
|
||||||
|
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_prediction.label(), "CustomLlama");
|
||||||
|
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CombinedPredictionCalculatorTestCase {
|
||||||
|
std::string test_name;
|
||||||
|
float custom_negative_score;
|
||||||
|
float drama_score;
|
||||||
|
float llama_score;
|
||||||
|
float drama_thresh;
|
||||||
|
float llama_thresh;
|
||||||
|
float canned_negative_score;
|
||||||
|
float bazinga_score;
|
||||||
|
float joy_score;
|
||||||
|
float peace_score;
|
||||||
|
float bazinga_thresh;
|
||||||
|
float joy_thresh;
|
||||||
|
float peace_thresh;
|
||||||
|
std::string max_scoring_label;
|
||||||
|
float max_score;
|
||||||
|
};
|
||||||
|
|
||||||
|
using CombinedPredictionCalculatorTest =
|
||||||
|
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
|
||||||
|
|
||||||
|
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
|
||||||
|
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
|
||||||
|
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
|
||||||
|
test_case.joy_thresh, test_case.peace_thresh);
|
||||||
|
|
||||||
|
auto custom_scores =
|
||||||
|
BuildCustomScoreInput(test_case.custom_negative_score,
|
||||||
|
test_case.drama_score, test_case.llama_score);
|
||||||
|
|
||||||
|
runner->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
auto canned_scores = BuildCannedScoreInput(
|
||||||
|
test_case.canned_negative_score, test_case.bazinga_score,
|
||||||
|
test_case.joy_score, test_case.peace_score);
|
||||||
|
runner->MutableInputs()->Index(1).packets.push_back(
|
||||||
|
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
|
||||||
|
auto output_prediction_packets =
|
||||||
|
runner->Outputs().Tag(kPredictionTag).packets;
|
||||||
|
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||||
|
Classification output_prediction =
|
||||||
|
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
|
||||||
|
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
|
||||||
|
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
|
||||||
|
{
|
||||||
|
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.5,
|
||||||
|
.llama_score = 0.3,
|
||||||
|
.drama_thresh = 0.25,
|
||||||
|
.llama_thresh = 0.7,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "CustomDrama",
|
||||||
|
.max_score = 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.3,
|
||||||
|
.llama_score = 0.6,
|
||||||
|
.drama_thresh = 0.4,
|
||||||
|
.llama_thresh = 0.8,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.4,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.2,
|
||||||
|
.bazinga_thresh = 0.0,
|
||||||
|
.joy_thresh = 0.0,
|
||||||
|
.peace_thresh = 0.0,
|
||||||
|
.max_scoring_label = "CannedBazinga",
|
||||||
|
.max_score = 0.4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.5,
|
||||||
|
.drama_score = 0.1,
|
||||||
|
.llama_score = 0.4,
|
||||||
|
.drama_thresh = 0.1,
|
||||||
|
.llama_thresh = 0.05,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.8,
|
||||||
|
.drama_score = 0.1,
|
||||||
|
.llama_score = 0.1,
|
||||||
|
.drama_thresh = 0.25,
|
||||||
|
.llama_thresh = 0.7,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.2,
|
||||||
|
.llama_score = 0.7,
|
||||||
|
.drama_thresh = 1.1,
|
||||||
|
.llama_thresh = 1.1,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.3,
|
||||||
|
.llama_score = 0.6,
|
||||||
|
.drama_thresh = 0.4,
|
||||||
|
.llama_thresh = 0.8,
|
||||||
|
.canned_negative_score = 0.3,
|
||||||
|
.bazinga_score = 0.2,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.2,
|
||||||
|
.bazinga_thresh = 0.5,
|
||||||
|
.joy_thresh = 0.5,
|
||||||
|
.peace_thresh = 0.5,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.1,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
[](const testing::TestParamInfo<
|
||||||
|
CombinedPredictionCalculatorTest::ParamType>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
||||||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||||
image_to_tensor_options.set_border_mode(
|
image_to_tensor_options.set_border_mode(
|
||||||
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
|
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
|
||||||
|
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||||
|
subgraph_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
model_resources,
|
model_resources, use_gpu,
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In("IMAGE");
|
image_in >> preprocessing.In("IMAGE");
|
||||||
|
|
|
@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
auto& preprocessing =
|
auto& preprocessing =
|
||||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||||
|
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||||
|
subgraph_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
model_resources,
|
model_resources, use_gpu,
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In("IMAGE");
|
image_in >> preprocessing.In("IMAGE");
|
||||||
|
|
|
@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
// stream.
|
// stream.
|
||||||
auto& preprocessing =
|
auto& preprocessing =
|
||||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||||
|
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||||
|
task_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
model_resources,
|
model_resources, use_gpu,
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
|
|
@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// stream.
|
// stream.
|
||||||
auto& preprocessing =
|
auto& preprocessing =
|
||||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||||
|
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||||
|
task_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
model_resources,
|
model_resources, use_gpu,
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
|
|
@ -243,8 +243,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
// stream.
|
// stream.
|
||||||
auto& preprocessing =
|
auto& preprocessing =
|
||||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||||
|
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||||
|
task_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
model_resources,
|
model_resources, use_gpu,
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
|
|
@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
||||||
// stream.
|
// stream.
|
||||||
auto& preprocessing =
|
auto& preprocessing =
|
||||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||||
|
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||||
|
task_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||||
model_resources,
|
model_resources, use_gpu,
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
|
|
@ -40,6 +40,10 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||||
|
]
|
||||||
|
|
||||||
def mediapipe_tasks_core_aar(name, srcs, manifest):
|
def mediapipe_tasks_core_aar(name, srcs, manifest):
|
||||||
"""Builds medaipipe tasks core AAR.
|
"""Builds medaipipe tasks core AAR.
|
||||||
|
|
||||||
|
@ -60,6 +64,11 @@ def mediapipe_tasks_core_aar(name, srcs, manifest):
|
||||||
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS:
|
||||||
|
mediapipe_tasks_java_proto_srcs.append(
|
||||||
|
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
|
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
|
||||||
target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||||
src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java",
|
src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java",
|
||||||
|
@ -81,32 +90,35 @@ def mediapipe_tasks_core_aar(name, srcs, manifest):
|
||||||
],
|
],
|
||||||
manifest = manifest,
|
manifest = manifest,
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||||
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
||||||
"//mediapipe/framework:calculator_java_proto_lite",
|
"//mediapipe/framework:calculator_java_proto_lite",
|
||||||
"//mediapipe/framework:calculator_profile_java_proto_lite",
|
"//mediapipe/framework:calculator_profile_java_proto_lite",
|
||||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
"//mediapipe/framework:mediapipe_options_java_proto_lite",
|
"//mediapipe/framework:mediapipe_options_java_proto_lite",
|
||||||
"//mediapipe/framework:packet_factory_java_proto_lite",
|
"//mediapipe/framework:packet_factory_java_proto_lite",
|
||||||
"//mediapipe/framework:packet_generator_java_proto_lite",
|
"//mediapipe/framework:packet_generator_java_proto_lite",
|
||||||
"//mediapipe/framework:status_handler_java_proto_lite",
|
"//mediapipe/framework:status_handler_java_proto_lite",
|
||||||
"//mediapipe/framework:stream_handler_java_proto_lite",
|
"//mediapipe/framework:stream_handler_java_proto_lite",
|
||||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||||
"//mediapipe/framework/formats:detection_java_proto_lite",
|
"//mediapipe/framework/formats:detection_java_proto_lite",
|
||||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||||
"//mediapipe/framework/formats:location_data_java_proto_lite",
|
"//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||||
"//mediapipe/framework/formats:rect_java_proto_lite",
|
"//mediapipe/framework/formats:rect_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||||
"//third_party:androidx_annotation",
|
"//third_party:androidx_annotation",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@com_google_protobuf//:protobuf_javalite",
|
"@com_google_protobuf//:protobuf_javalite",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
"@maven//:com_google_flogger_flogger",
|
"@maven//:com_google_flogger_flogger",
|
||||||
"@maven//:com_google_flogger_flogger_system_backend",
|
"@maven//:com_google_flogger_flogger_system_backend",
|
||||||
"@maven//:com_google_code_findbugs_jsr305",
|
"@maven//:com_google_code_findbugs_jsr305",
|
||||||
] + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS,
|
] +
|
||||||
|
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS +
|
||||||
|
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS +
|
||||||
|
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mediapipe_tasks_vision_aar(name, srcs, native_library):
|
def mediapipe_tasks_vision_aar(name, srcs, native_library):
|
||||||
|
@ -142,6 +154,39 @@ EOF
|
||||||
native_library = native_library,
|
native_library = native_library,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mediapipe_tasks_text_aar(name, srcs, native_library):
|
||||||
|
"""Builds medaipipe tasks text AAR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The bazel target name.
|
||||||
|
srcs: MediaPipe Text Tasks' source files.
|
||||||
|
native_library: The native library that contains text tasks' graph and calculators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
native.genrule(
|
||||||
|
name = name + "tasks_manifest_generator",
|
||||||
|
outs = ["AndroidManifest.xml"],
|
||||||
|
cmd = """
|
||||||
|
cat > $(OUTS) <<EOF
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.text">
|
||||||
|
<uses-sdk
|
||||||
|
android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
</manifest>
|
||||||
|
EOF
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
_mediapipe_tasks_aar(
|
||||||
|
name = name,
|
||||||
|
srcs = srcs,
|
||||||
|
manifest = "AndroidManifest.xml",
|
||||||
|
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||||
|
native_library = native_library,
|
||||||
|
)
|
||||||
|
|
||||||
def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library):
|
def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library):
|
||||||
"""Builds medaipipe tasks AAR."""
|
"""Builds medaipipe tasks AAR."""
|
||||||
|
|
||||||
|
|
|
@ -61,3 +61,11 @@ android_library(
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
||||||
|
|
||||||
|
mediapipe_tasks_text_aar(
|
||||||
|
name = "tasks_text",
|
||||||
|
srcs = glob(["**/*.java"]),
|
||||||
|
native_library = ":libmediapipe_tasks_text_jni_lib",
|
||||||
|
)
|
||||||
|
|
|
@ -84,33 +84,28 @@ cc_library(
|
||||||
"//conditions:default": ["tflite_gpu_runner.h"],
|
"//conditions:default": ["tflite_gpu_runner.h"],
|
||||||
}),
|
}),
|
||||||
deps = select({
|
deps = select({
|
||||||
"//mediapipe:ios": [],
|
"//mediapipe:ios": [],
|
||||||
"//mediapipe:macos": [],
|
"//mediapipe:macos": [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||||
],
|
],
|
||||||
"//mediapipe:android": [
|
}) +
|
||||||
"@com_google_absl//absl/strings",
|
select({
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe:android": [
|
||||||
"//mediapipe/framework/port:status",
|
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
|
||||||
"//mediapipe/framework/port:statusor",
|
],
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"//conditions:default": [],
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
}) + [
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
|
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
|
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
|
||||||
],
|
|
||||||
}) + [
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
"//mediapipe/framework:port",
|
||||||
"@org_tensorflow//tensorflow/lite/core/api",
|
"@org_tensorflow//tensorflow/lite/core/api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
|
|
||||||
// This code should be enabled as soon as TensorFlow version, which mediapipe
|
// This code should be enabled as soon as TensorFlow version, which mediapipe
|
||||||
// uses, will include this module.
|
// uses, will include this module.
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ ObjectDef GetSSBOObjectDef(int channels) {
|
||||||
return gpu_object_def;
|
return gpu_object_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
|
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
|
||||||
cl::InferenceOptions result{};
|
cl::InferenceOptions result{};
|
||||||
|
@ -106,7 +106,7 @@ absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ absl::Status TFLiteGPURunner::InitializeOpenGL(
|
||||||
|
|
||||||
absl::Status TFLiteGPURunner::InitializeOpenCL(
|
absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||||
std::unique_ptr<InferenceBuilder>* builder) {
|
std::unique_ptr<InferenceBuilder>* builder) {
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
cl::InferenceEnvironmentOptions env_options;
|
cl::InferenceEnvironmentOptions env_options;
|
||||||
if (!serialized_binary_cache_.empty()) {
|
if (!serialized_binary_cache_.empty()) {
|
||||||
env_options.serialized_binary_cache = serialized_binary_cache_;
|
env_options.serialized_binary_cache = serialized_binary_cache_;
|
||||||
|
@ -254,11 +254,12 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
#else
|
#else
|
||||||
return mediapipe::UnimplementedError("Currently only Android is supported");
|
return mediapipe::UnimplementedError(
|
||||||
#endif // __ANDROID__
|
"Currently only Android & ChromeOS are supported");
|
||||||
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
||||||
std::unique_ptr<InferenceBuilder>* builder) {
|
std::unique_ptr<InferenceBuilder>* builder) {
|
||||||
|
@ -283,7 +284,7 @@ absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
||||||
return serialized_model;
|
return serialized_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/port.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -28,9 +29,9 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
@ -83,7 +84,7 @@ class TFLiteGPURunner {
|
||||||
return output_shape_from_model_;
|
return output_shape_from_model_;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
void SetSerializedBinaryCache(std::vector<uint8_t>&& cache) {
|
void SetSerializedBinaryCache(std::vector<uint8_t>&& cache) {
|
||||||
serialized_binary_cache_ = std::move(cache);
|
serialized_binary_cache_ = std::move(cache);
|
||||||
}
|
}
|
||||||
|
@ -98,26 +99,26 @@ class TFLiteGPURunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
|
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
|
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
|
||||||
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
|
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
absl::Status InitializeOpenCLFromSerializedModel(
|
absl::Status InitializeOpenCLFromSerializedModel(
|
||||||
std::unique_ptr<InferenceBuilder>* builder);
|
std::unique_ptr<InferenceBuilder>* builder);
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
InferenceOptions options_;
|
InferenceOptions options_;
|
||||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||||
|
|
||||||
std::vector<uint8_t> serialized_binary_cache_;
|
std::vector<uint8_t> serialized_binary_cache_;
|
||||||
std::vector<uint8_t> serialized_model_;
|
std::vector<uint8_t> serialized_model_;
|
||||||
bool serialized_model_used_ = false;
|
bool serialized_model_used_ = false;
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is
|
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is
|
||||||
// ready
|
// ready
|
||||||
|
|
34
third_party/com_google_sentencepiece_no_gflag_no_gtest.diff
vendored
Normal file
34
third_party/com_google_sentencepiece_no_gflag_no_gtest.diff
vendored
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
diff --git a/src/BUILD b/src/BUILD
|
||||||
|
index b4298d2..f3877a3 100644
|
||||||
|
--- a/src/BUILD
|
||||||
|
+++ b/src/BUILD
|
||||||
|
@@ -71,9 +71,7 @@ cc_library(
|
||||||
|
":common",
|
||||||
|
":sentencepiece_cc_proto",
|
||||||
|
":sentencepiece_model_cc_proto",
|
||||||
|
- "@com_github_gflags_gflags//:gflags",
|
||||||
|
"@com_google_glog//:glog",
|
||||||
|
- "@com_google_googletest//:gtest",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
diff --git a/src/normalizer.h b/src/normalizer.h
|
||||||
|
index c16ac16..2af58be 100644
|
||||||
|
--- a/src/normalizer.h
|
||||||
|
+++ b/src/normalizer.h
|
||||||
|
@@ -21,7 +21,6 @@
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
-#include "gtest/gtest_prod.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "third_party/darts_clone/include/darts.h"
|
||||||
|
#include "src/common.h"
|
||||||
|
@@ -97,7 +96,6 @@ class Normalizer {
|
||||||
|
friend class Builder;
|
||||||
|
|
||||||
|
private:
|
||||||
|
- FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest);
|
||||||
|
|
||||||
|
void Init();
|
||||||
|
|
Loading…
Reference in New Issue
Block a user