Merge branch 'master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-03 23:24:31 +05:30 committed by GitHub
commit 5a68ba84b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
244 changed files with 12284 additions and 2008 deletions

View File

@ -30,6 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
git \
wget \
unzip \
nodejs \
npm \
python3-dev \
python3-opencv \
python3-pip \

View File

@ -172,6 +172,10 @@ http_archive(
urls = [
"https://github.com/google/sentencepiece/archive/1.0.0.zip",
],
patches = [
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
],
patch_args = ["-p1"],
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
)

14
docs/BUILD Normal file
View File

@ -0,0 +1,14 @@
# Placeholder for internal Python strict binary compatibility macro.
py_binary(
name = "build_py_api_docs",
srcs = ["build_py_api_docs.py"],
deps = [
"//mediapipe",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/tensorflow_docs",
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
"//third_party/py/tensorflow_docs/api_generator:public_api",
],
)

85
docs/build_py_api_docs.py Normal file
View File

@ -0,0 +1,85 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""MediaPipe reference docs generation script.
This script generates API reference docs for the `mediapipe` PIP package.
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
$> python build_py_api_docs.py
"""
import os
from absl import app
from absl import flags
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
try:
# mediapipe has not been set up to work with bazel yet, so catch & report.
import mediapipe # pytype: disable=import-error
except ImportError as e:
raise ImportError('Please `pip install mediapipe`.') from e
PROJECT_SHORT_NAME = 'mp'
PROJECT_FULL_NAME = 'MediaPipe'
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
default='/tmp/generated_docs',
help='Where to write the resulting docs.')
_URL_PREFIX = flags.DEFINE_string(
'code_url_prefix',
'https://github.com/google/mediapipe/tree/master/mediapipe',
'The url prefix for links to code.')
_SEARCH_HINTS = flags.DEFINE_bool(
'search_hints', True,
'Include metadata search hints in the generated files')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
'Path prefix in the _toc.yaml')
def gen_api_docs():
"""Generates API docs for the mediapipe package."""
doc_generator = generate_lib.DocGenerator(
root_title=PROJECT_FULL_NAME,
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
base_dir=os.path.dirname(mediapipe.__file__),
code_url_prefix=_URL_PREFIX.value,
search_hints=_SEARCH_HINTS.value,
site_path=_SITE_PATH.value,
# This callback ensures that docs are only generated for objects that
# are explicitly imported in your __init__.py files. There are other
# options but this is a good starting point.
callbacks=[public_api.explicit_package_contents_filter],
)
doc_generator.build(_OUTPUT_DIR.value)
print('Docs output to:', _OUTPUT_DIR.value)
def main(_):
gen_api_docs()
if __name__ == '__main__':
app.run(main)

View File

@ -222,10 +222,10 @@ cc_library(
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check",
@ -328,6 +328,7 @@ cc_library(
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -344,6 +345,7 @@ cc_test(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",

View File

@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
async_selection: true
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
}
}
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
async_selection: true
contained_node: {
calculator: "BypassCalculator"
node_options: {

View File

@ -18,6 +18,7 @@
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
};
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
class ConcatenateClassificationListCalculator
: public ConcatenateListsCalculator<Classification, ClassificationList> {
protected:
int ListSize(const ClassificationList& list) const override {
return list.classification_size();
}
const Classification GetItem(const ClassificationList& list,
int idx) const override {
return list.classification(idx);
}
Classification* AddItem(ClassificationList& list) const override {
return list.add_classification();
}
};
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
}
}
void AddInputClassificationLists(
const std::vector<ClassificationList>& input_classifications_vec,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < input_classifications_vec.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<ClassificationList>(input_classifications_vec[i])
.At(Timestamp(timestamp)));
}
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3,
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size());
}
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateClassificationListCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
)pb");
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb");
std::vector<ClassificationList> inputs = {input_0, input_1};
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
auto result = outputs[0].Get<ClassificationList>();
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb"),
EqualsProto(result));
}
} // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h"
#include "tensorflow/lite/interpreter.h"
@ -58,4 +59,7 @@ typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
EndLoopDetectionCalculator;
REGISTER_CALCULATOR(EndLoopDetectionCalculator);
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
} // namespace mediapipe

View File

@ -50,7 +50,7 @@ namespace mediapipe {
// calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts
// }
template <typename IterableT>
class EndLoopCalculator : public CalculatorBase {

View File

@ -109,6 +109,56 @@ cc_test(
],
)
mediapipe_proto_library(
name = "tensors_to_audio_calculator_proto",
srcs = ["tensors_to_audio_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "tensors_to_audio_calculator",
srcs = ["tensors_to_audio_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_audio_tools//audio/dsp:window_functions",
"@pffft",
],
alwayslink = 1,
)
cc_test(
name = "tensors_to_audio_calculator_test",
srcs = ["tensors_to_audio_calculator_test.cc"],
deps = [
":audio_to_tensor_calculator",
":audio_to_tensor_calculator_cc_proto",
":tensors_to_audio_calculator",
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library(
name = "feedback_tensors_calculator_proto",
srcs = ["feedback_tensors_calculator.proto"],
@ -253,6 +303,26 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "regex_preprocessor_calculator_test",
srcs = ["regex_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":regex_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"],
@ -304,6 +374,28 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "universal_sentence_encoder_preprocessor_calculator_test",
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
deps = [
":universal_sentence_encoder_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library(
@ -438,6 +530,7 @@ cc_library(
}),
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status:statusor",
],
@ -458,6 +551,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":inference_runner",
"//mediapipe/framework:mediapipe_profiling",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
@ -1200,13 +1294,30 @@ cc_library(
name = "image_to_tensor_utils",
srcs = ["image_to_tensor_utils.cc"],
hdrs = ["image_to_tensor_utils.h"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":image_to_tensor_calculator_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/types:optional",
],
"//mediapipe/gpu:gpu_origin_cc_proto",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": ["//mediapipe/gpu:gpu_buffer"],
}),
)
cc_test(
@ -1216,6 +1327,8 @@ cc_test(
":image_to_tensor_utils",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)

View File

@ -133,7 +133,7 @@ bool IsValidFftSize(int size) {
// invocation. In the non-streaming mode, the vector contains all of the
// output timestamps for an input audio buffer.
// DC_AND_NYQUIST - std::pair<float, float> @Optional.
// A pair of dc component and nyquest component. Only can be connected when
// A pair of dc component and nyquist component. Only can be connected when
// the calculator performs fft (the fft_size is set in the calculator
// options).
//

View File

@ -54,13 +54,6 @@
namespace mediapipe {
namespace api2 {
#if MEDIAPIPE_DISABLE_GPU
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
using GpuBuffer = AnyType;
#else
using GpuBuffer = mediapipe::GpuBuffer;
#endif // MEDIAPIPE_DISABLE_GPU
// Converts image into Tensor, possibly with cropping, resizing and
// normalization, according to specified inputs and options.
//
@ -141,42 +134,7 @@ class ImageToTensorCalculator : public Node {
const auto& options =
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
RET_CHECK_OK(ValidateOptionOutputDims(options));
RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected())
<< "One and only one of IMAGE and IMAGE_GPU input is expected.";
@ -198,21 +156,7 @@ class ImageToTensorCalculator : public Node {
absl::Status Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
output_width_ = options_.output_tensor_width();
output_height_ = options_.output_tensor_height();
is_float_output_ = options_.has_output_tensor_float_range();
if (options_.has_output_tensor_uint_range()) {
range_min_ =
static_cast<float>(options_.output_tensor_uint_range().min());
range_max_ =
static_cast<float>(options_.output_tensor_uint_range().max());
} else if (options_.has_output_tensor_int_range()) {
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
} else {
range_min_ = options_.output_tensor_float_range().min();
range_max_ = options_.output_tensor_float_range().max();
}
params_ = GetOutputTensorParams(options_);
return absl::OkStatus();
}
@ -242,7 +186,13 @@ class ImageToTensorCalculator : public Node {
}
}
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
#if MEDIAPIPE_DISABLE_GPU
ASSIGN_OR_RETURN(auto image, GetInputImage(kIn(cc)));
#else
const bool is_input_gpu = kInGpu(cc).IsConnected();
ASSIGN_OR_RETURN(auto image, is_input_gpu ? GetInputImage(kInGpu(cc))
: GetInputImage(kIn(cc)));
#endif // MEDIAPIPE_DISABLE_GPU
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
@ -263,11 +213,13 @@ class ImageToTensorCalculator : public Node {
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
Tensor::ElementType output_tensor_type =
GetOutputTensorType(image->UsesGpu());
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
GetOutputTensorType(image->UsesGpu(), params_);
Tensor tensor(output_tensor_type,
{1, params_.output_height, params_.output_width,
GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_,
->Convert(*image, roi, params_.range_min,
params_.range_max,
/*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>();
@ -278,81 +230,11 @@ class ImageToTensorCalculator : public Node {
}
private:
bool DoesGpuInputStartAtBottom() {
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
BorderMode GetBorderMode() {
switch (options_.border_mode()) {
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
return BorderMode::kReplicate;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
return BorderMode::kZero;
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
return BorderMode::kReplicate;
}
}
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (!uses_gpu) {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
CalculatorContext* cc) {
if (kIn(cc).IsConnected()) {
const auto& packet = kIn(cc).packet();
return kIn(cc).Visit(
[&packet](const mediapipe::Image&) {
return SharedPtrWithPacket<mediapipe::Image>(packet);
},
[&packet](const mediapipe::ImageFrame&) {
return std::make_shared<const mediapipe::Image>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
});
} else { // if (kInGpu(cc).IsConnected())
#if !MEDIAPIPE_DISABLE_GPU
const GpuBuffer& input = *kInGpu(cc);
// A shallow copy is okay since the resulting 'image' object is local in
// Process(), and thus never outlives 'input'.
return std::make_shared<const mediapipe::Image>(input);
#else
return absl::UnimplementedError(
"GPU processing is disabled in build flags");
#endif // !MEDIAPIPE_DISABLE_GPU
}
}
absl::Status InitConverterIfNecessary(CalculatorContext* cc,
const Image& image) {
// Lazy initialization of the GPU or CPU converter.
if (image.UsesGpu()) {
if (!is_float_output_) {
if (!params_.is_float_output) {
return absl::UnimplementedError(
"ImageToTensorConverter for the input GPU image currently doesn't "
"support quantization.");
@ -360,18 +242,20 @@ class ImageToTensorCalculator : public Node {
if (!gpu_converter_) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
ASSIGN_OR_RETURN(gpu_converter_,
CreateMetalConverter(cc, GetBorderMode()));
ASSIGN_OR_RETURN(
gpu_converter_,
CreateMetalConverter(cc, GetBorderMode(options_.border_mode())));
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlBufferTensorConverter(
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
cc, DoesGpuInputStartAtBottom(options_),
GetBorderMode(options_.border_mode())));
#else
if (!gpu_converter_) {
ASSIGN_OR_RETURN(
gpu_converter_,
ASSIGN_OR_RETURN(gpu_converter_,
CreateImageToGlTextureTensorConverter(
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
cc, DoesGpuInputStartAtBottom(options_),
GetBorderMode(options_.border_mode())));
}
if (!gpu_converter_) {
return absl::UnimplementedError(
@ -383,10 +267,10 @@ class ImageToTensorCalculator : public Node {
} else {
if (!cpu_converter_) {
#if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(
cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(),
GetOutputTensorType(/*uses_gpu=*/false)));
ASSIGN_OR_RETURN(cpu_converter_,
CreateOpenCvConverter(
cc, GetBorderMode(options_.border_mode()),
GetOutputTensorType(/*uses_gpu=*/false, params_)));
#else
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined.";
@ -399,11 +283,7 @@ class ImageToTensorCalculator : public Node {
std::unique_ptr<ImageToTensorConverter> gpu_converter_;
std::unique_ptr<ImageToTensorConverter> cpu_converter_;
mediapipe::ImageToTensorCalculatorOptions options_;
int output_width_ = 0;
int output_height_ = 0;
bool is_float_output_ = false;
float range_min_ = 0.0f;
float range_max_ = 1.0f;
OutputTensorParams params_;
};
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);

View File

@ -27,12 +27,6 @@ struct Size {
int height;
};
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such pixels
// will be calculated.
enum class BorderMode { kZero, kReplicate };
// Converts image to tensor.
class ImageToTensorConverter {
public:

View File

@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter {
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
"Unsupported format: ", static_cast<uint32_t>(input.format())));
}
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter {
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max, tensor_buffer_offset]() -> absl::Status {
constexpr int kRgbaNumChannels = 4;
const int input_num_channels = input.channels();
auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
GL_TEXTURE_2D, source_texture.name(),
input_num_channels == 4 ? GL_RGB : GL_RGBA,
source_texture.width() * source_texture.height() *
kRgbaNumChannels * sizeof(uint8_t),
input_num_channels * sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);

View File

@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter {
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
"Unsupported format: ", static_cast<uint32_t>(input.format())));
}
// TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0)

View File

@ -16,7 +16,9 @@
#include <array>
#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
@ -214,4 +216,68 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
matrix[15] = 1.0f;
}
BorderMode GetBorderMode(
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode) {
switch (mode) {
case mediapipe::
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
return BorderMode::kReplicate;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
return BorderMode::kZero;
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
return BorderMode::kReplicate;
}
}
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
const OutputTensorParams& params) {
if (!uses_gpu) {
if (params.is_float_output) {
return Tensor::ElementType::kFloat32;
}
if (params.range_min < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const mediapipe::Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
image_packet) {
return image_packet.Visit(
[&image_packet](const mediapipe::Image&) {
return SharedPtrWithPacket<mediapipe::Image>(image_packet);
},
[&image_packet](const mediapipe::ImageFrame&) {
return std::make_shared<const mediapipe::Image>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(image_packet)));
});
}
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet) {
// A shallow copy is okay since the resulting 'image' object is local in
// Process(), and thus never outlives 'input'.
return std::make_shared<const mediapipe::Image>(image_gpu_packet.Get());
}
#endif // !MEDIAPIPE_DISABLE_GPU
} // namespace mediapipe

View File

@ -18,8 +18,18 @@
#include <array>
#include "absl/types/optional.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_buffer.h"
#endif // !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_origin.pb.h"
namespace mediapipe {
@ -31,6 +41,24 @@ struct RotatedRect {
float rotation;
};
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such pixels
// will be calculated.
// TODO: Consider moving this to a separate border_mode.h file.
enum class BorderMode { kZero, kReplicate };
// Struct that host commonly accessed parameters used in the
// ImageTo[Batch]TensorCalculator.
struct OutputTensorParams {
int output_height;
int output_width;
int output_batch;
bool is_float_output;
float range_min;
float range_max;
};
// Generates a new ROI or converts it from normalized rect.
RotatedRect GetRoi(int input_width, int input_height,
absl::optional<mediapipe::NormalizedRect> norm_rect);
@ -95,6 +123,103 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
const RotatedRect& sub_rect, int rect_width, int rect_height,
bool flip_horizontaly, std::array<float, 16>* matrix);
// Validates the output dimensions set in the option proto. The input option
// proto is expected to have to following fields:
// output_tensor_float_range, output_tensor_int_range, output_tensor_uint_range
// output_tensor_width, output_tensor_height.
// See ImageToTensorCalculatorOptions for the description of each field.
template <typename T>
absl::Status ValidateOptionOutputDims(const T& options) {
RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
RET_CHECK_GT(options.output_tensor_height(), 0)
<< "Valid output tensor height is required.";
return absl::OkStatus();
}
template <typename T>
OutputTensorParams GetOutputTensorParams(const T& options) {
OutputTensorParams params;
if (options.has_output_tensor_uint_range()) {
params.range_min =
static_cast<float>(options.output_tensor_uint_range().min());
params.range_max =
static_cast<float>(options.output_tensor_uint_range().max());
} else if (options.has_output_tensor_int_range()) {
params.range_min =
static_cast<float>(options.output_tensor_int_range().min());
params.range_max =
static_cast<float>(options.output_tensor_int_range().max());
} else {
params.range_min = options.output_tensor_float_range().min();
params.range_max = options.output_tensor_float_range().max();
}
params.output_width = options.output_tensor_width();
params.output_height = options.output_tensor_height();
params.is_float_output = options.has_output_tensor_float_range();
params.output_batch = 1;
return params;
}
// Returns whether the GPU input format starts at the bottom.
template <typename T>
bool DoesGpuInputStartAtBottom(const T& options) {
return options.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
// Converts the BorderMode proto into struct.
BorderMode GetBorderMode(
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode);
// Gets the output tensor type.
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
const OutputTensorParams& params);
// Gets the number of output channels from the input Image format.
int GetNumOutputChannels(const mediapipe::Image& image);
// Converts the packet that hosts different format (Image, ImageFrame,
// GpuBuffer) into the mediapipe::Image format.
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
image_packet);
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet);
#endif // !MEDIAPIPE_DISABLE_GPU
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_

View File

@ -16,6 +16,8 @@
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
@ -23,6 +25,7 @@ namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
testing::Matcher<RotatedRect> EqRotatedRect(float width, float height,
float center_x, float center_y,
@ -157,5 +160,95 @@ TEST(GetValueRangeTransformation, FloatToPixel) {
EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f));
}
constexpr char kValidFloatProto[] = R"(
output_tensor_float_range { min: 0.0 max: 1.0 }
output_tensor_width: 100
output_tensor_height: 200
)";
constexpr char kValidIntProto[] = R"(
output_tensor_float_range { min: 0 max: 255 }
output_tensor_width: 100
output_tensor_height: 200
)";
TEST(ValidateOptionOutputDims, ValidProtos) {
const auto float_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidFloatProto);
MP_EXPECT_OK(ValidateOptionOutputDims(float_options));
}
TEST(ValidateOptionOutputDims, EmptyProto) {
mediapipe::ImageToTensorCalculatorOptions options;
// No output tensor range set.
EXPECT_THAT(ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Output tensor range is required")));
// Invalid output float tensor range.
options.mutable_output_tensor_float_range()->set_min(1.0);
options.mutable_output_tensor_float_range()->set_max(0.0);
EXPECT_THAT(
ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Valid output float tensor range is required")));
// Output width/height is not set.
options.mutable_output_tensor_float_range()->set_min(0.0);
options.mutable_output_tensor_float_range()->set_max(1.0);
EXPECT_THAT(ValidateOptionOutputDims(options),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("Valid output tensor width is required")));
}
TEST(GetOutputTensorParams, SetValues) {
// Test int range with ImageToTensorCalculatorOptions.
const auto int_options =
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
kValidIntProto);
const auto params2 = GetOutputTensorParams(int_options);
EXPECT_EQ(params2.range_min, 0.0f);
EXPECT_EQ(params2.range_max, 255.0f);
EXPECT_EQ(params2.output_batch, 1);
EXPECT_EQ(params2.output_width, 100);
EXPECT_EQ(params2.output_height, 200);
}
TEST(GetBorderMode, GetBorderMode) {
// Default to REPLICATE.
auto border_mode =
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED;
EXPECT_EQ(BorderMode::kReplicate, GetBorderMode(border_mode));
// Set to ZERO.
border_mode =
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO;
EXPECT_EQ(BorderMode::kZero, GetBorderMode(border_mode));
}
TEST(GetOutputTensorType, GetOutputTensorType) {
OutputTensorParams params;
// Return float32 when GPU is enabled.
EXPECT_EQ(Tensor::ElementType::kFloat32,
GetOutputTensorType(/*uses_gpu=*/true, params));
// Return float32 when is_float_output is set to true.
params.is_float_output = true;
EXPECT_EQ(Tensor::ElementType::kFloat32,
GetOutputTensorType(/*uses_gpu=*/false, params));
// Return int8 when range_min is negative.
params.is_float_output = false;
params.range_min = -255.0f;
EXPECT_EQ(Tensor::ElementType::kInt8,
GetOutputTensorType(/*uses_gpu=*/false, params));
// Return 8int8 when range_min is non-negative.
params.range_min = 0.0f;
EXPECT_EQ(Tensor::ElementType::kUInt8,
GetOutputTensorType(/*uses_gpu=*/false, params));
}
} // namespace
} // namespace mediapipe

View File

@ -72,7 +72,7 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
RET_CHECK(!input_tensors.empty());
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
inference_runner_->Run(input_tensors));
inference_runner_->Run(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus();
}

View File

@ -26,6 +26,8 @@
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe {
namespace api2 {
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
std::vector<Tensor>& output_tensors) {
return gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status {
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
// Explicitly copy input.
for (int i = 0; i < input_tensors.size(); ++i) {
glBindBuffer(GL_COPY_READ_BUFFER,
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
}
// Run inference.
{
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
output_tensors.reserve(output_size_);
for (int i = 0; i < output_size_; ++i) {

View File

@ -32,6 +32,8 @@
#include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe {
namespace api2 {
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
absl::StatusOr<std::vector<Tensor>> Process(
const std::vector<Tensor>& input_tensors);
CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
absl::Status Close();
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
absl::StatusOr<std::vector<Tensor>>
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
const std::vector<Tensor>& input_tensors) {
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
std::vector<Tensor> output_tensors;
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status {
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i));
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
output_tensors.back().GetOpenGlBufferWriteView().name(), i));
}
// Run inference.
{
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
return tflite_gpu_runner_->Invoke();
}
}));
return output_tensors;
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
ASSIGN_OR_RETURN(*output_tensors,
gpu_inference_runner_->Process(input_tensors));
gpu_inference_runner_->Process(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus();

View File

@ -70,7 +70,7 @@ absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) {
RET_CHECK(!input_tensors.empty());
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
inference_runner_->Run(input_tensors));
inference_runner_->Run(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus();
}

View File

@ -20,12 +20,15 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe {
namespace {
@ -79,7 +82,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
delegate_(std::move(delegate)) {}
absl::StatusOr<std::vector<Tensor>> Run(
const std::vector<Tensor>& input_tensors) override;
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) override;
private:
api2::Packet<TfLiteModelPtr> model_;
@ -88,7 +91,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
};
absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
const std::vector<Tensor>& input_tensors) {
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
// Read CPU input into tensors.
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) {
@ -131,8 +134,10 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
}
// Run inference.
{
MEDIAPIPE_PROFILING(CPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
// Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs();
std::vector<Tensor> output_tensors;

View File

@ -2,6 +2,7 @@
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
#include "absl/status/statusor.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
@ -11,7 +12,7 @@ class InferenceRunner {
public:
virtual ~InferenceRunner() = default;
virtual absl::StatusOr<std::vector<Tensor>> Run(
const std::vector<Tensor>& inputs) = 0;
CalculatorContext* cc, const std::vector<Tensor>& inputs) = 0;
};
} // namespace mediapipe

View File

@ -0,0 +1,197 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstring>
#include <new>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "pffft.h"
namespace mediapipe {
namespace api2 {
namespace {
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
std::vector<float> hann_window(window_size);
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
if (sqrt_hann) {
absl::c_transform(hann_window, hann_window.begin(),
[](double x) { return std::sqrt(x); });
}
return hann_window;
}
// Note that the InvHannWindow function may only work for 50% overlapping case.
std::vector<float> InvHannWindow(int window_size, bool sqrt_hann) {
std::vector<float> window = HannWindow(window_size, sqrt_hann);
std::vector<float> inv_window(window.size());
if (sqrt_hann) {
absl::c_copy(window, inv_window.begin());
} else {
const int kHalfWindowSize = window.size() / 2;
absl::c_transform(window, inv_window.begin(),
[](double x) { return x * x; });
for (int i = 0; i < kHalfWindowSize; ++i) {
double sum = inv_window[i] + inv_window[kHalfWindowSize + i];
inv_window[i] = window[i] / sum;
inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum;
}
}
return inv_window;
}
// PFFFT only supports transforms for inputs of length N of the form
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
bool IsValidFftSize(int size) {
if (size <= 0) {
return false;
}
constexpr int kFactors[] = {2, 3, 5};
int factorization[] = {0, 0, 0};
int n = static_cast<int>(size);
for (int i = 0; i < 3; ++i) {
while (n % kFactors[i] == 0) {
n = n / kFactors[i];
++factorization[i];
}
}
return factorization[0] >= 5 && n == 1;
}
} // namespace
// Converts 2D MediaPipe float Tensors to audio buffers.
// The calculator will perform ifft on the complex DFT and apply the window
// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must
// have the DFT real parts in its first row and the DFT imagery parts in its
// second row. A valid "fft_size" must be set in the CalculatorOptions.
//
// Inputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor that represents the audio's complex DFT
// results.
// DC_AND_NYQUIST - std::pair<float, float>
// A pair of dc component and nyquist component.
//
// Outputs:
// AUDIO - mediapipe::Matrix
// The audio data represented as mediapipe::Matrix.
//
// Example:
// node {
// calculator: "TensorsToAudioCalculator"
// input_stream: "TENSORS:tensors"
// input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
// output_stream: "AUDIO:audio"
// options {
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
// fft_size: 256
// }
// }
// }
class TensorsToAudioCalculator : public Node {
public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Input<std::pair<float, float>> kDcAndNyquistIn{
"DC_AND_NYQUIST"};
static constexpr Output<Matrix> kAudioOut{"AUDIO"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
// The internal state of the FFT library.
PFFFT_Setup* fft_state_ = nullptr;
int fft_size_ = 0;
float inverse_fft_size_ = 0;
std::vector<float, Eigen::aligned_allocator<float>> input_dft_;
std::vector<float> inv_fft_window_;
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
// pffft requires memory to work with to avoid using the stack.
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
};
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
const auto& options =
cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
RET_CHECK(IsValidFftSize(options.fft_size()))
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
">=0 and c >= 0 and a >= 5, the requested fft size is "
<< options.fft_size();
fft_size_ = options.fft_size();
inverse_fft_size_ = 1.0f / fft_size_;
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
input_dft_.resize(fft_size_);
inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false);
fft_input_buffer_.resize(fft_size_);
fft_workplace_.resize(fft_size_);
fft_output_.resize(fft_size_);
return absl::OkStatus();
}
absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = *kTensorsIn(cc);
RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
auto view = input_tensors[0].GetCpuReadView();
// DC's real part.
input_dft_[0] = kDcAndNyquistIn(cc)->first;
// Nyquist's real part is the penultimate element of the tensor buffer.
// pffft ignores the Nyquist's imagery part. No need to fetch the last value
// from the tensor buffer.
input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
(fft_size_ - 2) * sizeof(float));
pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
fft_workplace_.data(), PFFFT_BACKWARD);
// Applies the inverse window function.
std::transform(
fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(),
fft_output_.begin(),
[this](float a, float b) { return a * b * inverse_fft_size_; });
Matrix matrix = Eigen::Map<Matrix>(fft_output_.data(), 1, fft_output_.size());
kAudioOut(cc).Send(std::move(matrix));
return absl::OkStatus();
}
absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) {
if (fft_state_) {
pffft_destroy_setup(fft_state_);
}
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorsToAudioCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsToAudioCalculatorOptions ext = 484297136;
}
// Size of the fft in number of bins. If set, the calculator will do ifft
// on the input tensor.
optional int64 fft_size = 1;
}

View File

@ -0,0 +1,149 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <new>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
protected:
// Creates an audio matrix containing a single sample of 1.0 at a specified
// offset.
Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) {
Matrix impulse = Matrix::Zero(1, num_samples);
impulse(0, impulse_offset_idx) = 1.0;
return impulse;
}
void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "audio_in"
input_stream: "sample_rate"
output_stream: "audio_out"
node {
calculator: "AudioToTensorCalculator"
input_stream: "AUDIO:audio_in"
input_stream: "SAMPLE_RATE:sample_rate"
output_stream: "TENSORS:tensors"
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
options {
[mediapipe.AudioToTensorCalculatorOptions.ext] {
num_channels: 1
num_samples: $0
num_overlapping_samples: 0
target_sample_rate: $1
fft_size: $2
}
}
}
node {
calculator: "TensorsToAudioCalculator"
input_stream: "TENSORS:tensors"
input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
output_stream: "AUDIO:audio_out"
options {
[mediapipe.TensorsToAudioCalculatorOptions.ext] {
fft_size: $2
}
}
}
)",
/*$0=*/num_samples,
/*$1=*/sample_rate,
/*$2=*/fft_size));
tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
}
void RunGraph(const Matrix& input_data, double sample_rate) {
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"audio_in", MakePacket<Matrix>(input_data).At(Timestamp(0))));
MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
std::vector<Packet> audio_out_packets_;
CalculatorGraphConfig graph_config_;
CalculatorGraph graph_;
};
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
ConfigGraph(320, 16000, 103);
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
auto status = graph_.WaitUntilIdle();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(status.message(),
::testing::HasSubstr("FFT size must be of the form"));
}
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// The impulse signal at the center is not affected by the window function.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// As the impulse signal sits at the 1/4 of the hann window, the inverse
// window function reduces it by half.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data / 2);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// As the impulse signal sits at the beginning of the hann window, the inverse
// window function completely removes it.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
}
} // namespace
} // namespace mediapipe

View File

@ -289,8 +289,15 @@ class NodeBase {
template <typename T>
T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true;
return *options_.MutableExtension(T::ext);
return *options_.MutableExtension(extension);
}
protected:
@ -386,8 +393,15 @@ class PacketGenerator {
template <typename T>
T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true;
return *options_.MutableExtension(T::ext);
return *options_.MutableExtension(extension);
}
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>

View File

@ -185,7 +185,7 @@ class CalculatorBaseFactory {
// Functions for checking that the calculator has the required GetContract.
template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
typedef absl::Status (*GetContractType)(CalculatorContract * cc);
typedef absl::Status (*GetContractType)(CalculatorContract* cc);
return std::is_same<decltype(&T::GetContract), GetContractType>::value;
}
template <class T>

View File

@ -133,7 +133,13 @@ message GraphTrace {
TPU_TASK = 13;
GPU_CALIBRATION = 14;
PACKET_QUEUED = 15;
GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17;
CPU_TASK_INVOKE = 18;
}
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
// )
// The timing for one packet set being processed at one caclulator node.
message CalculatorTrace {

View File

@ -293,7 +293,6 @@ mediapipe_proto_library(
name = "rect_proto",
srcs = ["rect.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/formats:location_data_proto"],
)
mediapipe_register_type(

View File

@ -109,6 +109,13 @@ struct TraceEvent {
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
// )
};
// Packet trace log buffer.

View File

@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
// Returns a PacketSequencerCalculator node.
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
bool synchronize_io) {
bool async_selection) {
CalculatorGraphConfig::Node* result = config->add_node();
*result->mutable_calculator() = "PacketSequencerCalculator";
if (synchronize_io) {
if (!async_selection) {
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
"DefaultInputStreamHandler";
}
@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField<std::string>& streams,
return tags.count({tag, 0}) > 0;
}
// Returns true if a set of "TAG::index" includes a TagIndex.
bool ContainsTag(const proto_ns::RepeatedPtrField<std::string>& tags,
TagIndex item) {
for (const std::string& t : tags) {
if (ParseTagIndex(t) == item) return true;
}
return false;
}
absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
const Subgraph::SubgraphOptions& options) {
CalculatorGraphConfig config;
@ -263,17 +272,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
std::string enable_stream = "ENABLE:gate_enable";
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
bool synchronize_io =
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
.synchronize_io();
const auto& switch_options =
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options);
bool async_selection = switch_options.async_selection();
if (HasTag(container_node.input_stream(), "SELECT")) {
select_node = BuildTimestampNode(&config, synchronize_io);
select_node = BuildTimestampNode(&config, async_selection);
select_node->add_input_stream("INPUT:gate_select");
select_node->add_output_stream("OUTPUT:gate_select_timed");
select_stream = "SELECT:gate_select_timed";
}
if (HasTag(container_node.input_stream(), "ENABLE")) {
enable_node = BuildTimestampNode(&config, synchronize_io);
enable_node = BuildTimestampNode(&config, async_selection);
enable_node->add_input_stream("INPUT:gate_enable");
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
enable_stream = "ENABLE:gate_enable_timed";
@ -296,7 +305,7 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
mux->add_input_side_packet("SELECT:gate_select");
mux->add_input_side_packet("ENABLE:gate_enable");
// Add input streams for graph and demux and the timestamper.
// Add input streams for graph and demux.
config.add_input_stream("SELECT:gate_select");
config.add_input_stream("ENABLE:gate_enable");
config.add_input_side_packet("SELECT:gate_select");
@ -306,6 +315,12 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
std::string stream = CatStream(p.first, p.second);
config.add_input_stream(stream);
demux->add_input_stream(stream);
}
// Add input streams for the timestamper.
auto& tick_streams = switch_options.tick_input_stream();
for (const auto& p : input_tags) {
if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue;
TagIndex tick_tag{"TICK", tick_index++};
if (select_node) {
select_node->add_input_stream(CatStream(tick_tag, p.second));

View File

@ -25,6 +25,14 @@ message SwitchContainerOptions {
// Activates channel 1 for enable = true, channel 0 otherwise.
optional bool enable = 4;
// Use DefaultInputStreamHandler for muxing & demuxing.
// Use DefaultInputStreamHandler for demuxing.
optional bool synchronize_io = 5;
// Use ImmediateInputStreamHandler for channel selection.
optional bool async_selection = 6;
// Specifies an input stream, "TAG:index", that defines the processed
// timestamps. SwitchContainer awaits output at the last processed
// timestamp before advancing from one selected channel to the next.
repeated string tick_input_stream = 7;
}

View File

@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
input_stream: "INPUT:enable"
input_stream: "TICK:foo"
output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
input_stream_handler {
input_stream_handler: "DefaultInputStreamHandler"
}
}
node {
name: "switchcontainer__SwitchDemuxCalculator"
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
// Shows the SwitchContainer container runs with a pair of simple subnodes.
TEST(SwitchContainerTest, RunsWithSubnodes) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
CalculatorGraphConfig supergraph = SubnodeContainerExample();
CalculatorGraphConfig supergraph =
SubnodeContainerExample("async_selection: true");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
RunTestContainer(supergraph);
}

View File

@ -14,6 +14,7 @@
#include <algorithm>
#include <memory>
#include <queue>
#include <set>
#include <string>
@ -54,21 +55,47 @@ namespace mediapipe {
// contained subgraph or calculator nodes.
//
class SwitchDemuxCalculator : public CalculatorBase {
static constexpr char kSelectTag[] = "SELECT";
static constexpr char kEnableTag[] = "ENABLE";
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
absl::Status RecordPackets(CalculatorContext* cc);
int ChannelIndex(Timestamp timestamp);
absl::Status SendActivePackets(CalculatorContext* cc);
private:
int channel_index_;
std::set<std::string> channel_tags_;
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
PacketQueue input_queue_;
std::map<Timestamp, int> channel_history_;
};
REGISTER_CALCULATOR(SwitchDemuxCalculator);
namespace {
static constexpr char kSelectTag[] = "SELECT";
static constexpr char kEnableTag[] = "ENABLE";
// Returns the last received timestamp for an input stream.
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
return input.Value().Timestamp();
}
// Returns the last received timestamp for channel selection.
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
Timestamp result = Timestamp::Done();
if (cc->Inputs().HasTag(kEnableTag)) {
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
} else if (cc->Inputs().HasTag(kSelectTag)) {
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
}
return result;
}
} // namespace
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
// Allow any one of kSelectTag, kEnableTag.
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
channel_tags_ = ChannelTags(cc->Outputs().TagMap());
channel_history_[Timestamp::Unstarted()] = channel_index_;
// Relay side packets to all channels.
// Note: This is necessary because Calculator::Open only proceeds when every
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
}
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
// Update the input channel index if specified.
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
MP_RETURN_IF_ERROR(RecordPackets(cc));
MP_RETURN_IF_ERROR(SendActivePackets(cc));
return absl::OkStatus();
}
// Relay packets and timestamps only to channel_index_.
// Enqueue all arriving packets and bounds.
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
// Enqueue any new arriving packets.
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto& input = cc->Inputs().Get(tag, index);
std::string output_tag = tool::ChannelTag(tag, channel_index_);
auto output_id = cc->Outputs().GetId(output_tag, index);
if (output_id.IsValid()) {
auto& output = cc->Outputs().Get(output_tag, index);
tool::Relay(input, &output);
auto input_id = cc->Inputs().GetId(tag, index);
Packet packet = cc->Inputs().Get(input_id).Value();
if (packet.Timestamp() == cc->InputTimestamp()) {
input_queue_[input_id].push(packet);
}
}
}
// Enque any new input channel and its activation timestamp.
Timestamp channel_settled = ChannelSettledTimestamp(cc);
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
if (channel_settled == cc->InputTimestamp() &&
new_channel_index != channel_index_) {
channel_index_ = new_channel_index;
channel_history_[channel_settled] = channel_index_;
}
return absl::OkStatus();
}
// Returns the channel index for a Timestamp.
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
auto it = std::prev(channel_history_.upper_bound(timestamp));
return it->second;
}
// Dispatches all queued input packets with known channels.
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
// Dispatch any queued input packets with a defined channel_index.
Timestamp channel_settled = ChannelSettledTimestamp(cc);
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto input_id = cc->Inputs().GetId(tag, index);
auto& queue = input_queue_[input_id];
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
int channel_index = ChannelIndex(queue.front().Timestamp());
std::string output_tag = tool::ChannelTag(tag, channel_index);
auto output_id = cc->Outputs().GetId(output_tag, index);
if (output_id.IsValid()) {
cc->Outputs().Get(output_id).AddPacket(queue.front());
}
queue.pop();
}
}
}
// Discard all select packets not needed for any remaining input packets.
Timestamp input_settled = Timestamp::Done();
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto input_id = cc->Inputs().GetId(tag, index);
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
if (!input_queue_[input_id].empty()) {
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
stream_settled =
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
}
}
}
Timestamp input_bound = input_settled.NextAllowedInStream();
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
channel_history_.erase(channel_history_.begin(), history_bound);
return absl::OkStatus();
}

View File

@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::SwitchContainerOptions>();
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
channel_tags_ = ChannelTags(cc->Inputs().TagMap());
channel_history_[Timestamp::Unset()] = channel_index_;
channel_history_[Timestamp::Unstarted()] = channel_index_;
// Relay side packets only from channel_index_.
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {

View File

@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
static void EglThreadExitCallback(void* key_value) {
#if defined(__ANDROID__)
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
EGL_NO_CONTEXT);
#else
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
// implementations, and should be considered as an undocumented vendor
// extension.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
//
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif
eglReleaseThread();
}

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.framework.image.ImageProperties;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.MPImageProperties;
import java.nio.ByteBuffer;
// TODO: use Preconditions in this file.
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
}
/**
* Creates an Image packet from an {@link Image}.
* Creates a MediaPipe Image packet from a {@link MPImage}.
*
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
*/
public Packet createImage(Image image) {
public Packet createImage(MPImage image) {
// TODO: Choose the best storage from multiple containers.
ImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
MPImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
ByteBuffer buffer = ByteBufferExtractor.extract(image);
int numChannels = 0;
switch (properties.getImageFormat()) {
case Image.IMAGE_FORMAT_RGBA:
case MPImage.IMAGE_FORMAT_RGBA:
numChannels = 4;
break;
case Image.IMAGE_FORMAT_RGB:
case MPImage.IMAGE_FORMAT_RGB:
numChannels = 3;
break;
case Image.IMAGE_FORMAT_ALPHA:
case MPImage.IMAGE_FORMAT_ALPHA:
numChannels = 1;
break;
default: // fall out
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
int height = image.getHeight();
return createImage(buffer, width, height, numChannels);
}
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
@ -100,7 +100,7 @@ public class AndroidPacketCreator extends PacketCreator {
// Unsupported type.
throw new UnsupportedOperationException(
"Unsupported Image container type: " + properties.getImageFormat());
"Unsupported Image container type: " + properties.getStorageType());
}
/**

View File

@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
import android.graphics.Bitmap;
/**
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
* Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
* {@link IllegalArgumentException} will be thrown.
*/
public final class BitmapExtractor {
/**
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
* Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
*
* @param image the image to extract {@link android.graphics.Bitmap} from.
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
* @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
public static Bitmap extract(Image image) {
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
public static Bitmap extract(MPImage image) {
MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
if (imageContainer != null) {
return ((BitmapImageContainer) imageContainer).getBitmap();
} else {
// TODO: Support ByteBuffer -> Bitmap conversion.
throw new IllegalArgumentException(
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
"Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
+ " supported");
}
}

View File

@ -22,7 +22,7 @@ import android.provider.MediaStore;
import java.io.IOException;
/**
* Builds {@link Image} from {@link android.graphics.Bitmap}.
* Builds {@link MPImage} from {@link android.graphics.Bitmap}.
*
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
}
/**
* Creates the builder to build {@link Image} from a file.
* Creates the builder to build {@link MPImage} from a file.
*
* @param context the application context.
* @param uri the path to the resource file.
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
}
/** Sets value for {@link Image#getTimestamp()}. */
/** Sets value for {@link MPImage#getTimestamp()}. */
BitmapImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(
/** Builds a {@link MPImage} instance. */
public MPImage build() {
return new MPImage(
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
}
}

View File

@ -16,19 +16,19 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
class BitmapImageContainer implements ImageContainer {
class BitmapImageContainer implements MPImageContainer {
private final Bitmap bitmap;
private final ImageProperties properties;
private final MPImageProperties properties;
public BitmapImageContainer(Bitmap bitmap) {
this.bitmap = bitmap;
this.properties =
ImageProperties.builder()
MPImageProperties.builder()
.setImageFormat(convertFormatCode(bitmap.getConfig()))
.setStorageType(Image.STORAGE_TYPE_BITMAP)
.setStorageType(MPImage.STORAGE_TYPE_BITMAP)
.build();
}
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
}
@Override
public ImageProperties getImageProperties() {
public MPImageProperties getImageProperties() {
return properties;
}
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
bitmap.recycle();
}
@ImageFormat
@MPImageFormat
static int convertFormatCode(Bitmap.Config config) {
switch (config) {
case ALPHA_8:
return Image.IMAGE_FORMAT_ALPHA;
return MPImage.IMAGE_FORMAT_ALPHA;
case ARGB_8888:
return Image.IMAGE_FORMAT_RGBA;
return MPImage.IMAGE_FORMAT_RGBA;
default:
return Image.IMAGE_FORMAT_UNKNOWN;
return MPImage.IMAGE_FORMAT_UNKNOWN;
}
}
}

View File

@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Locale;
/**
* Utility for extracting {@link ByteBuffer} from {@link Image}.
* Utility for extracting {@link ByteBuffer} from {@link MPImage}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
* {@link IllegalArgumentException} will be thrown.
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
* otherwise {@link IllegalArgumentException} will be thrown.
*/
public class ByteBufferExtractor {
/**
* Extracts a {@link ByteBuffer} from an {@link Image}.
* Extracts a {@link ByteBuffer} from a {@link MPImage}.
*
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
* MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
*
* @see Image#getContainedImageProperties()
* @see MPImage#getContainedImageProperties()
* @return A read-only {@link ByteBuffer}.
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/
@SuppressLint("SwitchIntDef")
public static ByteBuffer extract(Image image) {
ImageContainer container = image.getContainer();
public static ByteBuffer extract(MPImage image) {
MPImageContainer container = image.getContainer();
switch (container.getImageProperties().getStorageType()) {
case Image.STORAGE_TYPE_BYTEBUFFER:
case MPImage.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
default:
throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
+ " supported");
}
}
/**
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
*
* <p>Format conversion spec:
*
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
*
* @param image the image to extract buffer from.
* @param targetFormat the image format of the result bytebuffer.
* @return the readonly {@link ByteBuffer} stored in {@link Image}
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
ImageContainer container;
ImageProperties byteBufferProperties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
MPImageContainer container;
MPImageProperties byteBufferProperties =
MPImageProperties.builder()
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(targetFormat)
.build();
if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer =
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
return byteBuffer;
} else {
throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
"Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
+ " Bytebuffer is not supported");
}
}
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
/** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
@AutoValue
abstract static class Result {
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
/**
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
public abstract ByteBuffer buffer();
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
@ImageFormat
/**
* Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
@MPImageFormat
public abstract int format();
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
}
}
/**
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
*
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
*
* @return the readonly {@link ByteBuffer} stored in {@link Image}
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
* given {@code imageFormat}
*/
static Result extractInRecommendedFormat(Image image) {
ImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
static Result extractInRecommendedFormat(MPImage image) {
MPImageContainer container;
if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
@ImageFormat int format = adviseImageFormat(bitmap);
@MPImageFormat int format = adviseImageFormat(bitmap);
Result result =
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
boolean unused =
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
return result;
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return Result.create(
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
byteBufferImageContainer.getImageFormat());
} else {
throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
"Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
+ " is not supported");
}
}
@ImageFormat
@MPImageFormat
private static int adviseImageFormat(Bitmap bitmap) {
if (bitmap.getConfig() == Config.ARGB_8888) {
return Image.IMAGE_FORMAT_RGBA;
return MPImage.IMAGE_FORMAT_RGBA;
} else {
throw new IllegalArgumentException(
String.format(
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
"Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
+ " supported",
bitmap.getConfig()));
}
}
private static ByteBuffer extractByteBufferFromBitmap(
Bitmap bitmap, @ImageFormat int imageFormat) {
Bitmap bitmap, @MPImageFormat int imageFormat) {
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
"Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
+ " supported");
}
if (bitmap.getConfig() == Config.ARGB_8888) {
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer);
buffer.rewind();
return buffer;
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
} else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
int w = bitmap.getWidth();
int h = bitmap.getHeight();
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
}
throw new IllegalArgumentException(
String.format(
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
"Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
+ " %d is not supported",
bitmap.getConfig(), imageFormat));
}
private static ByteBuffer convertByteBuffer(
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
// array reversely to convert in-place.
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
target.put(array, 0, target.capacity());
target.rewind();
return target;
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
} else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
// array to convert in-place.

View File

@ -15,11 +15,11 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer;
/**
* Builds a {@link Image} from a {@link ByteBuffer}.
* Builds a {@link MPImage} from a {@link ByteBuffer}.
*
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
private final ByteBuffer buffer;
private final int width;
private final int height;
@ImageFormat private final int imageFormat;
@MPImageFormat private final int imageFormat;
// Optional fields.
private long timestamp;
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
* @param imageFormat how the data encode the image.
*/
public ByteBufferImageBuilder(
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
this.buffer = byteBuffer;
this.width = width;
this.height = height;
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
this.timestamp = 0;
}
/** Sets value for {@link Image#getTimestamp()}. */
/** Sets value for {@link MPImage#getTimestamp()}. */
ByteBufferImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
/** Builds a {@link MPImage} instance. */
public MPImage build() {
return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
}
}

View File

@ -15,21 +15,19 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer {
class ByteBufferImageContainer implements MPImageContainer {
private final ByteBuffer buffer;
private final ImageProperties properties;
private final MPImageProperties properties;
public ByteBufferImageContainer(
ByteBuffer buffer,
@ImageFormat int imageFormat) {
public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
this.buffer = buffer;
this.properties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
MPImageProperties.builder()
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(imageFormat)
.build();
}
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
}
@Override
public ImageProperties getImageProperties() {
public MPImageProperties getImageProperties() {
return properties;
}
/**
* Returns the image format.
*/
@ImageFormat
/** Returns the image format. */
@MPImageFormat
public int getImageFormat() {
return properties.getImageFormat();
}

View File

@ -29,10 +29,10 @@ import java.util.Map.Entry;
/**
* The wrapper class for image objects.
*
* <p>{@link Image} is designed to be an immutable image container, which could be shared
* <p>{@link MPImage} is designed to be an immutable image container, which could be shared
* cross-platforms.
*
* <p>To construct an {@link Image}, use the provided builders:
* <p>To construct a {@link MPImage}, use the provided builders:
*
* <ul>
* <li>{@link ByteBufferImageBuilder}
@ -40,7 +40,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageBuilder}
* </ul>
*
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
* <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
*
@ -53,7 +53,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageExtractor}
* </ul>
*/
public class Image implements Closeable {
public class MPImage implements Closeable {
/** Specifies the image format of an image. */
@IntDef({
@ -69,7 +69,7 @@ public class Image implements Closeable {
IMAGE_FORMAT_JPEG,
})
@Retention(RetentionPolicy.SOURCE)
public @interface ImageFormat {}
public @interface MPImageFormat {}
public static final int IMAGE_FORMAT_UNKNOWN = 0;
public static final int IMAGE_FORMAT_RGBA = 1;
@ -98,14 +98,14 @@ public class Image implements Closeable {
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
/**
* Returns a list of supported image properties for this {@link Image}.
* Returns a list of supported image properties for this {@link MPImage}.
*
* <p>Currently {@link Image} only support single storage type so the size of return list will
* <p>Currently {@link MPImage} only support single storage type so the size of return list will
* always be 1.
*
* @see ImageProperties
* @see MPImageProperties
*/
public List<ImageProperties> getContainedImageProperties() {
public List<MPImageProperties> getContainedImageProperties() {
return Collections.singletonList(getContainer().getImageProperties());
}
@ -124,7 +124,7 @@ public class Image implements Closeable {
return height;
}
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
/** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
private synchronized void acquire() {
referenceCount += 1;
}
@ -132,7 +132,7 @@ public class Image implements Closeable {
/**
* Removes a reference that was previously acquired or init.
*
* <p>When {@link Image} is created, it has 1 reference count.
* <p>When {@link MPImage} is created, it has 1 reference count.
*
* <p>When the reference count becomes 0, it will release the resource under the hood.
*/
@ -141,24 +141,24 @@ public class Image implements Closeable {
public synchronized void close() {
referenceCount -= 1;
if (referenceCount == 0) {
for (ImageContainer imageContainer : containerMap.values()) {
for (MPImageContainer imageContainer : containerMap.values()) {
imageContainer.close();
}
}
}
/** Advanced API access for {@link Image}. */
/** Advanced API access for {@link MPImage}. */
static final class Internal {
/**
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
* Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
*
* <p>This method is more useful for image consumer to acquire a reference so image resource
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
* method.
*
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link Image} anymore.
* <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link MPImage} anymore.
*
* @see #close()
*/
@ -166,10 +166,10 @@ public class Image implements Closeable {
image.acquire();
}
private final Image image;
private final MPImage image;
// Only Image creates the internal helper.
private Internal(Image image) {
// Only MPImage creates the internal helper.
private Internal(MPImage image) {
this.image = image;
}
}
@ -179,15 +179,15 @@ public class Image implements Closeable {
return new Internal(this);
}
private final Map<ImageProperties, ImageContainer> containerMap;
private final Map<MPImageProperties, MPImageContainer> containerMap;
private final long timestamp;
private final int width;
private final int height;
private int referenceCount;
/** Constructs an {@link Image} with a built container. */
Image(ImageContainer container, long timestamp, int width, int height) {
/** Constructs a {@link MPImage} with a built container. */
MPImage(MPImageContainer container, long timestamp, int width, int height) {
this.containerMap = new HashMap<>();
containerMap.put(container.getImageProperties(), container);
this.timestamp = timestamp;
@ -201,10 +201,10 @@ public class Image implements Closeable {
*
* @return the current container.
*/
ImageContainer getContainer() {
MPImageContainer getContainer() {
// According to the design, in the future we will support multiple containers in one image.
// Currently just return the original container.
// TODO: Cache multiple containers in Image.
// TODO: Cache multiple containers in MPImage.
return containerMap.values().iterator().next();
}
@ -214,8 +214,8 @@ public class Image implements Closeable {
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
*/
@Nullable
ImageContainer getContainer(@StorageType int storageType) {
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
MPImageContainer getContainer(@StorageType int storageType) {
for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
if (entry.getKey().getStorageType() == storageType) {
return entry.getValue();
}
@ -225,13 +225,13 @@ public class Image implements Closeable {
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
@Nullable
ImageContainer getContainer(ImageProperties imageProperties) {
MPImageContainer getContainer(MPImageProperties imageProperties) {
return containerMap.get(imageProperties);
}
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
boolean addContainer(ImageContainer container) {
ImageProperties imageProperties = container.getImageProperties();
boolean addContainer(MPImageContainer container) {
MPImageProperties imageProperties = container.getImageProperties();
if (containerMap.containsKey(imageProperties)) {
return false;
}

View File

@ -14,14 +14,14 @@ limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that can receive {@link Image} */
public interface ImageConsumer {
/** Lightweight abstraction for an object that can receive {@link MPImage} */
public interface MPImageConsumer {
/**
* Called when an {@link Image} is available.
* Called when a {@link MPImage} is available.
*
* <p>The argument is only guaranteed to be available until this method returns. if you need to
* extend its life time, acquire it, then release it when done.
*/
void onNewImage(Image image);
void onNewMPImage(MPImage image);
}

View File

@ -16,9 +16,9 @@ limitations under the License.
package com.google.mediapipe.framework.image;
/** Manages internal image data storage. The interface is package-private. */
interface ImageContainer {
interface MPImageContainer {
/** Returns the properties of the contained image. */
ImageProperties getImageProperties();
MPImageProperties getImageProperties();
/** Close the image container and releases the image resource inside. */
void close();

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that produce {@link Image} */
public interface ImageProducer {
/** Lightweight abstraction for an object that produce {@link MPImage} */
public interface MPImageProducer {
/** Sets the consumer that receives the {@link Image}. */
void setImageConsumer(ImageConsumer imageConsumer);
/** Sets the consumer that receives the {@link MPImage}. */
void setMPImageConsumer(MPImageConsumer imageConsumer);
}

View File

@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
import com.google.auto.value.AutoValue;
import com.google.auto.value.extension.memoized.Memoized;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.Image.StorageType;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import com.google.mediapipe.framework.image.MPImage.StorageType;
/** Groups a set of properties to describe how an image is stored. */
@AutoValue
public abstract class ImageProperties {
public abstract class MPImageProperties {
/**
* Gets the pixel format of the image.
*
* @see Image.ImageFormat
* @see MPImage.MPImageFormat
*/
@ImageFormat
@MPImageFormat
public abstract int getImageFormat();
/**
* Gets the storage type of the image.
*
* @see Image.StorageType
* @see MPImage.StorageType
*/
@StorageType
public abstract int getStorageType();
@ -45,36 +45,36 @@ public abstract class ImageProperties {
public abstract int hashCode();
/**
* Creates a builder of {@link ImageProperties}.
* Creates a builder of {@link MPImageProperties}.
*
* @see ImageProperties.Builder
* @see MPImageProperties.Builder
*/
static Builder builder() {
return new AutoValue_ImageProperties.Builder();
return new AutoValue_MPImageProperties.Builder();
}
/** Builds a {@link ImageProperties}. */
/** Builds a {@link MPImageProperties}. */
@AutoValue.Builder
abstract static class Builder {
/**
* Sets the {@link Image.ImageFormat}.
* Sets the {@link MPImage.MPImageFormat}.
*
* @see ImageProperties#getImageFormat
* @see MPImageProperties#getImageFormat
*/
abstract Builder setImageFormat(@ImageFormat int value);
abstract Builder setImageFormat(@MPImageFormat int value);
/**
* Sets the {@link Image.StorageType}.
* Sets the {@link MPImage.StorageType}.
*
* @see ImageProperties#getStorageType
* @see MPImageProperties#getStorageType
*/
abstract Builder setStorageType(@StorageType int value);
/** Builds the {@link ImageProperties}. */
abstract ImageProperties build();
/** Builds the {@link MPImageProperties}. */
abstract MPImageProperties build();
}
// Hide the constructor.
ImageProperties() {}
MPImageProperties() {}
}

View File

@ -15,11 +15,12 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
/**
* Builds {@link Image} from {@link android.media.Image}.
* Builds {@link MPImage} from {@link android.media.Image}.
*
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
* content in it.
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
public class MediaImageBuilder {
// Mandatory fields.
private final android.media.Image mediaImage;
private final Image mediaImage;
// Optional fields.
private long timestamp;
@ -40,20 +41,20 @@ public class MediaImageBuilder {
*
* @param mediaImage image data object.
*/
public MediaImageBuilder(android.media.Image mediaImage) {
public MediaImageBuilder(Image mediaImage) {
this.mediaImage = mediaImage;
this.timestamp = 0;
}
/** Sets value for {@link Image#getTimestamp()}. */
/** Sets value for {@link MPImage#getTimestamp()}. */
MediaImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(
/** Builds a {@link MPImage} instance. */
public MPImage build() {
return new MPImage(
new MediaImageContainer(mediaImage),
timestamp,
mediaImage.getWidth(),

View File

@ -15,33 +15,34 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
@RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer {
class MediaImageContainer implements MPImageContainer {
private final android.media.Image mediaImage;
private final ImageProperties properties;
private final Image mediaImage;
private final MPImageProperties properties;
public MediaImageContainer(android.media.Image mediaImage) {
public MediaImageContainer(Image mediaImage) {
this.mediaImage = mediaImage;
this.properties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
MPImageProperties.builder()
.setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
.build();
}
public android.media.Image getImage() {
public Image getImage() {
return mediaImage;
}
@Override
public ImageProperties getImageProperties() {
public MPImageProperties getImageProperties() {
return properties;
}
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
mediaImage.close();
}
@ImageFormat
@MPImageFormat
static int convertFormatCode(int graphicsFormat) {
// We only cover the format mentioned in
// https://developer.android.com/reference/android/media/Image#getFormat()
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
return Image.IMAGE_FORMAT_RGBA;
return MPImage.IMAGE_FORMAT_RGBA;
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
return Image.IMAGE_FORMAT_RGB;
return MPImage.IMAGE_FORMAT_RGB;
}
}
switch (graphicsFormat) {
case android.graphics.ImageFormat.JPEG:
return Image.IMAGE_FORMAT_JPEG;
return MPImage.IMAGE_FORMAT_JPEG;
case android.graphics.ImageFormat.YUV_420_888:
return Image.IMAGE_FORMAT_YUV_420_888;
return MPImage.IMAGE_FORMAT_YUV_420_888;
default:
return Image.IMAGE_FORMAT_UNKNOWN;
return MPImage.IMAGE_FORMAT_UNKNOWN;
}
}
}

View File

@ -15,13 +15,14 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
/**
* Utility for extracting {@link android.media.Image} from {@link Image}.
* Utility for extracting {@link android.media.Image} from {@link MPImage}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
* otherwise {@link IllegalArgumentException} will be thrown.
*/
@RequiresApi(VERSION_CODES.KITKAT)
@ -30,20 +31,20 @@ public class MediaImageExtractor {
private MediaImageExtractor() {}
/**
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
* {@link Image} that built from {@link MediaImageBuilder}.
* Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
* {@link MPImage} that built from {@link MediaImageBuilder}.
*
* @param image the image to extract {@link android.media.Image} from.
* @return {@link android.media.Image} that stored in {@link Image}.
* @return {@link android.media.Image} that stored in {@link MPImage}.
* @throws IllegalArgumentException if the extraction failed.
*/
public static android.media.Image extract(Image image) {
ImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
public static Image extract(MPImage image) {
MPImageContainer container;
if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
return ((MediaImageContainer) container).getImage();
}
throw new IllegalArgumentException(
"Extract Media Image from an Image created by objects other than Media Image"
"Extract Media Image from a MPImage created by objects other than Media Image"
+ " is not supported");
}
}

View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors.
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""):
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""):
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
))
return proto_src_list

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],

View File

@ -13,6 +13,7 @@
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
@ -23,15 +24,12 @@ package(
py_library(
name = "data_util",
srcs = ["data_util.py"],
srcs_version = "PY3",
)
py_test(
name = "data_util_test",
srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":data_util"],
)
@ -44,8 +42,6 @@ py_library(
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dataset",
"//mediapipe/model_maker/python/core/utils:test_util",
@ -55,14 +51,11 @@ py_test(
py_library(
name = "classification_dataset",
srcs = ["classification_dataset.py"],
srcs_version = "PY3",
deps = [":dataset"],
)
py_test(
name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":classification_dataset"],
)

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
from typing import List, Tuple
import tensorflow as tf
@ -21,15 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
"""Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
super().__init__(dataset, size)
self.index_to_label = index_to_label
self._label_names = label_names
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label)
return len(self._label_names)
@property
def label_names(self: ds._DatasetT) -> List[str]:
return self._label_names
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -44,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self.index_to_label)
return self._split(fraction, self._label_names)

