Merge branch 'master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-03 23:24:31 +05:30 committed by GitHub
commit 5a68ba84b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
244 changed files with 12284 additions and 2008 deletions

View File

@ -30,6 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
git \ git \
wget \ wget \
unzip \ unzip \
nodejs \
npm \
python3-dev \ python3-dev \
python3-opencv \ python3-opencv \
python3-pip \ python3-pip \

View File

@ -172,6 +172,10 @@ http_archive(
urls = [ urls = [
"https://github.com/google/sentencepiece/archive/1.0.0.zip", "https://github.com/google/sentencepiece/archive/1.0.0.zip",
], ],
patches = [
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
],
patch_args = ["-p1"],
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"}, repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
) )

14
docs/BUILD Normal file
View File

@ -0,0 +1,14 @@
# Placeholder for internal Python strict binary compatibility macro.
py_binary(
name = "build_py_api_docs",
srcs = ["build_py_api_docs.py"],
deps = [
"//mediapipe",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/tensorflow_docs",
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
"//third_party/py/tensorflow_docs/api_generator:public_api",
],
)

85
docs/build_py_api_docs.py Normal file
View File

@ -0,0 +1,85 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""MediaPipe reference docs generation script.
This script generates API reference docs for the `mediapipe` PIP package.
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
$> python build_py_api_docs.py
"""
import os
from absl import app
from absl import flags
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
try:
# mediapipe has not been set up to work with bazel yet, so catch & report.
import mediapipe # pytype: disable=import-error
except ImportError as e:
raise ImportError('Please `pip install mediapipe`.') from e
PROJECT_SHORT_NAME = 'mp'
PROJECT_FULL_NAME = 'MediaPipe'
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
default='/tmp/generated_docs',
help='Where to write the resulting docs.')
_URL_PREFIX = flags.DEFINE_string(
'code_url_prefix',
'https://github.com/google/mediapipe/tree/master/mediapipe',
'The url prefix for links to code.')
_SEARCH_HINTS = flags.DEFINE_bool(
'search_hints', True,
'Include metadata search hints in the generated files')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
'Path prefix in the _toc.yaml')
def gen_api_docs():
"""Generates API docs for the mediapipe package."""
doc_generator = generate_lib.DocGenerator(
root_title=PROJECT_FULL_NAME,
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
base_dir=os.path.dirname(mediapipe.__file__),
code_url_prefix=_URL_PREFIX.value,
search_hints=_SEARCH_HINTS.value,
site_path=_SITE_PATH.value,
# This callback ensures that docs are only generated for objects that
# are explicitly imported in your __init__.py files. There are other
# options but this is a good starting point.
callbacks=[public_api.explicit_package_contents_filter],
)
doc_generator.build(_OUTPUT_DIR.value)
print('Docs output to:', _OUTPUT_DIR.value)
def main(_):
gen_api_docs()
if __name__ == '__main__':
app.run(main)

View File

@ -222,10 +222,10 @@ cc_library(
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
@ -328,6 +328,7 @@ cc_library(
":concatenate_vector_calculator_cc_proto", ":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -344,6 +345,7 @@ cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",

View File

@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options { options {
[mediapipe.SwitchContainerOptions.ext] { [mediapipe.SwitchContainerOptions.ext] {
async_selection: true
contained_node: { calculator: "AppearancesPassThroughSubgraph" } contained_node: { calculator: "AppearancesPassThroughSubgraph" }
} }
} }
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options { options {
[mediapipe.SwitchContainerOptions.ext] { [mediapipe.SwitchContainerOptions.ext] {
async_selection: true
contained_node: { contained_node: {
calculator: "BypassCalculator" calculator: "BypassCalculator"
node_options: { node_options: {

View File

@ -18,6 +18,7 @@
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
}; };
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator); MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
class ConcatenateClassificationListCalculator
: public ConcatenateListsCalculator<Classification, ClassificationList> {
protected:
int ListSize(const ClassificationList& list) const override {
return list.classification_size();
}
const Classification GetItem(const ClassificationList& list,
int idx) const override {
return list.classification(idx);
}
Classification* AddItem(ClassificationList& list) const override {
return list.add_classification();
}
};
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
} }
} }
void AddInputClassificationLists(
const std::vector<ClassificationList>& input_classifications_vec,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < input_classifications_vec.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<ClassificationList>(input_classifications_vec[i])
.At(Timestamp(timestamp)));
}
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) { TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator", CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3, /*options_string=*/"", /*num_inputs=*/3,
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size()); EXPECT_EQ(0, outputs.size());
} }
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateClassificationListCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
)pb");
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb");
std::vector<ClassificationList> inputs = {input_0, input_1};
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
auto result = outputs[0].Get<ClassificationList>();
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb"),
EqualsProto(result));
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h" #include "mediapipe/util/render_data.pb.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
@ -58,4 +59,7 @@ typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
EndLoopDetectionCalculator; EndLoopDetectionCalculator;
REGISTER_CALCULATOR(EndLoopDetectionCalculator); REGISTER_CALCULATOR(EndLoopDetectionCalculator);
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -50,7 +50,7 @@ namespace mediapipe {
// calculator: "EndLoopWithOutputCalculator" // calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts // input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts // input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts // output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts
// } // }
template <typename IterableT> template <typename IterableT>
class EndLoopCalculator : public CalculatorBase { class EndLoopCalculator : public CalculatorBase {

View File

@ -109,6 +109,56 @@ cc_test(
], ],
) )
mediapipe_proto_library(
name = "tensors_to_audio_calculator_proto",
srcs = ["tensors_to_audio_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "tensors_to_audio_calculator",
srcs = ["tensors_to_audio_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_audio_tools//audio/dsp:window_functions",
"@pffft",
],
alwayslink = 1,
)
cc_test(
name = "tensors_to_audio_calculator_test",
srcs = ["tensors_to_audio_calculator_test.cc"],
deps = [
":audio_to_tensor_calculator",
":audio_to_tensor_calculator_cc_proto",
":tensors_to_audio_calculator",
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "feedback_tensors_calculator_proto", name = "feedback_tensors_calculator_proto",
srcs = ["feedback_tensors_calculator.proto"], srcs = ["feedback_tensors_calculator.proto"],
@ -253,6 +303,26 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "regex_preprocessor_calculator_test",
srcs = ["regex_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":regex_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "text_to_tensor_calculator", name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"], srcs = ["text_to_tensor_calculator.cc"],
@ -304,6 +374,28 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1,
)
cc_test(
name = "universal_sentence_encoder_preprocessor_calculator_test",
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
deps = [
":universal_sentence_encoder_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
) )
mediapipe_proto_library( mediapipe_proto_library(
@ -438,6 +530,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
@ -458,6 +551,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":inference_runner", ":inference_runner",
"//mediapipe/framework:mediapipe_profiling",
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
@ -1200,13 +1294,30 @@ cc_library(
name = "image_to_tensor_utils", name = "image_to_tensor_utils",
srcs = ["image_to_tensor_utils.cc"], srcs = ["image_to_tensor_utils.cc"],
hdrs = ["image_to_tensor_utils.h"], hdrs = ["image_to_tensor_utils.h"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":image_to_tensor_calculator_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@com_google_absl//absl/types:optional", "//mediapipe/gpu:gpu_origin_cc_proto",
], ] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": ["//mediapipe/gpu:gpu_buffer"],
}),
) )
cc_test( cc_test(
@ -1216,6 +1327,8 @@ cc_test(
":image_to_tensor_utils", ":image_to_tensor_utils",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
], ],
) )

View File

@ -133,7 +133,7 @@ bool IsValidFftSize(int size) {
// invocation. In the non-streaming mode, the vector contains all of the // invocation. In the non-streaming mode, the vector contains all of the
// output timestamps for an input audio buffer. // output timestamps for an input audio buffer.
// DC_AND_NYQUIST - std::pair<float, float> @Optional. // DC_AND_NYQUIST - std::pair<float, float> @Optional.
// A pair of dc component and nyquest component. Only can be connected when // A pair of dc component and nyquist component. Only can be connected when
// the calculator performs fft (the fft_size is set in the calculator // the calculator performs fft (the fft_size is set in the calculator
// options). // options).
// //

View File

@ -54,13 +54,6 @@
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
#if MEDIAPIPE_DISABLE_GPU
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
using GpuBuffer = AnyType;
#else
using GpuBuffer = mediapipe::GpuBuffer;
#endif // MEDIAPIPE_DISABLE_GPU
// Converts image into Tensor, possibly with cropping, resizing and // Converts image into Tensor, possibly with cropping, resizing and
// normalization, according to specified inputs and options. // normalization, according to specified inputs and options.
// //
@ -141,42 +134,7 @@ class ImageToTensorCalculator : public Node {
const auto& options = const auto& options =
cc->Options<mediapipe::ImageToTensorCalculatorOptions>(); cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
RET_CHECK(options.has_output_tensor_float_range() || RET_CHECK_OK(ValidateOptionOutputDims(options));
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected()) RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected())
<< "One and only one of IMAGE and IMAGE_GPU input is expected."; << "One and only one of IMAGE and IMAGE_GPU input is expected.";
@ -198,21 +156,7 @@ class ImageToTensorCalculator : public Node {
absl::Status Open(CalculatorContext* cc) { absl::Status Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>(); options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
output_width_ = options_.output_tensor_width(); params_ = GetOutputTensorParams(options_);
output_height_ = options_.output_tensor_height();
is_float_output_ = options_.has_output_tensor_float_range();
if (options_.has_output_tensor_uint_range()) {
range_min_ =
static_cast<float>(options_.output_tensor_uint_range().min());
range_max_ =
static_cast<float>(options_.output_tensor_uint_range().max());
} else if (options_.has_output_tensor_int_range()) {
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
} else {
range_min_ = options_.output_tensor_float_range().min();
range_max_ = options_.output_tensor_float_range().max();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -242,7 +186,13 @@ class ImageToTensorCalculator : public Node {
} }
} }
ASSIGN_OR_RETURN(auto image, GetInputImage(cc)); #if MEDIAPIPE_DISABLE_GPU
ASSIGN_OR_RETURN(auto image, GetInputImage(kIn(cc)));
#else
const bool is_input_gpu = kInGpu(cc).IsConnected();
ASSIGN_OR_RETURN(auto image, is_input_gpu ? GetInputImage(kInGpu(cc))
: GetInputImage(kIn(cc)));
#endif // MEDIAPIPE_DISABLE_GPU
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect); RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
@ -263,11 +213,13 @@ class ImageToTensorCalculator : public Node {
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get())); MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
Tensor::ElementType output_tensor_type = Tensor::ElementType output_tensor_type =
GetOutputTensorType(image->UsesGpu()); GetOutputTensorType(image->UsesGpu(), params_);
Tensor tensor(output_tensor_type, {1, output_height_, output_width_, Tensor tensor(output_tensor_type,
GetNumOutputChannels(*image)}); {1, params_.output_height, params_.output_width,
GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_) MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_, ->Convert(*image, roi, params_.range_min,
params_.range_max,
/*tensor_buffer_offset=*/0, tensor)); /*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>(); auto result = std::make_unique<std::vector<Tensor>>();
@ -278,81 +230,11 @@ class ImageToTensorCalculator : public Node {
} }
private: private:
bool DoesGpuInputStartAtBottom() {
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
BorderMode GetBorderMode() {
switch (options_.border_mode()) {
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
return BorderMode::kReplicate;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
return BorderMode::kZero;
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
return BorderMode::kReplicate;
}
}
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (!uses_gpu) {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
CalculatorContext* cc) {
if (kIn(cc).IsConnected()) {
const auto& packet = kIn(cc).packet();
return kIn(cc).Visit(
[&packet](const mediapipe::Image&) {
return SharedPtrWithPacket<mediapipe::Image>(packet);
},
[&packet](const mediapipe::ImageFrame&) {
return std::make_shared<const mediapipe::Image>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
});
} else { // if (kInGpu(cc).IsConnected())
#if !MEDIAPIPE_DISABLE_GPU
const GpuBuffer& input = *kInGpu(cc);
// A shallow copy is okay since the resulting 'image' object is local in
// Process(), and thus never outlives 'input'.
return std::make_shared<const mediapipe::Image>(input);
#else
return absl::UnimplementedError(
"GPU processing is disabled in build flags");
#endif // !MEDIAPIPE_DISABLE_GPU
}
}
absl::Status InitConverterIfNecessary(CalculatorContext* cc, absl::Status InitConverterIfNecessary(CalculatorContext* cc,
const Image& image) { const Image& image) {
// Lazy initialization of the GPU or CPU converter. // Lazy initialization of the GPU or CPU converter.
if (image.UsesGpu()) { if (image.UsesGpu()) {
if (!is_float_output_) { if (!params_.is_float_output) {
return absl::UnimplementedError( return absl::UnimplementedError(
"ImageToTensorConverter for the input GPU image currently doesn't " "ImageToTensorConverter for the input GPU image currently doesn't "
"support quantization."); "support quantization.");
@ -360,18 +242,20 @@ class ImageToTensorCalculator : public Node {
if (!gpu_converter_) { if (!gpu_converter_) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
ASSIGN_OR_RETURN(gpu_converter_, ASSIGN_OR_RETURN(
CreateMetalConverter(cc, GetBorderMode())); gpu_converter_,
CreateMetalConverter(cc, GetBorderMode(options_.border_mode())));
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
ASSIGN_OR_RETURN(gpu_converter_, ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlBufferTensorConverter( CreateImageToGlBufferTensorConverter(
cc, DoesGpuInputStartAtBottom(), GetBorderMode())); cc, DoesGpuInputStartAtBottom(options_),
GetBorderMode(options_.border_mode())));
#else #else
if (!gpu_converter_) { if (!gpu_converter_) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(gpu_converter_,
gpu_converter_, CreateImageToGlTextureTensorConverter(
CreateImageToGlTextureTensorConverter( cc, DoesGpuInputStartAtBottom(options_),
cc, DoesGpuInputStartAtBottom(), GetBorderMode())); GetBorderMode(options_.border_mode())));
} }
if (!gpu_converter_) { if (!gpu_converter_) {
return absl::UnimplementedError( return absl::UnimplementedError(
@ -383,10 +267,10 @@ class ImageToTensorCalculator : public Node {
} else { } else {
if (!cpu_converter_) { if (!cpu_converter_) {
#if !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(cpu_converter_,
cpu_converter_, CreateOpenCvConverter(
CreateOpenCvConverter(cc, GetBorderMode(), cc, GetBorderMode(options_.border_mode()),
GetOutputTensorType(/*uses_gpu=*/false))); GetOutputTensorType(/*uses_gpu=*/false, params_)));
#else #else
LOG(FATAL) << "Cannot create image to tensor opencv converter since " LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined."; "MEDIAPIPE_DISABLE_OPENCV is defined.";
@ -399,11 +283,7 @@ class ImageToTensorCalculator : public Node {
std::unique_ptr<ImageToTensorConverter> gpu_converter_; std::unique_ptr<ImageToTensorConverter> gpu_converter_;
std::unique_ptr<ImageToTensorConverter> cpu_converter_; std::unique_ptr<ImageToTensorConverter> cpu_converter_;
mediapipe::ImageToTensorCalculatorOptions options_; mediapipe::ImageToTensorCalculatorOptions options_;
int output_width_ = 0; OutputTensorParams params_;
int output_height_ = 0;
bool is_float_output_ = false;
float range_min_ = 0.0f;
float range_max_ = 1.0f;
}; };
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator); MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);

View File

@ -27,12 +27,6 @@ struct Size {
int height; int height;
}; };
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such pixels
// will be calculated.
enum class BorderMode { kZero, kReplicate };
// Converts image to tensor. // Converts image to tensor.
class ImageToTensorConverter { class ImageToTensorConverter {
public: public:

View File

@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter {
Tensor& output_tensor) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ", "Unsupported format: ", static_cast<uint32_t>(input.format())));
static_cast<uint32_t>(input.format())));
} }
const auto& output_shape = output_tensor.shape(); const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter {
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min, [this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max, tensor_buffer_offset]() -> absl::Status { range_max, tensor_buffer_offset]() -> absl::Status {
constexpr int kRgbaNumChannels = 4; const int input_num_channels = input.channels();
auto source_texture = gl_helper_.CreateSourceTexture(input); auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture( tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA, GL_TEXTURE_2D, source_texture.name(),
input_num_channels == 4 ? GL_RGB : GL_RGBA,
source_texture.width() * source_texture.height() * source_texture.width() * source_texture.height() *
kRgbaNumChannels * sizeof(uint8_t), input_num_channels * sizeof(uint8_t),
/*layer=*/0, /*layer=*/0,
/*owned=*/false); /*owned=*/false);

View File

@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter {
Tensor& output_tensor) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ", "Unsupported format: ", static_cast<uint32_t>(input.format())));
static_cast<uint32_t>(input.format())));
} }
// TODO: support tensor_buffer_offset > 0 scenario. // TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0) RET_CHECK_EQ(tensor_buffer_offset, 0)

View File

@ -16,7 +16,9 @@
#include <array> #include <array>
#include "absl/status/status.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
@ -214,4 +216,68 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
matrix[15] = 1.0f; matrix[15] = 1.0f;
} }
BorderMode GetBorderMode(
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode) {
switch (mode) {
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
return BorderMode::kReplicate;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
return BorderMode::kZero;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
return BorderMode::kReplicate;
}
}
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
const OutputTensorParams& params) {
if (!uses_gpu) {
if (params.is_float_output) {
return Tensor::ElementType::kFloat32;
}
if (params.range_min < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const mediapipe::Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
image_packet) {
return image_packet.Visit(
[&image_packet](const mediapipe::Image&) {
return SharedPtrWithPacket<mediapipe::Image>(image_packet);
},
[&image_packet](const mediapipe::ImageFrame&) {
return std::make_shared<const mediapipe::Image>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(image_packet)));
});
}
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet) {
// A shallow copy is okay since the resulting 'image' object is local in
// Process(), and thus never outlives 'input'.
return std::make_shared<const mediapipe::Image>(image_gpu_packet.Get());
}
#endif // !MEDIAPIPE_DISABLE_GPU
} // namespace mediapipe } // namespace mediapipe

View File

@ -18,8 +18,18 @@
#include <array> #include <array>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_buffer.h"
#endif // !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_origin.pb.h"
namespace mediapipe { namespace mediapipe {
@ -31,6 +41,24 @@ struct RotatedRect {
float rotation; float rotation;
}; };
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such pixels
// will be calculated.
// TODO: Consider moving this to a separate border_mode.h file.
enum class BorderMode { kZero, kReplicate };
// Struct that host commonly accessed parameters used in the
// ImageTo[Batch]TensorCalculator.
struct OutputTensorParams {
int output_height;
int output_width;
int output_batch;
bool is_float_output;
float range_min;
float range_max;
};
// Generates a new ROI or converts it from normalized rect. // Generates a new ROI or converts it from normalized rect.
RotatedRect GetRoi(int input_width, int input_height, RotatedRect GetRoi(int input_width, int input_height,
absl::optional<mediapipe::NormalizedRect> norm_rect); absl::optional<mediapipe::NormalizedRect> norm_rect);
@ -95,6 +123,103 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
const RotatedRect& sub_rect, int rect_width, int rect_height, const RotatedRect& sub_rect, int rect_width, int rect_height,
bool flip_horizontaly, std::array<float, 16>* matrix); bool flip_horizontaly, std::array<float, 16>* matrix);
// Validates the output dimensions set in the option proto. The input option
// proto is expected to have to following fields:
// output_tensor_float_range, output_tensor_int_range, output_tensor_uint_range
// output_tensor_width, output_tensor_height.
// See ImageToTensorCalculatorOptions for the description of each field.
template <typename T>
absl::Status ValidateOptionOutputDims(const T& options) {
RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
return absl::OkStatus();
}
template <typename T>
OutputTensorParams GetOutputTensorParams(const T& options) {
OutputTensorParams params;
if (options.has_output_tensor_uint_range()) {
params.range_min =
static_cast<float>(options.output_tensor_uint_range().min());
params.range_max =
static_cast<float>(options.output_tensor_uint_range().max());
} else if (options.has_output_tensor_int_range()) {
params.range_min =
static_cast<float>(options.output_tensor_int_range().min());
params.range_max =
static_cast<float>(options.output_tensor_int_range().max());
} else {
params.range_min = options.output_tensor_float_range().min();
params.range_max = options.output_tensor_float_range().max();
}
params.output_width = options.output_tensor_width();
params.output_height = options.output_tensor_height();
params.is_float_output = options.has_output_tensor_float_range();
params.output_batch = 1;
return params;
}
// Returns whether the GPU input format starts at the bottom.
template <typename T>
bool DoesGpuInputStartAtBottom(const T& options) {
return options.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
// Converts the BorderMode proto into struct.
BorderMode GetBorderMode(
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode);
// Gets the output tensor type.
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
const OutputTensorParams& params);
// Gets the number of output channels from the input Image format.
int GetNumOutputChannels(const mediapipe::Image& image);
// Converts the packet that hosts different format (Image, ImageFrame,
// GpuBuffer) into the mediapipe::Image format.
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
image_packet);
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet);
#endif // !MEDIAPIPE_DISABLE_GPU
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_ #endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_

View File

@ -16,6 +16,8 @@
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe { namespace mediapipe {
@ -23,6 +25,7 @@ namespace {
using ::testing::ElementsAre; using ::testing::ElementsAre;
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
testing::Matcher<RotatedRect> EqRotatedRect(float width, float height, testing::Matcher<RotatedRect> EqRotatedRect(float width, float height,
float center_x, float center_y, float center_x, float center_y,
@ -157,5 +160,95 @@ TEST(GetValueRangeTransformation, FloatToPixel) {
EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f)); EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f));
} }
constexpr char kValidFloatProto[] = R"(
output_tensor_float_range { min: 0.0 max: 1.0 }
output_tensor_width: 100
output_tensor_height: 200
)";
constexpr char kValidIntProto[] = R"(
output_tensor_float_range { min: 0 max: 255 }
output_tensor_width: 100
output_tensor_height: 200
)";
TEST(ValidateOptionOutputDims, ValidProtos) {
const auto float_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidFloatProto);
MP_EXPECT_OK(ValidateOptionOutputDims(float_options));
}
TEST(ValidateOptionOutputDims, EmptyProto) {
mediapipe::ImageToTensorCalculatorOptions options;
// No output tensor range set.
EXPECT_THAT(ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Output tensor range is required")));
// Invalid output float tensor range.
options.mutable_output_tensor_float_range()->set_min(1.0);
options.mutable_output_tensor_float_range()->set_max(0.0);
EXPECT_THAT(
ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Valid output float tensor range is required")));
// Output width/height is not set.
options.mutable_output_tensor_float_range()->set_min(0.0);
options.mutable_output_tensor_float_range()->set_max(1.0);
EXPECT_THAT(ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Valid output tensor width is required")));
}
TEST(GetOutputTensorParams, SetValues) {
// Test int range with ImageToTensorCalculatorOptions.
const auto int_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidIntProto);
const auto params2 = GetOutputTensorParams(int_options);
EXPECT_EQ(params2.range_min, 0.0f);
EXPECT_EQ(params2.range_max, 255.0f);
EXPECT_EQ(params2.output_batch, 1);
EXPECT_EQ(params2.output_width, 100);
EXPECT_EQ(params2.output_height, 200);
}
TEST(GetBorderMode, GetBorderMode) {
// Default to REPLICATE.
auto border_mode =
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED;
EXPECT_EQ(BorderMode::kReplicate, GetBorderMode(border_mode));
// Set to ZERO.
border_mode =
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO;
EXPECT_EQ(BorderMode::kZero, GetBorderMode(border_mode));
}
TEST(GetOutputTensorType, GetOutputTensorType) {
OutputTensorParams params;
// Return float32 when GPU is enabled.
EXPECT_EQ(Tensor::ElementType::kFloat32,
GetOutputTensorType(/*uses_gpu=*/true, params));
// Return float32 when is_float_output is set to true.
params.is_float_output = true;
EXPECT_EQ(Tensor::ElementType::kFloat32,
GetOutputTensorType(/*uses_gpu=*/false, params));
// Return int8 when range_min is negative.
params.is_float_output = false;
params.range_min = -255.0f;
EXPECT_EQ(Tensor::ElementType::kInt8,
GetOutputTensorType(/*uses_gpu=*/false, params));
// Return 8int8 when range_min is non-negative.
params.range_min = 0.0f;
EXPECT_EQ(Tensor::ElementType::kUInt8,
GetOutputTensorType(/*uses_gpu=*/false, params));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -72,7 +72,7 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
RET_CHECK(!input_tensors.empty()); RET_CHECK(!input_tensors.empty());
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors, ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
inference_runner_->Run(input_tensors)); inference_runner_->Run(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors)); kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -26,6 +26,8 @@
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors, CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
std::vector<Tensor>& output_tensors) { std::vector<Tensor>& output_tensors) {
return gpu_helper_.RunInGlContext( return gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status { [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
// Explicitly copy input. // Explicitly copy input.
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
glBindBuffer(GL_COPY_READ_BUFFER, glBindBuffer(GL_COPY_READ_BUFFER,
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
} }
// Run inference. // Run inference.
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
output_tensors.reserve(output_size_); output_tensors.reserve(output_size_);
for (int i = 0; i < output_size_; ++i) { for (int i = 0; i < output_size_; ++i) {

View File

@ -32,6 +32,8 @@
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate& delegate); const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
absl::StatusOr<std::vector<Tensor>> Process( absl::StatusOr<std::vector<Tensor>> Process(
const std::vector<Tensor>& input_tensors); CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
absl::Status Close(); absl::Status Close();
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
absl::StatusOr<std::vector<Tensor>> absl::StatusOr<std::vector<Tensor>>
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
const std::vector<Tensor>& input_tensors) { CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
std::vector<Tensor> output_tensors; std::vector<Tensor> output_tensors;
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status { [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i)); input_tensors[i].GetOpenGlBufferReadView().name(), i));
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
output_tensors.back().GetOpenGlBufferWriteView().name(), i)); output_tensors.back().GetOpenGlBufferWriteView().name(), i));
} }
// Run inference. // Run inference.
return tflite_gpu_runner_->Invoke(); {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
return tflite_gpu_runner_->Invoke();
}
})); }));
return output_tensors; return output_tensors;
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
auto output_tensors = absl::make_unique<std::vector<Tensor>>(); auto output_tensors = absl::make_unique<std::vector<Tensor>>();
ASSIGN_OR_RETURN(*output_tensors, ASSIGN_OR_RETURN(*output_tensors,
gpu_inference_runner_->Process(input_tensors)); gpu_inference_runner_->Process(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors)); kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus(); return absl::OkStatus();

View File

@ -70,7 +70,7 @@ absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) {
RET_CHECK(!input_tensors.empty()); RET_CHECK(!input_tensors.empty());
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors, ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
inference_runner_->Run(input_tensors)); inference_runner_->Run(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors)); kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -20,12 +20,15 @@
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -79,7 +82,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
delegate_(std::move(delegate)) {} delegate_(std::move(delegate)) {}
absl::StatusOr<std::vector<Tensor>> Run( absl::StatusOr<std::vector<Tensor>> Run(
const std::vector<Tensor>& input_tensors) override; CalculatorContext* cc, const std::vector<Tensor>& input_tensors) override;
private: private:
api2::Packet<TfLiteModelPtr> model_; api2::Packet<TfLiteModelPtr> model_;
@ -88,7 +91,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
}; };
absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run( absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
const std::vector<Tensor>& input_tensors) { CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
// Read CPU input into tensors. // Read CPU input into tensors.
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size()); RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
@ -131,8 +134,10 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
} }
// Run inference. // Run inference.
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); {
MEDIAPIPE_PROFILING(CPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
// Output result tensors (CPU). // Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs(); const auto& tensor_indexes = interpreter_->outputs();
std::vector<Tensor> output_tensors; std::vector<Tensor> output_tensors;

View File

@ -2,6 +2,7 @@
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_ #define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
namespace mediapipe { namespace mediapipe {
@ -11,7 +12,7 @@ class InferenceRunner {
public: public:
virtual ~InferenceRunner() = default; virtual ~InferenceRunner() = default;
virtual absl::StatusOr<std::vector<Tensor>> Run( virtual absl::StatusOr<std::vector<Tensor>> Run(
const std::vector<Tensor>& inputs) = 0; CalculatorContext* cc, const std::vector<Tensor>& inputs) = 0;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,197 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstring>
#include <new>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "pffft.h"
namespace mediapipe {
namespace api2 {
namespace {
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
std::vector<float> hann_window(window_size);
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
if (sqrt_hann) {
absl::c_transform(hann_window, hann_window.begin(),
[](double x) { return std::sqrt(x); });
}
return hann_window;
}
// Note that the InvHannWindow function may only work for 50% overlapping case.
std::vector<float> InvHannWindow(int window_size, bool sqrt_hann) {
std::vector<float> window = HannWindow(window_size, sqrt_hann);
std::vector<float> inv_window(window.size());
if (sqrt_hann) {
absl::c_copy(window, inv_window.begin());
} else {
const int kHalfWindowSize = window.size() / 2;
absl::c_transform(window, inv_window.begin(),
[](double x) { return x * x; });
for (int i = 0; i < kHalfWindowSize; ++i) {
double sum = inv_window[i] + inv_window[kHalfWindowSize + i];
inv_window[i] = window[i] / sum;
inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum;
}
}
return inv_window;
}
// PFFFT only supports transforms for inputs of length N of the form
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
bool IsValidFftSize(int size) {
if (size <= 0) {
return false;
}
constexpr int kFactors[] = {2, 3, 5};
int factorization[] = {0, 0, 0};
int n = static_cast<int>(size);
for (int i = 0; i < 3; ++i) {
while (n % kFactors[i] == 0) {
n = n / kFactors[i];
++factorization[i];
}
}
return factorization[0] >= 5 && n == 1;
}
} // namespace
// Converts 2D MediaPipe float Tensors to audio buffers.
// The calculator will perform ifft on the complex DFT and apply the window
// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must
// have the DFT real parts in its first row and the DFT imagery parts in its
// second row. A valid "fft_size" must be set in the CalculatorOptions.
//
// Inputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor that represents the audio's complex DFT
// results.
// DC_AND_NYQUIST - std::pair<float, float>
// A pair of dc component and nyquist component.
//
// Outputs:
// AUDIO - mediapipe::Matrix
// The audio data represented as mediapipe::Matrix.
//
// Example:
// node {
// calculator: "TensorsToAudioCalculator"
// input_stream: "TENSORS:tensors"
// input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
// output_stream: "AUDIO:audio"
// options {
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
// fft_size: 256
// }
// }
// }
class TensorsToAudioCalculator : public Node {
public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Input<std::pair<float, float>> kDcAndNyquistIn{
"DC_AND_NYQUIST"};
static constexpr Output<Matrix> kAudioOut{"AUDIO"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
// The internal state of the FFT library.
PFFFT_Setup* fft_state_ = nullptr;
int fft_size_ = 0;
float inverse_fft_size_ = 0;
std::vector<float, Eigen::aligned_allocator<float>> input_dft_;
std::vector<float> inv_fft_window_;
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
// pffft requires memory to work with to avoid using the stack.
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
};
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
const auto& options =
cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
RET_CHECK(IsValidFftSize(options.fft_size()))
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
">=0 and c >= 0 and a >= 5, the requested fft size is "
<< options.fft_size();
fft_size_ = options.fft_size();
inverse_fft_size_ = 1.0f / fft_size_;
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
input_dft_.resize(fft_size_);
inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false);
fft_input_buffer_.resize(fft_size_);
fft_workplace_.resize(fft_size_);
fft_output_.resize(fft_size_);
return absl::OkStatus();
}
absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = *kTensorsIn(cc);
RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
auto view = input_tensors[0].GetCpuReadView();
// DC's real part.
input_dft_[0] = kDcAndNyquistIn(cc)->first;
// Nyquist's real part is the penultimate element of the tensor buffer.
// pffft ignores the Nyquist's imagery part. No need to fetch the last value
// from the tensor buffer.
input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
(fft_size_ - 2) * sizeof(float));
pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
fft_workplace_.data(), PFFFT_BACKWARD);
// Applies the inverse window function.
std::transform(
fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(),
fft_output_.begin(),
[this](float a, float b) { return a * b * inverse_fft_size_; });
Matrix matrix = Eigen::Map<Matrix>(fft_output_.data(), 1, fft_output_.size());
kAudioOut(cc).Send(std::move(matrix));
return absl::OkStatus();
}
absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) {
if (fft_state_) {
pffft_destroy_setup(fft_state_);
}
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorsToAudioCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsToAudioCalculatorOptions ext = 484297136;
}
// Size of the fft in number of bins. If set, the calculator will do ifft
// on the input tensor.
optional int64 fft_size = 1;
}

View File

@ -0,0 +1,149 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <new>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
protected:
// Creates an audio matrix containing a single sample of 1.0 at a specified
// offset.
Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) {
Matrix impulse = Matrix::Zero(1, num_samples);
impulse(0, impulse_offset_idx) = 1.0;
return impulse;
}
void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "audio_in"
input_stream: "sample_rate"
output_stream: "audio_out"
node {
calculator: "AudioToTensorCalculator"
input_stream: "AUDIO:audio_in"
input_stream: "SAMPLE_RATE:sample_rate"
output_stream: "TENSORS:tensors"
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
options {
[mediapipe.AudioToTensorCalculatorOptions.ext] {
num_channels: 1
num_samples: $0
num_overlapping_samples: 0
target_sample_rate: $1
fft_size: $2
}
}
}
node {
calculator: "TensorsToAudioCalculator"
input_stream: "TENSORS:tensors"
input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
output_stream: "AUDIO:audio_out"
options {
[mediapipe.TensorsToAudioCalculatorOptions.ext] {
fft_size: $2
}
}
}
)",
/*$0=*/num_samples,
/*$1=*/sample_rate,
/*$2=*/fft_size));
tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
}
void RunGraph(const Matrix& input_data, double sample_rate) {
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"audio_in", MakePacket<Matrix>(input_data).At(Timestamp(0))));
MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
std::vector<Packet> audio_out_packets_;
CalculatorGraphConfig graph_config_;
CalculatorGraph graph_;
};
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
ConfigGraph(320, 16000, 103);
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
auto status = graph_.WaitUntilIdle();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(status.message(),
::testing::HasSubstr("FFT size must be of the form"));
}
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// The impulse signal at the center is not affected by the window function.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// As the impulse signal sits at the 1/4 of the hann window, the inverse
// window function reduces it by half.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data / 2);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// As the impulse signal sits at the beginning of the hann window, the inverse
// window function completely removes it.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
}
} // namespace
} // namespace mediapipe

View File

@ -289,8 +289,15 @@ class NodeBase {
template <typename T> template <typename T>
T& GetOptions() { T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true; options_used_ = true;
return *options_.MutableExtension(T::ext); return *options_.MutableExtension(extension);
} }
protected: protected:
@ -386,8 +393,15 @@ class PacketGenerator {
template <typename T> template <typename T>
T& GetOptions() { T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true; options_used_ = true;
return *options_.MutableExtension(T::ext); return *options_.MutableExtension(extension);
} }
template <typename B, typename T, bool kIsOptional, bool kIsMultiple> template <typename B, typename T, bool kIsOptional, bool kIsMultiple>

View File

@ -185,7 +185,7 @@ class CalculatorBaseFactory {
// Functions for checking that the calculator has the required GetContract. // Functions for checking that the calculator has the required GetContract.
template <class T> template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
typedef absl::Status (*GetContractType)(CalculatorContract * cc); typedef absl::Status (*GetContractType)(CalculatorContract* cc);
return std::is_same<decltype(&T::GetContract), GetContractType>::value; return std::is_same<decltype(&T::GetContract), GetContractType>::value;
} }
template <class T> template <class T>

View File

@ -133,7 +133,13 @@ message GraphTrace {
TPU_TASK = 13; TPU_TASK = 13;
GPU_CALIBRATION = 14; GPU_CALIBRATION = 14;
PACKET_QUEUED = 15; PACKET_QUEUED = 15;
GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17;
CPU_TASK_INVOKE = 18;
} }
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
// )
// The timing for one packet set being processed at one caclulator node. // The timing for one packet set being processed at one caclulator node.
message CalculatorTrace { message CalculatorTrace {

View File

@ -293,7 +293,6 @@ mediapipe_proto_library(
name = "rect_proto", name = "rect_proto",
srcs = ["rect.proto"], srcs = ["rect.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/formats:location_data_proto"],
) )
mediapipe_register_type( mediapipe_register_type(

View File

@ -109,6 +109,13 @@ struct TraceEvent {
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK; static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION; static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED; static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
// )
}; };
// Packet trace log buffer. // Packet trace log buffer.

View File

@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
// Returns a PacketSequencerCalculator node. // Returns a PacketSequencerCalculator node.
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config, CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
bool synchronize_io) { bool async_selection) {
CalculatorGraphConfig::Node* result = config->add_node(); CalculatorGraphConfig::Node* result = config->add_node();
*result->mutable_calculator() = "PacketSequencerCalculator"; *result->mutable_calculator() = "PacketSequencerCalculator";
if (synchronize_io) { if (!async_selection) {
*result->mutable_input_stream_handler()->mutable_input_stream_handler() = *result->mutable_input_stream_handler()->mutable_input_stream_handler() =
"DefaultInputStreamHandler"; "DefaultInputStreamHandler";
} }
@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField<std::string>& streams,
return tags.count({tag, 0}) > 0; return tags.count({tag, 0}) > 0;
} }
// Returns true if a set of "TAG::index" includes a TagIndex.
bool ContainsTag(const proto_ns::RepeatedPtrField<std::string>& tags,
TagIndex item) {
for (const std::string& t : tags) {
if (ParseTagIndex(t) == item) return true;
}
return false;
}
absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig( absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
const Subgraph::SubgraphOptions& options) { const Subgraph::SubgraphOptions& options) {
CalculatorGraphConfig config; CalculatorGraphConfig config;
@ -263,17 +272,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
std::string enable_stream = "ENABLE:gate_enable"; std::string enable_stream = "ENABLE:gate_enable";
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams. // Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
bool synchronize_io = const auto& switch_options =
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options) Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options);
.synchronize_io(); bool async_selection = switch_options.async_selection();
if (HasTag(container_node.input_stream(), "SELECT")) { if (HasTag(container_node.input_stream(), "SELECT")) {
select_node = BuildTimestampNode(&config, synchronize_io); select_node = BuildTimestampNode(&config, async_selection);
select_node->add_input_stream("INPUT:gate_select"); select_node->add_input_stream("INPUT:gate_select");
select_node->add_output_stream("OUTPUT:gate_select_timed"); select_node->add_output_stream("OUTPUT:gate_select_timed");
select_stream = "SELECT:gate_select_timed"; select_stream = "SELECT:gate_select_timed";
} }
if (HasTag(container_node.input_stream(), "ENABLE")) { if (HasTag(container_node.input_stream(), "ENABLE")) {
enable_node = BuildTimestampNode(&config, synchronize_io); enable_node = BuildTimestampNode(&config, async_selection);
enable_node->add_input_stream("INPUT:gate_enable"); enable_node->add_input_stream("INPUT:gate_enable");
enable_node->add_output_stream("OUTPUT:gate_enable_timed"); enable_node->add_output_stream("OUTPUT:gate_enable_timed");
enable_stream = "ENABLE:gate_enable_timed"; enable_stream = "ENABLE:gate_enable_timed";
@ -296,7 +305,7 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
mux->add_input_side_packet("SELECT:gate_select"); mux->add_input_side_packet("SELECT:gate_select");
mux->add_input_side_packet("ENABLE:gate_enable"); mux->add_input_side_packet("ENABLE:gate_enable");
// Add input streams for graph and demux and the timestamper. // Add input streams for graph and demux.
config.add_input_stream("SELECT:gate_select"); config.add_input_stream("SELECT:gate_select");
config.add_input_stream("ENABLE:gate_enable"); config.add_input_stream("ENABLE:gate_enable");
config.add_input_side_packet("SELECT:gate_select"); config.add_input_side_packet("SELECT:gate_select");
@ -306,6 +315,12 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
std::string stream = CatStream(p.first, p.second); std::string stream = CatStream(p.first, p.second);
config.add_input_stream(stream); config.add_input_stream(stream);
demux->add_input_stream(stream); demux->add_input_stream(stream);
}
// Add input streams for the timestamper.
auto& tick_streams = switch_options.tick_input_stream();
for (const auto& p : input_tags) {
if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue;
TagIndex tick_tag{"TICK", tick_index++}; TagIndex tick_tag{"TICK", tick_index++};
if (select_node) { if (select_node) {
select_node->add_input_stream(CatStream(tick_tag, p.second)); select_node->add_input_stream(CatStream(tick_tag, p.second));

View File

@ -25,6 +25,14 @@ message SwitchContainerOptions {
// Activates channel 1 for enable = true, channel 0 otherwise. // Activates channel 1 for enable = true, channel 0 otherwise.
optional bool enable = 4; optional bool enable = 4;
// Use DefaultInputStreamHandler for muxing & demuxing. // Use DefaultInputStreamHandler for demuxing.
optional bool synchronize_io = 5; optional bool synchronize_io = 5;
// Use ImmediateInputStreamHandler for channel selection.
optional bool async_selection = 6;
// Specifies an input stream, "TAG:index", that defines the processed
// timestamps. SwitchContainer awaits output at the last processed
// timestamp before advancing from one selected channel to the next.
repeated string tick_input_stream = 7;
} }

View File

@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
input_stream: "INPUT:enable" input_stream: "INPUT:enable"
input_stream: "TICK:foo" input_stream: "TICK:foo"
output_stream: "OUTPUT:switchcontainer__gate_enable_timed" output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
input_stream_handler {
input_stream_handler: "DefaultInputStreamHandler"
}
} }
node { node {
name: "switchcontainer__SwitchDemuxCalculator" name: "switchcontainer__SwitchDemuxCalculator"
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
// Shows the SwitchContainer container runs with a pair of simple subnodes. // Shows the SwitchContainer container runs with a pair of simple subnodes.
TEST(SwitchContainerTest, RunsWithSubnodes) { TEST(SwitchContainerTest, RunsWithSubnodes) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer")); EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
CalculatorGraphConfig supergraph = SubnodeContainerExample(); CalculatorGraphConfig supergraph =
SubnodeContainerExample("async_selection: true");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
RunTestContainer(supergraph); RunTestContainer(supergraph);
} }

View File

@ -14,6 +14,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <queue>
#include <set> #include <set>
#include <string> #include <string>
@ -54,21 +55,47 @@ namespace mediapipe {
// contained subgraph or calculator nodes. // contained subgraph or calculator nodes.
// //
class SwitchDemuxCalculator : public CalculatorBase { class SwitchDemuxCalculator : public CalculatorBase {
static constexpr char kSelectTag[] = "SELECT";
static constexpr char kEnableTag[] = "ENABLE";
public: public:
static absl::Status GetContract(CalculatorContract* cc); static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
private:
absl::Status RecordPackets(CalculatorContext* cc);
int ChannelIndex(Timestamp timestamp);
absl::Status SendActivePackets(CalculatorContext* cc);
private: private:
int channel_index_; int channel_index_;
std::set<std::string> channel_tags_; std::set<std::string> channel_tags_;
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
PacketQueue input_queue_;
std::map<Timestamp, int> channel_history_;
}; };
REGISTER_CALCULATOR(SwitchDemuxCalculator); REGISTER_CALCULATOR(SwitchDemuxCalculator);
namespace {
static constexpr char kSelectTag[] = "SELECT";
static constexpr char kEnableTag[] = "ENABLE";
// Returns the last received timestamp for an input stream.
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
return input.Value().Timestamp();
}
// Returns the last received timestamp for channel selection.
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
Timestamp result = Timestamp::Done();
if (cc->Inputs().HasTag(kEnableTag)) {
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
} else if (cc->Inputs().HasTag(kSelectTag)) {
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
}
return result;
}
} // namespace
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
// Allow any one of kSelectTag, kEnableTag. // Allow any one of kSelectTag, kEnableTag.
cc->Inputs().Tag(kSelectTag).Set<int>().Optional(); cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
channel_tags_ = ChannelTags(cc->Outputs().TagMap()); channel_tags_ = ChannelTags(cc->Outputs().TagMap());
channel_history_[Timestamp::Unstarted()] = channel_index_;
// Relay side packets to all channels. // Relay side packets to all channels.
// Note: This is necessary because Calculator::Open only proceeds when every // Note: This is necessary because Calculator::Open only proceeds when every
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
} }
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
// Update the input channel index if specified. MP_RETURN_IF_ERROR(RecordPackets(cc));
channel_index_ = tool::GetChannelIndex(*cc, channel_index_); MP_RETURN_IF_ERROR(SendActivePackets(cc));
return absl::OkStatus();
}
// Relay packets and timestamps only to channel_index_. // Enqueue all arriving packets and bounds.
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
// Enqueue any new arriving packets.
for (const std::string& tag : channel_tags_) { for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto& input = cc->Inputs().Get(tag, index); auto input_id = cc->Inputs().GetId(tag, index);
std::string output_tag = tool::ChannelTag(tag, channel_index_); Packet packet = cc->Inputs().Get(input_id).Value();
auto output_id = cc->Outputs().GetId(output_tag, index); if (packet.Timestamp() == cc->InputTimestamp()) {
if (output_id.IsValid()) { input_queue_[input_id].push(packet);
auto& output = cc->Outputs().Get(output_tag, index);
tool::Relay(input, &output);
} }
} }
} }
// Enque any new input channel and its activation timestamp.
Timestamp channel_settled = ChannelSettledTimestamp(cc);
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
if (channel_settled == cc->InputTimestamp() &&
new_channel_index != channel_index_) {
channel_index_ = new_channel_index;
channel_history_[channel_settled] = channel_index_;
}
return absl::OkStatus();
}
// Returns the channel index for a Timestamp.
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
auto it = std::prev(channel_history_.upper_bound(timestamp));
return it->second;
}
// Dispatches all queued input packets with known channels.
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
// Dispatch any queued input packets with a defined channel_index.
Timestamp channel_settled = ChannelSettledTimestamp(cc);
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto input_id = cc->Inputs().GetId(tag, index);
auto& queue = input_queue_[input_id];
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
int channel_index = ChannelIndex(queue.front().Timestamp());
std::string output_tag = tool::ChannelTag(tag, channel_index);
auto output_id = cc->Outputs().GetId(output_tag, index);
if (output_id.IsValid()) {
cc->Outputs().Get(output_id).AddPacket(queue.front());
}
queue.pop();
}
}
}
// Discard all select packets not needed for any remaining input packets.
Timestamp input_settled = Timestamp::Done();
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto input_id = cc->Inputs().GetId(tag, index);
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
if (!input_queue_[input_id].empty()) {
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
stream_settled =
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
}
}
}
Timestamp input_bound = input_settled.NextAllowedInStream();
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
channel_history_.erase(channel_history_.begin(), history_bound);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::SwitchContainerOptions>(); options_ = cc->Options<mediapipe::SwitchContainerOptions>();
channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
channel_tags_ = ChannelTags(cc->Inputs().TagMap()); channel_tags_ = ChannelTags(cc->Inputs().TagMap());
channel_history_[Timestamp::Unset()] = channel_index_; channel_history_[Timestamp::Unstarted()] = channel_index_;
// Relay side packets only from channel_index_. // Relay side packets only from channel_index_.
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) { for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {

View File

@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
static void EglThreadExitCallback(void* key_value) { static void EglThreadExitCallback(void* key_value) {
#if defined(__ANDROID__)
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
EGL_NO_CONTEXT);
#else
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
// parameter for eglMakeCurrent. This behavior is not portable to all EGL // parameter for eglMakeCurrent. This behavior is not portable to all EGL
// implementations, and should be considered as an undocumented vendor // implementations, and should be considered as an undocumented vendor
// extension. // extension.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
//
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT); EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif
eglReleaseThread(); eglReleaseThread();
} }

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.ImageProperties; import com.google.mediapipe.framework.image.MPImageProperties;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
// TODO: use Preconditions in this file. // TODO: use Preconditions in this file.
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
} }
/** /**
* Creates an Image packet from an {@link Image}. * Creates a MediaPipe Image packet from a {@link MPImage}.
* *
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. * <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
*/ */
public Packet createImage(Image image) { public Packet createImage(MPImage image) {
// TODO: Choose the best storage from multiple containers. // TODO: Choose the best storage from multiple containers.
ImageProperties properties = image.getContainedImageProperties().get(0); MPImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
ByteBuffer buffer = ByteBufferExtractor.extract(image); ByteBuffer buffer = ByteBufferExtractor.extract(image);
int numChannels = 0; int numChannels = 0;
switch (properties.getImageFormat()) { switch (properties.getImageFormat()) {
case Image.IMAGE_FORMAT_RGBA: case MPImage.IMAGE_FORMAT_RGBA:
numChannels = 4; numChannels = 4;
break; break;
case Image.IMAGE_FORMAT_RGB: case MPImage.IMAGE_FORMAT_RGB:
numChannels = 3; numChannels = 3;
break; break;
case Image.IMAGE_FORMAT_ALPHA: case MPImage.IMAGE_FORMAT_ALPHA:
numChannels = 1; numChannels = 1;
break; break;
default: // fall out default: // fall out
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
int height = image.getHeight(); int height = image.getHeight();
return createImage(buffer, width, height, numChannels); return createImage(buffer, width, height, numChannels);
} }
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image); Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
@ -100,7 +100,7 @@ public class AndroidPacketCreator extends PacketCreator {
// Unsupported type. // Unsupported type.
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Unsupported Image container type: " + properties.getImageFormat()); "Unsupported Image container type: " + properties.getStorageType());
} }
/** /**

View File

@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
import android.graphics.Bitmap; import android.graphics.Bitmap;
/** /**
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
* {@link IllegalArgumentException} will be thrown. * {@link IllegalArgumentException} will be thrown.
*/ */
public final class BitmapExtractor { public final class BitmapExtractor {
/** /**
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}. * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
* *
* @param image the image to extract {@link android.graphics.Bitmap} from. * @param image the image to extract {@link android.graphics.Bitmap} from.
* @return the {@link android.graphics.Bitmap} stored in {@link Image} * @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type * @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions. * conversions.
*/ */
public static Bitmap extract(Image image) { public static Bitmap extract(MPImage image) {
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
if (imageContainer != null) { if (imageContainer != null) {
return ((BitmapImageContainer) imageContainer).getBitmap(); return ((BitmapImageContainer) imageContainer).getBitmap();
} else { } else {
// TODO: Support ByteBuffer -> Bitmap conversion. // TODO: Support ByteBuffer -> Bitmap conversion.
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting Bitmap from an Image created by objects other than Bitmap is not" "Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
+ " supported"); + " supported");
} }
} }

View File

@ -22,7 +22,7 @@ import android.provider.MediaStore;
import java.io.IOException; import java.io.IOException;
/** /**
* Builds {@link Image} from {@link android.graphics.Bitmap}. * Builds {@link MPImage} from {@link android.graphics.Bitmap}.
* *
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once * <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
} }
/** /**
* Creates the builder to build {@link Image} from a file. * Creates the builder to build {@link MPImage} from a file.
* *
* @param context the application context. * @param context the application context.
* @param uri the path to the resource file. * @param uri the path to the resource file.
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
BitmapImageBuilder setTimestamp(long timestamp) { BitmapImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image( return new MPImage(
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
} }
} }

View File

@ -16,19 +16,19 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
class BitmapImageContainer implements ImageContainer { class BitmapImageContainer implements MPImageContainer {
private final Bitmap bitmap; private final Bitmap bitmap;
private final ImageProperties properties; private final MPImageProperties properties;
public BitmapImageContainer(Bitmap bitmap) { public BitmapImageContainer(Bitmap bitmap) {
this.bitmap = bitmap; this.bitmap = bitmap;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setImageFormat(convertFormatCode(bitmap.getConfig())) .setImageFormat(convertFormatCode(bitmap.getConfig()))
.setStorageType(Image.STORAGE_TYPE_BITMAP) .setStorageType(MPImage.STORAGE_TYPE_BITMAP)
.build(); .build();
} }
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
bitmap.recycle(); bitmap.recycle();
} }
@ImageFormat @MPImageFormat
static int convertFormatCode(Bitmap.Config config) { static int convertFormatCode(Bitmap.Config config) {
switch (config) { switch (config) {
case ALPHA_8: case ALPHA_8:
return Image.IMAGE_FORMAT_ALPHA; return MPImage.IMAGE_FORMAT_ALPHA;
case ARGB_8888: case ARGB_8888:
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
default: default:
return Image.IMAGE_FORMAT_UNKNOWN; return MPImage.IMAGE_FORMAT_UNKNOWN;
} }
} }
} }

View File

@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
import android.os.Build.VERSION; import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.Locale; import java.util.Locale;
/** /**
* Utility for extracting {@link ByteBuffer} from {@link Image}. * Utility for extracting {@link ByteBuffer} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
* {@link IllegalArgumentException} will be thrown. * otherwise {@link IllegalArgumentException} will be thrown.
*/ */
public class ByteBufferExtractor { public class ByteBufferExtractor {
/** /**
* Extracts a {@link ByteBuffer} from an {@link Image}. * Extracts a {@link ByteBuffer} from a {@link MPImage}.
* *
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
* *
* @see Image#getContainedImageProperties() * @see MPImage#getContainedImageProperties()
* @return A read-only {@link ByteBuffer}. * @return A read-only {@link ByteBuffer}.
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/ */
@SuppressLint("SwitchIntDef") @SuppressLint("SwitchIntDef")
public static ByteBuffer extract(Image image) { public static ByteBuffer extract(MPImage image) {
ImageContainer container = image.getContainer(); MPImageContainer container = image.getContainer();
switch (container.getImageProperties().getStorageType()) { switch (container.getImageProperties().getStorageType()) {
case Image.STORAGE_TYPE_BYTEBUFFER: case MPImage.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
default: default:
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
+ " supported"); + " supported");
} }
} }
/** /**
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
* *
* <p>Format conversion spec: * <p>Format conversion spec:
* *
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
* *
* @param image the image to extract buffer from. * @param image the image to extract buffer from.
* @param targetFormat the image format of the result bytebuffer. * @param targetFormat the image format of the result bytebuffer.
* @return the readonly {@link ByteBuffer} stored in {@link Image} * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type * @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions. * conversions.
*/ */
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
ImageContainer container; MPImageContainer container;
ImageProperties byteBufferProperties = MPImageProperties byteBufferProperties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(targetFormat) .setImageFormat(targetFormat)
.build(); .build();
if ((container = image.getContainer(byteBufferProperties)) != null) { if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer(); .asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer = ByteBuffer byteBuffer =
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
return byteBuffer; return byteBuffer;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by objects other than Bitmap or" "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
+ " Bytebuffer is not supported"); + " Bytebuffer is not supported");
} }
} }
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
@AutoValue @AutoValue
abstract static class Result { abstract static class Result {
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ /**
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
public abstract ByteBuffer buffer(); public abstract ByteBuffer buffer();
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ /**
@ImageFormat * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
@MPImageFormat
public abstract int format(); public abstract int format();
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
} }
} }
/** /**
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
* *
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
* *
* @return the readonly {@link ByteBuffer} stored in {@link Image} * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
* given {@code imageFormat} * given {@code imageFormat}
*/ */
static Result extractInRecommendedFormat(Image image) { static Result extractInRecommendedFormat(MPImage image) {
ImageContainer container; MPImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
@ImageFormat int format = adviseImageFormat(bitmap); @MPImageFormat int format = adviseImageFormat(bitmap);
Result result = Result result =
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
boolean unused = boolean unused =
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
return result; return result;
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return Result.create( return Result.create(
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
byteBufferImageContainer.getImageFormat()); byteBufferImageContainer.getImageFormat());
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
+ " is not supported"); + " is not supported");
} }
} }
@ImageFormat @MPImageFormat
private static int adviseImageFormat(Bitmap bitmap) { private static int adviseImageFormat(Bitmap bitmap) {
if (bitmap.getConfig() == Config.ARGB_8888) { if (bitmap.getConfig() == Config.ARGB_8888) {
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
String.format( String.format(
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
+ " supported", + " supported",
bitmap.getConfig())); bitmap.getConfig()));
} }
} }
private static ByteBuffer extractByteBufferFromBitmap( private static ByteBuffer extractByteBufferFromBitmap(
Bitmap bitmap, @ImageFormat int imageFormat) { Bitmap bitmap, @MPImageFormat int imageFormat) {
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
+ " supported"); + " supported");
} }
if (bitmap.getConfig() == Config.ARGB_8888) { if (bitmap.getConfig() == Config.ARGB_8888) {
if (imageFormat == Image.IMAGE_FORMAT_RGBA) { if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer); bitmap.copyPixelsToBuffer(buffer);
buffer.rewind(); buffer.rewind();
return buffer; return buffer;
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) { } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster. // TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
int w = bitmap.getWidth(); int w = bitmap.getWidth();
int h = bitmap.getHeight(); int h = bitmap.getHeight();
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
String.format( String.format(
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
+ " %d is not supported", + " %d is not supported",
bitmap.getConfig(), imageFormat)); bitmap.getConfig(), imageFormat));
} }
private static ByteBuffer convertByteBuffer( private static ByteBuffer convertByteBuffer(
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
// array reversely to convert in-place. // array reversely to convert in-place.
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
target.put(array, 0, target.capacity()); target.put(array, 0, target.capacity());
target.rewind(); target.rewind();
return target; return target;
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
// array to convert in-place. // array to convert in-place.

View File

@ -15,11 +15,11 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
/** /**
* Builds a {@link Image} from a {@link ByteBuffer}. * Builds a {@link MPImage} from a {@link ByteBuffer}.
* *
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link * <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
private final ByteBuffer buffer; private final ByteBuffer buffer;
private final int width; private final int width;
private final int height; private final int height;
@ImageFormat private final int imageFormat; @MPImageFormat private final int imageFormat;
// Optional fields. // Optional fields.
private long timestamp; private long timestamp;
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
* @param imageFormat how the data encode the image. * @param imageFormat how the data encode the image.
*/ */
public ByteBufferImageBuilder( public ByteBufferImageBuilder(
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
this.buffer = byteBuffer; this.buffer = byteBuffer;
this.width = width; this.width = width;
this.height = height; this.height = height;
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
this.timestamp = 0; this.timestamp = 0;
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
ByteBufferImageBuilder setTimestamp(long timestamp) { ByteBufferImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
} }
} }

View File

@ -15,21 +15,19 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer { class ByteBufferImageContainer implements MPImageContainer {
private final ByteBuffer buffer; private final ByteBuffer buffer;
private final ImageProperties properties; private final MPImageProperties properties;
public ByteBufferImageContainer( public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
ByteBuffer buffer,
@ImageFormat int imageFormat) {
this.buffer = buffer; this.buffer = buffer;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(imageFormat) .setImageFormat(imageFormat)
.build(); .build();
} }
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
/** /** Returns the image format. */
* Returns the image format. @MPImageFormat
*/
@ImageFormat
public int getImageFormat() { public int getImageFormat() {
return properties.getImageFormat(); return properties.getImageFormat();
} }

View File

@ -29,10 +29,10 @@ import java.util.Map.Entry;
/** /**
* The wrapper class for image objects. * The wrapper class for image objects.
* *
* <p>{@link Image} is designed to be an immutable image container, which could be shared * <p>{@link MPImage} is designed to be an immutable image container, which could be shared
* cross-platforms. * cross-platforms.
* *
* <p>To construct an {@link Image}, use the provided builders: * <p>To construct a {@link MPImage}, use the provided builders:
* *
* <ul> * <ul>
* <li>{@link ByteBufferImageBuilder} * <li>{@link ByteBufferImageBuilder}
@ -40,7 +40,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageBuilder} * <li>{@link MediaImageBuilder}
* </ul> * </ul>
* *
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the * <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release * reference count is 1. Developer can call {@link #close()} to reduce reference count to release
* internal storage earlier, otherwise Java garbage collection will release the storage eventually. * internal storage earlier, otherwise Java garbage collection will release the storage eventually.
* *
@ -53,7 +53,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageExtractor} * <li>{@link MediaImageExtractor}
* </ul> * </ul>
*/ */
public class Image implements Closeable { public class MPImage implements Closeable {
/** Specifies the image format of an image. */ /** Specifies the image format of an image. */
@IntDef({ @IntDef({
@ -69,7 +69,7 @@ public class Image implements Closeable {
IMAGE_FORMAT_JPEG, IMAGE_FORMAT_JPEG,
}) })
@Retention(RetentionPolicy.SOURCE) @Retention(RetentionPolicy.SOURCE)
public @interface ImageFormat {} public @interface MPImageFormat {}
public static final int IMAGE_FORMAT_UNKNOWN = 0; public static final int IMAGE_FORMAT_UNKNOWN = 0;
public static final int IMAGE_FORMAT_RGBA = 1; public static final int IMAGE_FORMAT_RGBA = 1;
@ -98,14 +98,14 @@ public class Image implements Closeable {
public static final int STORAGE_TYPE_IMAGE_PROXY = 4; public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
/** /**
* Returns a list of supported image properties for this {@link Image}. * Returns a list of supported image properties for this {@link MPImage}.
* *
* <p>Currently {@link Image} only support single storage type so the size of return list will * <p>Currently {@link MPImage} only support single storage type so the size of return list will
* always be 1. * always be 1.
* *
* @see ImageProperties * @see MPImageProperties
*/ */
public List<ImageProperties> getContainedImageProperties() { public List<MPImageProperties> getContainedImageProperties() {
return Collections.singletonList(getContainer().getImageProperties()); return Collections.singletonList(getContainer().getImageProperties());
} }
@ -124,7 +124,7 @@ public class Image implements Closeable {
return height; return height;
} }
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
private synchronized void acquire() { private synchronized void acquire() {
referenceCount += 1; referenceCount += 1;
} }
@ -132,7 +132,7 @@ public class Image implements Closeable {
/** /**
* Removes a reference that was previously acquired or init. * Removes a reference that was previously acquired or init.
* *
* <p>When {@link Image} is created, it has 1 reference count. * <p>When {@link MPImage} is created, it has 1 reference count.
* *
* <p>When the reference count becomes 0, it will release the resource under the hood. * <p>When the reference count becomes 0, it will release the resource under the hood.
*/ */
@ -141,24 +141,24 @@ public class Image implements Closeable {
public synchronized void close() { public synchronized void close() {
referenceCount -= 1; referenceCount -= 1;
if (referenceCount == 0) { if (referenceCount == 0) {
for (ImageContainer imageContainer : containerMap.values()) { for (MPImageContainer imageContainer : containerMap.values()) {
imageContainer.close(); imageContainer.close();
} }
} }
} }
/** Advanced API access for {@link Image}. */ /** Advanced API access for {@link MPImage}. */
static final class Internal { static final class Internal {
/** /**
* Acquires a reference on this {@link Image}. This will increase the reference count by 1. * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
* *
* <p>This method is more useful for image consumer to acquire a reference so image resource * <p>This method is more useful for image consumer to acquire a reference so image resource
* will not be closed accidentally. As image creator, normal developer doesn't need to call this * will not be closed accidentally. As image creator, normal developer doesn't need to call this
* method. * method.
* *
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link * <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link Image} anymore. * #close()} to indicate it doesn't need this {@link MPImage} anymore.
* *
* @see #close() * @see #close()
*/ */
@ -166,10 +166,10 @@ public class Image implements Closeable {
image.acquire(); image.acquire();
} }
private final Image image; private final MPImage image;
// Only Image creates the internal helper. // Only MPImage creates the internal helper.
private Internal(Image image) { private Internal(MPImage image) {
this.image = image; this.image = image;
} }
} }
@ -179,15 +179,15 @@ public class Image implements Closeable {
return new Internal(this); return new Internal(this);
} }
private final Map<ImageProperties, ImageContainer> containerMap; private final Map<MPImageProperties, MPImageContainer> containerMap;
private final long timestamp; private final long timestamp;
private final int width; private final int width;
private final int height; private final int height;
private int referenceCount; private int referenceCount;
/** Constructs an {@link Image} with a built container. */ /** Constructs a {@link MPImage} with a built container. */
Image(ImageContainer container, long timestamp, int width, int height) { MPImage(MPImageContainer container, long timestamp, int width, int height) {
this.containerMap = new HashMap<>(); this.containerMap = new HashMap<>();
containerMap.put(container.getImageProperties(), container); containerMap.put(container.getImageProperties(), container);
this.timestamp = timestamp; this.timestamp = timestamp;
@ -201,10 +201,10 @@ public class Image implements Closeable {
* *
* @return the current container. * @return the current container.
*/ */
ImageContainer getContainer() { MPImageContainer getContainer() {
// According to the design, in the future we will support multiple containers in one image. // According to the design, in the future we will support multiple containers in one image.
// Currently just return the original container. // Currently just return the original container.
// TODO: Cache multiple containers in Image. // TODO: Cache multiple containers in MPImage.
return containerMap.values().iterator().next(); return containerMap.values().iterator().next();
} }
@ -214,8 +214,8 @@ public class Image implements Closeable {
* <p>If there are multiple containers with required {@code storageType}, returns the first one. * <p>If there are multiple containers with required {@code storageType}, returns the first one.
*/ */
@Nullable @Nullable
ImageContainer getContainer(@StorageType int storageType) { MPImageContainer getContainer(@StorageType int storageType) {
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
if (entry.getKey().getStorageType() == storageType) { if (entry.getKey().getStorageType() == storageType) {
return entry.getValue(); return entry.getValue();
} }
@ -225,13 +225,13 @@ public class Image implements Closeable {
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
@Nullable @Nullable
ImageContainer getContainer(ImageProperties imageProperties) { MPImageContainer getContainer(MPImageProperties imageProperties) {
return containerMap.get(imageProperties); return containerMap.get(imageProperties);
} }
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
boolean addContainer(ImageContainer container) { boolean addContainer(MPImageContainer container) {
ImageProperties imageProperties = container.getImageProperties(); MPImageProperties imageProperties = container.getImageProperties();
if (containerMap.containsKey(imageProperties)) { if (containerMap.containsKey(imageProperties)) {
return false; return false;
} }

View File

@ -14,14 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that can receive {@link Image} */ /** Lightweight abstraction for an object that can receive {@link MPImage} */
public interface ImageConsumer { public interface MPImageConsumer {
/** /**
* Called when an {@link Image} is available. * Called when a {@link MPImage} is available.
* *
* <p>The argument is only guaranteed to be available until this method returns. if you need to * <p>The argument is only guaranteed to be available until this method returns. if you need to
* extend its life time, acquire it, then release it when done. * extend its life time, acquire it, then release it when done.
*/ */
void onNewImage(Image image); void onNewMPImage(MPImage image);
} }

View File

@ -16,9 +16,9 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Manages internal image data storage. The interface is package-private. */ /** Manages internal image data storage. The interface is package-private. */
interface ImageContainer { interface MPImageContainer {
/** Returns the properties of the contained image. */ /** Returns the properties of the contained image. */
ImageProperties getImageProperties(); MPImageProperties getImageProperties();
/** Close the image container and releases the image resource inside. */ /** Close the image container and releases the image resource inside. */
void close(); void close();

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that produce {@link Image} */ /** Lightweight abstraction for an object that produce {@link MPImage} */
public interface ImageProducer { public interface MPImageProducer {
/** Sets the consumer that receives the {@link Image}. */ /** Sets the consumer that receives the {@link MPImage}. */
void setImageConsumer(ImageConsumer imageConsumer); void setMPImageConsumer(MPImageConsumer imageConsumer);
} }

View File

@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.auto.value.extension.memoized.Memoized; import com.google.auto.value.extension.memoized.Memoized;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import com.google.mediapipe.framework.image.Image.StorageType; import com.google.mediapipe.framework.image.MPImage.StorageType;
/** Groups a set of properties to describe how an image is stored. */ /** Groups a set of properties to describe how an image is stored. */
@AutoValue @AutoValue
public abstract class ImageProperties { public abstract class MPImageProperties {
/** /**
* Gets the pixel format of the image. * Gets the pixel format of the image.
* *
* @see Image.ImageFormat * @see MPImage.MPImageFormat
*/ */
@ImageFormat @MPImageFormat
public abstract int getImageFormat(); public abstract int getImageFormat();
/** /**
* Gets the storage type of the image. * Gets the storage type of the image.
* *
* @see Image.StorageType * @see MPImage.StorageType
*/ */
@StorageType @StorageType
public abstract int getStorageType(); public abstract int getStorageType();
@ -45,36 +45,36 @@ public abstract class ImageProperties {
public abstract int hashCode(); public abstract int hashCode();
/** /**
* Creates a builder of {@link ImageProperties}. * Creates a builder of {@link MPImageProperties}.
* *
* @see ImageProperties.Builder * @see MPImageProperties.Builder
*/ */
static Builder builder() { static Builder builder() {
return new AutoValue_ImageProperties.Builder(); return new AutoValue_MPImageProperties.Builder();
} }
/** Builds a {@link ImageProperties}. */ /** Builds a {@link MPImageProperties}. */
@AutoValue.Builder @AutoValue.Builder
abstract static class Builder { abstract static class Builder {
/** /**
* Sets the {@link Image.ImageFormat}. * Sets the {@link MPImage.MPImageFormat}.
* *
* @see ImageProperties#getImageFormat * @see MPImageProperties#getImageFormat
*/ */
abstract Builder setImageFormat(@ImageFormat int value); abstract Builder setImageFormat(@MPImageFormat int value);
/** /**
* Sets the {@link Image.StorageType}. * Sets the {@link MPImage.StorageType}.
* *
* @see ImageProperties#getStorageType * @see MPImageProperties#getStorageType
*/ */
abstract Builder setStorageType(@StorageType int value); abstract Builder setStorageType(@StorageType int value);
/** Builds the {@link ImageProperties}. */ /** Builds the {@link MPImageProperties}. */
abstract ImageProperties build(); abstract MPImageProperties build();
} }
// Hide the constructor. // Hide the constructor.
ImageProperties() {} MPImageProperties() {}
} }

View File

@ -15,11 +15,12 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
/** /**
* Builds {@link Image} from {@link android.media.Image}. * Builds {@link MPImage} from {@link android.media.Image}.
* *
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify * <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
* content in it. * content in it.
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
public class MediaImageBuilder { public class MediaImageBuilder {
// Mandatory fields. // Mandatory fields.
private final android.media.Image mediaImage; private final Image mediaImage;
// Optional fields. // Optional fields.
private long timestamp; private long timestamp;
@ -40,20 +41,20 @@ public class MediaImageBuilder {
* *
* @param mediaImage image data object. * @param mediaImage image data object.
*/ */
public MediaImageBuilder(android.media.Image mediaImage) { public MediaImageBuilder(Image mediaImage) {
this.mediaImage = mediaImage; this.mediaImage = mediaImage;
this.timestamp = 0; this.timestamp = 0;
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
MediaImageBuilder setTimestamp(long timestamp) { MediaImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image( return new MPImage(
new MediaImageContainer(mediaImage), new MediaImageContainer(mediaImage),
timestamp, timestamp,
mediaImage.getWidth(), mediaImage.getWidth(),

View File

@ -15,33 +15,34 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build; import android.os.Build;
import android.os.Build.VERSION; import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
@RequiresApi(VERSION_CODES.KITKAT) @RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer { class MediaImageContainer implements MPImageContainer {
private final android.media.Image mediaImage; private final Image mediaImage;
private final ImageProperties properties; private final MPImageProperties properties;
public MediaImageContainer(android.media.Image mediaImage) { public MediaImageContainer(Image mediaImage) {
this.mediaImage = mediaImage; this.mediaImage = mediaImage;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
.setImageFormat(convertFormatCode(mediaImage.getFormat())) .setImageFormat(convertFormatCode(mediaImage.getFormat()))
.build(); .build();
} }
public android.media.Image getImage() { public Image getImage() {
return mediaImage; return mediaImage;
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
mediaImage.close(); mediaImage.close();
} }
@ImageFormat @MPImageFormat
static int convertFormatCode(int graphicsFormat) { static int convertFormatCode(int graphicsFormat) {
// We only cover the format mentioned in // We only cover the format mentioned in
// https://developer.android.com/reference/android/media/Image#getFormat() // https://developer.android.com/reference/android/media/Image#getFormat()
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
return Image.IMAGE_FORMAT_RGB; return MPImage.IMAGE_FORMAT_RGB;
} }
} }
switch (graphicsFormat) { switch (graphicsFormat) {
case android.graphics.ImageFormat.JPEG: case android.graphics.ImageFormat.JPEG:
return Image.IMAGE_FORMAT_JPEG; return MPImage.IMAGE_FORMAT_JPEG;
case android.graphics.ImageFormat.YUV_420_888: case android.graphics.ImageFormat.YUV_420_888:
return Image.IMAGE_FORMAT_YUV_420_888; return MPImage.IMAGE_FORMAT_YUV_420_888;
default: default:
return Image.IMAGE_FORMAT_UNKNOWN; return MPImage.IMAGE_FORMAT_UNKNOWN;
} }
} }
} }

View File

@ -15,13 +15,14 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
/** /**
* Utility for extracting {@link android.media.Image} from {@link Image}. * Utility for extracting {@link android.media.Image} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
* otherwise {@link IllegalArgumentException} will be thrown. * otherwise {@link IllegalArgumentException} will be thrown.
*/ */
@RequiresApi(VERSION_CODES.KITKAT) @RequiresApi(VERSION_CODES.KITKAT)
@ -30,20 +31,20 @@ public class MediaImageExtractor {
private MediaImageExtractor() {} private MediaImageExtractor() {}
/** /**
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
* {@link Image} that built from {@link MediaImageBuilder}. * {@link MPImage} that built from {@link MediaImageBuilder}.
* *
* @param image the image to extract {@link android.media.Image} from. * @param image the image to extract {@link android.media.Image} from.
* @return {@link android.media.Image} that stored in {@link Image}. * @return {@link android.media.Image} that stored in {@link MPImage}.
* @throws IllegalArgumentException if the extraction failed. * @throws IllegalArgumentException if the extraction failed.
*/ */
public static android.media.Image extract(Image image) { public static Image extract(MPImage image) {
ImageContainer container; MPImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
return ((MediaImageContainer) container).getImage(); return ((MediaImageContainer) container).getImage();
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract Media Image from an Image created by objects other than Media Image" "Extract Media Image from a MPImage created by objects other than Media Image"
+ " is not supported"); + " is not supported");
} }
} }

View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors. # Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -209,9 +209,9 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
def mediapipe_build_aar_with_jni(name, android_library): def mediapipe_build_aar_with_jni(name, android_library):
"""Builds MediaPipe AAR with jni. """Builds MediaPipe AAR with jni.
Args: Args:
name: The bazel target name. name: The bazel target name.
android_library: the android library that contains jni. android_library: the android library that contains jni.
""" """
# Generates dummy AndroidManifest.xml for dummy apk usage # Generates dummy AndroidManifest.xml for dummy apk usage
@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""):
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java", src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite", target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""):
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite", target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
)) ))
return proto_src_list return proto_src_list

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
package( package(
default_visibility = ["//mediapipe:__subpackages__"], default_visibility = ["//mediapipe:__subpackages__"],

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"]) licenses(["notice"])
@ -23,15 +24,12 @@ package(
py_library( py_library(
name = "data_util", name = "data_util",
srcs = ["data_util.py"], srcs = ["data_util.py"],
srcs_version = "PY3",
) )
py_test( py_test(
name = "data_util_test", name = "data_util_test",
srcs = ["data_util_test.py"], srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"], data = ["//mediapipe/model_maker/python/core/data/testdata"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":data_util"], deps = [":data_util"],
) )
@ -44,8 +42,6 @@ py_library(
py_test( py_test(
name = "dataset_test", name = "dataset_test",
srcs = ["dataset_test.py"], srcs = ["dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":dataset", ":dataset",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",
@ -55,14 +51,11 @@ py_test(
py_library( py_library(
name = "classification_dataset", name = "classification_dataset",
srcs = ["classification_dataset.py"], srcs = ["classification_dataset.py"],
srcs_version = "PY3",
deps = [":dataset"], deps = [":dataset"],
) )
py_test( py_test(
name = "classification_dataset_test", name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"], srcs = ["classification_dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":classification_dataset"], deps = [":classification_dataset"],
) )

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Common classification dataset library.""" """Common classification dataset library."""
from typing import Any, Tuple from typing import List, Tuple
import tensorflow as tf import tensorflow as tf
@ -21,15 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset): class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models.""" """Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any): def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
super().__init__(dataset, size) super().__init__(dataset, size)
self.index_to_label = index_to_label self._label_names = label_names
@property @property
def num_classes(self: ds._DatasetT) -> int: def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label) return len(self._label_names)
@property
def label_names(self: ds._DatasetT) -> List[str]:
return self._label_names
def split(self: ds._DatasetT, def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -44,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
Returns: Returns:
The splitted two sub datasets. The splitted two sub datasets.
""" """
return self._split(fraction, self.index_to_label) return self._split(fraction, self._label_names)

View File

@ -12,45 +12,55 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List, Tuple, TypeVar
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
_DatasetT = TypeVar(
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
class ClassificationDataLoaderTest(tf.test.TestCase):
class ClassificationDatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
class MagicClassificationDataLoader( class MagicClassificationDataset(
classification_dataset.ClassificationDataset): classification_dataset.ClassificationDataset):
"""A mock classification dataset class for testing purpose.
def __init__(self, dataset, size, index_to_label, value): Attributes:
super(MagicClassificationDataLoader, value: A value variable stored by the mock dataset class for testing.
self).__init__(dataset, size, index_to_label) """
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str], value: Any):
super().__init__(dataset=dataset, size=size, label_names=label_names)
self.value = value self.value = value
def split(self, fraction): def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_to_label, self.value) return self._split(fraction, self.label_names, self.value)
# Some dummy inputs. # Some dummy inputs.
magic_value = 42 magic_value = 42
num_classes = 2 num_classes = 2
index_to_label = (False, True) label_names = ['foo', 'bar']
# Create data loader from sample data. # Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label, data = MagicClassificationDataset(
magic_value) dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
# Train/Test data split. # Train/Test data split.
fraction = .25 fraction = .25
train_data, test_data = data.split(fraction) train_data, test_data = data.split(fraction=fraction)
# `split` should return instances of child DataLoader. # `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader) self.assertIsInstance(train_data, MagicClassificationDataset)
self.assertIsInstance(test_data, MagicClassificationDataLoader) self.assertIsInstance(test_data, MagicClassificationDataset)
# Make sure number of entries are right. # Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data)) self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
@ -59,7 +69,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
# Make sure attributes propagated correctly. # Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes) self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label) self.assertEqual(test_data.label_names, label_names)
self.assertEqual(train_data.value, magic_value) self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value) self.assertEqual(test_data.value, magic_value)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
package( package(
default_visibility = ["//mediapipe:__subpackages__"], default_visibility = ["//mediapipe:__subpackages__"],
@ -23,7 +24,6 @@ licenses(["notice"])
py_library( py_library(
name = "custom_model", name = "custom_model",
srcs = ["custom_model.py"], srcs = ["custom_model.py"],
srcs_version = "PY3",
deps = [ deps = [
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
@ -34,8 +34,6 @@ py_library(
py_test( py_test(
name = "custom_model_test", name = "custom_model_test",
srcs = ["custom_model_test.py"], srcs = ["custom_model_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":custom_model", ":custom_model",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",
@ -45,7 +43,6 @@ py_test(
py_library( py_library(
name = "classifier", name = "classifier",
srcs = ["classifier.py"], srcs = ["classifier.py"],
srcs_version = "PY3",
deps = [ deps = [
":custom_model", ":custom_model",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
@ -55,8 +52,6 @@ py_library(
py_test( py_test(
name = "classifier_test", name = "classifier_test",
srcs = ["classifier_test.py"], srcs = ["classifier_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":classifier", ":classifier",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool, def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
full_train: bool): full_train: bool):
"""Initilizes a classifier with its specifications. """Initilizes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_to_label: A list that map from index to label class name. label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top and the classification layers on top. Otherwise, only train the top
classification layers. classification layers.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._index_to_label = index_to_label self._label_names = label_names
self._full_train = full_train self._full_train = full_train
self._num_classes = len(index_to_label) self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename) label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath) tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f: with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_to_label)) f.write('\n'.join(self._label_names))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(ClassifierTest, self).setUp() super(ClassifierTest, self).setUp()
index_to_label = ['cat', 'dog'] label_names = ['cat', 'dog']
self.model = MockClassifier( self.model = MockClassifier(
model_spec=None, model_spec=None,
index_to_label=index_to_label, label_names=label_names,
shuffle=False, shuffle=False,
full_train=False) full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2) self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -21,8 +21,6 @@ import abc
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
# Dependency imports
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.data import dataset
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
tflite_filepath = os.path.join(export_dir, tflite_filename) tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model. # TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite( model_util.export_tflite(
self._model, model=self._model,
tflite_filepath, tflite_filepath=tflite_filepath,
quantization_config, quantization_config=quantization_config,
preprocess=preprocess) preprocess=preprocess)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath) 'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(CustomModelTest, self).setUp() super(CustomModelTest, self).setUp()
self.model = MockCustomModel(model_spec=None, shuffle=False) self._model = MockCustomModel(model_spec=None, shuffle=False)
self.model._model = test_util.build_model(input_shape=[4], num_classes=2) self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath): def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath)) self.assertTrue(os.path.isfile(filepath))
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
def test_export_tflite(self): def test_export_tflite(self):
export_path = os.path.join(self.get_temp_dir(), 'export/') export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_tflite(export_dir=export_path) self._model.export_tflite(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"]) licenses(["notice"])
@ -24,31 +25,15 @@ py_library(
name = "test_util", name = "test_util",
testonly = 1, testonly = 1,
srcs = ["test_util.py"], srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [ deps = [
":model_util", ":model_util",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
], ],
) )
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
srcs_version = "PY3",
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":image_preprocessing"],
)
py_library( py_library(
name = "model_util", name = "model_util",
srcs = ["model_util.py"], srcs = ["model_util.py"],
srcs_version = "PY3",
deps = [ deps = [
":quantization", ":quantization",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
@ -58,8 +43,6 @@ py_library(
py_test( py_test(
name = "model_util_test", name = "model_util_test",
srcs = ["model_util_test.py"], srcs = ["model_util_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":model_util", ":model_util",
":quantization", ":quantization",
@ -76,8 +59,6 @@ py_library(
py_test( py_test(
name = "loss_functions_test", name = "loss_functions_test",
srcs = ["loss_functions_test.py"], srcs = ["loss_functions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":loss_functions"], deps = [":loss_functions"],
) )
@ -91,8 +72,6 @@ py_library(
py_test( py_test(
name = "quantization_test", name = "quantization_test",
srcs = ["quantization_test.py"], srcs = ["quantization_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":quantization", ":quantization",
":test_util", ":test_util",

View File

@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss):
class_weight: A weight to apply to the loss, one for each class. The class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches. weight is applied for each input where the ground truth label matches.
""" """
super(tf.keras.losses.Loss, self).__init__() super().__init__()
# Used for clipping min/max values of probability values in y_pred to avoid # Used for clipping min/max values of probability values in y_pred to avoid
# NaNs and Infs in computation. # NaNs and Infs in computation.
self._epsilon = 1e-7 self._epsilon = 1e-7

View File

@ -104,8 +104,8 @@ def export_tflite(
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, quantization. The callable takes three arguments in order: feature, label,
label, and is_training. and is_training.
""" """
if tflite_filepath is None: if tflite_filepath is None:
raise ValueError( raise ValueError(

View File

@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file) model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim) test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
input_dim = 16 input_dim = 16
num_classes = 2 num_classes = 2
max_input_value = 5 max_input_value = 5
model = test_util.build_model([input_dim], num_classes) model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config) model_util.export_tflite(
self._test_tflite( model=model, tflite_filepath=tflite_file, quantization_config=config)
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
random_input = test_util.create_random_sample(
size=[1, input_dim], high=max_input_value)
random_input = tf.convert_to_tensor(random_input)
self.assertTrue( self.assertTrue(
test_util.is_same_output( test_util.test_tflite(
tflite_model_file, keras_model, random_input, atol=atol)) keras_model=model,
tflite_file=tflite_file,
size=[1, input_dim],
high=max_input_value,
atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
keras_output = keras_model.predict_on_batch(input_tensors) keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol) return np.allclose(lite_output, keras_output, atol=atol)
def test_tflite(keras_model: tf.keras.Model,
tflite_file: str,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)

View File

@ -0,0 +1,4 @@
# MediaPipe Model Maker Internal Library
This directory contains model maker library for internal users and experimental
purposes.

View File

@ -0,0 +1 @@
"""Model maker internal library."""

View File

@ -0,0 +1,33 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
deps = [":image_preprocessing"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -13,11 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ImageNet preprocessing.""" """ImageNet preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
IMAGE_SIZE = 224 IMAGE_SIZE = 224

View File

@ -12,15 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import image_preprocessing from mediapipe.model_maker.python.vision.core import image_preprocessing
def _get_preprocessed_image(preprocessor, is_training=False): def _get_preprocessed_image(preprocessor, is_training=False):

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python library rule.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python library rule.
licenses(["notice"]) licenses(["notice"])
@ -78,9 +78,9 @@ py_library(
":train_image_classifier_lib", ":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
], ],
) )

View File

@ -16,7 +16,7 @@
import os import os
import random import random
from typing import List, Optional, Tuple from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
@ -84,10 +84,10 @@ class Dataset(classification_dataset.ClassificationDataset):
name for name in os.listdir(data_root) name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))) if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names) all_label_size = len(label_names)
label_to_index = dict( index_by_label = dict(
(name, index) for index, name in enumerate(label_names)) (name, index) for index, name in enumerate(label_names))
all_image_labels = [ all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))] index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths for path in all_image_paths
] ]
@ -106,33 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, index_to_label=label_names) dataset=image_label_ds, size=all_image_size, label_names=label_names)
@classmethod
def load_tf_dataset(
cls, name: str
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset]]:
"""Loads data from tensorflow_datasets.
Args:
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
documentation of tfds.load for more details.
Returns:
A tuple of Datasets for the train/validation/test.
Raises:
ValueError: if the input tf dataset does not have train/validation/test
labels.
"""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = _create_data('train', data, info, label_names)
validation_data = _create_data('validation', data, info, label_names)
test_data = _create_data('test', data, info, label_names)
return train_data, validation_data, test_data

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg']) data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
train_data, test_data = data.split(0.5) train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2) self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset): for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2) self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg']) self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2) self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset): for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2) self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg']) self.assertEqual(test_data.label_names, ['pos', 'neg'])
def test_from_folder(self): def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path) data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2) self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2) self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips']) self.assertEqual(data.label_names, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset(): for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0) self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0: if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034) self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3) self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label, self.assertEqual(train_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133) self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3) self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label, self.assertEqual(validation_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128) self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3) self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label, self.assertEqual(test_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -13,16 +13,16 @@
# limitations under the License. # limitations under the License.
"""APIs to train image classifier model.""" """APIs to train image classifier model."""
from typing import Any, List, Optional from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import image_preprocessing
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model.""" """ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any], def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams): hparams: hp.HParams):
"""Initializes ImageClassifier class. """Initializes ImageClassifier class.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_to_label: A list that maps from index to label class name. label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
""" """
super(ImageClassifier, self).__init__( super().__init__(
model_spec=model_spec, model_spec=model_spec,
index_to_label=index_to_label, label_names=label_names,
shuffle=hparams.shuffle, shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning) full_train=hparams.do_fine_tuning)
self._hparams = hparams self._hparams = hparams
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec) spec = ms.SupportedModels.get(model_spec)
image_classifier = cls( image_classifier = cls(
model_spec=spec, model_spec=spec, label_names=train_data.label_names, hparams=hparams)
index_to_label=train_data.index_to_label,
hparams=hparams)
image_classifier._create_model() image_classifier._create_model()

