Internal change
PiperOrigin-RevId: 477538515
This commit is contained in:
parent
6cdc6443b6
commit
f8af41b1eb
|
@ -157,6 +157,13 @@ http_archive(
|
||||||
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
|
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "pffft",
|
||||||
|
strip_prefix = "jpommier-pffft-7c3b5a7dc510",
|
||||||
|
urls = ["https://bitbucket.org/jpommier/pffft/get/7c3b5a7dc510.zip"],
|
||||||
|
build_file = "@//third_party:pffft.BUILD",
|
||||||
|
)
|
||||||
|
|
||||||
# sentencepiece
|
# sentencepiece
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "com_google_sentencepiece",
|
name = "com_google_sentencepiece",
|
||||||
|
|
|
@ -217,7 +217,7 @@ A list of pose landmarks. Each landmark consists of the following:
|
||||||
|
|
||||||
*Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* |
|
*Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* |
|
||||||
:-----------------------------------------------------------: |
|
:-----------------------------------------------------------: |
|
||||||
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_world_landmarks.mp4" type="video/mp4"></video> |
|
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/mobile/pose_world_landmarks.mp4" type="video/mp4"></video> |
|
||||||
|
|
||||||
Another list of pose landmarks in world coordinates. Each landmark consists of
|
Another list of pose landmarks in world coordinates. Each landmark consists of
|
||||||
the following:
|
the following:
|
||||||
|
@ -238,7 +238,7 @@ for usage details.
|
||||||
|
|
||||||
*Fig 6. Example of MediaPipe Pose segmentation mask.* |
|
*Fig 6. Example of MediaPipe Pose segmentation mask.* |
|
||||||
:---------------------------------------------------: |
|
:---------------------------------------------------: |
|
||||||
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> |
|
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/mobile/pose_segmentation.mp4" type="video/mp4"></video> |
|
||||||
|
|
||||||
### Python Solution API
|
### Python Solution API
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ nav_order: 7
|
||||||
|
|
||||||
*Fig 1. Example of MediaPipe Selfie Segmentation.* |
|
*Fig 1. Example of MediaPipe Selfie Segmentation.* |
|
||||||
:------------------------------------------------: |
|
:------------------------------------------------: |
|
||||||
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/selfie_segmentation_web.mp4" type="video/mp4"></video> |
|
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/selfie_segmentation_web.mp4" type="video/mp4"></video> |
|
||||||
|
|
||||||
MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can
|
MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can
|
||||||
run in real-time on both smartphones and laptops. The intended use cases include
|
run in real-time on both smartphones and laptops. The intended use cases include
|
||||||
|
|
|
@ -1294,8 +1294,8 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":get_vector_item_calculator_cc_proto",
|
":get_vector_item_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:packet",
|
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
@ -1319,6 +1319,32 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "vector_indices_calculator",
|
||||||
|
srcs = ["vector_indices_calculator.cc"],
|
||||||
|
hdrs = ["vector_indices_calculator.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "vector_indices_calculator_test",
|
||||||
|
srcs = ["vector_indices_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":vector_indices_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "vector_size_calculator",
|
name = "vector_size_calculator",
|
||||||
srcs = ["vector_size_calculator.cc"],
|
srcs = ["vector_size_calculator.cc"],
|
||||||
|
|
|
@ -40,6 +40,9 @@ REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator);
|
||||||
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
|
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
|
||||||
REGISTER_CALCULATOR(EndLoopBooleanCalculator);
|
REGISTER_CALCULATOR(EndLoopBooleanCalculator);
|
||||||
|
|
||||||
|
typedef EndLoopCalculator<std::vector<float>> EndLoopFloatCalculator;
|
||||||
|
REGISTER_CALCULATOR(EndLoopFloatCalculator);
|
||||||
|
|
||||||
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
|
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
|
||||||
EndLoopRenderDataCalculator;
|
EndLoopRenderDataCalculator;
|
||||||
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
|
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
|
||||||
|
|
|
@ -24,6 +24,10 @@ using GetLandmarkListVectorItemCalculator =
|
||||||
GetVectorItemCalculator<mediapipe::LandmarkList>;
|
GetVectorItemCalculator<mediapipe::LandmarkList>;
|
||||||
REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator);
|
REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator);
|
||||||
|
|
||||||
|
using GetNormalizedLandmarkListVectorItemCalculator =
|
||||||
|
GetVectorItemCalculator<mediapipe::NormalizedLandmarkList>;
|
||||||
|
REGISTER_CALCULATOR(GetNormalizedLandmarkListVectorItemCalculator);
|
||||||
|
|
||||||
using GetClassificationListVectorItemCalculator =
|
using GetClassificationListVectorItemCalculator =
|
||||||
GetVectorItemCalculator<mediapipe::ClassificationList>;
|
GetVectorItemCalculator<mediapipe::ClassificationList>;
|
||||||
REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator);
|
REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator);
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
|
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
@ -58,7 +59,7 @@ template <typename T>
|
||||||
class GetVectorItemCalculator : public Node {
|
class GetVectorItemCalculator : public Node {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
|
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
|
||||||
static constexpr Input<int>::Optional kIdx{"INDEX"};
|
static constexpr Input<OneOf<int, uint64_t>>::Optional kIdx{"INDEX"};
|
||||||
static constexpr Output<T> kOut{"ITEM"};
|
static constexpr Output<T> kOut{"ITEM"};
|
||||||
|
|
||||||
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
|
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
|
||||||
|
@ -80,7 +81,9 @@ class GetVectorItemCalculator : public Node {
|
||||||
|
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) {
|
if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) {
|
||||||
idx = kIdx(cc).Get();
|
idx = kIdx(cc).Visit(
|
||||||
|
[](uint64_t idx_uint64_t) { return static_cast<int>(idx_uint64_t); },
|
||||||
|
[](int idx_int) { return idx_int; });
|
||||||
} else if (options.has_item_index()) {
|
} else if (options.has_item_index()) {
|
||||||
idx = options.item_index();
|
idx = options.item_index();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -227,4 +227,15 @@ TEST(TestGetIntVectorItemCalculatorTest, IndexOptionsTwoTimestamps) {
|
||||||
testing::ElementsAre(TimestampValue(1), TimestampValue(2)));
|
testing::ElementsAre(TimestampValue(1), TimestampValue(2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TestGetIntVectorItemCalculatorTest, IndexUint64) {
|
||||||
|
CalculatorRunner runner = MakeRunnerWithStream();
|
||||||
|
const std::vector<int> inputs = {1, 2, 3};
|
||||||
|
const uint64_t index = 1;
|
||||||
|
AddInputVector(runner, inputs, 1);
|
||||||
|
AddInputIndex(runner, index, 1);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
|
||||||
|
EXPECT_THAT(outputs, testing::ElementsAre(IntPacket(inputs[index])));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
33
mediapipe/calculators/core/vector_indices_calculator.cc
Normal file
33
mediapipe/calculators/core/vector_indices_calculator.cc
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/core/vector_indices_calculator.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using IntVectorIndicesCalculator = VectorIndicesCalculator<int>;
|
||||||
|
REGISTER_CALCULATOR(IntVectorIndicesCalculator);
|
||||||
|
|
||||||
|
using Uint64tVectorIndicesCalculator = VectorIndicesCalculator<uint64_t>;
|
||||||
|
REGISTER_CALCULATOR(Uint64tVectorIndicesCalculator);
|
||||||
|
|
||||||
|
using NormalizedLandmarkListVectorIndicesCalculator =
|
||||||
|
VectorIndicesCalculator<mediapipe::NormalizedLandmarkList>;
|
||||||
|
REGISTER_CALCULATOR(NormalizedLandmarkListVectorIndicesCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
65
mediapipe/calculators/core/vector_indices_calculator.h
Normal file
65
mediapipe/calculators/core/vector_indices_calculator.h
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
// Calculator that takes a vector and consturct an index range vector based on
|
||||||
|
// the size of the input vector.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// VECTOR - std::vector<T>
|
||||||
|
// Vector whose range of indices to return.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// INDICES - std::vector<int>
|
||||||
|
// Indices vector of the input vector.
|
||||||
|
//
|
||||||
|
// Example config:
|
||||||
|
// node {
|
||||||
|
// calculator: "{SpecificType}VectorIndicesCalculator"
|
||||||
|
// input_stream: "VECTOR:vector"
|
||||||
|
// output_stream: "INDICES:indices"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
template <typename T>
|
||||||
|
class VectorIndicesCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::vector<T>> kVector{"VECTOR"};
|
||||||
|
static constexpr Output<std::vector<int>> kRange{"INDICES"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kVector, kRange);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
// Get the size of the input vector.
|
||||||
|
const int vector_size = kVector(cc).Get().size();
|
||||||
|
std::vector<int> out_idxs(vector_size);
|
||||||
|
std::iota(out_idxs.begin(), out_idxs.end(), 0);
|
||||||
|
kRange(cc).Send(out_idxs);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_
|
87
mediapipe/calculators/core/vector_indices_calculator_test.cc
Normal file
87
mediapipe/calculators/core/vector_indices_calculator_test.cc
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/core/vector_indices_calculator.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::TestParamInfo;
|
||||||
|
using ::testing::TestWithParam;
|
||||||
|
using ::testing::Values;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AddInputVector(CalculatorRunner& runner, const std::vector<T>& inputs,
|
||||||
|
int timestamp) {
|
||||||
|
runner.MutableInputs()->Tag("VECTOR").packets.push_back(
|
||||||
|
MakePacket<std::vector<T>>(inputs).At(Timestamp(timestamp)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TestParams {
|
||||||
|
const std::string test_name;
|
||||||
|
const std::vector<T> inputs;
|
||||||
|
const int timestamp;
|
||||||
|
const std::vector<int> expected_indices;
|
||||||
|
};
|
||||||
|
|
||||||
|
class IntVectorIndicesCalculatorTest
|
||||||
|
: public testing::TestWithParam<TestParams<int>> {};
|
||||||
|
|
||||||
|
TEST_P(IntVectorIndicesCalculatorTest, Succeeds) {
|
||||||
|
CalculatorRunner runner = CalculatorRunner(R"(
|
||||||
|
calculator: "IntVectorIndicesCalculator"
|
||||||
|
input_stream: "VECTOR:vector_stream"
|
||||||
|
output_stream: "INDICES:indices_stream"
|
||||||
|
)");
|
||||||
|
const std::vector<int>& inputs = GetParam().inputs;
|
||||||
|
std::vector<int> expected_indices(inputs.size());
|
||||||
|
AddInputVector(runner, inputs, GetParam().timestamp);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Tag("INDICES").packets;
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_THAT(outputs[0].Get<std::vector<int>>(),
|
||||||
|
testing::ElementsAreArray(GetParam().expected_indices));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
IntVectorIndicesCalculatorTest, IntVectorIndicesCalculatorTest,
|
||||||
|
Values(TestParams<int>{
|
||||||
|
/* test_name= */ "IntVectorIndices",
|
||||||
|
/* inputs= */ {1, 2, 3},
|
||||||
|
/* timestamp= */ 1,
|
||||||
|
/* expected_indices= */ {0, 1, 2},
|
||||||
|
},
|
||||||
|
TestParams<int>{
|
||||||
|
/* test_name= */ "EmptyVector",
|
||||||
|
/* inputs= */ {},
|
||||||
|
/* timestamp= */ 1,
|
||||||
|
/* expected_indices= */ {},
|
||||||
|
}),
|
||||||
|
[](const TestParamInfo<IntVectorIndicesCalculatorTest::ParamType>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -55,6 +55,14 @@ mediapipe_proto_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "audio_to_tensor_calculator",
|
name = "audio_to_tensor_calculator",
|
||||||
srcs = ["audio_to_tensor_calculator.cc"],
|
srcs = ["audio_to_tensor_calculator.cc"],
|
||||||
|
copts = select({
|
||||||
|
# b/215212850
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe/framework:mediapipe_internal",
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
],
|
],
|
||||||
|
@ -67,13 +75,16 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/util:time_series_util",
|
"//mediapipe/util:time_series_util",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_audio_tools//audio/dsp:resampler_q",
|
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||||
|
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||||
"@org_tensorflow//tensorflow/lite/c:common",
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@pffft",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -83,6 +94,7 @@ cc_test(
|
||||||
srcs = ["audio_to_tensor_calculator_test.cc"],
|
srcs = ["audio_to_tensor_calculator_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":audio_to_tensor_calculator",
|
":audio_to_tensor_calculator",
|
||||||
|
":audio_to_tensor_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
|
@ -97,6 +109,58 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "feedback_tensors_calculator_proto",
|
||||||
|
srcs = ["feedback_tensors_calculator.proto"],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "feedback_tensors_calculator",
|
||||||
|
srcs = ["feedback_tensors_calculator.cc"],
|
||||||
|
copts = select({
|
||||||
|
# b/215212850
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":feedback_tensors_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "feedback_tensors_calculator_test",
|
||||||
|
srcs = ["feedback_tensors_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":feedback_tensors_calculator",
|
||||||
|
":feedback_tensors_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "inference_calculator_proto",
|
name = "inference_calculator_proto",
|
||||||
srcs = ["inference_calculator.proto"],
|
srcs = ["inference_calculator.proto"],
|
||||||
|
@ -346,6 +410,10 @@ cc_library(
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This target provides the InferenceCalculator and a default set of implementations tailored for the
|
||||||
|
# current build platforms. More implementations can be added as separate dependencies to a client;
|
||||||
|
# for clients that want a narrower set of implementations than the default should see the comment on
|
||||||
|
# inference_calculator_interface.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "inference_calculator",
|
name = "inference_calculator",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
|
|
@ -12,9 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <math.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -26,6 +25,7 @@
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "audio/dsp/resampler_q.h"
|
#include "audio/dsp/resampler_q.h"
|
||||||
|
#include "audio/dsp/window_functions.h"
|
||||||
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
@ -34,19 +34,60 @@
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/util/time_series_util.h"
|
#include "mediapipe/util/time_series_util.h"
|
||||||
|
#include "pffft.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using Options = ::mediapipe::AudioToTensorCalculatorOptions;
|
||||||
|
using FlushMode = Options::FlushMode;
|
||||||
|
|
||||||
|
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
|
||||||
|
std::vector<float> hann_window(window_size);
|
||||||
|
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
|
||||||
|
if (sqrt_hann) {
|
||||||
|
absl::c_transform(hann_window, hann_window.begin(),
|
||||||
|
[](double x) { return std::sqrt(x); });
|
||||||
|
}
|
||||||
|
return hann_window;
|
||||||
|
}
|
||||||
|
|
||||||
|
// PFFFT only supports transforms for inputs of length N of the form
|
||||||
|
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
|
||||||
|
bool IsValidFftSize(int size) {
|
||||||
|
if (size <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
constexpr int kFactors[] = {2, 3, 5};
|
||||||
|
int factorization[] = {0, 0, 0};
|
||||||
|
int n = static_cast<int>(size);
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
while (n % kFactors[i] == 0) {
|
||||||
|
n = n / kFactors[i];
|
||||||
|
++factorization[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return factorization[0] >= 5 && n == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Converts audio buffers into tensors, possibly with resampling, buffering
|
// Converts audio buffers into tensors, possibly with resampling, buffering
|
||||||
// and framing, according to specified inputs and options. All input audio
|
// and framing, according to specified inputs and options. All input audio
|
||||||
// buffers will be first resampled from the input sample rate to the target
|
// buffers will be first resampled from the input sample rate to the target
|
||||||
// sample rate if they are not equal. The resampled audio data (with the
|
// sample rate if they are not equal. The resampled audio data (with the
|
||||||
// buffered samples from the previous runs in the streaming mode) will be broken
|
// buffered samples from the previous runs in the streaming mode) will be broken
|
||||||
// into fixed-sized, possibly overlapping frames. Finally, all frames will be
|
// into fixed-sized, possibly overlapping frames. If the calculator is not asked
|
||||||
// converted to and outputted as MediaPipe Tensors. The last output tensor will
|
// to perform fft (the fft_size is not set in the calculator options), all
|
||||||
// be zero-padding if the remaining samples are insufficient.
|
// frames will be converted to and outputted as MediaPipe Tensors. The last
|
||||||
|
// output tensor will be zero-padding if the remaining samples are insufficient.
|
||||||
|
// Otherwise, when the fft_size is set and valid, the calculator will perform
|
||||||
|
// fft on the fixed-sized audio frames, the complex DFT results will be
|
||||||
|
// converted to and outputted as 2D MediaPipe float Tensors where the first
|
||||||
|
// rows are the DFT real parts and the second rows are the DFT imagery parts.
|
||||||
//
|
//
|
||||||
// This calculator assumes that the input timestamps refer to the first
|
// This calculator assumes that the input timestamps refer to the first
|
||||||
// sample in each Matrix. The output timestamps follow this same convention.
|
// sample in each Matrix. The output timestamps follow this same convention.
|
||||||
|
@ -86,11 +127,15 @@ namespace api2 {
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// Vector containing a single Tensor that represents a fix-sized audio
|
// Vector containing a single Tensor that represents a fix-sized audio
|
||||||
// frame.
|
// frame or the complex DFT results.
|
||||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
// Vector containing the output timestamps emitted by the current Process()
|
// Vector containing the output timestamps emitted by the current Process()
|
||||||
// invocation. In the non-streaming mode, the vector contains all of the
|
// invocation. In the non-streaming mode, the vector contains all of the
|
||||||
// output timestamps for an input audio buffer.
|
// output timestamps for an input audio buffer.
|
||||||
|
// DC_AND_NYQUIST - std::pair<float, float> @Optional.
|
||||||
|
// A pair of dc component and nyquest component. Only can be connected when
|
||||||
|
// the calculator performs fft (the fft_size is set in the calculator
|
||||||
|
// options).
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
|
@ -116,12 +161,14 @@ class AudioToTensorCalculator : public Node {
|
||||||
// such as sample rate.
|
// such as sample rate.
|
||||||
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
|
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
|
||||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||||
|
static constexpr Output<std::pair<float, float>>::Optional kDcAndNyquistOut{
|
||||||
|
"DC_AND_NYQUIST"};
|
||||||
// A vector of the output timestamps emitted by the current Process()
|
// A vector of the output timestamps emitted by the current Process()
|
||||||
// invocation. The packet timestamp is the last emitted timestamp.
|
// invocation. The packet timestamp is the last emitted timestamp.
|
||||||
static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{
|
static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{
|
||||||
"TIMESTAMPS"};
|
"TIMESTAMPS"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut,
|
MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut,
|
||||||
kTimestampsOut);
|
kDcAndNyquistOut, kTimestampsOut);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
absl::Status Open(CalculatorContext* cc);
|
absl::Status Open(CalculatorContext* cc);
|
||||||
|
@ -138,6 +185,9 @@ class AudioToTensorCalculator : public Node {
|
||||||
int frame_step_;
|
int frame_step_;
|
||||||
bool stream_mode_;
|
bool stream_mode_;
|
||||||
bool check_inconsistent_timestamps_;
|
bool check_inconsistent_timestamps_;
|
||||||
|
int padding_samples_before_;
|
||||||
|
int padding_samples_after_;
|
||||||
|
FlushMode flush_mode_;
|
||||||
Timestamp initial_timestamp_ = Timestamp::Unstarted();
|
Timestamp initial_timestamp_ = Timestamp::Unstarted();
|
||||||
int64 cumulative_input_samples_ = 0;
|
int64 cumulative_input_samples_ = 0;
|
||||||
Timestamp next_output_timestamp_ = Timestamp::Unstarted();
|
Timestamp next_output_timestamp_ = Timestamp::Unstarted();
|
||||||
|
@ -151,22 +201,33 @@ class AudioToTensorCalculator : public Node {
|
||||||
Matrix sample_buffer_;
|
Matrix sample_buffer_;
|
||||||
int processed_buffer_cols_ = 0;
|
int processed_buffer_cols_ = 0;
|
||||||
|
|
||||||
|
// The internal state of the FFT library.
|
||||||
|
PFFFT_Setup* fft_state_ = nullptr;
|
||||||
|
int fft_size_ = 0;
|
||||||
|
std::vector<float> fft_window_;
|
||||||
|
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
|
||||||
|
// pffft requires memory to work with to avoid using the stack.
|
||||||
|
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
|
||||||
|
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
|
||||||
|
|
||||||
absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input);
|
absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input);
|
||||||
absl::Status ProcessNonStreamingData(CalculatorContext* cc,
|
absl::Status ProcessNonStreamingData(CalculatorContext* cc,
|
||||||
const Matrix& input);
|
const Matrix& input);
|
||||||
|
|
||||||
absl::Status SetupStreamingResampler(double input_sample_rate_);
|
absl::Status SetupStreamingResampler(double input_sample_rate_);
|
||||||
void AppendToSampleBuffer(Matrix buffer_to_append);
|
void AppendToSampleBuffer(Matrix buffer_to_append);
|
||||||
|
void AppendZerosToSampleBuffer(int num_samples);
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Tensor>> ConvertToTensor(
|
absl::StatusOr<std::vector<Tensor>> ConvertToTensor(
|
||||||
const Matrix& frame_to_convert);
|
const Matrix& block, std::vector<int> tensor_dims);
|
||||||
absl::Status OutputTensors(const Matrix& buffer, bool should_flush,
|
absl::Status OutputTensor(const Matrix& block, Timestamp timestamp,
|
||||||
|
CalculatorContext* cc);
|
||||||
|
absl::Status ProcessBuffer(const Matrix& buffer, bool should_flush,
|
||||||
CalculatorContext* cc);
|
CalculatorContext* cc);
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
|
absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
|
||||||
const auto& options =
|
const auto& options = cc->Options<Options>();
|
||||||
cc->Options<mediapipe::AudioToTensorCalculatorOptions>();
|
|
||||||
if (!options.has_num_channels() || !options.has_num_samples() ||
|
if (!options.has_num_channels() || !options.has_num_samples() ||
|
||||||
!options.has_target_sample_rate()) {
|
!options.has_target_sample_rate()) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
|
@ -174,13 +235,21 @@ absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
|
||||||
"`num_channels`, `num_samples`, and `target_sample_rate`.");
|
"`num_channels`, `num_samples`, and `target_sample_rate`.");
|
||||||
}
|
}
|
||||||
if (options.stream_mode()) {
|
if (options.stream_mode()) {
|
||||||
// Explicitly disables tiemstamp offset to disallow the timestamp bound
|
// Explicitly disables timestamp offset to disallow the timestamp bound
|
||||||
// from the input streams to be propagated to the output streams.
|
// from the input streams to be propagated to the output streams.
|
||||||
// In the streaming mode, the output timestamp bound is based on
|
// In the streaming mode, the output timestamp bound is based on
|
||||||
// next_output_timestamp_, which can be smaller than the current input
|
// next_output_timestamp_, which can be smaller than the current input
|
||||||
// timestamps.
|
// timestamps.
|
||||||
cc->SetTimestampOffset(TimestampDiff::Unset());
|
cc->SetTimestampOffset(TimestampDiff::Unset());
|
||||||
}
|
}
|
||||||
|
if (options.padding_samples_before() < 0 ||
|
||||||
|
options.padding_samples_after() < 0) {
|
||||||
|
return absl::InvalidArgumentError("Negative zero padding unsupported");
|
||||||
|
}
|
||||||
|
if (options.flush_mode() != Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX &&
|
||||||
|
options.flush_mode() != Options::PROCEED_AS_USUAL) {
|
||||||
|
return absl::InvalidArgumentError("Unsupported flush mode");
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,6 +271,9 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
|
||||||
check_inconsistent_timestamps_ = options.check_inconsistent_timestamps();
|
check_inconsistent_timestamps_ = options.check_inconsistent_timestamps();
|
||||||
sample_buffer_.resize(num_channels_, Eigen::NoChange);
|
sample_buffer_.resize(num_channels_, Eigen::NoChange);
|
||||||
}
|
}
|
||||||
|
padding_samples_before_ = options.padding_samples_before();
|
||||||
|
padding_samples_after_ = options.padding_samples_after();
|
||||||
|
flush_mode_ = options.flush_mode();
|
||||||
|
|
||||||
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
|
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
|
||||||
!kAudioIn(cc).Header().IsEmpty())
|
!kAudioIn(cc).Header().IsEmpty())
|
||||||
|
@ -217,6 +289,25 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
|
||||||
source_sample_rate_ = input_header.sample_rate();
|
source_sample_rate_ = input_header.sample_rate();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
AppendZerosToSampleBuffer(padding_samples_before_);
|
||||||
|
if (options.has_fft_size()) {
|
||||||
|
RET_CHECK(IsValidFftSize(options.fft_size()))
|
||||||
|
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
|
||||||
|
">=0 and c >= 0 and a >= 5, the requested fft size is "
|
||||||
|
<< options.fft_size();
|
||||||
|
RET_CHECK_EQ(1, num_channels_)
|
||||||
|
<< "Currently only support applying FFT on mono channel.";
|
||||||
|
fft_size_ = options.fft_size();
|
||||||
|
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
|
||||||
|
fft_window_ = HannWindow(fft_size_, /* sqrt_hann = */ false);
|
||||||
|
fft_input_buffer_.resize(fft_size_);
|
||||||
|
fft_workplace_.resize(fft_size_);
|
||||||
|
fft_output_.resize(fft_size_);
|
||||||
|
} else {
|
||||||
|
RET_CHECK(!kDcAndNyquistOut(cc).IsConnected())
|
||||||
|
<< "The DC_AND_NYQUIST output stream can only be connected when the "
|
||||||
|
"calculator outputs fft tensors";
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,7 +353,12 @@ absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) {
|
||||||
resampler_->Flush(&resampled_buffer);
|
resampler_->Flush(&resampled_buffer);
|
||||||
AppendToSampleBuffer(std::move(resampled_buffer));
|
AppendToSampleBuffer(std::move(resampled_buffer));
|
||||||
}
|
}
|
||||||
return OutputTensors(sample_buffer_, /*should_flush=*/true, cc);
|
AppendZerosToSampleBuffer(padding_samples_after_);
|
||||||
|
MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/true, cc));
|
||||||
|
if (fft_state_) {
|
||||||
|
pffft_destroy_setup(fft_state_);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status AudioToTensorCalculator::ProcessStreamingData(
|
absl::Status AudioToTensorCalculator::ProcessStreamingData(
|
||||||
|
@ -303,7 +399,7 @@ absl::Status AudioToTensorCalculator::ProcessStreamingData(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc));
|
MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/false, cc));
|
||||||
// Removes the processed samples from the global sample buffer.
|
// Removes the processed samples from the global sample buffer.
|
||||||
sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() -
|
sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() -
|
||||||
processed_buffer_cols_ - 1));
|
processed_buffer_cols_ - 1));
|
||||||
|
@ -323,9 +419,9 @@ absl::Status AudioToTensorCalculator::ProcessNonStreamingData(
|
||||||
input_frame);
|
input_frame);
|
||||||
Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_,
|
Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_,
|
||||||
resampled.size() / num_channels_);
|
resampled.size() / num_channels_);
|
||||||
return OutputTensors(matrix_mapping, /*should_flush=*/true, cc);
|
return ProcessBuffer(matrix_mapping, /*should_flush=*/true, cc);
|
||||||
}
|
}
|
||||||
return OutputTensors(input_frame, /*should_flush=*/true, cc);
|
return ProcessBuffer(input_frame, /*should_flush=*/true, cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status AudioToTensorCalculator::SetupStreamingResampler(
|
absl::Status AudioToTensorCalculator::SetupStreamingResampler(
|
||||||
|
@ -344,6 +440,16 @@ absl::Status AudioToTensorCalculator::SetupStreamingResampler(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) {
|
||||||
|
CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`.
|
||||||
|
if (num_samples == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sample_buffer_.conservativeResize(Eigen::NoChange,
|
||||||
|
sample_buffer_.cols() + num_samples);
|
||||||
|
sample_buffer_.rightCols(num_samples).setZero();
|
||||||
|
}
|
||||||
|
|
||||||
void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
|
void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
|
||||||
sample_buffer_.conservativeResize(
|
sample_buffer_.conservativeResize(
|
||||||
Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols());
|
Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols());
|
||||||
|
@ -351,49 +457,89 @@ void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor(
|
absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor(
|
||||||
const Matrix& frame_to_convert) {
|
const Matrix& block, std::vector<int> tensor_dims) {
|
||||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape(tensor_dims));
|
||||||
Tensor::Shape({num_channels_, num_samples_}));
|
|
||||||
auto buffer_view = tensor.GetCpuWriteView();
|
auto buffer_view = tensor.GetCpuWriteView();
|
||||||
if (frame_to_convert.size() < num_channels_ * num_samples_) {
|
int total_size = 1;
|
||||||
|
for (int dim : tensor_dims) {
|
||||||
|
total_size *= dim;
|
||||||
|
}
|
||||||
|
if (block.size() < total_size) {
|
||||||
std::memset(buffer_view.buffer<float>(), 0, tensor.bytes());
|
std::memset(buffer_view.buffer<float>(), 0, tensor.bytes());
|
||||||
}
|
}
|
||||||
std::memcpy(buffer_view.buffer<float>(), frame_to_convert.data(),
|
std::memcpy(buffer_view.buffer<float>(), block.data(),
|
||||||
frame_to_convert.size() * sizeof(float));
|
block.size() * sizeof(float));
|
||||||
std::vector<Tensor> tensor_vector;
|
std::vector<Tensor> tensor_vector;
|
||||||
tensor_vector.push_back(std::move(tensor));
|
tensor_vector.push_back(std::move(tensor));
|
||||||
return tensor_vector;
|
return tensor_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer,
|
absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block,
|
||||||
|
Timestamp timestamp,
|
||||||
|
CalculatorContext* cc) {
|
||||||
|
std::vector<Tensor> output_tensor;
|
||||||
|
if (fft_state_) {
|
||||||
|
Eigen::VectorXf time_series_data =
|
||||||
|
Eigen::VectorXf::Map(block.data(), block.size());
|
||||||
|
// Window on input audio prior to FFT.
|
||||||
|
std::transform(time_series_data.begin(), time_series_data.end(),
|
||||||
|
fft_window_.begin(), fft_input_buffer_.begin(),
|
||||||
|
std::multiplies<float>());
|
||||||
|
pffft_transform_ordered(fft_state_, fft_input_buffer_.data(),
|
||||||
|
fft_output_.data(), fft_workplace_.data(),
|
||||||
|
PFFFT_FORWARD);
|
||||||
|
if (kDcAndNyquistOut(cc).IsConnected()) {
|
||||||
|
kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]),
|
||||||
|
timestamp);
|
||||||
|
}
|
||||||
|
Matrix fft_output_matrix =
|
||||||
|
Eigen::Map<const Matrix>(fft_output_.data() + 2, 1, fft_size_ - 2);
|
||||||
|
fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_);
|
||||||
|
// The last two elements are the DFT Nyquist values.
|
||||||
|
fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part
|
||||||
|
fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part
|
||||||
|
ASSIGN_OR_RETURN(output_tensor,
|
||||||
|
ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2}));
|
||||||
|
} else {
|
||||||
|
ASSIGN_OR_RETURN(output_tensor,
|
||||||
|
ConvertToTensor(block, {num_channels_, num_samples_}));
|
||||||
|
}
|
||||||
|
kTensorsOut(cc).Send(std::move(output_tensor), timestamp);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status AudioToTensorCalculator::ProcessBuffer(const Matrix& buffer,
|
||||||
bool should_flush,
|
bool should_flush,
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
|
const bool should_flush_at_timestamp_max =
|
||||||
|
stream_mode_ && should_flush &&
|
||||||
|
flush_mode_ == Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX;
|
||||||
int next_frame_first_col = 0;
|
int next_frame_first_col = 0;
|
||||||
std::vector<Timestamp> timestamps;
|
std::vector<Timestamp> timestamps;
|
||||||
while ((!stream_mode_ || !should_flush) &&
|
if (!should_flush_at_timestamp_max) {
|
||||||
next_frame_first_col + num_samples_ <= buffer.cols()) {
|
while (next_frame_first_col + num_samples_ <= buffer.cols()) {
|
||||||
ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block(
|
MP_RETURN_IF_ERROR(OutputTensor(
|
||||||
0, next_frame_first_col,
|
buffer.block(0, next_frame_first_col, num_channels_, num_samples_),
|
||||||
num_channels_, num_samples_)));
|
next_output_timestamp_, cc));
|
||||||
kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_);
|
|
||||||
timestamps.push_back(next_output_timestamp_);
|
timestamps.push_back(next_output_timestamp_);
|
||||||
next_output_timestamp_ += round(frame_step_ / target_sample_rate_ *
|
next_output_timestamp_ += round(frame_step_ / target_sample_rate_ *
|
||||||
Timestamp::kTimestampUnitsPerSecond);
|
Timestamp::kTimestampUnitsPerSecond);
|
||||||
next_frame_first_col += frame_step_;
|
next_frame_first_col += frame_step_;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (should_flush && next_frame_first_col < buffer.cols()) {
|
if (should_flush && next_frame_first_col < buffer.cols()) {
|
||||||
ASSIGN_OR_RETURN(auto output_tensor,
|
|
||||||
ConvertToTensor(buffer.block(
|
|
||||||
0, next_frame_first_col, num_channels_,
|
|
||||||
std::min(num_samples_,
|
|
||||||
(int)buffer.cols() - next_frame_first_col))));
|
|
||||||
// In the streaming mode, the flush happens in Close() and a packet at
|
// In the streaming mode, the flush happens in Close() and a packet at
|
||||||
// Timestamp::Max() will be emitted. In the non-streaming mode, each
|
// Timestamp::Max() will be emitted. In the non-streaming mode, each
|
||||||
// Process() invocation will process the entire buffer completely.
|
// Process() invocation will process the entire buffer completely.
|
||||||
Timestamp timestamp =
|
Timestamp timestamp = should_flush_at_timestamp_max
|
||||||
stream_mode_ ? Timestamp::Max() : next_output_timestamp_;
|
? Timestamp::Max()
|
||||||
|
: next_output_timestamp_;
|
||||||
|
MP_RETURN_IF_ERROR(OutputTensor(
|
||||||
|
buffer.block(
|
||||||
|
0, next_frame_first_col, num_channels_,
|
||||||
|
std::min(num_samples_, (int)buffer.cols() - next_frame_first_col)),
|
||||||
|
timestamp, cc));
|
||||||
timestamps.push_back(timestamp);
|
timestamps.push_back(timestamp);
|
||||||
kTensorsOut(cc).Send(std::move(output_tensor), timestamp);
|
|
||||||
}
|
}
|
||||||
if (kTimestampsOut(cc).IsConnected()) {
|
if (kTimestampsOut(cc).IsConnected()) {
|
||||||
Timestamp timestamp = timestamps.back();
|
Timestamp timestamp = timestamps.back();
|
||||||
|
|
|
@ -44,4 +44,28 @@ message AudioToTensorCalculatorOptions {
|
||||||
// Set to false to disable checks for jitter in timestamp values. Useful with
|
// Set to false to disable checks for jitter in timestamp values. Useful with
|
||||||
// live audio input.
|
// live audio input.
|
||||||
optional bool check_inconsistent_timestamps = 6 [default = true];
|
optional bool check_inconsistent_timestamps = 6 [default = true];
|
||||||
|
|
||||||
|
// Size of the fft in number of bins. If set, the calculator outputs fft
|
||||||
|
// tensors.
|
||||||
|
optional int64 fft_size = 7;
|
||||||
|
|
||||||
|
// The amount of padding samples to add before the audio after resampling.
|
||||||
|
// Note that the timestamps shift. Currently, only zero padding is supported.
|
||||||
|
optional int64 padding_samples_before = 8;
|
||||||
|
|
||||||
|
// The amount of padding samples to add after the audio after resampling.
|
||||||
|
// Currently, only zero padding is supported.
|
||||||
|
optional int64 padding_samples_after = 9;
|
||||||
|
|
||||||
|
// Determines the "flushing" behavior in stream mode.
|
||||||
|
enum FlushMode {
|
||||||
|
// Unspecified (causes an error). Won't be used because of the default.
|
||||||
|
NONE = 0;
|
||||||
|
// Emit a packet with the entire remainder at `Timestamp::Max`.
|
||||||
|
ENTIRE_TAIL_AT_TIMESTAMP_MAX = 1;
|
||||||
|
// Continue emitting framed packets with relevant timestamps.
|
||||||
|
PROCEED_AS_USUAL = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX];
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,13 +12,13 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "audio/dsp/resampler_q.h"
|
#include "audio/dsp/resampler_q.h"
|
||||||
|
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
@ -32,6 +32,14 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::Not;
|
||||||
|
using Options = ::mediapipe::AudioToTensorCalculatorOptions;
|
||||||
|
using FlushMode = Options::FlushMode;
|
||||||
|
|
||||||
|
int DivideRoundedUp(int dividend, int divisor) {
|
||||||
|
return (dividend + divisor - 1) / divisor;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples,
|
std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples,
|
||||||
int timestamp) {
|
int timestamp) {
|
||||||
auto matrix = std::make_unique<Matrix>(num_channels, num_samples);
|
auto matrix = std::make_unique<Matrix>(num_channels, num_samples);
|
||||||
|
@ -292,16 +300,17 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||||
num_iterations_ = num_iterations;
|
num_iterations_ = num_iterations;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetExpectedNumOfSamples() {
|
int GetExpectedNumOfSamples() { return output_sample_buffer_->cols(); }
|
||||||
Matrix* expected_matrix =
|
|
||||||
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
|
|
||||||
return expected_matrix->cols();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Run(int num_samples, int num_overlapping_samples,
|
void Run(int num_samples, int num_overlapping_samples,
|
||||||
double resampling_factor) {
|
double resampling_factor, int padding_before = 0,
|
||||||
|
int padding_after = 0, bool expect_init_error = false) {
|
||||||
double input_sample_rate = 10000;
|
double input_sample_rate = 10000;
|
||||||
double target_sample_rate = input_sample_rate * resampling_factor;
|
double target_sample_rate = input_sample_rate * resampling_factor;
|
||||||
|
FlushMode flush_mode = (padding_before != 0 || padding_after != 0)
|
||||||
|
? Options::PROCEED_AS_USUAL
|
||||||
|
: Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX;
|
||||||
|
|
||||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
absl::Substitute(R"(
|
absl::Substitute(R"(
|
||||||
input_stream: "audio"
|
input_stream: "audio"
|
||||||
|
@ -319,16 +328,25 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||||
num_overlapping_samples: $1
|
num_overlapping_samples: $1
|
||||||
target_sample_rate: $2
|
target_sample_rate: $2
|
||||||
stream_mode:true
|
stream_mode:true
|
||||||
|
padding_samples_before: $3
|
||||||
|
padding_samples_after: $4
|
||||||
|
flush_mode: $5
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
/*$0=*/num_samples, /*$1=*/num_overlapping_samples,
|
/*$0=*/num_samples, /*$1=*/num_overlapping_samples,
|
||||||
/*$2=*/target_sample_rate));
|
/*$2=*/target_sample_rate, /*$3=*/padding_before,
|
||||||
|
/*$4=*/padding_after, /*$5=*/flush_mode));
|
||||||
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
|
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
|
||||||
|
|
||||||
// Run the graph.
|
// Run the graph.
|
||||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
const absl::Status init_status = graph_.Initialize(graph_config);
|
||||||
|
if (expect_init_error) {
|
||||||
|
EXPECT_THAT(init_status, Not(IsOk()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MP_ASSERT_OK(init_status);
|
||||||
MP_ASSERT_OK(graph_.StartRun({}));
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
for (int i = 0; i < num_iterations_; ++i) {
|
for (int i = 0; i < num_iterations_; ++i) {
|
||||||
Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i);
|
Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i);
|
||||||
|
@ -345,8 +363,18 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
||||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
if (resampling_factor != 1) {
|
if (resampling_factor == 1) {
|
||||||
resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor);
|
output_sample_buffer_ = std::make_unique<Matrix>(*sample_buffer_);
|
||||||
|
} else {
|
||||||
|
output_sample_buffer_ =
|
||||||
|
ResampleBuffer(*sample_buffer_, resampling_factor);
|
||||||
|
}
|
||||||
|
if (padding_before != 0 || padding_after != 0) {
|
||||||
|
Matrix padded = Matrix::Zero(
|
||||||
|
2, padding_before + output_sample_buffer_->cols() + padding_after);
|
||||||
|
padded.block(0, padding_before, 2, output_sample_buffer_->cols()) =
|
||||||
|
*output_sample_buffer_;
|
||||||
|
output_sample_buffer_->swap(padded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -372,14 +400,12 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||||
auto buffer = output_tensor.GetCpuReadView().buffer<float>();
|
auto buffer = output_tensor.GetCpuReadView().buffer<float>();
|
||||||
int num_values = output_tensor.shape().num_elements();
|
int num_values = output_tensor.shape().num_elements();
|
||||||
std::vector<float> output_floats(buffer, buffer + num_values);
|
std::vector<float> output_floats(buffer, buffer + num_values);
|
||||||
Matrix* expected_matrix =
|
|
||||||
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
|
|
||||||
for (int i = 0; i < num_values; ++i) {
|
for (int i = 0; i < num_values; ++i) {
|
||||||
if (i + sample_offset >= expected_matrix->size()) {
|
if (i + sample_offset >= output_sample_buffer_->size()) {
|
||||||
EXPECT_FLOAT_EQ(output_floats[i], 0);
|
EXPECT_FLOAT_EQ(output_floats[i], 0);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_NEAR(output_floats[i],
|
EXPECT_NEAR(output_floats[i],
|
||||||
expected_matrix->coeff((i + sample_offset) % 2,
|
output_sample_buffer_->coeff((i + sample_offset) % 2,
|
||||||
(i + sample_offset) / 2),
|
(i + sample_offset) / 2),
|
||||||
0.001)
|
0.001)
|
||||||
<< "i=" << i << ", sample_offset=" << sample_offset
|
<< "i=" << i << ", sample_offset=" << sample_offset
|
||||||
|
@ -391,7 +417,8 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||||
|
|
||||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||||
// after calling WaitUntilDone().
|
// after calling WaitUntilDone().
|
||||||
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
|
absl::Status TryCloseGraph() { return graph_.WaitUntilDone(); }
|
||||||
|
void CloseGraph() { MP_EXPECT_OK(TryCloseGraph()); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int input_buffer_num_samples_ = 10;
|
int input_buffer_num_samples_ = 10;
|
||||||
|
@ -399,7 +426,7 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||||
CalculatorGraph graph_;
|
CalculatorGraph graph_;
|
||||||
std::vector<Packet> tensors_packets_;
|
std::vector<Packet> tensors_packets_;
|
||||||
std::unique_ptr<Matrix> sample_buffer_;
|
std::unique_ptr<Matrix> sample_buffer_;
|
||||||
std::unique_ptr<Matrix> resampled_buffer_;
|
std::unique_ptr<Matrix> output_sample_buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||||
|
@ -408,7 +435,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||||
/*resampling_factor=*/1.0f);
|
/*resampling_factor=*/1.0f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/10,
|
/*sample_offset=*/10,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 5),
|
||||||
/*timestamp_interval=*/500,
|
/*timestamp_interval=*/500,
|
||||||
/*output_last_at_close=*/false);
|
/*output_last_at_close=*/false);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
|
@ -419,7 +446,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) {
|
||||||
/*resampling_factor=*/1.0f);
|
/*resampling_factor=*/1.0f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/12,
|
/*sample_offset=*/12,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 6),
|
||||||
/*timestamp_interval=*/600,
|
/*timestamp_interval=*/600,
|
||||||
/*output_last_at_close=*/true);
|
/*output_last_at_close=*/true);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
|
@ -431,7 +458,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) {
|
||||||
/*resampling_factor=*/1.0f);
|
/*resampling_factor=*/1.0f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/16,
|
/*sample_offset=*/16,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 8),
|
||||||
/*timestamp_interval=*/800,
|
/*timestamp_interval=*/800,
|
||||||
/*output_last_at_close=*/true);
|
/*output_last_at_close=*/true);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
|
@ -443,7 +470,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) {
|
||||||
/*resampling_factor=*/0.5f);
|
/*resampling_factor=*/0.5f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/512,
|
/*sample_offset=*/512,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256),
|
||||||
/*timestamp_interval=*/51200,
|
/*timestamp_interval=*/51200,
|
||||||
/*output_last_at_close=*/true);
|
/*output_last_at_close=*/true);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
|
@ -455,7 +482,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) {
|
||||||
/*resampling_factor=*/0.5f);
|
/*resampling_factor=*/0.5f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/384,
|
/*sample_offset=*/384,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192),
|
||||||
/*timestamp_interval=*/38400,
|
/*timestamp_interval=*/38400,
|
||||||
/*output_last_at_close=*/true);
|
/*output_last_at_close=*/true);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
|
@ -467,7 +494,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) {
|
||||||
/*resampling_factor=*/2.0f);
|
/*resampling_factor=*/2.0f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/512,
|
/*sample_offset=*/512,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256),
|
||||||
/*timestamp_interval=*/12800,
|
/*timestamp_interval=*/12800,
|
||||||
/*output_last_at_close=*/true);
|
/*output_last_at_close=*/true);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
|
@ -479,12 +506,33 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) {
|
||||||
/*resampling_factor=*/2.0f);
|
/*resampling_factor=*/2.0f);
|
||||||
CheckTensorsOutputPackets(
|
CheckTensorsOutputPackets(
|
||||||
/*sample_offset=*/384,
|
/*sample_offset=*/384,
|
||||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192),
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192),
|
||||||
/*timestamp_interval=*/9600,
|
/*timestamp_interval=*/9600,
|
||||||
/*output_last_at_close=*/true);
|
/*output_last_at_close=*/true);
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||||
|
UpsamplingWithOverlappingAndPadding) {
|
||||||
|
SetInputBufferNumSamplesPerChannel(1024);
|
||||||
|
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
|
||||||
|
/*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/999);
|
||||||
|
CheckTensorsOutputPackets(
|
||||||
|
/*sample_offset=*/384,
|
||||||
|
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192),
|
||||||
|
/*timestamp_interval=*/9600,
|
||||||
|
/*output_last_at_close=*/false);
|
||||||
|
CloseGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AudioToTensorCalculatorStreamingModeTest, NegativePaddingUnsupported) {
|
||||||
|
SetInputBufferNumSamplesPerChannel(1024);
|
||||||
|
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
|
||||||
|
/*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/-3,
|
||||||
|
/*expect_init_error=*/true);
|
||||||
|
EXPECT_THAT(TryCloseGraph(), Not(IsOk()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||||
OnlyOutputInCloseIfNoSufficientSamples) {
|
OnlyOutputInCloseIfNoSufficientSamples) {
|
||||||
SetNumIterations(1);
|
SetNumIterations(1);
|
||||||
|
@ -498,5 +546,122 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||||
CloseGraph();
|
CloseGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class AudioToTensorCalculatorFftTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
// Creates an audio matrix containing a single sample of 1.0 at a specified
|
||||||
|
// offset.
|
||||||
|
std::unique_ptr<Matrix> CreateImpulseSignalData(int64 num_samples,
|
||||||
|
int impulse_offset_idx) {
|
||||||
|
Matrix impulse = Matrix::Zero(1, num_samples);
|
||||||
|
impulse(0, impulse_offset_idx) = 1.0;
|
||||||
|
return std::make_unique<Matrix>(std::move(impulse));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConfigGraph(int num_channels, int num_samples,
|
||||||
|
int num_overlapping_samples, double sample_rate,
|
||||||
|
int fft_size) {
|
||||||
|
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
absl::Substitute(R"(
|
||||||
|
input_stream: "audio"
|
||||||
|
input_stream: "sample_rate"
|
||||||
|
output_stream: "tensors"
|
||||||
|
output_stream: "dc_and_nyquist"
|
||||||
|
node {
|
||||||
|
calculator: "AudioToTensorCalculator"
|
||||||
|
input_stream: "AUDIO:audio"
|
||||||
|
input_stream: "SAMPLE_RATE:sample_rate"
|
||||||
|
output_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||||
|
options {
|
||||||
|
[mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||||
|
num_channels: $0
|
||||||
|
num_samples: $1
|
||||||
|
num_overlapping_samples: $2
|
||||||
|
target_sample_rate: $3
|
||||||
|
fft_size: $4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)",
|
||||||
|
/*$0=*/num_channels,
|
||||||
|
/*$1=*/num_samples,
|
||||||
|
/*$2=*/num_overlapping_samples,
|
||||||
|
/*$3=*/sample_rate, /*$4=*/fft_size));
|
||||||
|
std::vector<Packet> tensors_packets;
|
||||||
|
tool::AddVectorSink("tensors", &graph_config_, &tensors_packets_);
|
||||||
|
std::vector<Packet> dc_and_nyquist_packets;
|
||||||
|
tool::AddVectorSink("dc_and_nyquist", &graph_config_,
|
||||||
|
&dc_and_nyquist_packets_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunGraph(std::unique_ptr<Matrix> input_data, double sample_rate) {
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"audio", MakePacket<Matrix>(*input_data).At(Timestamp(0))));
|
||||||
|
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
ASSERT_EQ(tensors_packets_.size(), dc_and_nyquist_packets_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||||
|
// after calling WaitUntilDone().
|
||||||
|
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
|
||||||
|
|
||||||
|
std::vector<Packet> tensors_packets_;
|
||||||
|
std::vector<Packet> dc_and_nyquist_packets_;
|
||||||
|
CalculatorGraphConfig graph_config_;
|
||||||
|
CalculatorGraph graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(AudioToTensorCalculatorFftTest, TestInvalidFftSize) {
|
||||||
|
ConfigGraph(1, 320, 160, 16000, 103);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
|
auto status = graph_.WaitUntilIdle();
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
::testing::HasSubstr("FFT size must be of the form"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AudioToTensorCalculatorFftTest, TestInvalidNumChannels) {
|
||||||
|
ConfigGraph(3, 320, 160, 16000, 256);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
|
auto status = graph_.WaitUntilIdle();
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.message(),
|
||||||
|
::testing::HasSubstr("only support applying FFT on mono channel"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AudioToTensorCalculatorFftTest, TestImpulseSignal) {
|
||||||
|
constexpr double sample_rate = 16000;
|
||||||
|
ConfigGraph(1, 320, 160, sample_rate, 320);
|
||||||
|
RunGraph(CreateImpulseSignalData(320, 160), sample_rate);
|
||||||
|
for (int i = 0; i < tensors_packets_.size(); ++i) {
|
||||||
|
const auto& tensors = tensors_packets_[i].Get<std::vector<Tensor>>();
|
||||||
|
ASSERT_EQ(1, tensors.size());
|
||||||
|
const Tensor& output_tensor =
|
||||||
|
tensors_packets_[0].Get<std::vector<Tensor>>()[0];
|
||||||
|
auto* buffer = output_tensor.GetCpuReadView().buffer<float>();
|
||||||
|
int num_values = output_tensor.shape().num_elements();
|
||||||
|
const std::vector<float> output_floats(buffer, buffer + num_values);
|
||||||
|
// Impulse signal should have (approximately) const power across all
|
||||||
|
// frequency bins.
|
||||||
|
const auto& pair =
|
||||||
|
dc_and_nyquist_packets_[i].Get<std::pair<float, float>>();
|
||||||
|
EXPECT_FLOAT_EQ(pair.first, 1.0f);
|
||||||
|
EXPECT_FLOAT_EQ(pair.second, 1.0f);
|
||||||
|
for (int j = 0; j < num_values / 2; ++j) {
|
||||||
|
std::complex<float> cf(output_floats[j * 2], output_floats[j * 2 + 1]);
|
||||||
|
EXPECT_FLOAT_EQ(std::norm(cf), 1.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CloseGraph();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
165
mediapipe/calculators/tensor/feedback_tensors_calculator.cc
Normal file
165
mediapipe/calculators/tensor/feedback_tensors_calculator.cc
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constexpr char kInputTensorsTag[] = "INPUT_TENSORS";
|
||||||
|
constexpr char kFeedbackTensorsTag[] = "FEEDBACK_TENSORS";
|
||||||
|
constexpr char kOutputTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
|
using Tensors = std::vector<Tensor>;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// FeedbackTensorsCalculator groups the input and the feedback (typically
|
||||||
|
// recurrent neural network cell state output tensors from the previous run)
|
||||||
|
// tensor vectors as the input tensor vector for the next recurrent model cell
|
||||||
|
// inference. On the first step, the feedback tensor is filled with zeros to
|
||||||
|
// jumpstart the loop.
|
||||||
|
class FeedbackTensorsCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<Tensors> kFeedbackTensorsIn{kFeedbackTensorsTag};
|
||||||
|
static constexpr Input<Tensors> kInputTensorsIn{kInputTensorsTag};
|
||||||
|
static constexpr Output<Tensors> kTensorsOut{kOutputTensorsTag};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kFeedbackTensorsIn, kInputTensorsIn, kTensorsOut);
|
||||||
|
|
||||||
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
const auto& options =
|
||||||
|
cc->Options<mediapipe::FeedbackTensorsCalculatorOptions>();
|
||||||
|
|
||||||
|
const auto& shape_dims = options.feedback_tensor_shape().dims();
|
||||||
|
feedback_tensor_shape_.dims.assign(shape_dims.begin(), shape_dims.end());
|
||||||
|
feedback_tensor_size_ = feedback_tensor_shape_.num_elements();
|
||||||
|
|
||||||
|
num_feedback_tensors_ = options.num_feedback_tensors();
|
||||||
|
|
||||||
|
feedback_tensors_location_ = options.location();
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
if (feedback_tensors_location_ ==
|
||||||
|
mediapipe::FeedbackTensorsCalculatorOptions::NONE) {
|
||||||
|
kTensorsOut(cc).Send(kInputTensorsIn(cc).packet().As<Tensors>());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
switch (feedback_tensors_location_) {
|
||||||
|
case mediapipe::FeedbackTensorsCalculatorOptions::PREPENDED:
|
||||||
|
MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs));
|
||||||
|
MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs));
|
||||||
|
break;
|
||||||
|
case mediapipe::FeedbackTensorsCalculatorOptions::APPENDED:
|
||||||
|
MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs));
|
||||||
|
MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"Unsupported feedback tensors location");
|
||||||
|
}
|
||||||
|
kTensorsOut(cc).Send(std::move(outputs));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::Status AddInputTensors(CalculatorContext* cc,
|
||||||
|
std::vector<Tensor>& outputs) {
|
||||||
|
absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> input_tensors =
|
||||||
|
cc->Inputs()
|
||||||
|
.Tag(kInputTensorsTag)
|
||||||
|
.Value()
|
||||||
|
.Consume<std::vector<Tensor>>();
|
||||||
|
if (!input_tensors.ok()) {
|
||||||
|
return absl::InternalError("The input tensors packet is not consumable");
|
||||||
|
}
|
||||||
|
RET_CHECK(*input_tensors);
|
||||||
|
std::vector<Tensor>& inputs = **input_tensors;
|
||||||
|
outputs.insert(outputs.end(), std::make_move_iterator(inputs.begin()),
|
||||||
|
std::make_move_iterator(inputs.end()));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status AddFeedbackTensors(CalculatorContext* cc,
|
||||||
|
std::vector<Tensor>& outputs) {
|
||||||
|
if (first_run_) {
|
||||||
|
for (int index = 0; index < num_feedback_tensors_; ++index) {
|
||||||
|
Tensor initial_feedback_tensor(Tensor::ElementType::kFloat32,
|
||||||
|
feedback_tensor_shape_);
|
||||||
|
float* data = initial_feedback_tensor.GetCpuWriteView().buffer<float>();
|
||||||
|
std::fill_n(data, feedback_tensor_size_, 0.0f);
|
||||||
|
outputs.push_back(std::move(initial_feedback_tensor));
|
||||||
|
}
|
||||||
|
first_run_ = false;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_feedback_tensors_ != kFeedbackTensorsIn(cc)->size()) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The number of tensors fed back differs from the configuration");
|
||||||
|
}
|
||||||
|
absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> feedback_tensors =
|
||||||
|
cc->Inputs()
|
||||||
|
.Tag(kFeedbackTensorsTag)
|
||||||
|
.Value()
|
||||||
|
.Consume<std::vector<Tensor>>();
|
||||||
|
if (!feedback_tensors.ok()) {
|
||||||
|
return absl::InternalError(
|
||||||
|
"The feedback tensors packet is not consumable");
|
||||||
|
}
|
||||||
|
RET_CHECK(*feedback_tensors);
|
||||||
|
std::vector<Tensor>& feedbacks = **feedback_tensors;
|
||||||
|
for (const auto& feedback : feedbacks) {
|
||||||
|
if (feedback.shape().dims != feedback_tensor_shape_.dims) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The shape of a tensor fed back differs from the configuration");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputs.insert(outputs.end(), std::make_move_iterator(feedbacks.begin()),
|
||||||
|
std::make_move_iterator(feedbacks.end()));
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::Shape feedback_tensor_shape_;
|
||||||
|
int num_feedback_tensors_ = 0;
|
||||||
|
mediapipe::FeedbackTensorsCalculatorOptions::FeedbackTensorsLocation
|
||||||
|
feedback_tensors_location_;
|
||||||
|
|
||||||
|
int feedback_tensor_size_ = 0;
|
||||||
|
bool first_run_ = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(FeedbackTensorsCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,47 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message FeedbackTensorsCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional FeedbackTensorsCalculatorOptions ext = 474496252;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents the dimensions of a tensor starting from the outermost size.
|
||||||
|
message TensorShape {
|
||||||
|
repeated int32 dims = 1 [packed = true];
|
||||||
|
}
|
||||||
|
|
||||||
|
// The shape of the feedback tensors to add.
|
||||||
|
optional TensorShape feedback_tensor_shape = 1;
|
||||||
|
// The number of the feedback tensors to add.
|
||||||
|
optional int32 num_feedback_tensors = 2 [default = 1];
|
||||||
|
|
||||||
|
enum FeedbackTensorsLocation {
|
||||||
|
// The feedback tensors will not be added.
|
||||||
|
NONE = 0;
|
||||||
|
// The feedback tensors will be added before the input tensors.
|
||||||
|
PREPENDED = 1;
|
||||||
|
// The feedback tensors will be added after the input tensors.
|
||||||
|
APPENDED = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines the location of the feedback tensor(s) in the output vector.
|
||||||
|
optional FeedbackTensorsLocation location = 3 [default = APPENDED];
|
||||||
|
}
|
389
mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc
Normal file
389
mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc
Normal file
|
@ -0,0 +1,389 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
using ::testing::Not;
|
||||||
|
using Tensors = std::vector<Tensor>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TensorElementType {
|
||||||
|
static constexpr Tensor::ElementType value = Tensor::ElementType::kNone;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TensorElementType<float> {
|
||||||
|
static constexpr Tensor::ElementType value = Tensor::ElementType::kFloat32;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TensorElementType<std::int8_t> {
|
||||||
|
static constexpr Tensor::ElementType value = Tensor::ElementType::kInt8;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TensorElementType<std::uint8_t> {
|
||||||
|
static constexpr Tensor::ElementType value = Tensor::ElementType::kUInt8;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TensorElementType<std::int32_t> {
|
||||||
|
static constexpr Tensor::ElementType value = Tensor::ElementType::kInt32;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Tensor MakeTensor(std::initializer_list<int> shape,
|
||||||
|
std::initializer_list<T> values) {
|
||||||
|
Tensor tensor(TensorElementType<T>::value, shape);
|
||||||
|
CHECK_EQ(values.size(), tensor.shape().num_elements())
|
||||||
|
<< "The size of `values` is incompatible with `shape`";
|
||||||
|
absl::c_copy(values, tensor.GetCpuWriteView().buffer<T>());
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ValidateTensor(const Tensor& tensor,
|
||||||
|
const std::vector<int>& expected_shape,
|
||||||
|
const std::vector<T>& expected_values) {
|
||||||
|
ASSERT_EQ(tensor.element_type(), TensorElementType<T>::value);
|
||||||
|
EXPECT_EQ(tensor.shape().dims, expected_shape);
|
||||||
|
EXPECT_EQ(tensor.shape().num_elements(), expected_values.size());
|
||||||
|
|
||||||
|
auto* tensor_buffer = tensor.GetCpuReadView().buffer<T>();
|
||||||
|
const std::vector<T> tensor_values(
|
||||||
|
tensor_buffer, tensor_buffer + tensor.shape().num_elements());
|
||||||
|
EXPECT_THAT(tensor_values, ElementsAreArray(expected_values));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FeedbackTensorsCalculatorTest, AppendsFeedback) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
input_stream: "feedback"
|
||||||
|
node {
|
||||||
|
calculator: "FeedbackTensorsCalculator"
|
||||||
|
input_stream: "INPUT_TENSORS:input"
|
||||||
|
input_stream: "FEEDBACK_TENSORS:feedback"
|
||||||
|
output_stream: "TENSORS:output"
|
||||||
|
options: {
|
||||||
|
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
|
||||||
|
feedback_tensor_shape: { dims: 2 dims: 3 }
|
||||||
|
location: APPENDED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
auto initial_input_tensors = std::make_unique<Tensors>();
|
||||||
|
initial_input_tensors->push_back(
|
||||||
|
MakeTensor<std::int32_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
|
||||||
|
// At the beginning, the loopback packet with the model feedback is missing.
|
||||||
|
// The calculator has to assume it's all-zero with the shape from the options.
|
||||||
|
|
||||||
|
auto later_input_tensors = std::make_unique<Tensors>();
|
||||||
|
later_input_tensors->push_back(
|
||||||
|
MakeTensor<std::int32_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
|
||||||
|
auto later_feedback_tensors = std::make_unique<Tensors>();
|
||||||
|
later_feedback_tensors->push_back(
|
||||||
|
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams())
|
||||||
|
<< "Couldn't close the graph inputs";
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run";
|
||||||
|
|
||||||
|
ASSERT_EQ(output_packets.size(), 2);
|
||||||
|
|
||||||
|
const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>();
|
||||||
|
ASSERT_EQ(initial_combined_tensors.size(), 2);
|
||||||
|
ValidateTensor<std::int32_t>(initial_combined_tensors[0],
|
||||||
|
/*expected_shape=*/{2, 4},
|
||||||
|
/*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
// The initial feedback is zero.
|
||||||
|
ValidateTensor<float>(initial_combined_tensors[1], /*expected_shape=*/{2, 3},
|
||||||
|
/*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
|
const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>();
|
||||||
|
ASSERT_EQ(later_combined_tensors.size(), 2);
|
||||||
|
ValidateTensor<std::int32_t>(later_combined_tensors[0],
|
||||||
|
/*expected_shape=*/{2, 4},
|
||||||
|
/*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1});
|
||||||
|
// Afterwards, the provided feedback is passed through.
|
||||||
|
ValidateTensor<float>(
|
||||||
|
later_combined_tensors[1], /*expected_shape=*/{2, 3},
|
||||||
|
/*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FeedbackTensorsCalculatorTest, PrependsFeedback) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
input_stream: "feedback"
|
||||||
|
node {
|
||||||
|
calculator: "FeedbackTensorsCalculator"
|
||||||
|
input_stream: "INPUT_TENSORS:input"
|
||||||
|
input_stream: "FEEDBACK_TENSORS:feedback"
|
||||||
|
output_stream: "TENSORS:output"
|
||||||
|
options: {
|
||||||
|
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
|
||||||
|
feedback_tensor_shape: { dims: 3 dims: 2 }
|
||||||
|
location: PREPENDED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
auto initial_input_tensors = std::make_unique<Tensors>();
|
||||||
|
initial_input_tensors->push_back(
|
||||||
|
MakeTensor<std::int8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
|
||||||
|
// At the beginning, the loopback packet with the model feedback is missing.
|
||||||
|
// The calculator has to assume it's all-zero with the shape from the options.
|
||||||
|
|
||||||
|
auto later_input_tensors = std::make_unique<Tensors>();
|
||||||
|
later_input_tensors->push_back(
|
||||||
|
MakeTensor<std::int8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
|
||||||
|
auto later_feedback_tensors = std::make_unique<Tensors>();
|
||||||
|
later_feedback_tensors->push_back(
|
||||||
|
MakeTensor({3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams())
|
||||||
|
<< "Couldn't close the graph inputs";
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run";
|
||||||
|
|
||||||
|
ASSERT_EQ(output_packets.size(), 2);
|
||||||
|
|
||||||
|
const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>();
|
||||||
|
ASSERT_EQ(initial_combined_tensors.size(), 2);
|
||||||
|
// The initial feedback is zero.
|
||||||
|
ValidateTensor<float>(initial_combined_tensors[0], /*expected_shape=*/{3, 2},
|
||||||
|
/*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
ValidateTensor<std::int8_t>(initial_combined_tensors[1],
|
||||||
|
/*expected_shape=*/{2, 4},
|
||||||
|
/*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
|
||||||
|
const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>();
|
||||||
|
ASSERT_EQ(later_combined_tensors.size(), 2);
|
||||||
|
// Afterwards, the provided feedback is passed through.
|
||||||
|
ValidateTensor<float>(
|
||||||
|
later_combined_tensors[0], /*expected_shape=*/{3, 2},
|
||||||
|
/*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
|
||||||
|
ValidateTensor<std::int8_t>(later_combined_tensors[1],
|
||||||
|
/*expected_shape=*/{2, 4},
|
||||||
|
/*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FeedbackTensorsCalculatorTest, NoFeedback) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
input_stream: "feedback"
|
||||||
|
node {
|
||||||
|
calculator: "FeedbackTensorsCalculator"
|
||||||
|
input_stream: "INPUT_TENSORS:input"
|
||||||
|
input_stream: "FEEDBACK_TENSORS:feedback"
|
||||||
|
output_stream: "TENSORS:output"
|
||||||
|
options: {
|
||||||
|
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
|
||||||
|
feedback_tensor_shape: { dims: 3 dims: 4 }
|
||||||
|
location: NONE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
auto initial_input_tensors = std::make_unique<Tensors>();
|
||||||
|
initial_input_tensors->push_back(
|
||||||
|
MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
|
||||||
|
// At the beginning, the loopback packet with the model feedback is missing.
|
||||||
|
|
||||||
|
auto later_input_tensors = std::make_unique<Tensors>();
|
||||||
|
later_input_tensors->push_back(
|
||||||
|
MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
|
||||||
|
// This feedback should be ignored due to `location: NONE`.
|
||||||
|
auto later_feedback_tensors = std::make_unique<Tensors>();
|
||||||
|
later_feedback_tensors->push_back(
|
||||||
|
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams())
|
||||||
|
<< "Couldn't close the graph inputs";
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run";
|
||||||
|
|
||||||
|
ASSERT_EQ(output_packets.size(), 2);
|
||||||
|
|
||||||
|
const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>();
|
||||||
|
ASSERT_EQ(initial_combined_tensors.size(), 1);
|
||||||
|
ValidateTensor<std::uint8_t>(initial_combined_tensors[0],
|
||||||
|
/*expected_shape=*/{2, 4},
|
||||||
|
/*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
// No feedback due to `location: NONE`.
|
||||||
|
|
||||||
|
const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>();
|
||||||
|
ASSERT_EQ(later_combined_tensors.size(), 1);
|
||||||
|
ValidateTensor<std::uint8_t>(later_combined_tensors[0],
|
||||||
|
/*expected_shape=*/{2, 4},
|
||||||
|
/*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FeedbackTensorsCalculatorTest, ChecksTensorNumber) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
input_stream: "feedback"
|
||||||
|
node {
|
||||||
|
calculator: "FeedbackTensorsCalculator"
|
||||||
|
input_stream: "INPUT_TENSORS:input"
|
||||||
|
input_stream: "FEEDBACK_TENSORS:feedback"
|
||||||
|
output_stream: "TENSORS:output"
|
||||||
|
options: {
|
||||||
|
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
|
||||||
|
num_feedback_tensors: 2
|
||||||
|
feedback_tensor_shape: { dims: 2 dims: 3 }
|
||||||
|
location: PREPENDED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
auto initial_input_tensors = std::make_unique<Tensors>();
|
||||||
|
initial_input_tensors->push_back(
|
||||||
|
MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
|
||||||
|
// At the beginning, the loopback packet with the model feedback is missing.
|
||||||
|
|
||||||
|
auto later_input_tensors = std::make_unique<Tensors>();
|
||||||
|
later_input_tensors->push_back(
|
||||||
|
MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
|
||||||
|
// This feedback should be ignored due to `location: NONE`.
|
||||||
|
auto later_feedback_tensors = std::make_unique<Tensors>();
|
||||||
|
later_feedback_tensors->push_back(
|
||||||
|
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams())
|
||||||
|
<< "Couldn't close the graph inputs";
|
||||||
|
EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk()))
|
||||||
|
<< "Tensor number mismatch missed";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FeedbackTensorsCalculatorTest, ChecksShape) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
input_stream: "feedback"
|
||||||
|
node {
|
||||||
|
calculator: "FeedbackTensorsCalculator"
|
||||||
|
input_stream: "INPUT_TENSORS:input"
|
||||||
|
input_stream: "FEEDBACK_TENSORS:feedback"
|
||||||
|
output_stream: "TENSORS:output"
|
||||||
|
options: {
|
||||||
|
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
|
||||||
|
feedback_tensor_shape: { dims: 3 dims: 4 }
|
||||||
|
location: APPENDED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
auto initial_input_tensors = std::make_unique<Tensors>();
|
||||||
|
initial_input_tensors->push_back(
|
||||||
|
MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
|
||||||
|
// At the beginning, the loopback packet with the model feedback is missing.
|
||||||
|
|
||||||
|
auto later_input_tensors = std::make_unique<Tensors>();
|
||||||
|
later_input_tensors->push_back(
|
||||||
|
MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
|
||||||
|
// This feedback should be ignored due to `location: NONE`.
|
||||||
|
auto later_feedback_tensors = std::make_unique<Tensors>();
|
||||||
|
later_feedback_tensors->push_back(
|
||||||
|
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams())
|
||||||
|
<< "Couldn't close the graph inputs";
|
||||||
|
EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk()))
|
||||||
|
<< "Tensor shape mismatch missed";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -231,7 +231,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
std::vector<tensorflow::DeviceAttributes> devices;
|
std::vector<tensorflow::DeviceAttributes> devices;
|
||||||
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK());
|
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus());
|
||||||
EXPECT_THAT(devices.size(), 10);
|
EXPECT_THAT(devices.size(), 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -220,7 +220,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
std::vector<tensorflow::DeviceAttributes> devices;
|
std::vector<tensorflow::DeviceAttributes> devices;
|
||||||
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK());
|
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus());
|
||||||
EXPECT_THAT(devices.size(), 10);
|
EXPECT_THAT(devices.size(), 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,7 @@ filegroup(
|
||||||
srcs = [
|
srcs = [
|
||||||
"testdata/anchor_golden_file_0.txt",
|
"testdata/anchor_golden_file_0.txt",
|
||||||
"testdata/anchor_golden_file_1.txt",
|
"testdata/anchor_golden_file_1.txt",
|
||||||
|
"testdata/anchor_golden_file_2.txt",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
|
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
|
||||||
|
@ -24,6 +25,19 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
struct MultiScaleAnchorInfo {
|
||||||
|
int32 level;
|
||||||
|
std::vector<float> aspect_ratios;
|
||||||
|
std::vector<float> scales;
|
||||||
|
std::pair<float, float> base_anchor_size;
|
||||||
|
std::pair<float, float> anchor_stride;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FeatureMapDim {
|
||||||
|
int height;
|
||||||
|
int width;
|
||||||
|
};
|
||||||
|
|
||||||
float CalculateScale(float min_scale, float max_scale, int stride_index,
|
float CalculateScale(float min_scale, float max_scale, int stride_index,
|
||||||
int num_strides) {
|
int num_strides) {
|
||||||
if (num_strides == 1) {
|
if (num_strides == 1) {
|
||||||
|
@ -34,6 +48,71 @@ float CalculateScale(float min_scale, float max_scale, int stride_index,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetNumLayers(const SsdAnchorsCalculatorOptions& options) {
|
||||||
|
if (options.multiscale_anchor_generation()) {
|
||||||
|
return (options.max_level() - options.min_level() + 1);
|
||||||
|
}
|
||||||
|
return options.num_layers();
|
||||||
|
}
|
||||||
|
|
||||||
|
FeatureMapDim GetFeatureMapDimensions(
|
||||||
|
const SsdAnchorsCalculatorOptions& options, int index) {
|
||||||
|
FeatureMapDim feature_map_dims;
|
||||||
|
if (options.feature_map_height_size()) {
|
||||||
|
feature_map_dims.height = options.feature_map_height(index);
|
||||||
|
feature_map_dims.width = options.feature_map_width(index);
|
||||||
|
} else {
|
||||||
|
const int stride = options.strides(index);
|
||||||
|
feature_map_dims.height =
|
||||||
|
std::ceil(1.0f * options.input_size_height() / stride);
|
||||||
|
feature_map_dims.width =
|
||||||
|
std::ceil(1.0f * options.input_size_width() / stride);
|
||||||
|
}
|
||||||
|
return feature_map_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Although we have stride for both x and y, only one value is used for offset
|
||||||
|
// calculation. See
|
||||||
|
// tensorflow_models/object_detection/anchor_generators/multiscale_grid_anchor_generator.py;l=121
|
||||||
|
std::pair<float, float> GetMultiScaleAnchorOffset(
|
||||||
|
const SsdAnchorsCalculatorOptions& options, const float stride,
|
||||||
|
const int level) {
|
||||||
|
std::pair<float, float> result(0., 0.);
|
||||||
|
int denominator = std::pow(2, level);
|
||||||
|
if (options.input_size_height() % denominator == 0 ||
|
||||||
|
options.input_size_height() == 1) {
|
||||||
|
result.first = stride / 2.0;
|
||||||
|
}
|
||||||
|
if (options.input_size_width() % denominator == 0 ||
|
||||||
|
options.input_size_width() == 1) {
|
||||||
|
result.second = stride / 2.0;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NormalizeAnchor(const int input_height, const int input_width,
|
||||||
|
Anchor* anchor) {
|
||||||
|
anchor->set_h(anchor->h() / (float)input_height);
|
||||||
|
anchor->set_w(anchor->w() / (float)input_width);
|
||||||
|
anchor->set_y_center(anchor->y_center() / (float)input_height);
|
||||||
|
anchor->set_x_center(anchor->x_center() / (float)input_width);
|
||||||
|
}
|
||||||
|
|
||||||
|
Anchor CalculateAnchorBox(const int y_center, const int x_center,
|
||||||
|
const float scale, const float aspect_ratio,
|
||||||
|
const std::pair<float, float> base_anchor_size,
|
||||||
|
// y-height first
|
||||||
|
const std::pair<float, float> anchor_stride,
|
||||||
|
const std::pair<float, float> anchor_offset) {
|
||||||
|
Anchor result;
|
||||||
|
float ratio_sqrt = std::sqrt(aspect_ratio);
|
||||||
|
result.set_h(scale * base_anchor_size.first / ratio_sqrt);
|
||||||
|
result.set_w(scale * ratio_sqrt * base_anchor_size.second);
|
||||||
|
result.set_y_center(y_center * anchor_stride.first + anchor_offset.first);
|
||||||
|
result.set_x_center(x_center * anchor_stride.second + anchor_offset.second);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Generate anchors for SSD object detection model.
|
// Generate anchors for SSD object detection model.
|
||||||
|
@ -95,9 +174,77 @@ class SsdAnchorsCalculator : public CalculatorBase {
|
||||||
private:
|
private:
|
||||||
static absl::Status GenerateAnchors(
|
static absl::Status GenerateAnchors(
|
||||||
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options);
|
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options);
|
||||||
|
|
||||||
|
static absl::Status GenerateMultiScaleAnchors(
|
||||||
|
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options);
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(SsdAnchorsCalculator);
|
REGISTER_CALCULATOR(SsdAnchorsCalculator);
|
||||||
|
|
||||||
|
// Generates grid anchors on the fly corresponding to multiple CNN layers as
|
||||||
|
// described in:
|
||||||
|
// "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002)
|
||||||
|
// T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
|
||||||
|
absl::Status SsdAnchorsCalculator::GenerateMultiScaleAnchors(
|
||||||
|
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) {
|
||||||
|
std::vector<MultiScaleAnchorInfo> anchor_infos;
|
||||||
|
for (int i = options.min_level(); i <= options.max_level(); ++i) {
|
||||||
|
MultiScaleAnchorInfo current_anchor_info;
|
||||||
|
// level
|
||||||
|
current_anchor_info.level = i;
|
||||||
|
// aspect_ratios
|
||||||
|
for (const float aspect_ratio : options.aspect_ratios()) {
|
||||||
|
current_anchor_info.aspect_ratios.push_back(aspect_ratio);
|
||||||
|
}
|
||||||
|
|
||||||
|
// scale
|
||||||
|
for (int i = 0; i < options.scales_per_octave(); ++i) {
|
||||||
|
current_anchor_info.scales.push_back(
|
||||||
|
std::pow(2.0, (double)i / (double)options.scales_per_octave()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// anchor stride
|
||||||
|
float anchor_stride = std::pow(2.0, i);
|
||||||
|
current_anchor_info.anchor_stride =
|
||||||
|
std::make_pair(anchor_stride, anchor_stride);
|
||||||
|
|
||||||
|
// base_anchor_size
|
||||||
|
current_anchor_info.base_anchor_size =
|
||||||
|
std::make_pair(anchor_stride * options.anchor_scale(),
|
||||||
|
anchor_stride * options.anchor_scale());
|
||||||
|
anchor_infos.push_back(current_anchor_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned int i = 0; i < anchor_infos.size(); ++i) {
|
||||||
|
FeatureMapDim dimensions = GetFeatureMapDimensions(options, i);
|
||||||
|
for (int y = 0; y < dimensions.height; ++y) {
|
||||||
|
for (int x = 0; x < dimensions.width; ++x) {
|
||||||
|
// loop over combination of scale and aspect ratio
|
||||||
|
for (unsigned int j = 0; j < anchor_infos[i].aspect_ratios.size();
|
||||||
|
++j) {
|
||||||
|
for (unsigned int k = 0; k < anchor_infos[i].scales.size(); ++k) {
|
||||||
|
Anchor anchor = CalculateAnchorBox(
|
||||||
|
/*y_center=*/y, /*x_center=*/x, anchor_infos[i].scales[k],
|
||||||
|
anchor_infos[i].aspect_ratios[j],
|
||||||
|
anchor_infos[i].base_anchor_size,
|
||||||
|
/*anchor_stride=*/anchor_infos[i].anchor_stride,
|
||||||
|
/*anchor_offset=*/
|
||||||
|
GetMultiScaleAnchorOffset(options,
|
||||||
|
anchor_infos[i].anchor_stride.first,
|
||||||
|
anchor_infos[i].level));
|
||||||
|
if (options.normalize_coordinates()) {
|
||||||
|
NormalizeAnchor(options.input_size_height(),
|
||||||
|
options.input_size_width(), &anchor);
|
||||||
|
}
|
||||||
|
anchors->push_back(anchor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status SsdAnchorsCalculator::GenerateAnchors(
|
absl::Status SsdAnchorsCalculator::GenerateAnchors(
|
||||||
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) {
|
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) {
|
||||||
// Verify the options.
|
// Verify the options.
|
||||||
|
@ -106,15 +253,21 @@ absl::Status SsdAnchorsCalculator::GenerateAnchors(
|
||||||
"Both feature map shape and strides are missing. Must provide either "
|
"Both feature map shape and strides are missing. Must provide either "
|
||||||
"one.");
|
"one.");
|
||||||
}
|
}
|
||||||
|
const int kNumLayers = GetNumLayers(options);
|
||||||
|
|
||||||
if (options.feature_map_height_size()) {
|
if (options.feature_map_height_size()) {
|
||||||
if (options.strides_size()) {
|
if (options.strides_size()) {
|
||||||
LOG(ERROR) << "Found feature map shapes. Strides will be ignored.";
|
LOG(ERROR) << "Found feature map shapes. Strides will be ignored.";
|
||||||
}
|
}
|
||||||
CHECK_EQ(options.feature_map_height_size(), options.num_layers());
|
CHECK_EQ(options.feature_map_height_size(), kNumLayers);
|
||||||
CHECK_EQ(options.feature_map_height_size(),
|
CHECK_EQ(options.feature_map_height_size(),
|
||||||
options.feature_map_width_size());
|
options.feature_map_width_size());
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(options.strides_size(), options.num_layers());
|
CHECK_EQ(options.strides_size(), kNumLayers);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.multiscale_anchor_generation()) {
|
||||||
|
return GenerateMultiScaleAnchors(anchors, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
int layer_id = 0;
|
int layer_id = 0;
|
||||||
|
|
|
@ -60,4 +60,30 @@ message SsdAnchorsCalculatorOptions {
|
||||||
// This option can be used when the predicted anchor width and height are in
|
// This option can be used when the predicted anchor width and height are in
|
||||||
// pixels.
|
// pixels.
|
||||||
optional bool fixed_anchor_size = 14 [default = false];
|
optional bool fixed_anchor_size = 14 [default = false];
|
||||||
|
|
||||||
|
// Generates grid anchors on the fly corresponding to multiple CNN layers as
|
||||||
|
// described in:
|
||||||
|
// "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002)
|
||||||
|
// T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
|
||||||
|
optional bool multiscale_anchor_generation = 15 [default = false];
|
||||||
|
|
||||||
|
// minimum level in feature pyramid
|
||||||
|
// for multiscale_anchor_generation only!
|
||||||
|
optional int32 min_level = 16 [default = 3];
|
||||||
|
|
||||||
|
// maximum level in feature pyramid
|
||||||
|
// for multiscale_anchor_generation only!
|
||||||
|
optional int32 max_level = 17 [default = 7];
|
||||||
|
|
||||||
|
// Scale of anchor to feature stride
|
||||||
|
// for multiscale_anchor_generation only!
|
||||||
|
optional float anchor_scale = 18 [default = 4.0];
|
||||||
|
|
||||||
|
// Number of intermediate scale each scale octave
|
||||||
|
// for multiscale_anchor_generation only!
|
||||||
|
optional int32 scales_per_octave = 19 [default = 2];
|
||||||
|
|
||||||
|
// Whether to produce anchors in normalized coordinates.
|
||||||
|
// for multiscale_anchor_generation only!
|
||||||
|
optional bool normalize_coordinates = 20 [default = true];
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,9 +33,6 @@ std::string GetGoldenFilePath(const std::string& filename) {
|
||||||
|
|
||||||
void ParseAnchorsFromText(const std::string& text,
|
void ParseAnchorsFromText(const std::string& text,
|
||||||
std::vector<Anchor>* anchors) {
|
std::vector<Anchor>* anchors) {
|
||||||
const std::string line_delimiter = "\n";
|
|
||||||
const std::string number_delimiter = ",";
|
|
||||||
|
|
||||||
std::istringstream stream(text);
|
std::istringstream stream(text);
|
||||||
std::string line;
|
std::string line;
|
||||||
while (std::getline(stream, line)) {
|
while (std::getline(stream, line)) {
|
||||||
|
@ -64,6 +61,8 @@ void CompareAnchors(const std::vector<Anchor>& anchors_0,
|
||||||
testing::FloatNear(anchor_1.x_center(), 1e-5));
|
testing::FloatNear(anchor_1.x_center(), 1e-5));
|
||||||
EXPECT_THAT(anchor_0.y_center(),
|
EXPECT_THAT(anchor_0.y_center(),
|
||||||
testing::FloatNear(anchor_1.y_center(), 1e-5));
|
testing::FloatNear(anchor_1.y_center(), 1e-5));
|
||||||
|
EXPECT_THAT(anchor_0.h(), testing::FloatNear(anchor_1.h(), 1e-5));
|
||||||
|
EXPECT_THAT(anchor_0.w(), testing::FloatNear(anchor_1.w(), 1e-5));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,4 +147,40 @@ TEST(SsdAnchorCalculatorTest, MobileSSDConfig) {
|
||||||
CompareAnchors(anchors, anchors_golden);
|
CompareAnchors(anchors, anchors_golden);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SsdAnchorCalculatorTest, RetinaNetSSDConfig) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "SsdAnchorsCalculator"
|
||||||
|
output_side_packet: "anchors"
|
||||||
|
options {
|
||||||
|
[mediapipe.SsdAnchorsCalculatorOptions.ext] {
|
||||||
|
input_size_height: 640
|
||||||
|
input_size_width: 640
|
||||||
|
strides: 64
|
||||||
|
strides: 128
|
||||||
|
aspect_ratios: 1.0
|
||||||
|
aspect_ratios: 2.0
|
||||||
|
aspect_ratios: 0.5
|
||||||
|
multiscale_anchor_generation: true
|
||||||
|
min_level: 6
|
||||||
|
max_level: 7
|
||||||
|
anchor_scale: 3.0
|
||||||
|
scales_per_octave: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
const auto& anchors =
|
||||||
|
runner.OutputSidePackets().Index(0).Get<std::vector<Anchor>>();
|
||||||
|
|
||||||
|
std::string anchors_string;
|
||||||
|
MP_EXPECT_OK(mediapipe::file::GetContents(
|
||||||
|
GetGoldenFilePath("anchor_golden_file_2.txt"), &anchors_string));
|
||||||
|
|
||||||
|
std::vector<Anchor> anchors_golden;
|
||||||
|
ParseAnchorsFromText(anchors_string, &anchors_golden);
|
||||||
|
|
||||||
|
CompareAnchors(anchors, anchors_golden);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
1125
mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt
vendored
Normal file
1125
mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -19,6 +19,7 @@
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "tensorflow/lite/allocation.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -32,6 +33,8 @@ namespace mediapipe {
|
||||||
// it to the graph as input side packet or you can use some of
|
// it to the graph as input side packet or you can use some of
|
||||||
// calculators like LocalFileContentsCalculator to get model
|
// calculators like LocalFileContentsCalculator to get model
|
||||||
// blob and use it as input here.
|
// blob and use it as input here.
|
||||||
|
// MODEL_FD - Tflite model file descriptor std::tuple<int, size_t, size_t>
|
||||||
|
// containing (fd, offset, size).
|
||||||
//
|
//
|
||||||
// Output side packets:
|
// Output side packets:
|
||||||
// MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
|
// MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
|
||||||
|
@ -52,17 +55,42 @@ class TfLiteModelCalculator : public CalculatorBase {
|
||||||
std::function<void(tflite::FlatBufferModel*)>>;
|
std::function<void(tflite::FlatBufferModel*)>>;
|
||||||
|
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
|
if (cc->InputSidePackets().HasTag("MODEL_BLOB")) {
|
||||||
cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>();
|
cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
||||||
|
cc->InputSidePackets()
|
||||||
|
.Tag("MODEL_FD")
|
||||||
|
.Set<std::tuple<int, size_t, size_t>>();
|
||||||
|
}
|
||||||
|
|
||||||
cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB");
|
Packet model_packet;
|
||||||
|
std::unique_ptr<tflite::FlatBufferModel> model;
|
||||||
|
|
||||||
|
if (cc->InputSidePackets().HasTag("MODEL_BLOB")) {
|
||||||
|
model_packet = cc->InputSidePackets().Tag("MODEL_BLOB");
|
||||||
const std::string& model_blob = model_packet.Get<std::string>();
|
const std::string& model_blob = model_packet.Get<std::string>();
|
||||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
model = tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
|
||||||
tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
|
|
||||||
model_blob.size());
|
model_blob.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
||||||
|
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
||||||
|
const auto& model_fd =
|
||||||
|
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
||||||
|
auto model_allocation = std::make_unique<tflite::MMAPAllocation>(
|
||||||
|
std::get<0>(model_fd), std::get<1>(model_fd), std::get<2>(model_fd),
|
||||||
|
tflite::DefaultErrorReporter());
|
||||||
|
model = tflite::FlatBufferModel::BuildFromAllocation(
|
||||||
|
std::move(model_allocation), tflite::DefaultErrorReporter());
|
||||||
|
}
|
||||||
|
|
||||||
RET_CHECK(model) << "Failed to load TfLite model from blob.";
|
RET_CHECK(model) << "Failed to load TfLite model from blob.";
|
||||||
|
|
||||||
cc->OutputSidePackets().Tag("MODEL").Set(
|
cc->OutputSidePackets().Tag("MODEL").Set(
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 9.3 KiB |
|
@ -90,6 +90,7 @@
|
||||||
{
|
{
|
||||||
"idiom" : "ipad",
|
"idiom" : "ipad",
|
||||||
"size" : "83.5x83.5",
|
"size" : "83.5x83.5",
|
||||||
|
"filename" : "83.5_c_Ipad_2x.png",
|
||||||
"scale" : "2x"
|
"scale" : "2x"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -21,7 +21,9 @@ cc_library(
|
||||||
":port",
|
":port",
|
||||||
"//mediapipe/framework:calculator_base",
|
"//mediapipe/framework:calculator_base",
|
||||||
"//mediapipe/framework:calculator_contract",
|
"//mediapipe/framework:calculator_contract",
|
||||||
|
"@com_google_absl//absl/container:btree",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "absl/container/btree_map.h"
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/api2/const_str.h"
|
#include "mediapipe/framework/api2/const_str.h"
|
||||||
#include "mediapipe/framework/api2/contract.h"
|
#include "mediapipe/framework/api2/contract.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
@ -46,7 +48,7 @@ struct TagIndexLocation {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class TagIndexMap {
|
class TagIndexMap {
|
||||||
public:
|
public:
|
||||||
std::vector<std::unique_ptr<T>>& operator[](const std::string& tag) {
|
std::vector<std::unique_ptr<T>>& operator[](absl::string_view tag) {
|
||||||
return map_[tag];
|
return map_[tag];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +74,7 @@ class TagIndexMap {
|
||||||
|
|
||||||
// Note: entries are held by a unique_ptr to ensure pointers remain valid.
|
// Note: entries are held by a unique_ptr to ensure pointers remain valid.
|
||||||
// Should use absl::flat_hash_map but ordering keys for now.
|
// Should use absl::flat_hash_map but ordering keys for now.
|
||||||
std::map<std::string, std::vector<std::unique_ptr<T>>> map_;
|
absl::btree_map<std::string, std::vector<std::unique_ptr<T>>> map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Graph;
|
class Graph;
|
||||||
|
@ -169,6 +171,16 @@ class SourceImpl {
|
||||||
return AddTarget(dest);
|
return AddTarget(dest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct AllowCast
|
||||||
|
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||||
|
!std::is_same_v<T, U>> {};
|
||||||
|
|
||||||
|
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
|
||||||
|
SourceImpl<IsSide, U> Cast() {
|
||||||
|
return SourceImpl<IsSide, U>(base_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Never null.
|
// Never null.
|
||||||
SourceBase* base_;
|
SourceBase* base_;
|
||||||
|
@ -212,19 +224,19 @@ class NodeBase {
|
||||||
// of its entries by index. However, for nodes without visible contracts we
|
// of its entries by index. However, for nodes without visible contracts we
|
||||||
// can't know whether a tag is indexable or not, so we would need the
|
// can't know whether a tag is indexable or not, so we would need the
|
||||||
// multi-port to also be usable as a port directly (representing index 0).
|
// multi-port to also be usable as a port directly (representing index 0).
|
||||||
MultiSource<> Out(const std::string& tag) {
|
MultiSource<> Out(absl::string_view tag) {
|
||||||
return MultiSource<>(&out_streams_[tag]);
|
return MultiSource<>(&out_streams_[tag]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiDestination<> In(const std::string& tag) {
|
MultiDestination<> In(absl::string_view tag) {
|
||||||
return MultiDestination<>(&in_streams_[tag]);
|
return MultiDestination<>(&in_streams_[tag]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiSideSource<> SideOut(const std::string& tag) {
|
MultiSideSource<> SideOut(absl::string_view tag) {
|
||||||
return MultiSideSource<>(&out_sides_[tag]);
|
return MultiSideSource<>(&out_sides_[tag]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiSideDestination<> SideIn(const std::string& tag) {
|
MultiSideDestination<> SideIn(absl::string_view tag) {
|
||||||
return MultiSideDestination<>(&in_sides_[tag]);
|
return MultiSideDestination<>(&in_sides_[tag]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -359,11 +371,11 @@ class PacketGenerator {
|
||||||
public:
|
public:
|
||||||
PacketGenerator(std::string type) : type_(std::move(type)) {}
|
PacketGenerator(std::string type) : type_(std::move(type)) {}
|
||||||
|
|
||||||
MultiSideSource<> SideOut(const std::string& tag) {
|
MultiSideSource<> SideOut(absl::string_view tag) {
|
||||||
return MultiSideSource<>(&out_sides_[tag]);
|
return MultiSideSource<>(&out_sides_[tag]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiSideDestination<> SideIn(const std::string& tag) {
|
MultiSideDestination<> SideIn(absl::string_view tag) {
|
||||||
return MultiSideDestination<>(&in_sides_[tag]);
|
return MultiSideDestination<>(&in_sides_[tag]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -452,19 +464,19 @@ class Graph {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graph ports, non-typed.
|
// Graph ports, non-typed.
|
||||||
MultiSource<> In(const std::string& graph_input) {
|
MultiSource<> In(absl::string_view graph_input) {
|
||||||
return graph_boundary_.Out(graph_input);
|
return graph_boundary_.Out(graph_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiDestination<> Out(const std::string& graph_output) {
|
MultiDestination<> Out(absl::string_view graph_output) {
|
||||||
return graph_boundary_.In(graph_output);
|
return graph_boundary_.In(graph_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiSideSource<> SideIn(const std::string& graph_input) {
|
MultiSideSource<> SideIn(absl::string_view graph_input) {
|
||||||
return graph_boundary_.SideOut(graph_input);
|
return graph_boundary_.SideOut(graph_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiSideDestination<> SideOut(const std::string& graph_output) {
|
MultiSideDestination<> SideOut(absl::string_view graph_output) {
|
||||||
return graph_boundary_.SideIn(graph_output);
|
return graph_boundary_.SideIn(graph_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
@ -296,6 +297,32 @@ TEST(BuilderTest, EmptyTag) {
|
||||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BuilderTest, StringLikeTags) {
|
||||||
|
const char kA[] = "A";
|
||||||
|
const std::string kB = "B";
|
||||||
|
constexpr absl::string_view kC = "C";
|
||||||
|
|
||||||
|
builder::Graph graph;
|
||||||
|
auto& foo = graph.AddNode("Foo");
|
||||||
|
graph.In(kA).SetName("a") >> foo.In(kA);
|
||||||
|
graph.In(kB).SetName("b") >> foo.In(kB);
|
||||||
|
foo.Out(kC).SetName("c") >> graph.Out(kC);
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "A:a"
|
||||||
|
input_stream: "B:b"
|
||||||
|
output_stream: "C:c"
|
||||||
|
node {
|
||||||
|
calculator: "Foo"
|
||||||
|
input_stream: "A:a"
|
||||||
|
input_stream: "B:b"
|
||||||
|
output_stream: "C:c"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(BuilderTest, GraphIndexes) {
|
TEST(BuilderTest, GraphIndexes) {
|
||||||
builder::Graph graph;
|
builder::Graph graph;
|
||||||
auto& foo = graph.AddNode("Foo");
|
auto& foo = graph.AddNode("Foo");
|
||||||
|
@ -326,52 +353,63 @@ TEST(BuilderTest, GraphIndexes) {
|
||||||
|
|
||||||
class AnyAndSameTypeCalculator : public NodeIntf {
|
class AnyAndSameTypeCalculator : public NodeIntf {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<AnyType> kAnyTypeInput{"INPUT"};
|
static constexpr Input<AnyType>::Optional kAnyTypeInput{"INPUT"};
|
||||||
static constexpr Output<AnyType> kAnyTypeOutput{"ANY_OUTPUT"};
|
static constexpr Output<AnyType>::Optional kAnyTypeOutput{"ANY_OUTPUT"};
|
||||||
static constexpr Output<SameType<kAnyTypeInput>> kSameTypeOutput{
|
static constexpr Output<SameType<kAnyTypeInput>>::Optional kSameTypeOutput{
|
||||||
"SAME_OUTPUT"};
|
"SAME_OUTPUT"};
|
||||||
|
static constexpr Output<SameType<kSameTypeOutput>> kRecursiveSameTypeOutput{
|
||||||
|
"RECURSIVE_SAME_OUTPUT"};
|
||||||
|
|
||||||
static constexpr Input<int> kIntInput{"INT_INPUT"};
|
static constexpr Input<int>::Optional kIntInput{"INT_INPUT"};
|
||||||
// `SameType` usage for this output is only for testing purposes.
|
// `SameType` usage for this output is only for testing purposes.
|
||||||
//
|
//
|
||||||
// `SameType` is designed to work with inputs of `AnyType` and, normally, you
|
// `SameType` is designed to work with inputs of `AnyType` and, normally, you
|
||||||
// would not use `Output<SameType<kIntInput>>` in a real calculator. You
|
// would not use `Output<SameType<kIntInput>>` in a real calculator. You
|
||||||
// should write `Output<int>` instead, since the type is known.
|
// should write `Output<int>` instead, since the type is known.
|
||||||
static constexpr Output<SameType<kIntInput>> kSameIntOutput{
|
static constexpr Output<SameType<kIntInput>>::Optional kSameIntOutput{
|
||||||
"SAME_INT_OUTPUT"};
|
"SAME_INT_OUTPUT"};
|
||||||
|
static constexpr Output<SameType<kSameIntOutput>> kRecursiveSameIntOutput{
|
||||||
|
"RECURSIVE_SAME_INT_OUTPUT"};
|
||||||
|
|
||||||
MEDIAPIPE_NODE_INTERFACE(AnyTypeCalculator, kAnyTypeInput, kAnyTypeOutput,
|
MEDIAPIPE_NODE_INTERFACE(AnyAndSameTypeCalculator, kAnyTypeInput,
|
||||||
kSameTypeOutput);
|
kAnyTypeOutput, kSameTypeOutput);
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
|
TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
|
||||||
builder::Graph graph;
|
builder::Graph graph;
|
||||||
builder::Source<internal::Generic> any_input =
|
builder::Source<AnyType> any_input = graph[Input<AnyType>{"GRAPH_ANY_INPUT"}];
|
||||||
graph[Input<AnyType>{"GRAPH_ANY_INPUT"}];
|
|
||||||
builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}];
|
builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}];
|
||||||
|
|
||||||
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
||||||
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
||||||
int_input >> node[AnyAndSameTypeCalculator::kIntInput];
|
int_input >> node[AnyAndSameTypeCalculator::kIntInput];
|
||||||
|
|
||||||
builder::Source<internal::Generic> any_type_output =
|
builder::Source<AnyType> any_type_output =
|
||||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput];
|
node[AnyAndSameTypeCalculator::kAnyTypeOutput];
|
||||||
any_type_output.SetName("any_type_output");
|
any_type_output.SetName("any_type_output");
|
||||||
|
|
||||||
builder::Source<internal::Generic> same_type_output =
|
builder::Source<AnyType> same_type_output =
|
||||||
node[AnyAndSameTypeCalculator::kSameTypeOutput];
|
node[AnyAndSameTypeCalculator::kSameTypeOutput];
|
||||||
same_type_output.SetName("same_type_output");
|
same_type_output.SetName("same_type_output");
|
||||||
builder::Source<internal::Generic> same_int_output =
|
builder::Source<AnyType> recursive_same_type_output =
|
||||||
|
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
|
||||||
|
recursive_same_type_output.SetName("recursive_same_type_output");
|
||||||
|
builder::Source<int> same_int_output =
|
||||||
node[AnyAndSameTypeCalculator::kSameIntOutput];
|
node[AnyAndSameTypeCalculator::kSameIntOutput];
|
||||||
same_int_output.SetName("same_int_output");
|
same_int_output.SetName("same_int_output");
|
||||||
|
builder::Source<int> recursive_same_int_type_output =
|
||||||
|
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
|
||||||
|
recursive_same_int_type_output.SetName("recursive_same_int_type_output");
|
||||||
|
|
||||||
CalculatorGraphConfig expected =
|
CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie<
|
||||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig>(R"pb(
|
||||||
node {
|
node {
|
||||||
calculator: "AnyAndSameTypeCalculator"
|
calculator: "AnyAndSameTypeCalculator"
|
||||||
input_stream: "INPUT:__stream_0"
|
input_stream: "INPUT:__stream_0"
|
||||||
input_stream: "INT_INPUT:__stream_1"
|
input_stream: "INT_INPUT:__stream_1"
|
||||||
output_stream: "ANY_OUTPUT:any_type_output"
|
output_stream: "ANY_OUTPUT:any_type_output"
|
||||||
|
output_stream: "RECURSIVE_SAME_INT_OUTPUT:recursive_same_int_type_output"
|
||||||
|
output_stream: "RECURSIVE_SAME_OUTPUT:recursive_same_type_output"
|
||||||
output_stream: "SAME_INT_OUTPUT:same_int_output"
|
output_stream: "SAME_INT_OUTPUT:same_int_output"
|
||||||
output_stream: "SAME_OUTPUT:same_type_output"
|
output_stream: "SAME_OUTPUT:same_type_output"
|
||||||
}
|
}
|
||||||
|
@ -381,6 +419,29 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
|
||||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BuilderTest, AnyTypeCanBeCast) {
|
||||||
|
builder::Graph graph;
|
||||||
|
builder::Source<std::string> any_input =
|
||||||
|
graph.In("GRAPH_ANY_INPUT").Cast<std::string>();
|
||||||
|
|
||||||
|
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
||||||
|
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
||||||
|
builder::Source<double> any_type_output =
|
||||||
|
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||||
|
any_type_output.SetName("any_type_output");
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "AnyAndSameTypeCalculator"
|
||||||
|
input_stream: "INPUT:__stream_0"
|
||||||
|
output_stream: "ANY_OUTPUT:any_type_output"
|
||||||
|
}
|
||||||
|
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -27,9 +27,7 @@ using HolderBase = mediapipe::packet_internal::HolderBase;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class Packet;
|
class Packet;
|
||||||
|
|
||||||
struct DynamicType {};
|
struct AnyType {
|
||||||
|
|
||||||
struct AnyType : public DynamicType {
|
|
||||||
AnyType() = delete;
|
AnyType() = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -73,14 +73,12 @@ class SideOutputBase : public PortBase {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NoneType {
|
struct NoneType {
|
||||||
private:
|
|
||||||
NoneType() = delete;
|
NoneType() = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <auto& P>
|
template <auto& kP>
|
||||||
class SameType : public DynamicType {
|
struct SameType {
|
||||||
public:
|
static constexpr const decltype(kP)& kPort = kP;
|
||||||
static constexpr const decltype(P)& kPort = P;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class PacketTypeAccess;
|
class PacketTypeAccess;
|
||||||
|
@ -137,21 +135,28 @@ struct IsOneOf : std::false_type {};
|
||||||
template <class... T>
|
template <class... T>
|
||||||
struct IsOneOf<OneOf<T...>> : std::true_type {};
|
struct IsOneOf<OneOf<T...>> : std::true_type {};
|
||||||
|
|
||||||
template <typename T, typename std::enable_if<
|
template <class T>
|
||||||
!std::is_base_of<DynamicType, T>{} && !IsOneOf<T>{},
|
struct IsSameType : std::false_type {};
|
||||||
|
|
||||||
|
template <class P, P& kP>
|
||||||
|
struct IsSameType<SameType<kP>> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
typename std::enable_if<!std::is_same<T, AnyType>{} &&
|
||||||
|
!IsOneOf<T>{} && !IsSameType<T>{},
|
||||||
int>::type = 0>
|
int>::type = 0>
|
||||||
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
||||||
pt.Set<T>();
|
pt.Set<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename std::enable_if<std::is_base_of<DynamicType, T>{},
|
template <typename T, typename std::enable_if<IsSameType<T>{}, int>::type = 0>
|
||||||
int>::type = 0>
|
|
||||||
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
||||||
pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag()));
|
pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag()));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <typename T,
|
||||||
inline void SetType<AnyType>(CalculatorContract* cc, PacketType& pt) {
|
typename std::enable_if<std::is_same<T, AnyType>{}, int>::type = 0>
|
||||||
|
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
||||||
pt.SetAny();
|
pt.SetAny();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,15 +294,15 @@ struct SideBase<InputBase> {
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
|
// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
|
||||||
template <typename T, class = void>
|
template <typename T, typename = void>
|
||||||
struct ActualPayloadType {
|
struct ActualPayloadType {
|
||||||
using type = T;
|
using type = T;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ActualPayloadType<
|
struct ActualPayloadType<T, std::enable_if_t<IsSameType<T>{}, void>> {
|
||||||
T, std::enable_if_t<std::is_base_of<DynamicType, T>{}, void>> {
|
using type = typename ActualPayloadType<
|
||||||
using type = internal::Generic;
|
typename std::decay_t<decltype(T::kPort)>::value_t>::type;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
@ -85,14 +85,8 @@ cv::Mat MatView(const ImageFrame* image) {
|
||||||
const size_t steps[] = {static_cast<size_t>(image->WidthStep()),
|
const size_t steps[] = {static_cast<size_t>(image->WidthStep()),
|
||||||
static_cast<size_t>(image->ByteDepth())};
|
static_cast<size_t>(image->ByteDepth())};
|
||||||
// Use ImageFrame to initialize in-place. ImageFrame still owns memory.
|
// Use ImageFrame to initialize in-place. ImageFrame still owns memory.
|
||||||
if (steps[0] == sizes[1] * image->NumberOfChannels() * image->ByteDepth()) {
|
return cv::Mat(dims, sizes, type, const_cast<uint8_t*>(image->PixelData()),
|
||||||
// Contiguous memory optimization. See b/78570764
|
|
||||||
return cv::Mat(dims, sizes, type, const_cast<uint8*>(image->PixelData()));
|
|
||||||
} else {
|
|
||||||
// Custom width step.
|
|
||||||
return cv::Mat(dims, sizes, type, const_cast<uint8*>(image->PixelData()),
|
|
||||||
steps);
|
steps);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace formats
|
} // namespace formats
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
@ -21,7 +21,6 @@
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Set image_frame to a constant per-channel pix_value.
|
// Set image_frame to a constant per-channel pix_value.
|
||||||
|
@ -50,8 +49,8 @@ TEST(ImageFrameOpencvTest, ConvertToMat) {
|
||||||
ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height);
|
ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height);
|
||||||
|
|
||||||
// Check adding constant images.
|
// Check adding constant images.
|
||||||
const uint8 frame1_val = 12;
|
const uint8_t frame1_val = 12;
|
||||||
const uint8 frame2_val = 34;
|
const uint8_t frame2_val = 34;
|
||||||
SetToColor<uint8>(&frame1_val, &frame1);
|
SetToColor<uint8>(&frame1_val, &frame1);
|
||||||
SetToColor<uint8>(&frame2_val, &frame2);
|
SetToColor<uint8>(&frame2_val, &frame2);
|
||||||
// Get Mat wrapper around ImageFrame memory (zero copy).
|
// Get Mat wrapper around ImageFrame memory (zero copy).
|
||||||
|
@ -77,6 +76,37 @@ TEST(ImageFrameOpencvTest, ConvertToMat) {
|
||||||
EXPECT_EQ(max_loc.y, i_height - 6);
|
EXPECT_EQ(max_loc.y, i_height - 6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ImageFrameOpencvTest, ConvertToIpl) {
|
||||||
|
const int i_width = 123, i_height = 45;
|
||||||
|
ImageFrame frame1(ImageFormat::GRAY8, i_width, i_height);
|
||||||
|
ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height);
|
||||||
|
|
||||||
|
// Check adding constant images.
|
||||||
|
const uint8_t frame1_val = 12;
|
||||||
|
const uint8_t frame2_val = 34;
|
||||||
|
SetToColor<uint8>(&frame1_val, &frame1);
|
||||||
|
SetToColor<uint8>(&frame2_val, &frame2);
|
||||||
|
const cv::Mat frame1_mat = formats::MatView(&frame1);
|
||||||
|
const cv::Mat frame2_mat = formats::MatView(&frame2);
|
||||||
|
const cv::Mat frame_sum = frame1_mat + frame2_mat;
|
||||||
|
const auto frame_avg = static_cast<int>(cv::mean(frame_sum).val[0]);
|
||||||
|
EXPECT_EQ(frame_avg, frame1_val + frame2_val);
|
||||||
|
|
||||||
|
// Check setting min/max pixels.
|
||||||
|
uint8* frame1_ptr = frame1.MutablePixelData();
|
||||||
|
frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1;
|
||||||
|
frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100;
|
||||||
|
double min, max;
|
||||||
|
cv::Point min_loc, max_loc;
|
||||||
|
cv::minMaxLoc(frame1_mat, &min, &max, &min_loc, &max_loc);
|
||||||
|
EXPECT_EQ(min, 1);
|
||||||
|
EXPECT_EQ(min_loc.x, i_width - 5);
|
||||||
|
EXPECT_EQ(min_loc.y, i_height - 5);
|
||||||
|
EXPECT_EQ(max, 100);
|
||||||
|
EXPECT_EQ(max_loc.x, i_width - 6);
|
||||||
|
EXPECT_EQ(max_loc.y, i_height - 6);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ImageFrameOpencvTest, ImageFormats) {
|
TEST(ImageFrameOpencvTest, ImageFormats) {
|
||||||
const int i_width = 123, i_height = 45;
|
const int i_width = 123, i_height = 45;
|
||||||
ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height);
|
ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019-2020 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -37,26 +37,21 @@ namespace mediapipe {
|
||||||
bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; }
|
bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; }
|
||||||
|
|
||||||
int BhwcBatchFromShape(const Tensor::Shape& shape) {
|
int BhwcBatchFromShape(const Tensor::Shape& shape) {
|
||||||
LOG_IF(FATAL, shape.dims.empty())
|
if (shape.dims.empty()) {
|
||||||
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
return 1;
|
||||||
|
}
|
||||||
return shape.dims[0];
|
return shape.dims[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
int BhwcHeightFromShape(const Tensor::Shape& shape) {
|
int BhwcHeightFromShape(const Tensor::Shape& shape) {
|
||||||
LOG_IF(FATAL, shape.dims.empty())
|
|
||||||
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
|
||||||
return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3];
|
return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3];
|
||||||
}
|
}
|
||||||
|
|
||||||
int BhwcWidthFromShape(const Tensor::Shape& shape) {
|
int BhwcWidthFromShape(const Tensor::Shape& shape) {
|
||||||
LOG_IF(FATAL, shape.dims.empty())
|
|
||||||
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
|
||||||
return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2];
|
return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2];
|
||||||
}
|
}
|
||||||
|
|
||||||
int BhwcDepthFromShape(const Tensor::Shape& shape) {
|
int BhwcDepthFromShape(const Tensor::Shape& shape) {
|
||||||
LOG_IF(FATAL, shape.dims.empty())
|
|
||||||
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
|
|
||||||
return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1];
|
return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -424,6 +419,11 @@ Tensor::Tensor(ElementType element_type, const Shape& shape,
|
||||||
|
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
void Tensor::Invalidate() {
|
void Tensor::Invalidate() {
|
||||||
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
GLuint cleanup_gl_tex = GL_INVALID_INDEX;
|
||||||
|
GLuint cleanup_gl_fb = GL_INVALID_INDEX;
|
||||||
|
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
{
|
||||||
absl::MutexLock lock(&view_mutex_);
|
absl::MutexLock lock(&view_mutex_);
|
||||||
// If memory is allocated and not owned by the metal buffer.
|
// If memory is allocated and not owned by the metal buffer.
|
||||||
// TODO: Re-design cpu buffer memory management.
|
// TODO: Re-design cpu buffer memory management.
|
||||||
|
@ -432,6 +432,23 @@ void Tensor::Invalidate() {
|
||||||
}
|
}
|
||||||
metal_buffer_ = nil;
|
metal_buffer_ = nil;
|
||||||
cpu_buffer_ = nullptr;
|
cpu_buffer_ = nullptr;
|
||||||
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
// Don't need to wait for the resource to be deleted bacause if will be
|
||||||
|
// released on last reference deletion inside the OpenGL driver.
|
||||||
|
std::swap(cleanup_gl_tex, opengl_texture2d_);
|
||||||
|
std::swap(cleanup_gl_fb, frame_buffer_);
|
||||||
|
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
}
|
||||||
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
// Do not hold the view mutex while invoking GlContext::RunWithoutWaiting,
|
||||||
|
// since that method may acquire the context's own lock.
|
||||||
|
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) {
|
||||||
|
gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() {
|
||||||
|
glDeleteTextures(1, &cleanup_gl_tex);
|
||||||
|
glDeleteFramebuffers(1, &cleanup_gl_fb);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
|
#include <numeric>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -95,9 +96,8 @@ class Tensor {
|
||||||
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
||||||
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
|
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
|
||||||
int num_elements() const {
|
int num_elements() const {
|
||||||
int res = dims.empty() ? 0 : 1;
|
return std::accumulate(dims.begin(), dims.end(), 1,
|
||||||
std::for_each(dims.begin(), dims.end(), [&res](int i) { res *= i; });
|
std::multiplies<int>());
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
std::vector<int> dims;
|
std::vector<int> dims;
|
||||||
};
|
};
|
||||||
|
|
|
@ -15,7 +15,7 @@ def mediapipe_cc_test(
|
||||||
platforms = ["linux", "android", "ios", "wasm"],
|
platforms = ["linux", "android", "ios", "wasm"],
|
||||||
exclude_platforms = None,
|
exclude_platforms = None,
|
||||||
# ios_unit_test arguments
|
# ios_unit_test arguments
|
||||||
ios_minimum_os_version = "9.0",
|
ios_minimum_os_version = "11.0",
|
||||||
# android_cc_test arguments
|
# android_cc_test arguments
|
||||||
open_gl_driver = None,
|
open_gl_driver = None,
|
||||||
emulator_mini_boot = True,
|
emulator_mini_boot = True,
|
||||||
|
|
|
@ -108,6 +108,7 @@ cc_library(
|
||||||
":sharded_map",
|
":sharded_map",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_profile_cc_proto",
|
"//mediapipe/framework:calculator_profile_cc_proto",
|
||||||
|
"//mediapipe/framework/port:file_helpers",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "absl/time/time.h"
|
#include "absl/time/time.h"
|
||||||
#include "mediapipe/framework/port/advanced_proto_lite_inc.h"
|
#include "mediapipe/framework/port/advanced_proto_lite_inc.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
|
#include "mediapipe/framework/port/file_helpers.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/proto_ns.h"
|
#include "mediapipe/framework/port/proto_ns.h"
|
||||||
#include "mediapipe/framework/port/re2.h"
|
#include "mediapipe/framework/port/re2.h"
|
||||||
|
@ -244,7 +245,16 @@ absl::Status GraphProfiler::Start(mediapipe::Executor* executor) {
|
||||||
executor != nullptr) {
|
executor != nullptr) {
|
||||||
// Inform the user via logging the path to the trace logs.
|
// Inform the user via logging the path to the trace logs.
|
||||||
ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath());
|
ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath());
|
||||||
|
// Check that we can actually write to it.
|
||||||
|
auto status =
|
||||||
|
file::SetContents(absl::StrCat(trace_log_path, "trace_writing_check"),
|
||||||
|
"can write trace logs to this location");
|
||||||
|
if (status.ok()) {
|
||||||
LOG(INFO) << "trace_log_path: " << trace_log_path;
|
LOG(INFO) << "trace_log_path: " << trace_log_path;
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "cannot write to trace_log_path: " << trace_log_path << ": "
|
||||||
|
<< status;
|
||||||
|
}
|
||||||
|
|
||||||
is_running_ = true;
|
is_running_ = true;
|
||||||
executor->Schedule([this] {
|
executor->Schedule([this] {
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
buffers: {
|
buffers: {
|
||||||
size_kb: 150000
|
size_kb: 150000
|
||||||
fill_policy: DISCARD
|
fill_policy: RING_BUFFER
|
||||||
}
|
}
|
||||||
|
|
||||||
data_sources: {
|
data_sources: {
|
||||||
|
@ -30,3 +30,5 @@ data_sources: {
|
||||||
}
|
}
|
||||||
write_into_file: true
|
write_into_file: true
|
||||||
file_write_period_ms: 500
|
file_write_period_ms: 500
|
||||||
|
# b/243571696 Added to remove Perfetto timeouts when running benchmarks remotely.
|
||||||
|
duration_ms: 60000
|
||||||
|
|
|
@ -821,6 +821,19 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_cc_test(
|
||||||
|
name = "switch_demux_calculator_test",
|
||||||
|
srcs = ["switch_demux_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":container_util",
|
||||||
|
":switch_demux_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:logging",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "switch_mux_calculator",
|
name = "switch_mux_calculator",
|
||||||
srcs = ["switch_mux_calculator.cc"],
|
srcs = ["switch_mux_calculator.cc"],
|
||||||
|
|
|
@ -129,12 +129,12 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||||
// Relay side packets to all channels.
|
// Relay side packets to all channels.
|
||||||
// Note: This is necessary because Calculator::Open only proceeds when every
|
// Note: This is necessary because Calculator::Open only proceeds when every
|
||||||
// anticipated side-packet arrives.
|
// anticipated side-packet arrives.
|
||||||
int channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap());
|
int side_channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap());
|
||||||
for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) {
|
for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) {
|
||||||
int num_entries = cc->InputSidePackets().NumEntries(tag);
|
int num_entries = cc->InputSidePackets().NumEntries(tag);
|
||||||
for (int index = 0; index < num_entries; ++index) {
|
for (int index = 0; index < num_entries; ++index) {
|
||||||
Packet input = cc->InputSidePackets().Get(tag, index);
|
Packet input = cc->InputSidePackets().Get(tag, index);
|
||||||
for (int channel = 0; channel < channel_count; ++channel) {
|
for (int channel = 0; channel < side_channel_count; ++channel) {
|
||||||
std::string output_tag = tool::ChannelTag(tag, channel);
|
std::string output_tag = tool::ChannelTag(tag, channel);
|
||||||
auto output_id = cc->OutputSidePackets().GetId(output_tag, index);
|
auto output_id = cc->OutputSidePackets().GetId(output_tag, index);
|
||||||
if (output_id.IsValid()) {
|
if (output_id.IsValid()) {
|
||||||
|
@ -143,6 +143,23 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Relay headers to all channels.
|
||||||
|
int output_channel_count = tool::ChannelCount(cc->Outputs().TagMap());
|
||||||
|
for (const std::string& tag : ChannelTags(cc->Outputs().TagMap())) {
|
||||||
|
int num_entries = cc->Inputs().NumEntries(tag);
|
||||||
|
for (int index = 0; index < num_entries; ++index) {
|
||||||
|
auto& input = cc->Inputs().Get(tag, index);
|
||||||
|
if (input.Header().IsEmpty()) continue;
|
||||||
|
for (int channel = 0; channel < output_channel_count; ++channel) {
|
||||||
|
std::string output_tag = tool::ChannelTag(tag, channel);
|
||||||
|
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||||
|
if (output_id.IsValid()) {
|
||||||
|
cc->Outputs().Get(output_tag, index).SetHeader(input.Header());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
135
mediapipe/framework/tool/switch_demux_calculator_test.cc
Normal file
135
mediapipe/framework/tool/switch_demux_calculator_test.cc
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/logging.h"
|
||||||
|
#include "mediapipe/framework/tool/container_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Returns a CalculatorGraph to run a single calculator.
|
||||||
|
CalculatorGraph BuildCalculatorGraph(CalculatorGraphConfig::Node node_config) {
|
||||||
|
CalculatorGraphConfig config;
|
||||||
|
*config.add_node() = node_config;
|
||||||
|
*config.mutable_input_stream() = node_config.input_stream();
|
||||||
|
*config.mutable_output_stream() = node_config.output_stream();
|
||||||
|
*config.mutable_input_side_packet() = node_config.input_side_packet();
|
||||||
|
*config.mutable_output_side_packet() = node_config.output_side_packet();
|
||||||
|
return CalculatorGraph(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a string packet.
|
||||||
|
Packet pack(std::string data, int timestamp) {
|
||||||
|
return MakePacket<std::string>(data).At(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates an int packet.
|
||||||
|
Packet pack(int data, int timestamp) {
|
||||||
|
return MakePacket<int>(data).At(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests showing packet channel synchronization through SwitchDemuxCalculator.
|
||||||
|
class SwitchDemuxCalculatorTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
SwitchDemuxCalculatorTest() {}
|
||||||
|
~SwitchDemuxCalculatorTest() override {}
|
||||||
|
void SetUp() override {}
|
||||||
|
void TearDown() override {}
|
||||||
|
|
||||||
|
// Defines a SwitchDemuxCalculator CalculatorGraphConfig::Node.
|
||||||
|
CalculatorGraphConfig::Node BuildNodeConfig() {
|
||||||
|
CalculatorGraphConfig::Node result;
|
||||||
|
*result.mutable_calculator() = "SwitchDemuxCalculator";
|
||||||
|
*result.add_input_stream() = "SELECT:select";
|
||||||
|
for (int c = 0; c < 2; ++c) {
|
||||||
|
*result.add_output_stream() =
|
||||||
|
absl::StrCat(tool::ChannelTag("FRAME", c), ":frame_", c);
|
||||||
|
*result.add_output_stream() =
|
||||||
|
absl::StrCat(tool::ChannelTag("MASK", c), ":mask_", c);
|
||||||
|
}
|
||||||
|
*result.add_input_stream() = "FRAME:frame";
|
||||||
|
*result.add_input_stream() = "MASK:mask";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Shows the SwitchMuxCalculator is available.
|
||||||
|
TEST_F(SwitchDemuxCalculatorTest, IsRegistered) {
|
||||||
|
EXPECT_TRUE(CalculatorBaseRegistry::IsRegistered("SwitchDemuxCalculator"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SwitchDemuxCalculatorTest, BasicDataFlow) {
|
||||||
|
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
|
||||||
|
CalculatorGraph graph = BuildCalculatorGraph(node_config);
|
||||||
|
std::vector<Packet> output_frames0;
|
||||||
|
EXPECT_TRUE(graph
|
||||||
|
.ObserveOutputStream("frame_0",
|
||||||
|
[&](const Packet& p) {
|
||||||
|
output_frames0.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
})
|
||||||
|
.ok());
|
||||||
|
std::vector<Packet> output_frames1;
|
||||||
|
EXPECT_TRUE(graph
|
||||||
|
.ObserveOutputStream("frame_1",
|
||||||
|
[&](const Packet& p) {
|
||||||
|
output_frames1.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
})
|
||||||
|
.ok());
|
||||||
|
EXPECT_TRUE(
|
||||||
|
graph.StartRun({}, {{"frame", MakePacket<std::string>("frame_header")}})
|
||||||
|
.ok());
|
||||||
|
|
||||||
|
// Finalize input for the "mask" input stream.
|
||||||
|
EXPECT_TRUE(graph.CloseInputStream("mask").ok());
|
||||||
|
|
||||||
|
// Channel 0 is selected just before corresponding packets arrive.
|
||||||
|
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 1)).ok());
|
||||||
|
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 10)).ok());
|
||||||
|
EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p0_t10", 10)).ok());
|
||||||
|
EXPECT_TRUE(graph.WaitUntilIdle().ok());
|
||||||
|
EXPECT_EQ(output_frames0.size(), 1);
|
||||||
|
EXPECT_EQ(output_frames1.size(), 0);
|
||||||
|
EXPECT_EQ(output_frames0[0].Get<std::string>(), "p0_t10");
|
||||||
|
|
||||||
|
// Channel 1 is selected just before corresponding packets arrive.
|
||||||
|
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 11)).ok());
|
||||||
|
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 20)).ok());
|
||||||
|
EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p1_t20", 20)).ok());
|
||||||
|
EXPECT_TRUE(graph.WaitUntilIdle().ok());
|
||||||
|
EXPECT_EQ(output_frames0.size(), 1);
|
||||||
|
EXPECT_EQ(output_frames1.size(), 1);
|
||||||
|
EXPECT_EQ(output_frames1[0].Get<std::string>(), "p1_t20");
|
||||||
|
|
||||||
|
EXPECT_EQ(
|
||||||
|
graph.FindOutputStreamManager("frame_0")->Header().Get<std::string>(),
|
||||||
|
"frame_header");
|
||||||
|
EXPECT_EQ(
|
||||||
|
graph.FindOutputStreamManager("frame_1")->Header().Get<std::string>(),
|
||||||
|
"frame_header");
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph.CloseAllPacketSources().ok());
|
||||||
|
EXPECT_TRUE(graph.WaitUntilDone().ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -271,6 +271,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_buffer_format",
|
":gpu_buffer_format",
|
||||||
":gpu_buffer_storage",
|
":gpu_buffer_storage",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
":gpu_buffer_storage_image_frame",
|
":gpu_buffer_storage_image_frame",
|
||||||
|
@ -366,6 +367,23 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gpu_buffer_storage_ahwb",
|
||||||
|
srcs = ["gpu_buffer_storage_ahwb.cc"],
|
||||||
|
hdrs = ["gpu_buffer_storage_ahwb.h"],
|
||||||
|
linkopts = select({
|
||||||
|
"//conditions:default": [],
|
||||||
|
"//mediapipe:android": [
|
||||||
|
"-landroid",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
deps = [
|
||||||
|
":gpu_buffer_format",
|
||||||
|
":gpu_buffer_storage",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "gpu_origin_proto",
|
name = "gpu_origin_proto",
|
||||||
srcs = ["gpu_origin.proto"],
|
srcs = ["gpu_origin.proto"],
|
||||||
|
@ -1087,3 +1105,19 @@ ios_unit_test(
|
||||||
],
|
],
|
||||||
deps = [":gl_ios_test_lib"],
|
deps = [":gl_ios_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_cc_test(
|
||||||
|
name = "gpu_buffer_storage_ahwb_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["gpu_buffer_storage_ahwb_test.cc"],
|
||||||
|
exclude_platforms = [
|
||||||
|
"ios",
|
||||||
|
"wasm",
|
||||||
|
],
|
||||||
|
requires_full_emulation = True,
|
||||||
|
deps = [
|
||||||
|
":gpu_buffer_format",
|
||||||
|
":gpu_buffer_storage_ahwb",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -620,7 +620,9 @@ class GlSyncWrapper {
|
||||||
#endif
|
#endif
|
||||||
GLenum result = glClientWaitSync(sync_, flags, timeout);
|
GLenum result = glClientWaitSync(sync_, flags, timeout);
|
||||||
if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) {
|
if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) {
|
||||||
Clear();
|
// TODO: we could clear at this point so later calls are faster,
|
||||||
|
// but we need to do so in a thread-safe way.
|
||||||
|
// Clear();
|
||||||
}
|
}
|
||||||
// TODO: do something if the wait fails?
|
// TODO: do something if the wait fails?
|
||||||
}
|
}
|
||||||
|
@ -646,7 +648,9 @@ class GlSyncWrapper {
|
||||||
#endif
|
#endif
|
||||||
GLenum result = glClientWaitSync(sync_, flags, 0);
|
GLenum result = glClientWaitSync(sync_, flags, 0);
|
||||||
if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) {
|
if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) {
|
||||||
Clear();
|
// TODO: we could clear at this point so later calls are faster,
|
||||||
|
// but we need to do so in a thread-safe way.
|
||||||
|
// Clear();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -822,10 +826,17 @@ std::shared_ptr<GlSyncPoint> GlContext::CreateSyncToken() {
|
||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GlContext::IsAnyContextCurrent() {
|
||||||
|
ContextBinding ctx;
|
||||||
|
GetCurrentContextBinding(&ctx);
|
||||||
|
return ctx.context != kPlatformGlContextNone;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<GlSyncPoint>
|
std::shared_ptr<GlSyncPoint>
|
||||||
GlContext::CreateSyncTokenForCurrentExternalContext(
|
GlContext::CreateSyncTokenForCurrentExternalContext(
|
||||||
const std::shared_ptr<GlContext>& delegate_graph_context) {
|
const std::shared_ptr<GlContext>& delegate_graph_context) {
|
||||||
CHECK(delegate_graph_context);
|
CHECK(delegate_graph_context);
|
||||||
|
if (!IsAnyContextCurrent()) return nullptr;
|
||||||
if (delegate_graph_context->ShouldUseFenceSync()) {
|
if (delegate_graph_context->ShouldUseFenceSync()) {
|
||||||
return std::shared_ptr<GlSyncPoint>(
|
return std::shared_ptr<GlSyncPoint>(
|
||||||
new GlExternalFenceSyncPoint(delegate_graph_context));
|
new GlExternalFenceSyncPoint(delegate_graph_context));
|
||||||
|
|
|
@ -303,6 +303,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
||||||
return *static_cast<T*>(entry.get());
|
return *static_cast<T*>(entry.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true if any GL context, including external contexts not managed by
|
||||||
|
// the GlContext class, is current.
|
||||||
|
static bool IsAnyContextCurrent();
|
||||||
|
|
||||||
// Creates a synchronization token for the current, non-GlContext-owned
|
// Creates a synchronization token for the current, non-GlContext-owned
|
||||||
// context. This can be passed to MediaPipe so it can synchronize with the
|
// context. This can be passed to MediaPipe so it can synchronize with the
|
||||||
// commands issued in the external context up to this point.
|
// commands issued in the external context up to this point.
|
||||||
|
|
|
@ -145,9 +145,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
||||||
CHECK_NE(name_, 0);
|
CHECK_NE(name_, 0);
|
||||||
GLuint name_to_delete = name_;
|
GLuint name_to_delete = name_;
|
||||||
context->RunWithoutWaiting([name_to_delete, sync_token]() {
|
context->RunWithoutWaiting([name_to_delete, sync_token]() {
|
||||||
|
if (sync_token) {
|
||||||
// TODO: maybe we do not actually have to wait for the
|
// TODO: maybe we do not actually have to wait for the
|
||||||
// consumer sync here. Check docs.
|
// consumer sync here. Check docs.
|
||||||
sync_token->WaitOnGpu();
|
sync_token->WaitOnGpu();
|
||||||
|
} else {
|
||||||
|
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback";
|
||||||
|
}
|
||||||
DLOG_IF(ERROR, !glIsTexture(name_to_delete))
|
DLOG_IF(ERROR, !glIsTexture(name_to_delete))
|
||||||
<< "Deleting invalid texture id: " << name_to_delete;
|
<< "Deleting invalid texture id: " << name_to_delete;
|
||||||
glDeleteTextures(1, &name_to_delete);
|
glDeleteTextures(1, &name_to_delete);
|
||||||
|
@ -179,13 +183,19 @@ void GlTextureBuffer::Reuse() {
|
||||||
void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
||||||
CHECK(!producer_sync_)
|
CHECK(!producer_sync_)
|
||||||
<< "Updated existing texture which had not been marked for reuse!";
|
<< "Updated existing texture which had not been marked for reuse!";
|
||||||
|
CHECK(prod_token);
|
||||||
producer_sync_ = std::move(prod_token);
|
producer_sync_ = std::move(prod_token);
|
||||||
producer_context_ = producer_sync_->GetContext();
|
producer_context_ = producer_sync_->GetContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
||||||
absl::MutexLock lock(&consumer_sync_mutex_);
|
absl::MutexLock lock(&consumer_sync_mutex_);
|
||||||
|
if (cons_token) {
|
||||||
consumer_multi_sync_->Add(std::move(cons_token));
|
consumer_multi_sync_->Add(std::move(cons_token));
|
||||||
|
} else {
|
||||||
|
// TODO: change to a CHECK.
|
||||||
|
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in DidRead";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GlTextureBuffer::~GlTextureBuffer() {
|
GlTextureBuffer::~GlTextureBuffer() {
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
|
|
||||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
@ -10,6 +12,23 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct StorageTypeFormatter {
|
||||||
|
void operator()(std::string* out,
|
||||||
|
const std::shared_ptr<internal::GpuBufferStorage>& s) const {
|
||||||
|
absl::StrAppend(out, s->storage_type().name());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::string GpuBuffer::DebugString() const {
|
||||||
|
return absl::StrCat("GpuBuffer[",
|
||||||
|
absl::StrJoin(storages_, ", ", StorageTypeFormatter()),
|
||||||
|
"]");
|
||||||
|
}
|
||||||
|
|
||||||
internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
|
internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
|
||||||
TypeId view_provider_type, bool for_writing) const {
|
TypeId view_provider_type, bool for_writing) const {
|
||||||
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
|
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
|
||||||
|
@ -52,7 +71,10 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(chosen_storage) << "no view provider found";
|
CHECK(chosen_storage) << "no view provider found for requested view "
|
||||||
|
<< view_provider_type.name() << "; storages available: "
|
||||||
|
<< absl::StrJoin(storages_, ", ",
|
||||||
|
StorageTypeFormatter());
|
||||||
DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type));
|
DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type));
|
||||||
return **chosen_storage;
|
return **chosen_storage;
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,6 +129,8 @@ class GpuBuffer {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string DebugString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class PlaceholderGpuBufferStorage
|
class PlaceholderGpuBufferStorage
|
||||||
: public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> {
|
: public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> {
|
||||||
|
|
|
@ -21,18 +21,29 @@ licenses(["notice"])
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
mediapipe_simple_subgraph(
|
||||||
|
name = "pose_landmarks_to_render_data",
|
||||||
|
graph = "pose_landmarks_to_render_data.pbtxt",
|
||||||
|
register_as = "PoseLandmarksToRenderData",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||||
|
"//mediapipe/calculators/core:split_proto_list_calculator",
|
||||||
|
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
|
||||||
|
"//mediapipe/calculators/util:rect_to_render_scale_calculator",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_simple_subgraph(
|
mediapipe_simple_subgraph(
|
||||||
name = "pose_renderer_gpu",
|
name = "pose_renderer_gpu",
|
||||||
graph = "pose_renderer_gpu.pbtxt",
|
graph = "pose_renderer_gpu.pbtxt",
|
||||||
register_as = "PoseRendererGpu",
|
register_as = "PoseRendererGpu",
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:split_proto_list_calculator",
|
":pose_landmarks_to_render_data",
|
||||||
|
"//mediapipe/calculators/image:image_properties_calculator",
|
||||||
"//mediapipe/calculators/image:recolor_calculator",
|
"//mediapipe/calculators/image:recolor_calculator",
|
||||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||||
"//mediapipe/calculators/util:detections_to_render_data_calculator",
|
"//mediapipe/calculators/util:detections_to_render_data_calculator",
|
||||||
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
|
|
||||||
"//mediapipe/calculators/util:rect_to_render_data_calculator",
|
"//mediapipe/calculators/util:rect_to_render_data_calculator",
|
||||||
"//mediapipe/calculators/util:rect_to_render_scale_calculator",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,12 +52,11 @@ mediapipe_simple_subgraph(
|
||||||
graph = "pose_renderer_cpu.pbtxt",
|
graph = "pose_renderer_cpu.pbtxt",
|
||||||
register_as = "PoseRendererCpu",
|
register_as = "PoseRendererCpu",
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:split_proto_list_calculator",
|
":pose_landmarks_to_render_data",
|
||||||
|
"//mediapipe/calculators/image:image_properties_calculator",
|
||||||
"//mediapipe/calculators/image:recolor_calculator",
|
"//mediapipe/calculators/image:recolor_calculator",
|
||||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||||
"//mediapipe/calculators/util:detections_to_render_data_calculator",
|
"//mediapipe/calculators/util:detections_to_render_data_calculator",
|
||||||
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
|
|
||||||
"//mediapipe/calculators/util:rect_to_render_data_calculator",
|
"//mediapipe/calculators/util:rect_to_render_data_calculator",
|
||||||
"//mediapipe/calculators/util:rect_to_render_scale_calculator",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,236 @@
|
||||||
|
# MediaPipe pose landmarks to render data subgraph.
|
||||||
|
|
||||||
|
type: "PoseLandmarksToRenderData"
|
||||||
|
|
||||||
|
# Pose landmarks. (NormalizedLandmarkList)
|
||||||
|
input_stream: "LANDMARKS:pose_landmarks"
|
||||||
|
# Region of interest calculated based on landmarks. (NormalizedRect)
|
||||||
|
input_stream: "ROI:roi"
|
||||||
|
# Image size. (pair<int, int>)
|
||||||
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
|
|
||||||
|
# The resulting render data. (vector<RenderData>)
|
||||||
|
output_stream: "RENDER_DATA:merged_render_data"
|
||||||
|
|
||||||
|
# Calculates rendering scale based on the pose roi.
|
||||||
|
node {
|
||||||
|
calculator: "RectToRenderScaleCalculator"
|
||||||
|
input_stream: "NORM_RECT:roi"
|
||||||
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
|
output_stream: "RENDER_SCALE:render_scale"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] {
|
||||||
|
multiplier: 0.0012
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "SplitNormalizedLandmarkListCalculator"
|
||||||
|
input_stream: "pose_landmarks"
|
||||||
|
output_stream: "visible_pose_landmarks"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
||||||
|
ranges: { begin: 0 end: 25 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Converts landmarks to drawing primitives for annotation overlay.
|
||||||
|
node {
|
||||||
|
calculator: "LandmarksToRenderDataCalculator"
|
||||||
|
input_stream: "NORM_LANDMARKS:pose_landmarks"
|
||||||
|
input_stream: "RENDER_SCALE:render_scale"
|
||||||
|
output_stream: "RENDER_DATA:landmarks_render_data"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
||||||
|
landmark_connections: 0
|
||||||
|
landmark_connections: 1
|
||||||
|
landmark_connections: 1
|
||||||
|
landmark_connections: 2
|
||||||
|
landmark_connections: 2
|
||||||
|
landmark_connections: 3
|
||||||
|
landmark_connections: 3
|
||||||
|
landmark_connections: 7
|
||||||
|
landmark_connections: 0
|
||||||
|
landmark_connections: 4
|
||||||
|
landmark_connections: 4
|
||||||
|
landmark_connections: 5
|
||||||
|
landmark_connections: 5
|
||||||
|
landmark_connections: 6
|
||||||
|
landmark_connections: 6
|
||||||
|
landmark_connections: 8
|
||||||
|
landmark_connections: 9
|
||||||
|
landmark_connections: 10
|
||||||
|
landmark_connections: 11
|
||||||
|
landmark_connections: 12
|
||||||
|
landmark_connections: 11
|
||||||
|
landmark_connections: 13
|
||||||
|
landmark_connections: 13
|
||||||
|
landmark_connections: 15
|
||||||
|
landmark_connections: 15
|
||||||
|
landmark_connections: 17
|
||||||
|
landmark_connections: 15
|
||||||
|
landmark_connections: 19
|
||||||
|
landmark_connections: 15
|
||||||
|
landmark_connections: 21
|
||||||
|
landmark_connections: 17
|
||||||
|
landmark_connections: 19
|
||||||
|
landmark_connections: 12
|
||||||
|
landmark_connections: 14
|
||||||
|
landmark_connections: 14
|
||||||
|
landmark_connections: 16
|
||||||
|
landmark_connections: 16
|
||||||
|
landmark_connections: 18
|
||||||
|
landmark_connections: 16
|
||||||
|
landmark_connections: 20
|
||||||
|
landmark_connections: 16
|
||||||
|
landmark_connections: 22
|
||||||
|
landmark_connections: 18
|
||||||
|
landmark_connections: 20
|
||||||
|
landmark_connections: 11
|
||||||
|
landmark_connections: 23
|
||||||
|
landmark_connections: 12
|
||||||
|
landmark_connections: 24
|
||||||
|
landmark_connections: 23
|
||||||
|
landmark_connections: 24
|
||||||
|
landmark_connections: 23
|
||||||
|
landmark_connections: 25
|
||||||
|
landmark_connections: 24
|
||||||
|
landmark_connections: 26
|
||||||
|
landmark_connections: 25
|
||||||
|
landmark_connections: 27
|
||||||
|
landmark_connections: 26
|
||||||
|
landmark_connections: 28
|
||||||
|
landmark_connections: 27
|
||||||
|
landmark_connections: 29
|
||||||
|
landmark_connections: 28
|
||||||
|
landmark_connections: 30
|
||||||
|
landmark_connections: 29
|
||||||
|
landmark_connections: 31
|
||||||
|
landmark_connections: 30
|
||||||
|
landmark_connections: 32
|
||||||
|
landmark_connections: 27
|
||||||
|
landmark_connections: 31
|
||||||
|
landmark_connections: 28
|
||||||
|
landmark_connections: 32
|
||||||
|
|
||||||
|
landmark_color { r: 255 g: 255 b: 255 }
|
||||||
|
connection_color { r: 255 g: 255 b: 255 }
|
||||||
|
thickness: 3.0
|
||||||
|
visualize_landmark_depth: false
|
||||||
|
utilize_visibility: true
|
||||||
|
visibility_threshold: 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Take left pose landmarks.
|
||||||
|
node {
|
||||||
|
calculator: "SplitNormalizedLandmarkListCalculator"
|
||||||
|
input_stream: "pose_landmarks"
|
||||||
|
output_stream: "landmarks_left_side"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
||||||
|
ranges: { begin: 1 end: 4 }
|
||||||
|
ranges: { begin: 7 end: 8 }
|
||||||
|
ranges: { begin: 9 end: 10 }
|
||||||
|
ranges: { begin: 11 end: 12 }
|
||||||
|
ranges: { begin: 13 end: 14 }
|
||||||
|
ranges: { begin: 15 end: 16 }
|
||||||
|
ranges: { begin: 17 end: 18 }
|
||||||
|
ranges: { begin: 19 end: 20 }
|
||||||
|
ranges: { begin: 21 end: 22 }
|
||||||
|
ranges: { begin: 23 end: 24 }
|
||||||
|
|
||||||
|
combine_outputs: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Take right pose landmarks.
|
||||||
|
node {
|
||||||
|
calculator: "SplitNormalizedLandmarkListCalculator"
|
||||||
|
input_stream: "pose_landmarks"
|
||||||
|
output_stream: "landmarks_right_side"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
||||||
|
ranges: { begin: 4 end: 7 }
|
||||||
|
ranges: { begin: 8 end: 9 }
|
||||||
|
ranges: { begin: 10 end: 11 }
|
||||||
|
ranges: { begin: 12 end: 13 }
|
||||||
|
ranges: { begin: 14 end: 15 }
|
||||||
|
ranges: { begin: 16 end: 17 }
|
||||||
|
ranges: { begin: 18 end: 19 }
|
||||||
|
ranges: { begin: 20 end: 21 }
|
||||||
|
ranges: { begin: 22 end: 23 }
|
||||||
|
ranges: { begin: 24 end: 25 }
|
||||||
|
|
||||||
|
combine_outputs: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Render pose joints as big white circles.
|
||||||
|
node {
|
||||||
|
calculator: "LandmarksToRenderDataCalculator"
|
||||||
|
input_stream: "NORM_LANDMARKS:visible_pose_landmarks"
|
||||||
|
input_stream: "RENDER_SCALE:render_scale"
|
||||||
|
output_stream: "RENDER_DATA:landmarks_background_joints_render_data"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
||||||
|
landmark_color { r: 255 g: 255 b: 255 }
|
||||||
|
connection_color { r: 255 g: 255 b: 255 }
|
||||||
|
thickness: 5.0
|
||||||
|
visualize_landmark_depth: false
|
||||||
|
utilize_visibility: true
|
||||||
|
visibility_threshold: 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Render pose left side joints as orange circles (inside white ones).
|
||||||
|
node {
|
||||||
|
calculator: "LandmarksToRenderDataCalculator"
|
||||||
|
input_stream: "NORM_LANDMARKS:landmarks_left_side"
|
||||||
|
input_stream: "RENDER_SCALE:render_scale"
|
||||||
|
output_stream: "RENDER_DATA:landmarks_left_joints_render_data"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
||||||
|
landmark_color { r: 255 g: 138 b: 0 }
|
||||||
|
connection_color { r: 255 g: 138 b: 0 }
|
||||||
|
thickness: 3.0
|
||||||
|
visualize_landmark_depth: false
|
||||||
|
utilize_visibility: true
|
||||||
|
visibility_threshold: 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Render pose right side joints as cyan circles (inside white ones).
|
||||||
|
node {
|
||||||
|
calculator: "LandmarksToRenderDataCalculator"
|
||||||
|
input_stream: "NORM_LANDMARKS:landmarks_right_side"
|
||||||
|
input_stream: "RENDER_SCALE:render_scale"
|
||||||
|
output_stream: "RENDER_DATA:landmarks_right_joints_render_data"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
||||||
|
landmark_color { r: 0 g: 217 b: 231 }
|
||||||
|
connection_color { r: 0 g: 217 b: 231 }
|
||||||
|
thickness: 3.0
|
||||||
|
visualize_landmark_depth: false
|
||||||
|
utilize_visibility: true
|
||||||
|
visibility_threshold: 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merges annotations into one result.
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateRenderDataVectorCalculator"
|
||||||
|
input_stream: "landmarks_render_data"
|
||||||
|
input_stream: "landmarks_background_joints_render_data"
|
||||||
|
input_stream: "landmarks_left_joints_render_data"
|
||||||
|
input_stream: "landmarks_right_joints_render_data"
|
||||||
|
output_stream: "merged_render_data"
|
||||||
|
}
|
|
@ -22,19 +22,6 @@ node {
|
||||||
output_stream: "SIZE:image_size"
|
output_stream: "SIZE:image_size"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Calculates rendering scale based on the pose roi.
|
|
||||||
node {
|
|
||||||
calculator: "RectToRenderScaleCalculator"
|
|
||||||
input_stream: "NORM_RECT:roi"
|
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
|
||||||
output_stream: "RENDER_SCALE:render_scale"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] {
|
|
||||||
multiplier: 0.0012
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Converts detections to drawing primitives for annotation overlay.
|
# Converts detections to drawing primitives for annotation overlay.
|
||||||
node {
|
node {
|
||||||
calculator: "DetectionsToRenderDataCalculator"
|
calculator: "DetectionsToRenderDataCalculator"
|
||||||
|
@ -48,204 +35,13 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Computes render data for landmarks.
|
||||||
node {
|
node {
|
||||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
calculator: "PoseLandmarksToRenderData"
|
||||||
input_stream: "pose_landmarks"
|
input_stream: "LANDMARKS:pose_landmarks"
|
||||||
output_stream: "visible_pose_landmarks"
|
input_stream: "ROI:roi"
|
||||||
node_options: {
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
|
||||||
ranges: { begin: 0 end: 25 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Converts landmarks to drawing primitives for annotation overlay.
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:pose_landmarks"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_render_data"
|
output_stream: "RENDER_DATA:landmarks_render_data"
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_connections: 0
|
|
||||||
landmark_connections: 1
|
|
||||||
landmark_connections: 1
|
|
||||||
landmark_connections: 2
|
|
||||||
landmark_connections: 2
|
|
||||||
landmark_connections: 3
|
|
||||||
landmark_connections: 3
|
|
||||||
landmark_connections: 7
|
|
||||||
landmark_connections: 0
|
|
||||||
landmark_connections: 4
|
|
||||||
landmark_connections: 4
|
|
||||||
landmark_connections: 5
|
|
||||||
landmark_connections: 5
|
|
||||||
landmark_connections: 6
|
|
||||||
landmark_connections: 6
|
|
||||||
landmark_connections: 8
|
|
||||||
landmark_connections: 9
|
|
||||||
landmark_connections: 10
|
|
||||||
landmark_connections: 11
|
|
||||||
landmark_connections: 12
|
|
||||||
landmark_connections: 11
|
|
||||||
landmark_connections: 13
|
|
||||||
landmark_connections: 13
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 17
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 19
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 21
|
|
||||||
landmark_connections: 17
|
|
||||||
landmark_connections: 19
|
|
||||||
landmark_connections: 12
|
|
||||||
landmark_connections: 14
|
|
||||||
landmark_connections: 14
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 18
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 20
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 22
|
|
||||||
landmark_connections: 18
|
|
||||||
landmark_connections: 20
|
|
||||||
landmark_connections: 11
|
|
||||||
landmark_connections: 23
|
|
||||||
landmark_connections: 12
|
|
||||||
landmark_connections: 24
|
|
||||||
landmark_connections: 23
|
|
||||||
landmark_connections: 24
|
|
||||||
landmark_connections: 23
|
|
||||||
landmark_connections: 25
|
|
||||||
landmark_connections: 24
|
|
||||||
landmark_connections: 26
|
|
||||||
landmark_connections: 25
|
|
||||||
landmark_connections: 27
|
|
||||||
landmark_connections: 26
|
|
||||||
landmark_connections: 28
|
|
||||||
landmark_connections: 27
|
|
||||||
landmark_connections: 29
|
|
||||||
landmark_connections: 28
|
|
||||||
landmark_connections: 30
|
|
||||||
landmark_connections: 29
|
|
||||||
landmark_connections: 31
|
|
||||||
landmark_connections: 30
|
|
||||||
landmark_connections: 32
|
|
||||||
landmark_connections: 27
|
|
||||||
landmark_connections: 31
|
|
||||||
landmark_connections: 28
|
|
||||||
landmark_connections: 32
|
|
||||||
|
|
||||||
landmark_color { r: 255 g: 255 b: 255 }
|
|
||||||
connection_color { r: 255 g: 255 b: 255 }
|
|
||||||
thickness: 3.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Take left pose landmarks.
|
|
||||||
node {
|
|
||||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
|
||||||
input_stream: "pose_landmarks"
|
|
||||||
output_stream: "landmarks_left_side"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
|
||||||
ranges: { begin: 1 end: 4 }
|
|
||||||
ranges: { begin: 7 end: 8 }
|
|
||||||
ranges: { begin: 9 end: 10 }
|
|
||||||
ranges: { begin: 11 end: 12 }
|
|
||||||
ranges: { begin: 13 end: 14 }
|
|
||||||
ranges: { begin: 15 end: 16 }
|
|
||||||
ranges: { begin: 17 end: 18 }
|
|
||||||
ranges: { begin: 19 end: 20 }
|
|
||||||
ranges: { begin: 21 end: 22 }
|
|
||||||
ranges: { begin: 23 end: 24 }
|
|
||||||
|
|
||||||
combine_outputs: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Take right pose landmarks.
|
|
||||||
node {
|
|
||||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
|
||||||
input_stream: "pose_landmarks"
|
|
||||||
output_stream: "landmarks_right_side"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
|
||||||
ranges: { begin: 4 end: 7 }
|
|
||||||
ranges: { begin: 8 end: 9 }
|
|
||||||
ranges: { begin: 10 end: 11 }
|
|
||||||
ranges: { begin: 12 end: 13 }
|
|
||||||
ranges: { begin: 14 end: 15 }
|
|
||||||
ranges: { begin: 16 end: 17 }
|
|
||||||
ranges: { begin: 18 end: 19 }
|
|
||||||
ranges: { begin: 20 end: 21 }
|
|
||||||
ranges: { begin: 22 end: 23 }
|
|
||||||
ranges: { begin: 24 end: 25 }
|
|
||||||
|
|
||||||
combine_outputs: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Render pose joints as big white circles.
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:visible_pose_landmarks"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_background_joints_render_data"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_color { r: 255 g: 255 b: 255 }
|
|
||||||
connection_color { r: 255 g: 255 b: 255 }
|
|
||||||
thickness: 5.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Render pose left side joints as orange circles (inside white ones).
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:landmarks_left_side"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_left_joints_render_data"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_color { r: 255 g: 138 b: 0 }
|
|
||||||
connection_color { r: 255 g: 138 b: 0 }
|
|
||||||
thickness: 3.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Render pose right side joints as cyan circles (inside white ones).
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:landmarks_right_side"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_right_joints_render_data"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_color { r: 0 g: 217 b: 231 }
|
|
||||||
connection_color { r: 0 g: 217 b: 231 }
|
|
||||||
thickness: 3.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Converts normalized rects to drawing primitives for annotation overlay.
|
# Converts normalized rects to drawing primitives for annotation overlay.
|
||||||
|
@ -283,10 +79,7 @@ node {
|
||||||
calculator: "AnnotationOverlayCalculator"
|
calculator: "AnnotationOverlayCalculator"
|
||||||
input_stream: "IMAGE:segmented_image"
|
input_stream: "IMAGE:segmented_image"
|
||||||
input_stream: "detection_render_data"
|
input_stream: "detection_render_data"
|
||||||
input_stream: "landmarks_render_data"
|
input_stream: "VECTOR:landmarks_render_data"
|
||||||
input_stream: "landmarks_background_joints_render_data"
|
|
||||||
input_stream: "landmarks_left_joints_render_data"
|
|
||||||
input_stream: "landmarks_right_joints_render_data"
|
|
||||||
input_stream: "roi_render_data"
|
input_stream: "roi_render_data"
|
||||||
output_stream: "IMAGE:output_image"
|
output_stream: "IMAGE:output_image"
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,19 +22,6 @@ node {
|
||||||
output_stream: "SIZE:image_size"
|
output_stream: "SIZE:image_size"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Calculates rendering scale based on the pose roi.
|
|
||||||
node {
|
|
||||||
calculator: "RectToRenderScaleCalculator"
|
|
||||||
input_stream: "NORM_RECT:roi"
|
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
|
||||||
output_stream: "RENDER_SCALE:render_scale"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] {
|
|
||||||
multiplier: 0.0012
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Converts detections to drawing primitives for annotation overlay.
|
# Converts detections to drawing primitives for annotation overlay.
|
||||||
node {
|
node {
|
||||||
calculator: "DetectionsToRenderDataCalculator"
|
calculator: "DetectionsToRenderDataCalculator"
|
||||||
|
@ -48,204 +35,13 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Computes render data for landmarks.
|
||||||
node {
|
node {
|
||||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
calculator: "PoseLandmarksToRenderData"
|
||||||
input_stream: "pose_landmarks"
|
input_stream: "LANDMARKS:pose_landmarks"
|
||||||
output_stream: "visible_pose_landmarks"
|
input_stream: "ROI:roi"
|
||||||
node_options: {
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
|
||||||
ranges: { begin: 0 end: 25 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Converts landmarks to drawing primitives for annotation overlay.
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:pose_landmarks"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_render_data"
|
output_stream: "RENDER_DATA:landmarks_render_data"
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_connections: 0
|
|
||||||
landmark_connections: 1
|
|
||||||
landmark_connections: 1
|
|
||||||
landmark_connections: 2
|
|
||||||
landmark_connections: 2
|
|
||||||
landmark_connections: 3
|
|
||||||
landmark_connections: 3
|
|
||||||
landmark_connections: 7
|
|
||||||
landmark_connections: 0
|
|
||||||
landmark_connections: 4
|
|
||||||
landmark_connections: 4
|
|
||||||
landmark_connections: 5
|
|
||||||
landmark_connections: 5
|
|
||||||
landmark_connections: 6
|
|
||||||
landmark_connections: 6
|
|
||||||
landmark_connections: 8
|
|
||||||
landmark_connections: 9
|
|
||||||
landmark_connections: 10
|
|
||||||
landmark_connections: 11
|
|
||||||
landmark_connections: 12
|
|
||||||
landmark_connections: 11
|
|
||||||
landmark_connections: 13
|
|
||||||
landmark_connections: 13
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 17
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 19
|
|
||||||
landmark_connections: 15
|
|
||||||
landmark_connections: 21
|
|
||||||
landmark_connections: 17
|
|
||||||
landmark_connections: 19
|
|
||||||
landmark_connections: 12
|
|
||||||
landmark_connections: 14
|
|
||||||
landmark_connections: 14
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 18
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 20
|
|
||||||
landmark_connections: 16
|
|
||||||
landmark_connections: 22
|
|
||||||
landmark_connections: 18
|
|
||||||
landmark_connections: 20
|
|
||||||
landmark_connections: 11
|
|
||||||
landmark_connections: 23
|
|
||||||
landmark_connections: 12
|
|
||||||
landmark_connections: 24
|
|
||||||
landmark_connections: 23
|
|
||||||
landmark_connections: 24
|
|
||||||
landmark_connections: 23
|
|
||||||
landmark_connections: 25
|
|
||||||
landmark_connections: 24
|
|
||||||
landmark_connections: 26
|
|
||||||
landmark_connections: 25
|
|
||||||
landmark_connections: 27
|
|
||||||
landmark_connections: 26
|
|
||||||
landmark_connections: 28
|
|
||||||
landmark_connections: 27
|
|
||||||
landmark_connections: 29
|
|
||||||
landmark_connections: 28
|
|
||||||
landmark_connections: 30
|
|
||||||
landmark_connections: 29
|
|
||||||
landmark_connections: 31
|
|
||||||
landmark_connections: 30
|
|
||||||
landmark_connections: 32
|
|
||||||
landmark_connections: 27
|
|
||||||
landmark_connections: 31
|
|
||||||
landmark_connections: 28
|
|
||||||
landmark_connections: 32
|
|
||||||
|
|
||||||
landmark_color { r: 255 g: 255 b: 255 }
|
|
||||||
connection_color { r: 255 g: 255 b: 255 }
|
|
||||||
thickness: 3.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Take left pose landmarks.
|
|
||||||
node {
|
|
||||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
|
||||||
input_stream: "pose_landmarks"
|
|
||||||
output_stream: "landmarks_left_side"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
|
||||||
ranges: { begin: 1 end: 4 }
|
|
||||||
ranges: { begin: 7 end: 8 }
|
|
||||||
ranges: { begin: 9 end: 10 }
|
|
||||||
ranges: { begin: 11 end: 12 }
|
|
||||||
ranges: { begin: 13 end: 14 }
|
|
||||||
ranges: { begin: 15 end: 16 }
|
|
||||||
ranges: { begin: 17 end: 18 }
|
|
||||||
ranges: { begin: 19 end: 20 }
|
|
||||||
ranges: { begin: 21 end: 22 }
|
|
||||||
ranges: { begin: 23 end: 24 }
|
|
||||||
|
|
||||||
combine_outputs: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Take right pose landmarks.
|
|
||||||
node {
|
|
||||||
calculator: "SplitNormalizedLandmarkListCalculator"
|
|
||||||
input_stream: "pose_landmarks"
|
|
||||||
output_stream: "landmarks_right_side"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
|
|
||||||
ranges: { begin: 4 end: 7 }
|
|
||||||
ranges: { begin: 8 end: 9 }
|
|
||||||
ranges: { begin: 10 end: 11 }
|
|
||||||
ranges: { begin: 12 end: 13 }
|
|
||||||
ranges: { begin: 14 end: 15 }
|
|
||||||
ranges: { begin: 16 end: 17 }
|
|
||||||
ranges: { begin: 18 end: 19 }
|
|
||||||
ranges: { begin: 20 end: 21 }
|
|
||||||
ranges: { begin: 22 end: 23 }
|
|
||||||
ranges: { begin: 24 end: 25 }
|
|
||||||
|
|
||||||
combine_outputs: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Render pose joints as big white circles.
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:visible_pose_landmarks"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_background_joints_render_data"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_color { r: 255 g: 255 b: 255 }
|
|
||||||
connection_color { r: 255 g: 255 b: 255 }
|
|
||||||
thickness: 5.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Render pose left side joints as orange circles (inside white ones).
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:landmarks_left_side"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_left_joints_render_data"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_color { r: 255 g: 138 b: 0 }
|
|
||||||
connection_color { r: 255 g: 138 b: 0 }
|
|
||||||
thickness: 3.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Render pose right side joints as cyan circles (inside white ones).
|
|
||||||
node {
|
|
||||||
calculator: "LandmarksToRenderDataCalculator"
|
|
||||||
input_stream: "NORM_LANDMARKS:landmarks_right_side"
|
|
||||||
input_stream: "RENDER_SCALE:render_scale"
|
|
||||||
output_stream: "RENDER_DATA:landmarks_right_joints_render_data"
|
|
||||||
node_options: {
|
|
||||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
|
||||||
landmark_color { r: 0 g: 217 b: 231 }
|
|
||||||
connection_color { r: 0 g: 217 b: 231 }
|
|
||||||
thickness: 3.0
|
|
||||||
visualize_landmark_depth: false
|
|
||||||
utilize_visibility: true
|
|
||||||
visibility_threshold: 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Converts normalized rects to drawing primitives for annotation overlay.
|
# Converts normalized rects to drawing primitives for annotation overlay.
|
||||||
|
@ -283,10 +79,7 @@ node {
|
||||||
calculator: "AnnotationOverlayCalculator"
|
calculator: "AnnotationOverlayCalculator"
|
||||||
input_stream: "IMAGE_GPU:segmented_image"
|
input_stream: "IMAGE_GPU:segmented_image"
|
||||||
input_stream: "detection_render_data"
|
input_stream: "detection_render_data"
|
||||||
input_stream: "landmarks_render_data"
|
input_stream: "VECTOR:landmarks_render_data"
|
||||||
input_stream: "landmarks_background_joints_render_data"
|
|
||||||
input_stream: "landmarks_left_joints_render_data"
|
|
||||||
input_stream: "landmarks_right_joints_render_data"
|
|
||||||
input_stream: "roi_render_data"
|
input_stream: "roi_render_data"
|
||||||
output_stream: "IMAGE_GPU:output_image"
|
output_stream: "IMAGE_GPU:output_image"
|
||||||
}
|
}
|
||||||
|
|
|
@ -174,6 +174,14 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
thread.setRotation(rotation);
|
thread.setRotation(rotation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets whether the timestamps of each frame should be adjusted to be always monotonically
|
||||||
|
* increasing. The default behavior is that this is {@code true}.
|
||||||
|
*/
|
||||||
|
public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) {
|
||||||
|
thread.setShouldAdjustTimestamps(shouldAdjustTimestamps);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets an offset that can be used to adjust the timestamps on the camera frames, for example to
|
* Sets an offset that can be used to adjust the timestamps on the camera frames, for example to
|
||||||
* conform to a preferred time-base or to account for a known device latency. The offset is added
|
* conform to a preferred time-base or to account for a known device latency. The offset is added
|
||||||
|
@ -298,6 +306,7 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
private int bufferPoolMaxSize;
|
private int bufferPoolMaxSize;
|
||||||
|
|
||||||
private ExternalTextureRenderer renderer = null;
|
private ExternalTextureRenderer renderer = null;
|
||||||
|
private boolean shouldAdjustTimestamps = true;
|
||||||
private long nextFrameTimestampOffset = 0;
|
private long nextFrameTimestampOffset = 0;
|
||||||
private long timestampOffsetNanos = 0;
|
private long timestampOffsetNanos = 0;
|
||||||
private long previousTimestamp = 0;
|
private long previousTimestamp = 0;
|
||||||
|
@ -433,6 +442,10 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
|
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) {
|
||||||
|
this.shouldAdjustTimestamps = shouldAdjustTimestamps;
|
||||||
|
}
|
||||||
|
|
||||||
public void setTimestampOffsetNanos(long offsetInNanos) {
|
public void setTimestampOffsetNanos(long offsetInNanos) {
|
||||||
timestampOffsetNanos = offsetInNanos;
|
timestampOffsetNanos = offsetInNanos;
|
||||||
}
|
}
|
||||||
|
@ -565,7 +578,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
||||||
// |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.)
|
// |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.)
|
||||||
long textureTimestamp =
|
long textureTimestamp =
|
||||||
(surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO;
|
(surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO;
|
||||||
if (previousTimestampValid
|
if (shouldAdjustTimestamps
|
||||||
|
&& previousTimestampValid
|
||||||
&& textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) {
|
&& textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) {
|
||||||
nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp;
|
nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,10 @@
|
||||||
package com.google.mediapipe.framework;
|
package com.google.mediapipe.framework;
|
||||||
|
|
||||||
import android.graphics.Bitmap;
|
import android.graphics.Bitmap;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||||
|
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||||
|
import com.google.mediapipe.framework.image.Image;
|
||||||
|
import com.google.mediapipe.framework.image.ImageProperties;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
// TODO: use Preconditions in this file.
|
// TODO: use Preconditions in this file.
|
||||||
|
@ -55,6 +59,50 @@ public class AndroidPacketCreator extends PacketCreator {
|
||||||
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
|
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an Image packet from an {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
|
||||||
|
*/
|
||||||
|
public Packet createImage(Image image) {
|
||||||
|
// TODO: Choose the best storage from multiple containers.
|
||||||
|
ImageProperties properties = image.getContainedImageProperties().get(0);
|
||||||
|
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
|
||||||
|
ByteBuffer buffer = ByteBufferExtractor.extract(image);
|
||||||
|
int numChannels = 0;
|
||||||
|
switch (properties.getImageFormat()) {
|
||||||
|
case Image.IMAGE_FORMAT_RGBA:
|
||||||
|
numChannels = 4;
|
||||||
|
break;
|
||||||
|
case Image.IMAGE_FORMAT_RGB:
|
||||||
|
numChannels = 3;
|
||||||
|
break;
|
||||||
|
case Image.IMAGE_FORMAT_ALPHA:
|
||||||
|
numChannels = 1;
|
||||||
|
break;
|
||||||
|
default: // fall out
|
||||||
|
}
|
||||||
|
if (numChannels == 0) {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Unsupported MediaPipe Image image format: " + properties.getImageFormat());
|
||||||
|
}
|
||||||
|
int width = image.getWidth();
|
||||||
|
int height = image.getHeight();
|
||||||
|
return createImage(buffer, width, height, numChannels);
|
||||||
|
}
|
||||||
|
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
|
||||||
|
Bitmap bitmap = BitmapExtractor.extract(image);
|
||||||
|
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
|
||||||
|
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
|
||||||
|
}
|
||||||
|
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsupported type.
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Unsupported Image container type: " + properties.getImageFormat());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on
|
* Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on
|
||||||
* failure.
|
* failure.
|
||||||
|
|
|
@ -57,6 +57,7 @@ android_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":android_core",
|
":android_core",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//third_party:androidx_annotation",
|
"//third_party:androidx_annotation",
|
||||||
"//third_party:androidx_legacy_support_v4",
|
"//third_party:androidx_legacy_support_v4",
|
||||||
"@maven//:com_google_code_findbugs_jsr305",
|
"@maven//:com_google_code_findbugs_jsr305",
|
||||||
|
@ -75,6 +76,7 @@ android_library(
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
["**/*.java"],
|
["**/*.java"],
|
||||||
exclude = [
|
exclude = [
|
||||||
|
"image/**",
|
||||||
"Android*",
|
"Android*",
|
||||||
"AssetCache.java",
|
"AssetCache.java",
|
||||||
"AssetCacheDbHelper.java",
|
"AssetCacheDbHelper.java",
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.framework.image">
|
||||||
|
<uses-sdk android:minSdkVersion="16" />
|
||||||
|
<application />
|
||||||
|
</manifest>
|
32
mediapipe/java/com/google/mediapipe/framework/image/BUILD
Normal file
32
mediapipe/java/com/google/mediapipe/framework/image/BUILD
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "image",
|
||||||
|
srcs = glob(["*.java"]),
|
||||||
|
manifest = "AndroidManifest.xml",
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe:__subpackages__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//third_party:androidx_legacy_support_v4",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:androidx_annotation_annotation",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,49 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.graphics.Bitmap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
|
||||||
|
* {@link IllegalArgumentException} will be thrown.
|
||||||
|
*/
|
||||||
|
public final class BitmapExtractor {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
|
||||||
|
*
|
||||||
|
* @param image the image to extract {@link android.graphics.Bitmap} from.
|
||||||
|
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
|
||||||
|
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||||
|
* conversions.
|
||||||
|
*/
|
||||||
|
public static Bitmap extract(Image image) {
|
||||||
|
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
|
||||||
|
if (imageContainer != null) {
|
||||||
|
return ((BitmapImageContainer) imageContainer).getBitmap();
|
||||||
|
} else {
|
||||||
|
// TODO: Support ByteBuffer -> Bitmap conversion.
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
|
||||||
|
+ " supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private BitmapExtractor() {}
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.graphics.Bitmap;
|
||||||
|
import android.net.Uri;
|
||||||
|
import android.provider.MediaStore;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds {@link Image} from {@link android.graphics.Bitmap}.
|
||||||
|
*
|
||||||
|
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
|
||||||
|
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
|
||||||
|
* in it.
|
||||||
|
*
|
||||||
|
* <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in.
|
||||||
|
*/
|
||||||
|
public class BitmapImageBuilder {
|
||||||
|
|
||||||
|
// Mandatory fields.
|
||||||
|
private final Bitmap bitmap;
|
||||||
|
|
||||||
|
// Optional fields.
|
||||||
|
private long timestamp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the builder with a mandatory {@link android.graphics.Bitmap}.
|
||||||
|
*
|
||||||
|
* @param bitmap image data object.
|
||||||
|
*/
|
||||||
|
public BitmapImageBuilder(Bitmap bitmap) {
|
||||||
|
this.bitmap = bitmap;
|
||||||
|
timestamp = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the builder to build {@link Image} from a file.
|
||||||
|
*
|
||||||
|
* @param context the application context.
|
||||||
|
* @param uri the path to the resource file.
|
||||||
|
*/
|
||||||
|
public BitmapImageBuilder(Context context, Uri uri) throws IOException {
|
||||||
|
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets value for {@link Image#getTimestamp()}. */
|
||||||
|
BitmapImageBuilder setTimestamp(long timestamp) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Builds an {@link Image} instance. */
|
||||||
|
public Image build() {
|
||||||
|
return new Image(
|
||||||
|
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.graphics.Bitmap;
|
||||||
|
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||||
|
|
||||||
|
class BitmapImageContainer implements ImageContainer {
|
||||||
|
|
||||||
|
private final Bitmap bitmap;
|
||||||
|
private final ImageProperties properties;
|
||||||
|
|
||||||
|
public BitmapImageContainer(Bitmap bitmap) {
|
||||||
|
this.bitmap = bitmap;
|
||||||
|
this.properties =
|
||||||
|
ImageProperties.builder()
|
||||||
|
.setImageFormat(convertFormatCode(bitmap.getConfig()))
|
||||||
|
.setStorageType(Image.STORAGE_TYPE_BITMAP)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Bitmap getBitmap() {
|
||||||
|
return bitmap;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ImageProperties getImageProperties() {
|
||||||
|
return properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
bitmap.recycle();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ImageFormat
|
||||||
|
static int convertFormatCode(Bitmap.Config config) {
|
||||||
|
switch (config) {
|
||||||
|
case ALPHA_8:
|
||||||
|
return Image.IMAGE_FORMAT_ALPHA;
|
||||||
|
case ARGB_8888:
|
||||||
|
return Image.IMAGE_FORMAT_RGBA;
|
||||||
|
default:
|
||||||
|
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,254 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.annotation.SuppressLint;
|
||||||
|
import android.graphics.Bitmap;
|
||||||
|
import android.graphics.Bitmap.Config;
|
||||||
|
import android.os.Build.VERSION;
|
||||||
|
import android.os.Build.VERSION_CODES;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility for extracting {@link ByteBuffer} from {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
|
||||||
|
* {@link IllegalArgumentException} will be thrown.
|
||||||
|
*/
|
||||||
|
public class ByteBufferExtractor {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts a {@link ByteBuffer} from an {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
|
||||||
|
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
|
||||||
|
*
|
||||||
|
* @see Image#getContainedImageProperties()
|
||||||
|
* @return A read-only {@link ByteBuffer}.
|
||||||
|
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
|
||||||
|
*/
|
||||||
|
@SuppressLint("SwitchIntDef")
|
||||||
|
public static ByteBuffer extract(Image image) {
|
||||||
|
ImageContainer container = image.getContainer();
|
||||||
|
switch (container.getImageProperties().getStorageType()) {
|
||||||
|
case Image.STORAGE_TYPE_BYTEBUFFER:
|
||||||
|
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||||
|
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
|
||||||
|
+ " supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>Format conversion spec:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>When extracting RGB images to RGBA format, A channel will always set to 255.
|
||||||
|
* <li>When extracting RGBA images to RGB format, A channel will be dropped.
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image the image to extract buffer from.
|
||||||
|
* @param targetFormat the image format of the result bytebuffer.
|
||||||
|
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||||
|
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||||
|
* conversions.
|
||||||
|
*/
|
||||||
|
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
|
||||||
|
ImageContainer container;
|
||||||
|
ImageProperties byteBufferProperties =
|
||||||
|
ImageProperties.builder()
|
||||||
|
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||||
|
.setImageFormat(targetFormat)
|
||||||
|
.build();
|
||||||
|
if ((container = image.getContainer(byteBufferProperties)) != null) {
|
||||||
|
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||||
|
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||||
|
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||||
|
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||||
|
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||||
|
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
|
||||||
|
.asReadOnlyBuffer();
|
||||||
|
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||||
|
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
|
||||||
|
ByteBuffer byteBuffer =
|
||||||
|
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
|
||||||
|
.asReadOnlyBuffer();
|
||||||
|
boolean unused = image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
|
||||||
|
return byteBuffer;
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
|
||||||
|
+ " Bytebuffer is not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
|
||||||
|
@AutoValue
|
||||||
|
abstract static class Result {
|
||||||
|
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||||
|
public abstract ByteBuffer buffer();
|
||||||
|
|
||||||
|
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||||
|
@ImageFormat
|
||||||
|
public abstract int format();
|
||||||
|
|
||||||
|
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
|
||||||
|
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
|
||||||
|
*
|
||||||
|
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||||
|
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
|
||||||
|
* given {@code imageFormat}
|
||||||
|
*/
|
||||||
|
static Result extractInRecommendedFormat(Image image) {
|
||||||
|
ImageContainer container;
|
||||||
|
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||||
|
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
|
||||||
|
@ImageFormat int format = adviseImageFormat(bitmap);
|
||||||
|
Result result =
|
||||||
|
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
|
||||||
|
|
||||||
|
boolean unused =
|
||||||
|
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
|
||||||
|
return result;
|
||||||
|
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||||
|
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||||
|
return Result.create(
|
||||||
|
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
|
||||||
|
byteBufferImageContainer.getImageFormat());
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
|
||||||
|
+ " is not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ImageFormat
|
||||||
|
private static int adviseImageFormat(Bitmap bitmap) {
|
||||||
|
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||||
|
return Image.IMAGE_FORMAT_RGBA;
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
|
||||||
|
+ " supported",
|
||||||
|
bitmap.getConfig()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ByteBuffer extractByteBufferFromBitmap(
|
||||||
|
Bitmap bitmap, @ImageFormat int imageFormat) {
|
||||||
|
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
|
||||||
|
+ " supported");
|
||||||
|
}
|
||||||
|
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||||
|
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
|
||||||
|
bitmap.copyPixelsToBuffer(buffer);
|
||||||
|
buffer.rewind();
|
||||||
|
return buffer;
|
||||||
|
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
|
||||||
|
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
|
||||||
|
int w = bitmap.getWidth();
|
||||||
|
int h = bitmap.getHeight();
|
||||||
|
int[] pixels = new int[w * h];
|
||||||
|
bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
|
||||||
|
buffer.order(ByteOrder.nativeOrder());
|
||||||
|
for (int pixel : pixels) {
|
||||||
|
// getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA
|
||||||
|
buffer.put((byte) ((pixel >> 16) & 0xff));
|
||||||
|
buffer.put((byte) ((pixel >> 8) & 0xff));
|
||||||
|
buffer.put((byte) (pixel & 0xff));
|
||||||
|
}
|
||||||
|
buffer.rewind();
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
|
||||||
|
+ " %d is not supported",
|
||||||
|
bitmap.getConfig(), imageFormat));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ByteBuffer convertByteBuffer(
|
||||||
|
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
|
||||||
|
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||||
|
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
|
||||||
|
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
|
||||||
|
// array reversely to convert in-place.
|
||||||
|
byte[] array = new byte[target.capacity()];
|
||||||
|
source.get(array, 0, source.capacity());
|
||||||
|
source.rewind();
|
||||||
|
int rgbCursor = source.capacity();
|
||||||
|
int rgbaCursor = target.capacity();
|
||||||
|
while (rgbCursor != rgbaCursor) {
|
||||||
|
array[--rgbaCursor] = (byte) 0xff; // A
|
||||||
|
array[--rgbaCursor] = array[--rgbCursor]; // B
|
||||||
|
array[--rgbaCursor] = array[--rgbCursor]; // G
|
||||||
|
array[--rgbaCursor] = array[--rgbCursor]; // R
|
||||||
|
}
|
||||||
|
target.put(array, 0, target.capacity());
|
||||||
|
target.rewind();
|
||||||
|
return target;
|
||||||
|
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
|
||||||
|
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
|
||||||
|
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
|
||||||
|
// array to convert in-place.
|
||||||
|
byte[] array = new byte[source.capacity()];
|
||||||
|
source.get(array, 0, source.capacity());
|
||||||
|
source.rewind();
|
||||||
|
int rgbaCursor = 0;
|
||||||
|
int rgbCursor = 0;
|
||||||
|
while (rgbaCursor < array.length) {
|
||||||
|
array[rgbCursor++] = array[rgbaCursor++]; // R
|
||||||
|
array[rgbCursor++] = array[rgbaCursor++]; // G
|
||||||
|
array[rgbCursor++] = array[rgbaCursor++]; // B
|
||||||
|
rgbaCursor++;
|
||||||
|
}
|
||||||
|
target.put(array, 0, target.capacity());
|
||||||
|
target.rewind();
|
||||||
|
return target;
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
Locale.ENGLISH,
|
||||||
|
"Convert bytebuffer image format from %d to %d is not supported",
|
||||||
|
sourceFormat,
|
||||||
|
targetFormat));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByteBuffer is not able to be instantiated.
|
||||||
|
private ByteBufferExtractor() {}
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds a {@link Image} from a {@link ByteBuffer}.
|
||||||
|
*
|
||||||
|
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
|
||||||
|
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
|
||||||
|
*
|
||||||
|
* <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in.
|
||||||
|
*/
|
||||||
|
public class ByteBufferImageBuilder {
|
||||||
|
|
||||||
|
// Mandatory fields.
|
||||||
|
private final ByteBuffer buffer;
|
||||||
|
private final int width;
|
||||||
|
private final int height;
|
||||||
|
@ImageFormat private final int imageFormat;
|
||||||
|
|
||||||
|
// Optional fields.
|
||||||
|
private long timestamp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the builder with mandatory {@link ByteBuffer} and the represented image.
|
||||||
|
*
|
||||||
|
* <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height}
|
||||||
|
* and {@code imageFormat}.
|
||||||
|
*
|
||||||
|
* @param byteBuffer image data object.
|
||||||
|
* @param width the width of the represented image.
|
||||||
|
* @param height the height of the represented image.
|
||||||
|
* @param imageFormat how the data encode the image.
|
||||||
|
*/
|
||||||
|
public ByteBufferImageBuilder(
|
||||||
|
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
|
||||||
|
this.buffer = byteBuffer;
|
||||||
|
this.width = width;
|
||||||
|
this.height = height;
|
||||||
|
this.imageFormat = imageFormat;
|
||||||
|
// TODO: Validate bytebuffer size with width, height and image format
|
||||||
|
this.timestamp = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets value for {@link Image#getTimestamp()}. */
|
||||||
|
ByteBufferImageBuilder setTimestamp(long timestamp) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Builds an {@link Image} instance. */
|
||||||
|
public Image build() {
|
||||||
|
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
|
class ByteBufferImageContainer implements ImageContainer {
|
||||||
|
|
||||||
|
private final ByteBuffer buffer;
|
||||||
|
private final ImageProperties properties;
|
||||||
|
|
||||||
|
public ByteBufferImageContainer(
|
||||||
|
ByteBuffer buffer,
|
||||||
|
@ImageFormat int imageFormat) {
|
||||||
|
this.buffer = buffer;
|
||||||
|
this.properties =
|
||||||
|
ImageProperties.builder()
|
||||||
|
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||||
|
.setImageFormat(imageFormat)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public ByteBuffer getByteBuffer() {
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ImageProperties getImageProperties() {
|
||||||
|
return properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the image format.
|
||||||
|
*/
|
||||||
|
@ImageFormat
|
||||||
|
public int getImageFormat() {
|
||||||
|
return properties.getImageFormat();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
// No op for ByteBuffer.
|
||||||
|
}
|
||||||
|
}
|
241
mediapipe/java/com/google/mediapipe/framework/image/Image.java
Normal file
241
mediapipe/java/com/google/mediapipe/framework/image/Image.java
Normal file
|
@ -0,0 +1,241 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import androidx.annotation.IntDef;
|
||||||
|
import androidx.annotation.Nullable;
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The wrapper class for image objects.
|
||||||
|
*
|
||||||
|
* <p>{@link Image} is designed to be an immutable image container, which could be shared
|
||||||
|
* cross-platforms.
|
||||||
|
*
|
||||||
|
* <p>To construct an {@link Image}, use the provided builders:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link ByteBufferImageBuilder}
|
||||||
|
* <li>{@link BitmapImageBuilder}
|
||||||
|
* <li>{@link MediaImageBuilder}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
|
||||||
|
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
|
||||||
|
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
|
||||||
|
*
|
||||||
|
* <p>To extract concrete image, first check {@link StorageType} and then use the provided
|
||||||
|
* extractors:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link ByteBufferExtractor}
|
||||||
|
* <li>{@link BitmapExtractor}
|
||||||
|
* <li>{@link MediaImageExtractor}
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public class Image implements Closeable {
|
||||||
|
|
||||||
|
/** Specifies the image format of an image. */
|
||||||
|
@IntDef({
|
||||||
|
IMAGE_FORMAT_UNKNOWN,
|
||||||
|
IMAGE_FORMAT_RGBA,
|
||||||
|
IMAGE_FORMAT_RGB,
|
||||||
|
IMAGE_FORMAT_NV12,
|
||||||
|
IMAGE_FORMAT_NV21,
|
||||||
|
IMAGE_FORMAT_YV12,
|
||||||
|
IMAGE_FORMAT_YV21,
|
||||||
|
IMAGE_FORMAT_YUV_420_888,
|
||||||
|
IMAGE_FORMAT_ALPHA,
|
||||||
|
IMAGE_FORMAT_JPEG,
|
||||||
|
})
|
||||||
|
@Retention(RetentionPolicy.SOURCE)
|
||||||
|
public @interface ImageFormat {}
|
||||||
|
|
||||||
|
public static final int IMAGE_FORMAT_UNKNOWN = 0;
|
||||||
|
public static final int IMAGE_FORMAT_RGBA = 1;
|
||||||
|
public static final int IMAGE_FORMAT_RGB = 2;
|
||||||
|
public static final int IMAGE_FORMAT_NV12 = 3;
|
||||||
|
public static final int IMAGE_FORMAT_NV21 = 4;
|
||||||
|
public static final int IMAGE_FORMAT_YV12 = 5;
|
||||||
|
public static final int IMAGE_FORMAT_YV21 = 6;
|
||||||
|
public static final int IMAGE_FORMAT_YUV_420_888 = 7;
|
||||||
|
public static final int IMAGE_FORMAT_ALPHA = 8;
|
||||||
|
public static final int IMAGE_FORMAT_JPEG = 9;
|
||||||
|
|
||||||
|
/** Specifies the image container type. Would be useful for choosing extractors. */
|
||||||
|
@IntDef({
|
||||||
|
STORAGE_TYPE_BITMAP,
|
||||||
|
STORAGE_TYPE_BYTEBUFFER,
|
||||||
|
STORAGE_TYPE_MEDIA_IMAGE,
|
||||||
|
STORAGE_TYPE_IMAGE_PROXY,
|
||||||
|
})
|
||||||
|
@Retention(RetentionPolicy.SOURCE)
|
||||||
|
public @interface StorageType {}
|
||||||
|
|
||||||
|
public static final int STORAGE_TYPE_BITMAP = 1;
|
||||||
|
public static final int STORAGE_TYPE_BYTEBUFFER = 2;
|
||||||
|
public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
|
||||||
|
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a list of supported image properties for this {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>Currently {@link Image} only support single storage type so the size of return list will
|
||||||
|
* always be 1.
|
||||||
|
*
|
||||||
|
* @see ImageProperties
|
||||||
|
*/
|
||||||
|
public List<ImageProperties> getContainedImageProperties() {
|
||||||
|
return Collections.singletonList(getContainer().getImageProperties());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the timestamp attached to the image. */
|
||||||
|
long getTimestamp() {
|
||||||
|
return timestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the width of the image. */
|
||||||
|
public int getWidth() {
|
||||||
|
return width;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the height of the image. */
|
||||||
|
public int getHeight() {
|
||||||
|
return height;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
|
||||||
|
private synchronized void acquire() {
|
||||||
|
referenceCount += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes a reference that was previously acquired or init.
|
||||||
|
*
|
||||||
|
* <p>When {@link Image} is created, it has 1 reference count.
|
||||||
|
*
|
||||||
|
* <p>When the reference count becomes 0, it will release the resource under the hood.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
// TODO: Create an internal flag to indicate image is closed, or use referenceCount
|
||||||
|
public synchronized void close() {
|
||||||
|
referenceCount -= 1;
|
||||||
|
if (referenceCount == 0) {
|
||||||
|
for (ImageContainer imageContainer : containerMap.values()) {
|
||||||
|
imageContainer.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Advanced API access for {@link Image}. */
|
||||||
|
static final class Internal {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
|
||||||
|
*
|
||||||
|
* <p>This method is more useful for image consumer to acquire a reference so image resource
|
||||||
|
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
|
||||||
|
* method.
|
||||||
|
*
|
||||||
|
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
|
||||||
|
* #close()} to indicate it doesn't need this {@link Image} anymore.
|
||||||
|
*
|
||||||
|
* @see #close()
|
||||||
|
*/
|
||||||
|
void acquire() {
|
||||||
|
image.acquire();
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Image image;
|
||||||
|
|
||||||
|
// Only Image creates the internal helper.
|
||||||
|
private Internal(Image image) {
|
||||||
|
this.image = image;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets {@link Internal} object which contains internal APIs. */
|
||||||
|
Internal getInternal() {
|
||||||
|
return new Internal(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Map<ImageProperties, ImageContainer> containerMap;
|
||||||
|
private final long timestamp;
|
||||||
|
private final int width;
|
||||||
|
private final int height;
|
||||||
|
|
||||||
|
private int referenceCount;
|
||||||
|
|
||||||
|
/** Constructs an {@link Image} with a built container. */
|
||||||
|
Image(ImageContainer container, long timestamp, int width, int height) {
|
||||||
|
this.containerMap = new HashMap<>();
|
||||||
|
containerMap.put(container.getImageProperties(), container);
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
this.width = width;
|
||||||
|
this.height = height;
|
||||||
|
this.referenceCount = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets one available container.
|
||||||
|
*
|
||||||
|
* @return the current container.
|
||||||
|
*/
|
||||||
|
ImageContainer getContainer() {
|
||||||
|
// According to the design, in the future we will support multiple containers in one image.
|
||||||
|
// Currently just return the original container.
|
||||||
|
// TODO: Cache multiple containers in Image.
|
||||||
|
return containerMap.values().iterator().next();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets container from required {@code storageType}. Returns {@code null} if not existed.
|
||||||
|
*
|
||||||
|
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
ImageContainer getContainer(@StorageType int storageType) {
|
||||||
|
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
|
||||||
|
if (entry.getKey().getStorageType() == storageType) {
|
||||||
|
return entry.getValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
|
||||||
|
@Nullable
|
||||||
|
ImageContainer getContainer(ImageProperties imageProperties) {
|
||||||
|
return containerMap.get(imageProperties);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
|
||||||
|
boolean addContainer(ImageContainer container) {
|
||||||
|
ImageProperties imageProperties = container.getImageProperties();
|
||||||
|
if (containerMap.containsKey(imageProperties)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
containerMap.put(imageProperties, container);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
/** Lightweight abstraction for an object that can receive {@link Image} */
|
||||||
|
public interface ImageConsumer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when an {@link Image} is available.
|
||||||
|
*
|
||||||
|
* <p>The argument is only guaranteed to be available until this method returns. if you need to
|
||||||
|
* extend its life time, acquire it, then release it when done.
|
||||||
|
*/
|
||||||
|
void onNewImage(Image image);
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
/** Manages internal image data storage. The interface is package-private. */
|
||||||
|
interface ImageContainer {
|
||||||
|
/** Returns the properties of the contained image. */
|
||||||
|
ImageProperties getImageProperties();
|
||||||
|
|
||||||
|
/** Close the image container and releases the image resource inside. */
|
||||||
|
void close();
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
/** Lightweight abstraction for an object that produce {@link Image} */
|
||||||
|
public interface ImageProducer {
|
||||||
|
|
||||||
|
/** Sets the consumer that receives the {@link Image}. */
|
||||||
|
void setImageConsumer(ImageConsumer imageConsumer);
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.auto.value.extension.memoized.Memoized;
|
||||||
|
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||||
|
import com.google.mediapipe.framework.image.Image.StorageType;
|
||||||
|
|
||||||
|
/** Groups a set of properties to describe how an image is stored. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ImageProperties {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the pixel format of the image.
|
||||||
|
*
|
||||||
|
* @see Image.ImageFormat
|
||||||
|
*/
|
||||||
|
@ImageFormat
|
||||||
|
public abstract int getImageFormat();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the storage type of the image.
|
||||||
|
*
|
||||||
|
* @see Image.StorageType
|
||||||
|
*/
|
||||||
|
@StorageType
|
||||||
|
public abstract int getStorageType();
|
||||||
|
|
||||||
|
@Memoized
|
||||||
|
@Override
|
||||||
|
public abstract int hashCode();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a builder of {@link ImageProperties}.
|
||||||
|
*
|
||||||
|
* @see ImageProperties.Builder
|
||||||
|
*/
|
||||||
|
static Builder builder() {
|
||||||
|
return new AutoValue_ImageProperties.Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Builds a {@link ImageProperties}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
abstract static class Builder {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link Image.ImageFormat}.
|
||||||
|
*
|
||||||
|
* @see ImageProperties#getImageFormat
|
||||||
|
*/
|
||||||
|
abstract Builder setImageFormat(@ImageFormat int value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link Image.StorageType}.
|
||||||
|
*
|
||||||
|
* @see ImageProperties#getStorageType
|
||||||
|
*/
|
||||||
|
abstract Builder setStorageType(@StorageType int value);
|
||||||
|
|
||||||
|
/** Builds the {@link ImageProperties}. */
|
||||||
|
abstract ImageProperties build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hide the constructor.
|
||||||
|
ImageProperties() {}
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.os.Build.VERSION_CODES;
|
||||||
|
import androidx.annotation.RequiresApi;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds {@link Image} from {@link android.media.Image}.
|
||||||
|
*
|
||||||
|
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
|
||||||
|
* content in it.
|
||||||
|
*
|
||||||
|
* <p>Use {@link MediaImageExtractor} to get {@link android.media.Image} you passed in.
|
||||||
|
*/
|
||||||
|
@RequiresApi(VERSION_CODES.KITKAT)
|
||||||
|
public class MediaImageBuilder {
|
||||||
|
|
||||||
|
// Mandatory fields.
|
||||||
|
private final android.media.Image mediaImage;
|
||||||
|
|
||||||
|
// Optional fields.
|
||||||
|
private long timestamp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the builder with a mandatory {@link android.media.Image}.
|
||||||
|
*
|
||||||
|
* @param mediaImage image data object.
|
||||||
|
*/
|
||||||
|
public MediaImageBuilder(android.media.Image mediaImage) {
|
||||||
|
this.mediaImage = mediaImage;
|
||||||
|
this.timestamp = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets value for {@link Image#getTimestamp()}. */
|
||||||
|
MediaImageBuilder setTimestamp(long timestamp) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Builds an {@link Image} instance. */
|
||||||
|
public Image build() {
|
||||||
|
return new Image(
|
||||||
|
new MediaImageContainer(mediaImage),
|
||||||
|
timestamp,
|
||||||
|
mediaImage.getWidth(),
|
||||||
|
mediaImage.getHeight());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.os.Build;
|
||||||
|
import android.os.Build.VERSION;
|
||||||
|
import android.os.Build.VERSION_CODES;
|
||||||
|
import androidx.annotation.RequiresApi;
|
||||||
|
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||||
|
|
||||||
|
@RequiresApi(VERSION_CODES.KITKAT)
|
||||||
|
class MediaImageContainer implements ImageContainer {
|
||||||
|
|
||||||
|
private final android.media.Image mediaImage;
|
||||||
|
private final ImageProperties properties;
|
||||||
|
|
||||||
|
public MediaImageContainer(android.media.Image mediaImage) {
|
||||||
|
this.mediaImage = mediaImage;
|
||||||
|
this.properties =
|
||||||
|
ImageProperties.builder()
|
||||||
|
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
|
||||||
|
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public android.media.Image getImage() {
|
||||||
|
return mediaImage;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ImageProperties getImageProperties() {
|
||||||
|
return properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
mediaImage.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ImageFormat
|
||||||
|
static int convertFormatCode(int graphicsFormat) {
|
||||||
|
// We only cover the format mentioned in
|
||||||
|
// https://developer.android.com/reference/android/media/Image#getFormat()
|
||||||
|
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||||
|
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
|
||||||
|
return Image.IMAGE_FORMAT_RGBA;
|
||||||
|
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
|
||||||
|
return Image.IMAGE_FORMAT_RGB;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch (graphicsFormat) {
|
||||||
|
case android.graphics.ImageFormat.JPEG:
|
||||||
|
return Image.IMAGE_FORMAT_JPEG;
|
||||||
|
case android.graphics.ImageFormat.YUV_420_888:
|
||||||
|
return Image.IMAGE_FORMAT_YUV_420_888;
|
||||||
|
default:
|
||||||
|
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.framework.image;
|
||||||
|
|
||||||
|
import android.os.Build.VERSION_CODES;
|
||||||
|
import androidx.annotation.RequiresApi;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility for extracting {@link android.media.Image} from {@link Image}.
|
||||||
|
*
|
||||||
|
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
|
||||||
|
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||||
|
*/
|
||||||
|
@RequiresApi(VERSION_CODES.KITKAT)
|
||||||
|
public class MediaImageExtractor {
|
||||||
|
|
||||||
|
private MediaImageExtractor() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
|
||||||
|
* {@link Image} that built from {@link MediaImageBuilder}.
|
||||||
|
*
|
||||||
|
* @param image the image to extract {@link android.media.Image} from.
|
||||||
|
* @return {@link android.media.Image} that stored in {@link Image}.
|
||||||
|
* @throws IllegalArgumentException if the extraction failed.
|
||||||
|
*/
|
||||||
|
public static android.media.Image extract(Image image) {
|
||||||
|
ImageContainer container;
|
||||||
|
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||||
|
return ((MediaImageContainer) container).getImage();
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Extract Media Image from an Image created by objects other than Media Image"
|
||||||
|
+ " is not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -73,9 +73,14 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
|
||||||
// TODO: get the graph's main context from the packet context?
|
// TODO: get the graph's main context from the packet context?
|
||||||
// Or clean up in some other way?
|
// Or clean up in some other way?
|
||||||
if (context_for_deletion) {
|
if (context_for_deletion) {
|
||||||
token = new mediapipe::GlSyncToken(
|
auto sync = mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext(
|
||||||
mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext(
|
context_for_deletion);
|
||||||
context_for_deletion));
|
// A Java handle to a token is a raw pointer to a std::shared_ptr on the
|
||||||
|
// heap, cast to a long. If the shared_ptr itself is null, leave the token
|
||||||
|
// null too.
|
||||||
|
if (sync) {
|
||||||
|
token = new mediapipe::GlSyncToken(std::move(sync));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return reinterpret_cast<jlong>(token);
|
return reinterpret_cast<jlong>(token);
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,6 +145,7 @@ EOF
|
||||||
"//mediapipe/java/com/google/mediapipe/components:android_components",
|
"//mediapipe/java/com/google/mediapipe/components:android_components",
|
||||||
"//mediapipe/java/com/google/mediapipe/components:android_camerax_helper",
|
"//mediapipe/java/com/google/mediapipe/components:android_camerax_helper",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/java/com/google/mediapipe/glutil",
|
"//mediapipe/java/com/google/mediapipe/glutil",
|
||||||
"//third_party:androidx_annotation",
|
"//third_party:androidx_annotation",
|
||||||
"//third_party:androidx_appcompat",
|
"//third_party:androidx_appcompat",
|
||||||
|
|
|
@ -76,7 +76,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
||||||
auto classifier_options_proto = std::make_unique<tasks::ClassifierOptions>(
|
auto classifier_options_proto =
|
||||||
|
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
||||||
components::ConvertClassifierOptionsToProto(
|
components::ConvertClassifierOptionsToProto(
|
||||||
&(options->classifier_options)));
|
&(options->classifier_options)));
|
||||||
options_proto->mutable_classifier_options()->Swap(
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
|
|
|
@ -136,6 +136,11 @@ void ConfigureAudioToTensorCalculator(
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
|
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
|
||||||
// {
|
// {
|
||||||
|
// base_options {
|
||||||
|
// model_asset {
|
||||||
|
// file_name: "/path/to/model.tflite"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
// max_results: 4
|
// max_results: 4
|
||||||
// score_threshold: 0.5
|
// score_threshold: 0.5
|
||||||
// category_allowlist: "foo"
|
// category_allowlist: "foo"
|
||||||
|
@ -225,15 +230,17 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
// Adds inference subgraph and connects its input stream to the output
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
// tensors produced by the AudioToTensorCalculator.
|
// tensors produced by the AudioToTensorCalculator.
|
||||||
auto& inference = AddInference(model_resources, graph);
|
auto& inference = AddInference(
|
||||||
|
model_resources, task_options.base_options().acceleration(), graph);
|
||||||
audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||||
|
|
||||||
// Adds postprocessing calculators and connects them to the graph output.
|
// Adds postprocessing calculators and connects them to the graph output.
|
||||||
auto& postprocessing =
|
auto& postprocessing = graph.AddNode(
|
||||||
graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph");
|
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||||
model_resources, task_options.classifier_options(),
|
model_resources, task_options.classifier_options(),
|
||||||
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
&postprocessing.GetOptions<
|
||||||
|
tasks::components::ClassificationPostprocessingOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Time aggregation is only needed for performing audio classification on
|
// Time aggregation is only needed for performing audio classification on
|
||||||
|
|
|
@ -37,7 +37,6 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
@ -168,7 +167,7 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||||
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->classifier_options.max_results = 3;
|
options->classifier_options.max_results = 3;
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -192,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
|
TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->classifier_options.max_results = 0;
|
options->classifier_options.max_results = 0;
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||||
AudioClassifier::Create(std::move(options));
|
AudioClassifier::Create(std::move(options));
|
||||||
|
@ -208,7 +207,7 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) {
|
TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.category_allowlist.push_back("foo");
|
options->classifier_options.category_allowlist.push_back("foo");
|
||||||
options->classifier_options.category_denylist.push_back("bar");
|
options->classifier_options.category_denylist.push_back("bar");
|
||||||
|
@ -226,7 +225,7 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) {
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||||
AudioClassifier::Create(std::move(options));
|
AudioClassifier::Create(std::move(options));
|
||||||
|
@ -242,7 +241,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) {
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->sample_rate = 16000;
|
options->sample_rate = 16000;
|
||||||
|
@ -260,7 +259,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[](absl::StatusOr<ClassificationResult> status_or_result) {};
|
[](absl::StatusOr<ClassificationResult> status_or_result) {};
|
||||||
|
@ -279,7 +278,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
|
@ -301,7 +300,7 @@ class ClassifyTest : public tflite_shims::testing::Test {};
|
||||||
TEST_F(ClassifyTest, Succeeds) {
|
TEST_F(ClassifyTest, Succeeds) {
|
||||||
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -315,7 +314,7 @@ TEST_F(ClassifyTest, Succeeds) {
|
||||||
TEST_F(ClassifyTest, SucceedsWithResampling) {
|
TEST_F(ClassifyTest, SucceedsWithResampling) {
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -330,7 +329,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
||||||
auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename);
|
auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename);
|
||||||
auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -349,7 +348,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -374,7 +373,7 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
||||||
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
||||||
auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename);
|
auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -388,7 +387,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
||||||
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
||||||
auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename);
|
auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -404,7 +403,7 @@ TEST_F(ClassifyTest,
|
||||||
auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename);
|
auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename);
|
||||||
auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename);
|
auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
|
@ -424,7 +423,7 @@ TEST_F(ClassifyTest,
|
||||||
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->classifier_options.score_threshold = 0.35f;
|
options->classifier_options.score_threshold = 0.35f;
|
||||||
|
@ -440,7 +439,7 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
||||||
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.score_threshold = 0.35f;
|
options->classifier_options.score_threshold = 0.35f;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
|
@ -455,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
||||||
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.score_threshold = 0.1f;
|
options->classifier_options.score_threshold = 0.1f;
|
||||||
options->classifier_options.category_allowlist.push_back("Speech");
|
options->classifier_options.category_allowlist.push_back("Speech");
|
||||||
|
@ -471,7 +470,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
||||||
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.score_threshold = 0.9f;
|
options->classifier_options.score_threshold = 0.9f;
|
||||||
options->classifier_options.category_denylist.push_back("Speech");
|
options->classifier_options.category_denylist.push_back("Speech");
|
||||||
|
@ -499,7 +498,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
constexpr int kSampleRateHz = 48000;
|
constexpr int kSampleRateHz = 48000;
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->classifier_options.score_threshold = 0.3f;
|
options->classifier_options.score_threshold = 0.3f;
|
||||||
|
@ -529,7 +528,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
constexpr int kSampleRateHz = 48000;
|
constexpr int kSampleRateHz = 48000;
|
||||||
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
auto audio_buffer = GetAudioData(k48kTestWavFilename);
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
options->base_options.model_file_name =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->classifier_options.score_threshold = 0.3f;
|
options->classifier_options.score_threshold = 0.3f;
|
||||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components:classifier_options_proto",
|
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.audio.audio_classifier.proto;
|
package mediapipe.tasks.audio.audio_classifier.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message AudioClassifierOptions {
|
message AudioClassifierOptions {
|
||||||
|
@ -31,7 +31,7 @@ message AudioClassifierOptions {
|
||||||
|
|
||||||
// Options for configuring the classifier behavior, such as score threshold,
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
// number of results, etc.
|
// number of results, etc.
|
||||||
optional ClassifierOptions classifier_options = 2;
|
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||||
|
|
||||||
// The default sample rate of the input audio. Must be set when the
|
// The default sample rate of the input audio. Must be set when the
|
||||||
// AudioClassifier is configured to process audio stream data.
|
// AudioClassifier is configured to process audio stream data.
|
||||||
|
|
|
@ -35,6 +35,8 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":image_preprocessing_options_cc_proto",
|
":image_preprocessing_options_cc_proto",
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_clone_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:image_properties_calculator",
|
"//mediapipe/calculators/image:image_properties_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||||
|
@ -56,21 +58,11 @@ cc_library(
|
||||||
|
|
||||||
# TODO: Enable this test
|
# TODO: Enable this test
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "segmenter_options_proto",
|
|
||||||
srcs = ["segmenter_options.proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "classifier_options",
|
name = "classifier_options",
|
||||||
srcs = ["classifier_options.cc"],
|
srcs = ["classifier_options.cc"],
|
||||||
hdrs = ["classifier_options.h"],
|
hdrs = ["classifier_options.h"],
|
||||||
deps = [":classifier_options_cc_proto"],
|
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "classifier_options_proto",
|
|
||||||
srcs = ["classifier_options.proto"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
|
@ -81,6 +73,7 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -90,7 +83,6 @@ cc_library(
|
||||||
hdrs = ["classification_postprocessing.h"],
|
hdrs = ["classification_postprocessing.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":classification_postprocessing_options_cc_proto",
|
":classification_postprocessing_options_cc_proto",
|
||||||
":classifier_options_cc_proto",
|
|
||||||
"//mediapipe/calculators/core:split_vector_calculator",
|
"//mediapipe/calculators/core:split_vector_calculator",
|
||||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||||
|
@ -104,7 +96,12 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
||||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
@ -119,3 +116,38 @@ cc_library(
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedder_options",
|
||||||
|
srcs = ["embedder_options.cc"],
|
||||||
|
hdrs = ["embedder_options.h"],
|
||||||
|
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_postprocessing_graph",
|
||||||
|
srcs = ["embedding_postprocessing_graph.cc"],
|
||||||
|
hdrs = ["embedding_postprocessing_graph.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/tool:options_map",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
|
@ -113,3 +113,66 @@ cc_test(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "end_loop_calculator",
|
||||||
|
srcs = ["end_loop_calculator.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:end_loop_calculator",
|
||||||
|
"//mediapipe/framework:calculator_context",
|
||||||
|
"//mediapipe/framework:calculator_contract",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:collection_item_id",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "tensors_to_embeddings_calculator_proto",
|
||||||
|
srcs = ["tensors_to_embeddings_calculator.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensors_to_embeddings_calculator",
|
||||||
|
srcs = ["tensors_to_embeddings_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
":tensors_to_embeddings_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tensors_to_embeddings_calculator_test",
|
||||||
|
srcs = ["tensors_to_embeddings_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":tensors_to_embeddings_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||||
|
|
||||||
|
// Specialized EndLoopCalculator for Tasks specific types.
|
||||||
|
namespace mediapipe::tasks {
|
||||||
|
|
||||||
|
typedef EndLoopCalculator<std::vector<ClassificationResult>>
|
||||||
|
EndLoopClassificationResultCalculator;
|
||||||
|
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks
|
|
@ -25,7 +25,7 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/framework/formats:image_format_proto",
|
"//mediapipe/framework/formats:image_format_proto",
|
||||||
"//mediapipe/tasks/cc/components:segmenter_options_proto",
|
"//mediapipe/tasks/cc/components/proto:segmenter_options_proto",
|
||||||
"//mediapipe/util:label_map_proto",
|
"//mediapipe/util:label_map_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -45,7 +45,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
"//mediapipe/framework/port:opencv_imgproc",
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc/components:segmenter_options_cc_proto",
|
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||||
"//mediapipe/util:label_map_cc_proto",
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
|
|
@ -36,19 +36,22 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace tasks {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::ImageFrameSharedPtr;
|
using ::mediapipe::ImageFrameSharedPtr;
|
||||||
using ::mediapipe::tasks::SegmenterOptions;
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Node;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions;
|
using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions;
|
||||||
|
using ::mediapipe::tasks::components::proto::SegmenterOptions;
|
||||||
using ::mediapipe::tasks::vision::GetImageLikeTensorShape;
|
using ::mediapipe::tasks::vision::GetImageLikeTensorShape;
|
||||||
using ::mediapipe::tasks::vision::Shape;
|
using ::mediapipe::tasks::vision::Shape;
|
||||||
|
|
||||||
|
@ -254,7 +257,7 @@ std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResult(
|
||||||
return segmented_masks;
|
return segmented_masks;
|
||||||
}
|
}
|
||||||
|
|
||||||
MEDIAPIPE_REGISTER_NODE(TensorsToSegmentationCalculator);
|
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator);
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks;
|
package mediapipe.tasks;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/segmenter_options.proto";
|
import "mediapipe/tasks/cc/components/proto/segmenter_options.proto";
|
||||||
import "mediapipe/util/label_map.proto";
|
import "mediapipe/util/label_map.proto";
|
||||||
|
|
||||||
message TensorsToSegmentationCalculatorOptions {
|
message TensorsToSegmentationCalculatorOptions {
|
||||||
|
@ -26,7 +26,7 @@ message TensorsToSegmentationCalculatorOptions {
|
||||||
optional TensorsToSegmentationCalculatorOptions ext = 458105876;
|
optional TensorsToSegmentationCalculatorOptions ext = 458105876;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional SegmenterOptions segmenter_options = 1;
|
optional components.proto.SegmenterOptions segmenter_options = 1;
|
||||||
|
|
||||||
// Identifying information for each classification label.
|
// Identifying information for each classification label.
|
||||||
map<int64, mediapipe.LabelMapItem> label_items = 2;
|
map<int64, mediapipe.LabelMapItem> label_items = 2;
|
||||||
|
|
|
@ -117,7 +117,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "SEGMENTATION:segmentation"
|
||||||
options {
|
options {
|
||||||
|
@ -144,7 +144,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "SEGMENTATION:segmentation"
|
||||||
options {
|
options {
|
||||||
|
@ -172,7 +172,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
||||||
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
||||||
|
@ -217,7 +217,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
||||||
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
||||||
|
@ -258,7 +258,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
||||||
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
||||||
|
@ -300,7 +300,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "SEGMENTATION:segmentation"
|
||||||
options {
|
options {
|
||||||
|
@ -333,7 +333,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
|
||||||
CalculatorRunner runner(
|
CalculatorRunner runner(
|
||||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
input_stream: "OUTPUT_SIZE:size"
|
input_stream: "OUTPUT_SIZE:size"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "SEGMENTATION:segmentation"
|
||||||
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
|
||||||
|
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
|
||||||
|
// case all values are 0.
|
||||||
|
float GetInverseL2Norm(const float* values, int size) {
|
||||||
|
float squared_l2_norm = 0.0f;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
squared_l2_norm += values[i] * values[i];
|
||||||
|
}
|
||||||
|
float inv_l2_norm = 1.0f;
|
||||||
|
if (squared_l2_norm > 0.0f) {
|
||||||
|
inv_l2_norm = 1.0f / std::sqrt(squared_l2_norm);
|
||||||
|
}
|
||||||
|
return inv_l2_norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Converts tensors into an EmbeddingResult object, performing optional
|
||||||
|
// L2-normalization and scalar-quantization on-the-fly if required through the
|
||||||
|
// options.
|
||||||
|
//
|
||||||
|
// Input:
|
||||||
|
// TENSORS - std::vector<Tensor>
|
||||||
|
// A vector of one or more Tensors of type kFloat32.
|
||||||
|
// Output:
|
||||||
|
// EMBEDDINGS - EmbeddingResult
|
||||||
|
// The contents of the input tensors converted into an EmbeddingResult
|
||||||
|
// proto.
|
||||||
|
class TensorsToEmbeddingsCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||||
|
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool l2_normalize_;
|
||||||
|
bool quantize_;
|
||||||
|
std::vector<std::string> head_names_;
|
||||||
|
|
||||||
|
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
||||||
|
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
||||||
|
auto options = cc->Options<mediapipe::TensorsToEmbeddingsCalculatorOptions>();
|
||||||
|
l2_normalize_ = options.embedder_options().l2_normalize();
|
||||||
|
quantize_ = options.embedder_options().quantize();
|
||||||
|
if (!options.head_names().empty()) {
|
||||||
|
head_names_.assign(options.head_names().begin(),
|
||||||
|
options.head_names().end());
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
|
||||||
|
EmbeddingResult result;
|
||||||
|
const auto& tensors = *kTensorsIn(cc);
|
||||||
|
if (!head_names_.empty() && tensors.size() != head_names_.size()) {
|
||||||
|
return absl::InvalidArgumentError(absl::StrFormat(
|
||||||
|
"Mismatch between number of provided head names (%d) and number "
|
||||||
|
"of input tensors (%d).",
|
||||||
|
head_names_.size(), tensors.size()));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
const auto& tensor = tensors[i];
|
||||||
|
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
||||||
|
auto* embeddings = result.add_embeddings();
|
||||||
|
embeddings->set_head_index(i);
|
||||||
|
if (!head_names_.empty()) {
|
||||||
|
embeddings->set_head_name(head_names_[i]);
|
||||||
|
}
|
||||||
|
if (quantize_) {
|
||||||
|
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries());
|
||||||
|
} else {
|
||||||
|
FillFloatEmbeddingEntry(tensor, embeddings->add_entries());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kEmbeddingsOut(cc).Send(result);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry(
|
||||||
|
const Tensor& tensor, EmbeddingEntry* entry) {
|
||||||
|
int size = tensor.shape().num_elements();
|
||||||
|
auto tensor_view = tensor.GetCpuReadView();
|
||||||
|
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||||
|
float inv_l2_norm =
|
||||||
|
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||||
|
auto* float_embedding = entry->mutable_float_embedding();
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry(
|
||||||
|
const Tensor& tensor, EmbeddingEntry* entry) {
|
||||||
|
int size = tensor.shape().num_elements();
|
||||||
|
auto tensor_view = tensor.GetCpuReadView();
|
||||||
|
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||||
|
float inv_l2_norm =
|
||||||
|
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||||
|
auto* values = entry->mutable_quantized_embedding()->mutable_values();
|
||||||
|
values->resize(size);
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
// Normalize.
|
||||||
|
float normalized = tensor_buffer[i] * inv_l2_norm;
|
||||||
|
// Quantize.
|
||||||
|
int unclamped_value = static_cast<int>(roundf(normalized * 128));
|
||||||
|
// Clamp and assign.
|
||||||
|
(*values)[i] =
|
||||||
|
static_cast<char>(std::max(-128, std::min(unclamped_value, 127)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(TensorsToEmbeddingsCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
||||||
|
|
||||||
|
message TensorsToEmbeddingsCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional TensorsToEmbeddingsCalculatorOptions ext = 474762326;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The embedder options defining whether to L2-normalize or scalar-quantize
|
||||||
|
// the outputs.
|
||||||
|
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options =
|
||||||
|
1;
|
||||||
|
|
||||||
|
// The embedder head names.
|
||||||
|
repeated string head_names = 2;
|
||||||
|
}
|
|
@ -0,0 +1,249 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||||
|
|
||||||
|
// Builds the graph and feeds inputs.
|
||||||
|
void BuildGraph(CalculatorRunner* runner,
|
||||||
|
std::vector<std::vector<float>> tensors) {
|
||||||
|
auto inputs = std::make_unique<std::vector<Tensor>>();
|
||||||
|
for (const auto& tensor : tensors) {
|
||||||
|
inputs->emplace_back(Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{1, static_cast<int>(tensor.size())});
|
||||||
|
auto view = inputs->back().GetCpuWriteView();
|
||||||
|
float* buffer = view.buffer<float>();
|
||||||
|
ASSERT_NE(buffer, nullptr);
|
||||||
|
for (int i = 0; i < tensor.size(); ++i) {
|
||||||
|
buffer[i] = tensor[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto& input_packets = runner->MutableInputs()->Tag("TENSORS").packets;
|
||||||
|
input_packets.push_back(Adopt(inputs.release()).At(Timestamp(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {0.2, 0.3}});
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("Mismatch between number of provided head names"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result = runner.Outputs()
|
||||||
|
.Get("EMBEDDING_RESULT", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(
|
||||||
|
result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings {
|
||||||
|
entries { float_embedding { values: 0.1 values: 0.2 } }
|
||||||
|
head_index: 0
|
||||||
|
}
|
||||||
|
embeddings {
|
||||||
|
entries { float_embedding { values: -0.2 values: -0.3 } }
|
||||||
|
head_index: 1
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
head_names: "foo"
|
||||||
|
head_names: "bar"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result = runner.Outputs()
|
||||||
|
.Get("EMBEDDING_RESULT", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(
|
||||||
|
result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings {
|
||||||
|
entries { float_embedding { values: 0.1 values: 0.2 } }
|
||||||
|
head_index: 0
|
||||||
|
head_name: "foo"
|
||||||
|
}
|
||||||
|
embeddings {
|
||||||
|
entries { float_embedding { values: -0.2 values: -0.3 } }
|
||||||
|
head_index: 1
|
||||||
|
head_name: "bar"
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: true quantize: false }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result = runner.Outputs()
|
||||||
|
.Get("EMBEDDING_RESULT", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(
|
||||||
|
result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings {
|
||||||
|
entries {
|
||||||
|
float_embedding { values: 0.44721356 values: 0.8944271 }
|
||||||
|
}
|
||||||
|
head_index: 0
|
||||||
|
}
|
||||||
|
embeddings {
|
||||||
|
entries {
|
||||||
|
float_embedding { values: -0.5547002 values: -0.8320503 }
|
||||||
|
}
|
||||||
|
head_index: 1
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: false quantize: true }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result = runner.Outputs()
|
||||||
|
.Get("EMBEDDING_RESULT", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings {
|
||||||
|
entries {
|
||||||
|
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
||||||
|
}
|
||||||
|
head_index: 0
|
||||||
|
}
|
||||||
|
embeddings {
|
||||||
|
entries {
|
||||||
|
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
||||||
|
}
|
||||||
|
head_index: 1
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest,
|
||||||
|
SucceedsWithNormalizationAndQuantization) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: true quantize: true }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result = runner.Outputs()
|
||||||
|
.Get("EMBEDDING_RESULT", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(
|
||||||
|
result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings {
|
||||||
|
entries {
|
||||||
|
quantized_embedding { values: "\x39\x72" } # 57,114
|
||||||
|
}
|
||||||
|
head_index: 0
|
||||||
|
}
|
||||||
|
embeddings {
|
||||||
|
entries {
|
||||||
|
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
||||||
|
}
|
||||||
|
head_index: 1
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -35,9 +35,12 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
@ -47,6 +50,7 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
namespace components {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -57,18 +61,21 @@ using ::mediapipe::api2::Timestamp;
|
||||||
using ::mediapipe::api2::builder::GenericNode;
|
using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
using ::tflite::ProcessUnit;
|
using ::tflite::ProcessUnit;
|
||||||
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
|
|
||||||
using ::tflite::TensorMetadata;
|
using ::tflite::TensorMetadata;
|
||||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||||
|
using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||||
|
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
|
||||||
// Performs sanity checks on provided ClassifierOptions.
|
// Performs sanity checks on provided ClassifierOptions.
|
||||||
|
@ -183,10 +190,10 @@ absl::StatusOr<LabelItems> GetLabelItemsIfAny(
|
||||||
absl::StatusOr<float> GetScoreThreshold(
|
absl::StatusOr<float> GetScoreThreshold(
|
||||||
const ModelMetadataExtractor& metadata_extractor,
|
const ModelMetadataExtractor& metadata_extractor,
|
||||||
const TensorMetadata& tensor_metadata) {
|
const TensorMetadata& tensor_metadata) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(const ProcessUnit* score_thresholding_process_unit,
|
||||||
const ProcessUnit* score_thresholding_process_unit,
|
|
||||||
metadata_extractor.FindFirstProcessUnit(
|
metadata_extractor.FindFirstProcessUnit(
|
||||||
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
|
tensor_metadata,
|
||||||
|
tflite::ProcessUnitOptions_ScoreThresholdingOptions));
|
||||||
if (score_thresholding_process_unit == nullptr) {
|
if (score_thresholding_process_unit == nullptr) {
|
||||||
return kDefaultScoreThreshold;
|
return kDefaultScoreThreshold;
|
||||||
}
|
}
|
||||||
|
@ -230,8 +237,51 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||||
return category_indices;
|
return category_indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fills in the TensorsToClassificationCalculatorOptions based on the classifier
|
absl::Status ConfigureScoreCalibrationIfAny(
|
||||||
// options and the (optional) output tensor metadata.
|
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||||
|
ClassificationPostprocessingOptions* options) {
|
||||||
|
const auto* tensor_metadata =
|
||||||
|
metadata_extractor.GetOutputTensorMetadata(tensor_index);
|
||||||
|
if (tensor_metadata == nullptr) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
// Get ScoreCalibrationOptions, if any.
|
||||||
|
ASSIGN_OR_RETURN(const ProcessUnit* score_calibration_process_unit,
|
||||||
|
metadata_extractor.FindFirstProcessUnit(
|
||||||
|
*tensor_metadata,
|
||||||
|
tflite::ProcessUnitOptions_ScoreCalibrationOptions));
|
||||||
|
if (score_calibration_process_unit == nullptr) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
auto* score_calibration_options =
|
||||||
|
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
|
||||||
|
// Get corresponding AssociatedFile.
|
||||||
|
auto score_calibration_filename =
|
||||||
|
metadata_extractor.FindFirstAssociatedFileName(
|
||||||
|
*tensor_metadata,
|
||||||
|
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
|
||||||
|
if (score_calibration_filename.empty()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kNotFound,
|
||||||
|
"Found ScoreCalibrationOptions but missing required associated "
|
||||||
|
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
|
||||||
|
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
absl::string_view score_calibration_file,
|
||||||
|
metadata_extractor.GetAssociatedFile(score_calibration_filename));
|
||||||
|
ScoreCalibrationCalculatorOptions calculator_options;
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
|
||||||
|
score_calibration_options->score_transformation(),
|
||||||
|
score_calibration_options->default_score(), score_calibration_file,
|
||||||
|
&calculator_options));
|
||||||
|
(*options->mutable_score_calibration_options())[tensor_index] =
|
||||||
|
calculator_options;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||||
|
// classifier options and the (optional) output tensor metadata.
|
||||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||||
const ClassifierOptions& options,
|
const ClassifierOptions& options,
|
||||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||||
|
@ -303,6 +353,8 @@ absl::Status ConfigureClassificationPostprocessing(
|
||||||
ASSIGN_OR_RETURN(const auto heads_properties,
|
ASSIGN_OR_RETURN(const auto heads_properties,
|
||||||
GetClassificationHeadsProperties(model_resources));
|
GetClassificationHeadsProperties(model_resources));
|
||||||
for (int i = 0; i < heads_properties.num_heads; ++i) {
|
for (int i = 0; i < heads_properties.num_heads; ++i) {
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureScoreCalibrationIfAny(
|
||||||
|
*model_resources.GetMetadataExtractor(), i, options));
|
||||||
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
|
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
|
||||||
classifier_options, *model_resources.GetMetadataExtractor(), i,
|
classifier_options, *model_resources.GetMetadataExtractor(), i,
|
||||||
options->add_tensors_to_classifications_options()));
|
options->add_tensors_to_classifications_options()));
|
||||||
|
@ -314,8 +366,8 @@ absl::Status ConfigureClassificationPostprocessing(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// A "mediapipe.tasks.ClassificationPostprocessingSubgraph" converts raw
|
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
|
||||||
// tensors into ClassificationResult objects.
|
// raw tensors into ClassificationResult objects.
|
||||||
// - Accepts CPU input tensors.
|
// - Accepts CPU input tensors.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -376,18 +428,21 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If output tensors are quantized, they must be dequantized first.
|
// If output tensors are quantized, they must be dequantized first.
|
||||||
GenericNode* tensors_dequantization_node;
|
TensorsSource dequantized_tensors(&tensors_in);
|
||||||
if (options.has_quantized_outputs()) {
|
if (options.has_quantized_outputs()) {
|
||||||
tensors_dequantization_node =
|
GenericNode* tensors_dequantization_node =
|
||||||
&graph.AddNode("TensorsDequantizationCalculator");
|
&graph.AddNode("TensorsDequantizationCalculator");
|
||||||
tensors_in >> tensors_dequantization_node->In(kTensorsTag);
|
tensors_in >> tensors_dequantization_node->In(kTensorsTag);
|
||||||
|
dequantized_tensors = {tensors_dequantization_node, kTensorsTag};
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are multiple classification heads, the output tensors need to be
|
// If there are multiple classification heads, the output tensors need to be
|
||||||
// split.
|
// split.
|
||||||
GenericNode* split_tensor_vector_node;
|
std::vector<TensorsSource> split_tensors;
|
||||||
|
split_tensors.reserve(num_heads);
|
||||||
if (num_heads > 1) {
|
if (num_heads > 1) {
|
||||||
split_tensor_vector_node = &graph.AddNode("SplitTensorVectorCalculator");
|
GenericNode* split_tensor_vector_node =
|
||||||
|
&graph.AddNode("SplitTensorVectorCalculator");
|
||||||
auto& split_tensor_vector_options =
|
auto& split_tensor_vector_options =
|
||||||
split_tensor_vector_node
|
split_tensor_vector_node
|
||||||
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
|
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
|
||||||
|
@ -395,12 +450,27 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
auto* range = split_tensor_vector_options.add_ranges();
|
auto* range = split_tensor_vector_options.add_ranges();
|
||||||
range->set_begin(i);
|
range->set_begin(i);
|
||||||
range->set_end(i + 1);
|
range->set_end(i + 1);
|
||||||
|
split_tensors.emplace_back(split_tensor_vector_node, i);
|
||||||
}
|
}
|
||||||
if (options.has_quantized_outputs()) {
|
dequantized_tensors >> split_tensor_vector_node->In(0);
|
||||||
tensors_dequantization_node->Out(kTensorsTag) >>
|
|
||||||
split_tensor_vector_node->In(0);
|
|
||||||
} else {
|
} else {
|
||||||
tensors_in >> split_tensor_vector_node->In(0);
|
split_tensors.emplace_back(dequantized_tensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds score calibration for heads that specify it, if any.
|
||||||
|
std::vector<TensorsSource> calibrated_tensors;
|
||||||
|
calibrated_tensors.reserve(num_heads);
|
||||||
|
for (int i = 0; i < num_heads; ++i) {
|
||||||
|
if (options.score_calibration_options().contains(i)) {
|
||||||
|
GenericNode* score_calibration_node =
|
||||||
|
&graph.AddNode("ScoreCalibrationCalculator");
|
||||||
|
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
|
||||||
|
.CopyFrom(options.score_calibration_options().at(i));
|
||||||
|
split_tensors[i] >> score_calibration_node->In(kScoresTag);
|
||||||
|
calibrated_tensors.emplace_back(score_calibration_node,
|
||||||
|
kCalibratedScoresTag);
|
||||||
|
} else {
|
||||||
|
calibrated_tensors.emplace_back(split_tensors[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,17 +483,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
tensors_to_classification_nodes.back()
|
tensors_to_classification_nodes.back()
|
||||||
->GetOptions<TensorsToClassificationCalculatorOptions>()
|
->GetOptions<TensorsToClassificationCalculatorOptions>()
|
||||||
.CopyFrom(options.tensors_to_classifications_options(i));
|
.CopyFrom(options.tensors_to_classifications_options(i));
|
||||||
if (num_heads == 1) {
|
calibrated_tensors[i] >>
|
||||||
if (options.has_quantized_outputs()) {
|
|
||||||
tensors_dequantization_node->Out(kTensorsTag) >>
|
|
||||||
tensors_to_classification_nodes.back()->In(kTensorsTag);
|
tensors_to_classification_nodes.back()->In(kTensorsTag);
|
||||||
} else {
|
|
||||||
tensors_in >> tensors_to_classification_nodes.back()->In(kTensorsTag);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
split_tensor_vector_node->Out(i) >>
|
|
||||||
tensors_to_classification_nodes.back()->In(kTensorsTag);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Aggregates Classifications into a single ClassificationResult.
|
// Aggregates Classifications into a single ClassificationResult.
|
||||||
|
@ -444,7 +505,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
::mediapipe::tasks::ClassificationPostprocessingSubgraph);
|
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
|
||||||
|
|
||||||
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -18,11 +18,12 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
namespace components {
|
||||||
|
|
||||||
// Configures a ClassificationPostprocessing subgraph using the provided model
|
// Configures a ClassificationPostprocessing subgraph using the provided model
|
||||||
// resources and ClassifierOptions.
|
// resources and ClassifierOptions.
|
||||||
|
@ -31,7 +32,7 @@ namespace tasks {
|
||||||
// Example usage:
|
// Example usage:
|
||||||
//
|
//
|
||||||
// auto& postprocessing =
|
// auto& postprocessing =
|
||||||
// graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph");
|
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||||
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||||
// model_resources,
|
// model_resources,
|
||||||
// classifier_options,
|
// classifier_options,
|
||||||
|
@ -49,10 +50,11 @@ namespace tasks {
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATION_RESULT - ClassificationResult
|
||||||
// The output aggregated classification results.
|
// The output aggregated classification results.
|
||||||
absl::Status ConfigureClassificationPostprocessing(
|
absl::Status ConfigureClassificationPostprocessing(
|
||||||
const core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const ClassifierOptions& classifier_options,
|
const tasks::components::proto::ClassifierOptions& classifier_options,
|
||||||
ClassificationPostprocessingOptions* options);
|
ClassificationPostprocessingOptions* options);
|
||||||
|
|
||||||
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -15,17 +15,22 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks;
|
package mediapipe.tasks.components;
|
||||||
|
|
||||||
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
|
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
|
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
|
||||||
|
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
|
||||||
|
|
||||||
message ClassificationPostprocessingOptions {
|
message ClassificationPostprocessingOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ClassificationPostprocessingOptions ext = 460416950;
|
optional ClassificationPostprocessingOptions ext = 460416950;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Optional mapping between output tensor index and corresponding score
|
||||||
|
// calibration options.
|
||||||
|
map<int32, ScoreCalibrationCalculatorOptions> score_calibration_options = 4;
|
||||||
|
|
||||||
// Options for the TensorsToClassification calculators (one per classification
|
// Options for the TensorsToClassification calculators (one per classification
|
||||||
// head) encapsulated by the ClassificationPostprocessing subgraph.
|
// head) encapsulated by the ClassificationPostprocessing subgraph.
|
||||||
repeated mediapipe.TensorsToClassificationCalculatorOptions
|
repeated mediapipe.TensorsToClassificationCalculatorOptions
|
||||||
|
|
|
@ -41,9 +41,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
|
@ -51,6 +52,7 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
namespace components {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::api2::Input;
|
using ::mediapipe::api2::Input;
|
||||||
|
@ -58,6 +60,7 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::proto::Approximately;
|
using ::testing::proto::Approximately;
|
||||||
|
@ -65,6 +68,8 @@ using ::testing::proto::Approximately;
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||||
constexpr char kQuantizedImageClassifierWithMetadata[] =
|
constexpr char kQuantizedImageClassifierWithMetadata[] =
|
||||||
"vision/mobilenet_v1_0.25_224_quant.tflite";
|
"vision/mobilenet_v1_0.25_224_quant.tflite";
|
||||||
|
constexpr char kQuantizedImageClassifierWithDummyScoreCalibration[] =
|
||||||
|
"vision/mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
|
||||||
constexpr char kQuantizedImageClassifierWithoutMetadata[] =
|
constexpr char kQuantizedImageClassifierWithoutMetadata[] =
|
||||||
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
||||||
constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] =
|
constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] =
|
||||||
|
@ -147,11 +152,12 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
|
||||||
ClassifierOptions options_in;
|
ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: -3.4028235e+38
|
min_score_threshold: -3.4028235e+38
|
||||||
top_k: -1
|
top_k: -1
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -169,11 +175,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
|
||||||
options_in.set_max_results(3);
|
options_in.set_max_results(3);
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: -3.4028235e+38
|
min_score_threshold: -3.4028235e+38
|
||||||
top_k: 3
|
top_k: 3
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -191,11 +198,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
||||||
options_in.set_score_threshold(0.5);
|
options_in.set_score_threshold(0.5);
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: 0.5
|
min_score_threshold: 0.5
|
||||||
top_k: -1
|
top_k: -1
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -212,7 +220,7 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
|
||||||
ClassifierOptions options_in;
|
ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
// Check label map size and two first elements.
|
// Check label map size and two first elements.
|
||||||
|
@ -229,7 +237,8 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
|
||||||
options_out.mutable_tensors_to_classifications_options(0)
|
options_out.mutable_tensors_to_classifications_options(0)
|
||||||
->clear_label_items();
|
->clear_label_items();
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: -3.4028235e+38
|
min_score_threshold: -3.4028235e+38
|
||||||
top_k: -1
|
top_k: -1
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -249,14 +258,15 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
||||||
options_in.add_category_allowlist("tench");
|
options_in.add_category_allowlist("tench");
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
// Clear label map and compare the rest of the options.
|
// Clear label map and compare the rest of the options.
|
||||||
options_out.mutable_tensors_to_classifications_options(0)
|
options_out.mutable_tensors_to_classifications_options(0)
|
||||||
->clear_label_items();
|
->clear_label_items();
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: -3.4028235e+38
|
min_score_threshold: -3.4028235e+38
|
||||||
top_k: -1
|
top_k: -1
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -277,14 +287,15 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
||||||
options_in.add_category_denylist("background");
|
options_in.add_category_denylist("background");
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
// Clear label map and compare the rest of the options.
|
// Clear label map and compare the rest of the options.
|
||||||
options_out.mutable_tensors_to_classifications_options(0)
|
options_out.mutable_tensors_to_classifications_options(0)
|
||||||
->clear_label_items();
|
->clear_label_items();
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: -3.4028235e+38
|
min_score_threshold: -3.4028235e+38
|
||||||
top_k: -1
|
top_k: -1
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -297,6 +308,56 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
||||||
)pb")));
|
)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto model_resources,
|
||||||
|
CreateModelResourcesForModel(
|
||||||
|
kQuantizedImageClassifierWithDummyScoreCalibration));
|
||||||
|
ClassifierOptions options_in;
|
||||||
|
|
||||||
|
ClassificationPostprocessingOptions options_out;
|
||||||
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
|
options_in, &options_out));
|
||||||
|
|
||||||
|
// Check label map size and two first elements.
|
||||||
|
EXPECT_EQ(
|
||||||
|
options_out.tensors_to_classifications_options(0).label_items_size(),
|
||||||
|
kMobileNetNumClasses);
|
||||||
|
EXPECT_THAT(
|
||||||
|
options_out.tensors_to_classifications_options(0).label_items().at(0),
|
||||||
|
EqualsProto(R"pb(name: "background")pb"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
options_out.tensors_to_classifications_options(0).label_items().at(1),
|
||||||
|
EqualsProto(R"pb(name: "tench")pb"));
|
||||||
|
// Clear label map.
|
||||||
|
options_out.mutable_tensors_to_classifications_options(0)
|
||||||
|
->clear_label_items();
|
||||||
|
// Check sigmoids size and first element.
|
||||||
|
EXPECT_EQ(options_out.score_calibration_options_size(), 1);
|
||||||
|
auto score_calibration_options =
|
||||||
|
options_out.score_calibration_options().at(0);
|
||||||
|
EXPECT_EQ(score_calibration_options.sigmoids_size(), kMobileNetNumClasses);
|
||||||
|
EXPECT_THAT(score_calibration_options.sigmoids(0),
|
||||||
|
EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb"));
|
||||||
|
options_out.mutable_score_calibration_options()->at(0).clear_sigmoids();
|
||||||
|
// Compare the rest of the options.
|
||||||
|
EXPECT_THAT(
|
||||||
|
options_out,
|
||||||
|
Approximately(EqualsProto(
|
||||||
|
R"pb(score_calibration_options {
|
||||||
|
key: 0
|
||||||
|
value { score_transformation: IDENTITY default_score: 0.5 }
|
||||||
|
}
|
||||||
|
tensors_to_classifications_options {
|
||||||
|
min_score_threshold: -3.4028235e+38
|
||||||
|
top_k: -1
|
||||||
|
sort_by_descending_score: true
|
||||||
|
}
|
||||||
|
classification_aggregation_options { head_names: "probability" }
|
||||||
|
has_quantized_outputs: true
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
|
@ -304,7 +365,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
ClassifierOptions options_in;
|
ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
ClassificationPostprocessingOptions options_out;
|
||||||
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||||
options_in, &options_out));
|
options_in, &options_out));
|
||||||
// Check label maps sizes and first two elements.
|
// Check label maps sizes and first two elements.
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
|
@ -331,7 +392,8 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
options_out.mutable_tensors_to_classifications_options(1)
|
options_out.mutable_tensors_to_classifications_options(1)
|
||||||
->clear_label_items();
|
->clear_label_items();
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(tensors_to_classifications_options {
|
R"pb(score_calibration_options: []
|
||||||
|
tensors_to_classifications_options {
|
||||||
min_score_threshold: -3.4028235e+38
|
min_score_threshold: -3.4028235e+38
|
||||||
top_k: -1
|
top_k: -1
|
||||||
sort_by_descending_score: true
|
sort_by_descending_score: true
|
||||||
|
@ -358,8 +420,8 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
CreateModelResourcesForModel(model_name));
|
CreateModelResourcesForModel(model_name));
|
||||||
|
|
||||||
Graph graph;
|
Graph graph;
|
||||||
auto& postprocessing =
|
auto& postprocessing = graph.AddNode(
|
||||||
graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph");
|
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||||
*model_resources, options,
|
*model_resources, options,
|
||||||
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
||||||
|
@ -503,6 +565,52 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||||
})pb"));
|
})pb"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
|
// Build graph.
|
||||||
|
ClassifierOptions options;
|
||||||
|
options.set_max_results(3);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller,
|
||||||
|
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||||
|
tensor[1] = 12;
|
||||||
|
tensor[2] = 14;
|
||||||
|
tensor[3] = 16;
|
||||||
|
tensor[4] = 18;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||||
|
/*quantization_parameters=*/{0.1, 10});
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_THAT(results, EqualsProto(
|
||||||
|
R"pb(classifications {
|
||||||
|
entries {
|
||||||
|
categories {
|
||||||
|
index: 4
|
||||||
|
score: 0.6899744811
|
||||||
|
category_name: "tiger shark"
|
||||||
|
}
|
||||||
|
categories {
|
||||||
|
index: 3
|
||||||
|
score: 0.6456563062
|
||||||
|
category_name: "great white shark"
|
||||||
|
}
|
||||||
|
categories {
|
||||||
|
index: 2
|
||||||
|
score: 0.5986876601
|
||||||
|
category_name: "goldfish"
|
||||||
|
}
|
||||||
|
timestamp_ms: 0
|
||||||
|
}
|
||||||
|
head_index: 0
|
||||||
|
head_name: "probability"
|
||||||
|
})pb"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
ClassifierOptions options;
|
ClassifierOptions options;
|
||||||
|
@ -621,5 +729,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -15,15 +15,15 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
|
||||||
tasks::ClassifierOptions ConvertClassifierOptionsToProto(
|
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||||
ClassifierOptions* options) {
|
ClassifierOptions* options) {
|
||||||
tasks::ClassifierOptions options_proto;
|
tasks::components::proto::ClassifierOptions options_proto;
|
||||||
options_proto.set_display_names_locale(options->display_names_locale);
|
options_proto.set_display_names_locale(options->display_names_locale);
|
||||||
options_proto.set_max_results(options->max_results);
|
options_proto.set_max_results(options->max_results);
|
||||||
options_proto.set_score_threshold(options->score_threshold);
|
options_proto.set_score_threshold(options->score_threshold);
|
||||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -49,7 +49,7 @@ struct ClassifierOptions {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts a ClassifierOptions to a ClassifierOptionsProto.
|
// Converts a ClassifierOptions to a ClassifierOptionsProto.
|
||||||
tasks::ClassifierOptions ConvertClassifierOptionsToProto(
|
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||||
ClassifierOptions* classifier_options);
|
ClassifierOptions* classifier_options);
|
||||||
|
|
||||||
} // namespace components
|
} // namespace components
|
||||||
|
|
|
@ -29,3 +29,8 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework/formats:rect_proto",
|
"//mediapipe/framework/formats:rect_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "embeddings_proto",
|
||||||
|
srcs = ["embeddings.proto"],
|
||||||
|
)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user