View File

@ -12,45 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
_DatasetT = TypeVar(
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
class ClassificationDataLoaderTest(tf.test.TestCase):
class ClassificationDatasetTest(tf.test.TestCase):
def test_split(self):
class MagicClassificationDataLoader(
class MagicClassificationDataset(
classification_dataset.ClassificationDataset):
"""A mock classification dataset class for testing purpose.
def __init__(self, dataset, size, index_to_label, value):
super(MagicClassificationDataLoader,
self).__init__(dataset, size, index_to_label)
Attributes:
value: A value variable stored by the mock dataset class for testing.
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str], value: Any):
super().__init__(dataset=dataset, size=size, label_names=label_names)
self.value = value
def split(self, fraction):
return self._split(fraction, self.index_to_label, self.value)
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.label_names, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_to_label = (False, True)
label_names = ['foo', 'bar']
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
magic_value)
data = MagicClassificationDataset(
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
# Train/Test data split.
fraction = .25
train_data, test_data = data.split(fraction)
train_data, test_data = data.split(fraction=fraction)
# `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader)
self.assertIsInstance(test_data, MagicClassificationDataLoader)
self.assertIsInstance(train_data, MagicClassificationDataset)
self.assertIsInstance(test_data, MagicClassificationDataset)
# Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
@ -59,7 +69,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label)
self.assertEqual(test_data.label_names, label_names)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
@ -23,7 +24,6 @@ licenses(["notice"])
py_library(
name = "custom_model",
srcs = ["custom_model.py"],
srcs_version = "PY3",
deps = [
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util",
@ -34,8 +34,6 @@ py_library(
py_test(
name = "custom_model_test",
srcs = ["custom_model_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":custom_model",
"//mediapipe/model_maker/python/core/utils:test_util",
@ -45,7 +43,6 @@ py_test(
py_library(
name = "classifier",
srcs = ["classifier.py"],
srcs_version = "PY3",
deps = [
":custom_model",
"//mediapipe/model_maker/python/core/data:dataset",
@ -55,8 +52,6 @@ py_library(
py_test(
name = "classifier_test",
srcs = ["classifier_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":classifier",
"//mediapipe/model_maker/python/core/utils:test_util",

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
full_train: bool):
"""Initilizes a classifier with its specifications.
Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
"""
super(Classifier, self).__init__(model_spec, shuffle)
self._index_to_label = index_to_label
self._label_names = label_names
self._full_train = full_train
self._num_classes = len(index_to_label)
self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_to_label))
f.write('\n'.join(self._label_names))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self):
super(ClassifierTest, self).setUp()
index_to_label = ['cat', 'dog']
label_names = ['cat', 'dog']
self.model = MockClassifier(
model_spec=None,
index_to_label=index_to_label,
label_names=label_names,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -21,8 +21,6 @@ import abc
import os
from typing import Any, Callable, Optional
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite(
self._model,
tflite_filepath,
quantization_config,
model=self._model,
tflite_filepath=tflite_filepath,
quantization_config=quantization_config,
preprocess=preprocess)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
def setUp(self):
super(CustomModelTest, self).setUp()
self.model = MockCustomModel(model_spec=None, shuffle=False)
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
self._model = MockCustomModel(model_spec=None, shuffle=False)
self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath))
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
def test_export_tflite(self):
export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_tflite(export_dir=export_path)
self._model.export_tflite(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
if __name__ == '__main__':

View File

@ -13,6 +13,7 @@
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
@ -24,31 +25,15 @@ py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [
":model_util",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
srcs_version = "PY3",
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":image_preprocessing"],
)
py_library(
name = "model_util",
srcs = ["model_util.py"],
srcs_version = "PY3",
deps = [
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
@ -58,8 +43,6 @@ py_library(
py_test(
name = "model_util_test",
srcs = ["model_util_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":model_util",
":quantization",
@ -76,8 +59,6 @@ py_library(
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":loss_functions"],
)
@ -91,8 +72,6 @@ py_library(
py_test(
name = "quantization_test",
srcs = ["quantization_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":quantization",
":test_util",

View File

@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss):
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super(tf.keras.losses.Loss, self).__init__()
super().__init__()
# Used for clipping min/max values of probability values in y_pred to avoid
# NaNs and Infs in computation.
self._epsilon = 1e-7

View File

@ -104,8 +104,8 @@ def export_tflite(
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training.
quantization. The callable takes three arguments in order: feature, label,
and is_training.
"""
if tflite_filepath is None:
raise ValueError(

View File

@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim)
test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
@parameterized.named_parameters(
dict(
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
input_dim = 16
num_classes = 2
max_input_value = 5
model = test_util.build_model([input_dim], num_classes)
model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config)
self._test_tflite(
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
random_input = test_util.create_random_sample(
size=[1, input_dim], high=max_input_value)
random_input = tf.convert_to_tensor(random_input)
model_util.export_tflite(
model=model, tflite_filepath=tflite_file, quantization_config=config)
self.assertTrue(
test_util.is_same_output(
tflite_model_file, keras_model, random_input, atol=atol))
test_util.test_tflite(
keras_model=model,
tflite_file=tflite_file,
size=[1, input_dim],
high=max_input_value,
atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
if __name__ == '__main__':

View File

@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol)
def test_tflite(keras_model: tf.keras.Model,
tflite_file: str,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)

View File

@ -0,0 +1,4 @@
# MediaPipe Model Maker Internal Library
This directory contains model maker library for internal users and experimental
purposes.

View File

@ -0,0 +1 @@
"""Model maker internal library."""

View File

@ -0,0 +1,33 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
deps = [":image_preprocessing"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -13,11 +13,7 @@
# limitations under the License.
# ==============================================================================
"""ImageNet preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
IMAGE_SIZE = 224

View File

@ -12,15 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import image_preprocessing
from mediapipe.model_maker.python.vision.core import image_preprocessing
def _get_preprocessed_image(preprocessor, is_training=False):

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python library rule.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python library rule.
licenses(["notice"])
@ -78,9 +78,9 @@ py_library(
":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
],
)

View File

@ -16,7 +16,7 @@
import os
import random
from typing import List, Optional, Tuple
from typing import List, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
@ -84,10 +84,10 @@ class Dataset(classification_dataset.ClassificationDataset):
name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names)
label_to_index = dict(
index_by_label = dict(
(name, index) for index, name in enumerate(label_names))
all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))]
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
@ -106,33 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(
dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
@classmethod
def load_tf_dataset(
cls, name: str
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset]]:
"""Loads data from tensorflow_datasets.
Args:
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
documentation of tfds.load for more details.
Returns:
A tuple of Datasets for the train/validation/test.
Raises:
ValueError: if the input tf dataset does not have train/validation/test
labels.
"""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = _create_data('train', data, info, label_names)
validation_data = _create_data('validation', data, info, label_names)
test_data = _create_data('test', data, info, label_names)
return train_data, validation_data, test_data
dataset=image_label_ds, size=all_image_size, label_names=label_names)

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
train_data, test_data = data.split(0.5)
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
self.assertEqual(test_data.label_names, ['pos', 'neg'])
def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path)
data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
self.assertEqual(data.label_names, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label,
self.assertEqual(train_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label,
self.assertEqual(validation_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label,
self.assertEqual(test_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -13,16 +13,16 @@
# limitations under the License.
"""APIs to train image classifier model."""
from typing import Any, List, Optional
from typing import List, Optional
import tensorflow as tf
import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import image_preprocessing
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams):
"""Initializes ImageClassifier class.
Args:
model_spec: Specification for the model.
index_to_label: A list that maps from index to label class name.
label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier.
"""
super(ImageClassifier, self).__init__(
super().__init__(
model_spec=model_spec,
index_to_label=index_to_label,
label_names=label_names,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec)
image_classifier = cls(
model_spec=spec,
index_to_label=train_data.index_to_label,
hparams=hparams)
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
image_classifier._create_model()

View File

@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
return model.fit(
x=train_ds,
epochs=hparams.train_epochs,
steps_per_epoch=hparams.steps_per_epoch,
validation_data=validation_ds,
callbacks=callbacks)

View File

@ -161,7 +161,7 @@ class Texture {
~Texture() {
if (is_owned_) {
glDeleteProgram(handle_);
glDeleteTextures(1, &handle_);
}
}

View File

@ -87,6 +87,9 @@ cc_library(
cc_library(
name = "builtin_task_graphs",
deps = [
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
],

View File

@ -14,7 +14,7 @@
"""The public facing packet getter APIs."""
from typing import List, Type
from typing import List
from google.protobuf import message
from google.protobuf import symbol_database
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
get_matrix = _packet_getter.get_matrix
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
def get_proto(packet: mp_packet.Packet) -> message.Message:
"""Get the content of a MediaPipe proto Packet as a proto message.
Args:

View File

@ -46,8 +46,10 @@ cc_library(
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -44,6 +44,30 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "classification_aggregation_calculator_test",
srcs = ["classification_aggregation_calculator_test.cc"],
deps = [
":classification_aggregation_calculator",
":classification_aggregation_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:output_stream_poller",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
mediapipe_proto_library(
name = "score_calibration_calculator_proto",
srcs = ["score_calibration_calculator.proto"],

View File

@ -31,37 +31,62 @@
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has
// 3 dimensions: (classification head, classification timestamp, classification
// category).
// Aggregates ClassificationLists into either a ClassificationResult object
// representing the classification results aggregated by classifier head, or
// into an std::vector<ClassificationResult> representing the classification
// results aggregated first by timestamp then by classifier head.
//
// Inputs:
// CLASSIFICATIONS - ClassificationList
// CLASSIFICATIONS - ClassificationList @Multiple
// ClassificationList per classification head.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that a single ClassificationResult
// should aggragate. This stream is optional, and the timestamp information
// will only be populated to the ClassificationResult proto when this stream
// is connected.
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// output is used for results. Otherwise as no timestamp aggregation is
// required the CLASSIFICATIONS output is used for results.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// Example:
// Example without timestamp aggregation:
// node {
// calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c"
// output_stream: "CLASSIFICATIONS:classifications"
// options {
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a"
// head_names: "head_name_b"
// head_names: "head_name_c"
// }
// }
// }
//
// Example with timestamp aggregation:
// node {
// calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c"
// input_stream: "TIMESTAMPS:timestamps"
// output_stream: "CLASSIFICATION_RESULT:classification_result"
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options {
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] {
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a"
// head_names: "head_name_b"
// head_names: "head_name_c"
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
"CLASSIFICATIONS"};
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
"TIMESTAMPS"};
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut);
static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
"CLASSIFICATIONS"};
static constexpr Output<std::vector<ClassificationResult>>::Optional
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
static constexpr Output<ClassificationResult>::Optional
kClassificationResultOut{"CLASSIFICATION_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
kClassificationsOut, kTimestampedClassificationsOut,
kClassificationResultOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc);
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
cached_classifications_;
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
CalculatorContext* cc);
// TODO: deprecate this function once migration is over.
ClassificationResult LegacyConvertToClassificationResult(
CalculatorContext* cc);
};
absl::Status ClassificationAggregationCalculator::UpdateContract(
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
<< "The size of classifications input streams should match the "
"size of head names specified in the calculator options";
}
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
// not connected. All dependent tasks must be updated to use these outputs
// first.
return absl::OkStatus();
}
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
[](const auto& elem) -> ClassificationList { return elem.Get(); });
cached_classifications_[cc->InputTimestamp().Value()] =
std::move(classification_lists);
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) {
ClassificationResult classification_result;
if (time_aggregation_enabled_) {
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
kOut(cc).Send(ConvertToClassificationResult(cc));
classification_result = LegacyConvertToClassificationResult(cc);
kTimestampedClassificationsOut(cc).Send(
ConvertToTimestampedClassificationResults(cc));
} else {
classification_result = LegacyConvertToClassificationResult(cc);
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
}
kClassificationResultOut(cc).Send(classification_result);
RET_CHECK(cached_classifications_.empty());
return absl::OkStatus();
}
@ -136,6 +186,50 @@ ClassificationResult
ClassificationAggregationCalculator::ConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
auto& classification_lists =
cached_classifications_[cc->InputTimestamp().Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(cc->InputTimestamp().Value());
return result;
}
std::vector<ClassificationResult>
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
CalculatorContext* cc) {
auto timestamps = kTimestampsIn(cc).Get();
std::vector<ClassificationResult> results;
results.reserve(timestamps.size());
for (const auto& timestamp : timestamps) {
ClassificationResult result;
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
auto& classification_lists = cached_classifications_[timestamp.Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(timestamp.Value());
results.push_back(std::move(result));
}
return results;
}
ClassificationResult
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
Timestamp first_timestamp(0);
std::vector<Timestamp> timestamps;
if (time_aggregation_enabled_) {
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
1000);
}
cached_classifications_.erase(timestamp.Value());
}
return result;
}

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks;
package mediapipe;
import "mediapipe/framework/calculator.proto";