View File

@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
return model.fit( return model.fit(
x=train_ds, x=train_ds,
epochs=hparams.train_epochs, epochs=hparams.train_epochs,
steps_per_epoch=hparams.steps_per_epoch,
validation_data=validation_ds, validation_data=validation_ds,
callbacks=callbacks) callbacks=callbacks)

View File

@ -161,7 +161,7 @@ class Texture {
~Texture() { ~Texture() {
if (is_owned_) { if (is_owned_) {
glDeleteProgram(handle_); glDeleteTextures(1, &handle_);
} }
} }

View File

@ -87,6 +87,9 @@ cc_library(
cc_library( cc_library(
name = "builtin_task_graphs", name = "builtin_task_graphs",
deps = [ deps = [
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
], ],

View File

@ -14,7 +14,7 @@
"""The public facing packet getter APIs.""" """The public facing packet getter APIs."""
from typing import List, Type from typing import List
from google.protobuf import message from google.protobuf import message
from google.protobuf import symbol_database from google.protobuf import symbol_database
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
get_matrix = _packet_getter.get_matrix get_matrix = _packet_getter.get_matrix
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]: def get_proto(packet: mp_packet.Packet) -> message.Message:
"""Get the content of a MediaPipe proto Packet as a proto message. """Get the content of a MediaPipe proto Packet as a proto message.
Args: Args:

View File

@ -46,8 +46,10 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",

View File

@ -44,6 +44,30 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "classification_aggregation_calculator_test",
srcs = ["classification_aggregation_calculator_test.cc"],
deps = [
":classification_aggregation_calculator",
":classification_aggregation_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:output_stream_poller",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "score_calibration_calculator_proto", name = "score_calibration_calculator_proto",
srcs = ["score_calibration_calculator.proto"], srcs = ["score_calibration_calculator.proto"],

View File

@ -31,37 +31,62 @@
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications; using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has // Aggregates ClassificationLists into either a ClassificationResult object
// 3 dimensions: (classification head, classification timestamp, classification // representing the classification results aggregated by classifier head, or
// category). // into an std::vector<ClassificationResult> representing the classification
// results aggregated first by timestamp then by classifier head.
// //
// Inputs: // Inputs:
// CLASSIFICATIONS - ClassificationList // CLASSIFICATIONS - ClassificationList @Multiple
// ClassificationList per classification head. // ClassificationList per classification head.
// TIMESTAMPS - std::vector<Timestamp> @Optional // TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that a single ClassificationResult // The collection of the timestamps that this calculator should aggregate.
// should aggragate. This stream is optional, and the timestamp information // This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// will only be populated to the ClassificationResult proto when this stream // output is used for results. Otherwise as no timestamp aggregation is
// is connected. // required the CLASSIFICATIONS output is used for results.
// //
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result. // The aggregated classification result.
// //
// Example: // Example without timestamp aggregation:
// node {
// calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c"
// output_stream: "CLASSIFICATIONS:classifications"
// options {
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a"
// head_names: "head_name_b"
// head_names: "head_name_c"
// }
// }
// }
//
// Example with timestamp aggregation:
// node { // node {
// calculator: "ClassificationAggregationCalculator" // calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a" // input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b" // input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c" // input_stream: "CLASSIFICATIONS:2:stream_c"
// input_stream: "TIMESTAMPS:timestamps" // input_stream: "TIMESTAMPS:timestamps"
// output_stream: "CLASSIFICATION_RESULT:classification_result" // output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options { // options {
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] { // [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a" // head_names: "head_name_a"
// head_names: "head_name_b" // head_names: "head_name_b"
// head_names: "head_name_c" // head_names: "head_name_c"
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
"CLASSIFICATIONS"}; "CLASSIFICATIONS"};
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{ static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
"TIMESTAMPS"}; "TIMESTAMPS"};
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"}; static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut); "CLASSIFICATIONS"};
static constexpr Output<std::vector<ClassificationResult>>::Optional
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
static constexpr Output<ClassificationResult>::Optional
kClassificationResultOut{"CLASSIFICATION_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
kClassificationsOut, kTimestampedClassificationsOut,
kClassificationResultOut);
static absl::Status UpdateContract(CalculatorContract* cc); static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc); absl::Status Open(CalculatorContext* cc);
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
cached_classifications_; cached_classifications_;
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
CalculatorContext* cc);
// TODO: deprecate this function once migration is over.
ClassificationResult LegacyConvertToClassificationResult(
CalculatorContext* cc);
}; };
absl::Status ClassificationAggregationCalculator::UpdateContract( absl::Status ClassificationAggregationCalculator::UpdateContract(
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
<< "The size of classifications input streams should match the " << "The size of classifications input streams should match the "
"size of head names specified in the calculator options"; "size of head names specified in the calculator options";
} }
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
// not connected. All dependent tasks must be updated to use these outputs
// first.
return absl::OkStatus(); return absl::OkStatus();
} }
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
[](const auto& elem) -> ClassificationList { return elem.Get(); }); [](const auto& elem) -> ClassificationList { return elem.Get(); });
cached_classifications_[cc->InputTimestamp().Value()] = cached_classifications_[cc->InputTimestamp().Value()] =
std::move(classification_lists); std::move(classification_lists);
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) { ClassificationResult classification_result;
return absl::OkStatus(); if (time_aggregation_enabled_) {
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
classification_result = LegacyConvertToClassificationResult(cc);
kTimestampedClassificationsOut(cc).Send(
ConvertToTimestampedClassificationResults(cc));
} else {
classification_result = LegacyConvertToClassificationResult(cc);
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
} }
kOut(cc).Send(ConvertToClassificationResult(cc)); kClassificationResultOut(cc).Send(classification_result);
RET_CHECK(cached_classifications_.empty()); RET_CHECK(cached_classifications_.empty());
return absl::OkStatus(); return absl::OkStatus();
} }
@ -136,6 +186,50 @@ ClassificationResult
ClassificationAggregationCalculator::ConvertToClassificationResult( ClassificationAggregationCalculator::ConvertToClassificationResult(
CalculatorContext* cc) { CalculatorContext* cc) {
ClassificationResult result; ClassificationResult result;
auto& classification_lists =
cached_classifications_[cc->InputTimestamp().Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(cc->InputTimestamp().Value());
return result;
}
std::vector<ClassificationResult>
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
CalculatorContext* cc) {
auto timestamps = kTimestampsIn(cc).Get();
std::vector<ClassificationResult> results;
results.reserve(timestamps.size());
for (const auto& timestamp : timestamps) {
ClassificationResult result;
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
auto& classification_lists = cached_classifications_[timestamp.Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(timestamp.Value());
results.push_back(std::move(result));
}
return results;
}
ClassificationResult
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
Timestamp first_timestamp(0); Timestamp first_timestamp(0);
std::vector<Timestamp> timestamps; std::vector<Timestamp> timestamps;
if (time_aggregation_enabled_) { if (time_aggregation_enabled_) {
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
1000); 1000);
} }
cached_classifications_.erase(timestamp.Value());
} }
return result; return result;
} }

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";

