Merge branch 'google:master' into gesture-recognizer-python
This commit is contained in:
commit
0f7c5d5e90
|
@ -172,6 +172,10 @@ http_archive(
|
|||
urls = [
|
||||
"https://github.com/google/sentencepiece/archive/1.0.0.zip",
|
||||
],
|
||||
patches = [
|
||||
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
|
||||
],
|
||||
patch_args = ["-p1"],
|
||||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
|
||||
)
|
||||
|
||||
|
|
14
docs/BUILD
Normal file
14
docs/BUILD
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Placeholder for internal Python strict binary compatibility macro.
|
||||
|
||||
py_binary(
|
||||
name = "build_py_api_docs",
|
||||
srcs = ["build_py_api_docs.py"],
|
||||
deps = [
|
||||
"//mediapipe",
|
||||
"//third_party/py/absl:app",
|
||||
"//third_party/py/absl/flags",
|
||||
"//third_party/py/tensorflow_docs",
|
||||
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
|
||||
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||
],
|
||||
)
|
85
docs/build_py_api_docs.py
Normal file
85
docs/build_py_api_docs.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""MediaPipe reference docs generation script.
|
||||
|
||||
This script generates API reference docs for the `mediapipe` PIP package.
|
||||
|
||||
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
|
||||
$> python build_py_api_docs.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow_docs.api_generator import generate_lib
|
||||
from tensorflow_docs.api_generator import public_api
|
||||
|
||||
try:
|
||||
# mediapipe has not been set up to work with bazel yet, so catch & report.
|
||||
import mediapipe # pytype: disable=import-error
|
||||
except ImportError as e:
|
||||
raise ImportError('Please `pip install mediapipe`.') from e
|
||||
|
||||
|
||||
PROJECT_SHORT_NAME = 'mp'
|
||||
PROJECT_FULL_NAME = 'MediaPipe'
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir',
|
||||
default='/tmp/generated_docs',
|
||||
help='Where to write the resulting docs.')
|
||||
|
||||
_URL_PREFIX = flags.DEFINE_string(
|
||||
'code_url_prefix',
|
||||
'https://github.com/google/mediapipe/tree/master/mediapipe',
|
||||
'The url prefix for links to code.')
|
||||
|
||||
_SEARCH_HINTS = flags.DEFINE_bool(
|
||||
'search_hints', True,
|
||||
'Include metadata search hints in the generated files')
|
||||
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
|
||||
'Path prefix in the _toc.yaml')
|
||||
|
||||
|
||||
def gen_api_docs():
|
||||
"""Generates API docs for the mediapipe package."""
|
||||
|
||||
doc_generator = generate_lib.DocGenerator(
|
||||
root_title=PROJECT_FULL_NAME,
|
||||
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
|
||||
base_dir=os.path.dirname(mediapipe.__file__),
|
||||
code_url_prefix=_URL_PREFIX.value,
|
||||
search_hints=_SEARCH_HINTS.value,
|
||||
site_path=_SITE_PATH.value,
|
||||
# This callback ensures that docs are only generated for objects that
|
||||
# are explicitly imported in your __init__.py files. There are other
|
||||
# options but this is a good starting point.
|
||||
callbacks=[public_api.explicit_package_contents_filter],
|
||||
)
|
||||
|
||||
doc_generator.build(_OUTPUT_DIR.value)
|
||||
|
||||
print('Docs output to:', _OUTPUT_DIR.value)
|
||||
|
||||
|
||||
def main(_):
|
||||
gen_api_docs()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -289,8 +289,15 @@ class NodeBase {
|
|||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
return *options_.MutableExtension(extension);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -386,8 +393,15 @@ class PacketGenerator {
|
|||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
return *options_.MutableExtension(extension);
|
||||
}
|
||||
|
||||
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
|
||||
|
|
|
@ -161,7 +161,7 @@ class Texture {
|
|||
|
||||
~Texture() {
|
||||
if (is_owned_) {
|
||||
glDeleteProgram(handle_);
|
||||
glDeleteTextures(1, &handle_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ cc_library(
|
|||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -129,7 +130,7 @@ absl::Status ConfigureImageToTensorCalculator(
|
|||
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
||||
std);
|
||||
}
|
||||
// TODO: need to.support different GPU origin on differnt
|
||||
// TODO: need to support different GPU origin on differnt
|
||||
// platforms or applications.
|
||||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||
return absl::OkStatus();
|
||||
|
@ -137,7 +138,13 @@ absl::Status ConfigureImageToTensorCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
bool DetermineImagePreprocessingGpuBackend(
|
||||
const core::proto::Acceleration& acceleration) {
|
||||
return acceleration.has_gpu();
|
||||
}
|
||||
|
||||
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
||||
bool use_gpu,
|
||||
ImagePreprocessingOptions* options) {
|
||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildImageTensorSpecs(model_resources));
|
||||
|
@ -145,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
|||
image_tensor_specs, options->mutable_image_to_tensor_options()));
|
||||
// The GPU backend isn't able to process int data. If the input tensor is
|
||||
// quantized, forces the image preprocessing graph to use CPU backend.
|
||||
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) {
|
||||
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
|
||||
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
|
||||
} else {
|
||||
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -19,20 +19,26 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
|
||||
// Configures an ImagePreprocessing subgraph using the provided model resources.
|
||||
// Configures an ImagePreprocessing subgraph using the provided model resources
|
||||
// When use_gpu is true, use GPU as backend to convert image to tensor.
|
||||
// - Accepts CPU input images and outputs CPU tensors.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// auto& preprocessing =
|
||||
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
// core::proto::Acceleration acceleration;
|
||||
// acceleration.mutable_xnnpack();
|
||||
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
|
||||
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
// model_resources,
|
||||
// use_gpu,
|
||||
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
|
||||
//
|
||||
// The resulting ImagePreprocessing subgraph has the following I/O:
|
||||
|
@ -56,9 +62,14 @@ namespace components {
|
|||
// The image that has the pixel data stored on the target storage (CPU vs
|
||||
// GPU).
|
||||
absl::Status ConfigureImagePreprocessing(
|
||||
const core::ModelResources& model_resources,
|
||||
const core::ModelResources& model_resources, bool use_gpu,
|
||||
ImagePreprocessingOptions* options);
|
||||
|
||||
// Determine if the image preprocessing subgraph should use GPU as the backend
|
||||
// according to the given acceleration setting.
|
||||
bool DetermineImagePreprocessingGpuBackend(
|
||||
const core::proto::Acceleration& acceleration);
|
||||
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -93,3 +93,46 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "combined_prediction_calculator_proto",
|
||||
srcs = ["combined_prediction_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "combined_prediction_calculator",
|
||||
srcs = ["combined_prediction_calculator.cc"],
|
||||
deps = [
|
||||
":combined_prediction_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "combined_prediction_calculator_test",
|
||||
srcs = ["combined_prediction_calculator_test.cc"],
|
||||
deps = [
|
||||
":combined_prediction_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
constexpr char kPredictionTag[] = "PREDICTION";
|
||||
|
||||
Classification GetMaxScoringClassification(
|
||||
const ClassificationList& classifications) {
|
||||
Classification max_classification;
|
||||
max_classification.set_score(0);
|
||||
for (const auto& input : classifications.classification()) {
|
||||
if (max_classification.score() < input.score()) {
|
||||
max_classification = input;
|
||||
}
|
||||
}
|
||||
return max_classification;
|
||||
}
|
||||
|
||||
float GetScoreThreshold(
|
||||
const std::string& input_label,
|
||||
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||
const std::string& background_label, const float default_threshold) {
|
||||
float threshold = default_threshold;
|
||||
auto it = classwise_thresholds.find(input_label);
|
||||
if (it != classwise_thresholds.end()) {
|
||||
threshold = it->second;
|
||||
}
|
||||
return threshold;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> GetWinningPrediction(
|
||||
const ClassificationList& classification_list,
|
||||
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||
const std::string& background_label, const float default_threshold) {
|
||||
auto prediction_list = std::make_unique<ClassificationList>();
|
||||
if (classification_list.classification().empty()) {
|
||||
return prediction_list;
|
||||
}
|
||||
Classification& prediction = *prediction_list->add_classification();
|
||||
auto argmax_prediction = GetMaxScoringClassification(classification_list);
|
||||
float argmax_prediction_thresh =
|
||||
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
|
||||
background_label, default_threshold);
|
||||
if (argmax_prediction.score() >= argmax_prediction_thresh) {
|
||||
prediction.set_label(argmax_prediction.label());
|
||||
prediction.set_score(argmax_prediction.score());
|
||||
} else {
|
||||
for (const auto& input : classification_list.classification()) {
|
||||
if (input.label() == background_label) {
|
||||
prediction.set_label(input.label());
|
||||
prediction.set_score(input.score());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return prediction_list;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// This calculator accepts multiple ClassificationList input streams. Each
|
||||
// ClassificationList should contain classifications with labels and
|
||||
// corresponding softmax scores. The calculator computes the best prediction for
|
||||
// each ClassificationList input stream via argmax and thresholding. Thresholds
|
||||
// for all classes can be specified in the
|
||||
// `CombinedPredictionCalculatorOptions`, along with a default global
|
||||
// threshold.
|
||||
// Please note that for this calculator to work as designed, the class names
|
||||
// other than the background class in the ClassificationList objects must be
|
||||
// different, but the background class name has to be the same. This background
|
||||
// label name can be set via `background_label` in
|
||||
// `CombinedPredictionCalculatorOptions`.
|
||||
// The ClassificationList in the PREDICTION output stream contains the label of
|
||||
// the winning class and corresponding softmax score. If none of the
|
||||
// ClassificationList objects has a non-background winning class, the output
|
||||
// contains the background class and score of the background class in the first
|
||||
// ClassificationList. If multiple ClassificationList objects have a
|
||||
// non-background winning class, the output contains the winning prediction from
|
||||
// the ClassificationList with the highest priority. Priority is in decreasing
|
||||
// order of input streams to the graph node using this calculator.
|
||||
// Input:
|
||||
// At least one stream with ClassificationList.
|
||||
// Output:
|
||||
// PREDICTION - A ClassificationList with the winning label as the only item.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "CombinedPredictionCalculator"
|
||||
// input_stream: "classification_list_0"
|
||||
// input_stream: "classification_list_1"
|
||||
// output_stream: "PREDICTION:prediction"
|
||||
// options {
|
||||
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||
// class {
|
||||
// label: "A"
|
||||
// score_threshold: 0.7
|
||||
// }
|
||||
// default_global_threshold: 0.1
|
||||
// background_label: "B"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
class CombinedPredictionCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
|
||||
""};
|
||||
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
|
||||
for (const auto& input : options_.class_()) {
|
||||
classwise_thresholds_[input.label()] = input.score_threshold();
|
||||
}
|
||||
classwise_thresholds_[options_.background_label()] = 0;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
// After loop, if have winning prediction return. Otherwise empty packet.
|
||||
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
|
||||
auto collection = kClassificationListIn(cc);
|
||||
for (int idx = 0; idx < collection.Count(); ++idx) {
|
||||
const auto& packet = collection[idx];
|
||||
if (packet.IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
auto prediction = GetWinningPrediction(
|
||||
packet.Get(), classwise_thresholds_, options_.background_label(),
|
||||
options_.default_global_threshold());
|
||||
if (prediction->classification(0).label() !=
|
||||
options_.background_label()) {
|
||||
kPredictionOut(cc).Send(std::move(prediction));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (first_winning_prediction == nullptr) {
|
||||
first_winning_prediction = std::move(prediction);
|
||||
}
|
||||
}
|
||||
if (first_winning_prediction != nullptr) {
|
||||
kPredictionOut(cc).Send(std::move(first_winning_prediction));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
CombinedPredictionCalculatorOptions options_;
|
||||
absl::btree_map<std::string, float> classwise_thresholds_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message CombinedPredictionCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional CombinedPredictionCalculatorOptions ext = 483738635;
|
||||
}
|
||||
|
||||
message Class {
|
||||
optional string label = 1;
|
||||
optional float score_threshold = 2;
|
||||
}
|
||||
|
||||
// List of classes with score thresholds.
|
||||
repeated Class class = 1;
|
||||
|
||||
// Default score threshold applied to a label.
|
||||
optional float default_global_threshold = 2 [default = 0];
|
||||
|
||||
// Name of the background class whose input scores will be ignored while
|
||||
// thresholding.
|
||||
optional string background_label = 3;
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kPredictionTag[] = "PREDICTION";
|
||||
|
||||
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
|
||||
float drama_thresh, float llama_thresh, float bazinga_thresh,
|
||||
float joy_thresh, float peace_thresh) {
|
||||
constexpr absl::string_view kCalculatorProto = R"pb(
|
||||
calculator: "CombinedPredictionCalculator"
|
||||
input_stream: "custom_softmax_scores"
|
||||
input_stream: "canned_softmax_scores"
|
||||
output_stream: "PREDICTION:prediction"
|
||||
options {
|
||||
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||
class { label: "CustomDrama" score_threshold: $0 }
|
||||
class { label: "CustomLlama" score_threshold: $1 }
|
||||
class { label: "CannedBazinga" score_threshold: $2 }
|
||||
class { label: "CannedJoy" score_threshold: $3 }
|
||||
class { label: "CannedPeace" score_threshold: $4 }
|
||||
background_label: "Negative"
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
auto runner = std::make_unique<CalculatorRunner>(
|
||||
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
|
||||
bazinga_thresh, joy_thresh, peace_thresh));
|
||||
return runner;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
|
||||
const float negative_score, const float drama_score,
|
||||
const float llama_score) {
|
||||
auto custom_scores = std::make_unique<ClassificationList>();
|
||||
auto custom_negative = custom_scores->add_classification();
|
||||
custom_negative->set_label("Negative");
|
||||
custom_negative->set_score(negative_score);
|
||||
auto drama = custom_scores->add_classification();
|
||||
drama->set_label("CustomDrama");
|
||||
drama->set_score(drama_score);
|
||||
auto llama = custom_scores->add_classification();
|
||||
llama->set_label("CustomLlama");
|
||||
llama->set_score(llama_score);
|
||||
return custom_scores;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
|
||||
const float negative_score, const float bazinga_score,
|
||||
const float joy_score, const float peace_score) {
|
||||
auto canned_scores = std::make_unique<ClassificationList>();
|
||||
auto canned_negative = canned_scores->add_classification();
|
||||
canned_negative->set_label("Negative");
|
||||
canned_negative->set_score(negative_score);
|
||||
auto bazinga = canned_scores->add_classification();
|
||||
bazinga->set_label("CannedBazinga");
|
||||
bazinga->set_score(bazinga_score);
|
||||
auto joy = canned_scores->add_classification();
|
||||
joy->set_label("CannedJoy");
|
||||
joy->set_score(joy_score);
|
||||
auto peace = canned_scores->add_classification();
|
||||
peace->set_label("CannedPeace");
|
||||
peace->set_score(peace_score);
|
||||
return canned_scores;
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomEmpty_CannedEmpty_ResultIsEmpty) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
|
||||
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
|
||||
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
|
||||
auto canned_scores = BuildCannedScoreInput(
|
||||
/*negative_score=*/0.1,
|
||||
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
|
||||
runner->MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), "CannedJoy");
|
||||
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
|
||||
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||
auto custom_scores =
|
||||
BuildCustomScoreInput(/*negative_score=*/0.1,
|
||||
/*drama_score=*/0.2, /*llama_score=*/0.7);
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), "CustomLlama");
|
||||
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
|
||||
}
|
||||
|
||||
struct CombinedPredictionCalculatorTestCase {
|
||||
std::string test_name;
|
||||
float custom_negative_score;
|
||||
float drama_score;
|
||||
float llama_score;
|
||||
float drama_thresh;
|
||||
float llama_thresh;
|
||||
float canned_negative_score;
|
||||
float bazinga_score;
|
||||
float joy_score;
|
||||
float peace_score;
|
||||
float bazinga_thresh;
|
||||
float joy_thresh;
|
||||
float peace_thresh;
|
||||
std::string max_scoring_label;
|
||||
float max_score;
|
||||
};
|
||||
|
||||
using CombinedPredictionCalculatorTest =
|
||||
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
|
||||
|
||||
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
|
||||
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
|
||||
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
|
||||
test_case.joy_thresh, test_case.peace_thresh);
|
||||
|
||||
auto custom_scores =
|
||||
BuildCustomScoreInput(test_case.custom_negative_score,
|
||||
test_case.drama_score, test_case.llama_score);
|
||||
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||
|
||||
auto canned_scores = BuildCannedScoreInput(
|
||||
test_case.canned_negative_score, test_case.bazinga_score,
|
||||
test_case.joy_score, test_case.peace_score);
|
||||
runner->MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
|
||||
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
|
||||
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
|
||||
{
|
||||
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.5,
|
||||
.llama_score = 0.3,
|
||||
.drama_thresh = 0.25,
|
||||
.llama_thresh = 0.7,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "CustomDrama",
|
||||
.max_score = 0.5,
|
||||
},
|
||||
{
|
||||
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.3,
|
||||
.llama_score = 0.6,
|
||||
.drama_thresh = 0.4,
|
||||
.llama_thresh = 0.8,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.4,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.2,
|
||||
.bazinga_thresh = 0.0,
|
||||
.joy_thresh = 0.0,
|
||||
.peace_thresh = 0.0,
|
||||
.max_scoring_label = "CannedBazinga",
|
||||
.max_score = 0.4,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.5,
|
||||
.drama_score = 0.1,
|
||||
.llama_score = 0.4,
|
||||
.drama_thresh = 0.1,
|
||||
.llama_thresh = 0.05,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.5,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.8,
|
||||
.drama_score = 0.1,
|
||||
.llama_score = 0.1,
|
||||
.drama_thresh = 0.25,
|
||||
.llama_thresh = 0.7,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.8,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.2,
|
||||
.llama_score = 0.7,
|
||||
.drama_thresh = 1.1,
|
||||
.llama_thresh = 1.1,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.1,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.3,
|
||||
.llama_score = 0.6,
|
||||
.drama_thresh = 0.4,
|
||||
.llama_thresh = 0.8,
|
||||
.canned_negative_score = 0.3,
|
||||
.bazinga_score = 0.2,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.2,
|
||||
.bazinga_thresh = 0.5,
|
||||
.joy_thresh = 0.5,
|
||||
.peace_thresh = 0.5,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.1,
|
||||
},
|
||||
}),
|
||||
[](const testing::TestParamInfo<
|
||||
CombinedPredictionCalculatorTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
|
@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||
image_to_tensor_options.set_border_mode(
|
||||
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In("IMAGE");
|
||||
|
|
|
@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In("IMAGE");
|
||||
|
|
|
@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -243,8 +243,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -40,6 +40,10 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
def mediapipe_tasks_core_aar(name, srcs, manifest):
|
||||
"""Builds medaipipe tasks core AAR.
|
||||
|
||||
|
@ -60,6 +64,11 @@ def mediapipe_tasks_core_aar(name, srcs, manifest):
|
|||
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||
)
|
||||
|
||||
for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS:
|
||||
mediapipe_tasks_java_proto_srcs.append(
|
||||
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||
)
|
||||
|
||||
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java",
|
||||
|
@ -106,7 +115,10 @@ def mediapipe_tasks_core_aar(name, srcs, manifest):
|
|||
"@maven//:com_google_flogger_flogger",
|
||||
"@maven//:com_google_flogger_flogger_system_backend",
|
||||
"@maven//:com_google_code_findbugs_jsr305",
|
||||
] + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||
] +
|
||||
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS +
|
||||
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS +
|
||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||
)
|
||||
|
||||
def mediapipe_tasks_vision_aar(name, srcs, native_library):
|
||||
|
@ -142,6 +154,39 @@ EOF
|
|||
native_library = native_library,
|
||||
)
|
||||
|
||||
def mediapipe_tasks_text_aar(name, srcs, native_library):
|
||||
"""Builds medaipipe tasks text AAR.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
srcs: MediaPipe Text Tasks' source files.
|
||||
native_library: The native library that contains text tasks' graph and calculators.
|
||||
"""
|
||||
|
||||
native.genrule(
|
||||
name = name + "tasks_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.text">
|
||||
<uses-sdk
|
||||
android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
)
|
||||
|
||||
_mediapipe_tasks_aar(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
manifest = "AndroidManifest.xml",
|
||||
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||
native_library = native_library,
|
||||
)
|
||||
|
||||
def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library):
|
||||
"""Builds medaipipe tasks AAR."""
|
||||
|
||||
|
|
|
@ -61,3 +61,11 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
||||
|
||||
mediapipe_tasks_text_aar(
|
||||
name = "tasks_text",
|
||||
srcs = glob(["**/*.java"]),
|
||||
native_library = ":libmediapipe_tasks_text_jni_lib",
|
||||
)
|
||||
|
|
|
@ -97,20 +97,15 @@ cc_library(
|
|||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
],
|
||||
}) +
|
||||
select({
|
||||
"//mediapipe:android": [
|
||||
"@com_google_absl//absl/strings",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:api",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/status",
|
||||
"//mediapipe/framework:port",
|
||||
"@org_tensorflow//tensorflow/lite/core/api",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
|
||||
// This code should be enabled as soon as TensorFlow version, which mediapipe
|
||||
// uses, will include this module.
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||
#endif
|
||||
|
||||
|
@ -82,7 +82,7 @@ ObjectDef GetSSBOObjectDef(int channels) {
|
|||
return gpu_object_def;
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
|
||||
cl::InferenceOptions result{};
|
||||
|
@ -106,7 +106,7 @@ absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
#endif // __ANDROID__
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -225,7 +225,7 @@ absl::Status TFLiteGPURunner::InitializeOpenGL(
|
|||
|
||||
absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||
std::unique_ptr<InferenceBuilder>* builder) {
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
cl::InferenceEnvironmentOptions env_options;
|
||||
if (!serialized_binary_cache_.empty()) {
|
||||
env_options.serialized_binary_cache = serialized_binary_cache_;
|
||||
|
@ -254,11 +254,12 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
|
|||
|
||||
return absl::OkStatus();
|
||||
#else
|
||||
return mediapipe::UnimplementedError("Currently only Android is supported");
|
||||
#endif // __ANDROID__
|
||||
return mediapipe::UnimplementedError(
|
||||
"Currently only Android & ChromeOS are supported");
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
||||
std::unique_ptr<InferenceBuilder>* builder) {
|
||||
|
@ -283,7 +284,7 @@ absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
|||
return serialized_model;
|
||||
}
|
||||
|
||||
#endif // __ANDROID__
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -28,9 +29,9 @@
|
|||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||
#endif // __ANDROID__
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
@ -83,7 +84,7 @@ class TFLiteGPURunner {
|
|||
return output_shape_from_model_;
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
void SetSerializedBinaryCache(std::vector<uint8_t>&& cache) {
|
||||
serialized_binary_cache_ = std::move(cache);
|
||||
}
|
||||
|
@ -98,26 +99,26 @@ class TFLiteGPURunner {
|
|||
}
|
||||
|
||||
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
|
||||
#endif // __ANDROID__
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
private:
|
||||
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
|
||||
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
absl::Status InitializeOpenCLFromSerializedModel(
|
||||
std::unique_ptr<InferenceBuilder>* builder);
|
||||
#endif // __ANDROID__
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
InferenceOptions options_;
|
||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||
|
||||
std::vector<uint8_t> serialized_binary_cache_;
|
||||
std::vector<uint8_t> serialized_model_;
|
||||
bool serialized_model_used_ = false;
|
||||
#endif // __ANDROID__
|
||||
#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
||||
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is
|
||||
// ready
|
||||
|
|
34
third_party/com_google_sentencepiece_no_gflag_no_gtest.diff
vendored
Normal file
34
third_party/com_google_sentencepiece_no_gflag_no_gtest.diff
vendored
Normal file
|
@ -0,0 +1,34 @@
|
|||
diff --git a/src/BUILD b/src/BUILD
|
||||
index b4298d2..f3877a3 100644
|
||||
--- a/src/BUILD
|
||||
+++ b/src/BUILD
|
||||
@@ -71,9 +71,7 @@ cc_library(
|
||||
":common",
|
||||
":sentencepiece_cc_proto",
|
||||
":sentencepiece_model_cc_proto",
|
||||
- "@com_github_gflags_gflags//:gflags",
|
||||
"@com_google_glog//:glog",
|
||||
- "@com_google_googletest//:gtest",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
diff --git a/src/normalizer.h b/src/normalizer.h
|
||||
index c16ac16..2af58be 100644
|
||||
--- a/src/normalizer.h
|
||||
+++ b/src/normalizer.h
|
||||
@@ -21,7 +21,6 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
-#include "gtest/gtest_prod.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "third_party/darts_clone/include/darts.h"
|
||||
#include "src/common.h"
|
||||
@@ -97,7 +96,6 @@ class Normalizer {
|
||||
friend class Builder;
|
||||
|
||||
private:
|
||||
- FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest);
|
||||
|
||||
void Init();
|
||||
|
Loading…
Reference in New Issue
Block a user