Merge branch 'master' into image-embedder-python
This commit is contained in:
commit
5a68ba84b6
|
@ -30,6 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
git \
|
||||
wget \
|
||||
unzip \
|
||||
nodejs \
|
||||
npm \
|
||||
python3-dev \
|
||||
python3-opencv \
|
||||
python3-pip \
|
||||
|
|
|
@ -172,6 +172,10 @@ http_archive(
|
|||
urls = [
|
||||
"https://github.com/google/sentencepiece/archive/1.0.0.zip",
|
||||
],
|
||||
patches = [
|
||||
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
|
||||
],
|
||||
patch_args = ["-p1"],
|
||||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
|
||||
)
|
||||
|
||||
|
|
14
docs/BUILD
Normal file
14
docs/BUILD
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Placeholder for internal Python strict binary compatibility macro.
|
||||
|
||||
py_binary(
|
||||
name = "build_py_api_docs",
|
||||
srcs = ["build_py_api_docs.py"],
|
||||
deps = [
|
||||
"//mediapipe",
|
||||
"//third_party/py/absl:app",
|
||||
"//third_party/py/absl/flags",
|
||||
"//third_party/py/tensorflow_docs",
|
||||
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
|
||||
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||
],
|
||||
)
|
85
docs/build_py_api_docs.py
Normal file
85
docs/build_py_api_docs.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""MediaPipe reference docs generation script.
|
||||
|
||||
This script generates API reference docs for the `mediapipe` PIP package.
|
||||
|
||||
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
|
||||
$> python build_py_api_docs.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow_docs.api_generator import generate_lib
|
||||
from tensorflow_docs.api_generator import public_api
|
||||
|
||||
try:
|
||||
# mediapipe has not been set up to work with bazel yet, so catch & report.
|
||||
import mediapipe # pytype: disable=import-error
|
||||
except ImportError as e:
|
||||
raise ImportError('Please `pip install mediapipe`.') from e
|
||||
|
||||
|
||||
PROJECT_SHORT_NAME = 'mp'
|
||||
PROJECT_FULL_NAME = 'MediaPipe'
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir',
|
||||
default='/tmp/generated_docs',
|
||||
help='Where to write the resulting docs.')
|
||||
|
||||
_URL_PREFIX = flags.DEFINE_string(
|
||||
'code_url_prefix',
|
||||
'https://github.com/google/mediapipe/tree/master/mediapipe',
|
||||
'The url prefix for links to code.')
|
||||
|
||||
_SEARCH_HINTS = flags.DEFINE_bool(
|
||||
'search_hints', True,
|
||||
'Include metadata search hints in the generated files')
|
||||
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
|
||||
'Path prefix in the _toc.yaml')
|
||||
|
||||
|
||||
def gen_api_docs():
|
||||
"""Generates API docs for the mediapipe package."""
|
||||
|
||||
doc_generator = generate_lib.DocGenerator(
|
||||
root_title=PROJECT_FULL_NAME,
|
||||
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
|
||||
base_dir=os.path.dirname(mediapipe.__file__),
|
||||
code_url_prefix=_URL_PREFIX.value,
|
||||
search_hints=_SEARCH_HINTS.value,
|
||||
site_path=_SITE_PATH.value,
|
||||
# This callback ensures that docs are only generated for objects that
|
||||
# are explicitly imported in your __init__.py files. There are other
|
||||
# options but this is a good starting point.
|
||||
callbacks=[public_api.explicit_package_contents_filter],
|
||||
)
|
||||
|
||||
doc_generator.build(_OUTPUT_DIR.value)
|
||||
|
||||
print('Docs output to:', _OUTPUT_DIR.value)
|
||||
|
||||
|
||||
def main(_):
|
||||
gen_api_docs()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -222,10 +222,10 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -328,6 +328,7 @@ cc_library(
|
|||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -344,6 +345,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
|
|
@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
|
|||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
async_selection: true
|
||||
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
|
|||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
async_selection: true
|
||||
contained_node: {
|
||||
calculator: "BypassCalculator"
|
||||
node_options: {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
|
|||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
||||
|
||||
class ConcatenateClassificationListCalculator
|
||||
: public ConcatenateListsCalculator<Classification, ClassificationList> {
|
||||
protected:
|
||||
int ListSize(const ClassificationList& list) const override {
|
||||
return list.classification_size();
|
||||
}
|
||||
const Classification GetItem(const ClassificationList& list,
|
||||
int idx) const override {
|
||||
return list.classification(idx);
|
||||
}
|
||||
Classification* AddItem(ClassificationList& list) const override {
|
||||
return list.add_classification();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
|
|||
}
|
||||
}
|
||||
|
||||
void AddInputClassificationLists(
|
||||
const std::vector<ClassificationList>& input_classifications_vec,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < input_classifications_vec.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<ClassificationList>(input_classifications_vec[i])
|
||||
.At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
|
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
|
|||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateClassificationListCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||
)pb");
|
||||
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||
)pb");
|
||||
std::vector<ClassificationList> inputs = {input_0, input_1};
|
||||
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
auto result = outputs[0].Get<ClassificationList>();
|
||||
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||
)pb"),
|
||||
EqualsProto(result));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
@ -58,4 +59,7 @@ typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
|
|||
EndLoopDetectionCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopDetectionCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -50,7 +50,7 @@ namespace mediapipe {
|
|||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
template <typename IterableT>
|
||||
class EndLoopCalculator : public CalculatorBase {
|
||||
|
|
|
@ -109,6 +109,56 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "tensors_to_audio_calculator_proto",
|
||||
srcs = ["tensors_to_audio_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensors_to_audio_calculator",
|
||||
srcs = ["tensors_to_audio_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":tensors_to_audio_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@pffft",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensors_to_audio_calculator_test",
|
||||
srcs = ["tensors_to_audio_calculator_test.cc"],
|
||||
deps = [
|
||||
":audio_to_tensor_calculator",
|
||||
":audio_to_tensor_calculator_cc_proto",
|
||||
":tensors_to_audio_calculator",
|
||||
":tensors_to_audio_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "feedback_tensors_calculator_proto",
|
||||
srcs = ["feedback_tensors_calculator.proto"],
|
||||
|
@ -253,6 +303,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "regex_preprocessor_calculator_test",
|
||||
srcs = ["regex_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_to_tensor_calculator",
|
||||
srcs = ["text_to_tensor_calculator.cc"],
|
||||
|
@ -304,6 +374,28 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "universal_sentence_encoder_preprocessor_calculator_test",
|
||||
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
|
||||
deps = [
|
||||
":universal_sentence_encoder_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
|
@ -438,6 +530,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
|
@ -458,6 +551,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_runner",
|
||||
"//mediapipe/framework:mediapipe_profiling",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -1200,13 +1294,30 @@ cc_library(
|
|||
name = "image_to_tensor_utils",
|
||||
srcs = ["image_to_tensor_utils.cc"],
|
||||
hdrs = ["image_to_tensor_utils.h"],
|
||||
copts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_calculator_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": ["//mediapipe/gpu:gpu_buffer"],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_test(
|
||||
|
@ -1216,6 +1327,8 @@ cc_test(
|
|||
":image_to_tensor_utils",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ bool IsValidFftSize(int size) {
|
|||
// invocation. In the non-streaming mode, the vector contains all of the
|
||||
// output timestamps for an input audio buffer.
|
||||
// DC_AND_NYQUIST - std::pair<float, float> @Optional.
|
||||
// A pair of dc component and nyquest component. Only can be connected when
|
||||
// A pair of dc component and nyquist component. Only can be connected when
|
||||
// the calculator performs fft (the fft_size is set in the calculator
|
||||
// options).
|
||||
//
|
||||
|
|
|
@ -54,13 +54,6 @@
|
|||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
|
||||
using GpuBuffer = AnyType;
|
||||
#else
|
||||
using GpuBuffer = mediapipe::GpuBuffer;
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// Converts image into Tensor, possibly with cropping, resizing and
|
||||
// normalization, according to specified inputs and options.
|
||||
//
|
||||
|
@ -141,42 +134,7 @@ class ImageToTensorCalculator : public Node {
|
|||
const auto& options =
|
||||
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
|
||||
RET_CHECK(options.has_output_tensor_float_range() ||
|
||||
options.has_output_tensor_int_range() ||
|
||||
options.has_output_tensor_uint_range())
|
||||
<< "Output tensor range is required.";
|
||||
if (options.has_output_tensor_float_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_float_range().min(),
|
||||
options.output_tensor_float_range().max())
|
||||
<< "Valid output float tensor range is required.";
|
||||
}
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_uint_range().min(),
|
||||
options.output_tensor_uint_range().max())
|
||||
<< "Valid output uint tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
|
||||
<< "The minimum of the output uint tensor range must be "
|
||||
"non-negative.";
|
||||
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
|
||||
<< "The maximum of the output uint tensor range must be less than or "
|
||||
"equal to 255.";
|
||||
}
|
||||
if (options.has_output_tensor_int_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_int_range().min(),
|
||||
options.output_tensor_int_range().max())
|
||||
<< "Valid output int tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
|
||||
<< "The minimum of the output int tensor range must be greater than "
|
||||
"or equal to -128.";
|
||||
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
|
||||
<< "The maximum of the output int tensor range must be less than or "
|
||||
"equal to 127.";
|
||||
}
|
||||
RET_CHECK_GT(options.output_tensor_width(), 0)
|
||||
<< "Valid output tensor width is required.";
|
||||
RET_CHECK_GT(options.output_tensor_height(), 0)
|
||||
<< "Valid output tensor height is required.";
|
||||
|
||||
RET_CHECK_OK(ValidateOptionOutputDims(options));
|
||||
RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected())
|
||||
<< "One and only one of IMAGE and IMAGE_GPU input is expected.";
|
||||
|
||||
|
@ -198,21 +156,7 @@ class ImageToTensorCalculator : public Node {
|
|||
|
||||
absl::Status Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
output_width_ = options_.output_tensor_width();
|
||||
output_height_ = options_.output_tensor_height();
|
||||
is_float_output_ = options_.has_output_tensor_float_range();
|
||||
if (options_.has_output_tensor_uint_range()) {
|
||||
range_min_ =
|
||||
static_cast<float>(options_.output_tensor_uint_range().min());
|
||||
range_max_ =
|
||||
static_cast<float>(options_.output_tensor_uint_range().max());
|
||||
} else if (options_.has_output_tensor_int_range()) {
|
||||
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
|
||||
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
|
||||
} else {
|
||||
range_min_ = options_.output_tensor_float_range().min();
|
||||
range_max_ = options_.output_tensor_float_range().max();
|
||||
}
|
||||
params_ = GetOutputTensorParams(options_);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -242,7 +186,13 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
ASSIGN_OR_RETURN(auto image, GetInputImage(kIn(cc)));
|
||||
#else
|
||||
const bool is_input_gpu = kInGpu(cc).IsConnected();
|
||||
ASSIGN_OR_RETURN(auto image, is_input_gpu ? GetInputImage(kInGpu(cc))
|
||||
: GetInputImage(kIn(cc)));
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
|
||||
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
||||
|
@ -263,11 +213,13 @@ class ImageToTensorCalculator : public Node {
|
|||
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
|
||||
|
||||
Tensor::ElementType output_tensor_type =
|
||||
GetOutputTensorType(image->UsesGpu());
|
||||
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
|
||||
GetOutputTensorType(image->UsesGpu(), params_);
|
||||
Tensor tensor(output_tensor_type,
|
||||
{1, params_.output_height, params_.output_width,
|
||||
GetNumOutputChannels(*image)});
|
||||
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||
->Convert(*image, roi, range_min_, range_max_,
|
||||
->Convert(*image, roi, params_.range_min,
|
||||
params_.range_max,
|
||||
/*tensor_buffer_offset=*/0, tensor));
|
||||
|
||||
auto result = std::make_unique<std::vector<Tensor>>();
|
||||
|
@ -278,81 +230,11 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
|
||||
private:
|
||||
bool DoesGpuInputStartAtBottom() {
|
||||
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||
}
|
||||
|
||||
BorderMode GetBorderMode() {
|
||||
switch (options_.border_mode()) {
|
||||
case mediapipe::
|
||||
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
|
||||
return BorderMode::kReplicate;
|
||||
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
|
||||
return BorderMode::kZero;
|
||||
case mediapipe::
|
||||
ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
|
||||
return BorderMode::kReplicate;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
|
||||
if (!uses_gpu) {
|
||||
if (is_float_output_) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
if (range_min_ < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
// Always use float32 when GPU is enabled.
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
||||
int GetNumOutputChannels(const Image& image) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
if (image.UsesGpu()) {
|
||||
return 4;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
CalculatorContext* cc) {
|
||||
if (kIn(cc).IsConnected()) {
|
||||
const auto& packet = kIn(cc).packet();
|
||||
return kIn(cc).Visit(
|
||||
[&packet](const mediapipe::Image&) {
|
||||
return SharedPtrWithPacket<mediapipe::Image>(packet);
|
||||
},
|
||||
[&packet](const mediapipe::ImageFrame&) {
|
||||
return std::make_shared<const mediapipe::Image>(
|
||||
std::const_pointer_cast<mediapipe::ImageFrame>(
|
||||
SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
|
||||
});
|
||||
} else { // if (kInGpu(cc).IsConnected())
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
const GpuBuffer& input = *kInGpu(cc);
|
||||
// A shallow copy is okay since the resulting 'image' object is local in
|
||||
// Process(), and thus never outlives 'input'.
|
||||
return std::make_shared<const mediapipe::Image>(input);
|
||||
#else
|
||||
return absl::UnimplementedError(
|
||||
"GPU processing is disabled in build flags");
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status InitConverterIfNecessary(CalculatorContext* cc,
|
||||
const Image& image) {
|
||||
// Lazy initialization of the GPU or CPU converter.
|
||||
if (image.UsesGpu()) {
|
||||
if (!is_float_output_) {
|
||||
if (!params_.is_float_output) {
|
||||
return absl::UnimplementedError(
|
||||
"ImageToTensorConverter for the input GPU image currently doesn't "
|
||||
"support quantization.");
|
||||
|
@ -360,18 +242,20 @@ class ImageToTensorCalculator : public Node {
|
|||
if (!gpu_converter_) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateMetalConverter(cc, GetBorderMode()));
|
||||
ASSIGN_OR_RETURN(
|
||||
gpu_converter_,
|
||||
CreateMetalConverter(cc, GetBorderMode(options_.border_mode())));
|
||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateImageToGlBufferTensorConverter(
|
||||
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||
cc, DoesGpuInputStartAtBottom(options_),
|
||||
GetBorderMode(options_.border_mode())));
|
||||
#else
|
||||
if (!gpu_converter_) {
|
||||
ASSIGN_OR_RETURN(
|
||||
gpu_converter_,
|
||||
ASSIGN_OR_RETURN(gpu_converter_,
|
||||
CreateImageToGlTextureTensorConverter(
|
||||
cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
|
||||
cc, DoesGpuInputStartAtBottom(options_),
|
||||
GetBorderMode(options_.border_mode())));
|
||||
}
|
||||
if (!gpu_converter_) {
|
||||
return absl::UnimplementedError(
|
||||
|
@ -383,10 +267,10 @@ class ImageToTensorCalculator : public Node {
|
|||
} else {
|
||||
if (!cpu_converter_) {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
ASSIGN_OR_RETURN(
|
||||
cpu_converter_,
|
||||
CreateOpenCvConverter(cc, GetBorderMode(),
|
||||
GetOutputTensorType(/*uses_gpu=*/false)));
|
||||
ASSIGN_OR_RETURN(cpu_converter_,
|
||||
CreateOpenCvConverter(
|
||||
cc, GetBorderMode(options_.border_mode()),
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params_)));
|
||||
#else
|
||||
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
||||
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
||||
|
@ -399,11 +283,7 @@ class ImageToTensorCalculator : public Node {
|
|||
std::unique_ptr<ImageToTensorConverter> gpu_converter_;
|
||||
std::unique_ptr<ImageToTensorConverter> cpu_converter_;
|
||||
mediapipe::ImageToTensorCalculatorOptions options_;
|
||||
int output_width_ = 0;
|
||||
int output_height_ = 0;
|
||||
bool is_float_output_ = false;
|
||||
float range_min_ = 0.0f;
|
||||
float range_max_ = 1.0f;
|
||||
OutputTensorParams params_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);
|
||||
|
|
|
@ -27,12 +27,6 @@ struct Size {
|
|||
int height;
|
||||
};
|
||||
|
||||
// Pixel extrapolation method.
|
||||
// When converting image to tensor it may happen that tensor needs to read
|
||||
// pixels outside image boundaries. Border mode helps to specify how such pixels
|
||||
// will be calculated.
|
||||
enum class BorderMode { kZero, kReplicate };
|
||||
|
||||
// Converts image to tensor.
|
||||
class ImageToTensorConverter {
|
||||
public:
|
||||
|
|
|
@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||
range_max, tensor_buffer_offset]() -> absl::Status {
|
||||
constexpr int kRgbaNumChannels = 4;
|
||||
const int input_num_channels = input.channels();
|
||||
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
||||
tflite::gpu::gl::GlTexture input_texture(
|
||||
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
|
||||
GL_TEXTURE_2D, source_texture.name(),
|
||||
input_num_channels == 4 ? GL_RGB : GL_RGBA,
|
||||
source_texture.width() * source_texture.height() *
|
||||
kRgbaNumChannels * sizeof(uint8_t),
|
||||
input_num_channels * sizeof(uint8_t),
|
||||
/*layer=*/0,
|
||||
/*owned=*/false);
|
||||
|
||||
|
|
|
@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
// TODO: support tensor_buffer_offset > 0 scenario.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
|
||||
|
@ -214,4 +216,68 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
|
|||
matrix[15] = 1.0f;
|
||||
}
|
||||
|
||||
BorderMode GetBorderMode(
|
||||
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode) {
|
||||
switch (mode) {
|
||||
case mediapipe::
|
||||
ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
|
||||
return BorderMode::kReplicate;
|
||||
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
|
||||
return BorderMode::kZero;
|
||||
case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
|
||||
return BorderMode::kReplicate;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
|
||||
const OutputTensorParams& params) {
|
||||
if (!uses_gpu) {
|
||||
if (params.is_float_output) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
if (params.range_min < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
// Always use float32 when GPU is enabled.
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
||||
int GetNumOutputChannels(const mediapipe::Image& image) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
if (image.UsesGpu()) {
|
||||
return 4;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
|
||||
image_packet) {
|
||||
return image_packet.Visit(
|
||||
[&image_packet](const mediapipe::Image&) {
|
||||
return SharedPtrWithPacket<mediapipe::Image>(image_packet);
|
||||
},
|
||||
[&image_packet](const mediapipe::ImageFrame&) {
|
||||
return std::make_shared<const mediapipe::Image>(
|
||||
std::const_pointer_cast<mediapipe::ImageFrame>(
|
||||
SharedPtrWithPacket<mediapipe::ImageFrame>(image_packet)));
|
||||
});
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet) {
|
||||
// A shallow copy is okay since the resulting 'image' object is local in
|
||||
// Process(), and thus never outlives 'input'.
|
||||
return std::make_shared<const mediapipe::Image>(image_gpu_packet.Get());
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -18,8 +18,18 @@
|
|||
#include <array>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -31,6 +41,24 @@ struct RotatedRect {
|
|||
float rotation;
|
||||
};
|
||||
|
||||
// Pixel extrapolation method.
|
||||
// When converting image to tensor it may happen that tensor needs to read
|
||||
// pixels outside image boundaries. Border mode helps to specify how such pixels
|
||||
// will be calculated.
|
||||
// TODO: Consider moving this to a separate border_mode.h file.
|
||||
enum class BorderMode { kZero, kReplicate };
|
||||
|
||||
// Struct that host commonly accessed parameters used in the
|
||||
// ImageTo[Batch]TensorCalculator.
|
||||
struct OutputTensorParams {
|
||||
int output_height;
|
||||
int output_width;
|
||||
int output_batch;
|
||||
bool is_float_output;
|
||||
float range_min;
|
||||
float range_max;
|
||||
};
|
||||
|
||||
// Generates a new ROI or converts it from normalized rect.
|
||||
RotatedRect GetRoi(int input_width, int input_height,
|
||||
absl::optional<mediapipe::NormalizedRect> norm_rect);
|
||||
|
@ -95,6 +123,103 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
|
|||
const RotatedRect& sub_rect, int rect_width, int rect_height,
|
||||
bool flip_horizontaly, std::array<float, 16>* matrix);
|
||||
|
||||
// Validates the output dimensions set in the option proto. The input option
|
||||
// proto is expected to have to following fields:
|
||||
// output_tensor_float_range, output_tensor_int_range, output_tensor_uint_range
|
||||
// output_tensor_width, output_tensor_height.
|
||||
// See ImageToTensorCalculatorOptions for the description of each field.
|
||||
template <typename T>
|
||||
absl::Status ValidateOptionOutputDims(const T& options) {
|
||||
RET_CHECK(options.has_output_tensor_float_range() ||
|
||||
options.has_output_tensor_int_range() ||
|
||||
options.has_output_tensor_uint_range())
|
||||
<< "Output tensor range is required.";
|
||||
if (options.has_output_tensor_float_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_float_range().min(),
|
||||
options.output_tensor_float_range().max())
|
||||
<< "Valid output float tensor range is required.";
|
||||
}
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_uint_range().min(),
|
||||
options.output_tensor_uint_range().max())
|
||||
<< "Valid output uint tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
|
||||
<< "The minimum of the output uint tensor range must be "
|
||||
"non-negative.";
|
||||
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
|
||||
<< "The maximum of the output uint tensor range must be less than or "
|
||||
"equal to 255.";
|
||||
}
|
||||
if (options.has_output_tensor_int_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_int_range().min(),
|
||||
options.output_tensor_int_range().max())
|
||||
<< "Valid output int tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
|
||||
<< "The minimum of the output int tensor range must be greater than "
|
||||
"or equal to -128.";
|
||||
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
|
||||
<< "The maximum of the output int tensor range must be less than or "
|
||||
"equal to 127.";
|
||||
}
|
||||
RET_CHECK_GT(options.output_tensor_width(), 0)
|
||||
<< "Valid output tensor width is required.";
|
||||
RET_CHECK_GT(options.output_tensor_height(), 0)
|
||||
<< "Valid output tensor height is required.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputTensorParams GetOutputTensorParams(const T& options) {
|
||||
OutputTensorParams params;
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
params.range_min =
|
||||
static_cast<float>(options.output_tensor_uint_range().min());
|
||||
params.range_max =
|
||||
static_cast<float>(options.output_tensor_uint_range().max());
|
||||
} else if (options.has_output_tensor_int_range()) {
|
||||
params.range_min =
|
||||
static_cast<float>(options.output_tensor_int_range().min());
|
||||
params.range_max =
|
||||
static_cast<float>(options.output_tensor_int_range().max());
|
||||
} else {
|
||||
params.range_min = options.output_tensor_float_range().min();
|
||||
params.range_max = options.output_tensor_float_range().max();
|
||||
}
|
||||
params.output_width = options.output_tensor_width();
|
||||
params.output_height = options.output_tensor_height();
|
||||
params.is_float_output = options.has_output_tensor_float_range();
|
||||
params.output_batch = 1;
|
||||
return params;
|
||||
}
|
||||
|
||||
// Returns whether the GPU input format starts at the bottom.
|
||||
template <typename T>
|
||||
bool DoesGpuInputStartAtBottom(const T& options) {
|
||||
return options.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||
}
|
||||
|
||||
// Converts the BorderMode proto into struct.
|
||||
BorderMode GetBorderMode(
|
||||
const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode);
|
||||
|
||||
// Gets the output tensor type.
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
|
||||
const OutputTensorParams& params);
|
||||
|
||||
// Gets the number of output channels from the input Image format.
|
||||
int GetNumOutputChannels(const mediapipe::Image& image);
|
||||
|
||||
// Converts the packet that hosts different format (Image, ImageFrame,
|
||||
// GpuBuffer) into the mediapipe::Image format.
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
|
||||
image_packet);
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -23,6 +25,7 @@ namespace {
|
|||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
testing::Matcher<RotatedRect> EqRotatedRect(float width, float height,
|
||||
float center_x, float center_y,
|
||||
|
@ -157,5 +160,95 @@ TEST(GetValueRangeTransformation, FloatToPixel) {
|
|||
EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f));
|
||||
}
|
||||
|
||||
constexpr char kValidFloatProto[] = R"(
|
||||
output_tensor_float_range { min: 0.0 max: 1.0 }
|
||||
output_tensor_width: 100
|
||||
output_tensor_height: 200
|
||||
)";
|
||||
|
||||
constexpr char kValidIntProto[] = R"(
|
||||
output_tensor_float_range { min: 0 max: 255 }
|
||||
output_tensor_width: 100
|
||||
output_tensor_height: 200
|
||||
)";
|
||||
|
||||
TEST(ValidateOptionOutputDims, ValidProtos) {
|
||||
const auto float_options =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
|
||||
kValidFloatProto);
|
||||
MP_EXPECT_OK(ValidateOptionOutputDims(float_options));
|
||||
}
|
||||
|
||||
TEST(ValidateOptionOutputDims, EmptyProto) {
|
||||
mediapipe::ImageToTensorCalculatorOptions options;
|
||||
// No output tensor range set.
|
||||
EXPECT_THAT(ValidateOptionOutputDims(options),
|
||||
StatusIs(absl::StatusCode::kInternal,
|
||||
HasSubstr("Output tensor range is required")));
|
||||
|
||||
// Invalid output float tensor range.
|
||||
options.mutable_output_tensor_float_range()->set_min(1.0);
|
||||
options.mutable_output_tensor_float_range()->set_max(0.0);
|
||||
EXPECT_THAT(
|
||||
ValidateOptionOutputDims(options),
|
||||
StatusIs(absl::StatusCode::kInternal,
|
||||
HasSubstr("Valid output float tensor range is required")));
|
||||
|
||||
// Output width/height is not set.
|
||||
options.mutable_output_tensor_float_range()->set_min(0.0);
|
||||
options.mutable_output_tensor_float_range()->set_max(1.0);
|
||||
EXPECT_THAT(ValidateOptionOutputDims(options),
|
||||
StatusIs(absl::StatusCode::kInternal,
|
||||
HasSubstr("Valid output tensor width is required")));
|
||||
}
|
||||
|
||||
TEST(GetOutputTensorParams, SetValues) {
|
||||
// Test int range with ImageToTensorCalculatorOptions.
|
||||
const auto int_options =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
|
||||
kValidIntProto);
|
||||
const auto params2 = GetOutputTensorParams(int_options);
|
||||
EXPECT_EQ(params2.range_min, 0.0f);
|
||||
EXPECT_EQ(params2.range_max, 255.0f);
|
||||
EXPECT_EQ(params2.output_batch, 1);
|
||||
EXPECT_EQ(params2.output_width, 100);
|
||||
EXPECT_EQ(params2.output_height, 200);
|
||||
}
|
||||
|
||||
TEST(GetBorderMode, GetBorderMode) {
|
||||
// Default to REPLICATE.
|
||||
auto border_mode =
|
||||
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED;
|
||||
EXPECT_EQ(BorderMode::kReplicate, GetBorderMode(border_mode));
|
||||
|
||||
// Set to ZERO.
|
||||
border_mode =
|
||||
mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO;
|
||||
EXPECT_EQ(BorderMode::kZero, GetBorderMode(border_mode));
|
||||
}
|
||||
|
||||
TEST(GetOutputTensorType, GetOutputTensorType) {
|
||||
OutputTensorParams params;
|
||||
// Return float32 when GPU is enabled.
|
||||
EXPECT_EQ(Tensor::ElementType::kFloat32,
|
||||
GetOutputTensorType(/*uses_gpu=*/true, params));
|
||||
|
||||
// Return float32 when is_float_output is set to true.
|
||||
params.is_float_output = true;
|
||||
EXPECT_EQ(Tensor::ElementType::kFloat32,
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params));
|
||||
|
||||
// Return int8 when range_min is negative.
|
||||
params.is_float_output = false;
|
||||
params.range_min = -255.0f;
|
||||
EXPECT_EQ(Tensor::ElementType::kInt8,
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params));
|
||||
|
||||
// Return 8int8 when range_min is non-negative.
|
||||
params.range_min = 0.0f;
|
||||
EXPECT_EQ(Tensor::ElementType::kUInt8,
|
||||
GetOutputTensorType(/*uses_gpu=*/false, params));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -72,7 +72,7 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
|||
RET_CHECK(!input_tensors.empty());
|
||||
|
||||
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
|
||||
inference_runner_->Run(input_tensors));
|
||||
inference_runner_->Run(cc, input_tensors));
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
|
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
|
|||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
|
||||
std::vector<Tensor>& output_tensors) {
|
||||
return gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
// Explicitly copy input.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
glBindBuffer(GL_COPY_READ_BUFFER,
|
||||
|
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
|
|||
}
|
||||
|
||||
// Run inference.
|
||||
{
|
||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
output_tensors.reserve(output_size_);
|
||||
for (int i = 0; i < output_size_; ++i) {
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
|
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
|
|||
const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> Process(
|
||||
const std::vector<Tensor>& input_tensors);
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
|
||||
|
||||
absl::Status Close();
|
||||
|
||||
|
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
|
|||
|
||||
absl::StatusOr<std::vector<Tensor>>
|
||||
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||
const std::vector<Tensor>& input_tensors) {
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
||||
std::vector<Tensor> output_tensors;
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
||||
input_tensors[i].GetOpenGlBufferReadView().name(), i));
|
||||
|
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
|||
output_tensors.back().GetOpenGlBufferWriteView().name(), i));
|
||||
}
|
||||
// Run inference.
|
||||
{
|
||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
||||
return tflite_gpu_runner_->Invoke();
|
||||
}
|
||||
}));
|
||||
|
||||
return output_tensors;
|
||||
|
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
|
|||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
|
||||
ASSIGN_OR_RETURN(*output_tensors,
|
||||
gpu_inference_runner_->Process(input_tensors));
|
||||
gpu_inference_runner_->Process(cc, input_tensors));
|
||||
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -70,7 +70,7 @@ absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) {
|
|||
RET_CHECK(!input_tensors.empty());
|
||||
|
||||
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
|
||||
inference_runner_->Run(input_tensors));
|
||||
inference_runner_->Run(cc, input_tensors));
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -20,12 +20,15 @@
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
@ -79,7 +82,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
|||
delegate_(std::move(delegate)) {}
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> Run(
|
||||
const std::vector<Tensor>& input_tensors) override;
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) override;
|
||||
|
||||
private:
|
||||
api2::Packet<TfLiteModelPtr> model_;
|
||||
|
@ -88,7 +91,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
|||
};
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||
const std::vector<Tensor>& input_tensors) {
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
||||
// Read CPU input into tensors.
|
||||
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
|
@ -131,8 +134,10 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
}
|
||||
|
||||
// Run inference.
|
||||
{
|
||||
MEDIAPIPE_PROFILING(CPU_TASK_INVOKE, cc);
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
}
|
||||
// Output result tensors (CPU).
|
||||
const auto& tensor_indexes = interpreter_->outputs();
|
||||
std::vector<Tensor> output_tensors;
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -11,7 +12,7 @@ class InferenceRunner {
|
|||
public:
|
||||
virtual ~InferenceRunner() = default;
|
||||
virtual absl::StatusOr<std::vector<Tensor>> Run(
|
||||
const std::vector<Tensor>& inputs) = 0;
|
||||
CalculatorContext* cc, const std::vector<Tensor>& inputs) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
197
mediapipe/calculators/tensor/tensors_to_audio_calculator.cc
Normal file
197
mediapipe/calculators/tensor/tensors_to_audio_calculator.cc
Normal file
|
@ -0,0 +1,197 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <new>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "pffft.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
|
||||
std::vector<float> hann_window(window_size);
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
|
||||
if (sqrt_hann) {
|
||||
absl::c_transform(hann_window, hann_window.begin(),
|
||||
[](double x) { return std::sqrt(x); });
|
||||
}
|
||||
return hann_window;
|
||||
}
|
||||
|
||||
// Note that the InvHannWindow function may only work for 50% overlapping case.
|
||||
std::vector<float> InvHannWindow(int window_size, bool sqrt_hann) {
|
||||
std::vector<float> window = HannWindow(window_size, sqrt_hann);
|
||||
std::vector<float> inv_window(window.size());
|
||||
if (sqrt_hann) {
|
||||
absl::c_copy(window, inv_window.begin());
|
||||
} else {
|
||||
const int kHalfWindowSize = window.size() / 2;
|
||||
absl::c_transform(window, inv_window.begin(),
|
||||
[](double x) { return x * x; });
|
||||
for (int i = 0; i < kHalfWindowSize; ++i) {
|
||||
double sum = inv_window[i] + inv_window[kHalfWindowSize + i];
|
||||
inv_window[i] = window[i] / sum;
|
||||
inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum;
|
||||
}
|
||||
}
|
||||
return inv_window;
|
||||
}
|
||||
|
||||
// PFFFT only supports transforms for inputs of length N of the form
|
||||
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
|
||||
bool IsValidFftSize(int size) {
|
||||
if (size <= 0) {
|
||||
return false;
|
||||
}
|
||||
constexpr int kFactors[] = {2, 3, 5};
|
||||
int factorization[] = {0, 0, 0};
|
||||
int n = static_cast<int>(size);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
while (n % kFactors[i] == 0) {
|
||||
n = n / kFactors[i];
|
||||
++factorization[i];
|
||||
}
|
||||
}
|
||||
return factorization[0] >= 5 && n == 1;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Converts 2D MediaPipe float Tensors to audio buffers.
|
||||
// The calculator will perform ifft on the complex DFT and apply the window
|
||||
// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must
|
||||
// have the DFT real parts in its first row and the DFT imagery parts in its
|
||||
// second row. A valid "fft_size" must be set in the CalculatorOptions.
|
||||
//
|
||||
// Inputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor that represents the audio's complex DFT
|
||||
// results.
|
||||
// DC_AND_NYQUIST - std::pair<float, float>
|
||||
// A pair of dc component and nyquist component.
|
||||
//
|
||||
// Outputs:
|
||||
// AUDIO - mediapipe::Matrix
|
||||
// The audio data represented as mediapipe::Matrix.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "TensorsToAudioCalculator"
|
||||
// input_stream: "TENSORS:tensors"
|
||||
// input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||
// output_stream: "AUDIO:audio"
|
||||
// options {
|
||||
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
// fft_size: 256
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorsToAudioCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||
static constexpr Input<std::pair<float, float>> kDcAndNyquistIn{
|
||||
"DC_AND_NYQUIST"};
|
||||
static constexpr Output<Matrix> kAudioOut{"AUDIO"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// The internal state of the FFT library.
|
||||
PFFFT_Setup* fft_state_ = nullptr;
|
||||
int fft_size_ = 0;
|
||||
float inverse_fft_size_ = 0;
|
||||
std::vector<float, Eigen::aligned_allocator<float>> input_dft_;
|
||||
std::vector<float> inv_fft_window_;
|
||||
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
|
||||
// pffft requires memory to work with to avoid using the stack.
|
||||
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
|
||||
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
|
||||
};
|
||||
|
||||
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
|
||||
RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
|
||||
RET_CHECK(IsValidFftSize(options.fft_size()))
|
||||
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
|
||||
">=0 and c >= 0 and a >= 5, the requested fft size is "
|
||||
<< options.fft_size();
|
||||
fft_size_ = options.fft_size();
|
||||
inverse_fft_size_ = 1.0f / fft_size_;
|
||||
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
|
||||
input_dft_.resize(fft_size_);
|
||||
inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false);
|
||||
fft_input_buffer_.resize(fft_size_);
|
||||
fft_workplace_.resize(fft_size_);
|
||||
fft_output_.resize(fft_size_);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
|
||||
if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
const auto& input_tensors = *kTensorsIn(cc);
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||
auto view = input_tensors[0].GetCpuReadView();
|
||||
// DC's real part.
|
||||
input_dft_[0] = kDcAndNyquistIn(cc)->first;
|
||||
// Nyquist's real part is the penultimate element of the tensor buffer.
|
||||
// pffft ignores the Nyquist's imagery part. No need to fetch the last value
|
||||
// from the tensor buffer.
|
||||
input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
|
||||
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
|
||||
(fft_size_ - 2) * sizeof(float));
|
||||
pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
|
||||
fft_workplace_.data(), PFFFT_BACKWARD);
|
||||
// Applies the inverse window function.
|
||||
std::transform(
|
||||
fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(),
|
||||
fft_output_.begin(),
|
||||
[this](float a, float b) { return a * b * inverse_fft_size_; });
|
||||
Matrix matrix = Eigen::Map<Matrix>(fft_output_.data(), 1, fft_output_.size());
|
||||
kAudioOut(cc).Send(std::move(matrix));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) {
|
||||
if (fft_state_) {
|
||||
pffft_destroy_setup(fft_state_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message TensorsToAudioCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TensorsToAudioCalculatorOptions ext = 484297136;
|
||||
}
|
||||
|
||||
// Size of the fft in number of bins. If set, the calculator will do ifft
|
||||
// on the input tensor.
|
||||
optional int64 fft_size = 1;
|
||||
}
|
149
mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc
Normal file
149
mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc
Normal file
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
|
||||
protected:
|
||||
// Creates an audio matrix containing a single sample of 1.0 at a specified
|
||||
// offset.
|
||||
Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) {
|
||||
Matrix impulse = Matrix::Zero(1, num_samples);
|
||||
impulse(0, impulse_offset_idx) = 1.0;
|
||||
return impulse;
|
||||
}
|
||||
|
||||
void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
|
||||
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "audio_in"
|
||||
input_stream: "sample_rate"
|
||||
output_stream: "audio_out"
|
||||
node {
|
||||
calculator: "AudioToTensorCalculator"
|
||||
input_stream: "AUDIO:audio_in"
|
||||
input_stream: "SAMPLE_RATE:sample_rate"
|
||||
output_stream: "TENSORS:tensors"
|
||||
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||
options {
|
||||
[mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
num_channels: 1
|
||||
num_samples: $0
|
||||
num_overlapping_samples: 0
|
||||
target_sample_rate: $1
|
||||
fft_size: $2
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "TensorsToAudioCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||
output_stream: "AUDIO:audio_out"
|
||||
options {
|
||||
[mediapipe.TensorsToAudioCalculatorOptions.ext] {
|
||||
fft_size: $2
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
/*$0=*/num_samples,
|
||||
/*$1=*/sample_rate,
|
||||
/*$2=*/fft_size));
|
||||
tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
|
||||
}
|
||||
|
||||
void RunGraph(const Matrix& input_data, double sample_rate) {
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"audio_in", MakePacket<Matrix>(input_data).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
std::vector<Packet> audio_out_packets_;
|
||||
CalculatorGraphConfig graph_config_;
|
||||
CalculatorGraph graph_;
|
||||
};
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
|
||||
ConfigGraph(320, 16000, 103);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
auto status = graph_.WaitUntilIdle();
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(status.message(),
|
||||
::testing::HasSubstr("FFT size must be of the form"));
|
||||
}
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
|
||||
constexpr int sample_size = 320;
|
||||
constexpr double sample_rate = 16000;
|
||||
ConfigGraph(sample_size, sample_rate, 320);
|
||||
|
||||
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
|
||||
RunGraph(impulse_data, sample_rate);
|
||||
ASSERT_EQ(1, audio_out_packets_.size());
|
||||
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
|
||||
// The impulse signal at the center is not affected by the window function.
|
||||
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
|
||||
}
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
|
||||
constexpr int sample_size = 320;
|
||||
constexpr double sample_rate = 16000;
|
||||
ConfigGraph(sample_size, sample_rate, 320);
|
||||
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
|
||||
RunGraph(impulse_data, sample_rate);
|
||||
ASSERT_EQ(1, audio_out_packets_.size());
|
||||
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
|
||||
// As the impulse signal sits at the 1/4 of the hann window, the inverse
|
||||
// window function reduces it by half.
|
||||
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data / 2);
|
||||
}
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
|
||||
constexpr int sample_size = 320;
|
||||
constexpr double sample_rate = 16000;
|
||||
ConfigGraph(sample_size, sample_rate, 320);
|
||||
Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
|
||||
RunGraph(impulse_data, sample_rate);
|
||||
ASSERT_EQ(1, audio_out_packets_.size());
|
||||
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
|
||||
// As the impulse signal sits at the beginning of the hann window, the inverse
|
||||
// window function completely removes it.
|
||||
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -289,8 +289,15 @@ class NodeBase {
|
|||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
return *options_.MutableExtension(extension);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -386,8 +393,15 @@ class PacketGenerator {
|
|||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
return *options_.MutableExtension(extension);
|
||||
}
|
||||
|
||||
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
|
||||
|
|
|
@ -185,7 +185,7 @@ class CalculatorBaseFactory {
|
|||
// Functions for checking that the calculator has the required GetContract.
|
||||
template <class T>
|
||||
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
|
||||
typedef absl::Status (*GetContractType)(CalculatorContract * cc);
|
||||
typedef absl::Status (*GetContractType)(CalculatorContract* cc);
|
||||
return std::is_same<decltype(&T::GetContract), GetContractType>::value;
|
||||
}
|
||||
template <class T>
|
||||
|
|
|
@ -133,7 +133,13 @@ message GraphTrace {
|
|||
TPU_TASK = 13;
|
||||
GPU_CALIBRATION = 14;
|
||||
PACKET_QUEUED = 15;
|
||||
GPU_TASK_INVOKE = 16;
|
||||
TPU_TASK_INVOKE = 17;
|
||||
CPU_TASK_INVOKE = 18;
|
||||
}
|
||||
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
|
||||
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
|
||||
// )
|
||||
|
||||
// The timing for one packet set being processed at one caclulator node.
|
||||
message CalculatorTrace {
|
||||
|
|
|
@ -293,7 +293,6 @@ mediapipe_proto_library(
|
|||
name = "rect_proto",
|
||||
srcs = ["rect.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/formats:location_data_proto"],
|
||||
)
|
||||
|
||||
mediapipe_register_type(
|
||||
|
|
|
@ -109,6 +109,13 @@ struct TraceEvent {
|
|||
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
|
||||
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
|
||||
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
|
||||
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
||||
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
||||
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
|
||||
|
||||
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
|
||||
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
|
||||
// )
|
||||
};
|
||||
|
||||
// Packet trace log buffer.
|
||||
|
|
|
@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
|
|||
|
||||
// Returns a PacketSequencerCalculator node.
|
||||
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
|
||||
bool synchronize_io) {
|
||||
bool async_selection) {
|
||||
CalculatorGraphConfig::Node* result = config->add_node();
|
||||
*result->mutable_calculator() = "PacketSequencerCalculator";
|
||||
if (synchronize_io) {
|
||||
if (!async_selection) {
|
||||
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
|
||||
"DefaultInputStreamHandler";
|
||||
}
|
||||
|
@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField<std::string>& streams,
|
|||
return tags.count({tag, 0}) > 0;
|
||||
}
|
||||
|
||||
// Returns true if a set of "TAG::index" includes a TagIndex.
|
||||
bool ContainsTag(const proto_ns::RepeatedPtrField<std::string>& tags,
|
||||
TagIndex item) {
|
||||
for (const std::string& t : tags) {
|
||||
if (ParseTagIndex(t) == item) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
||||
const Subgraph::SubgraphOptions& options) {
|
||||
CalculatorGraphConfig config;
|
||||
|
@ -263,17 +272,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
std::string enable_stream = "ENABLE:gate_enable";
|
||||
|
||||
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
|
||||
bool synchronize_io =
|
||||
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
|
||||
.synchronize_io();
|
||||
const auto& switch_options =
|
||||
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options);
|
||||
bool async_selection = switch_options.async_selection();
|
||||
if (HasTag(container_node.input_stream(), "SELECT")) {
|
||||
select_node = BuildTimestampNode(&config, synchronize_io);
|
||||
select_node = BuildTimestampNode(&config, async_selection);
|
||||
select_node->add_input_stream("INPUT:gate_select");
|
||||
select_node->add_output_stream("OUTPUT:gate_select_timed");
|
||||
select_stream = "SELECT:gate_select_timed";
|
||||
}
|
||||
if (HasTag(container_node.input_stream(), "ENABLE")) {
|
||||
enable_node = BuildTimestampNode(&config, synchronize_io);
|
||||
enable_node = BuildTimestampNode(&config, async_selection);
|
||||
enable_node->add_input_stream("INPUT:gate_enable");
|
||||
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
|
||||
enable_stream = "ENABLE:gate_enable_timed";
|
||||
|
@ -296,7 +305,7 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
mux->add_input_side_packet("SELECT:gate_select");
|
||||
mux->add_input_side_packet("ENABLE:gate_enable");
|
||||
|
||||
// Add input streams for graph and demux and the timestamper.
|
||||
// Add input streams for graph and demux.
|
||||
config.add_input_stream("SELECT:gate_select");
|
||||
config.add_input_stream("ENABLE:gate_enable");
|
||||
config.add_input_side_packet("SELECT:gate_select");
|
||||
|
@ -306,6 +315,12 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
std::string stream = CatStream(p.first, p.second);
|
||||
config.add_input_stream(stream);
|
||||
demux->add_input_stream(stream);
|
||||
}
|
||||
|
||||
// Add input streams for the timestamper.
|
||||
auto& tick_streams = switch_options.tick_input_stream();
|
||||
for (const auto& p : input_tags) {
|
||||
if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue;
|
||||
TagIndex tick_tag{"TICK", tick_index++};
|
||||
if (select_node) {
|
||||
select_node->add_input_stream(CatStream(tick_tag, p.second));
|
||||
|
|
|
@ -25,6 +25,14 @@ message SwitchContainerOptions {
|
|||
// Activates channel 1 for enable = true, channel 0 otherwise.
|
||||
optional bool enable = 4;
|
||||
|
||||
// Use DefaultInputStreamHandler for muxing & demuxing.
|
||||
// Use DefaultInputStreamHandler for demuxing.
|
||||
optional bool synchronize_io = 5;
|
||||
|
||||
// Use ImmediateInputStreamHandler for channel selection.
|
||||
optional bool async_selection = 6;
|
||||
|
||||
// Specifies an input stream, "TAG:index", that defines the processed
|
||||
// timestamps. SwitchContainer awaits output at the last processed
|
||||
// timestamp before advancing from one selected channel to the next.
|
||||
repeated string tick_input_stream = 7;
|
||||
}
|
||||
|
|
|
@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|||
input_stream: "INPUT:enable"
|
||||
input_stream: "TICK:foo"
|
||||
output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
|
||||
input_stream_handler {
|
||||
input_stream_handler: "DefaultInputStreamHandler"
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "switchcontainer__SwitchDemuxCalculator"
|
||||
|
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|||
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
||||
TEST(SwitchContainerTest, RunsWithSubnodes) {
|
||||
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
||||
CalculatorGraphConfig supergraph = SubnodeContainerExample();
|
||||
CalculatorGraphConfig supergraph =
|
||||
SubnodeContainerExample("async_selection: true");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
RunTestContainer(supergraph);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
|
@ -54,21 +55,47 @@ namespace mediapipe {
|
|||
// contained subgraph or calculator nodes.
|
||||
//
|
||||
class SwitchDemuxCalculator : public CalculatorBase {
|
||||
static constexpr char kSelectTag[] = "SELECT";
|
||||
static constexpr char kEnableTag[] = "ENABLE";
|
||||
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status RecordPackets(CalculatorContext* cc);
|
||||
int ChannelIndex(Timestamp timestamp);
|
||||
absl::Status SendActivePackets(CalculatorContext* cc);
|
||||
|
||||
private:
|
||||
int channel_index_;
|
||||
std::set<std::string> channel_tags_;
|
||||
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
|
||||
PacketQueue input_queue_;
|
||||
std::map<Timestamp, int> channel_history_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SwitchDemuxCalculator);
|
||||
|
||||
namespace {
|
||||
static constexpr char kSelectTag[] = "SELECT";
|
||||
static constexpr char kEnableTag[] = "ENABLE";
|
||||
|
||||
// Returns the last received timestamp for an input stream.
|
||||
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
|
||||
return input.Value().Timestamp();
|
||||
}
|
||||
|
||||
// Returns the last received timestamp for channel selection.
|
||||
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
|
||||
Timestamp result = Timestamp::Done();
|
||||
if (cc->Inputs().HasTag(kEnableTag)) {
|
||||
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
|
||||
} else if (cc->Inputs().HasTag(kSelectTag)) {
|
||||
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
||||
// Allow any one of kSelectTag, kEnableTag.
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
||||
|
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
|||
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
channel_tags_ = ChannelTags(cc->Outputs().TagMap());
|
||||
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||
|
||||
// Relay side packets to all channels.
|
||||
// Note: This is necessary because Calculator::Open only proceeds when every
|
||||
|
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
|
||||
// Update the input channel index if specified.
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
MP_RETURN_IF_ERROR(RecordPackets(cc));
|
||||
MP_RETURN_IF_ERROR(SendActivePackets(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Relay packets and timestamps only to channel_index_.
|
||||
// Enqueue all arriving packets and bounds.
|
||||
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
|
||||
// Enqueue any new arriving packets.
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto& input = cc->Inputs().Get(tag, index);
|
||||
std::string output_tag = tool::ChannelTag(tag, channel_index_);
|
||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||
if (output_id.IsValid()) {
|
||||
auto& output = cc->Outputs().Get(output_tag, index);
|
||||
tool::Relay(input, &output);
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
Packet packet = cc->Inputs().Get(input_id).Value();
|
||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||
input_queue_[input_id].push(packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enque any new input channel and its activation timestamp.
|
||||
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
|
||||
if (channel_settled == cc->InputTimestamp() &&
|
||||
new_channel_index != channel_index_) {
|
||||
channel_index_ = new_channel_index;
|
||||
channel_history_[channel_settled] = channel_index_;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns the channel index for a Timestamp.
|
||||
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
|
||||
auto it = std::prev(channel_history_.upper_bound(timestamp));
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Dispatches all queued input packets with known channels.
|
||||
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
|
||||
// Dispatch any queued input packets with a defined channel_index.
|
||||
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
auto& queue = input_queue_[input_id];
|
||||
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
|
||||
int channel_index = ChannelIndex(queue.front().Timestamp());
|
||||
std::string output_tag = tool::ChannelTag(tag, channel_index);
|
||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||
if (output_id.IsValid()) {
|
||||
cc->Outputs().Get(output_id).AddPacket(queue.front());
|
||||
}
|
||||
queue.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Discard all select packets not needed for any remaining input packets.
|
||||
Timestamp input_settled = Timestamp::Done();
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
|
||||
if (!input_queue_[input_id].empty()) {
|
||||
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
|
||||
stream_settled =
|
||||
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
Timestamp input_bound = input_settled.NextAllowedInStream();
|
||||
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
|
||||
channel_history_.erase(channel_history_.begin(), history_bound);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
|
|||
options_ = cc->Options<mediapipe::SwitchContainerOptions>();
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
channel_tags_ = ChannelTags(cc->Inputs().TagMap());
|
||||
channel_history_[Timestamp::Unset()] = channel_index_;
|
||||
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||
|
||||
// Relay side packets only from channel_index_.
|
||||
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {
|
||||
|
|
|
@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
|
|||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
||||
|
||||
static void EglThreadExitCallback(void* key_value) {
|
||||
#if defined(__ANDROID__)
|
||||
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||
EGL_NO_CONTEXT);
|
||||
#else
|
||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
||||
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
||||
// implementations, and should be considered as an undocumented vendor
|
||||
// extension.
|
||||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||
//
|
||||
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
|
||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||
#endif
|
||||
eglReleaseThread();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
|
|||
import android.graphics.Bitmap;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.framework.image.ImageProperties;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.framework.image.MPImageProperties;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
// TODO: use Preconditions in this file.
|
||||
|
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates an Image packet from an {@link Image}.
|
||||
* Creates a MediaPipe Image packet from a {@link MPImage}.
|
||||
*
|
||||
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
|
||||
*/
|
||||
public Packet createImage(Image image) {
|
||||
public Packet createImage(MPImage image) {
|
||||
// TODO: Choose the best storage from multiple containers.
|
||||
ImageProperties properties = image.getContainedImageProperties().get(0);
|
||||
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
|
||||
MPImageProperties properties = image.getContainedImageProperties().get(0);
|
||||
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
|
||||
ByteBuffer buffer = ByteBufferExtractor.extract(image);
|
||||
int numChannels = 0;
|
||||
switch (properties.getImageFormat()) {
|
||||
case Image.IMAGE_FORMAT_RGBA:
|
||||
case MPImage.IMAGE_FORMAT_RGBA:
|
||||
numChannels = 4;
|
||||
break;
|
||||
case Image.IMAGE_FORMAT_RGB:
|
||||
case MPImage.IMAGE_FORMAT_RGB:
|
||||
numChannels = 3;
|
||||
break;
|
||||
case Image.IMAGE_FORMAT_ALPHA:
|
||||
case MPImage.IMAGE_FORMAT_ALPHA:
|
||||
numChannels = 1;
|
||||
break;
|
||||
default: // fall out
|
||||
|
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
int height = image.getHeight();
|
||||
return createImage(buffer, width, height, numChannels);
|
||||
}
|
||||
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
|
||||
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) {
|
||||
Bitmap bitmap = BitmapExtractor.extract(image);
|
||||
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
|
||||
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
|
||||
|
@ -100,7 +100,7 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
|
||||
// Unsupported type.
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported Image container type: " + properties.getImageFormat());
|
||||
"Unsupported Image container type: " + properties.getStorageType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
|
|||
import android.graphics.Bitmap;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
|
||||
* Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
|
||||
* {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
public final class BitmapExtractor {
|
||||
|
||||
/**
|
||||
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
|
||||
* Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
|
||||
*
|
||||
* @param image the image to extract {@link android.graphics.Bitmap} from.
|
||||
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
|
||||
* @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
public static Bitmap extract(Image image) {
|
||||
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
|
||||
public static Bitmap extract(MPImage image) {
|
||||
MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
|
||||
if (imageContainer != null) {
|
||||
return ((BitmapImageContainer) imageContainer).getBitmap();
|
||||
} else {
|
||||
// TODO: Support ByteBuffer -> Bitmap conversion.
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
|
||||
"Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
|
||||
+ " supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import android.provider.MediaStore;
|
|||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Builds {@link Image} from {@link android.graphics.Bitmap}.
|
||||
* Builds {@link MPImage} from {@link android.graphics.Bitmap}.
|
||||
*
|
||||
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
|
||||
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
|
||||
|
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates the builder to build {@link Image} from a file.
|
||||
* Creates the builder to build {@link MPImage} from a file.
|
||||
*
|
||||
* @param context the application context.
|
||||
* @param uri the path to the resource file.
|
||||
|
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
|
|||
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
BitmapImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(
|
||||
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,19 +16,19 @@ limitations under the License.
|
|||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
|
||||
class BitmapImageContainer implements ImageContainer {
|
||||
class BitmapImageContainer implements MPImageContainer {
|
||||
|
||||
private final Bitmap bitmap;
|
||||
private final ImageProperties properties;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public BitmapImageContainer(Bitmap bitmap) {
|
||||
this.bitmap = bitmap;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
MPImageProperties.builder()
|
||||
.setImageFormat(convertFormatCode(bitmap.getConfig()))
|
||||
.setStorageType(Image.STORAGE_TYPE_BITMAP)
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BITMAP)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
|
|||
bitmap.recycle();
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
static int convertFormatCode(Bitmap.Config config) {
|
||||
switch (config) {
|
||||
case ALPHA_8:
|
||||
return Image.IMAGE_FORMAT_ALPHA;
|
||||
return MPImage.IMAGE_FORMAT_ALPHA;
|
||||
case ARGB_8888:
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
default:
|
||||
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||
return MPImage.IMAGE_FORMAT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
|
|||
import android.os.Build.VERSION;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link ByteBuffer} from {@link Image}.
|
||||
* Utility for extracting {@link ByteBuffer} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
|
||||
* {@link IllegalArgumentException} will be thrown.
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
|
||||
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
public class ByteBufferExtractor {
|
||||
|
||||
/**
|
||||
* Extracts a {@link ByteBuffer} from an {@link Image}.
|
||||
* Extracts a {@link ByteBuffer} from a {@link MPImage}.
|
||||
*
|
||||
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
|
||||
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
|
||||
* MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
|
||||
*
|
||||
* @see Image#getContainedImageProperties()
|
||||
* @see MPImage#getContainedImageProperties()
|
||||
* @return A read-only {@link ByteBuffer}.
|
||||
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
|
||||
*/
|
||||
@SuppressLint("SwitchIntDef")
|
||||
public static ByteBuffer extract(Image image) {
|
||||
ImageContainer container = image.getContainer();
|
||||
public static ByteBuffer extract(MPImage image) {
|
||||
MPImageContainer container = image.getContainer();
|
||||
switch (container.getImageProperties().getStorageType()) {
|
||||
case Image.STORAGE_TYPE_BYTEBUFFER:
|
||||
case MPImage.STORAGE_TYPE_BYTEBUFFER:
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
|
||||
+ " supported");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
|
||||
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
|
||||
*
|
||||
* <p>Format conversion spec:
|
||||
*
|
||||
|
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
|
|||
*
|
||||
* @param image the image to extract buffer from.
|
||||
* @param targetFormat the image format of the result bytebuffer.
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
|
||||
ImageContainer container;
|
||||
ImageProperties byteBufferProperties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
MPImageContainer container;
|
||||
MPImageProperties byteBufferProperties =
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
|
||||
.setImageFormat(targetFormat)
|
||||
.build();
|
||||
if ((container = image.getContainer(byteBufferProperties)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
|
||||
.asReadOnlyBuffer();
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
|
||||
ByteBuffer byteBuffer =
|
||||
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
|
||||
|
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
|
|||
return byteBuffer;
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
|
||||
"Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
|
||||
+ " Bytebuffer is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
|
||||
/** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
|
||||
@AutoValue
|
||||
abstract static class Result {
|
||||
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||
/**
|
||||
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
|
||||
*/
|
||||
public abstract ByteBuffer buffer();
|
||||
|
||||
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||
@ImageFormat
|
||||
/**
|
||||
* Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
|
||||
*/
|
||||
@MPImageFormat
|
||||
public abstract int format();
|
||||
|
||||
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
|
||||
static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
|
||||
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
|
||||
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
|
||||
*
|
||||
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
|
||||
*
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
|
||||
* given {@code imageFormat}
|
||||
*/
|
||||
static Result extractInRecommendedFormat(Image image) {
|
||||
ImageContainer container;
|
||||
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||
static Result extractInRecommendedFormat(MPImage image) {
|
||||
MPImageContainer container;
|
||||
if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
|
||||
@ImageFormat int format = adviseImageFormat(bitmap);
|
||||
@MPImageFormat int format = adviseImageFormat(bitmap);
|
||||
Result result =
|
||||
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
|
||||
|
||||
boolean unused =
|
||||
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
|
||||
return result;
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return Result.create(
|
||||
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
|
||||
byteBufferImageContainer.getImageFormat());
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
|
||||
+ " is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
private static int adviseImageFormat(Bitmap bitmap) {
|
||||
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
|
||||
"Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
|
||||
+ " supported",
|
||||
bitmap.getConfig()));
|
||||
}
|
||||
}
|
||||
|
||||
private static ByteBuffer extractByteBufferFromBitmap(
|
||||
Bitmap bitmap, @ImageFormat int imageFormat) {
|
||||
Bitmap bitmap, @MPImageFormat int imageFormat) {
|
||||
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
|
||||
"Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
|
||||
+ " supported");
|
||||
}
|
||||
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||
if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
|
||||
bitmap.copyPixelsToBuffer(buffer);
|
||||
buffer.rewind();
|
||||
return buffer;
|
||||
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
|
||||
} else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
|
||||
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
|
||||
int w = bitmap.getWidth();
|
||||
int h = bitmap.getHeight();
|
||||
|
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
|
|||
}
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
|
||||
"Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
|
||||
+ " %d is not supported",
|
||||
bitmap.getConfig(), imageFormat));
|
||||
}
|
||||
|
||||
private static ByteBuffer convertByteBuffer(
|
||||
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
|
||||
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
|
||||
if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
|
||||
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
|
||||
// array reversely to convert in-place.
|
||||
|
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
|
|||
target.put(array, 0, target.capacity());
|
||||
target.rewind();
|
||||
return target;
|
||||
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
|
||||
} else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
|
||||
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
|
||||
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
|
||||
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
|
||||
// array to convert in-place.
|
||||
|
|
|
@ -15,11 +15,11 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* Builds a {@link Image} from a {@link ByteBuffer}.
|
||||
* Builds a {@link MPImage} from a {@link ByteBuffer}.
|
||||
*
|
||||
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
|
||||
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
|
||||
|
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
|
|||
private final ByteBuffer buffer;
|
||||
private final int width;
|
||||
private final int height;
|
||||
@ImageFormat private final int imageFormat;
|
||||
@MPImageFormat private final int imageFormat;
|
||||
|
||||
// Optional fields.
|
||||
private long timestamp;
|
||||
|
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
|
|||
* @param imageFormat how the data encode the image.
|
||||
*/
|
||||
public ByteBufferImageBuilder(
|
||||
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
|
||||
ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
|
||||
this.buffer = byteBuffer;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
|
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
|
|||
this.timestamp = 0;
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
ByteBufferImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,21 +15,19 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
class ByteBufferImageContainer implements ImageContainer {
|
||||
class ByteBufferImageContainer implements MPImageContainer {
|
||||
|
||||
private final ByteBuffer buffer;
|
||||
private final ImageProperties properties;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public ByteBufferImageContainer(
|
||||
ByteBuffer buffer,
|
||||
@ImageFormat int imageFormat) {
|
||||
public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
|
||||
this.buffer = buffer;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
|
||||
.setImageFormat(imageFormat)
|
||||
.build();
|
||||
}
|
||||
|
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the image format.
|
||||
*/
|
||||
@ImageFormat
|
||||
/** Returns the image format. */
|
||||
@MPImageFormat
|
||||
public int getImageFormat() {
|
||||
return properties.getImageFormat();
|
||||
}
|
||||
|
|
|
@ -29,10 +29,10 @@ import java.util.Map.Entry;
|
|||
/**
|
||||
* The wrapper class for image objects.
|
||||
*
|
||||
* <p>{@link Image} is designed to be an immutable image container, which could be shared
|
||||
* <p>{@link MPImage} is designed to be an immutable image container, which could be shared
|
||||
* cross-platforms.
|
||||
*
|
||||
* <p>To construct an {@link Image}, use the provided builders:
|
||||
* <p>To construct a {@link MPImage}, use the provided builders:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link ByteBufferImageBuilder}
|
||||
|
@ -40,7 +40,7 @@ import java.util.Map.Entry;
|
|||
* <li>{@link MediaImageBuilder}
|
||||
* </ul>
|
||||
*
|
||||
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
|
||||
* <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
|
||||
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
|
||||
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
|
||||
*
|
||||
|
@ -53,7 +53,7 @@ import java.util.Map.Entry;
|
|||
* <li>{@link MediaImageExtractor}
|
||||
* </ul>
|
||||
*/
|
||||
public class Image implements Closeable {
|
||||
public class MPImage implements Closeable {
|
||||
|
||||
/** Specifies the image format of an image. */
|
||||
@IntDef({
|
||||
|
@ -69,7 +69,7 @@ public class Image implements Closeable {
|
|||
IMAGE_FORMAT_JPEG,
|
||||
})
|
||||
@Retention(RetentionPolicy.SOURCE)
|
||||
public @interface ImageFormat {}
|
||||
public @interface MPImageFormat {}
|
||||
|
||||
public static final int IMAGE_FORMAT_UNKNOWN = 0;
|
||||
public static final int IMAGE_FORMAT_RGBA = 1;
|
||||
|
@ -98,14 +98,14 @@ public class Image implements Closeable {
|
|||
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
|
||||
|
||||
/**
|
||||
* Returns a list of supported image properties for this {@link Image}.
|
||||
* Returns a list of supported image properties for this {@link MPImage}.
|
||||
*
|
||||
* <p>Currently {@link Image} only support single storage type so the size of return list will
|
||||
* <p>Currently {@link MPImage} only support single storage type so the size of return list will
|
||||
* always be 1.
|
||||
*
|
||||
* @see ImageProperties
|
||||
* @see MPImageProperties
|
||||
*/
|
||||
public List<ImageProperties> getContainedImageProperties() {
|
||||
public List<MPImageProperties> getContainedImageProperties() {
|
||||
return Collections.singletonList(getContainer().getImageProperties());
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ public class Image implements Closeable {
|
|||
return height;
|
||||
}
|
||||
|
||||
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
|
||||
/** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
|
||||
private synchronized void acquire() {
|
||||
referenceCount += 1;
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ public class Image implements Closeable {
|
|||
/**
|
||||
* Removes a reference that was previously acquired or init.
|
||||
*
|
||||
* <p>When {@link Image} is created, it has 1 reference count.
|
||||
* <p>When {@link MPImage} is created, it has 1 reference count.
|
||||
*
|
||||
* <p>When the reference count becomes 0, it will release the resource under the hood.
|
||||
*/
|
||||
|
@ -141,24 +141,24 @@ public class Image implements Closeable {
|
|||
public synchronized void close() {
|
||||
referenceCount -= 1;
|
||||
if (referenceCount == 0) {
|
||||
for (ImageContainer imageContainer : containerMap.values()) {
|
||||
for (MPImageContainer imageContainer : containerMap.values()) {
|
||||
imageContainer.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Advanced API access for {@link Image}. */
|
||||
/** Advanced API access for {@link MPImage}. */
|
||||
static final class Internal {
|
||||
|
||||
/**
|
||||
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
|
||||
* Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
|
||||
*
|
||||
* <p>This method is more useful for image consumer to acquire a reference so image resource
|
||||
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
|
||||
* method.
|
||||
*
|
||||
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
|
||||
* #close()} to indicate it doesn't need this {@link Image} anymore.
|
||||
* <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
|
||||
* #close()} to indicate it doesn't need this {@link MPImage} anymore.
|
||||
*
|
||||
* @see #close()
|
||||
*/
|
||||
|
@ -166,10 +166,10 @@ public class Image implements Closeable {
|
|||
image.acquire();
|
||||
}
|
||||
|
||||
private final Image image;
|
||||
private final MPImage image;
|
||||
|
||||
// Only Image creates the internal helper.
|
||||
private Internal(Image image) {
|
||||
// Only MPImage creates the internal helper.
|
||||
private Internal(MPImage image) {
|
||||
this.image = image;
|
||||
}
|
||||
}
|
||||
|
@ -179,15 +179,15 @@ public class Image implements Closeable {
|
|||
return new Internal(this);
|
||||
}
|
||||
|
||||
private final Map<ImageProperties, ImageContainer> containerMap;
|
||||
private final Map<MPImageProperties, MPImageContainer> containerMap;
|
||||
private final long timestamp;
|
||||
private final int width;
|
||||
private final int height;
|
||||
|
||||
private int referenceCount;
|
||||
|
||||
/** Constructs an {@link Image} with a built container. */
|
||||
Image(ImageContainer container, long timestamp, int width, int height) {
|
||||
/** Constructs a {@link MPImage} with a built container. */
|
||||
MPImage(MPImageContainer container, long timestamp, int width, int height) {
|
||||
this.containerMap = new HashMap<>();
|
||||
containerMap.put(container.getImageProperties(), container);
|
||||
this.timestamp = timestamp;
|
||||
|
@ -201,10 +201,10 @@ public class Image implements Closeable {
|
|||
*
|
||||
* @return the current container.
|
||||
*/
|
||||
ImageContainer getContainer() {
|
||||
MPImageContainer getContainer() {
|
||||
// According to the design, in the future we will support multiple containers in one image.
|
||||
// Currently just return the original container.
|
||||
// TODO: Cache multiple containers in Image.
|
||||
// TODO: Cache multiple containers in MPImage.
|
||||
return containerMap.values().iterator().next();
|
||||
}
|
||||
|
||||
|
@ -214,8 +214,8 @@ public class Image implements Closeable {
|
|||
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
|
||||
*/
|
||||
@Nullable
|
||||
ImageContainer getContainer(@StorageType int storageType) {
|
||||
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
|
||||
MPImageContainer getContainer(@StorageType int storageType) {
|
||||
for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
|
||||
if (entry.getKey().getStorageType() == storageType) {
|
||||
return entry.getValue();
|
||||
}
|
||||
|
@ -225,13 +225,13 @@ public class Image implements Closeable {
|
|||
|
||||
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
|
||||
@Nullable
|
||||
ImageContainer getContainer(ImageProperties imageProperties) {
|
||||
MPImageContainer getContainer(MPImageProperties imageProperties) {
|
||||
return containerMap.get(imageProperties);
|
||||
}
|
||||
|
||||
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
|
||||
boolean addContainer(ImageContainer container) {
|
||||
ImageProperties imageProperties = container.getImageProperties();
|
||||
boolean addContainer(MPImageContainer container) {
|
||||
MPImageProperties imageProperties = container.getImageProperties();
|
||||
if (containerMap.containsKey(imageProperties)) {
|
||||
return false;
|
||||
}
|
|
@ -14,14 +14,14 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Lightweight abstraction for an object that can receive {@link Image} */
|
||||
public interface ImageConsumer {
|
||||
/** Lightweight abstraction for an object that can receive {@link MPImage} */
|
||||
public interface MPImageConsumer {
|
||||
|
||||
/**
|
||||
* Called when an {@link Image} is available.
|
||||
* Called when a {@link MPImage} is available.
|
||||
*
|
||||
* <p>The argument is only guaranteed to be available until this method returns. if you need to
|
||||
* extend its life time, acquire it, then release it when done.
|
||||
*/
|
||||
void onNewImage(Image image);
|
||||
void onNewMPImage(MPImage image);
|
||||
}
|
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Manages internal image data storage. The interface is package-private. */
|
||||
interface ImageContainer {
|
||||
interface MPImageContainer {
|
||||
/** Returns the properties of the contained image. */
|
||||
ImageProperties getImageProperties();
|
||||
MPImageProperties getImageProperties();
|
||||
|
||||
/** Close the image container and releases the image resource inside. */
|
||||
void close();
|
|
@ -14,9 +14,9 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Lightweight abstraction for an object that produce {@link Image} */
|
||||
public interface ImageProducer {
|
||||
/** Lightweight abstraction for an object that produce {@link MPImage} */
|
||||
public interface MPImageProducer {
|
||||
|
||||
/** Sets the consumer that receives the {@link Image}. */
|
||||
void setImageConsumer(ImageConsumer imageConsumer);
|
||||
/** Sets the consumer that receives the {@link MPImage}. */
|
||||
void setMPImageConsumer(MPImageConsumer imageConsumer);
|
||||
}
|
|
@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
|
|||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.auto.value.extension.memoized.Memoized;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.Image.StorageType;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.StorageType;
|
||||
|
||||
/** Groups a set of properties to describe how an image is stored. */
|
||||
@AutoValue
|
||||
public abstract class ImageProperties {
|
||||
public abstract class MPImageProperties {
|
||||
|
||||
/**
|
||||
* Gets the pixel format of the image.
|
||||
*
|
||||
* @see Image.ImageFormat
|
||||
* @see MPImage.MPImageFormat
|
||||
*/
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
public abstract int getImageFormat();
|
||||
|
||||
/**
|
||||
* Gets the storage type of the image.
|
||||
*
|
||||
* @see Image.StorageType
|
||||
* @see MPImage.StorageType
|
||||
*/
|
||||
@StorageType
|
||||
public abstract int getStorageType();
|
||||
|
@ -45,36 +45,36 @@ public abstract class ImageProperties {
|
|||
public abstract int hashCode();
|
||||
|
||||
/**
|
||||
* Creates a builder of {@link ImageProperties}.
|
||||
* Creates a builder of {@link MPImageProperties}.
|
||||
*
|
||||
* @see ImageProperties.Builder
|
||||
* @see MPImageProperties.Builder
|
||||
*/
|
||||
static Builder builder() {
|
||||
return new AutoValue_ImageProperties.Builder();
|
||||
return new AutoValue_MPImageProperties.Builder();
|
||||
}
|
||||
|
||||
/** Builds a {@link ImageProperties}. */
|
||||
/** Builds a {@link MPImageProperties}. */
|
||||
@AutoValue.Builder
|
||||
abstract static class Builder {
|
||||
|
||||
/**
|
||||
* Sets the {@link Image.ImageFormat}.
|
||||
* Sets the {@link MPImage.MPImageFormat}.
|
||||
*
|
||||
* @see ImageProperties#getImageFormat
|
||||
* @see MPImageProperties#getImageFormat
|
||||
*/
|
||||
abstract Builder setImageFormat(@ImageFormat int value);
|
||||
abstract Builder setImageFormat(@MPImageFormat int value);
|
||||
|
||||
/**
|
||||
* Sets the {@link Image.StorageType}.
|
||||
* Sets the {@link MPImage.StorageType}.
|
||||
*
|
||||
* @see ImageProperties#getStorageType
|
||||
* @see MPImageProperties#getStorageType
|
||||
*/
|
||||
abstract Builder setStorageType(@StorageType int value);
|
||||
|
||||
/** Builds the {@link ImageProperties}. */
|
||||
abstract ImageProperties build();
|
||||
/** Builds the {@link MPImageProperties}. */
|
||||
abstract MPImageProperties build();
|
||||
}
|
||||
|
||||
// Hide the constructor.
|
||||
ImageProperties() {}
|
||||
MPImageProperties() {}
|
||||
}
|
|
@ -15,11 +15,12 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Builds {@link Image} from {@link android.media.Image}.
|
||||
* Builds {@link MPImage} from {@link android.media.Image}.
|
||||
*
|
||||
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
|
||||
* content in it.
|
||||
|
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
|
|||
public class MediaImageBuilder {
|
||||
|
||||
// Mandatory fields.
|
||||
private final android.media.Image mediaImage;
|
||||
private final Image mediaImage;
|
||||
|
||||
// Optional fields.
|
||||
private long timestamp;
|
||||
|
@ -40,20 +41,20 @@ public class MediaImageBuilder {
|
|||
*
|
||||
* @param mediaImage image data object.
|
||||
*/
|
||||
public MediaImageBuilder(android.media.Image mediaImage) {
|
||||
public MediaImageBuilder(Image mediaImage) {
|
||||
this.mediaImage = mediaImage;
|
||||
this.timestamp = 0;
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
MediaImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(
|
||||
new MediaImageContainer(mediaImage),
|
||||
timestamp,
|
||||
mediaImage.getWidth(),
|
||||
|
|
|
@ -15,33 +15,34 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build;
|
||||
import android.os.Build.VERSION;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
|
||||
@RequiresApi(VERSION_CODES.KITKAT)
|
||||
class MediaImageContainer implements ImageContainer {
|
||||
class MediaImageContainer implements MPImageContainer {
|
||||
|
||||
private final android.media.Image mediaImage;
|
||||
private final ImageProperties properties;
|
||||
private final Image mediaImage;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public MediaImageContainer(android.media.Image mediaImage) {
|
||||
public MediaImageContainer(Image mediaImage) {
|
||||
this.mediaImage = mediaImage;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
|
||||
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
|
||||
.build();
|
||||
}
|
||||
|
||||
public android.media.Image getImage() {
|
||||
public Image getImage() {
|
||||
return mediaImage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
|
|||
mediaImage.close();
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
static int convertFormatCode(int graphicsFormat) {
|
||||
// We only cover the format mentioned in
|
||||
// https://developer.android.com/reference/android/media/Image#getFormat()
|
||||
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
|
||||
return Image.IMAGE_FORMAT_RGB;
|
||||
return MPImage.IMAGE_FORMAT_RGB;
|
||||
}
|
||||
}
|
||||
switch (graphicsFormat) {
|
||||
case android.graphics.ImageFormat.JPEG:
|
||||
return Image.IMAGE_FORMAT_JPEG;
|
||||
return MPImage.IMAGE_FORMAT_JPEG;
|
||||
case android.graphics.ImageFormat.YUV_420_888:
|
||||
return Image.IMAGE_FORMAT_YUV_420_888;
|
||||
return MPImage.IMAGE_FORMAT_YUV_420_888;
|
||||
default:
|
||||
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||
return MPImage.IMAGE_FORMAT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,13 +15,14 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link android.media.Image} from {@link Image}.
|
||||
* Utility for extracting {@link android.media.Image} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
|
||||
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
@RequiresApi(VERSION_CODES.KITKAT)
|
||||
|
@ -30,20 +31,20 @@ public class MediaImageExtractor {
|
|||
private MediaImageExtractor() {}
|
||||
|
||||
/**
|
||||
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
|
||||
* {@link Image} that built from {@link MediaImageBuilder}.
|
||||
* Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
|
||||
* {@link MPImage} that built from {@link MediaImageBuilder}.
|
||||
*
|
||||
* @param image the image to extract {@link android.media.Image} from.
|
||||
* @return {@link android.media.Image} that stored in {@link Image}.
|
||||
* @return {@link android.media.Image} that stored in {@link MPImage}.
|
||||
* @throws IllegalArgumentException if the extraction failed.
|
||||
*/
|
||||
public static android.media.Image extract(Image image) {
|
||||
ImageContainer container;
|
||||
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||
public static Image extract(MPImage image) {
|
||||
MPImageContainer container;
|
||||
if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||
return ((MediaImageContainer) container).getImage();
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Extract Media Image from an Image created by objects other than Media Image"
|
||||
"Extract Media Image from a MPImage created by objects other than Media Image"
|
||||
+ " is not supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019-2020 The MediaPipe Authors.
|
||||
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""):
|
|||
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
|
@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""):
|
|||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -23,15 +24,12 @@ package(
|
|||
py_library(
|
||||
name = "data_util",
|
||||
srcs = ["data_util.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "data_util_test",
|
||||
srcs = ["data_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/data/testdata"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":data_util"],
|
||||
)
|
||||
|
||||
|
@ -44,8 +42,6 @@ py_library(
|
|||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
@ -55,14 +51,11 @@ py_test(
|
|||
py_library(
|
||||
name = "classification_dataset",
|
||||
srcs = ["classification_dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classification_dataset_test",
|
||||
srcs = ["classification_dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":classification_dataset"],
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Common classification dataset library."""
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -21,15 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
|||
|
||||
|
||||
class ClassificationDataset(ds.Dataset):
|
||||
"""DataLoader for classification models."""
|
||||
"""Dataset Loader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
label_names: List[str]):
|
||||
super().__init__(dataset, size)
|
||||
self.index_to_label = index_to_label
|
||||
self._label_names = label_names
|
||||
|
||||
@property
|
||||
def num_classes(self: ds._DatasetT) -> int:
|
||||
return len(self.index_to_label)
|
||||
return len(self._label_names)
|
||||
|
||||
@property
|
||||
def label_names(self: ds._DatasetT) -> List[str]:
|
||||
return self._label_names
|
||||
|
||||
def split(self: ds._DatasetT,
|
||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||
|
@ -44,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
|
|||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction, self.index_to_label)
|
||||
return self._split(fraction, self._label_names)
|
||||
|
|
|
@ -12,45 +12,55 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
_DatasetT = TypeVar(
|
||||
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
|
||||
|
||||
class ClassificationDataLoaderTest(tf.test.TestCase):
|
||||
|
||||
class ClassificationDatasetTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
|
||||
class MagicClassificationDataLoader(
|
||||
class MagicClassificationDataset(
|
||||
classification_dataset.ClassificationDataset):
|
||||
"""A mock classification dataset class for testing purpose.
|
||||
|
||||
def __init__(self, dataset, size, index_to_label, value):
|
||||
super(MagicClassificationDataLoader,
|
||||
self).__init__(dataset, size, index_to_label)
|
||||
Attributes:
|
||||
value: A value variable stored by the mock dataset class for testing.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
label_names: List[str], value: Any):
|
||||
super().__init__(dataset=dataset, size=size, label_names=label_names)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction):
|
||||
return self._split(fraction, self.index_to_label, self.value)
|
||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
return self._split(fraction, self.label_names, self.value)
|
||||
|
||||
# Some dummy inputs.
|
||||
magic_value = 42
|
||||
num_classes = 2
|
||||
index_to_label = (False, True)
|
||||
label_names = ['foo', 'bar']
|
||||
|
||||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
|
||||
magic_value)
|
||||
data = MagicClassificationDataset(
|
||||
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
train_data, test_data = data.split(fraction)
|
||||
train_data, test_data = data.split(fraction=fraction)
|
||||
|
||||
# `split` should return instances of child DataLoader.
|
||||
self.assertIsInstance(train_data, MagicClassificationDataLoader)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataLoader)
|
||||
self.assertIsInstance(train_data, MagicClassificationDataset)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataset)
|
||||
|
||||
# Make sure number of entries are right.
|
||||
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||
|
@ -59,7 +69,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
|
|||
|
||||
# Make sure attributes propagated correctly.
|
||||
self.assertEqual(train_data.num_classes, num_classes)
|
||||
self.assertEqual(test_data.index_to_label, index_to_label)
|
||||
self.assertEqual(test_data.label_names, label_names)
|
||||
self.assertEqual(train_data.value, magic_value)
|
||||
self.assertEqual(test_data.value, magic_value)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
|
@ -23,7 +24,6 @@ licenses(["notice"])
|
|||
py_library(
|
||||
name = "custom_model",
|
||||
srcs = ["custom_model.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
|
@ -34,8 +34,6 @@ py_library(
|
|||
py_test(
|
||||
name = "custom_model_test",
|
||||
srcs = ["custom_model_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
@ -45,7 +43,6 @@ py_test(
|
|||
py_library(
|
||||
name = "classifier",
|
||||
srcs = ["classifier.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
|
@ -55,8 +52,6 @@ py_library(
|
|||
py_test(
|
||||
name = "classifier_test",
|
||||
srcs = ["classifier_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
|
|
@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
|||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
|
||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
|
||||
full_train: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_to_label: A list that map from index to label class name.
|
||||
label_names: A list of label names for the classes.
|
||||
shuffle: Whether the dataset should be shuffled.
|
||||
full_train: If true, train the model end-to-end including the backbone
|
||||
and the classification layers on top. Otherwise, only train the top
|
||||
classification layers.
|
||||
"""
|
||||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._index_to_label = index_to_label
|
||||
self._label_names = label_names
|
||||
self._full_train = full_train
|
||||
self._num_classes = len(index_to_label)
|
||||
self._num_classes = len(label_names)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
|
|||
label_filepath = os.path.join(export_dir, label_filename)
|
||||
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
||||
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
||||
f.write('\n'.join(self._index_to_label))
|
||||
f.write('\n'.join(self._label_names))
|
||||
|
|
|
@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(ClassifierTest, self).setUp()
|
||||
index_to_label = ['cat', 'dog']
|
||||
label_names = ['cat', 'dog']
|
||||
self.model = MockClassifier(
|
||||
model_spec=None,
|
||||
index_to_label=index_to_label,
|
||||
label_names=label_names,
|
||||
shuffle=False,
|
||||
full_train=False)
|
||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
|
|
@ -21,8 +21,6 @@ import abc
|
|||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
|
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
|
|||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||
# TODO: Populate metadata to the exported TFLite model.
|
||||
model_util.export_tflite(
|
||||
self._model,
|
||||
tflite_filepath,
|
||||
quantization_config,
|
||||
model=self._model,
|
||||
tflite_filepath=tflite_filepath,
|
||||
quantization_config=quantization_config,
|
||||
preprocess=preprocess)
|
||||
tf.compat.v1.logging.info(
|
||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||
|
|
|
@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(CustomModelTest, self).setUp()
|
||||
self.model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
self._model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
||||
def _check_nonempty_file(self, filepath):
|
||||
self.assertTrue(os.path.isfile(filepath))
|
||||
|
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
|
|||
|
||||
def test_export_tflite(self):
|
||||
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||
self.model.export_tflite(export_dir=export_path)
|
||||
self._model.export_tflite(export_dir=export_path)
|
||||
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -24,31 +25,15 @@ py_library(
|
|||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_preprocessing",
|
||||
srcs = ["image_preprocessing.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_preprocessing_test",
|
||||
srcs = ["image_preprocessing_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
|
@ -58,8 +43,6 @@ py_library(
|
|||
py_test(
|
||||
name = "model_util_test",
|
||||
srcs = ["model_util_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
":quantization",
|
||||
|
@ -76,8 +59,6 @@ py_library(
|
|||
py_test(
|
||||
name = "loss_functions_test",
|
||||
srcs = ["loss_functions_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":loss_functions"],
|
||||
)
|
||||
|
||||
|
@ -91,8 +72,6 @@ py_library(
|
|||
py_test(
|
||||
name = "quantization_test",
|
||||
srcs = ["quantization_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
":test_util",
|
||||
|
|
|
@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss):
|
|||
class_weight: A weight to apply to the loss, one for each class. The
|
||||
weight is applied for each input where the ground truth label matches.
|
||||
"""
|
||||
super(tf.keras.losses.Loss, self).__init__()
|
||||
super().__init__()
|
||||
# Used for clipping min/max values of probability values in y_pred to avoid
|
||||
# NaNs and Infs in computation.
|
||||
self._epsilon = 1e-7
|
||||
|
|
|
@ -104,8 +104,8 @@ def export_tflite(
|
|||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature,
|
||||
label, and is_training.
|
||||
quantization. The callable takes three arguments in order: feature, label,
|
||||
and is_training.
|
||||
"""
|
||||
if tflite_filepath is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.export_tflite(model, tflite_file)
|
||||
self._test_tflite(model, tflite_file, input_dim)
|
||||
test_util.test_tflite(
|
||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
|
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
model = test_util.build_model([input_dim], num_classes)
|
||||
model = test_util.build_model(
|
||||
input_shape=[input_dim], num_classes=num_classes)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||
|
||||
model_util.export_tflite(model, tflite_file, config)
|
||||
self._test_tflite(
|
||||
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
def _test_tflite(self,
|
||||
keras_model: tf.keras.Model,
|
||||
tflite_model_file: str,
|
||||
input_dim: int,
|
||||
max_input_value: int = 1000,
|
||||
atol: float = 1e-04):
|
||||
random_input = test_util.create_random_sample(
|
||||
size=[1, input_dim], high=max_input_value)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
model_util.export_tflite(
|
||||
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
||||
self.assertTrue(
|
||||
test_util.is_same_output(
|
||||
tflite_model_file, keras_model, random_input, atol=atol))
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
tflite_file=tflite_file,
|
||||
size=[1, input_dim],
|
||||
high=max_input_value,
|
||||
atol=1e-00))
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
|
|||
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||
|
||||
return np.allclose(lite_output, keras_output, atol=atol)
|
||||
|
||||
|
||||
def test_tflite(keras_model: tf.keras.Model,
|
||||
tflite_file: str,
|
||||
size: Union[int, List[int]],
|
||||
high: float = 1,
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||
|
||||
Args:
|
||||
keras_model: Input TensorFlow Keras model.
|
||||
tflite_file: Input TFLite model file.
|
||||
size: Size of the input tesnor.
|
||||
high: Higher boundary of the values in input tensors.
|
||||
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||
model and TFLite model.
|
||||
|
||||
Returns:
|
||||
True if the output of TFLite model and TF Keras model are identical.
|
||||
Otherwise, False.
|
||||
"""
|
||||
random_input = create_random_sample(size=size, high=high)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
return is_same_output(
|
||||
tflite_file=tflite_file,
|
||||
keras_model=keras_model,
|
||||
input_tensors=random_input,
|
||||
atol=atol)
|
||||
|
|
4
mediapipe/model_maker/python/internal/README.md
Normal file
4
mediapipe/model_maker/python/internal/README.md
Normal file
|
@ -0,0 +1,4 @@
|
|||
# MediaPipe Model Maker Internal Library
|
||||
|
||||
This directory contains model maker library for internal users and experimental
|
||||
purposes.
|
1
mediapipe/model_maker/python/internal/__init__.py
Normal file
1
mediapipe/model_maker/python/internal/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""Model maker internal library."""
|
33
mediapipe/model_maker/python/vision/core/BUILD
Normal file
33
mediapipe/model_maker/python/vision/core/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_preprocessing",
|
||||
srcs = ["image_preprocessing.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_preprocessing_test",
|
||||
srcs = ["image_preprocessing_test.py"],
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
13
mediapipe/model_maker/python/vision/core/__init__.py
Normal file
13
mediapipe/model_maker/python/vision/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -13,11 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ImageNet preprocessing."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_SIZE = 224
|
|
@ -12,15 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||
|
||||
|
||||
def _get_preprocessed_image(preprocessor, is_training=False):
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python library rule.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python library rule.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -78,9 +78,9 @@ py_library(
|
|||
":train_image_classifier_lib",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
|
@ -84,10 +84,10 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
name for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name)))
|
||||
all_label_size = len(label_names)
|
||||
label_to_index = dict(
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names))
|
||||
all_image_labels = [
|
||||
label_to_index[os.path.basename(os.path.dirname(path))]
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
|
@ -106,33 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
|
||||
|
||||
@classmethod
|
||||
def load_tf_dataset(
|
||||
cls, name: str
|
||||
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset]]:
|
||||
"""Loads data from tensorflow_datasets.
|
||||
|
||||
Args:
|
||||
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
|
||||
documentation of tfds.load for more details.
|
||||
|
||||
Returns:
|
||||
A tuple of Datasets for the train/validation/test.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input tf dataset does not have train/validation/test
|
||||
labels.
|
||||
"""
|
||||
data, info = tfds.load(name, with_info=True)
|
||||
if 'label' not in info.features:
|
||||
raise ValueError('info.features need to contain \'label\' key.')
|
||||
label_names = info.features['label'].names
|
||||
|
||||
train_data = _create_data('train', data, info, label_names)
|
||||
validation_data = _create_data('validation', data, info, label_names)
|
||||
test_data = _create_data('test', data, info, label_names)
|
||||
return train_data, validation_data, test_data
|
||||
dataset=image_label_ds, size=all_image_size, label_names=label_names)
|
||||
|
|
|
@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||
train_data, test_data = data.split(0.5)
|
||||
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
|
||||
train_data, test_data = data.split(fraction=0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
for i, elem in enumerate(train_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
self.assertEqual(train_data.num_classes, 2)
|
||||
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
|
||||
self.assertEqual(train_data.label_names, ['pos', 'neg'])
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
self.assertEqual(test_data.num_classes, 2)
|
||||
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
|
||||
self.assertEqual(test_data.label_names, ['pos', 'neg'])
|
||||
|
||||
def test_from_folder(self):
|
||||
data = dataset.Dataset.from_folder(self.image_path)
|
||||
data = dataset.Dataset.from_folder(dirname=self.image_path)
|
||||
|
||||
self.assertLen(data, 2)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
|
||||
self.assertEqual(data.label_names, ['daisy', 'tulips'])
|
||||
for image, label in data.gen_tf_dataset():
|
||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||
if label.numpy() == 0:
|
||||
|
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
|
|||
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(train_data, 1034)
|
||||
self.assertEqual(train_data.num_classes, 3)
|
||||
self.assertEqual(train_data.index_to_label,
|
||||
self.assertEqual(train_data.label_names,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(validation_data, 133)
|
||||
self.assertEqual(validation_data.num_classes, 3)
|
||||
self.assertEqual(validation_data.index_to_label,
|
||||
self.assertEqual(validation_data.label_names,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(test_data, 128)
|
||||
self.assertEqual(test_data.num_classes, 3)
|
||||
self.assertEqual(test_data.index_to_label,
|
||||
self.assertEqual(test_data.label_names,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
|
||||
|
|
|
@ -13,16 +13,16 @@
|
|||
# limitations under the License.
|
||||
"""APIs to train image classifier model."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
|
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
|
|||
class ImageClassifier(classifier.Classifier):
|
||||
"""ImageClassifier for building image classification model."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
|
||||
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
|
||||
hparams: hp.HParams):
|
||||
"""Initializes ImageClassifier class.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_to_label: A list that maps from index to label class name.
|
||||
label_names: A list of label names for the classes.
|
||||
hparams: The hyperparameters for training image classifier.
|
||||
"""
|
||||
super(ImageClassifier, self).__init__(
|
||||
super().__init__(
|
||||
model_spec=model_spec,
|
||||
index_to_label=index_to_label,
|
||||
label_names=label_names,
|
||||
shuffle=hparams.shuffle,
|
||||
full_train=hparams.do_fine_tuning)
|
||||
self._hparams = hparams
|
||||
|
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
|
||||
spec = ms.SupportedModels.get(model_spec)
|
||||
image_classifier = cls(
|
||||
model_spec=spec,
|
||||
index_to_label=train_data.index_to_label,
|
||||
hparams=hparams)
|
||||
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
||||
|
|
|
@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.train_epochs,
|
||||
steps_per_epoch=hparams.steps_per_epoch,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
||||
|
|
|
@ -161,7 +161,7 @@ class Texture {
|
|||
|
||||
~Texture() {
|
||||
if (is_owned_) {
|
||||
glDeleteProgram(handle_);
|
||||
glDeleteTextures(1, &handle_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -87,6 +87,9 @@ cc_library(
|
|||
cc_library(
|
||||
name = "builtin_task_graphs",
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
],
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
"""The public facing packet getter APIs."""
|
||||
|
||||
from typing import List, Type
|
||||
from typing import List
|
||||
|
||||
from google.protobuf import message
|
||||
from google.protobuf import symbol_database
|
||||
|
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
|
|||
get_matrix = _packet_getter.get_matrix
|
||||
|
||||
|
||||
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
|
||||
def get_proto(packet: mp_packet.Packet) -> message.Message:
|
||||
"""Get the content of a MediaPipe proto Packet as a proto message.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -46,8 +46,10 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
@ -44,6 +44,30 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "classification_aggregation_calculator_test",
|
||||
srcs = ["classification_aggregation_calculator_test.cc"],
|
||||
deps = [
|
||||
":classification_aggregation_calculator",
|
||||
":classification_aggregation_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:output_stream_poller",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "score_calibration_calculator_proto",
|
||||
srcs = ["score_calibration_calculator.proto"],
|
||||
|
|
|
@ -31,37 +31,62 @@
|
|||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
|
||||
// Aggregates ClassificationLists into a single ClassificationResult that has
|
||||
// 3 dimensions: (classification head, classification timestamp, classification
|
||||
// category).
|
||||
// Aggregates ClassificationLists into either a ClassificationResult object
|
||||
// representing the classification results aggregated by classifier head, or
|
||||
// into an std::vector<ClassificationResult> representing the classification
|
||||
// results aggregated first by timestamp then by classifier head.
|
||||
//
|
||||
// Inputs:
|
||||
// CLASSIFICATIONS - ClassificationList
|
||||
// CLASSIFICATIONS - ClassificationList @Multiple
|
||||
// ClassificationList per classification head.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of the timestamps that a single ClassificationResult
|
||||
// should aggragate. This stream is optional, and the timestamp information
|
||||
// will only be populated to the ClassificationResult proto when this stream
|
||||
// is connected.
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the CLASSIFICATIONS output is used for results.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// Example:
|
||||
// Example without timestamp aggregation:
|
||||
// node {
|
||||
// calculator: "ClassificationAggregationCalculator"
|
||||
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
||||
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
||||
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
||||
// output_stream: "CLASSIFICATIONS:classifications"
|
||||
// options {
|
||||
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
|
||||
// head_names: "head_name_a"
|
||||
// head_names: "head_name_b"
|
||||
// head_names: "head_name_c"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Example with timestamp aggregation:
|
||||
// node {
|
||||
// calculator: "ClassificationAggregationCalculator"
|
||||
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
||||
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
||||
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
||||
// input_stream: "TIMESTAMPS:timestamps"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result"
|
||||
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
|
||||
// options {
|
||||
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] {
|
||||
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
|
||||
// head_names: "head_name_a"
|
||||
// head_names: "head_name_b"
|
||||
// head_names: "head_name_c"
|
||||
|
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
|
|||
"CLASSIFICATIONS"};
|
||||
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||
"TIMESTAMPS"};
|
||||
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut);
|
||||
static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
|
||||
"CLASSIFICATIONS"};
|
||||
static constexpr Output<std::vector<ClassificationResult>>::Optional
|
||||
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
|
||||
static constexpr Output<ClassificationResult>::Optional
|
||||
kClassificationResultOut{"CLASSIFICATION_RESULT"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
|
||||
kClassificationsOut, kTimestampedClassificationsOut,
|
||||
kClassificationResultOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc);
|
||||
|
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
|
|||
cached_classifications_;
|
||||
|
||||
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
||||
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
|
||||
CalculatorContext* cc);
|
||||
// TODO: deprecate this function once migration is over.
|
||||
ClassificationResult LegacyConvertToClassificationResult(
|
||||
CalculatorContext* cc);
|
||||
};
|
||||
|
||||
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||
|
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
|
|||
<< "The size of classifications input streams should match the "
|
||||
"size of head names specified in the calculator options";
|
||||
}
|
||||
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
|
||||
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
|
||||
// not connected. All dependent tasks must be updated to use these outputs
|
||||
// first.
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
|
|||
[](const auto& elem) -> ClassificationList { return elem.Get(); });
|
||||
cached_classifications_[cc->InputTimestamp().Value()] =
|
||||
std::move(classification_lists);
|
||||
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) {
|
||||
ClassificationResult classification_result;
|
||||
if (time_aggregation_enabled_) {
|
||||
if (kTimestampsIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
kOut(cc).Send(ConvertToClassificationResult(cc));
|
||||
classification_result = LegacyConvertToClassificationResult(cc);
|
||||
kTimestampedClassificationsOut(cc).Send(
|
||||
ConvertToTimestampedClassificationResults(cc));
|
||||
} else {
|
||||
classification_result = LegacyConvertToClassificationResult(cc);
|
||||
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
|
||||
}
|
||||
kClassificationResultOut(cc).Send(classification_result);
|
||||
RET_CHECK(cached_classifications_.empty());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -136,6 +186,50 @@ ClassificationResult
|
|||
ClassificationAggregationCalculator::ConvertToClassificationResult(
|
||||
CalculatorContext* cc) {
|
||||
ClassificationResult result;
|
||||
auto& classification_lists =
|
||||
cached_classifications_[cc->InputTimestamp().Value()];
|
||||
for (int i = 0; i < classification_lists.size(); ++i) {
|
||||
auto classifications = result.add_classifications();
|
||||
classifications->set_head_index(i);
|
||||
if (!head_names_.empty()) {
|
||||
classifications->set_head_name(head_names_[i]);
|
||||
}
|
||||
*classifications->mutable_classification_list() =
|
||||
std::move(classification_lists[i]);
|
||||
}
|
||||
cached_classifications_.erase(cc->InputTimestamp().Value());
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<ClassificationResult>
|
||||
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
|
||||
CalculatorContext* cc) {
|
||||
auto timestamps = kTimestampsIn(cc).Get();
|
||||
std::vector<ClassificationResult> results;
|
||||
results.reserve(timestamps.size());
|
||||
for (const auto& timestamp : timestamps) {
|
||||
ClassificationResult result;
|
||||
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
|
||||
auto& classification_lists = cached_classifications_[timestamp.Value()];
|
||||
for (int i = 0; i < classification_lists.size(); ++i) {
|
||||
auto classifications = result.add_classifications();
|
||||
classifications->set_head_index(i);
|
||||
if (!head_names_.empty()) {
|
||||
classifications->set_head_name(head_names_[i]);
|
||||
}
|
||||
*classifications->mutable_classification_list() =
|
||||
std::move(classification_lists[i]);
|
||||
}
|
||||
cached_classifications_.erase(timestamp.Value());
|
||||
results.push_back(std::move(result));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
ClassificationResult
|
||||
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
|
||||
CalculatorContext* cc) {
|
||||
ClassificationResult result;
|
||||
Timestamp first_timestamp(0);
|
||||
std::vector<Timestamp> timestamps;
|
||||
if (time_aggregation_enabled_) {
|
||||
|
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
|
|||
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
|
||||
1000);
|
||||
}
|
||||
cached_classifications_.erase(timestamp.Value());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks;
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::ParseTextProtoOrDie;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
|
||||
constexpr char kClassificationInput0Name[] = "classifications_0";
|
||||
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
|
||||
constexpr char kClassificationInput1Name[] = "classifications_1";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kClassificationsName[] = "classifications";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
constexpr char kTimestampedClassificationsName[] =
|
||||
"timestamped_classifications";
|
||||
|
||||
ClassificationList MakeClassificationList(int class_index) {
|
||||
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
|
||||
R"pb(
|
||||
classification { index: %d }
|
||||
)pb",
|
||||
class_index));
|
||||
}
|
||||
|
||||
class ClassificationAggregationCalculatorTest
|
||||
: public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
bool connect_timestamps = false) {
|
||||
Graph graph;
|
||||
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
|
||||
calculator
|
||||
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
|
||||
ParseTextProtoOrDie<
|
||||
mediapipe::ClassificationAggregationCalculatorOptions>(
|
||||
R"pb(head_names: "foo" head_names: "bar")pb");
|
||||
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
|
||||
kClassificationInput0Name) >>
|
||||
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
|
||||
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
|
||||
kClassificationInput1Name) >>
|
||||
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
calculator.In(kTimestampsTag);
|
||||
calculator.Out(kTimestampedClassificationsTag)
|
||||
.SetName(kTimestampedClassificationsName) >>
|
||||
graph[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)];
|
||||
} else {
|
||||
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedClassificationsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kClassificationsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
absl::Status Send(
|
||||
std::vector<ClassificationList> classifications, int timestamp = 0,
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kClassificationInput0Name,
|
||||
MakePacket<ClassificationList>(classifications[0])
|
||||
.At(Timestamp(timestamp))));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kClassificationInput1Name,
|
||||
MakePacket<ClassificationList>(classifications[1])
|
||||
.At(Timestamp(timestamp))));
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
};
|
||||
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
|
||||
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
|
||||
|
||||
EXPECT_THAT(result,
|
||||
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 0 } }
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
classification_list { classification { index: 1 } }
|
||||
})pb")));
|
||||
}
|
||||
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||
MP_ASSERT_OK(Send(
|
||||
{MakeClassificationList(2), MakeClassificationList(3)},
|
||||
/*timestamp=*/1000,
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||
GetResult<std::vector<ClassificationResult>>(poller));
|
||||
|
||||
EXPECT_THAT(result,
|
||||
Pointwise(EqualsProto(),
|
||||
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 0 } }
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
classification_list { classification { index: 1 } }
|
||||
}
|
||||
)pb"),
|
||||
ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 1,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 2 } }
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
classification_list { classification { index: 3 } }
|
||||
}
|
||||
)pb")}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -29,3 +29,23 @@ cc_library(
|
|||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "category",
|
||||
srcs = ["category.cc"],
|
||||
hdrs = ["category.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_result",
|
||||
srcs = ["classification_result.cc"],
|
||||
hdrs = ["classification_result.h"],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
],
|
||||
)
|
||||
|
|
38
mediapipe/tasks/cc/components/containers/category.cc
Normal file
38
mediapipe/tasks/cc/components/containers/category.cc
Normal file
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
Category ConvertToCategory(const mediapipe::Classification& proto) {
|
||||
Category category;
|
||||
category.index = proto.index();
|
||||
category.score = proto.score();
|
||||
if (proto.has_label()) {
|
||||
category.category_name = proto.label();
|
||||
}
|
||||
if (proto.has_display_name()) {
|
||||
category.display_name = proto.display_name();
|
||||
}
|
||||
return category;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
52
mediapipe/tasks/cc/components/containers/category.h
Normal file
52
mediapipe/tasks/cc/components/containers/category.h
Normal file
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Defines a single classification result.
|
||||
//
|
||||
// The label maps packed into the TFLite Model Metadata [1] are used to populate
|
||||
// the 'category_name' and 'display_name' fields.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
struct Category {
|
||||
// The index of the category in the classification model output.
|
||||
int index;
|
||||
// The score for this category, e.g. (but not necessarily) a probability in
|
||||
// [0,1].
|
||||
float score;
|
||||
// The optional ID for the category, read from the label map packed in the
|
||||
// TFLite Model Metadata if present. Not necessarily human-readable.
|
||||
std::optional<std::string> category_name = std::nullopt;
|
||||
// The optional human-readable name for the category, read from the label map
|
||||
// packed in the TFLite Model Metadata if present.
|
||||
std::optional<std::string> display_name = std::nullopt;
|
||||
};
|
||||
|
||||
// Utility function to convert from mediapipe::Classification proto to Category
|
||||
// struct.
|
||||
Category ConvertToCategory(const mediapipe::Classification& proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
|
@ -0,0 +1,57 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
Classifications ConvertToClassifications(const proto::Classifications& proto) {
|
||||
Classifications classifications;
|
||||
classifications.categories.reserve(
|
||||
proto.classification_list().classification_size());
|
||||
for (const auto& classification :
|
||||
proto.classification_list().classification()) {
|
||||
classifications.categories.push_back(ConvertToCategory(classification));
|
||||
}
|
||||
classifications.head_index = proto.head_index();
|
||||
if (proto.has_head_name()) {
|
||||
classifications.head_name = proto.head_name();
|
||||
}
|
||||
return classifications;
|
||||
}
|
||||
|
||||
ClassificationResult ConvertToClassificationResult(
|
||||
const proto::ClassificationResult& proto) {
|
||||
ClassificationResult classification_result;
|
||||
classification_result.classifications.reserve(proto.classifications_size());
|
||||
for (const auto& classifications : proto.classifications()) {
|
||||
classification_result.classifications.push_back(
|
||||
ConvertToClassifications(classifications));
|
||||
}
|
||||
if (proto.has_timestamp_ms()) {
|
||||
classification_result.timestamp_ms = proto.timestamp_ms();
|
||||
}
|
||||
return classification_result;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
|
@ -0,0 +1,68 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Defines classification results for a given classifier head.
|
||||
struct Classifications {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
// e.g. from high to low probability.
|
||||
std::vector<Category> categories;
|
||||
// The index of the classifier head (i.e. output tensor) these categories
|
||||
// refer to. This is useful for multi-head models.
|
||||
int head_index;
|
||||
// The optional name of the classifier head, as provided in the TFLite Model
|
||||
// Metadata [1] if present. This is useful for multi-head models.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
std::optional<std::string> head_name = std::nullopt;
|
||||
};
|
||||
|
||||
// Defines classification results of a model.
|
||||
struct ClassificationResult {
|
||||
// The classification results for each head of the model.
|
||||
std::vector<Classifications> classifications;
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for classification on time series (e.g. audio
|
||||
// classification). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
std::optional<int64_t> timestamp_ms = std::nullopt;
|
||||
};
|
||||
|
||||
// Utility function to convert from Classifications proto to
|
||||
// Classifications struct.
|
||||
Classifications ConvertToClassifications(const proto::Classifications& proto);
|
||||
|
||||
// Utility function to convert from ClassificationResult proto to
|
||||
// ClassificationResult struct.
|
||||
ClassificationResult ConvertToClassificationResult(
|
||||
const proto::ClassificationResult& proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
|
@ -28,6 +28,7 @@ mediapipe_proto_library(
|
|||
srcs = ["classifications.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
"//mediapipe/framework/formats:classification_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -17,9 +17,10 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// TODO: deprecate this message once migration is over.
|
||||
// A single classification result.
|
||||
message Category {
|
||||
// The index of the category in the corresponding label map, usually packed in
|
||||
|
|
|
@ -17,11 +17,13 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
import "mediapipe/framework/formats/classification.proto";
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// TODO: deprecate this message once migration is over.
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
|
@ -33,9 +35,12 @@ message ClassificationEntry {
|
|||
optional int64 timestamp_ms = 2;
|
||||
}
|
||||
|
||||
// Classifications for a given classifier head.
|
||||
// Classifications for a given classifier head, i.e. for a given output tensor.
|
||||
message Classifications {
|
||||
// TODO: deprecate this field once migration is over.
|
||||
repeated ClassificationEntry entries = 1;
|
||||
// The classification results for this head.
|
||||
optional mediapipe.ClassificationList classification_list = 4;
|
||||
// The index of the classifier head these categories refer to. This is useful
|
||||
// for multi-head models.
|
||||
optional int32 head_index = 2;
|
||||
|
@ -45,7 +50,17 @@ message Classifications {
|
|||
optional string head_name = 3;
|
||||
}
|
||||
|
||||
// Contains one set of results per classifier head.
|
||||
// Classifications for a given classifier model.
|
||||
message ClassificationResult {
|
||||
// The classification results for each model head, i.e. one for each output
|
||||
// tensor.
|
||||
repeated Classifications classifications = 1;
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for classification on time series (e.g. audio
|
||||
// classification). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
optional int64 timestamp_ms = 2;
|
||||
}
|
||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "EmbeddingsProto";
|
||||
|
||||
// Defines a dense floating-point embedding.
|
||||
message FloatEmbedding {
|
||||
repeated float values = 1 [packed = true];
|
||||
|
|
|
@ -30,9 +30,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator(
|
|||
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
||||
std);
|
||||
}
|
||||
// TODO: need to support different GPU origin on differnt
|
||||
// platforms or applications.
|
||||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool DetermineImagePreprocessingGpuBackend(
|
||||
const core::proto::Acceleration& acceleration) {
|
||||
return acceleration.has_gpu();
|
||||
}
|
||||
|
||||
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
||||
bool use_gpu,
|
||||
ImagePreprocessingOptions* options) {
|
||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildImageTensorSpecs(model_resources));
|
||||
|
@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
|||
image_tensor_specs, options->mutable_image_to_tensor_options()));
|
||||
// The GPU backend isn't able to process int data. If the input tensor is
|
||||
// quantized, forces the image preprocessing graph to use CPU backend.
|
||||
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) {
|
||||
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
|
||||
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
|
||||
} else {
|
||||
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -19,20 +19,26 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
|
||||
// Configures an ImagePreprocessing subgraph using the provided model resources.
|
||||
// Configures an ImagePreprocessing subgraph using the provided model resources
|
||||
// When use_gpu is true, use GPU as backend to convert image to tensor.
|
||||
// - Accepts CPU input images and outputs CPU tensors.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// auto& preprocessing =
|
||||
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
// core::proto::Acceleration acceleration;
|
||||
// acceleration.mutable_xnnpack();
|
||||
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
|
||||
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
// model_resources,
|
||||
// use_gpu,
|
||||
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
|
||||
//
|
||||
// The resulting ImagePreprocessing subgraph has the following I/O:
|
||||
|
@ -56,9 +62,14 @@ namespace components {
|
|||
// The image that has the pixel data stored on the target storage (CPU vs
|
||||
// GPU).
|
||||
absl::Status ConfigureImagePreprocessing(
|
||||
const core::ModelResources& model_resources,
|
||||
const core::ModelResources& model_resources, bool use_gpu,
|
||||
ImagePreprocessingOptions* options);
|
||||
|
||||
// Determine if the image preprocessing subgraph should use GPU as the backend
|
||||
// according to the given acceleration setting.
|
||||
bool DetermineImagePreprocessingGpuBackend(
|
||||
const core::proto::Acceleration& acceleration);
|
||||
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
|||
constexpr char kScoresTag[] = "SCORES";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
|
||||
// Struct holding the different output streams produced by the graph.
|
||||
struct ClassificationPostprocessingOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||
};
|
||||
|
||||
// Performs sanity checks on provided ClassifierOptions.
|
||||
absl::Status SanityCheckClassifierOptions(
|
||||
|
@ -286,7 +294,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
|||
|
||||
void ConfigureClassificationAggregationCalculator(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
ClassificationAggregationCalculatorOptions* options) {
|
||||
mediapipe::ClassificationAggregationCalculatorOptions* options) {
|
||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr) {
|
||||
return;
|
||||
|
@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of timestamps that a single ClassificationResult should
|
||||
// aggregate. This is mostly useful for classifiers working on time series,
|
||||
// e.g. audio or video classification.
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the CLASSIFICATIONS output is used for results.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||
|
@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
auto output_streams,
|
||||
BuildClassificationPostprocessing(
|
||||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
classification_result_out >>
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.timestamped_classifications >>
|
||||
graph[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds an on-device classification postprocessing graph into the provided
|
||||
// builder::Graph instance. The classification postprocessing graph takes
|
||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
||||
// stream containing the output classification results (ClassificationResult).
|
||||
// tensors (std::vector<mediapipe::Tensor>) and optional timestamps
|
||||
// (std::vector<Timestamp>) as input and returns two output streams:
|
||||
// - classification results aggregated by classifier head as a
|
||||
// ClassificationResult proto, used when no timestamps are passed in
|
||||
// the graph,
|
||||
// - classification results aggregated by timestamp then by classifier head
|
||||
// as a std::vector<ClassificationResult>, used when timestamps are passed
|
||||
// in the graph.
|
||||
//
|
||||
// options: the on-device ClassificationPostprocessingGraphOptions.
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||
// timestamps that a single ClassificationResult should aggregate.
|
||||
// timestamps that should be used to aggregate classification results.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>>
|
||||
absl::StatusOr<ClassificationPostprocessingOutputStreams>
|
||||
BuildClassificationPostprocessing(
|
||||
const proto::ClassificationPostprocessingGraphOptions& options,
|
||||
Source<std::vector<Tensor>> tensors_in,
|
||||
|
@ -494,7 +524,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
// Aggregates Classifications into a single ClassificationResult.
|
||||
auto& result_aggregation =
|
||||
graph.AddNode("ClassificationAggregationCalculator");
|
||||
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
|
||||
result_aggregation
|
||||
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
|
||||
.CopyFrom(options.classification_aggregation_options());
|
||||
for (int i = 0; i < num_heads; ++i) {
|
||||
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
|
||||
|
@ -504,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
timestamps_in >> result_aggregation.In(kTimestampsTag);
|
||||
|
||||
// Connects output.
|
||||
return result_aggregation[Output<ClassificationResult>(
|
||||
kClassificationResultTag)];
|
||||
ClassificationPostprocessingOutputStreams output_streams{
|
||||
/*classification_result=*/result_aggregation
|
||||
[Output<ClassificationResult>(kClassificationResultTag)],
|
||||
/*classifications=*/
|
||||
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
|
||||
/*timestamped_classifications=*/
|
||||
result_aggregation[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)]};
|
||||
return output_streams;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user