Merge branch 'master' into image-embedder-python
This commit is contained in:
		
						commit
						5a68ba84b6
					
				| 
						 | 
				
			
			@ -30,6 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
 | 
			
		|||
        git \
 | 
			
		||||
        wget \
 | 
			
		||||
        unzip \
 | 
			
		||||
        nodejs \
 | 
			
		||||
        npm \
 | 
			
		||||
        python3-dev \
 | 
			
		||||
        python3-opencv \
 | 
			
		||||
        python3-pip \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -172,6 +172,10 @@ http_archive(
 | 
			
		|||
    urls = [
 | 
			
		||||
        "https://github.com/google/sentencepiece/archive/1.0.0.zip",
 | 
			
		||||
    ],
 | 
			
		||||
    patches = [
 | 
			
		||||
        "//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
 | 
			
		||||
    ],
 | 
			
		||||
    patch_args = ["-p1"],
 | 
			
		||||
    repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										14
									
								
								docs/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								docs/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,14 @@
 | 
			
		|||
# Placeholder for internal Python strict binary compatibility macro.
 | 
			
		||||
 | 
			
		||||
py_binary(
 | 
			
		||||
    name = "build_py_api_docs",
 | 
			
		||||
    srcs = ["build_py_api_docs.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe",
 | 
			
		||||
        "//third_party/py/absl:app",
 | 
			
		||||
        "//third_party/py/absl/flags",
 | 
			
		||||
        "//third_party/py/tensorflow_docs",
 | 
			
		||||
        "//third_party/py/tensorflow_docs/api_generator:generate_lib",
 | 
			
		||||
        "//third_party/py/tensorflow_docs/api_generator:public_api",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										85
									
								
								docs/build_py_api_docs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								docs/build_py_api_docs.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,85 @@
 | 
			
		|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
r"""MediaPipe reference docs generation script.
 | 
			
		||||
 | 
			
		||||
This script generates API reference docs for the `mediapipe` PIP package.
 | 
			
		||||
 | 
			
		||||
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
 | 
			
		||||
$> python build_py_api_docs.py
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from absl import app
 | 
			
		||||
from absl import flags
 | 
			
		||||
 | 
			
		||||
from tensorflow_docs.api_generator import generate_lib
 | 
			
		||||
from tensorflow_docs.api_generator import public_api
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
  # mediapipe has not been set up to work with bazel yet, so catch & report.
 | 
			
		||||
  import mediapipe  # pytype: disable=import-error
 | 
			
		||||
except ImportError as e:
 | 
			
		||||
  raise ImportError('Please `pip install mediapipe`.') from e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PROJECT_SHORT_NAME = 'mp'
 | 
			
		||||
PROJECT_FULL_NAME = 'MediaPipe'
 | 
			
		||||
 | 
			
		||||
_OUTPUT_DIR = flags.DEFINE_string(
 | 
			
		||||
    'output_dir',
 | 
			
		||||
    default='/tmp/generated_docs',
 | 
			
		||||
    help='Where to write the resulting docs.')
 | 
			
		||||
 | 
			
		||||
_URL_PREFIX = flags.DEFINE_string(
 | 
			
		||||
    'code_url_prefix',
 | 
			
		||||
    'https://github.com/google/mediapipe/tree/master/mediapipe',
 | 
			
		||||
    'The url prefix for links to code.')
 | 
			
		||||
 | 
			
		||||
_SEARCH_HINTS = flags.DEFINE_bool(
 | 
			
		||||
    'search_hints', True,
 | 
			
		||||
    'Include metadata search hints in the generated files')
 | 
			
		||||
 | 
			
		||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
 | 
			
		||||
                                 'Path prefix in the _toc.yaml')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gen_api_docs():
 | 
			
		||||
  """Generates API docs for the mediapipe package."""
 | 
			
		||||
 | 
			
		||||
  doc_generator = generate_lib.DocGenerator(
 | 
			
		||||
      root_title=PROJECT_FULL_NAME,
 | 
			
		||||
      py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
 | 
			
		||||
      base_dir=os.path.dirname(mediapipe.__file__),
 | 
			
		||||
      code_url_prefix=_URL_PREFIX.value,
 | 
			
		||||
      search_hints=_SEARCH_HINTS.value,
 | 
			
		||||
      site_path=_SITE_PATH.value,
 | 
			
		||||
      # This callback ensures that docs are only generated for objects that
 | 
			
		||||
      # are explicitly imported in your __init__.py files. There are other
 | 
			
		||||
      # options but this is a good starting point.
 | 
			
		||||
      callbacks=[public_api.explicit_package_contents_filter],
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  doc_generator.build(_OUTPUT_DIR.value)
 | 
			
		||||
 | 
			
		||||
  print('Docs output to:', _OUTPUT_DIR.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(_):
 | 
			
		||||
  gen_api_docs()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  app.run(main)
 | 
			
		||||
| 
						 | 
				
			
			@ -222,10 +222,10 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework:calculator_contract",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:collection_item_id",
 | 
			
		||||
        "//mediapipe/framework:packet",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:integral_types",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
| 
						 | 
				
			
			@ -328,6 +328,7 @@ cc_library(
 | 
			
		|||
        ":concatenate_vector_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/api2:node",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
| 
						 | 
				
			
			@ -344,6 +345,7 @@ cc_test(
 | 
			
		|||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:calculator_runner",
 | 
			
		||||
        "//mediapipe/framework:timestamp",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
 | 
			
		|||
    output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
 | 
			
		||||
    options {
 | 
			
		||||
      [mediapipe.SwitchContainerOptions.ext] {
 | 
			
		||||
        async_selection: true
 | 
			
		||||
        contained_node: { calculator: "AppearancesPassThroughSubgraph" }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
 | 
			
		|||
    output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
 | 
			
		||||
    options {
 | 
			
		||||
      [mediapipe.SwitchContainerOptions.ext] {
 | 
			
		||||
        async_selection: true
 | 
			
		||||
        contained_node: {
 | 
			
		||||
          calculator: "BypassCalculator"
 | 
			
		||||
          node_options: {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,7 @@
 | 
			
		|||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/canonical_errors.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
 | 
			
		|||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
 | 
			
		||||
 | 
			
		||||
class ConcatenateClassificationListCalculator
 | 
			
		||||
    : public ConcatenateListsCalculator<Classification, ClassificationList> {
 | 
			
		||||
 protected:
 | 
			
		||||
  int ListSize(const ClassificationList& list) const override {
 | 
			
		||||
    return list.classification_size();
 | 
			
		||||
  }
 | 
			
		||||
  const Classification GetItem(const ClassificationList& list,
 | 
			
		||||
                               int idx) const override {
 | 
			
		||||
    return list.classification(idx);
 | 
			
		||||
  }
 | 
			
		||||
  Classification* AddItem(ClassificationList& list) const override {
 | 
			
		||||
    return list.add_classification();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
 | 
			
		||||
 | 
			
		||||
}  // namespace api2
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,7 @@
 | 
			
		|||
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_runner.h"
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -70,6 +71,16 @@ void AddInputLandmarkLists(
 | 
			
		|||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void AddInputClassificationLists(
 | 
			
		||||
    const std::vector<ClassificationList>& input_classifications_vec,
 | 
			
		||||
    int64 timestamp, CalculatorRunner* runner) {
 | 
			
		||||
  for (int i = 0; i < input_classifications_vec.size(); ++i) {
 | 
			
		||||
    runner->MutableInputs()->Index(i).packets.push_back(
 | 
			
		||||
        MakePacket<ClassificationList>(input_classifications_vec[i])
 | 
			
		||||
            .At(Timestamp(timestamp)));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
 | 
			
		||||
  CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
 | 
			
		||||
                          /*options_string=*/"", /*num_inputs=*/3,
 | 
			
		||||
| 
						 | 
				
			
			@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
 | 
			
		|||
  EXPECT_EQ(0, outputs.size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
 | 
			
		||||
  CalculatorRunner runner("ConcatenateClassificationListCalculator",
 | 
			
		||||
                          /*options_string=*/
 | 
			
		||||
                          "[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
 | 
			
		||||
                          "{only_emit_if_all_present: true}",
 | 
			
		||||
                          /*num_inputs=*/2,
 | 
			
		||||
                          /*num_outputs=*/1, /*num_side_packets=*/0);
 | 
			
		||||
 | 
			
		||||
  auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
 | 
			
		||||
    classification: { index: 0 score: 0.2 label: "test_0" }
 | 
			
		||||
    classification: { index: 1 score: 0.3 label: "test_1" }
 | 
			
		||||
    classification: { index: 2 score: 0.4 label: "test_2" }
 | 
			
		||||
  )pb");
 | 
			
		||||
  auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
 | 
			
		||||
    classification: { index: 3 score: 0.2 label: "test_3" }
 | 
			
		||||
    classification: { index: 4 score: 0.3 label: "test_4" }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<ClassificationList> inputs = {input_0, input_1};
 | 
			
		||||
  AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
 | 
			
		||||
  MP_ASSERT_OK(runner.Run());
 | 
			
		||||
 | 
			
		||||
  const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
 | 
			
		||||
  EXPECT_EQ(1, outputs.size());
 | 
			
		||||
  EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
 | 
			
		||||
  auto result = outputs[0].Get<ClassificationList>();
 | 
			
		||||
  EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
 | 
			
		||||
                classification: { index: 0 score: 0.2 label: "test_0" }
 | 
			
		||||
                classification: { index: 1 score: 0.3 label: "test_1" }
 | 
			
		||||
                classification: { index: 2 score: 0.4 label: "test_2" }
 | 
			
		||||
                classification: { index: 3 score: 0.2 label: "test_3" }
 | 
			
		||||
                classification: { index: 4 score: 0.3 label: "test_4" }
 | 
			
		||||
              )pb"),
 | 
			
		||||
              EqualsProto(result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@
 | 
			
		|||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/detection.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/matrix.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/util/render_data.pb.h"
 | 
			
		||||
#include "tensorflow/lite/interpreter.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -58,4 +59,7 @@ typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
 | 
			
		|||
    EndLoopDetectionCalculator;
 | 
			
		||||
REGISTER_CALCULATOR(EndLoopDetectionCalculator);
 | 
			
		||||
 | 
			
		||||
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
 | 
			
		||||
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,7 @@ namespace mediapipe {
 | 
			
		|||
//   calculator:    "EndLoopWithOutputCalculator"
 | 
			
		||||
//   input_stream:  "ITEM:output_of_loop_body"     # ItemU     @loop_internal_ts
 | 
			
		||||
//   input_stream:  "BATCH_END:ext_ts"             # Timestamp @loop_internal_ts
 | 
			
		||||
//   output_stream: "OUTPUT:aggregated_result"     # IterableU @ext_ts
 | 
			
		||||
//   output_stream: "ITERABLE:aggregated_result"   # IterableU @ext_ts
 | 
			
		||||
// }
 | 
			
		||||
template <typename IterableT>
 | 
			
		||||
class EndLoopCalculator : public CalculatorBase {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -109,6 +109,56 @@ cc_test(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_proto_library(
 | 
			
		||||
    name = "tensors_to_audio_calculator_proto",
 | 
			
		||||
    srcs = ["tensors_to_audio_calculator.proto"],
 | 
			
		||||
    visibility = [
 | 
			
		||||
        "//mediapipe/framework:mediapipe_internal",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:calculator_options_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "tensors_to_audio_calculator",
 | 
			
		||||
    srcs = ["tensors_to_audio_calculator.cc"],
 | 
			
		||||
    visibility = [
 | 
			
		||||
        "//mediapipe/framework:mediapipe_internal",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":tensors_to_audio_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/api2:node",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_audio_tools//audio/dsp:window_functions",
 | 
			
		||||
        "@pffft",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "tensors_to_audio_calculator_test",
 | 
			
		||||
    srcs = ["tensors_to_audio_calculator_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":audio_to_tensor_calculator",
 | 
			
		||||
        ":audio_to_tensor_calculator_cc_proto",
 | 
			
		||||
        ":tensors_to_audio_calculator",
 | 
			
		||||
        ":tensors_to_audio_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_proto_library(
 | 
			
		||||
    name = "feedback_tensors_calculator_proto",
 | 
			
		||||
    srcs = ["feedback_tensors_calculator.proto"],
 | 
			
		||||
| 
						 | 
				
			
			@ -253,6 +303,26 @@ cc_library(
 | 
			
		|||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "regex_preprocessor_calculator_test",
 | 
			
		||||
    srcs = ["regex_preprocessor_calculator_test.cc"],
 | 
			
		||||
    data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
 | 
			
		||||
    linkopts = ["-ldl"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":regex_preprocessor_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/tool:sink",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/metadata:metadata_extractor",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "text_to_tensor_calculator",
 | 
			
		||||
    srcs = ["text_to_tensor_calculator.cc"],
 | 
			
		||||
| 
						 | 
				
			
			@ -304,6 +374,28 @@ cc_library(
 | 
			
		|||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "universal_sentence_encoder_preprocessor_calculator_test",
 | 
			
		||||
    srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
 | 
			
		||||
    data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":universal_sentence_encoder_preprocessor_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:packet",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
        "//mediapipe/framework/tool:options_map",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/metadata:metadata_extractor",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_proto_library(
 | 
			
		||||
| 
						 | 
				
			
			@ -438,6 +530,7 @@ cc_library(
 | 
			
		|||
    }),
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:calculator_context",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -458,6 +551,7 @@ cc_library(
 | 
			
		|||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":inference_runner",
 | 
			
		||||
        "//mediapipe/framework:mediapipe_profiling",
 | 
			
		||||
        "//mediapipe/framework/api2:packet",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
| 
						 | 
				
			
			@ -1200,13 +1294,30 @@ cc_library(
 | 
			
		|||
    name = "image_to_tensor_utils",
 | 
			
		||||
    srcs = ["image_to_tensor_utils.cc"],
 | 
			
		||||
    hdrs = ["image_to_tensor_utils.h"],
 | 
			
		||||
    copts = select({
 | 
			
		||||
        "//mediapipe:apple": [
 | 
			
		||||
            "-x objective-c++",
 | 
			
		||||
            "-fobjc-arc",  # enable reference-counting
 | 
			
		||||
        ],
 | 
			
		||||
        "//conditions:default": [],
 | 
			
		||||
    }),
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image_to_tensor_calculator_cc_proto",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/types:optional",
 | 
			
		||||
        "//mediapipe/framework/api2:packet",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "//mediapipe/framework/port:statusor",
 | 
			
		||||
        "@com_google_absl//absl/types:optional",
 | 
			
		||||
    ],
 | 
			
		||||
        "//mediapipe/gpu:gpu_origin_cc_proto",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        "//mediapipe/gpu:disable_gpu": [],
 | 
			
		||||
        "//conditions:default": ["//mediapipe/gpu:gpu_buffer"],
 | 
			
		||||
    }),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
| 
						 | 
				
			
			@ -1216,6 +1327,8 @@ cc_test(
 | 
			
		|||
        ":image_to_tensor_utils",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -133,7 +133,7 @@ bool IsValidFftSize(int size) {
 | 
			
		|||
//     invocation. In the non-streaming mode, the vector contains all of the
 | 
			
		||||
//     output timestamps for an input audio buffer.
 | 
			
		||||
//   DC_AND_NYQUIST - std::pair<float, float> @Optional.
 | 
			
		||||
//     A pair of dc component and nyquest component. Only can be connected when
 | 
			
		||||
//     A pair of dc component and nyquist component. Only can be connected when
 | 
			
		||||
//     the calculator performs fft (the fft_size is set in the calculator
 | 
			
		||||
//     options).
 | 
			
		||||
//
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,13 +54,6 @@
 | 
			
		|||
namespace mediapipe {
 | 
			
		||||
namespace api2 {
 | 
			
		||||
 | 
			
		||||
#if MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
// Just a placeholder to not have to depend on mediapipe::GpuBuffer.
 | 
			
		||||
using GpuBuffer = AnyType;
 | 
			
		||||
#else
 | 
			
		||||
using GpuBuffer = mediapipe::GpuBuffer;
 | 
			
		||||
#endif  // MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
 | 
			
		||||
// Converts image into Tensor, possibly with cropping, resizing and
 | 
			
		||||
// normalization, according to specified inputs and options.
 | 
			
		||||
//
 | 
			
		||||
| 
						 | 
				
			
			@ -141,42 +134,7 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
    const auto& options =
 | 
			
		||||
        cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
 | 
			
		||||
 | 
			
		||||
    RET_CHECK(options.has_output_tensor_float_range() ||
 | 
			
		||||
              options.has_output_tensor_int_range() ||
 | 
			
		||||
              options.has_output_tensor_uint_range())
 | 
			
		||||
        << "Output tensor range is required.";
 | 
			
		||||
    if (options.has_output_tensor_float_range()) {
 | 
			
		||||
      RET_CHECK_LT(options.output_tensor_float_range().min(),
 | 
			
		||||
                   options.output_tensor_float_range().max())
 | 
			
		||||
          << "Valid output float tensor range is required.";
 | 
			
		||||
    }
 | 
			
		||||
    if (options.has_output_tensor_uint_range()) {
 | 
			
		||||
      RET_CHECK_LT(options.output_tensor_uint_range().min(),
 | 
			
		||||
                   options.output_tensor_uint_range().max())
 | 
			
		||||
          << "Valid output uint tensor range is required.";
 | 
			
		||||
      RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
 | 
			
		||||
          << "The minimum of the output uint tensor range must be "
 | 
			
		||||
             "non-negative.";
 | 
			
		||||
      RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
 | 
			
		||||
          << "The maximum of the output uint tensor range must be less than or "
 | 
			
		||||
             "equal to 255.";
 | 
			
		||||
    }
 | 
			
		||||
    if (options.has_output_tensor_int_range()) {
 | 
			
		||||
      RET_CHECK_LT(options.output_tensor_int_range().min(),
 | 
			
		||||
                   options.output_tensor_int_range().max())
 | 
			
		||||
          << "Valid output int tensor range is required.";
 | 
			
		||||
      RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
 | 
			
		||||
          << "The minimum of the output int tensor range must be greater than "
 | 
			
		||||
             "or equal to -128.";
 | 
			
		||||
      RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
 | 
			
		||||
          << "The maximum of the output int tensor range must be less than or "
 | 
			
		||||
             "equal to 127.";
 | 
			
		||||
    }
 | 
			
		||||
    RET_CHECK_GT(options.output_tensor_width(), 0)
 | 
			
		||||
        << "Valid output tensor width is required.";
 | 
			
		||||
    RET_CHECK_GT(options.output_tensor_height(), 0)
 | 
			
		||||
        << "Valid output tensor height is required.";
 | 
			
		||||
 | 
			
		||||
    RET_CHECK_OK(ValidateOptionOutputDims(options));
 | 
			
		||||
    RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected())
 | 
			
		||||
        << "One and only one of IMAGE and IMAGE_GPU input is expected.";
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -198,21 +156,7 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
 | 
			
		||||
  absl::Status Open(CalculatorContext* cc) {
 | 
			
		||||
    options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
 | 
			
		||||
    output_width_ = options_.output_tensor_width();
 | 
			
		||||
    output_height_ = options_.output_tensor_height();
 | 
			
		||||
    is_float_output_ = options_.has_output_tensor_float_range();
 | 
			
		||||
    if (options_.has_output_tensor_uint_range()) {
 | 
			
		||||
      range_min_ =
 | 
			
		||||
          static_cast<float>(options_.output_tensor_uint_range().min());
 | 
			
		||||
      range_max_ =
 | 
			
		||||
          static_cast<float>(options_.output_tensor_uint_range().max());
 | 
			
		||||
    } else if (options_.has_output_tensor_int_range()) {
 | 
			
		||||
      range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
 | 
			
		||||
      range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
 | 
			
		||||
    } else {
 | 
			
		||||
      range_min_ = options_.output_tensor_float_range().min();
 | 
			
		||||
      range_max_ = options_.output_tensor_float_range().max();
 | 
			
		||||
    }
 | 
			
		||||
    params_ = GetOutputTensorParams(options_);
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -242,7 +186,13 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
 | 
			
		||||
#if MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
    ASSIGN_OR_RETURN(auto image, GetInputImage(kIn(cc)));
 | 
			
		||||
#else
 | 
			
		||||
    const bool is_input_gpu = kInGpu(cc).IsConnected();
 | 
			
		||||
    ASSIGN_OR_RETURN(auto image, is_input_gpu ? GetInputImage(kInGpu(cc))
 | 
			
		||||
                                              : GetInputImage(kIn(cc)));
 | 
			
		||||
#endif  // MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
 | 
			
		||||
    RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
 | 
			
		||||
    ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
 | 
			
		||||
| 
						 | 
				
			
			@ -263,11 +213,13 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
    MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
 | 
			
		||||
 | 
			
		||||
    Tensor::ElementType output_tensor_type =
 | 
			
		||||
        GetOutputTensorType(image->UsesGpu());
 | 
			
		||||
    Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
 | 
			
		||||
                                       GetNumOutputChannels(*image)});
 | 
			
		||||
        GetOutputTensorType(image->UsesGpu(), params_);
 | 
			
		||||
    Tensor tensor(output_tensor_type,
 | 
			
		||||
                  {1, params_.output_height, params_.output_width,
 | 
			
		||||
                   GetNumOutputChannels(*image)});
 | 
			
		||||
    MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
 | 
			
		||||
                           ->Convert(*image, roi, range_min_, range_max_,
 | 
			
		||||
                           ->Convert(*image, roi, params_.range_min,
 | 
			
		||||
                                     params_.range_max,
 | 
			
		||||
                                     /*tensor_buffer_offset=*/0, tensor));
 | 
			
		||||
 | 
			
		||||
    auto result = std::make_unique<std::vector<Tensor>>();
 | 
			
		||||
| 
						 | 
				
			
			@ -278,81 +230,11 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  bool DoesGpuInputStartAtBottom() {
 | 
			
		||||
    return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  BorderMode GetBorderMode() {
 | 
			
		||||
    switch (options_.border_mode()) {
 | 
			
		||||
      case mediapipe::
 | 
			
		||||
          ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
 | 
			
		||||
        return BorderMode::kReplicate;
 | 
			
		||||
      case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
 | 
			
		||||
        return BorderMode::kZero;
 | 
			
		||||
      case mediapipe::
 | 
			
		||||
          ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
 | 
			
		||||
        return BorderMode::kReplicate;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
 | 
			
		||||
    if (!uses_gpu) {
 | 
			
		||||
      if (is_float_output_) {
 | 
			
		||||
        return Tensor::ElementType::kFloat32;
 | 
			
		||||
      }
 | 
			
		||||
      if (range_min_ < 0) {
 | 
			
		||||
        return Tensor::ElementType::kInt8;
 | 
			
		||||
      } else {
 | 
			
		||||
        return Tensor::ElementType::kUInt8;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    // Always use float32 when GPU is enabled.
 | 
			
		||||
    return Tensor::ElementType::kFloat32;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int GetNumOutputChannels(const Image& image) {
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
#if MEDIAPIPE_METAL_ENABLED
 | 
			
		||||
    if (image.UsesGpu()) {
 | 
			
		||||
      return 4;
 | 
			
		||||
    }
 | 
			
		||||
#endif  // MEDIAPIPE_METAL_ENABLED
 | 
			
		||||
#endif  // !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
    // All of the processors except for Metal expect 3 channels.
 | 
			
		||||
    return 3;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
 | 
			
		||||
      CalculatorContext* cc) {
 | 
			
		||||
    if (kIn(cc).IsConnected()) {
 | 
			
		||||
      const auto& packet = kIn(cc).packet();
 | 
			
		||||
      return kIn(cc).Visit(
 | 
			
		||||
          [&packet](const mediapipe::Image&) {
 | 
			
		||||
            return SharedPtrWithPacket<mediapipe::Image>(packet);
 | 
			
		||||
          },
 | 
			
		||||
          [&packet](const mediapipe::ImageFrame&) {
 | 
			
		||||
            return std::make_shared<const mediapipe::Image>(
 | 
			
		||||
                std::const_pointer_cast<mediapipe::ImageFrame>(
 | 
			
		||||
                    SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
 | 
			
		||||
          });
 | 
			
		||||
    } else {  // if (kInGpu(cc).IsConnected())
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
      const GpuBuffer& input = *kInGpu(cc);
 | 
			
		||||
      // A shallow copy is okay since the resulting 'image' object is local in
 | 
			
		||||
      // Process(), and thus never outlives 'input'.
 | 
			
		||||
      return std::make_shared<const mediapipe::Image>(input);
 | 
			
		||||
#else
 | 
			
		||||
      return absl::UnimplementedError(
 | 
			
		||||
          "GPU processing is disabled in build flags");
 | 
			
		||||
#endif  // !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::Status InitConverterIfNecessary(CalculatorContext* cc,
 | 
			
		||||
                                        const Image& image) {
 | 
			
		||||
    // Lazy initialization of the GPU or CPU converter.
 | 
			
		||||
    if (image.UsesGpu()) {
 | 
			
		||||
      if (!is_float_output_) {
 | 
			
		||||
      if (!params_.is_float_output) {
 | 
			
		||||
        return absl::UnimplementedError(
 | 
			
		||||
            "ImageToTensorConverter for the input GPU image currently doesn't "
 | 
			
		||||
            "support quantization.");
 | 
			
		||||
| 
						 | 
				
			
			@ -360,18 +242,20 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
      if (!gpu_converter_) {
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
#if MEDIAPIPE_METAL_ENABLED
 | 
			
		||||
        ASSIGN_OR_RETURN(gpu_converter_,
 | 
			
		||||
                         CreateMetalConverter(cc, GetBorderMode()));
 | 
			
		||||
        ASSIGN_OR_RETURN(
 | 
			
		||||
            gpu_converter_,
 | 
			
		||||
            CreateMetalConverter(cc, GetBorderMode(options_.border_mode())));
 | 
			
		||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
 | 
			
		||||
        ASSIGN_OR_RETURN(gpu_converter_,
 | 
			
		||||
                         CreateImageToGlBufferTensorConverter(
 | 
			
		||||
                             cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
 | 
			
		||||
                             cc, DoesGpuInputStartAtBottom(options_),
 | 
			
		||||
                             GetBorderMode(options_.border_mode())));
 | 
			
		||||
#else
 | 
			
		||||
        if (!gpu_converter_) {
 | 
			
		||||
          ASSIGN_OR_RETURN(
 | 
			
		||||
              gpu_converter_,
 | 
			
		||||
              CreateImageToGlTextureTensorConverter(
 | 
			
		||||
                  cc, DoesGpuInputStartAtBottom(), GetBorderMode()));
 | 
			
		||||
          ASSIGN_OR_RETURN(gpu_converter_,
 | 
			
		||||
                           CreateImageToGlTextureTensorConverter(
 | 
			
		||||
                               cc, DoesGpuInputStartAtBottom(options_),
 | 
			
		||||
                               GetBorderMode(options_.border_mode())));
 | 
			
		||||
        }
 | 
			
		||||
        if (!gpu_converter_) {
 | 
			
		||||
          return absl::UnimplementedError(
 | 
			
		||||
| 
						 | 
				
			
			@ -383,10 +267,10 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
    } else {
 | 
			
		||||
      if (!cpu_converter_) {
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_OPENCV
 | 
			
		||||
        ASSIGN_OR_RETURN(
 | 
			
		||||
            cpu_converter_,
 | 
			
		||||
            CreateOpenCvConverter(cc, GetBorderMode(),
 | 
			
		||||
                                  GetOutputTensorType(/*uses_gpu=*/false)));
 | 
			
		||||
        ASSIGN_OR_RETURN(cpu_converter_,
 | 
			
		||||
                         CreateOpenCvConverter(
 | 
			
		||||
                             cc, GetBorderMode(options_.border_mode()),
 | 
			
		||||
                             GetOutputTensorType(/*uses_gpu=*/false, params_)));
 | 
			
		||||
#else
 | 
			
		||||
        LOG(FATAL) << "Cannot create image to tensor opencv converter since "
 | 
			
		||||
                      "MEDIAPIPE_DISABLE_OPENCV is defined.";
 | 
			
		||||
| 
						 | 
				
			
			@ -399,11 +283,7 @@ class ImageToTensorCalculator : public Node {
 | 
			
		|||
  std::unique_ptr<ImageToTensorConverter> gpu_converter_;
 | 
			
		||||
  std::unique_ptr<ImageToTensorConverter> cpu_converter_;
 | 
			
		||||
  mediapipe::ImageToTensorCalculatorOptions options_;
 | 
			
		||||
  int output_width_ = 0;
 | 
			
		||||
  int output_height_ = 0;
 | 
			
		||||
  bool is_float_output_ = false;
 | 
			
		||||
  float range_min_ = 0.0f;
 | 
			
		||||
  float range_max_ = 1.0f;
 | 
			
		||||
  OutputTensorParams params_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ImageToTensorCalculator);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,12 +27,6 @@ struct Size {
 | 
			
		|||
  int height;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Pixel extrapolation method.
 | 
			
		||||
// When converting image to tensor it may happen that tensor needs to read
 | 
			
		||||
// pixels outside image boundaries. Border mode helps to specify how such pixels
 | 
			
		||||
// will be calculated.
 | 
			
		||||
enum class BorderMode { kZero, kReplicate };
 | 
			
		||||
 | 
			
		||||
// Converts image to tensor.
 | 
			
		||||
class ImageToTensorConverter {
 | 
			
		||||
 public:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter {
 | 
			
		|||
                       Tensor& output_tensor) override {
 | 
			
		||||
    if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGB24) {
 | 
			
		||||
      return InvalidArgumentError(absl::StrCat(
 | 
			
		||||
          "Only 4-channel texture input formats are supported, passed format: ",
 | 
			
		||||
          static_cast<uint32_t>(input.format())));
 | 
			
		||||
          "Unsupported format: ", static_cast<uint32_t>(input.format())));
 | 
			
		||||
    }
 | 
			
		||||
    const auto& output_shape = output_tensor.shape();
 | 
			
		||||
    MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
 | 
			
		||||
| 
						 | 
				
			
			@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter {
 | 
			
		|||
    MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
 | 
			
		||||
        [this, &output_tensor, &input, &roi, &output_shape, range_min,
 | 
			
		||||
         range_max, tensor_buffer_offset]() -> absl::Status {
 | 
			
		||||
          constexpr int kRgbaNumChannels = 4;
 | 
			
		||||
          const int input_num_channels = input.channels();
 | 
			
		||||
          auto source_texture = gl_helper_.CreateSourceTexture(input);
 | 
			
		||||
          tflite::gpu::gl::GlTexture input_texture(
 | 
			
		||||
              GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
 | 
			
		||||
              GL_TEXTURE_2D, source_texture.name(),
 | 
			
		||||
              input_num_channels == 4 ? GL_RGB : GL_RGBA,
 | 
			
		||||
              source_texture.width() * source_texture.height() *
 | 
			
		||||
                  kRgbaNumChannels * sizeof(uint8_t),
 | 
			
		||||
                  input_num_channels * sizeof(uint8_t),
 | 
			
		||||
              /*layer=*/0,
 | 
			
		||||
              /*owned=*/false);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter {
 | 
			
		|||
                       Tensor& output_tensor) override {
 | 
			
		||||
    if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
 | 
			
		||||
        input.format() != mediapipe::GpuBufferFormat::kRGB24) {
 | 
			
		||||
      return InvalidArgumentError(absl::StrCat(
 | 
			
		||||
          "Only 4-channel texture input formats are supported, passed format: ",
 | 
			
		||||
          static_cast<uint32_t>(input.format())));
 | 
			
		||||
          "Unsupported format: ", static_cast<uint32_t>(input.format())));
 | 
			
		||||
    }
 | 
			
		||||
    // TODO: support tensor_buffer_offset > 0 scenario.
 | 
			
		||||
    RET_CHECK_EQ(tensor_buffer_offset, 0)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,9 @@
 | 
			
		|||
 | 
			
		||||
#include <array>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/types/optional.h"
 | 
			
		||||
#include "mediapipe/framework/api2/packet.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/statusor.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -214,4 +216,68 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
 | 
			
		|||
  matrix[15] = 1.0f;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BorderMode GetBorderMode(
 | 
			
		||||
    const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode) {
 | 
			
		||||
  switch (mode) {
 | 
			
		||||
    case mediapipe::
 | 
			
		||||
        ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED:
 | 
			
		||||
      return BorderMode::kReplicate;
 | 
			
		||||
    case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO:
 | 
			
		||||
      return BorderMode::kZero;
 | 
			
		||||
    case mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_REPLICATE:
 | 
			
		||||
      return BorderMode::kReplicate;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
 | 
			
		||||
                                        const OutputTensorParams& params) {
 | 
			
		||||
  if (!uses_gpu) {
 | 
			
		||||
    if (params.is_float_output) {
 | 
			
		||||
      return Tensor::ElementType::kFloat32;
 | 
			
		||||
    }
 | 
			
		||||
    if (params.range_min < 0) {
 | 
			
		||||
      return Tensor::ElementType::kInt8;
 | 
			
		||||
    } else {
 | 
			
		||||
      return Tensor::ElementType::kUInt8;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Always use float32 when GPU is enabled.
 | 
			
		||||
  return Tensor::ElementType::kFloat32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int GetNumOutputChannels(const mediapipe::Image& image) {
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
#if MEDIAPIPE_METAL_ENABLED
 | 
			
		||||
  if (image.UsesGpu()) {
 | 
			
		||||
    return 4;
 | 
			
		||||
  }
 | 
			
		||||
#endif  // MEDIAPIPE_METAL_ENABLED
 | 
			
		||||
#endif  // !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
  // All of the processors except for Metal expect 3 channels.
 | 
			
		||||
  return 3;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
 | 
			
		||||
    const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
 | 
			
		||||
        image_packet) {
 | 
			
		||||
  return image_packet.Visit(
 | 
			
		||||
      [&image_packet](const mediapipe::Image&) {
 | 
			
		||||
        return SharedPtrWithPacket<mediapipe::Image>(image_packet);
 | 
			
		||||
      },
 | 
			
		||||
      [&image_packet](const mediapipe::ImageFrame&) {
 | 
			
		||||
        return std::make_shared<const mediapipe::Image>(
 | 
			
		||||
            std::const_pointer_cast<mediapipe::ImageFrame>(
 | 
			
		||||
                SharedPtrWithPacket<mediapipe::ImageFrame>(image_packet)));
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
 | 
			
		||||
    const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet) {
 | 
			
		||||
  // A shallow copy is okay since the resulting 'image' object is local in
 | 
			
		||||
  // Process(), and thus never outlives 'input'.
 | 
			
		||||
  return std::make_shared<const mediapipe::Image>(image_gpu_packet.Get());
 | 
			
		||||
}
 | 
			
		||||
#endif  // !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,8 +18,18 @@
 | 
			
		|||
#include <array>
 | 
			
		||||
 | 
			
		||||
#include "absl/types/optional.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/packet.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/statusor.h"
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
#include "mediapipe/gpu/gpu_buffer.h"
 | 
			
		||||
#endif  // !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
#include "mediapipe/gpu/gpu_origin.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +41,24 @@ struct RotatedRect {
 | 
			
		|||
  float rotation;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Pixel extrapolation method.
 | 
			
		||||
// When converting image to tensor it may happen that tensor needs to read
 | 
			
		||||
// pixels outside image boundaries. Border mode helps to specify how such pixels
 | 
			
		||||
// will be calculated.
 | 
			
		||||
// TODO: Consider moving this to a separate border_mode.h file.
 | 
			
		||||
enum class BorderMode { kZero, kReplicate };
 | 
			
		||||
 | 
			
		||||
// Struct that host commonly accessed parameters used in the
 | 
			
		||||
// ImageTo[Batch]TensorCalculator.
 | 
			
		||||
struct OutputTensorParams {
 | 
			
		||||
  int output_height;
 | 
			
		||||
  int output_width;
 | 
			
		||||
  int output_batch;
 | 
			
		||||
  bool is_float_output;
 | 
			
		||||
  float range_min;
 | 
			
		||||
  float range_max;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Generates a new ROI or converts it from normalized rect.
 | 
			
		||||
RotatedRect GetRoi(int input_width, int input_height,
 | 
			
		||||
                   absl::optional<mediapipe::NormalizedRect> norm_rect);
 | 
			
		||||
| 
						 | 
				
			
			@ -95,6 +123,103 @@ void GetTransposedRotatedSubRectToRectTransformMatrix(
 | 
			
		|||
    const RotatedRect& sub_rect, int rect_width, int rect_height,
 | 
			
		||||
    bool flip_horizontaly, std::array<float, 16>* matrix);
 | 
			
		||||
 | 
			
		||||
// Validates the output dimensions set in the option proto. The input option
 | 
			
		||||
// proto is expected to have to following fields:
 | 
			
		||||
//  output_tensor_float_range, output_tensor_int_range, output_tensor_uint_range
 | 
			
		||||
//  output_tensor_width, output_tensor_height.
 | 
			
		||||
// See ImageToTensorCalculatorOptions for the description of each field.
 | 
			
		||||
template <typename T>
 | 
			
		||||
absl::Status ValidateOptionOutputDims(const T& options) {
 | 
			
		||||
  RET_CHECK(options.has_output_tensor_float_range() ||
 | 
			
		||||
            options.has_output_tensor_int_range() ||
 | 
			
		||||
            options.has_output_tensor_uint_range())
 | 
			
		||||
      << "Output tensor range is required.";
 | 
			
		||||
  if (options.has_output_tensor_float_range()) {
 | 
			
		||||
    RET_CHECK_LT(options.output_tensor_float_range().min(),
 | 
			
		||||
                 options.output_tensor_float_range().max())
 | 
			
		||||
        << "Valid output float tensor range is required.";
 | 
			
		||||
  }
 | 
			
		||||
  if (options.has_output_tensor_uint_range()) {
 | 
			
		||||
    RET_CHECK_LT(options.output_tensor_uint_range().min(),
 | 
			
		||||
                 options.output_tensor_uint_range().max())
 | 
			
		||||
        << "Valid output uint tensor range is required.";
 | 
			
		||||
    RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
 | 
			
		||||
        << "The minimum of the output uint tensor range must be "
 | 
			
		||||
           "non-negative.";
 | 
			
		||||
    RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
 | 
			
		||||
        << "The maximum of the output uint tensor range must be less than or "
 | 
			
		||||
           "equal to 255.";
 | 
			
		||||
  }
 | 
			
		||||
  if (options.has_output_tensor_int_range()) {
 | 
			
		||||
    RET_CHECK_LT(options.output_tensor_int_range().min(),
 | 
			
		||||
                 options.output_tensor_int_range().max())
 | 
			
		||||
        << "Valid output int tensor range is required.";
 | 
			
		||||
    RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
 | 
			
		||||
        << "The minimum of the output int tensor range must be greater than "
 | 
			
		||||
           "or equal to -128.";
 | 
			
		||||
    RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
 | 
			
		||||
        << "The maximum of the output int tensor range must be less than or "
 | 
			
		||||
           "equal to 127.";
 | 
			
		||||
  }
 | 
			
		||||
  RET_CHECK_GT(options.output_tensor_width(), 0)
 | 
			
		||||
      << "Valid output tensor width is required.";
 | 
			
		||||
  RET_CHECK_GT(options.output_tensor_height(), 0)
 | 
			
		||||
      << "Valid output tensor height is required.";
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
OutputTensorParams GetOutputTensorParams(const T& options) {
 | 
			
		||||
  OutputTensorParams params;
 | 
			
		||||
  if (options.has_output_tensor_uint_range()) {
 | 
			
		||||
    params.range_min =
 | 
			
		||||
        static_cast<float>(options.output_tensor_uint_range().min());
 | 
			
		||||
    params.range_max =
 | 
			
		||||
        static_cast<float>(options.output_tensor_uint_range().max());
 | 
			
		||||
  } else if (options.has_output_tensor_int_range()) {
 | 
			
		||||
    params.range_min =
 | 
			
		||||
        static_cast<float>(options.output_tensor_int_range().min());
 | 
			
		||||
    params.range_max =
 | 
			
		||||
        static_cast<float>(options.output_tensor_int_range().max());
 | 
			
		||||
  } else {
 | 
			
		||||
    params.range_min = options.output_tensor_float_range().min();
 | 
			
		||||
    params.range_max = options.output_tensor_float_range().max();
 | 
			
		||||
  }
 | 
			
		||||
  params.output_width = options.output_tensor_width();
 | 
			
		||||
  params.output_height = options.output_tensor_height();
 | 
			
		||||
  params.is_float_output = options.has_output_tensor_float_range();
 | 
			
		||||
  params.output_batch = 1;
 | 
			
		||||
  return params;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns whether the GPU input format starts at the bottom.
 | 
			
		||||
template <typename T>
 | 
			
		||||
bool DoesGpuInputStartAtBottom(const T& options) {
 | 
			
		||||
  return options.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Converts the BorderMode proto into struct.
 | 
			
		||||
BorderMode GetBorderMode(
 | 
			
		||||
    const mediapipe::ImageToTensorCalculatorOptions::BorderMode& mode);
 | 
			
		||||
 | 
			
		||||
// Gets the output tensor type.
 | 
			
		||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu,
 | 
			
		||||
                                        const OutputTensorParams& params);
 | 
			
		||||
 | 
			
		||||
// Gets the number of output channels from the input Image format.
 | 
			
		||||
int GetNumOutputChannels(const mediapipe::Image& image);
 | 
			
		||||
 | 
			
		||||
// Converts the packet that hosts different format (Image, ImageFrame,
 | 
			
		||||
// GpuBuffer) into the mediapipe::Image format.
 | 
			
		||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
 | 
			
		||||
    const api2::Packet<api2::OneOf<Image, mediapipe::ImageFrame>>&
 | 
			
		||||
        image_packet);
 | 
			
		||||
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
 | 
			
		||||
    const api2::Packet<mediapipe::GpuBuffer>& image_gpu_packet);
 | 
			
		||||
#endif  // !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,8 @@
 | 
			
		|||
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
| 
						 | 
				
			
			@ -23,6 +25,7 @@ namespace {
 | 
			
		|||
 | 
			
		||||
using ::testing::ElementsAre;
 | 
			
		||||
using ::testing::ElementsAreArray;
 | 
			
		||||
using ::testing::HasSubstr;
 | 
			
		||||
 | 
			
		||||
testing::Matcher<RotatedRect> EqRotatedRect(float width, float height,
 | 
			
		||||
                                            float center_x, float center_y,
 | 
			
		||||
| 
						 | 
				
			
			@ -157,5 +160,95 @@ TEST(GetValueRangeTransformation, FloatToPixel) {
 | 
			
		|||
              EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
constexpr char kValidFloatProto[] = R"(
 | 
			
		||||
  output_tensor_float_range { min: 0.0 max: 1.0 }
 | 
			
		||||
  output_tensor_width: 100
 | 
			
		||||
  output_tensor_height: 200
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
constexpr char kValidIntProto[] = R"(
 | 
			
		||||
  output_tensor_float_range { min: 0 max: 255 }
 | 
			
		||||
  output_tensor_width: 100
 | 
			
		||||
  output_tensor_height: 200
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
TEST(ValidateOptionOutputDims, ValidProtos) {
 | 
			
		||||
  const auto float_options =
 | 
			
		||||
      mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
 | 
			
		||||
          kValidFloatProto);
 | 
			
		||||
  MP_EXPECT_OK(ValidateOptionOutputDims(float_options));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ValidateOptionOutputDims, EmptyProto) {
 | 
			
		||||
  mediapipe::ImageToTensorCalculatorOptions options;
 | 
			
		||||
  // No output tensor range set.
 | 
			
		||||
  EXPECT_THAT(ValidateOptionOutputDims(options),
 | 
			
		||||
              StatusIs(absl::StatusCode::kInternal,
 | 
			
		||||
                       HasSubstr("Output tensor range is required")));
 | 
			
		||||
 | 
			
		||||
  // Invalid output float tensor range.
 | 
			
		||||
  options.mutable_output_tensor_float_range()->set_min(1.0);
 | 
			
		||||
  options.mutable_output_tensor_float_range()->set_max(0.0);
 | 
			
		||||
  EXPECT_THAT(
 | 
			
		||||
      ValidateOptionOutputDims(options),
 | 
			
		||||
      StatusIs(absl::StatusCode::kInternal,
 | 
			
		||||
               HasSubstr("Valid output float tensor range is required")));
 | 
			
		||||
 | 
			
		||||
  // Output width/height is not set.
 | 
			
		||||
  options.mutable_output_tensor_float_range()->set_min(0.0);
 | 
			
		||||
  options.mutable_output_tensor_float_range()->set_max(1.0);
 | 
			
		||||
  EXPECT_THAT(ValidateOptionOutputDims(options),
 | 
			
		||||
              StatusIs(absl::StatusCode::kInternal,
 | 
			
		||||
                       HasSubstr("Valid output tensor width is required")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(GetOutputTensorParams, SetValues) {
 | 
			
		||||
  // Test int range with ImageToTensorCalculatorOptions.
 | 
			
		||||
  const auto int_options =
 | 
			
		||||
      mediapipe::ParseTextProtoOrDie<mediapipe::ImageToTensorCalculatorOptions>(
 | 
			
		||||
          kValidIntProto);
 | 
			
		||||
  const auto params2 = GetOutputTensorParams(int_options);
 | 
			
		||||
  EXPECT_EQ(params2.range_min, 0.0f);
 | 
			
		||||
  EXPECT_EQ(params2.range_max, 255.0f);
 | 
			
		||||
  EXPECT_EQ(params2.output_batch, 1);
 | 
			
		||||
  EXPECT_EQ(params2.output_width, 100);
 | 
			
		||||
  EXPECT_EQ(params2.output_height, 200);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(GetBorderMode, GetBorderMode) {
 | 
			
		||||
  // Default to REPLICATE.
 | 
			
		||||
  auto border_mode =
 | 
			
		||||
      mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_UNSPECIFIED;
 | 
			
		||||
  EXPECT_EQ(BorderMode::kReplicate, GetBorderMode(border_mode));
 | 
			
		||||
 | 
			
		||||
  // Set to ZERO.
 | 
			
		||||
  border_mode =
 | 
			
		||||
      mediapipe::ImageToTensorCalculatorOptions_BorderMode_BORDER_ZERO;
 | 
			
		||||
  EXPECT_EQ(BorderMode::kZero, GetBorderMode(border_mode));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(GetOutputTensorType, GetOutputTensorType) {
 | 
			
		||||
  OutputTensorParams params;
 | 
			
		||||
  // Return float32 when GPU is enabled.
 | 
			
		||||
  EXPECT_EQ(Tensor::ElementType::kFloat32,
 | 
			
		||||
            GetOutputTensorType(/*uses_gpu=*/true, params));
 | 
			
		||||
 | 
			
		||||
  // Return float32 when is_float_output is set to true.
 | 
			
		||||
  params.is_float_output = true;
 | 
			
		||||
  EXPECT_EQ(Tensor::ElementType::kFloat32,
 | 
			
		||||
            GetOutputTensorType(/*uses_gpu=*/false, params));
 | 
			
		||||
 | 
			
		||||
  // Return int8 when range_min is negative.
 | 
			
		||||
  params.is_float_output = false;
 | 
			
		||||
  params.range_min = -255.0f;
 | 
			
		||||
  EXPECT_EQ(Tensor::ElementType::kInt8,
 | 
			
		||||
            GetOutputTensorType(/*uses_gpu=*/false, params));
 | 
			
		||||
 | 
			
		||||
  // Return 8int8 when range_min is non-negative.
 | 
			
		||||
  params.range_min = 0.0f;
 | 
			
		||||
  EXPECT_EQ(Tensor::ElementType::kUInt8,
 | 
			
		||||
            GetOutputTensorType(/*uses_gpu=*/false, params));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,7 +72,7 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
 | 
			
		|||
  RET_CHECK(!input_tensors.empty());
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
 | 
			
		||||
                   inference_runner_->Run(input_tensors));
 | 
			
		||||
                   inference_runner_->Run(cc, input_tensors));
 | 
			
		||||
  kOutTensors(cc).Send(std::move(output_tensors));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,6 +26,8 @@
 | 
			
		|||
#include "mediapipe/gpu/gl_calculator_helper.h"
 | 
			
		||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
 | 
			
		||||
 | 
			
		||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace api2 {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
 | 
			
		|||
    CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
 | 
			
		||||
    std::vector<Tensor>& output_tensors) {
 | 
			
		||||
  return gpu_helper_.RunInGlContext(
 | 
			
		||||
      [this, &input_tensors, &output_tensors]() -> absl::Status {
 | 
			
		||||
      [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
 | 
			
		||||
        // Explicitly copy input.
 | 
			
		||||
        for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
          glBindBuffer(GL_COPY_READ_BUFFER,
 | 
			
		||||
| 
						 | 
				
			
			@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
 | 
			
		|||
        }
 | 
			
		||||
 | 
			
		||||
        // Run inference.
 | 
			
		||||
        RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
 | 
			
		||||
        {
 | 
			
		||||
          MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
 | 
			
		||||
          RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        output_tensors.reserve(output_size_);
 | 
			
		||||
        for (int i = 0; i < output_size_; ++i) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,6 +32,8 @@
 | 
			
		|||
#include "mediapipe/util/android/file/base/helpers.h"
 | 
			
		||||
#endif  // MEDIAPIPE_ANDROID
 | 
			
		||||
 | 
			
		||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace api2 {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
 | 
			
		|||
        const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
 | 
			
		||||
 | 
			
		||||
    absl::StatusOr<std::vector<Tensor>> Process(
 | 
			
		||||
        const std::vector<Tensor>& input_tensors);
 | 
			
		||||
        CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
 | 
			
		||||
 | 
			
		||||
    absl::Status Close();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
 | 
			
		|||
 | 
			
		||||
absl::StatusOr<std::vector<Tensor>>
 | 
			
		||||
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
 | 
			
		||||
    const std::vector<Tensor>& input_tensors) {
 | 
			
		||||
    CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
 | 
			
		||||
  std::vector<Tensor> output_tensors;
 | 
			
		||||
 | 
			
		||||
  MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
 | 
			
		||||
      [this, &input_tensors, &output_tensors]() -> absl::Status {
 | 
			
		||||
      [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
 | 
			
		||||
        for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
          MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
 | 
			
		||||
              input_tensors[i].GetOpenGlBufferReadView().name(), i));
 | 
			
		||||
| 
						 | 
				
			
			@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
 | 
			
		|||
              output_tensors.back().GetOpenGlBufferWriteView().name(), i));
 | 
			
		||||
        }
 | 
			
		||||
        // Run inference.
 | 
			
		||||
        return tflite_gpu_runner_->Invoke();
 | 
			
		||||
        {
 | 
			
		||||
          MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
 | 
			
		||||
          return tflite_gpu_runner_->Invoke();
 | 
			
		||||
        }
 | 
			
		||||
      }));
 | 
			
		||||
 | 
			
		||||
  return output_tensors;
 | 
			
		||||
| 
						 | 
				
			
			@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
 | 
			
		|||
  auto output_tensors = absl::make_unique<std::vector<Tensor>>();
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(*output_tensors,
 | 
			
		||||
                   gpu_inference_runner_->Process(input_tensors));
 | 
			
		||||
                   gpu_inference_runner_->Process(cc, input_tensors));
 | 
			
		||||
 | 
			
		||||
  kOutTensors(cc).Send(std::move(output_tensors));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -70,7 +70,7 @@ absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) {
 | 
			
		|||
  RET_CHECK(!input_tensors.empty());
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
 | 
			
		||||
                   inference_runner_->Run(input_tensors));
 | 
			
		||||
                   inference_runner_->Run(cc, input_tensors));
 | 
			
		||||
  kOutTensors(cc).Send(std::move(output_tensors));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,12 +20,15 @@
 | 
			
		|||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/mediapipe_profiling.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "tensorflow/lite/c/c_api_types.h"
 | 
			
		||||
#include "tensorflow/lite/interpreter.h"
 | 
			
		||||
#include "tensorflow/lite/interpreter_builder.h"
 | 
			
		||||
#include "tensorflow/lite/string_util.h"
 | 
			
		||||
 | 
			
		||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +82,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
 | 
			
		|||
        delegate_(std::move(delegate)) {}
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::vector<Tensor>> Run(
 | 
			
		||||
      const std::vector<Tensor>& input_tensors) override;
 | 
			
		||||
      CalculatorContext* cc, const std::vector<Tensor>& input_tensors) override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  api2::Packet<TfLiteModelPtr> model_;
 | 
			
		||||
| 
						 | 
				
			
			@ -88,7 +91,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
 | 
			
		|||
};
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		||||
    const std::vector<Tensor>& input_tensors) {
 | 
			
		||||
    CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
 | 
			
		||||
  // Read CPU input into tensors.
 | 
			
		||||
  RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
 | 
			
		||||
  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
| 
						 | 
				
			
			@ -131,8 +134,10 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  // Run inference.
 | 
			
		||||
  RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    MEDIAPIPE_PROFILING(CPU_TASK_INVOKE, cc);
 | 
			
		||||
    RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
 | 
			
		||||
  }
 | 
			
		||||
  // Output result tensors (CPU).
 | 
			
		||||
  const auto& tensor_indexes = interpreter_->outputs();
 | 
			
		||||
  std::vector<Tensor> output_tensors;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@
 | 
			
		|||
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
 | 
			
		||||
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_context.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
| 
						 | 
				
			
			@ -11,7 +12,7 @@ class InferenceRunner {
 | 
			
		|||
 public:
 | 
			
		||||
  virtual ~InferenceRunner() = default;
 | 
			
		||||
  virtual absl::StatusOr<std::vector<Tensor>> Run(
 | 
			
		||||
      const std::vector<Tensor>& inputs) = 0;
 | 
			
		||||
      CalculatorContext* cc, const std::vector<Tensor>& inputs) = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										197
									
								
								mediapipe/calculators/tensor/tensors_to_audio_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								mediapipe/calculators/tensor/tensors_to_audio_calculator.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,197 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <new>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/algorithm/container.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "audio/dsp/window_functions.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/matrix.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "pffft.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace api2 {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
 | 
			
		||||
  std::vector<float> hann_window(window_size);
 | 
			
		||||
  audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
 | 
			
		||||
  if (sqrt_hann) {
 | 
			
		||||
    absl::c_transform(hann_window, hann_window.begin(),
 | 
			
		||||
                      [](double x) { return std::sqrt(x); });
 | 
			
		||||
  }
 | 
			
		||||
  return hann_window;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Note that the InvHannWindow function may only work for 50% overlapping case.
 | 
			
		||||
std::vector<float> InvHannWindow(int window_size, bool sqrt_hann) {
 | 
			
		||||
  std::vector<float> window = HannWindow(window_size, sqrt_hann);
 | 
			
		||||
  std::vector<float> inv_window(window.size());
 | 
			
		||||
  if (sqrt_hann) {
 | 
			
		||||
    absl::c_copy(window, inv_window.begin());
 | 
			
		||||
  } else {
 | 
			
		||||
    const int kHalfWindowSize = window.size() / 2;
 | 
			
		||||
    absl::c_transform(window, inv_window.begin(),
 | 
			
		||||
                      [](double x) { return x * x; });
 | 
			
		||||
    for (int i = 0; i < kHalfWindowSize; ++i) {
 | 
			
		||||
      double sum = inv_window[i] + inv_window[kHalfWindowSize + i];
 | 
			
		||||
      inv_window[i] = window[i] / sum;
 | 
			
		||||
      inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return inv_window;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PFFFT only supports transforms for inputs of length N of the form
 | 
			
		||||
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
 | 
			
		||||
bool IsValidFftSize(int size) {
 | 
			
		||||
  if (size <= 0) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  constexpr int kFactors[] = {2, 3, 5};
 | 
			
		||||
  int factorization[] = {0, 0, 0};
 | 
			
		||||
  int n = static_cast<int>(size);
 | 
			
		||||
  for (int i = 0; i < 3; ++i) {
 | 
			
		||||
    while (n % kFactors[i] == 0) {
 | 
			
		||||
      n = n / kFactors[i];
 | 
			
		||||
      ++factorization[i];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return factorization[0] >= 5 && n == 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
// Converts 2D MediaPipe float Tensors to audio buffers.
 | 
			
		||||
// The calculator will perform ifft on the complex DFT and apply the window
 | 
			
		||||
// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must
 | 
			
		||||
// have the DFT real parts in its first row and the DFT imagery parts in its
 | 
			
		||||
// second row. A valid "fft_size" must be set in the CalculatorOptions.
 | 
			
		||||
//
 | 
			
		||||
// Inputs:
 | 
			
		||||
//   TENSORS - std::vector<Tensor>
 | 
			
		||||
//     Vector containing a single Tensor that represents the audio's complex DFT
 | 
			
		||||
//     results.
 | 
			
		||||
//   DC_AND_NYQUIST - std::pair<float, float>
 | 
			
		||||
//     A pair of dc component and nyquist component.
 | 
			
		||||
//
 | 
			
		||||
// Outputs:
 | 
			
		||||
//   AUDIO - mediapipe::Matrix
 | 
			
		||||
//     The audio data represented as mediapipe::Matrix.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
// node {
 | 
			
		||||
//   calculator: "TensorsToAudioCalculator"
 | 
			
		||||
//   input_stream: "TENSORS:tensors"
 | 
			
		||||
//   input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
 | 
			
		||||
//   output_stream: "AUDIO:audio"
 | 
			
		||||
//   options {
 | 
			
		||||
//     [mediapipe.AudioToTensorCalculatorOptions.ext] {
 | 
			
		||||
//       fft_size: 256
 | 
			
		||||
//     }
 | 
			
		||||
//   }
 | 
			
		||||
// }
 | 
			
		||||
class TensorsToAudioCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
 | 
			
		||||
  static constexpr Input<std::pair<float, float>> kDcAndNyquistIn{
 | 
			
		||||
      "DC_AND_NYQUIST"};
 | 
			
		||||
  static constexpr Output<Matrix> kAudioOut{"AUDIO"};
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Open(CalculatorContext* cc) override;
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) override;
 | 
			
		||||
  absl::Status Close(CalculatorContext* cc) override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // The internal state of the FFT library.
 | 
			
		||||
  PFFFT_Setup* fft_state_ = nullptr;
 | 
			
		||||
  int fft_size_ = 0;
 | 
			
		||||
  float inverse_fft_size_ = 0;
 | 
			
		||||
  std::vector<float, Eigen::aligned_allocator<float>> input_dft_;
 | 
			
		||||
  std::vector<float> inv_fft_window_;
 | 
			
		||||
  std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
 | 
			
		||||
  // pffft requires memory to work with to avoid using the stack.
 | 
			
		||||
  std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
 | 
			
		||||
  std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
 | 
			
		||||
  const auto& options =
 | 
			
		||||
      cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
 | 
			
		||||
  RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
 | 
			
		||||
  RET_CHECK(IsValidFftSize(options.fft_size()))
 | 
			
		||||
      << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
 | 
			
		||||
         ">=0 and c >= 0 and a >= 5, the requested fft size is "
 | 
			
		||||
      << options.fft_size();
 | 
			
		||||
  fft_size_ = options.fft_size();
 | 
			
		||||
  inverse_fft_size_ = 1.0f / fft_size_;
 | 
			
		||||
  fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
 | 
			
		||||
  input_dft_.resize(fft_size_);
 | 
			
		||||
  inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false);
 | 
			
		||||
  fft_input_buffer_.resize(fft_size_);
 | 
			
		||||
  fft_workplace_.resize(fft_size_);
 | 
			
		||||
  fft_output_.resize(fft_size_);
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
 | 
			
		||||
  if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) {
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
  const auto& input_tensors = *kTensorsIn(cc);
 | 
			
		||||
  RET_CHECK_EQ(input_tensors.size(), 1);
 | 
			
		||||
  RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
 | 
			
		||||
  auto view = input_tensors[0].GetCpuReadView();
 | 
			
		||||
  // DC's real part.
 | 
			
		||||
  input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
			
		||||
  // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
			
		||||
  // pffft ignores the Nyquist's imagery part. No need to fetch the last value
 | 
			
		||||
  // from the tensor buffer.
 | 
			
		||||
  input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
 | 
			
		||||
  std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
			
		||||
              (fft_size_ - 2) * sizeof(float));
 | 
			
		||||
  pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
 | 
			
		||||
                          fft_workplace_.data(), PFFFT_BACKWARD);
 | 
			
		||||
  // Applies the inverse window function.
 | 
			
		||||
  std::transform(
 | 
			
		||||
      fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(),
 | 
			
		||||
      fft_output_.begin(),
 | 
			
		||||
      [this](float a, float b) { return a * b * inverse_fft_size_; });
 | 
			
		||||
  Matrix matrix = Eigen::Map<Matrix>(fft_output_.data(), 1, fft_output_.size());
 | 
			
		||||
  kAudioOut(cc).Send(std::move(matrix));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) {
 | 
			
		||||
  if (fft_state_) {
 | 
			
		||||
    pffft_destroy_setup(fft_state_);
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator);
 | 
			
		||||
 | 
			
		||||
}  // namespace api2
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,29 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
syntax = "proto2";
 | 
			
		||||
 | 
			
		||||
package mediapipe;
 | 
			
		||||
 | 
			
		||||
import "mediapipe/framework/calculator.proto";
 | 
			
		||||
 | 
			
		||||
message TensorsToAudioCalculatorOptions {
 | 
			
		||||
  extend mediapipe.CalculatorOptions {
 | 
			
		||||
    optional TensorsToAudioCalculatorOptions ext = 484297136;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Size of the fft in number of bins. If set, the calculator will do ifft
 | 
			
		||||
  // on the input tensor.
 | 
			
		||||
  optional int64 fft_size = 1;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										149
									
								
								mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,149 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <new>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/strings/substitute.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/matrix.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  // Creates an audio matrix containing a single sample of 1.0 at a specified
 | 
			
		||||
  // offset.
 | 
			
		||||
  Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) {
 | 
			
		||||
    Matrix impulse = Matrix::Zero(1, num_samples);
 | 
			
		||||
    impulse(0, impulse_offset_idx) = 1.0;
 | 
			
		||||
    return impulse;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
 | 
			
		||||
    graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
			
		||||
        absl::Substitute(R"(
 | 
			
		||||
        input_stream: "audio_in"
 | 
			
		||||
        input_stream: "sample_rate"
 | 
			
		||||
        output_stream: "audio_out"
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "AudioToTensorCalculator"
 | 
			
		||||
          input_stream: "AUDIO:audio_in"
 | 
			
		||||
          input_stream: "SAMPLE_RATE:sample_rate"
 | 
			
		||||
          output_stream: "TENSORS:tensors"
 | 
			
		||||
          output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
 | 
			
		||||
          options {
 | 
			
		||||
            [mediapipe.AudioToTensorCalculatorOptions.ext] {
 | 
			
		||||
              num_channels: 1
 | 
			
		||||
              num_samples: $0
 | 
			
		||||
              num_overlapping_samples: 0
 | 
			
		||||
              target_sample_rate: $1
 | 
			
		||||
              fft_size: $2
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "TensorsToAudioCalculator"
 | 
			
		||||
          input_stream: "TENSORS:tensors"
 | 
			
		||||
          input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
 | 
			
		||||
          output_stream: "AUDIO:audio_out"
 | 
			
		||||
          options {
 | 
			
		||||
            [mediapipe.TensorsToAudioCalculatorOptions.ext] {
 | 
			
		||||
              fft_size: $2
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        )",
 | 
			
		||||
                         /*$0=*/num_samples,
 | 
			
		||||
                         /*$1=*/sample_rate,
 | 
			
		||||
                         /*$2=*/fft_size));
 | 
			
		||||
    tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void RunGraph(const Matrix& input_data, double sample_rate) {
 | 
			
		||||
    MP_ASSERT_OK(graph_.Initialize(graph_config_));
 | 
			
		||||
    MP_ASSERT_OK(graph_.StartRun({}));
 | 
			
		||||
    MP_ASSERT_OK(graph_.AddPacketToInputStream(
 | 
			
		||||
        "sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
 | 
			
		||||
    MP_ASSERT_OK(graph_.AddPacketToInputStream(
 | 
			
		||||
        "audio_in", MakePacket<Matrix>(input_data).At(Timestamp(0))));
 | 
			
		||||
    MP_ASSERT_OK(graph_.CloseAllInputStreams());
 | 
			
		||||
    MP_ASSERT_OK(graph_.WaitUntilDone());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<Packet> audio_out_packets_;
 | 
			
		||||
  CalculatorGraphConfig graph_config_;
 | 
			
		||||
  CalculatorGraph graph_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
 | 
			
		||||
  ConfigGraph(320, 16000, 103);
 | 
			
		||||
  MP_ASSERT_OK(graph_.Initialize(graph_config_));
 | 
			
		||||
  MP_ASSERT_OK(graph_.StartRun({}));
 | 
			
		||||
  auto status = graph_.WaitUntilIdle();
 | 
			
		||||
  EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
 | 
			
		||||
  EXPECT_THAT(status.message(),
 | 
			
		||||
              ::testing::HasSubstr("FFT size must be of the form"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
			
		||||
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
			
		||||
  // The impulse signal at the center is not affected by the window function.
 | 
			
		||||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
			
		||||
  // As the impulse signal sits at the 1/4 of the hann window, the inverse
 | 
			
		||||
  // window function reduces it by half.
 | 
			
		||||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data / 2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
			
		||||
  // As the impulse signal sits at the beginning of the hann window, the inverse
 | 
			
		||||
  // window function completely removes it.
 | 
			
		||||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -289,8 +289,15 @@ class NodeBase {
 | 
			
		|||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  T& GetOptions() {
 | 
			
		||||
    return GetOptions(T::ext);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Use this API when the proto extension does not follow the "ext" naming
 | 
			
		||||
  // convention.
 | 
			
		||||
  template <typename E>
 | 
			
		||||
  auto& GetOptions(const E& extension) {
 | 
			
		||||
    options_used_ = true;
 | 
			
		||||
    return *options_.MutableExtension(T::ext);
 | 
			
		||||
    return *options_.MutableExtension(extension);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
| 
						 | 
				
			
			@ -386,8 +393,15 @@ class PacketGenerator {
 | 
			
		|||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  T& GetOptions() {
 | 
			
		||||
    return GetOptions(T::ext);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Use this API when the proto extension does not follow the "ext" naming
 | 
			
		||||
  // convention.
 | 
			
		||||
  template <typename E>
 | 
			
		||||
  auto& GetOptions(const E& extension) {
 | 
			
		||||
    options_used_ = true;
 | 
			
		||||
    return *options_.MutableExtension(T::ext);
 | 
			
		||||
    return *options_.MutableExtension(extension);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -185,7 +185,7 @@ class CalculatorBaseFactory {
 | 
			
		|||
// Functions for checking that the calculator has the required GetContract.
 | 
			
		||||
template <class T>
 | 
			
		||||
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
 | 
			
		||||
  typedef absl::Status (*GetContractType)(CalculatorContract * cc);
 | 
			
		||||
  typedef absl::Status (*GetContractType)(CalculatorContract* cc);
 | 
			
		||||
  return std::is_same<decltype(&T::GetContract), GetContractType>::value;
 | 
			
		||||
}
 | 
			
		||||
template <class T>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -133,7 +133,13 @@ message GraphTrace {
 | 
			
		|||
    TPU_TASK = 13;
 | 
			
		||||
    GPU_CALIBRATION = 14;
 | 
			
		||||
    PACKET_QUEUED = 15;
 | 
			
		||||
    GPU_TASK_INVOKE = 16;
 | 
			
		||||
    TPU_TASK_INVOKE = 17;
 | 
			
		||||
    CPU_TASK_INVOKE = 18;
 | 
			
		||||
  }
 | 
			
		||||
  //  //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
 | 
			
		||||
  //  //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
 | 
			
		||||
  // )
 | 
			
		||||
 | 
			
		||||
  // The timing for one packet set being processed at one caclulator node.
 | 
			
		||||
  message CalculatorTrace {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -293,7 +293,6 @@ mediapipe_proto_library(
 | 
			
		|||
    name = "rect_proto",
 | 
			
		||||
    srcs = ["rect.proto"],
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = ["//mediapipe/framework/formats:location_data_proto"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_register_type(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -109,6 +109,13 @@ struct TraceEvent {
 | 
			
		|||
  static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
 | 
			
		||||
  static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
 | 
			
		||||
  static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
 | 
			
		||||
  static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
 | 
			
		||||
  static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
 | 
			
		||||
  static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
 | 
			
		||||
 | 
			
		||||
  //  //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
 | 
			
		||||
  //  //depot/mediapipe/framework/calculator_profile.proto:event_type,
 | 
			
		||||
  // )
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Packet trace log buffer.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
 | 
			
		|||
 | 
			
		||||
// Returns a PacketSequencerCalculator node.
 | 
			
		||||
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
 | 
			
		||||
                                                bool synchronize_io) {
 | 
			
		||||
                                                bool async_selection) {
 | 
			
		||||
  CalculatorGraphConfig::Node* result = config->add_node();
 | 
			
		||||
  *result->mutable_calculator() = "PacketSequencerCalculator";
 | 
			
		||||
  if (synchronize_io) {
 | 
			
		||||
  if (!async_selection) {
 | 
			
		||||
    *result->mutable_input_stream_handler()->mutable_input_stream_handler() =
 | 
			
		||||
        "DefaultInputStreamHandler";
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField<std::string>& streams,
 | 
			
		|||
  return tags.count({tag, 0}) > 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns true if a set of "TAG::index" includes a TagIndex.
 | 
			
		||||
bool ContainsTag(const proto_ns::RepeatedPtrField<std::string>& tags,
 | 
			
		||||
                 TagIndex item) {
 | 
			
		||||
  for (const std::string& t : tags) {
 | 
			
		||||
    if (ParseTagIndex(t) == item) return true;
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
 | 
			
		||||
    const Subgraph::SubgraphOptions& options) {
 | 
			
		||||
  CalculatorGraphConfig config;
 | 
			
		||||
| 
						 | 
				
			
			@ -263,17 +272,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
 | 
			
		|||
  std::string enable_stream = "ENABLE:gate_enable";
 | 
			
		||||
 | 
			
		||||
  // Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
 | 
			
		||||
  bool synchronize_io =
 | 
			
		||||
      Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
 | 
			
		||||
          .synchronize_io();
 | 
			
		||||
  const auto& switch_options =
 | 
			
		||||
      Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options);
 | 
			
		||||
  bool async_selection = switch_options.async_selection();
 | 
			
		||||
  if (HasTag(container_node.input_stream(), "SELECT")) {
 | 
			
		||||
    select_node = BuildTimestampNode(&config, synchronize_io);
 | 
			
		||||
    select_node = BuildTimestampNode(&config, async_selection);
 | 
			
		||||
    select_node->add_input_stream("INPUT:gate_select");
 | 
			
		||||
    select_node->add_output_stream("OUTPUT:gate_select_timed");
 | 
			
		||||
    select_stream = "SELECT:gate_select_timed";
 | 
			
		||||
  }
 | 
			
		||||
  if (HasTag(container_node.input_stream(), "ENABLE")) {
 | 
			
		||||
    enable_node = BuildTimestampNode(&config, synchronize_io);
 | 
			
		||||
    enable_node = BuildTimestampNode(&config, async_selection);
 | 
			
		||||
    enable_node->add_input_stream("INPUT:gate_enable");
 | 
			
		||||
    enable_node->add_output_stream("OUTPUT:gate_enable_timed");
 | 
			
		||||
    enable_stream = "ENABLE:gate_enable_timed";
 | 
			
		||||
| 
						 | 
				
			
			@ -296,7 +305,7 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
 | 
			
		|||
  mux->add_input_side_packet("SELECT:gate_select");
 | 
			
		||||
  mux->add_input_side_packet("ENABLE:gate_enable");
 | 
			
		||||
 | 
			
		||||
  // Add input streams for graph and demux and the timestamper.
 | 
			
		||||
  // Add input streams for graph and demux.
 | 
			
		||||
  config.add_input_stream("SELECT:gate_select");
 | 
			
		||||
  config.add_input_stream("ENABLE:gate_enable");
 | 
			
		||||
  config.add_input_side_packet("SELECT:gate_select");
 | 
			
		||||
| 
						 | 
				
			
			@ -306,6 +315,12 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
 | 
			
		|||
    std::string stream = CatStream(p.first, p.second);
 | 
			
		||||
    config.add_input_stream(stream);
 | 
			
		||||
    demux->add_input_stream(stream);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Add input streams for the timestamper.
 | 
			
		||||
  auto& tick_streams = switch_options.tick_input_stream();
 | 
			
		||||
  for (const auto& p : input_tags) {
 | 
			
		||||
    if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue;
 | 
			
		||||
    TagIndex tick_tag{"TICK", tick_index++};
 | 
			
		||||
    if (select_node) {
 | 
			
		||||
      select_node->add_input_stream(CatStream(tick_tag, p.second));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,6 +25,14 @@ message SwitchContainerOptions {
 | 
			
		|||
  // Activates channel 1 for enable = true, channel 0 otherwise.
 | 
			
		||||
  optional bool enable = 4;
 | 
			
		||||
 | 
			
		||||
  // Use DefaultInputStreamHandler for muxing & demuxing.
 | 
			
		||||
  // Use DefaultInputStreamHandler for demuxing.
 | 
			
		||||
  optional bool synchronize_io = 5;
 | 
			
		||||
 | 
			
		||||
  // Use ImmediateInputStreamHandler for channel selection.
 | 
			
		||||
  optional bool async_selection = 6;
 | 
			
		||||
 | 
			
		||||
  // Specifies an input stream, "TAG:index", that defines the processed
 | 
			
		||||
  // timestamps.  SwitchContainer awaits output at the last processed
 | 
			
		||||
  // timestamp before advancing from one selected channel to the next.
 | 
			
		||||
  repeated string tick_input_stream = 7;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
 | 
			
		|||
          input_stream: "INPUT:enable"
 | 
			
		||||
          input_stream: "TICK:foo"
 | 
			
		||||
          output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
 | 
			
		||||
          input_stream_handler {
 | 
			
		||||
            input_stream_handler: "DefaultInputStreamHandler"
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        node {
 | 
			
		||||
          name: "switchcontainer__SwitchDemuxCalculator"
 | 
			
		||||
| 
						 | 
				
			
			@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
 | 
			
		|||
// Shows the SwitchContainer container runs with a pair of simple subnodes.
 | 
			
		||||
TEST(SwitchContainerTest, RunsWithSubnodes) {
 | 
			
		||||
  EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
 | 
			
		||||
  CalculatorGraphConfig supergraph = SubnodeContainerExample();
 | 
			
		||||
  CalculatorGraphConfig supergraph =
 | 
			
		||||
      SubnodeContainerExample("async_selection: true");
 | 
			
		||||
  MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
 | 
			
		||||
  RunTestContainer(supergraph);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,7 @@
 | 
			
		|||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <queue>
 | 
			
		||||
#include <set>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -54,21 +55,47 @@ namespace mediapipe {
 | 
			
		|||
// contained subgraph or calculator nodes.
 | 
			
		||||
//
 | 
			
		||||
class SwitchDemuxCalculator : public CalculatorBase {
 | 
			
		||||
  static constexpr char kSelectTag[] = "SELECT";
 | 
			
		||||
  static constexpr char kEnableTag[] = "ENABLE";
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  static absl::Status GetContract(CalculatorContract* cc);
 | 
			
		||||
 | 
			
		||||
  absl::Status Open(CalculatorContext* cc) override;
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  absl::Status RecordPackets(CalculatorContext* cc);
 | 
			
		||||
  int ChannelIndex(Timestamp timestamp);
 | 
			
		||||
  absl::Status SendActivePackets(CalculatorContext* cc);
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  int channel_index_;
 | 
			
		||||
  std::set<std::string> channel_tags_;
 | 
			
		||||
  using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
 | 
			
		||||
  PacketQueue input_queue_;
 | 
			
		||||
  std::map<Timestamp, int> channel_history_;
 | 
			
		||||
};
 | 
			
		||||
REGISTER_CALCULATOR(SwitchDemuxCalculator);
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
static constexpr char kSelectTag[] = "SELECT";
 | 
			
		||||
static constexpr char kEnableTag[] = "ENABLE";
 | 
			
		||||
 | 
			
		||||
// Returns the last received timestamp for an input stream.
 | 
			
		||||
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
 | 
			
		||||
  return input.Value().Timestamp();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns the last received timestamp for channel selection.
 | 
			
		||||
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
 | 
			
		||||
  Timestamp result = Timestamp::Done();
 | 
			
		||||
  if (cc->Inputs().HasTag(kEnableTag)) {
 | 
			
		||||
    result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
 | 
			
		||||
  } else if (cc->Inputs().HasTag(kSelectTag)) {
 | 
			
		||||
    result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
 | 
			
		||||
  // Allow any one of kSelectTag, kEnableTag.
 | 
			
		||||
  cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
 | 
			
		||||
| 
						 | 
				
			
			@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
 | 
			
		|||
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
 | 
			
		||||
  channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
 | 
			
		||||
  channel_tags_ = ChannelTags(cc->Outputs().TagMap());
 | 
			
		||||
  channel_history_[Timestamp::Unstarted()] = channel_index_;
 | 
			
		||||
 | 
			
		||||
  // Relay side packets to all channels.
 | 
			
		||||
  // Note: This is necessary because Calculator::Open only proceeds when every
 | 
			
		||||
| 
						 | 
				
			
			@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
 | 
			
		||||
  // Update the input channel index if specified.
 | 
			
		||||
  channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
 | 
			
		||||
  MP_RETURN_IF_ERROR(RecordPackets(cc));
 | 
			
		||||
  MP_RETURN_IF_ERROR(SendActivePackets(cc));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
  // Relay packets and timestamps only to channel_index_.
 | 
			
		||||
// Enqueue all arriving packets and bounds.
 | 
			
		||||
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
 | 
			
		||||
  // Enqueue any new arriving packets.
 | 
			
		||||
  for (const std::string& tag : channel_tags_) {
 | 
			
		||||
    for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
 | 
			
		||||
      auto& input = cc->Inputs().Get(tag, index);
 | 
			
		||||
      std::string output_tag = tool::ChannelTag(tag, channel_index_);
 | 
			
		||||
      auto output_id = cc->Outputs().GetId(output_tag, index);
 | 
			
		||||
      if (output_id.IsValid()) {
 | 
			
		||||
        auto& output = cc->Outputs().Get(output_tag, index);
 | 
			
		||||
        tool::Relay(input, &output);
 | 
			
		||||
      auto input_id = cc->Inputs().GetId(tag, index);
 | 
			
		||||
      Packet packet = cc->Inputs().Get(input_id).Value();
 | 
			
		||||
      if (packet.Timestamp() == cc->InputTimestamp()) {
 | 
			
		||||
        input_queue_[input_id].push(packet);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Enque any new input channel and its activation timestamp.
 | 
			
		||||
  Timestamp channel_settled = ChannelSettledTimestamp(cc);
 | 
			
		||||
  int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
 | 
			
		||||
  if (channel_settled == cc->InputTimestamp() &&
 | 
			
		||||
      new_channel_index != channel_index_) {
 | 
			
		||||
    channel_index_ = new_channel_index;
 | 
			
		||||
    channel_history_[channel_settled] = channel_index_;
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns the channel index for a Timestamp.
 | 
			
		||||
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
 | 
			
		||||
  auto it = std::prev(channel_history_.upper_bound(timestamp));
 | 
			
		||||
  return it->second;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Dispatches all queued input packets with known channels.
 | 
			
		||||
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
 | 
			
		||||
  // Dispatch any queued input packets with a defined channel_index.
 | 
			
		||||
  Timestamp channel_settled = ChannelSettledTimestamp(cc);
 | 
			
		||||
  for (const std::string& tag : channel_tags_) {
 | 
			
		||||
    for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
 | 
			
		||||
      auto input_id = cc->Inputs().GetId(tag, index);
 | 
			
		||||
      auto& queue = input_queue_[input_id];
 | 
			
		||||
      while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
 | 
			
		||||
        int channel_index = ChannelIndex(queue.front().Timestamp());
 | 
			
		||||
        std::string output_tag = tool::ChannelTag(tag, channel_index);
 | 
			
		||||
        auto output_id = cc->Outputs().GetId(output_tag, index);
 | 
			
		||||
        if (output_id.IsValid()) {
 | 
			
		||||
          cc->Outputs().Get(output_id).AddPacket(queue.front());
 | 
			
		||||
        }
 | 
			
		||||
        queue.pop();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Discard all select packets not needed for any remaining input packets.
 | 
			
		||||
  Timestamp input_settled = Timestamp::Done();
 | 
			
		||||
  for (const std::string& tag : channel_tags_) {
 | 
			
		||||
    for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
 | 
			
		||||
      auto input_id = cc->Inputs().GetId(tag, index);
 | 
			
		||||
      Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
 | 
			
		||||
      if (!input_queue_[input_id].empty()) {
 | 
			
		||||
        Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
 | 
			
		||||
        stream_settled =
 | 
			
		||||
            std::min(stream_settled, stream_bound.PreviousAllowedInStream());
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  Timestamp input_bound = input_settled.NextAllowedInStream();
 | 
			
		||||
  auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
 | 
			
		||||
  channel_history_.erase(channel_history_.begin(), history_bound);
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
 | 
			
		|||
  options_ = cc->Options<mediapipe::SwitchContainerOptions>();
 | 
			
		||||
  channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
 | 
			
		||||
  channel_tags_ = ChannelTags(cc->Inputs().TagMap());
 | 
			
		||||
  channel_history_[Timestamp::Unset()] = channel_index_;
 | 
			
		||||
  channel_history_[Timestamp::Unstarted()] = channel_index_;
 | 
			
		||||
 | 
			
		||||
  // Relay side packets only from channel_index_.
 | 
			
		||||
  for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
 | 
			
		|||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
 | 
			
		||||
 | 
			
		||||
static void EglThreadExitCallback(void* key_value) {
 | 
			
		||||
#if defined(__ANDROID__)
 | 
			
		||||
  eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
 | 
			
		||||
                 EGL_NO_CONTEXT);
 | 
			
		||||
#else
 | 
			
		||||
  // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
 | 
			
		||||
  // parameter for eglMakeCurrent. This behavior is not portable to all EGL
 | 
			
		||||
  // implementations, and should be considered as an undocumented vendor
 | 
			
		||||
  // extension.
 | 
			
		||||
  // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
 | 
			
		||||
  //
 | 
			
		||||
  // NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
 | 
			
		||||
  eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
 | 
			
		||||
                 EGL_NO_SURFACE, EGL_NO_CONTEXT);
 | 
			
		||||
#endif
 | 
			
		||||
  eglReleaseThread();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
 | 
			
		|||
import android.graphics.Bitmap;
 | 
			
		||||
import com.google.mediapipe.framework.image.BitmapExtractor;
 | 
			
		||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
 | 
			
		||||
import com.google.mediapipe.framework.image.Image;
 | 
			
		||||
import com.google.mediapipe.framework.image.ImageProperties;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImageProperties;
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
 | 
			
		||||
// TODO: use Preconditions in this file.
 | 
			
		||||
| 
						 | 
				
			
			@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates an Image packet from an {@link Image}.
 | 
			
		||||
   * Creates a MediaPipe Image packet from a {@link MPImage}.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
 | 
			
		||||
   */
 | 
			
		||||
  public Packet createImage(Image image) {
 | 
			
		||||
  public Packet createImage(MPImage image) {
 | 
			
		||||
    // TODO: Choose the best storage from multiple containers.
 | 
			
		||||
    ImageProperties properties = image.getContainedImageProperties().get(0);
 | 
			
		||||
    if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
 | 
			
		||||
    MPImageProperties properties = image.getContainedImageProperties().get(0);
 | 
			
		||||
    if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
 | 
			
		||||
      ByteBuffer buffer = ByteBufferExtractor.extract(image);
 | 
			
		||||
      int numChannels = 0;
 | 
			
		||||
      switch (properties.getImageFormat()) {
 | 
			
		||||
        case Image.IMAGE_FORMAT_RGBA:
 | 
			
		||||
        case MPImage.IMAGE_FORMAT_RGBA:
 | 
			
		||||
          numChannels = 4;
 | 
			
		||||
          break;
 | 
			
		||||
        case Image.IMAGE_FORMAT_RGB:
 | 
			
		||||
        case MPImage.IMAGE_FORMAT_RGB:
 | 
			
		||||
          numChannels = 3;
 | 
			
		||||
          break;
 | 
			
		||||
        case Image.IMAGE_FORMAT_ALPHA:
 | 
			
		||||
        case MPImage.IMAGE_FORMAT_ALPHA:
 | 
			
		||||
          numChannels = 1;
 | 
			
		||||
          break;
 | 
			
		||||
        default: // fall out
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
 | 
			
		|||
      int height = image.getHeight();
 | 
			
		||||
      return createImage(buffer, width, height, numChannels);
 | 
			
		||||
    }
 | 
			
		||||
    if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
 | 
			
		||||
    if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) {
 | 
			
		||||
      Bitmap bitmap = BitmapExtractor.extract(image);
 | 
			
		||||
      if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
 | 
			
		||||
        throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
 | 
			
		||||
| 
						 | 
				
			
			@ -100,7 +100,7 @@ public class AndroidPacketCreator extends PacketCreator {
 | 
			
		|||
 | 
			
		||||
    // Unsupported type.
 | 
			
		||||
    throw new UnsupportedOperationException(
 | 
			
		||||
        "Unsupported Image container type: " + properties.getImageFormat());
 | 
			
		||||
        "Unsupported Image container type: " + properties.getStorageType());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
 | 
			
		|||
import android.graphics.Bitmap;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
 | 
			
		||||
 * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
 | 
			
		||||
 * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
 | 
			
		||||
 * {@link IllegalArgumentException} will be thrown.
 | 
			
		||||
 */
 | 
			
		||||
public final class BitmapExtractor {
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
 | 
			
		||||
   * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image the image to extract {@link android.graphics.Bitmap} from.
 | 
			
		||||
   * @return the {@link android.graphics.Bitmap} stored in {@link Image}
 | 
			
		||||
   * @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
 | 
			
		||||
   * @throws IllegalArgumentException when the extraction requires unsupported format or data type
 | 
			
		||||
   *     conversions.
 | 
			
		||||
   */
 | 
			
		||||
  public static Bitmap extract(Image image) {
 | 
			
		||||
    ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
 | 
			
		||||
  public static Bitmap extract(MPImage image) {
 | 
			
		||||
    MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
 | 
			
		||||
    if (imageContainer != null) {
 | 
			
		||||
      return ((BitmapImageContainer) imageContainer).getBitmap();
 | 
			
		||||
    } else {
 | 
			
		||||
      // TODO: Support ByteBuffer -> Bitmap conversion.
 | 
			
		||||
      throw new IllegalArgumentException(
 | 
			
		||||
          "Extracting Bitmap from an Image created by objects other than Bitmap is not"
 | 
			
		||||
          "Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
 | 
			
		||||
              + " supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,7 @@ import android.provider.MediaStore;
 | 
			
		|||
import java.io.IOException;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Builds {@link Image} from {@link android.graphics.Bitmap}.
 | 
			
		||||
 * Builds {@link MPImage} from {@link android.graphics.Bitmap}.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
 | 
			
		||||
 * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ public class BitmapImageBuilder {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates the builder to build {@link Image} from a file.
 | 
			
		||||
   * Creates the builder to build {@link MPImage} from a file.
 | 
			
		||||
   *
 | 
			
		||||
   * @param context the application context.
 | 
			
		||||
   * @param uri the path to the resource file.
 | 
			
		||||
| 
						 | 
				
			
			@ -58,15 +58,15 @@ public class BitmapImageBuilder {
 | 
			
		|||
    this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sets value for {@link Image#getTimestamp()}. */
 | 
			
		||||
  /** Sets value for {@link MPImage#getTimestamp()}. */
 | 
			
		||||
  BitmapImageBuilder setTimestamp(long timestamp) {
 | 
			
		||||
    this.timestamp = timestamp;
 | 
			
		||||
    return this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Builds an {@link Image} instance. */
 | 
			
		||||
  public Image build() {
 | 
			
		||||
    return new Image(
 | 
			
		||||
  /** Builds a {@link MPImage} instance. */
 | 
			
		||||
  public MPImage build() {
 | 
			
		||||
    return new MPImage(
 | 
			
		||||
        new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,19 +16,19 @@ limitations under the License.
 | 
			
		|||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
import android.graphics.Bitmap;
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
 | 
			
		||||
 | 
			
		||||
class BitmapImageContainer implements ImageContainer {
 | 
			
		||||
class BitmapImageContainer implements MPImageContainer {
 | 
			
		||||
 | 
			
		||||
  private final Bitmap bitmap;
 | 
			
		||||
  private final ImageProperties properties;
 | 
			
		||||
  private final MPImageProperties properties;
 | 
			
		||||
 | 
			
		||||
  public BitmapImageContainer(Bitmap bitmap) {
 | 
			
		||||
    this.bitmap = bitmap;
 | 
			
		||||
    this.properties =
 | 
			
		||||
        ImageProperties.builder()
 | 
			
		||||
        MPImageProperties.builder()
 | 
			
		||||
            .setImageFormat(convertFormatCode(bitmap.getConfig()))
 | 
			
		||||
            .setStorageType(Image.STORAGE_TYPE_BITMAP)
 | 
			
		||||
            .setStorageType(MPImage.STORAGE_TYPE_BITMAP)
 | 
			
		||||
            .build();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public ImageProperties getImageProperties() {
 | 
			
		||||
  public MPImageProperties getImageProperties() {
 | 
			
		||||
    return properties;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
 | 
			
		|||
    bitmap.recycle();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @ImageFormat
 | 
			
		||||
  @MPImageFormat
 | 
			
		||||
  static int convertFormatCode(Bitmap.Config config) {
 | 
			
		||||
    switch (config) {
 | 
			
		||||
      case ALPHA_8:
 | 
			
		||||
        return Image.IMAGE_FORMAT_ALPHA;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_ALPHA;
 | 
			
		||||
      case ARGB_8888:
 | 
			
		||||
        return Image.IMAGE_FORMAT_RGBA;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_RGBA;
 | 
			
		||||
      default:
 | 
			
		||||
        return Image.IMAGE_FORMAT_UNKNOWN;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_UNKNOWN;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
 | 
			
		|||
import android.os.Build.VERSION;
 | 
			
		||||
import android.os.Build.VERSION_CODES;
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
import java.nio.ByteOrder;
 | 
			
		||||
import java.util.Locale;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Utility for extracting {@link ByteBuffer} from {@link Image}.
 | 
			
		||||
 * Utility for extracting {@link ByteBuffer} from {@link MPImage}.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
 | 
			
		||||
 * {@link IllegalArgumentException} will be thrown.
 | 
			
		||||
 * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
 | 
			
		||||
 * otherwise {@link IllegalArgumentException} will be thrown.
 | 
			
		||||
 */
 | 
			
		||||
public class ByteBufferExtractor {
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Extracts a {@link ByteBuffer} from an {@link Image}.
 | 
			
		||||
   * Extracts a {@link ByteBuffer} from a {@link MPImage}.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
 | 
			
		||||
   * ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
 | 
			
		||||
   * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
 | 
			
		||||
   *
 | 
			
		||||
   * @see Image#getContainedImageProperties()
 | 
			
		||||
   * @see MPImage#getContainedImageProperties()
 | 
			
		||||
   * @return A read-only {@link ByteBuffer}.
 | 
			
		||||
   * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
 | 
			
		||||
   */
 | 
			
		||||
  @SuppressLint("SwitchIntDef")
 | 
			
		||||
  public static ByteBuffer extract(Image image) {
 | 
			
		||||
    ImageContainer container = image.getContainer();
 | 
			
		||||
  public static ByteBuffer extract(MPImage image) {
 | 
			
		||||
    MPImageContainer container = image.getContainer();
 | 
			
		||||
    switch (container.getImageProperties().getStorageType()) {
 | 
			
		||||
      case Image.STORAGE_TYPE_BYTEBUFFER:
 | 
			
		||||
      case MPImage.STORAGE_TYPE_BYTEBUFFER:
 | 
			
		||||
        ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
 | 
			
		||||
        return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
 | 
			
		||||
      default:
 | 
			
		||||
        throw new IllegalArgumentException(
 | 
			
		||||
            "Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
 | 
			
		||||
            "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
 | 
			
		||||
                + " supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
 | 
			
		||||
   * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>Format conversion spec:
 | 
			
		||||
   *
 | 
			
		||||
| 
						 | 
				
			
			@ -70,26 +70,26 @@ public class ByteBufferExtractor {
 | 
			
		|||
   *
 | 
			
		||||
   * @param image the image to extract buffer from.
 | 
			
		||||
   * @param targetFormat the image format of the result bytebuffer.
 | 
			
		||||
   * @return the readonly {@link ByteBuffer} stored in {@link Image}
 | 
			
		||||
   * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
 | 
			
		||||
   * @throws IllegalArgumentException when the extraction requires unsupported format or data type
 | 
			
		||||
   *     conversions.
 | 
			
		||||
   */
 | 
			
		||||
  static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
 | 
			
		||||
    ImageContainer container;
 | 
			
		||||
    ImageProperties byteBufferProperties =
 | 
			
		||||
        ImageProperties.builder()
 | 
			
		||||
            .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
 | 
			
		||||
  static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
 | 
			
		||||
    MPImageContainer container;
 | 
			
		||||
    MPImageProperties byteBufferProperties =
 | 
			
		||||
        MPImageProperties.builder()
 | 
			
		||||
            .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
 | 
			
		||||
            .setImageFormat(targetFormat)
 | 
			
		||||
            .build();
 | 
			
		||||
    if ((container = image.getContainer(byteBufferProperties)) != null) {
 | 
			
		||||
      ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
 | 
			
		||||
      return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
 | 
			
		||||
    } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
 | 
			
		||||
    } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
 | 
			
		||||
      ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
 | 
			
		||||
      @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
 | 
			
		||||
      @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
 | 
			
		||||
      return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
 | 
			
		||||
          .asReadOnlyBuffer();
 | 
			
		||||
    } else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
 | 
			
		||||
    } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
 | 
			
		||||
      BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
 | 
			
		||||
      ByteBuffer byteBuffer =
 | 
			
		||||
          extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
 | 
			
		||||
| 
						 | 
				
			
			@ -98,85 +98,89 @@ public class ByteBufferExtractor {
 | 
			
		|||
      return byteBuffer;
 | 
			
		||||
    } else {
 | 
			
		||||
      throw new IllegalArgumentException(
 | 
			
		||||
          "Extracting ByteBuffer from an Image created by objects other than Bitmap or"
 | 
			
		||||
          "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
 | 
			
		||||
              + " Bytebuffer is not supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
 | 
			
		||||
  /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
 | 
			
		||||
  @AutoValue
 | 
			
		||||
  abstract static class Result {
 | 
			
		||||
    /** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
 | 
			
		||||
    /**
 | 
			
		||||
     * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
 | 
			
		||||
     */
 | 
			
		||||
    public abstract ByteBuffer buffer();
 | 
			
		||||
 | 
			
		||||
    /** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
 | 
			
		||||
    @ImageFormat
 | 
			
		||||
    /**
 | 
			
		||||
     * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
 | 
			
		||||
     */
 | 
			
		||||
    @MPImageFormat
 | 
			
		||||
    public abstract int format();
 | 
			
		||||
 | 
			
		||||
    static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
 | 
			
		||||
    static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
 | 
			
		||||
      return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
 | 
			
		||||
   * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
 | 
			
		||||
   *
 | 
			
		||||
   * @return the readonly {@link ByteBuffer} stored in {@link Image}
 | 
			
		||||
   * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
 | 
			
		||||
   * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
 | 
			
		||||
   *     given {@code imageFormat}
 | 
			
		||||
   */
 | 
			
		||||
  static Result extractInRecommendedFormat(Image image) {
 | 
			
		||||
    ImageContainer container;
 | 
			
		||||
    if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
 | 
			
		||||
  static Result extractInRecommendedFormat(MPImage image) {
 | 
			
		||||
    MPImageContainer container;
 | 
			
		||||
    if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
 | 
			
		||||
      Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
 | 
			
		||||
      @ImageFormat int format = adviseImageFormat(bitmap);
 | 
			
		||||
      @MPImageFormat int format = adviseImageFormat(bitmap);
 | 
			
		||||
      Result result =
 | 
			
		||||
          Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
 | 
			
		||||
 | 
			
		||||
      boolean unused =
 | 
			
		||||
          image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
 | 
			
		||||
      return result;
 | 
			
		||||
    } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
 | 
			
		||||
    } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
 | 
			
		||||
      ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
 | 
			
		||||
      return Result.create(
 | 
			
		||||
          byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
 | 
			
		||||
          byteBufferImageContainer.getImageFormat());
 | 
			
		||||
    } else {
 | 
			
		||||
      throw new IllegalArgumentException(
 | 
			
		||||
          "Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
 | 
			
		||||
          "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
 | 
			
		||||
              + " is not supported");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @ImageFormat
 | 
			
		||||
  @MPImageFormat
 | 
			
		||||
  private static int adviseImageFormat(Bitmap bitmap) {
 | 
			
		||||
    if (bitmap.getConfig() == Config.ARGB_8888) {
 | 
			
		||||
      return Image.IMAGE_FORMAT_RGBA;
 | 
			
		||||
      return MPImage.IMAGE_FORMAT_RGBA;
 | 
			
		||||
    } else {
 | 
			
		||||
      throw new IllegalArgumentException(
 | 
			
		||||
          String.format(
 | 
			
		||||
              "Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
 | 
			
		||||
              "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
 | 
			
		||||
                  + " supported",
 | 
			
		||||
              bitmap.getConfig()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private static ByteBuffer extractByteBufferFromBitmap(
 | 
			
		||||
      Bitmap bitmap, @ImageFormat int imageFormat) {
 | 
			
		||||
      Bitmap bitmap, @MPImageFormat int imageFormat) {
 | 
			
		||||
    if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
 | 
			
		||||
      throw new IllegalArgumentException(
 | 
			
		||||
          "Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
 | 
			
		||||
          "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
 | 
			
		||||
              + " supported");
 | 
			
		||||
    }
 | 
			
		||||
    if (bitmap.getConfig() == Config.ARGB_8888) {
 | 
			
		||||
      if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
 | 
			
		||||
      if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
 | 
			
		||||
        ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
 | 
			
		||||
        bitmap.copyPixelsToBuffer(buffer);
 | 
			
		||||
        buffer.rewind();
 | 
			
		||||
        return buffer;
 | 
			
		||||
      } else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
 | 
			
		||||
      } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
 | 
			
		||||
        // TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
 | 
			
		||||
        int w = bitmap.getWidth();
 | 
			
		||||
        int h = bitmap.getHeight();
 | 
			
		||||
| 
						 | 
				
			
			@ -196,14 +200,14 @@ public class ByteBufferExtractor {
 | 
			
		|||
    }
 | 
			
		||||
    throw new IllegalArgumentException(
 | 
			
		||||
        String.format(
 | 
			
		||||
            "Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
 | 
			
		||||
            "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
 | 
			
		||||
                + " %d is not supported",
 | 
			
		||||
            bitmap.getConfig(), imageFormat));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private static ByteBuffer convertByteBuffer(
 | 
			
		||||
      ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
 | 
			
		||||
    if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
 | 
			
		||||
      ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
 | 
			
		||||
    if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
 | 
			
		||||
      ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
 | 
			
		||||
      // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
 | 
			
		||||
      // array reversely to convert in-place.
 | 
			
		||||
| 
						 | 
				
			
			@ -221,7 +225,8 @@ public class ByteBufferExtractor {
 | 
			
		|||
      target.put(array, 0, target.capacity());
 | 
			
		||||
      target.rewind();
 | 
			
		||||
      return target;
 | 
			
		||||
    } else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
 | 
			
		||||
    } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
 | 
			
		||||
        && targetFormat == MPImage.IMAGE_FORMAT_RGB) {
 | 
			
		||||
      ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
 | 
			
		||||
      // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
 | 
			
		||||
      // array to convert in-place.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,11 +15,11 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Builds a {@link Image} from a {@link ByteBuffer}.
 | 
			
		||||
 * Builds a {@link MPImage} from a {@link ByteBuffer}.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
 | 
			
		||||
 * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
 | 
			
		||||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
 | 
			
		|||
  private final ByteBuffer buffer;
 | 
			
		||||
  private final int width;
 | 
			
		||||
  private final int height;
 | 
			
		||||
  @ImageFormat private final int imageFormat;
 | 
			
		||||
  @MPImageFormat private final int imageFormat;
 | 
			
		||||
 | 
			
		||||
  // Optional fields.
 | 
			
		||||
  private long timestamp;
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
 | 
			
		|||
   * @param imageFormat how the data encode the image.
 | 
			
		||||
   */
 | 
			
		||||
  public ByteBufferImageBuilder(
 | 
			
		||||
      ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
 | 
			
		||||
      ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
 | 
			
		||||
    this.buffer = byteBuffer;
 | 
			
		||||
    this.width = width;
 | 
			
		||||
    this.height = height;
 | 
			
		||||
| 
						 | 
				
			
			@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
 | 
			
		|||
    this.timestamp = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sets value for {@link Image#getTimestamp()}. */
 | 
			
		||||
  /** Sets value for {@link MPImage#getTimestamp()}. */
 | 
			
		||||
  ByteBufferImageBuilder setTimestamp(long timestamp) {
 | 
			
		||||
    this.timestamp = timestamp;
 | 
			
		||||
    return this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Builds an {@link Image} instance. */
 | 
			
		||||
  public Image build() {
 | 
			
		||||
    return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
 | 
			
		||||
  /** Builds a {@link MPImage} instance. */
 | 
			
		||||
  public MPImage build() {
 | 
			
		||||
    return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,21 +15,19 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
 | 
			
		||||
class ByteBufferImageContainer implements ImageContainer {
 | 
			
		||||
class ByteBufferImageContainer implements MPImageContainer {
 | 
			
		||||
 | 
			
		||||
  private final ByteBuffer buffer;
 | 
			
		||||
  private final ImageProperties properties;
 | 
			
		||||
  private final MPImageProperties properties;
 | 
			
		||||
 | 
			
		||||
  public ByteBufferImageContainer(
 | 
			
		||||
      ByteBuffer buffer,
 | 
			
		||||
      @ImageFormat int imageFormat) {
 | 
			
		||||
  public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
 | 
			
		||||
    this.buffer = buffer;
 | 
			
		||||
    this.properties =
 | 
			
		||||
        ImageProperties.builder()
 | 
			
		||||
            .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
 | 
			
		||||
        MPImageProperties.builder()
 | 
			
		||||
            .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
 | 
			
		||||
            .setImageFormat(imageFormat)
 | 
			
		||||
            .build();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public ImageProperties getImageProperties() {
 | 
			
		||||
  public MPImageProperties getImageProperties() {
 | 
			
		||||
    return properties;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Returns the image format.
 | 
			
		||||
   */
 | 
			
		||||
  @ImageFormat
 | 
			
		||||
  /** Returns the image format. */
 | 
			
		||||
  @MPImageFormat
 | 
			
		||||
  public int getImageFormat() {
 | 
			
		||||
    return properties.getImageFormat();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,10 +29,10 @@ import java.util.Map.Entry;
 | 
			
		|||
/**
 | 
			
		||||
 * The wrapper class for image objects.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>{@link Image} is designed to be an immutable image container, which could be shared
 | 
			
		||||
 * <p>{@link MPImage} is designed to be an immutable image container, which could be shared
 | 
			
		||||
 * cross-platforms.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>To construct an {@link Image}, use the provided builders:
 | 
			
		||||
 * <p>To construct a {@link MPImage}, use the provided builders:
 | 
			
		||||
 *
 | 
			
		||||
 * <ul>
 | 
			
		||||
 *   <li>{@link ByteBufferImageBuilder}
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +40,7 @@ import java.util.Map.Entry;
 | 
			
		|||
 *   <li>{@link MediaImageBuilder}
 | 
			
		||||
 * </ul>
 | 
			
		||||
 *
 | 
			
		||||
 * <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
 | 
			
		||||
 * <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
 | 
			
		||||
 * reference count is 1. Developer can call {@link #close()} to reduce reference count to release
 | 
			
		||||
 * internal storage earlier, otherwise Java garbage collection will release the storage eventually.
 | 
			
		||||
 *
 | 
			
		||||
| 
						 | 
				
			
			@ -53,7 +53,7 @@ import java.util.Map.Entry;
 | 
			
		|||
 *   <li>{@link MediaImageExtractor}
 | 
			
		||||
 * </ul>
 | 
			
		||||
 */
 | 
			
		||||
public class Image implements Closeable {
 | 
			
		||||
public class MPImage implements Closeable {
 | 
			
		||||
 | 
			
		||||
  /** Specifies the image format of an image. */
 | 
			
		||||
  @IntDef({
 | 
			
		||||
| 
						 | 
				
			
			@ -69,7 +69,7 @@ public class Image implements Closeable {
 | 
			
		|||
    IMAGE_FORMAT_JPEG,
 | 
			
		||||
  })
 | 
			
		||||
  @Retention(RetentionPolicy.SOURCE)
 | 
			
		||||
  public @interface ImageFormat {}
 | 
			
		||||
  public @interface MPImageFormat {}
 | 
			
		||||
 | 
			
		||||
  public static final int IMAGE_FORMAT_UNKNOWN = 0;
 | 
			
		||||
  public static final int IMAGE_FORMAT_RGBA = 1;
 | 
			
		||||
| 
						 | 
				
			
			@ -98,14 +98,14 @@ public class Image implements Closeable {
 | 
			
		|||
  public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Returns a list of supported image properties for this {@link Image}.
 | 
			
		||||
   * Returns a list of supported image properties for this {@link MPImage}.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>Currently {@link Image} only support single storage type so the size of return list will
 | 
			
		||||
   * <p>Currently {@link MPImage} only support single storage type so the size of return list will
 | 
			
		||||
   * always be 1.
 | 
			
		||||
   *
 | 
			
		||||
   * @see ImageProperties
 | 
			
		||||
   * @see MPImageProperties
 | 
			
		||||
   */
 | 
			
		||||
  public List<ImageProperties> getContainedImageProperties() {
 | 
			
		||||
  public List<MPImageProperties> getContainedImageProperties() {
 | 
			
		||||
    return Collections.singletonList(getContainer().getImageProperties());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -124,7 +124,7 @@ public class Image implements Closeable {
 | 
			
		|||
    return height;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
 | 
			
		||||
  /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
 | 
			
		||||
  private synchronized void acquire() {
 | 
			
		||||
    referenceCount += 1;
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -132,7 +132,7 @@ public class Image implements Closeable {
 | 
			
		|||
  /**
 | 
			
		||||
   * Removes a reference that was previously acquired or init.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>When {@link Image} is created, it has 1 reference count.
 | 
			
		||||
   * <p>When {@link MPImage} is created, it has 1 reference count.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>When the reference count becomes 0, it will release the resource under the hood.
 | 
			
		||||
   */
 | 
			
		||||
| 
						 | 
				
			
			@ -141,24 +141,24 @@ public class Image implements Closeable {
 | 
			
		|||
  public synchronized void close() {
 | 
			
		||||
    referenceCount -= 1;
 | 
			
		||||
    if (referenceCount == 0) {
 | 
			
		||||
      for (ImageContainer imageContainer : containerMap.values()) {
 | 
			
		||||
      for (MPImageContainer imageContainer : containerMap.values()) {
 | 
			
		||||
        imageContainer.close();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Advanced API access for {@link Image}. */
 | 
			
		||||
  /** Advanced API access for {@link MPImage}. */
 | 
			
		||||
  static final class Internal {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Acquires a reference on this {@link Image}. This will increase the reference count by 1.
 | 
			
		||||
     * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
 | 
			
		||||
     *
 | 
			
		||||
     * <p>This method is more useful for image consumer to acquire a reference so image resource
 | 
			
		||||
     * will not be closed accidentally. As image creator, normal developer doesn't need to call this
 | 
			
		||||
     * method.
 | 
			
		||||
     *
 | 
			
		||||
     * <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
 | 
			
		||||
     * #close()} to indicate it doesn't need this {@link Image} anymore.
 | 
			
		||||
     * <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
 | 
			
		||||
     * #close()} to indicate it doesn't need this {@link MPImage} anymore.
 | 
			
		||||
     *
 | 
			
		||||
     * @see #close()
 | 
			
		||||
     */
 | 
			
		||||
| 
						 | 
				
			
			@ -166,10 +166,10 @@ public class Image implements Closeable {
 | 
			
		|||
      image.acquire();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private final Image image;
 | 
			
		||||
    private final MPImage image;
 | 
			
		||||
 | 
			
		||||
    // Only Image creates the internal helper.
 | 
			
		||||
    private Internal(Image image) {
 | 
			
		||||
    // Only MPImage creates the internal helper.
 | 
			
		||||
    private Internal(MPImage image) {
 | 
			
		||||
      this.image = image;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -179,15 +179,15 @@ public class Image implements Closeable {
 | 
			
		|||
    return new Internal(this);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private final Map<ImageProperties, ImageContainer> containerMap;
 | 
			
		||||
  private final Map<MPImageProperties, MPImageContainer> containerMap;
 | 
			
		||||
  private final long timestamp;
 | 
			
		||||
  private final int width;
 | 
			
		||||
  private final int height;
 | 
			
		||||
 | 
			
		||||
  private int referenceCount;
 | 
			
		||||
 | 
			
		||||
  /** Constructs an {@link Image} with a built container. */
 | 
			
		||||
  Image(ImageContainer container, long timestamp, int width, int height) {
 | 
			
		||||
  /** Constructs a {@link MPImage} with a built container. */
 | 
			
		||||
  MPImage(MPImageContainer container, long timestamp, int width, int height) {
 | 
			
		||||
    this.containerMap = new HashMap<>();
 | 
			
		||||
    containerMap.put(container.getImageProperties(), container);
 | 
			
		||||
    this.timestamp = timestamp;
 | 
			
		||||
| 
						 | 
				
			
			@ -201,10 +201,10 @@ public class Image implements Closeable {
 | 
			
		|||
   *
 | 
			
		||||
   * @return the current container.
 | 
			
		||||
   */
 | 
			
		||||
  ImageContainer getContainer() {
 | 
			
		||||
  MPImageContainer getContainer() {
 | 
			
		||||
    // According to the design, in the future we will support multiple containers in one image.
 | 
			
		||||
    // Currently just return the original container.
 | 
			
		||||
    // TODO: Cache multiple containers in Image.
 | 
			
		||||
    // TODO: Cache multiple containers in MPImage.
 | 
			
		||||
    return containerMap.values().iterator().next();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -214,8 +214,8 @@ public class Image implements Closeable {
 | 
			
		|||
   * <p>If there are multiple containers with required {@code storageType}, returns the first one.
 | 
			
		||||
   */
 | 
			
		||||
  @Nullable
 | 
			
		||||
  ImageContainer getContainer(@StorageType int storageType) {
 | 
			
		||||
    for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
 | 
			
		||||
  MPImageContainer getContainer(@StorageType int storageType) {
 | 
			
		||||
    for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
 | 
			
		||||
      if (entry.getKey().getStorageType() == storageType) {
 | 
			
		||||
        return entry.getValue();
 | 
			
		||||
      }
 | 
			
		||||
| 
						 | 
				
			
			@ -225,13 +225,13 @@ public class Image implements Closeable {
 | 
			
		|||
 | 
			
		||||
  /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
 | 
			
		||||
  @Nullable
 | 
			
		||||
  ImageContainer getContainer(ImageProperties imageProperties) {
 | 
			
		||||
  MPImageContainer getContainer(MPImageProperties imageProperties) {
 | 
			
		||||
    return containerMap.get(imageProperties);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
 | 
			
		||||
  boolean addContainer(ImageContainer container) {
 | 
			
		||||
    ImageProperties imageProperties = container.getImageProperties();
 | 
			
		||||
  boolean addContainer(MPImageContainer container) {
 | 
			
		||||
    MPImageProperties imageProperties = container.getImageProperties();
 | 
			
		||||
    if (containerMap.containsKey(imageProperties)) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -14,14 +14,14 @@ limitations under the License.
 | 
			
		|||
==============================================================================*/
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
/** Lightweight abstraction for an object that can receive {@link Image} */
 | 
			
		||||
public interface ImageConsumer {
 | 
			
		||||
/** Lightweight abstraction for an object that can receive {@link MPImage} */
 | 
			
		||||
public interface MPImageConsumer {
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Called when an {@link Image} is available.
 | 
			
		||||
   * Called when a {@link MPImage} is available.
 | 
			
		||||
   *
 | 
			
		||||
   * <p>The argument is only guaranteed to be available until this method returns. if you need to
 | 
			
		||||
   * extend its life time, acquire it, then release it when done.
 | 
			
		||||
   */
 | 
			
		||||
  void onNewImage(Image image);
 | 
			
		||||
  void onNewMPImage(MPImage image);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -16,9 +16,9 @@ limitations under the License.
 | 
			
		|||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
/** Manages internal image data storage. The interface is package-private. */
 | 
			
		||||
interface ImageContainer {
 | 
			
		||||
interface MPImageContainer {
 | 
			
		||||
  /** Returns the properties of the contained image. */
 | 
			
		||||
  ImageProperties getImageProperties();
 | 
			
		||||
  MPImageProperties getImageProperties();
 | 
			
		||||
 | 
			
		||||
  /** Close the image container and releases the image resource inside. */
 | 
			
		||||
  void close();
 | 
			
		||||
| 
						 | 
				
			
			@ -14,9 +14,9 @@ limitations under the License.
 | 
			
		|||
==============================================================================*/
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
/** Lightweight abstraction for an object that produce {@link Image} */
 | 
			
		||||
public interface ImageProducer {
 | 
			
		||||
/** Lightweight abstraction for an object that produce {@link MPImage} */
 | 
			
		||||
public interface MPImageProducer {
 | 
			
		||||
 | 
			
		||||
  /** Sets the consumer that receives the {@link Image}. */
 | 
			
		||||
  void setImageConsumer(ImageConsumer imageConsumer);
 | 
			
		||||
  /** Sets the consumer that receives the {@link MPImage}. */
 | 
			
		||||
  void setMPImageConsumer(MPImageConsumer imageConsumer);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
 | 
			
		|||
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.auto.value.extension.memoized.Memoized;
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.StorageType;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.StorageType;
 | 
			
		||||
 | 
			
		||||
/** Groups a set of properties to describe how an image is stored. */
 | 
			
		||||
@AutoValue
 | 
			
		||||
public abstract class ImageProperties {
 | 
			
		||||
public abstract class MPImageProperties {
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Gets the pixel format of the image.
 | 
			
		||||
   *
 | 
			
		||||
   * @see Image.ImageFormat
 | 
			
		||||
   * @see MPImage.MPImageFormat
 | 
			
		||||
   */
 | 
			
		||||
  @ImageFormat
 | 
			
		||||
  @MPImageFormat
 | 
			
		||||
  public abstract int getImageFormat();
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Gets the storage type of the image.
 | 
			
		||||
   *
 | 
			
		||||
   * @see Image.StorageType
 | 
			
		||||
   * @see MPImage.StorageType
 | 
			
		||||
   */
 | 
			
		||||
  @StorageType
 | 
			
		||||
  public abstract int getStorageType();
 | 
			
		||||
| 
						 | 
				
			
			@ -45,36 +45,36 @@ public abstract class ImageProperties {
 | 
			
		|||
  public abstract int hashCode();
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a builder of {@link ImageProperties}.
 | 
			
		||||
   * Creates a builder of {@link MPImageProperties}.
 | 
			
		||||
   *
 | 
			
		||||
   * @see ImageProperties.Builder
 | 
			
		||||
   * @see MPImageProperties.Builder
 | 
			
		||||
   */
 | 
			
		||||
  static Builder builder() {
 | 
			
		||||
    return new AutoValue_ImageProperties.Builder();
 | 
			
		||||
    return new AutoValue_MPImageProperties.Builder();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Builds a {@link ImageProperties}. */
 | 
			
		||||
  /** Builds a {@link MPImageProperties}. */
 | 
			
		||||
  @AutoValue.Builder
 | 
			
		||||
  abstract static class Builder {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Sets the {@link Image.ImageFormat}.
 | 
			
		||||
     * Sets the {@link MPImage.MPImageFormat}.
 | 
			
		||||
     *
 | 
			
		||||
     * @see ImageProperties#getImageFormat
 | 
			
		||||
     * @see MPImageProperties#getImageFormat
 | 
			
		||||
     */
 | 
			
		||||
    abstract Builder setImageFormat(@ImageFormat int value);
 | 
			
		||||
    abstract Builder setImageFormat(@MPImageFormat int value);
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Sets the {@link Image.StorageType}.
 | 
			
		||||
     * Sets the {@link MPImage.StorageType}.
 | 
			
		||||
     *
 | 
			
		||||
     * @see ImageProperties#getStorageType
 | 
			
		||||
     * @see MPImageProperties#getStorageType
 | 
			
		||||
     */
 | 
			
		||||
    abstract Builder setStorageType(@StorageType int value);
 | 
			
		||||
 | 
			
		||||
    /** Builds the {@link ImageProperties}. */
 | 
			
		||||
    abstract ImageProperties build();
 | 
			
		||||
    /** Builds the {@link MPImageProperties}. */
 | 
			
		||||
    abstract MPImageProperties build();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Hide the constructor.
 | 
			
		||||
  ImageProperties() {}
 | 
			
		||||
  MPImageProperties() {}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -15,11 +15,12 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
import android.media.Image;
 | 
			
		||||
import android.os.Build.VERSION_CODES;
 | 
			
		||||
import androidx.annotation.RequiresApi;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Builds {@link Image} from {@link android.media.Image}.
 | 
			
		||||
 * Builds {@link MPImage} from {@link android.media.Image}.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
 | 
			
		||||
 * content in it.
 | 
			
		||||
| 
						 | 
				
			
			@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
 | 
			
		|||
public class MediaImageBuilder {
 | 
			
		||||
 | 
			
		||||
  // Mandatory fields.
 | 
			
		||||
  private final android.media.Image mediaImage;
 | 
			
		||||
  private final Image mediaImage;
 | 
			
		||||
 | 
			
		||||
  // Optional fields.
 | 
			
		||||
  private long timestamp;
 | 
			
		||||
| 
						 | 
				
			
			@ -40,20 +41,20 @@ public class MediaImageBuilder {
 | 
			
		|||
   *
 | 
			
		||||
   * @param mediaImage image data object.
 | 
			
		||||
   */
 | 
			
		||||
  public MediaImageBuilder(android.media.Image mediaImage) {
 | 
			
		||||
  public MediaImageBuilder(Image mediaImage) {
 | 
			
		||||
    this.mediaImage = mediaImage;
 | 
			
		||||
    this.timestamp = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sets value for {@link Image#getTimestamp()}. */
 | 
			
		||||
  /** Sets value for {@link MPImage#getTimestamp()}. */
 | 
			
		||||
  MediaImageBuilder setTimestamp(long timestamp) {
 | 
			
		||||
    this.timestamp = timestamp;
 | 
			
		||||
    return this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Builds an {@link Image} instance. */
 | 
			
		||||
  public Image build() {
 | 
			
		||||
    return new Image(
 | 
			
		||||
  /** Builds a {@link MPImage} instance. */
 | 
			
		||||
  public MPImage build() {
 | 
			
		||||
    return new MPImage(
 | 
			
		||||
        new MediaImageContainer(mediaImage),
 | 
			
		||||
        timestamp,
 | 
			
		||||
        mediaImage.getWidth(),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,33 +15,34 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
import android.media.Image;
 | 
			
		||||
import android.os.Build;
 | 
			
		||||
import android.os.Build.VERSION;
 | 
			
		||||
import android.os.Build.VERSION_CODES;
 | 
			
		||||
import androidx.annotation.RequiresApi;
 | 
			
		||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
 | 
			
		||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
 | 
			
		||||
 | 
			
		||||
@RequiresApi(VERSION_CODES.KITKAT)
 | 
			
		||||
class MediaImageContainer implements ImageContainer {
 | 
			
		||||
class MediaImageContainer implements MPImageContainer {
 | 
			
		||||
 | 
			
		||||
  private final android.media.Image mediaImage;
 | 
			
		||||
  private final ImageProperties properties;
 | 
			
		||||
  private final Image mediaImage;
 | 
			
		||||
  private final MPImageProperties properties;
 | 
			
		||||
 | 
			
		||||
  public MediaImageContainer(android.media.Image mediaImage) {
 | 
			
		||||
  public MediaImageContainer(Image mediaImage) {
 | 
			
		||||
    this.mediaImage = mediaImage;
 | 
			
		||||
    this.properties =
 | 
			
		||||
        ImageProperties.builder()
 | 
			
		||||
            .setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
 | 
			
		||||
        MPImageProperties.builder()
 | 
			
		||||
            .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
 | 
			
		||||
            .setImageFormat(convertFormatCode(mediaImage.getFormat()))
 | 
			
		||||
            .build();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public android.media.Image getImage() {
 | 
			
		||||
  public Image getImage() {
 | 
			
		||||
    return mediaImage;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public ImageProperties getImageProperties() {
 | 
			
		||||
  public MPImageProperties getImageProperties() {
 | 
			
		||||
    return properties;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
 | 
			
		|||
    mediaImage.close();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @ImageFormat
 | 
			
		||||
  @MPImageFormat
 | 
			
		||||
  static int convertFormatCode(int graphicsFormat) {
 | 
			
		||||
    // We only cover the format mentioned in
 | 
			
		||||
    // https://developer.android.com/reference/android/media/Image#getFormat()
 | 
			
		||||
    if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
 | 
			
		||||
      if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
 | 
			
		||||
        return Image.IMAGE_FORMAT_RGBA;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_RGBA;
 | 
			
		||||
      } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
 | 
			
		||||
        return Image.IMAGE_FORMAT_RGB;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_RGB;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    switch (graphicsFormat) {
 | 
			
		||||
      case android.graphics.ImageFormat.JPEG:
 | 
			
		||||
        return Image.IMAGE_FORMAT_JPEG;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_JPEG;
 | 
			
		||||
      case android.graphics.ImageFormat.YUV_420_888:
 | 
			
		||||
        return Image.IMAGE_FORMAT_YUV_420_888;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_YUV_420_888;
 | 
			
		||||
      default:
 | 
			
		||||
        return Image.IMAGE_FORMAT_UNKNOWN;
 | 
			
		||||
        return MPImage.IMAGE_FORMAT_UNKNOWN;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,13 +15,14 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
package com.google.mediapipe.framework.image;
 | 
			
		||||
 | 
			
		||||
import android.media.Image;
 | 
			
		||||
import android.os.Build.VERSION_CODES;
 | 
			
		||||
import androidx.annotation.RequiresApi;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Utility for extracting {@link android.media.Image} from {@link Image}.
 | 
			
		||||
 * Utility for extracting {@link android.media.Image} from {@link MPImage}.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
 | 
			
		||||
 * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
 | 
			
		||||
 * otherwise {@link IllegalArgumentException} will be thrown.
 | 
			
		||||
 */
 | 
			
		||||
@RequiresApi(VERSION_CODES.KITKAT)
 | 
			
		||||
| 
						 | 
				
			
			@ -30,20 +31,20 @@ public class MediaImageExtractor {
 | 
			
		|||
  private MediaImageExtractor() {}
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
 | 
			
		||||
   * {@link Image} that built from {@link MediaImageBuilder}.
 | 
			
		||||
   * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
 | 
			
		||||
   * {@link MPImage} that built from {@link MediaImageBuilder}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image the image to extract {@link android.media.Image} from.
 | 
			
		||||
   * @return {@link android.media.Image} that stored in {@link Image}.
 | 
			
		||||
   * @return {@link android.media.Image} that stored in {@link MPImage}.
 | 
			
		||||
   * @throws IllegalArgumentException if the extraction failed.
 | 
			
		||||
   */
 | 
			
		||||
  public static android.media.Image extract(Image image) {
 | 
			
		||||
    ImageContainer container;
 | 
			
		||||
    if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
 | 
			
		||||
  public static Image extract(MPImage image) {
 | 
			
		||||
    MPImageContainer container;
 | 
			
		||||
    if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
 | 
			
		||||
      return ((MediaImageContainer) container).getImage();
 | 
			
		||||
    }
 | 
			
		||||
    throw new IllegalArgumentException(
 | 
			
		||||
        "Extract Media Image from an Image created by objects other than Media Image"
 | 
			
		||||
        "Extract Media Image from a MPImage created by objects other than Media Image"
 | 
			
		||||
            + " is not supported");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,4 @@
 | 
			
		|||
# Copyright 2019-2020 The MediaPipe Authors.
 | 
			
		||||
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -209,9 +209,9 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
 | 
			
		|||
def mediapipe_build_aar_with_jni(name, android_library):
 | 
			
		||||
    """Builds MediaPipe AAR with jni.
 | 
			
		||||
 | 
			
		||||
      Args:
 | 
			
		||||
        name: The bazel target name.
 | 
			
		||||
        android_library: the android library that contains jni.
 | 
			
		||||
    Args:
 | 
			
		||||
      name: The bazel target name.
 | 
			
		||||
      android_library: the android library that contains jni.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Generates dummy AndroidManifest.xml for dummy apk usage
 | 
			
		||||
| 
						 | 
				
			
			@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""):
 | 
			
		|||
        src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
        target = "//mediapipe/framework/formats:landmark_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
        target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
        target = "//mediapipe/framework/formats:location_data_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
 | 
			
		||||
        target = "//mediapipe/framework/formats:classification_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
| 
						 | 
				
			
			@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""):
 | 
			
		|||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
        target = "//mediapipe/framework/formats:classification_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
 | 
			
		||||
        target = "//mediapipe/framework/formats:landmark_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
        target = "//mediapipe/framework/formats:location_data_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    proto_src_list.append(mediapipe_java_proto_src_extractor(
 | 
			
		||||
        target = "//mediapipe/framework/formats:rect_java_proto_lite",
 | 
			
		||||
        src_out = "com/google/mediapipe/formats/proto/RectProto.java",
 | 
			
		||||
    ))
 | 
			
		||||
    return proto_src_list
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict library compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//mediapipe:__subpackages__"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -23,15 +24,12 @@ package(
 | 
			
		|||