View File

@ -0,0 +1,213 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::Pointwise;
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
constexpr char kClassificationInput0Name[] = "classifications_0";
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
constexpr char kClassificationInput1Name[] = "classifications_1";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications";
ClassificationList MakeClassificationList(int class_index) {
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
R"pb(
classification { index: %d }
)pb",
class_index));
}
class ClassificationAggregationCalculatorTest
: public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
bool connect_timestamps = false) {
Graph graph;
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
calculator
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
ParseTextProtoOrDie<
mediapipe::ClassificationAggregationCalculatorOptions>(
R"pb(head_names: "foo" head_names: "bar")pb");
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
kClassificationInput0Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
kClassificationInput1Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
calculator.In(kTimestampsTag);
calculator.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
} else {
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
graph[Output<ClassificationResult>(kClassificationsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
absl::Status Send(
std::vector<ClassificationList> classifications, int timestamp = 0,
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput0Name,
MakePacket<ClassificationList>(classifications[0])
.At(Timestamp(timestamp))));
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput1Name,
MakePacket<ClassificationList>(classifications[1])
.At(Timestamp(timestamp))));
if (aggregation_timestamps.has_value()) {
auto packet = std::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
};
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
})pb")));
}
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK(Send(
{MakeClassificationList(2), MakeClassificationList(3)},
/*timestamp=*/1000,
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
MP_ASSERT_OK_AND_ASSIGN(auto result,
GetResult<std::vector<ClassificationResult>>(poller));
EXPECT_THAT(result,
Pointwise(EqualsProto(),
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
}
)pb"),
ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 1,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 2 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 3 } }
}
)pb")}));
}
} // namespace
} // namespace mediapipe