View File

@ -0,0 +1,213 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::Pointwise;
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
constexpr char kClassificationInput0Name[] = "classifications_0";
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
constexpr char kClassificationInput1Name[] = "classifications_1";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications";
ClassificationList MakeClassificationList(int class_index) {
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
R"pb(
classification { index: %d }
)pb",
class_index));
}
class ClassificationAggregationCalculatorTest
: public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
bool connect_timestamps = false) {
Graph graph;
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
calculator
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
ParseTextProtoOrDie<
mediapipe::ClassificationAggregationCalculatorOptions>(
R"pb(head_names: "foo" head_names: "bar")pb");
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
kClassificationInput0Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
kClassificationInput1Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
calculator.In(kTimestampsTag);
calculator.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
} else {
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
graph[Output<ClassificationResult>(kClassificationsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
absl::Status Send(
std::vector<ClassificationList> classifications, int timestamp = 0,
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput0Name,
MakePacket<ClassificationList>(classifications[0])
.At(Timestamp(timestamp))));
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput1Name,
MakePacket<ClassificationList>(classifications[1])
.At(Timestamp(timestamp))));
if (aggregation_timestamps.has_value()) {
auto packet = std::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
};
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
})pb")));
}
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK(Send(
{MakeClassificationList(2), MakeClassificationList(3)},
/*timestamp=*/1000,
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
MP_ASSERT_OK_AND_ASSIGN(auto result,
GetResult<std::vector<ClassificationResult>>(poller));
EXPECT_THAT(result,
Pointwise(EqualsProto(),
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
}
)pb"),
ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 1,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 2 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 3 } }
}
)pb")}));
}
} // namespace
} // namespace mediapipe