py_library(
 | 
			
		||||
    name = "data_util",
 | 
			
		||||
    srcs = ["data_util.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "data_util_test",
 | 
			
		||||
    srcs = ["data_util_test.py"],
 | 
			
		||||
    data = ["//mediapipe/model_maker/python/core/data/testdata"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [":data_util"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,8 +42,6 @@ py_library(
 | 
			
		|||
py_test(
 | 
			
		||||
    name = "dataset_test",
 | 
			
		||||
    srcs = ["dataset_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:test_util",
 | 
			
		||||
| 
						 | 
				
			
			@ -55,14 +51,11 @@ py_test(
 | 
			
		|||
py_library(
 | 
			
		||||
    name = "classification_dataset",
 | 
			
		||||
    srcs = ["classification_dataset.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [":dataset"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "classification_dataset_test",
 | 
			
		||||
    srcs = ["classification_dataset_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [":classification_dataset"],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""Common classification dataset library."""
 | 
			
		||||
 | 
			
		||||
from typing import Any, Tuple
 | 
			
		||||
from typing import List, Tuple
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -21,15 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ClassificationDataset(ds.Dataset):
 | 
			
		||||
  """DataLoader for classification models."""
 | 
			
		||||
  """Dataset Loader for classification models."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
 | 
			
		||||
  def __init__(self, dataset: tf.data.Dataset, size: int,
 | 
			
		||||
               label_names: List[str]):
 | 
			
		||||
    super().__init__(dataset, size)
 | 
			
		||||
    self.index_to_label = index_to_label
 | 
			
		||||
    self._label_names = label_names
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def num_classes(self: ds._DatasetT) -> int:
 | 
			
		||||
    return len(self.index_to_label)
 | 
			
		||||
    return len(self._label_names)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def label_names(self: ds._DatasetT) -> List[str]:
 | 
			
		||||
    return self._label_names
 | 
			
		||||
 | 
			
		||||
  def split(self: ds._DatasetT,
 | 
			
		||||
            fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
 | 
			
		||||
| 
						 | 
				
			
			@ -44,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
 | 
			
		|||
    Returns:
 | 
			
		||||
      The splitted two sub datasets.
 | 
			
		||||
    """
 | 
			
		||||
    return self._split(fraction, self.index_to_label)
 | 
			
		||||
    return self._split(fraction, self._label_names)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,45 +12,55 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import Any, List, Tuple, TypeVar
 | 
			
		||||
 | 
			
		||||
# Dependency imports
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
			
		||||
 | 
			
		||||
_DatasetT = TypeVar(
 | 
			
		||||
    '_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
 | 
			
		||||
 | 
			
		||||
class ClassificationDataLoaderTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
class ClassificationDatasetTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def test_split(self):
 | 
			
		||||
 | 
			
		||||
    class MagicClassificationDataLoader(
 | 
			
		||||
    class MagicClassificationDataset(
 | 
			
		||||
        classification_dataset.ClassificationDataset):
 | 
			
		||||
      """A mock classification dataset class for testing purpose.
 | 
			
		||||
 | 
			
		||||
      def __init__(self, dataset, size, index_to_label, value):
 | 
			
		||||
        super(MagicClassificationDataLoader,
 | 
			
		||||
              self).__init__(dataset, size, index_to_label)
 | 
			
		||||
      Attributes:
 | 
			
		||||
        value: A value variable stored by the mock dataset class for testing.
 | 
			
		||||
      """
 | 
			
		||||
 | 
			
		||||
      def __init__(self, dataset: tf.data.Dataset, size: int,
 | 
			
		||||
                   label_names: List[str], value: Any):
 | 
			
		||||
        super().__init__(dataset=dataset, size=size, label_names=label_names)
 | 
			
		||||
        self.value = value
 | 
			
		||||
 | 
			
		||||
      def split(self, fraction):
 | 
			
		||||
        return self._split(fraction, self.index_to_label, self.value)
 | 
			
		||||
      def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
 | 
			
		||||
        return self._split(fraction, self.label_names, self.value)
 | 
			
		||||
 | 
			
		||||
    # Some dummy inputs.
 | 
			
		||||
    magic_value = 42
 | 
			
		||||
    num_classes = 2
 | 
			
		||||
    index_to_label = (False, True)
 | 
			
		||||
    label_names = ['foo', 'bar']
 | 
			
		||||
 | 
			
		||||
    # Create data loader from sample data.
 | 
			
		||||
    ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
 | 
			
		||||
    data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
 | 
			
		||||
                                         magic_value)
 | 
			
		||||
    data = MagicClassificationDataset(
 | 
			
		||||
        dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
 | 
			
		||||
 | 
			
		||||
    # Train/Test data split.
 | 
			
		||||
    fraction = .25
 | 
			
		||||
    train_data, test_data = data.split(fraction)
 | 
			
		||||
    train_data, test_data = data.split(fraction=fraction)
 | 
			
		||||
 | 
			
		||||
    # `split` should return instances of child DataLoader.
 | 
			
		||||
    self.assertIsInstance(train_data, MagicClassificationDataLoader)
 | 
			
		||||
    self.assertIsInstance(test_data, MagicClassificationDataLoader)
 | 
			
		||||
    self.assertIsInstance(train_data, MagicClassificationDataset)
 | 
			
		||||
    self.assertIsInstance(test_data, MagicClassificationDataset)
 | 
			
		||||
 | 
			
		||||
    # Make sure number of entries are right.
 | 
			
		||||
    self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
 | 
			
		||||
| 
						 | 
				
			
			@ -59,7 +69,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
    # Make sure attributes propagated correctly.
 | 
			
		||||
    self.assertEqual(train_data.num_classes, num_classes)
 | 
			
		||||
    self.assertEqual(test_data.index_to_label, index_to_label)
 | 
			
		||||
    self.assertEqual(test_data.label_names, label_names)
 | 
			
		||||
    self.assertEqual(train_data.value, magic_value)
 | 
			
		||||
    self.assertEqual(test_data.value, magic_value)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//mediapipe:__subpackages__"],
 | 
			
		||||
| 
						 | 
				
			
			@ -23,7 +24,6 @@ licenses(["notice"])
 | 
			
		|||
py_library(
 | 
			
		||||
    name = "custom_model",
 | 
			
		||||
    srcs = ["custom_model.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
			
		||||
| 
						 | 
				
			
			@ -34,8 +34,6 @@ py_library(
 | 
			
		|||
py_test(
 | 
			
		||||
    name = "custom_model_test",
 | 
			
		||||
    srcs = ["custom_model_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":custom_model",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:test_util",
 | 
			
		||||
| 
						 | 
				
			
			@ -45,7 +43,6 @@ py_test(
 | 
			
		|||
py_library(
 | 
			
		||||
    name = "classifier",
 | 
			
		||||
    srcs = ["classifier.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":custom_model",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
			
		||||
| 
						 | 
				
			
			@ -55,8 +52,6 @@ py_library(
 | 
			
		|||
py_test(
 | 
			
		||||
    name = "classifier_test",
 | 
			
		||||
    srcs = ["classifier_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":classifier",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:test_util",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
 | 
			
		|||
class Classifier(custom_model.CustomModel):
 | 
			
		||||
  """An abstract base class that represents a TensorFlow classifier."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
 | 
			
		||||
  def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
 | 
			
		||||
               full_train: bool):
 | 
			
		||||
    """Initilizes a classifier with its specifications.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model_spec: Specification for the model.
 | 
			
		||||
        index_to_label: A list that map from index to label class name.
 | 
			
		||||
        label_names: A list of label names for the classes.
 | 
			
		||||
        shuffle: Whether the dataset should be shuffled.
 | 
			
		||||
        full_train: If true, train the model end-to-end including the backbone
 | 
			
		||||
          and the classification layers on top. Otherwise, only train the top
 | 
			
		||||
          classification layers.
 | 
			
		||||
    """
 | 
			
		||||
    super(Classifier, self).__init__(model_spec, shuffle)
 | 
			
		||||
    self._index_to_label = index_to_label
 | 
			
		||||
    self._label_names = label_names
 | 
			
		||||
    self._full_train = full_train
 | 
			
		||||
    self._num_classes = len(index_to_label)
 | 
			
		||||
    self._num_classes = len(label_names)
 | 
			
		||||
 | 
			
		||||
  def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
 | 
			
		||||
    """Evaluates the classifier with the provided evaluation dataset.
 | 
			
		||||
| 
						 | 
				
			
			@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
 | 
			
		|||
    label_filepath = os.path.join(export_dir, label_filename)
 | 
			
		||||
    tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
 | 
			
		||||
    with tf.io.gfile.GFile(label_filepath, 'w') as f:
 | 
			
		||||
      f.write('\n'.join(self._index_to_label))
 | 
			
		||||
      f.write('\n'.join(self._label_names))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super(ClassifierTest, self).setUp()
 | 
			
		||||
    index_to_label = ['cat', 'dog']
 | 
			
		||||
    label_names = ['cat', 'dog']
 | 
			
		||||
    self.model = MockClassifier(
 | 
			
		||||
        model_spec=None,
 | 
			
		||||
        index_to_label=index_to_label,
 | 
			
		||||
        label_names=label_names,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        full_train=False)
 | 
			
		||||
    self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,8 +21,6 @@ import abc
 | 
			
		|||
import os
 | 
			
		||||
from typing import Any, Callable, Optional
 | 
			
		||||
 | 
			
		||||
# Dependency imports
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import dataset
 | 
			
		||||
| 
						 | 
				
			
			@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
 | 
			
		|||
    tflite_filepath = os.path.join(export_dir, tflite_filename)
 | 
			
		||||
    # TODO: Populate metadata to the exported TFLite model.
 | 
			
		||||
    model_util.export_tflite(
 | 
			
		||||
        self._model,
 | 
			
		||||
        tflite_filepath,
 | 
			
		||||
        quantization_config,
 | 
			
		||||
        model=self._model,
 | 
			
		||||
        tflite_filepath=tflite_filepath,
 | 
			
		||||
        quantization_config=quantization_config,
 | 
			
		||||
        preprocess=preprocess)
 | 
			
		||||
    tf.compat.v1.logging.info(
 | 
			
		||||
        'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super(CustomModelTest, self).setUp()
 | 
			
		||||
    self.model = MockCustomModel(model_spec=None, shuffle=False)
 | 
			
		||||
    self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
 | 
			
		||||
    self._model = MockCustomModel(model_spec=None, shuffle=False)
 | 
			
		||||
    self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
 | 
			
		||||
 | 
			
		||||
  def _check_nonempty_file(self, filepath):
 | 
			
		||||
    self.assertTrue(os.path.isfile(filepath))
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def test_export_tflite(self):
 | 
			
		||||
    export_path = os.path.join(self.get_temp_dir(), 'export/')
 | 
			
		||||
    self.model.export_tflite(export_dir=export_path)
 | 
			
		||||
    self._model.export_tflite(export_dir=export_path)
 | 
			
		||||
    self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -24,31 +25,15 @@ py_library(
 | 
			
		|||
    name = "test_util",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["test_util.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "image_preprocessing",
 | 
			
		||||
    srcs = ["image_preprocessing.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "image_preprocessing_test",
 | 
			
		||||
    srcs = ["image_preprocessing_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [":image_preprocessing"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_util",
 | 
			
		||||
    srcs = ["model_util.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":quantization",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
			
		||||
| 
						 | 
				
			
			@ -58,8 +43,6 @@ py_library(
 | 
			
		|||
py_test(
 | 
			
		||||
    name = "model_util_test",
 | 
			
		||||
    srcs = ["model_util_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_util",
 | 
			
		||||
        ":quantization",
 | 
			
		||||
| 
						 | 
				
			
			@ -76,8 +59,6 @@ py_library(
 | 
			
		|||
py_test(
 | 
			
		||||
    name = "loss_functions_test",
 | 
			
		||||
    srcs = ["loss_functions_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [":loss_functions"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -91,8 +72,6 @@ py_library(
 | 
			
		|||
py_test(
 | 
			
		||||
    name = "quantization_test",
 | 
			
		||||
    srcs = ["quantization_test.py"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":quantization",
 | 
			
		||||
        ":test_util",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss):
 | 
			
		|||
      class_weight: A weight to apply to the loss, one for each class. The
 | 
			
		||||
        weight is applied for each input where the ground truth label matches.
 | 
			
		||||
    """
 | 
			
		||||
    super(tf.keras.losses.Loss, self).__init__()
 | 
			
		||||
    super().__init__()
 | 
			
		||||
    # Used for clipping min/max values of probability values in y_pred to avoid
 | 
			
		||||
    # NaNs and Infs in computation.
 | 
			
		||||
    self._epsilon = 1e-7
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -104,8 +104,8 @@ def export_tflite(
 | 
			
		|||
    quantization_config: Configuration for post-training quantization.
 | 
			
		||||
    supported_ops: A list of supported ops in the converted TFLite file.
 | 
			
		||||
    preprocess: A callable to preprocess the representative dataset for
 | 
			
		||||
        quantization. The callable takes three arguments in order: feature,
 | 
			
		||||
        label, and is_training.
 | 
			
		||||
      quantization. The callable takes three arguments in order: feature, label,
 | 
			
		||||
      and is_training.
 | 
			
		||||
  """
 | 
			
		||||
  if tflite_filepath is None:
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		|||
    model = test_util.build_model(input_shape=[input_dim], num_classes=2)
 | 
			
		||||
    tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
 | 
			
		||||
    model_util.export_tflite(model, tflite_file)
 | 
			
		||||
    self._test_tflite(model, tflite_file, input_dim)
 | 
			
		||||
    test_util.test_tflite(
 | 
			
		||||
        keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      dict(
 | 
			
		||||
| 
						 | 
				
			
			@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		|||
    input_dim = 16
 | 
			
		||||
    num_classes = 2
 | 
			
		||||
    max_input_value = 5
 | 
			
		||||
    model = test_util.build_model([input_dim], num_classes)
 | 
			
		||||
    model = test_util.build_model(
 | 
			
		||||
        input_shape=[input_dim], num_classes=num_classes)
 | 
			
		||||
    tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
 | 
			
		||||
 | 
			
		||||
    model_util.export_tflite(model, tflite_file, config)
 | 
			
		||||
    self._test_tflite(
 | 
			
		||||
        model, tflite_file, input_dim, max_input_value, atol=1e-00)
 | 
			
		||||
    self.assertNear(os.path.getsize(tflite_file), model_size, 300)
 | 
			
		||||
 | 
			
		||||
  def _test_tflite(self,
 | 
			
		||||
                   keras_model: tf.keras.Model,
 | 
			
		||||
                   tflite_model_file: str,
 | 
			
		||||
                   input_dim: int,
 | 
			
		||||
                   max_input_value: int = 1000,
 | 
			
		||||
                   atol: float = 1e-04):
 | 
			
		||||
    random_input = test_util.create_random_sample(
 | 
			
		||||
        size=[1, input_dim], high=max_input_value)
 | 
			
		||||
    random_input = tf.convert_to_tensor(random_input)
 | 
			
		||||
 | 
			
		||||
    model_util.export_tflite(
 | 
			
		||||
        model=model, tflite_filepath=tflite_file, quantization_config=config)
 | 
			
		||||
    self.assertTrue(
 | 
			
		||||
        test_util.is_same_output(
 | 
			
		||||
            tflite_model_file, keras_model, random_input, atol=atol))
 | 
			
		||||
        test_util.test_tflite(
 | 
			
		||||
            keras_model=model,
 | 
			
		||||
            tflite_file=tflite_file,
 | 
			
		||||
            size=[1, input_dim],
 | 
			
		||||
            high=max_input_value,
 | 
			
		||||
            atol=1e-00))
 | 
			
		||||
    self.assertNear(os.path.getsize(tflite_file), model_size, 300)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
 | 
			
		|||
  keras_output = keras_model.predict_on_batch(input_tensors)
 | 
			
		||||
 | 
			
		||||
  return np.allclose(lite_output, keras_output, atol=atol)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_tflite(keras_model: tf.keras.Model,
 | 
			
		||||
                tflite_file: str,
 | 
			
		||||
                size: Union[int, List[int]],
 | 
			
		||||
                high: float = 1,
 | 
			
		||||
                atol: float = 1e-04) -> bool:
 | 
			
		||||
  """Verifies if the output of TFLite model and TF Keras model are identical.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    keras_model: Input TensorFlow Keras model.
 | 
			
		||||
    tflite_file: Input TFLite model file.
 | 
			
		||||
    size: Size of the input tesnor.
 | 
			
		||||
    high: Higher boundary of the values in input tensors.
 | 
			
		||||
    atol: Absolute tolerance of the difference between the outputs of Keras
 | 
			
		||||
      model and TFLite model.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    True if the output of TFLite model and TF Keras model are identical.
 | 
			
		||||
    Otherwise, False.
 | 
			
		||||
  """
 | 
			
		||||
  random_input = create_random_sample(size=size, high=high)
 | 
			
		||||
  random_input = tf.convert_to_tensor(random_input)
 | 
			
		||||
 | 
			
		||||
  return is_same_output(
 | 
			
		||||
      tflite_file=tflite_file,
 | 
			
		||||
      keras_model=keras_model,
 | 
			
		||||
      input_tensors=random_input,
 | 
			
		||||
      atol=atol)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										4
									
								
								mediapipe/model_maker/python/internal/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								mediapipe/model_maker/python/internal/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,4 @@
 | 
			
		|||
# MediaPipe Model Maker Internal Library
 | 
			
		||||
 | 
			
		||||
This directory contains model maker library for internal users and experimental
 | 
			
		||||
purposes.
 | 
			
		||||
							
								
								
									
										1
									
								
								mediapipe/model_maker/python/internal/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								mediapipe/model_maker/python/internal/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
"""Model maker internal library."""
 | 
			
		||||
							
								
								
									
										33
									
								
								mediapipe/model_maker/python/vision/core/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								mediapipe/model_maker/python/vision/core/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,33 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//mediapipe:__subpackages__"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "image_preprocessing",
 | 
			
		||||
    srcs = ["image_preprocessing.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "image_preprocessing_test",
 | 
			
		||||
    srcs = ["image_preprocessing_test.py"],
 | 
			
		||||
    deps = [":image_preprocessing"],
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										13
									
								
								mediapipe/model_maker/python/vision/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/model_maker/python/vision/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -13,11 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""ImageNet preprocessing."""
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
# Dependency imports
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
IMAGE_SIZE = 224
 | 
			
		||||
| 
						 | 
				
			
			@ -12,15 +12,10 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
# Dependency imports
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_preprocessed_image(preprocessor, is_training=False):
 | 
			
		||||
| 
						 | 
				
			
			@ -12,8 +12,8 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python library rule.
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python library rule.
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -78,9 +78,9 @@ py_library(
 | 
			
		|||
        ":train_image_classifier_lib",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:image_preprocessing",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:image_preprocessing",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@
 | 
			
		|||
import os
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from typing import List, Optional, Tuple
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import tensorflow_datasets as tfds
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -84,10 +84,10 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		|||
        name for name in os.listdir(data_root)
 | 
			
		||||
        if os.path.isdir(os.path.join(data_root, name)))
 | 
			
		||||
    all_label_size = len(label_names)
 | 
			
		||||
    label_to_index = dict(
 | 
			
		||||
    index_by_label = dict(
 | 
			
		||||
        (name, index) for index, name in enumerate(label_names))
 | 
			
		||||
    all_image_labels = [
 | 
			
		||||
        label_to_index[os.path.basename(os.path.dirname(path))]
 | 
			
		||||
        index_by_label[os.path.basename(os.path.dirname(path))]
 | 
			
		||||
        for path in all_image_paths
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -106,33 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		|||
        'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
 | 
			
		||||
        all_label_size, ', '.join(label_names))
 | 
			
		||||
    return Dataset(
 | 
			
		||||
        dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def load_tf_dataset(
 | 
			
		||||
      cls, name: str
 | 
			
		||||
  ) -> Tuple[Optional[classification_dataset.ClassificationDataset],
 | 
			
		||||
             Optional[classification_dataset.ClassificationDataset],
 | 
			
		||||
             Optional[classification_dataset.ClassificationDataset]]:
 | 
			
		||||
    """Loads data from tensorflow_datasets.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      name: the registered name of the tfds.core.DatasetBuilder. Refer to the
 | 
			
		||||
        documentation of tfds.load for more details.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A tuple of Datasets for the train/validation/test.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: if the input tf dataset does not have train/validation/test
 | 
			
		||||
      labels.
 | 
			
		||||
    """
 | 
			
		||||
    data, info = tfds.load(name, with_info=True)
 | 
			
		||||
    if 'label' not in info.features:
 | 
			
		||||
      raise ValueError('info.features need to contain \'label\' key.')
 | 
			
		||||
    label_names = info.features['label'].names
 | 
			
		||||
 | 
			
		||||
    train_data = _create_data('train', data, info, label_names)
 | 
			
		||||
    validation_data = _create_data('validation', data, info, label_names)
 | 
			
		||||
    test_data = _create_data('test', data, info, label_names)
 | 
			
		||||
    return train_data, validation_data, test_data
 | 
			
		||||
        dataset=image_label_ds, size=all_image_size, label_names=label_names)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def test_split(self):
 | 
			
		||||
    ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
 | 
			
		||||
    data = dataset.Dataset(ds, 4, ['pos', 'neg'])
 | 
			
		||||
    train_data, test_data = data.split(0.5)
 | 
			
		||||
    data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
 | 
			
		||||
    train_data, test_data = data.split(fraction=0.5)
 | 
			
		||||
 | 
			
		||||
    self.assertLen(train_data, 2)
 | 
			
		||||
    for i, elem in enumerate(train_data._dataset):
 | 
			
		||||
      self.assertTrue((elem.numpy() == np.array([i, 1])).all())
 | 
			
		||||
    self.assertEqual(train_data.num_classes, 2)
 | 
			
		||||
    self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
 | 
			
		||||
    self.assertEqual(train_data.label_names, ['pos', 'neg'])
 | 
			
		||||
 | 
			
		||||
    self.assertLen(test_data, 2)
 | 
			
		||||
    for i, elem in enumerate(test_data._dataset):
 | 
			
		||||
      self.assertTrue((elem.numpy() == np.array([i, 0])).all())
 | 
			
		||||
    self.assertEqual(test_data.num_classes, 2)
 | 
			
		||||
    self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
 | 
			
		||||
    self.assertEqual(test_data.label_names, ['pos', 'neg'])
 | 
			
		||||
 | 
			
		||||
  def test_from_folder(self):
 | 
			
		||||
    data = dataset.Dataset.from_folder(self.image_path)
 | 
			
		||||
    data = dataset.Dataset.from_folder(dirname=self.image_path)
 | 
			
		||||
 | 
			
		||||
    self.assertLen(data, 2)
 | 
			
		||||
    self.assertEqual(data.num_classes, 2)
 | 
			
		||||
    self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
 | 
			
		||||
    self.assertEqual(data.label_names, ['daisy', 'tulips'])
 | 
			
		||||
    for image, label in data.gen_tf_dataset():
 | 
			
		||||
      self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
 | 
			
		||||
      if label.numpy() == 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
 | 
			
		|||
    self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
 | 
			
		||||
    self.assertLen(train_data, 1034)
 | 
			
		||||
    self.assertEqual(train_data.num_classes, 3)
 | 
			
		||||
    self.assertEqual(train_data.index_to_label,
 | 
			
		||||
    self.assertEqual(train_data.label_names,
 | 
			
		||||
                     ['angular_leaf_spot', 'bean_rust', 'healthy'])
 | 
			
		||||
 | 
			
		||||
    self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
 | 
			
		||||
    self.assertLen(validation_data, 133)
 | 
			
		||||
    self.assertEqual(validation_data.num_classes, 3)
 | 
			
		||||
    self.assertEqual(validation_data.index_to_label,
 | 
			
		||||
    self.assertEqual(validation_data.label_names,
 | 
			
		||||
                     ['angular_leaf_spot', 'bean_rust', 'healthy'])
 | 
			
		||||
 | 
			
		||||
    self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
 | 
			
		||||
    self.assertLen(test_data, 128)
 | 
			
		||||
    self.assertEqual(test_data.num_classes, 3)
 | 
			
		||||
    self.assertEqual(test_data.index_to_label,
 | 
			
		||||
    self.assertEqual(test_data.label_names,
 | 
			
		||||
                     ['angular_leaf_spot', 'bean_rust', 'healthy'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,16 +13,16 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""APIs to train image classifier model."""
 | 
			
		||||
 | 
			
		||||
from typing import Any, List, Optional
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import tensorflow_hub as hub
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
 | 
			
		||||
from mediapipe.model_maker.python.core.tasks import classifier
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
 | 
			
		||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
 | 
			
		||||
| 
						 | 
				
			
			@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
 | 
			
		|||
class ImageClassifier(classifier.Classifier):
 | 
			
		||||
  """ImageClassifier for building image classification model."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
 | 
			
		||||
  def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
 | 
			
		||||
               hparams: hp.HParams):
 | 
			
		||||
    """Initializes ImageClassifier class.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      model_spec: Specification for the model.
 | 
			
		||||
      index_to_label: A list that maps from index to label class name.
 | 
			
		||||
      label_names: A list of label names for the classes.
 | 
			
		||||
      hparams: The hyperparameters for training image classifier.
 | 
			
		||||
    """
 | 
			
		||||
    super(ImageClassifier, self).__init__(
 | 
			
		||||
    super().__init__(
 | 
			
		||||
        model_spec=model_spec,
 | 
			
		||||
        index_to_label=index_to_label,
 | 
			
		||||
        label_names=label_names,
 | 
			
		||||
        shuffle=hparams.shuffle,
 | 
			
		||||
        full_train=hparams.do_fine_tuning)
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
| 
						 | 
				
			
			@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
 | 
			
		|||
 | 
			
		||||
    spec = ms.SupportedModels.get(model_spec)
 | 
			
		||||
    image_classifier = cls(
 | 
			
		||||
        model_spec=spec,
 | 
			
		||||
        index_to_label=train_data.index_to_label,
 | 
			
		||||
        hparams=hparams)
 | 
			
		||||
        model_spec=spec, label_names=train_data.label_names, hparams=hparams)
 | 
			
		||||
 | 
			
		||||
    image_classifier._create_model()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
 | 
			
		|||
  return model.fit(
 | 
			
		||||
      x=train_ds,
 | 
			
		||||
      epochs=hparams.train_epochs,
 | 
			
		||||
      steps_per_epoch=hparams.steps_per_epoch,
 | 
			
		||||
      validation_data=validation_ds,
 | 
			
		||||
      callbacks=callbacks)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -161,7 +161,7 @@ class Texture {
 | 
			
		|||
 | 
			
		||||
  ~Texture() {
 | 
			
		||||
    if (is_owned_) {
 | 
			
		||||
      glDeleteProgram(handle_);
 | 
			
		||||
      glDeleteTextures(1, &handle_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,6 +87,9 @@ cc_library(
 | 
			
		|||
cc_library(
 | 
			
		||||
    name = "builtin_task_graphs",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
 | 
			
		||||
"""The public facing packet getter APIs."""
 | 
			
		||||
 | 
			
		||||
from typing import List, Type
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from google.protobuf import message
 | 
			
		||||
from google.protobuf import symbol_database
 | 
			
		||||
| 
						 | 
				
			
			@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
 | 
			
		|||
get_matrix = _packet_getter.get_matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
 | 
			
		||||
def get_proto(packet: mp_packet.Packet) -> message.Message:
 | 
			
		||||
  """Get the content of a MediaPipe proto Packet as a proto message.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,8 +46,10 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/gpu:gpu_origin_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc:common",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,6 +44,30 @@ cc_library(
 | 
			
		|||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "classification_aggregation_calculator_test",
 | 
			
		||||
    srcs = ["classification_aggregation_calculator_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":classification_aggregation_calculator",
 | 
			
		||||
        ":classification_aggregation_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:output_stream_poller",
 | 
			
		||||
        "//mediapipe/framework:packet",
 | 
			
		||||
        "//mediapipe/framework:timestamp",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings:str_format",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_proto_library(
 | 
			
		||||
    name = "score_calibration_calculator_proto",
 | 
			
		||||
    srcs = ["score_calibration_calculator.proto"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,37 +31,62 @@
 | 
			
		|||
namespace mediapipe {
 | 
			
		||||
namespace api2 {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
 | 
			
		||||
 | 
			
		||||
// Aggregates ClassificationLists into a single ClassificationResult that has
 | 
			
		||||
// 3 dimensions: (classification head, classification timestamp, classification
 | 
			
		||||
// category).
 | 
			
		||||
// Aggregates ClassificationLists into either a ClassificationResult object
 | 
			
		||||
// representing the classification results aggregated by classifier head, or
 | 
			
		||||
// into an std::vector<ClassificationResult> representing the classification
 | 
			
		||||
// results aggregated first by timestamp then by classifier head.
 | 
			
		||||
//
 | 
			
		||||
// Inputs:
 | 
			
		||||
//   CLASSIFICATIONS - ClassificationList
 | 
			
		||||
//   CLASSIFICATIONS - ClassificationList @Multiple
 | 
			
		||||
//     ClassificationList per classification head.
 | 
			
		||||
//   TIMESTAMPS - std::vector<Timestamp> @Optional
 | 
			
		||||
//     The collection of the timestamps that a single ClassificationResult
 | 
			
		||||
//     should aggragate. This stream is optional, and the timestamp information
 | 
			
		||||
//     will only be populated to the ClassificationResult proto when this stream
 | 
			
		||||
//     is connected.
 | 
			
		||||
//     The collection of the timestamps that this calculator should aggregate.
 | 
			
		||||
//     This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
 | 
			
		||||
//     output is used for results. Otherwise as no timestamp aggregation is
 | 
			
		||||
//     required the CLASSIFICATIONS output is used for results.
 | 
			
		||||
//
 | 
			
		||||
// Outputs:
 | 
			
		||||
//   CLASSIFICATION_RESULT - ClassificationResult
 | 
			
		||||
//   CLASSIFICATIONS - ClassificationResult @Optional
 | 
			
		||||
//     The classification results aggregated by head. Must be connected if the
 | 
			
		||||
//     TIMESTAMPS input is not connected, as it signals that timestamp
 | 
			
		||||
//     aggregation is not required.
 | 
			
		||||
//   TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
 | 
			
		||||
//     The classification result aggregated by timestamp, then by head. Must be
 | 
			
		||||
//     connected if the TIMESTAMPS input is connected, as it signals that
 | 
			
		||||
//     timestamp aggregation is required.
 | 
			
		||||
//   // TODO: remove output once migration is over.
 | 
			
		||||
//   CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
 | 
			
		||||
//     The aggregated classification result.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
// Example without timestamp aggregation:
 | 
			
		||||
// node {
 | 
			
		||||
//   calculator: "ClassificationAggregationCalculator"
 | 
			
		||||
//   input_stream: "CLASSIFICATIONS:0:stream_a"
 | 
			
		||||
//   input_stream: "CLASSIFICATIONS:1:stream_b"
 | 
			
		||||
//   input_stream: "CLASSIFICATIONS:2:stream_c"
 | 
			
		||||
//   output_stream: "CLASSIFICATIONS:classifications"
 | 
			
		||||
//   options {
 | 
			
		||||
//    [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
 | 
			
		||||
//      head_names: "head_name_a"
 | 
			
		||||
//      head_names: "head_name_b"
 | 
			
		||||
//      head_names: "head_name_c"
 | 
			
		||||
//    }
 | 
			
		||||
//  }
 | 
			
		||||
// }
 | 
			
		||||
//
 | 
			
		||||
// Example with timestamp aggregation:
 | 
			
		||||
// node {
 | 
			
		||||
//   calculator: "ClassificationAggregationCalculator"
 | 
			
		||||
//   input_stream: "CLASSIFICATIONS:0:stream_a"
 | 
			
		||||
//   input_stream: "CLASSIFICATIONS:1:stream_b"
 | 
			
		||||
//   input_stream: "CLASSIFICATIONS:2:stream_c"
 | 
			
		||||
//   input_stream: "TIMESTAMPS:timestamps"
 | 
			
		||||
//   output_stream: "CLASSIFICATION_RESULT:classification_result"
 | 
			
		||||
//   output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
 | 
			
		||||
//   options {
 | 
			
		||||
//    [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] {
 | 
			
		||||
//    [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
 | 
			
		||||
//      head_names: "head_name_a"
 | 
			
		||||
//      head_names: "head_name_b"
 | 
			
		||||
//      head_names: "head_name_c"
 | 
			
		||||
| 
						 | 
				
			
			@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
 | 
			
		|||
      "CLASSIFICATIONS"};
 | 
			
		||||
  static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
 | 
			
		||||
      "TIMESTAMPS"};
 | 
			
		||||
  static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"};
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut);
 | 
			
		||||
  static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
 | 
			
		||||
      "CLASSIFICATIONS"};
 | 
			
		||||
  static constexpr Output<std::vector<ClassificationResult>>::Optional
 | 
			
		||||
      kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
 | 
			
		||||
  static constexpr Output<ClassificationResult>::Optional
 | 
			
		||||
      kClassificationResultOut{"CLASSIFICATION_RESULT"};
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
 | 
			
		||||
                          kClassificationsOut, kTimestampedClassificationsOut,
 | 
			
		||||
                          kClassificationResultOut);
 | 
			
		||||
 | 
			
		||||
  static absl::Status UpdateContract(CalculatorContract* cc);
 | 
			
		||||
  absl::Status Open(CalculatorContext* cc);
 | 
			
		||||
| 
						 | 
				
			
			@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
 | 
			
		|||
      cached_classifications_;
 | 
			
		||||
 | 
			
		||||
  ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
 | 
			
		||||
  std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
 | 
			
		||||
      CalculatorContext* cc);
 | 
			
		||||
  // TODO: deprecate this function once migration is over.
 | 
			
		||||
  ClassificationResult LegacyConvertToClassificationResult(
 | 
			
		||||
      CalculatorContext* cc);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
absl::Status ClassificationAggregationCalculator::UpdateContract(
 | 
			
		||||
| 
						 | 
				
			
			@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
 | 
			
		|||
        << "The size of classifications input streams should match the "
 | 
			
		||||
           "size of head names specified in the calculator options";
 | 
			
		||||
  }
 | 
			
		||||
  // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
 | 
			
		||||
  // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
 | 
			
		||||
  // not connected. All dependent tasks must be updated to use these outputs
 | 
			
		||||
  // first.
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
 | 
			
		|||
      [](const auto& elem) -> ClassificationList { return elem.Get(); });
 | 
			
		||||
  cached_classifications_[cc->InputTimestamp().Value()] =
 | 
			
		||||
      std::move(classification_lists);
 | 
			
		||||
  if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) {
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  ClassificationResult classification_result;
 | 
			
		||||
  if (time_aggregation_enabled_) {
 | 
			
		||||
    if (kTimestampsIn(cc).IsEmpty()) {
 | 
			
		||||
      return absl::OkStatus();
 | 
			
		||||
    }
 | 
			
		||||
    classification_result = LegacyConvertToClassificationResult(cc);
 | 
			
		||||
    kTimestampedClassificationsOut(cc).Send(
 | 
			
		||||
        ConvertToTimestampedClassificationResults(cc));
 | 
			
		||||
  } else {
 | 
			
		||||
    classification_result = LegacyConvertToClassificationResult(cc);
 | 
			
		||||
    kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
 | 
			
		||||
  }
 | 
			
		||||
  kOut(cc).Send(ConvertToClassificationResult(cc));
 | 
			
		||||
  kClassificationResultOut(cc).Send(classification_result);
 | 
			
		||||
  RET_CHECK(cached_classifications_.empty());
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -136,6 +186,50 @@ ClassificationResult
 | 
			
		|||
ClassificationAggregationCalculator::ConvertToClassificationResult(
 | 
			
		||||
    CalculatorContext* cc) {
 | 
			
		||||
  ClassificationResult result;
 | 
			
		||||
  auto& classification_lists =
 | 
			
		||||
      cached_classifications_[cc->InputTimestamp().Value()];
 | 
			
		||||
  for (int i = 0; i < classification_lists.size(); ++i) {
 | 
			
		||||
    auto classifications = result.add_classifications();
 | 
			
		||||
    classifications->set_head_index(i);
 | 
			
		||||
    if (!head_names_.empty()) {
 | 
			
		||||
      classifications->set_head_name(head_names_[i]);
 | 
			
		||||
    }
 | 
			
		||||
    *classifications->mutable_classification_list() =
 | 
			
		||||
        std::move(classification_lists[i]);
 | 
			
		||||
  }
 | 
			
		||||
  cached_classifications_.erase(cc->InputTimestamp().Value());
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<ClassificationResult>
 | 
			
		||||
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
 | 
			
		||||
    CalculatorContext* cc) {
 | 
			
		||||
  auto timestamps = kTimestampsIn(cc).Get();
 | 
			
		||||
  std::vector<ClassificationResult> results;
 | 
			
		||||
  results.reserve(timestamps.size());
 | 
			
		||||
  for (const auto& timestamp : timestamps) {
 | 
			
		||||
    ClassificationResult result;
 | 
			
		||||
    result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
 | 
			
		||||
    auto& classification_lists = cached_classifications_[timestamp.Value()];
 | 
			
		||||
    for (int i = 0; i < classification_lists.size(); ++i) {
 | 
			
		||||
      auto classifications = result.add_classifications();
 | 
			
		||||
      classifications->set_head_index(i);
 | 
			
		||||
      if (!head_names_.empty()) {
 | 
			
		||||
        classifications->set_head_name(head_names_[i]);
 | 
			
		||||
      }
 | 
			
		||||
      *classifications->mutable_classification_list() =
 | 
			
		||||
          std::move(classification_lists[i]);
 | 
			
		||||
    }
 | 
			
		||||
    cached_classifications_.erase(timestamp.Value());
 | 
			
		||||
    results.push_back(std::move(result));
 | 
			
		||||
  }
 | 
			
		||||
  return results;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ClassificationResult
 | 
			
		||||
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
 | 
			
		||||
    CalculatorContext* cc) {
 | 
			
		||||
  ClassificationResult result;
 | 
			
		||||
  Timestamp first_timestamp(0);
 | 
			
		||||
  std::vector<Timestamp> timestamps;
 | 
			
		||||
  if (time_aggregation_enabled_) {
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
 | 
			
		|||
      entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
 | 
			
		||||
                              1000);
 | 
			
		||||
    }
 | 
			
		||||
    cached_classifications_.erase(timestamp.Value());
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,7 +15,7 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
syntax = "proto2";
 | 
			
		||||
 | 
			
		||||
package mediapipe.tasks;
 | 
			
		||||
package mediapipe;
 | 
			
		||||
 | 
			
		||||
import "mediapipe/framework/calculator.proto";
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,213 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/str_format.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/output_stream_poller.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
#include "mediapipe/framework/timestamp.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::ParseTextProtoOrDie;
 | 
			
		||||
using ::mediapipe::api2::Input;
 | 
			
		||||
using ::mediapipe::api2::Output;
 | 
			
		||||
using ::mediapipe::api2::builder::Graph;
 | 
			
		||||
using ::mediapipe::api2::builder::Source;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		||||
using ::testing::Pointwise;
 | 
			
		||||
 | 
			
		||||
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
 | 
			
		||||
constexpr char kClassificationInput0Name[] = "classifications_0";
 | 
			
		||||
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
 | 
			
		||||
constexpr char kClassificationInput1Name[] = "classifications_1";
 | 
			
		||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
 | 
			
		||||
constexpr char kTimestampsName[] = "timestamps";
 | 
			
		||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		||||
constexpr char kClassificationsName[] = "classifications";
 | 
			
		||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
 | 
			
		||||
constexpr char kTimestampedClassificationsName[] =
 | 
			
		||||
    "timestamped_classifications";
 | 
			
		||||
 | 
			
		||||
ClassificationList MakeClassificationList(int class_index) {
 | 
			
		||||
  return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
 | 
			
		||||
      R"pb(
 | 
			
		||||
        classification { index: %d }
 | 
			
		||||
      )pb",
 | 
			
		||||
      class_index));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class ClassificationAggregationCalculatorTest
 | 
			
		||||
    : public tflite_shims::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  absl::StatusOr<OutputStreamPoller> BuildGraph(
 | 
			
		||||
      bool connect_timestamps = false) {
 | 
			
		||||
    Graph graph;
 | 
			
		||||
    auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
 | 
			
		||||
    calculator
 | 
			
		||||
        .GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
 | 
			
		||||
        ParseTextProtoOrDie<
 | 
			
		||||
            mediapipe::ClassificationAggregationCalculatorOptions>(
 | 
			
		||||
            R"pb(head_names: "foo" head_names: "bar")pb");
 | 
			
		||||
    graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
 | 
			
		||||
        kClassificationInput0Name) >>
 | 
			
		||||
        calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
 | 
			
		||||
    graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
 | 
			
		||||
        kClassificationInput1Name) >>
 | 
			
		||||
        calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
 | 
			
		||||
    if (connect_timestamps) {
 | 
			
		||||
      graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
 | 
			
		||||
          kTimestampsName) >>
 | 
			
		||||
          calculator.In(kTimestampsTag);
 | 
			
		||||
      calculator.Out(kTimestampedClassificationsTag)
 | 
			
		||||
              .SetName(kTimestampedClassificationsName) >>
 | 
			
		||||
          graph[Output<std::vector<ClassificationResult>>(
 | 
			
		||||
              kTimestampedClassificationsTag)];
 | 
			
		||||
    } else {
 | 
			
		||||
      calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
 | 
			
		||||
          graph[Output<ClassificationResult>(kClassificationsTag)];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
 | 
			
		||||
    if (connect_timestamps) {
 | 
			
		||||
      ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
 | 
			
		||||
                                        kTimestampedClassificationsName));
 | 
			
		||||
      MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
 | 
			
		||||
      return poller;
 | 
			
		||||
    }
 | 
			
		||||
    ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
 | 
			
		||||
                                      kClassificationsName));
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
 | 
			
		||||
    return poller;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::Status Send(
 | 
			
		||||
      std::vector<ClassificationList> classifications, int timestamp = 0,
 | 
			
		||||
      std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
 | 
			
		||||
        kClassificationInput0Name,
 | 
			
		||||
        MakePacket<ClassificationList>(classifications[0])
 | 
			
		||||
            .At(Timestamp(timestamp))));
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
 | 
			
		||||
        kClassificationInput1Name,
 | 
			
		||||
        MakePacket<ClassificationList>(classifications[1])
 | 
			
		||||
            .At(Timestamp(timestamp))));
 | 
			
		||||
    if (aggregation_timestamps.has_value()) {
 | 
			
		||||
      auto packet = std::make_unique<std::vector<Timestamp>>();
 | 
			
		||||
      for (const auto& timestamp : *aggregation_timestamps) {
 | 
			
		||||
        packet->emplace_back(Timestamp(timestamp));
 | 
			
		||||
      }
 | 
			
		||||
      MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
 | 
			
		||||
          kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
 | 
			
		||||
    }
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
 | 
			
		||||
 | 
			
		||||
    Packet packet;
 | 
			
		||||
    if (!poller.Next(&packet)) {
 | 
			
		||||
      return absl::InternalError("Unable to get output packet");
 | 
			
		||||
    }
 | 
			
		||||
    auto result = packet.Get<T>();
 | 
			
		||||
    MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  CalculatorGraph calculator_graph_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
 | 
			
		||||
  MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
 | 
			
		||||
 | 
			
		||||
  EXPECT_THAT(result,
 | 
			
		||||
              EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                  R"pb(classifications {
 | 
			
		||||
                         head_index: 0
 | 
			
		||||
                         head_name: "foo"
 | 
			
		||||
                         classification_list { classification { index: 0 } }
 | 
			
		||||
                       }
 | 
			
		||||
                       classifications {
 | 
			
		||||
                         head_index: 1
 | 
			
		||||
                         head_name: "bar"
 | 
			
		||||
                         classification_list { classification { index: 1 } }
 | 
			
		||||
                       })pb")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
 | 
			
		||||
  MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
 | 
			
		||||
  MP_ASSERT_OK(Send(
 | 
			
		||||
      {MakeClassificationList(2), MakeClassificationList(3)},
 | 
			
		||||
      /*timestamp=*/1000,
 | 
			
		||||
      /*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto result,
 | 
			
		||||
                          GetResult<std::vector<ClassificationResult>>(poller));
 | 
			
		||||
 | 
			
		||||
  EXPECT_THAT(result,
 | 
			
		||||
              Pointwise(EqualsProto(),
 | 
			
		||||
                        {ParseTextProtoOrDie<ClassificationResult>(R"pb(
 | 
			
		||||
                           timestamp_ms: 0,
 | 
			
		||||
                           classifications {
 | 
			
		||||
                             head_index: 0
 | 
			
		||||
                             head_name: "foo"
 | 
			
		||||
                             classification_list { classification { index: 0 } }
 | 
			
		||||
                           }
 | 
			
		||||
                           classifications {
 | 
			
		||||
                             head_index: 1
 | 
			
		||||
                             head_name: "bar"
 | 
			
		||||
                             classification_list { classification { index: 1 } }
 | 
			
		||||
                           }
 | 
			
		||||
                         )pb"),
 | 
			
		||||
                         ParseTextProtoOrDie<ClassificationResult>(R"pb(
 | 
			
		||||
                           timestamp_ms: 1,
 | 
			
		||||
                           classifications {
 | 
			
		||||
                             head_index: 0
 | 
			
		||||
                             head_name: "foo"
 | 
			
		||||
                             classification_list { classification { index: 2 } }
 | 
			
		||||
                           }
 | 
			
		||||
                           classifications {
 | 
			
		||||
                             head_index: 1
 | 
			
		||||
                             head_name: "bar"
 | 
			
		||||
                             classification_list { classification { index: 3 } }
 | 
			
		||||
                           }
 | 
			
		||||
                         )pb")}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -29,3 +29,23 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "category",
 | 
			
		||||
    srcs = ["category.cc"],
 | 
			
		||||
    hdrs = ["category.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "classification_result",
 | 
			
		||||
    srcs = ["classification_result.cc"],
 | 
			
		||||
    hdrs = ["classification_result.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":category",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										38
									
								
								mediapipe/tasks/cc/components/containers/category.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mediapipe/tasks/cc/components/containers/category.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,38 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
Category ConvertToCategory(const mediapipe::Classification& proto) {
 | 
			
		||||
  Category category;
 | 
			
		||||
  category.index = proto.index();
 | 
			
		||||
  category.score = proto.score();
 | 
			
		||||
  if (proto.has_label()) {
 | 
			
		||||
    category.category_name = proto.label();
 | 
			
		||||
  }
 | 
			
		||||
  if (proto.has_display_name()) {
 | 
			
		||||
    category.display_name = proto.display_name();
 | 
			
		||||
  }
 | 
			
		||||
  return category;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
							
								
								
									
										52
									
								
								mediapipe/tasks/cc/components/containers/category.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								mediapipe/tasks/cc/components/containers/category.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,52 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
// Defines a single classification result.
 | 
			
		||||
//
 | 
			
		||||
// The label maps packed into the TFLite Model Metadata [1] are used to populate
 | 
			
		||||
// the 'category_name' and 'display_name' fields.
 | 
			
		||||
//
 | 
			
		||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
 | 
			
		||||
struct Category {
 | 
			
		||||
  // The index of the category in the classification model output.
 | 
			
		||||
  int index;
 | 
			
		||||
  // The score for this category, e.g. (but not necessarily) a probability in
 | 
			
		||||
  // [0,1].
 | 
			
		||||
  float score;
 | 
			
		||||
  // The optional ID for the category, read from the label map packed in the
 | 
			
		||||
  // TFLite Model Metadata if present. Not necessarily human-readable.
 | 
			
		||||
  std::optional<std::string> category_name = std::nullopt;
 | 
			
		||||
  // The optional human-readable name for the category, read from the label map
 | 
			
		||||
  // packed in the TFLite Model Metadata if present.
 | 
			
		||||
  std::optional<std::string> display_name = std::nullopt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Utility function to convert from mediapipe::Classification proto to Category
 | 
			
		||||
// struct.
 | 
			
		||||
Category ConvertToCategory(const mediapipe::Classification& proto);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,57 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
Classifications ConvertToClassifications(const proto::Classifications& proto) {
 | 
			
		||||
  Classifications classifications;
 | 
			
		||||
  classifications.categories.reserve(
 | 
			
		||||
      proto.classification_list().classification_size());
 | 
			
		||||
  for (const auto& classification :
 | 
			
		||||
       proto.classification_list().classification()) {
 | 
			
		||||
    classifications.categories.push_back(ConvertToCategory(classification));
 | 
			
		||||
  }
 | 
			
		||||
  classifications.head_index = proto.head_index();
 | 
			
		||||
  if (proto.has_head_name()) {
 | 
			
		||||
    classifications.head_name = proto.head_name();
 | 
			
		||||
  }
 | 
			
		||||
  return classifications;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ClassificationResult ConvertToClassificationResult(
 | 
			
		||||
    const proto::ClassificationResult& proto) {
 | 
			
		||||
  ClassificationResult classification_result;
 | 
			
		||||
  classification_result.classifications.reserve(proto.classifications_size());
 | 
			
		||||
  for (const auto& classifications : proto.classifications()) {
 | 
			
		||||
    classification_result.classifications.push_back(
 | 
			
		||||
        ConvertToClassifications(classifications));
 | 
			
		||||
  }
 | 
			
		||||
  if (proto.has_timestamp_ms()) {
 | 
			
		||||
    classification_result.timestamp_ms = proto.timestamp_ms();
 | 
			
		||||
  }
 | 
			
		||||
  return classification_result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,68 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
// Defines classification results for a given classifier head.
 | 
			
		||||
struct Classifications {
 | 
			
		||||
  // The array of predicted categories, usually sorted by descending scores,
 | 
			
		||||
  // e.g. from high to low probability.
 | 
			
		||||
  std::vector<Category> categories;
 | 
			
		||||
  // The index of the classifier head (i.e. output tensor) these categories
 | 
			
		||||
  // refer to. This is useful for multi-head models.
 | 
			
		||||
  int head_index;
 | 
			
		||||
  // The optional name of the classifier head, as provided in the TFLite Model
 | 
			
		||||
  // Metadata [1] if present. This is useful for multi-head models.
 | 
			
		||||
  //
 | 
			
		||||
  // [1]: https://www.tensorflow.org/lite/convert/metadata
 | 
			
		||||
  std::optional<std::string> head_name = std::nullopt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Defines classification results of a model.
 | 
			
		||||
struct ClassificationResult {
 | 
			
		||||
  // The classification results for each head of the model.
 | 
			
		||||
  std::vector<Classifications> classifications;
 | 
			
		||||
  // The optional timestamp (in milliseconds) of the start of the chunk of data
 | 
			
		||||
  // corresponding to these results.
 | 
			
		||||
  //
 | 
			
		||||
  // This is only used for classification on time series (e.g. audio
 | 
			
		||||
  // classification). In these use cases, the amount of data to process might
 | 
			
		||||
  // exceed the maximum size that the model can process: to solve this, the
 | 
			
		||||
  // input data is split into multiple chunks starting at different timestamps.
 | 
			
		||||
  std::optional<int64_t> timestamp_ms = std::nullopt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Utility function to convert from Classifications proto to
 | 
			
		||||
// Classifications struct.
 | 
			
		||||
Classifications ConvertToClassifications(const proto::Classifications& proto);
 | 
			
		||||
 | 
			
		||||
// Utility function to convert from ClassificationResult proto to
 | 
			
		||||
// ClassificationResult struct.
 | 
			
		||||
ClassificationResult ConvertToClassificationResult(
 | 
			
		||||
    const proto::ClassificationResult& proto);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -28,6 +28,7 @@ mediapipe_proto_library(
 | 
			
		|||
    srcs = ["classifications.proto"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":category_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,9 +17,10 @@ syntax = "proto2";
 | 
			
		|||
 | 
			
		||||
package mediapipe.tasks.components.containers.proto;
 | 
			
		||||
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
 | 
			
		||||
option java_outer_classname = "CategoryProto";
 | 
			
		||||
 | 
			
		||||
// TODO: deprecate this message once migration is over.
 | 
			
		||||
// A single classification result.
 | 
			
		||||
message Category {
 | 
			
		||||
  // The index of the category in the corresponding label map, usually packed in
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,11 +17,13 @@ syntax = "proto2";
 | 
			
		|||
 | 
			
		||||
package mediapipe.tasks.components.containers.proto;
 | 
			
		||||
 | 
			
		||||
import "mediapipe/framework/formats/classification.proto";
 | 
			
		||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
 | 
			
		||||
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
 | 
			
		||||
option java_outer_classname = "ClassificationsProto";
 | 
			
		||||
 | 
			
		||||
// TODO: deprecate this message once migration is over.
 | 
			
		||||
// List of predicted categories with an optional timestamp.
 | 
			
		||||
message ClassificationEntry {
 | 
			
		||||
  // The array of predicted categories, usually sorted by descending scores,
 | 
			
		||||
| 
						 | 
				
			
			@ -33,9 +35,12 @@ message ClassificationEntry {
 | 
			
		|||
  optional int64 timestamp_ms = 2;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Classifications for a given classifier head.
 | 
			
		||||
// Classifications for a given classifier head, i.e. for a given output tensor.
 | 
			
		||||
message Classifications {
 | 
			
		||||
  // TODO: deprecate this field once migration is over.
 | 
			
		||||
  repeated ClassificationEntry entries = 1;
 | 
			
		||||
  // The classification results for this head.
 | 
			
		||||
  optional mediapipe.ClassificationList classification_list = 4;
 | 
			
		||||
  // The index of the classifier head these categories refer to. This is useful
 | 
			
		||||
  // for multi-head models.
 | 
			
		||||
  optional int32 head_index = 2;
 | 
			
		||||
| 
						 | 
				
			
			@ -45,7 +50,17 @@ message Classifications {
 | 
			
		|||
  optional string head_name = 3;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains one set of results per classifier head.
 | 
			
		||||
// Classifications for a given classifier model.
 | 
			
		||||
message ClassificationResult {
 | 
			
		||||
  // The classification results for each model head, i.e. one for each output
 | 
			
		||||
  // tensor.
 | 
			
		||||
  repeated Classifications classifications = 1;
 | 
			
		||||
  // The optional timestamp (in milliseconds) of the start of the chunk of data
 | 
			
		||||
  // corresponding to these results.
 | 
			
		||||
  //
 | 
			
		||||
  // This is only used for classification on time series (e.g. audio
 | 
			
		||||
  // classification). In these use cases, the amount of data to process might
 | 
			
		||||
  // exceed the maximum size that the model can process: to solve this, the
 | 
			
		||||
  // input data is split into multiple chunks starting at different timestamps.
 | 
			
		||||
  optional int64 timestamp_ms = 2;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,6 +17,9 @@ syntax = "proto2";
 | 
			
		|||
 | 
			
		||||
package mediapipe.tasks.components.containers.proto;
 | 
			
		||||
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
 | 
			
		||||
option java_outer_classname = "EmbeddingsProto";
 | 
			
		||||
 | 
			
		||||
// Defines a dense floating-point embedding.
 | 
			
		||||
message FloatEmbedding {
 | 
			
		||||
  repeated float values = 1 [packed = true];
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,9 +30,11 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/gpu/gpu_origin.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/common.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_resources.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
			
		||||
#include "tensorflow/lite/schema/schema_generated.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator(
 | 
			
		|||
    options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
 | 
			
		||||
                                                          std);
 | 
			
		||||
  }
 | 
			
		||||
  // TODO: need to support different GPU origin on differnt
 | 
			
		||||
  // platforms or applications.
 | 
			
		||||
  options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
bool DetermineImagePreprocessingGpuBackend(
 | 
			
		||||
    const core::proto::Acceleration& acceleration) {
 | 
			
		||||
  return acceleration.has_gpu();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
 | 
			
		||||
                                         bool use_gpu,
 | 
			
		||||
                                         ImagePreprocessingOptions* options) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto image_tensor_specs,
 | 
			
		||||
                   BuildImageTensorSpecs(model_resources));
 | 
			
		||||
| 
						 | 
				
			
			@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
 | 
			
		|||
      image_tensor_specs, options->mutable_image_to_tensor_options()));
 | 
			
		||||
  // The GPU backend isn't able to process int data. If the input tensor is
 | 
			
		||||
  // quantized, forces the image preprocessing graph to use CPU backend.
 | 
			
		||||
  if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) {
 | 
			
		||||
  if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
 | 
			
		||||
    options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
 | 
			
		||||
  } else {
 | 
			
		||||
    options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,20 +19,26 @@ limitations under the License.
 | 
			
		|||
#include "absl/status/status.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_resources.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace components {
 | 
			
		||||
 | 
			
		||||
// Configures an ImagePreprocessing subgraph using the provided model resources.
 | 
			
		||||
// Configures an ImagePreprocessing subgraph using the provided model resources
 | 
			
		||||
// When use_gpu is true, use GPU as backend to convert image to tensor.
 | 
			
		||||
// - Accepts CPU input images and outputs CPU tensors.
 | 
			
		||||
//
 | 
			
		||||
// Example usage:
 | 
			
		||||
//
 | 
			
		||||
//   auto& preprocessing =
 | 
			
		||||
//       graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
 | 
			
		||||
//   core::proto::Acceleration acceleration;
 | 
			
		||||
//   acceleration.mutable_xnnpack();
 | 
			
		||||
//   bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
 | 
			
		||||
//   MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
 | 
			
		||||
//       model_resources,
 | 
			
		||||
//       use_gpu,
 | 
			
		||||
//       &preprocessing.GetOptions<ImagePreprocessingOptions>()));
 | 
			
		||||
//
 | 
			
		||||
// The resulting ImagePreprocessing subgraph has the following I/O:
 | 
			
		||||
| 
						 | 
				
			
			@ -56,9 +62,14 @@ namespace components {
 | 
			
		|||
//     The image that has the pixel data stored on the target storage (CPU vs
 | 
			
		||||
//     GPU).
 | 
			
		||||
absl::Status ConfigureImagePreprocessing(
 | 
			
		||||
    const core::ModelResources& model_resources,
 | 
			
		||||
    const core::ModelResources& model_resources, bool use_gpu,
 | 
			
		||||
    ImagePreprocessingOptions* options);
 | 
			
		||||
 | 
			
		||||
// Determine if the image preprocessing subgraph should use GPU as the backend
 | 
			
		||||
// according to the given acceleration setting.
 | 
			
		||||
bool DetermineImagePreprocessingGpuBackend(
 | 
			
		||||
    const core::proto::Acceleration& acceleration);
 | 
			
		||||
 | 
			
		||||
}  // namespace components
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		|||
constexpr char kScoresTag[] = "SCORES";
 | 
			
		||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
 | 
			
		||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
 | 
			
		||||
 | 
			
		||||
// Struct holding the different output streams produced by the graph.
 | 
			
		||||
struct ClassificationPostprocessingOutputStreams {
 | 
			
		||||
  Source<ClassificationResult> classification_result;
 | 
			
		||||
  Source<ClassificationResult> classifications;
 | 
			
		||||
  Source<std::vector<ClassificationResult>> timestamped_classifications;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Performs sanity checks on provided ClassifierOptions.
 | 
			
		||||
absl::Status SanityCheckClassifierOptions(
 | 
			
		||||
| 
						 | 
				
			
			@ -286,7 +294,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
 | 
			
		|||
 | 
			
		||||
void ConfigureClassificationAggregationCalculator(
 | 
			
		||||
    const ModelMetadataExtractor& metadata_extractor,
 | 
			
		||||
    ClassificationAggregationCalculatorOptions* options) {
 | 
			
		||||
    mediapipe::ClassificationAggregationCalculatorOptions* options) {
 | 
			
		||||
  auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
 | 
			
		||||
  if (output_tensors_metadata == nullptr) {
 | 
			
		||||
    return;
 | 
			
		||||
| 
						 | 
				
			
			@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
 | 
			
		|||
//   TENSORS - std::vector<Tensor>
 | 
			
		||||
//     The output tensors of an InferenceCalculator.
 | 
			
		||||
//   TIMESTAMPS - std::vector<Timestamp> @Optional
 | 
			
		||||
//     The collection of timestamps that a single ClassificationResult should
 | 
			
		||||
//     aggregate. This is mostly useful for classifiers working on time series,
 | 
			
		||||
//     e.g. audio or video classification.
 | 
			
		||||
//     The collection of the timestamps that this calculator should aggregate.
 | 
			
		||||
//     This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
 | 
			
		||||
//     output is used for results. Otherwise as no timestamp aggregation is
 | 
			
		||||
//     required the CLASSIFICATIONS output is used for results.
 | 
			
		||||
//
 | 
			
		||||
// Outputs:
 | 
			
		||||
//   CLASSIFICATION_RESULT - ClassificationResult
 | 
			
		||||
//     The output aggregated classification results.
 | 
			
		||||
//   CLASSIFICATIONS - ClassificationResult @Optional
 | 
			
		||||
//     The classification results aggregated by head. Must be connected if the
 | 
			
		||||
//     TIMESTAMPS input is not connected, as it signals that timestamp
 | 
			
		||||
//     aggregation is not required.
 | 
			
		||||
//   TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
 | 
			
		||||
//     The classification result aggregated by timestamp, then by head. Must be
 | 
			
		||||
//     connected if the TIMESTAMPS input is connected, as it signals that
 | 
			
		||||
//     timestamp aggregation is required.
 | 
			
		||||
//   // TODO: remove output once migration is over.
 | 
			
		||||
//   CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
 | 
			
		||||
//     The aggregated classification result.
 | 
			
		||||
//
 | 
			
		||||
// The recommended way of using this graph is through the GraphBuilder API
 | 
			
		||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
 | 
			
		||||
| 
						 | 
				
			
			@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
 | 
			
		|||
      mediapipe::SubgraphContext* sc) override {
 | 
			
		||||
    Graph graph;
 | 
			
		||||
    ASSIGN_OR_RETURN(
 | 
			
		||||
        auto classification_result_out,
 | 
			
		||||
        auto output_streams,
 | 
			
		||||
        BuildClassificationPostprocessing(
 | 
			
		||||
            sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
 | 
			
		||||
            graph[Input<std::vector<Tensor>>(kTensorsTag)],
 | 
			
		||||
            graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
 | 
			
		||||
    classification_result_out >>
 | 
			
		||||
    output_streams.classification_result >>
 | 
			
		||||
        graph[Output<ClassificationResult>(kClassificationResultTag)];
 | 
			
		||||
    output_streams.classifications >>
 | 
			
		||||
        graph[Output<ClassificationResult>(kClassificationsTag)];
 | 
			
		||||
    output_streams.timestamped_classifications >>
 | 
			
		||||
        graph[Output<std::vector<ClassificationResult>>(
 | 
			
		||||
            kTimestampedClassificationsTag)];
 | 
			
		||||
    return graph.GetConfig();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Adds an on-device classification postprocessing graph into the provided
 | 
			
		||||
  // builder::Graph instance. The classification postprocessing graph takes
 | 
			
		||||
  // tensors (std::vector<mediapipe::Tensor>) as input and returns one output
 | 
			
		||||
  // stream containing the output classification results (ClassificationResult).
 | 
			
		||||
  // tensors (std::vector<mediapipe::Tensor>) and optional timestamps
 | 
			
		||||
  // (std::vector<Timestamp>) as input and returns two output streams:
 | 
			
		||||
  //  - classification results aggregated by classifier head as a
 | 
			
		||||
  //  ClassificationResult proto, used when no timestamps are passed in
 | 
			
		||||
  //    the graph,
 | 
			
		||||
  //  - classification results aggregated by timestamp then by classifier head
 | 
			
		||||
  //    as a std::vector<ClassificationResult>, used when timestamps are passed
 | 
			
		||||
  //    in the graph.
 | 
			
		||||
  //
 | 
			
		||||
  // options: the on-device ClassificationPostprocessingGraphOptions.
 | 
			
		||||
  // tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
 | 
			
		||||
  // timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
 | 
			
		||||
  //   timestamps that a single ClassificationResult should aggregate.
 | 
			
		||||
  //   timestamps that should be used to aggregate classification results.
 | 
			
		||||
  // graph: the mediapipe builder::Graph instance to be updated.
 | 
			
		||||
  absl::StatusOr<Source<ClassificationResult>>
 | 
			
		||||
  absl::StatusOr<ClassificationPostprocessingOutputStreams>
 | 
			
		||||
  BuildClassificationPostprocessing(
 | 
			
		||||
      const proto::ClassificationPostprocessingGraphOptions& options,
 | 
			
		||||
      Source<std::vector<Tensor>> tensors_in,
 | 
			
		||||
| 
						 | 
				
			
			@ -494,7 +524,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
 | 
			
		|||
    // Aggregates Classifications into a single ClassificationResult.
 | 
			
		||||
    auto& result_aggregation =
 | 
			
		||||
        graph.AddNode("ClassificationAggregationCalculator");
 | 
			
		||||
    result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
 | 
			
		||||
    result_aggregation
 | 
			
		||||
        .GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
 | 
			
		||||
        .CopyFrom(options.classification_aggregation_options());
 | 
			
		||||
    for (int i = 0; i < num_heads; ++i) {
 | 
			
		||||
      tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
 | 
			
		||||
| 
						 | 
				
			
			@ -504,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
 | 
			
		|||
    timestamps_in >> result_aggregation.In(kTimestampsTag);
 | 
			
		||||
 | 
			
		||||
    // Connects output.
 | 
			
		||||
    return result_aggregation[Output<ClassificationResult>(
 | 
			
		||||
        kClassificationResultTag)];
 | 
			
		||||
    ClassificationPostprocessingOutputStreams output_streams{
 | 
			
		||||
        /*classification_result=*/result_aggregation
 | 
			
		||||
            [Output<ClassificationResult>(kClassificationResultTag)],
 | 
			
		||||
        /*classifications=*/
 | 
			
		||||
        result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
 | 
			
		||||
        /*timestamped_classifications=*/
 | 
			
		||||
        result_aggregation[Output<std::vector<ClassificationResult>>(
 | 
			
		||||
            kTimestampedClassificationsTag)]};
 | 
			
		||||
    return output_streams;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue
	
	Block a user