View File

@ -29,3 +29,23 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library(
name = "category",
srcs = ["category.cc"],
hdrs = ["category.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
],
)
cc_library(
name = "classification_result",
srcs = ["classification_result.cc"],
hdrs = ["classification_result.h"],
deps = [
":category",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/category.h"
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
Category ConvertToCategory(const mediapipe::Classification& proto) {
Category category;
category.index = proto.index();
category.score = proto.score();
if (proto.has_label()) {
category.category_name = proto.label();
}
if (proto.has_display_name()) {
category.display_name = proto.display_name();
}
return category;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines a single classification result.
//
// The label maps packed into the TFLite Model Metadata [1] are used to populate
// the 'category_name' and 'display_name' fields.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
struct Category {
// The index of the category in the classification model output.
int index;
// The score for this category, e.g. (but not necessarily) a probability in
// [0,1].
float score;
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
std::optional<std::string> category_name = std::nullopt;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
std::optional<std::string> display_name = std::nullopt;
};
// Utility function to convert from mediapipe::Classification proto to Category
// struct.
Category ConvertToCategory(const mediapipe::Classification& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
Classifications ConvertToClassifications(const proto::Classifications& proto) {
Classifications classifications;
classifications.categories.reserve(
proto.classification_list().classification_size());
for (const auto& classification :
proto.classification_list().classification()) {
classifications.categories.push_back(ConvertToCategory(classification));
}
classifications.head_index = proto.head_index();
if (proto.has_head_name()) {
classifications.head_name = proto.head_name();
}
return classifications;
}
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto) {
ClassificationResult classification_result;
classification_result.classifications.reserve(proto.classifications_size());
for (const auto& classifications : proto.classifications()) {
classification_result.classifications.push_back(
ConvertToClassifications(classifications));
}
if (proto.has_timestamp_ms()) {
classification_result.timestamp_ms = proto.timestamp_ms();
}
return classification_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,68 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines classification results for a given classifier head.
struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
// e.g. from high to low probability.
std::vector<Category> categories;
// The index of the classifier head (i.e. output tensor) these categories
// refer to. This is useful for multi-head models.
int head_index;
// The optional name of the classifier head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines classification results of a model.
struct ClassificationResult {
// The classification results for each head of the model.
std::vector<Classifications> classifications;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Classifications proto to
// Classifications struct.
Classifications ConvertToClassifications(const proto::Classifications& proto);
// Utility function to convert from ClassificationResult proto to
// ClassificationResult struct.
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

View File

@ -28,6 +28,7 @@ mediapipe_proto_library(
srcs = ["classifications.proto"],
deps = [
":category_proto",
"//mediapipe/framework/formats:classification_proto",
],
)

View File

@ -17,9 +17,10 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto";
// TODO: deprecate this message once migration is over.
// A single classification result.
message Category {
// The index of the category in the corresponding label map, usually packed in

View File

@ -17,11 +17,13 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto";
// TODO: deprecate this message once migration is over.
// List of predicted categories with an optional timestamp.
message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores,
@ -33,9 +35,12 @@ message ClassificationEntry {
optional int64 timestamp_ms = 2;
}
// Classifications for a given classifier head.
// Classifications for a given classifier head, i.e. for a given output tensor.
message Classifications {
// TODO: deprecate this field once migration is over.
repeated ClassificationEntry entries = 1;
// The classification results for this head.
optional mediapipe.ClassificationList classification_list = 4;
// The index of the classifier head these categories refer to. This is useful
// for multi-head models.
optional int32 head_index = 2;
@ -45,7 +50,17 @@ message Classifications {
optional string head_name = 3;
}
// Contains one set of results per classifier head.
// Classifications for a given classifier model.
message ClassificationResult {
// The classification results for each model head, i.e. one for each output
// tensor.
repeated Classifications classifications = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
}

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "EmbeddingsProto";
// Defines a dense floating-point embedding.
message FloatEmbedding {
repeated float values = 1 [packed = true];

View File

@ -30,9 +30,11 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator(
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
std);
}
// TODO: need to support different GPU origin on differnt
// platforms or applications.
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
return absl::OkStatus();
}
} // namespace
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration) {
return acceleration.has_gpu();
}
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
bool use_gpu,
ImagePreprocessingOptions* options) {
ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources));
@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
image_tensor_specs, options->mutable_image_to_tensor_options()));
// The GPU backend isn't able to process int data. If the input tensor is
// quantized, forces the image preprocessing graph to use CPU backend.
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) {
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
} else {
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
}
return absl::OkStatus();