View File

@ -29,3 +29,23 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
], ],
) )
cc_library(
name = "category",
srcs = ["category.cc"],
hdrs = ["category.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
],
)
cc_library(
name = "classification_result",
srcs = ["classification_result.cc"],
hdrs = ["classification_result.h"],
deps = [
":category",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/category.h"
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
Category ConvertToCategory(const mediapipe::Classification& proto) {
Category category;
category.index = proto.index();
category.score = proto.score();
if (proto.has_label()) {
category.category_name = proto.label();
}
if (proto.has_display_name()) {
category.display_name = proto.display_name();
}
return category;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines a single classification result.
//
// The label maps packed into the TFLite Model Metadata [1] are used to populate
// the 'category_name' and 'display_name' fields.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
struct Category {
// The index of the category in the classification model output.
int index;
// The score for this category, e.g. (but not necessarily) a probability in
// [0,1].
float score;
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
std::optional<std::string> category_name = std::nullopt;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
std::optional<std::string> display_name = std::nullopt;
};
// Utility function to convert from mediapipe::Classification proto to Category
// struct.
Category ConvertToCategory(const mediapipe::Classification& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
Classifications ConvertToClassifications(const proto::Classifications& proto) {
Classifications classifications;
classifications.categories.reserve(
proto.classification_list().classification_size());
for (const auto& classification :
proto.classification_list().classification()) {
classifications.categories.push_back(ConvertToCategory(classification));
}
classifications.head_index = proto.head_index();
if (proto.has_head_name()) {
classifications.head_name = proto.head_name();
}
return classifications;
}
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto) {
ClassificationResult classification_result;
classification_result.classifications.reserve(proto.classifications_size());
for (const auto& classifications : proto.classifications()) {
classification_result.classifications.push_back(
ConvertToClassifications(classifications));
}
if (proto.has_timestamp_ms()) {
classification_result.timestamp_ms = proto.timestamp_ms();
}
return classification_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,68 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines classification results for a given classifier head.
struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
// e.g. from high to low probability.
std::vector<Category> categories;
// The index of the classifier head (i.e. output tensor) these categories
// refer to. This is useful for multi-head models.
int head_index;
// The optional name of the classifier head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines classification results of a model.
struct ClassificationResult {
// The classification results for each head of the model.
std::vector<Classifications> classifications;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Classifications proto to
// Classifications struct.
Classifications ConvertToClassifications(const proto::Classifications& proto);
// Utility function to convert from ClassificationResult proto to
// ClassificationResult struct.
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

View File

@ -28,6 +28,7 @@ mediapipe_proto_library(
srcs = ["classifications.proto"], srcs = ["classifications.proto"],
deps = [ deps = [
":category_proto", ":category_proto",
"//mediapipe/framework/formats:classification_proto",
], ],
) )

View File

@ -17,9 +17,10 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.container.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto"; option java_outer_classname = "CategoryProto";
// TODO: deprecate this message once migration is over.
// A single classification result. // A single classification result.
message Category { message Category {
// The index of the category in the corresponding label map, usually packed in // The index of the category in the corresponding label map, usually packed in

View File

@ -17,11 +17,13 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/tasks/cc/components/containers/proto/category.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.container.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto"; option java_outer_classname = "ClassificationsProto";
// TODO: deprecate this message once migration is over.
// List of predicted categories with an optional timestamp. // List of predicted categories with an optional timestamp.
message ClassificationEntry { message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores, // The array of predicted categories, usually sorted by descending scores,
@ -33,9 +35,12 @@ message ClassificationEntry {
optional int64 timestamp_ms = 2; optional int64 timestamp_ms = 2;
} }
// Classifications for a given classifier head. // Classifications for a given classifier head, i.e. for a given output tensor.
message Classifications { message Classifications {
// TODO: deprecate this field once migration is over.
repeated ClassificationEntry entries = 1; repeated ClassificationEntry entries = 1;
// The classification results for this head.
optional mediapipe.ClassificationList classification_list = 4;
// The index of the classifier head these categories refer to. This is useful // The index of the classifier head these categories refer to. This is useful
// for multi-head models. // for multi-head models.
optional int32 head_index = 2; optional int32 head_index = 2;
@ -45,7 +50,17 @@ message Classifications {
optional string head_name = 3; optional string head_name = 3;
} }
// Contains one set of results per classifier head. // Classifications for a given classifier model.
message ClassificationResult { message ClassificationResult {
// The classification results for each model head, i.e. one for each output
// tensor.
repeated Classifications classifications = 1; repeated Classifications classifications = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
} }

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "EmbeddingsProto";
// Defines a dense floating-point embedding. // Defines a dense floating-point embedding.
message FloatEmbedding { message FloatEmbedding {
repeated float values = 1 [packed = true]; repeated float values = 1 [packed = true];

View File

@ -30,9 +30,11 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator(
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
std); std);
} }
// TODO: need to support different GPU origin on differnt
// platforms or applications.
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
return absl::OkStatus(); return absl::OkStatus();
} }
} // namespace } // namespace
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration) {
return acceleration.has_gpu();
}
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
bool use_gpu,
ImagePreprocessingOptions* options) { ImagePreprocessingOptions* options) {
ASSIGN_OR_RETURN(auto image_tensor_specs, ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources)); BuildImageTensorSpecs(model_resources));
@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
image_tensor_specs, options->mutable_image_to_tensor_options())); image_tensor_specs, options->mutable_image_to_tensor_options()));
// The GPU backend isn't able to process int data. If the input tensor is // The GPU backend isn't able to process int data. If the input tensor is
// quantized, forces the image preprocessing graph to use CPU backend. // quantized, forces the image preprocessing graph to use CPU backend.
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
} else {
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -19,20 +19,26 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
// Configures an ImagePreprocessing subgraph using the provided model resources. // Configures an ImagePreprocessing subgraph using the provided model resources
// When use_gpu is true, use GPU as backend to convert image to tensor.
// - Accepts CPU input images and outputs CPU tensors. // - Accepts CPU input images and outputs CPU tensors.
// //
// Example usage: // Example usage:
// //
// auto& preprocessing = // auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
// core::proto::Acceleration acceleration;
// acceleration.mutable_xnnpack();
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
// model_resources, // model_resources,
// use_gpu,
// &preprocessing.GetOptions<ImagePreprocessingOptions>())); // &preprocessing.GetOptions<ImagePreprocessingOptions>()));
// //
// The resulting ImagePreprocessing subgraph has the following I/O: // The resulting ImagePreprocessing subgraph has the following I/O:
@ -56,9 +62,14 @@ namespace components {
// The image that has the pixel data stored on the target storage (CPU vs // The image that has the pixel data stored on the target storage (CPU vs
// GPU). // GPU).
absl::Status ConfigureImagePreprocessing( absl::Status ConfigureImagePreprocessing(
const core::ModelResources& model_resources, const core::ModelResources& model_resources, bool use_gpu,
ImagePreprocessingOptions* options); ImagePreprocessingOptions* options);
// Determine if the image preprocessing subgraph should use GPU as the backend
// according to the given acceleration setting.
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration);
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES"; constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
// Struct holding the different output streams produced by the graph.
struct ClassificationPostprocessingOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
// Performs sanity checks on provided ClassifierOptions. // Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions( absl::Status SanityCheckClassifierOptions(
@ -286,7 +294,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
void ConfigureClassificationAggregationCalculator( void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor, const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) { mediapipe::ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) { if (output_tensors_metadata == nullptr) {
return; return;
@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator. // The output tensors of an InferenceCalculator.
// TIMESTAMPS - std::vector<Timestamp> @Optional // TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of timestamps that a single ClassificationResult should // The collection of the timestamps that this calculator should aggregate.
// aggregate. This is mostly useful for classifiers working on time series, // This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// e.g. audio or video classification. // output is used for results. Otherwise as no timestamp aggregation is
// required the CLASSIFICATIONS output is used for results.
//
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The output aggregated classification results. // The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
// //
// The recommended way of using this graph is through the GraphBuilder API // The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header // using the 'ConfigureClassificationPostprocessingGraph()' function. See header
@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto output_streams,
BuildClassificationPostprocessing( BuildClassificationPostprocessing(
sc->Options<proto::ClassificationPostprocessingGraphOptions>(), sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph)); graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >> output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)]; graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
// Adds an on-device classification postprocessing graph into the provided // Adds an on-device classification postprocessing graph into the provided
// builder::Graph instance. The classification postprocessing graph takes // builder::Graph instance. The classification postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output // tensors (std::vector<mediapipe::Tensor>) and optional timestamps
// stream containing the output classification results (ClassificationResult). // (std::vector<Timestamp>) as input and returns two output streams:
// - classification results aggregated by classifier head as a
// ClassificationResult proto, used when no timestamps are passed in
// the graph,
// - classification results aggregated by timestamp then by classifier head
// as a std::vector<ClassificationResult>, used when timestamps are passed
// in the graph.
// //
// options: the on-device ClassificationPostprocessingGraphOptions. // options: the on-device ClassificationPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess. // tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of // timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate. // timestamps that should be used to aggregate classification results.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> absl::StatusOr<ClassificationPostprocessingOutputStreams>
BuildClassificationPostprocessing( BuildClassificationPostprocessing(
const proto::ClassificationPostprocessingGraphOptions& options, const proto::ClassificationPostprocessingGraphOptions& options,
Source<std::vector<Tensor>> tensors_in, Source<std::vector<Tensor>> tensors_in,
@ -494,7 +524,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
// Aggregates Classifications into a single ClassificationResult. // Aggregates Classifications into a single ClassificationResult.
auto& result_aggregation = auto& result_aggregation =
graph.AddNode("ClassificationAggregationCalculator"); graph.AddNode("ClassificationAggregationCalculator");
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>() result_aggregation
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
.CopyFrom(options.classification_aggregation_options()); .CopyFrom(options.classification_aggregation_options());
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >> tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
@ -504,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
timestamps_in >> result_aggregation.In(kTimestampsTag); timestamps_in >> result_aggregation.In(kTimestampsTag);
// Connects output. // Connects output.
return result_aggregation[Output<ClassificationResult>( ClassificationPostprocessingOutputStreams output_streams{
kClassificationResultTag)]; /*classification_result=*/result_aggregation
[Output<ClassificationResult>(kClassificationResultTag)],
/*classifications=*/
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
/*timestamped_classifications=*/
result_aggregation[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)]};
return output_streams;
} }
}; };

Some files were not shown because too many files have changed in this diff Show More