View File

@ -19,20 +19,26 @@ limitations under the License.
#include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
// Configures an ImagePreprocessing subgraph using the provided model resources.
// Configures an ImagePreprocessing subgraph using the provided model resources
// When use_gpu is true, use GPU as backend to convert image to tensor.
// - Accepts CPU input images and outputs CPU tensors.
//
// Example usage:
//
// auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
// core::proto::Acceleration acceleration;
// acceleration.mutable_xnnpack();
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
// model_resources,
// use_gpu,
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
//
// The resulting ImagePreprocessing subgraph has the following I/O:
@ -56,9 +62,14 @@ namespace components {
// The image that has the pixel data stored on the target storage (CPU vs
// GPU).
absl::Status ConfigureImagePreprocessing(
const core::ModelResources& model_resources,
const core::ModelResources& model_resources, bool use_gpu,
ImagePreprocessingOptions* options);
// Determine if the image preprocessing subgraph should use GPU as the backend
// according to the given acceleration setting.
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration);
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
// Struct holding the different output streams produced by the graph.
struct ClassificationPostprocessingOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
// Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions(
@ -286,7 +294,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
mediapipe::ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of timestamps that a single ClassificationResult should
// aggregate. This is mostly useful for classifiers working on time series,
// e.g. audio or video classification.
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// output is used for results. Otherwise as no timestamp aggregation is
// required the CLASSIFICATIONS output is used for results.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results.
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
mediapipe::SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto classification_result_out,
auto output_streams,
BuildClassificationPostprocessing(
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >>
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig();
}
private:
// Adds an on-device classification postprocessing graph into the provided
// builder::Graph instance. The classification postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream containing the output classification results (ClassificationResult).
// tensors (std::vector<mediapipe::Tensor>) and optional timestamps
// (std::vector<Timestamp>) as input and returns two output streams:
// - classification results aggregated by classifier head as a
// ClassificationResult proto, used when no timestamps are passed in
// the graph,
// - classification results aggregated by timestamp then by classifier head
// as a std::vector<ClassificationResult>, used when timestamps are passed
// in the graph.
//
// options: the on-device ClassificationPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate.
// timestamps that should be used to aggregate classification results.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>>
absl::StatusOr<ClassificationPostprocessingOutputStreams>
BuildClassificationPostprocessing(
const proto::ClassificationPostprocessingGraphOptions& options,
Source<std::vector<Tensor>> tensors_in,
@ -494,7 +524,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
// Aggregates Classifications into a single ClassificationResult.
auto& result_aggregation =
graph.AddNode("ClassificationAggregationCalculator");
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
result_aggregation
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
.CopyFrom(options.classification_aggregation_options());
for (int i = 0; i < num_heads; ++i) {
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
@ -504,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
timestamps_in >> result_aggregation.In(kTimestampsTag);
// Connects output.
return result_aggregation[Output<ClassificationResult>(
kClassificationResultTag)];
ClassificationPostprocessingOutputStreams output_streams{
/*classification_result=*/result_aggregation
[Output<ClassificationResult>(kClassificationResultTag)],
/*classifications=*/
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
/*timestamped_classifications=*/
result_aggregation[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)]};
return output_streams;
}
};

Some files were not shown because too many files have changed in this diff Show More