Merge branch 'master' into image-segmenter-python-impl

This commit is contained in:
Kinar R 2022-09-29 16:05:36 +05:30 committed by GitHub
commit 1461bcf97d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
238 changed files with 13030 additions and 1923 deletions

View File

@ -157,6 +157,13 @@ http_archive(
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"], urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
) )
http_archive(
name = "pffft",
strip_prefix = "jpommier-pffft-7c3b5a7dc510",
urls = ["https://bitbucket.org/jpommier/pffft/get/7c3b5a7dc510.zip"],
build_file = "@//third_party:pffft.BUILD",
)
# sentencepiece # sentencepiece
http_archive( http_archive(
name = "com_google_sentencepiece", name = "com_google_sentencepiece",

View File

@ -217,7 +217,7 @@ A list of pose landmarks. Each landmark consists of the following:
*Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* | *Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* |
:-----------------------------------------------------------: | :-----------------------------------------------------------: |
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_world_landmarks.mp4" type="video/mp4"></video> | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/mobile/pose_world_landmarks.mp4" type="video/mp4"></video> |
Another list of pose landmarks in world coordinates. Each landmark consists of Another list of pose landmarks in world coordinates. Each landmark consists of
the following: the following:
@ -238,7 +238,7 @@ for usage details.
*Fig 6. Example of MediaPipe Pose segmentation mask.* | *Fig 6. Example of MediaPipe Pose segmentation mask.* |
:---------------------------------------------------: | :---------------------------------------------------: |
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/mobile/pose_segmentation.mp4" type="video/mp4"></video> |
### Python Solution API ### Python Solution API

View File

@ -22,7 +22,7 @@ nav_order: 7
*Fig 1. Example of MediaPipe Selfie Segmentation.* | *Fig 1. Example of MediaPipe Selfie Segmentation.* |
:------------------------------------------------: | :------------------------------------------------: |
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/selfie_segmentation_web.mp4" type="video/mp4"></video> | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/selfie_segmentation_web.mp4" type="video/mp4"></video> |
MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can
run in real-time on both smartphones and laptops. The intended use cases include run in real-time on both smartphones and laptops. The intended use cases include

View File

@ -1294,8 +1294,8 @@ cc_library(
deps = [ deps = [
":get_vector_item_calculator_cc_proto", ":get_vector_item_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
@ -1319,6 +1319,32 @@ cc_test(
], ],
) )
cc_library(
name = "vector_indices_calculator",
srcs = ["vector_indices_calculator.cc"],
hdrs = ["vector_indices_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "vector_indices_calculator_test",
srcs = ["vector_indices_calculator_test.cc"],
deps = [
":vector_indices_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_library( cc_library(
name = "vector_size_calculator", name = "vector_size_calculator",
srcs = ["vector_size_calculator.cc"], srcs = ["vector_size_calculator.cc"],

View File

@ -40,6 +40,9 @@ REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator);
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator; typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
REGISTER_CALCULATOR(EndLoopBooleanCalculator); REGISTER_CALCULATOR(EndLoopBooleanCalculator);
typedef EndLoopCalculator<std::vector<float>> EndLoopFloatCalculator;
REGISTER_CALCULATOR(EndLoopFloatCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>> typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
EndLoopRenderDataCalculator; EndLoopRenderDataCalculator;
REGISTER_CALCULATOR(EndLoopRenderDataCalculator); REGISTER_CALCULATOR(EndLoopRenderDataCalculator);

View File

@ -24,6 +24,10 @@ using GetLandmarkListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::LandmarkList>; GetVectorItemCalculator<mediapipe::LandmarkList>;
REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator); REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator);
using GetNormalizedLandmarkListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::NormalizedLandmarkList>;
REGISTER_CALCULATOR(GetNormalizedLandmarkListVectorItemCalculator);
using GetClassificationListVectorItemCalculator = using GetClassificationListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::ClassificationList>; GetVectorItemCalculator<mediapipe::ClassificationList>;
REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator);

View File

@ -19,6 +19,7 @@
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" #include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -58,7 +59,7 @@ template <typename T>
class GetVectorItemCalculator : public Node { class GetVectorItemCalculator : public Node {
public: public:
static constexpr Input<std::vector<T>> kIn{"VECTOR"}; static constexpr Input<std::vector<T>> kIn{"VECTOR"};
static constexpr Input<int>::Optional kIdx{"INDEX"}; static constexpr Input<OneOf<int, uint64_t>>::Optional kIdx{"INDEX"};
static constexpr Output<T> kOut{"ITEM"}; static constexpr Output<T> kOut{"ITEM"};
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
@ -80,7 +81,9 @@ class GetVectorItemCalculator : public Node {
int idx = 0; int idx = 0;
if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) { if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) {
idx = kIdx(cc).Get(); idx = kIdx(cc).Visit(
[](uint64_t idx_uint64_t) { return static_cast<int>(idx_uint64_t); },
[](int idx_int) { return idx_int; });
} else if (options.has_item_index()) { } else if (options.has_item_index()) {
idx = options.item_index(); idx = options.item_index();
} else { } else {

View File

@ -227,4 +227,15 @@ TEST(TestGetIntVectorItemCalculatorTest, IndexOptionsTwoTimestamps) {
testing::ElementsAre(TimestampValue(1), TimestampValue(2))); testing::ElementsAre(TimestampValue(1), TimestampValue(2)));
} }
TEST(TestGetIntVectorItemCalculatorTest, IndexUint64) {
CalculatorRunner runner = MakeRunnerWithStream();
const std::vector<int> inputs = {1, 2, 3};
const uint64_t index = 1;
AddInputVector(runner, inputs, 1);
AddInputIndex(runner, index, 1);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
EXPECT_THAT(outputs, testing::ElementsAre(IntPacket(inputs[index])));
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/vector_indices_calculator.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace api2 {
using IntVectorIndicesCalculator = VectorIndicesCalculator<int>;
REGISTER_CALCULATOR(IntVectorIndicesCalculator);
using Uint64tVectorIndicesCalculator = VectorIndicesCalculator<uint64_t>;
REGISTER_CALCULATOR(Uint64tVectorIndicesCalculator);
using NormalizedLandmarkListVectorIndicesCalculator =
VectorIndicesCalculator<mediapipe::NormalizedLandmarkList>;
REGISTER_CALCULATOR(NormalizedLandmarkListVectorIndicesCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,65 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_
#include <optional>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Calculator that takes a vector and constructs an index range vector based on
// the size of the input vector.
//
// Inputs:
// VECTOR - std::vector<T>
// Vector whose range of indices to return.
//
// Outputs:
// INDICES - std::vector<int>
// Indices vector of the input vector.
//
// Example config:
// node {
// calculator: "{SpecificType}VectorIndicesCalculator"
// input_stream: "VECTOR:vector"
// output_stream: "INDICES:indices"
// }
//
template <typename T>
class VectorIndicesCalculator : public Node {
public:
static constexpr Input<std::vector<T>> kVector{"VECTOR"};
static constexpr Output<std::vector<int>> kRange{"INDICES"};
MEDIAPIPE_NODE_CONTRACT(kVector, kRange);
absl::Status Process(CalculatorContext* cc) final {
// Get the size of the input vector.
const int vector_size = kVector(cc).Get().size();
std::vector<int> out_idxs(vector_size);
std::iota(out_idxs.begin(), out_idxs.end(), 0);
kRange(cc).Send(out_idxs);
return absl::OkStatus();
}
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_

View File

@ -0,0 +1,87 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/vector_indices_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
template <typename T>
void AddInputVector(CalculatorRunner& runner, const std::vector<T>& inputs,
int timestamp) {
runner.MutableInputs()->Tag("VECTOR").packets.push_back(
MakePacket<std::vector<T>>(inputs).At(Timestamp(timestamp)));
}
template <typename T>
struct TestParams {
const std::string test_name;
const std::vector<T> inputs;
const int timestamp;
const std::vector<int> expected_indices;
};
class IntVectorIndicesCalculatorTest
: public testing::TestWithParam<TestParams<int>> {};
TEST_P(IntVectorIndicesCalculatorTest, Succeeds) {
CalculatorRunner runner = CalculatorRunner(R"(
calculator: "IntVectorIndicesCalculator"
input_stream: "VECTOR:vector_stream"
output_stream: "INDICES:indices_stream"
)");
const std::vector<int>& inputs = GetParam().inputs;
std::vector<int> expected_indices(inputs.size());
AddInputVector(runner, inputs, GetParam().timestamp);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("INDICES").packets;
EXPECT_EQ(1, outputs.size());
EXPECT_THAT(outputs[0].Get<std::vector<int>>(),
testing::ElementsAreArray(GetParam().expected_indices));
}
INSTANTIATE_TEST_SUITE_P(
IntVectorIndicesCalculatorTest, IntVectorIndicesCalculatorTest,
Values(TestParams<int>{
/* test_name= */ "IntVectorIndices",
/* inputs= */ {1, 2, 3},
/* timestamp= */ 1,
/* expected_indices= */ {0, 1, 2},
},
TestParams<int>{
/* test_name= */ "EmptyVector",
/* inputs= */ {},
/* timestamp= */ 1,
/* expected_indices= */ {},
}),
[](const TestParamInfo<IntVectorIndicesCalculatorTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace mediapipe

View File

@ -55,6 +55,14 @@ mediapipe_proto_library(
cc_library( cc_library(
name = "audio_to_tensor_calculator", name = "audio_to_tensor_calculator",
srcs = ["audio_to_tensor_calculator.cc"], srcs = ["audio_to_tensor_calculator.cc"],
copts = select({
# b/215212850
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc",
],
"//conditions:default": [],
}),
visibility = [ visibility = [
"//mediapipe/framework:mediapipe_internal", "//mediapipe/framework:mediapipe_internal",
], ],
@ -67,13 +75,16 @@ cc_library(
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_audio_tools//audio/dsp:resampler_q", "@com_google_audio_tools//audio/dsp:resampler_q",
"@com_google_audio_tools//audio/dsp:window_functions",
"@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/c:common",
"@pffft",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -83,6 +94,7 @@ cc_test(
srcs = ["audio_to_tensor_calculator_test.cc"], srcs = ["audio_to_tensor_calculator_test.cc"],
deps = [ deps = [
":audio_to_tensor_calculator", ":audio_to_tensor_calculator",
":audio_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
@ -97,6 +109,58 @@ cc_test(
], ],
) )
mediapipe_proto_library(
name = "feedback_tensors_calculator_proto",
srcs = ["feedback_tensors_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "feedback_tensors_calculator",
srcs = ["feedback_tensors_calculator.cc"],
copts = select({
# b/215212850
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc",
],
"//conditions:default": [],
}),
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":feedback_tensors_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_test(
name = "feedback_tensors_calculator_test",
srcs = ["feedback_tensors_calculator_test.cc"],
deps = [
":feedback_tensors_calculator",
":feedback_tensors_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@org_tensorflow//tensorflow/lite/c:common",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],
@ -346,6 +410,10 @@ cc_library(
}), }),
) )
# This target provides the InferenceCalculator and a default set of implementations tailored for the
# current build platforms. More implementations can be added as separate dependencies to a client;
# for clients that want a narrower set of implementations than the default should see the comment on
# inference_calculator_interface.
cc_library( cc_library(
name = "inference_calculator", name = "inference_calculator",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],

View File

@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <math.h>
#include <algorithm> #include <algorithm>
#include <cmath>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,6 +25,7 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "audio/dsp/resampler_q.h" #include "audio/dsp/resampler_q.h"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
@ -34,19 +34,60 @@
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/time_series_util.h" #include "mediapipe/util/time_series_util.h"
#include "pffft.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
namespace {
using Options = ::mediapipe::AudioToTensorCalculatorOptions;
using FlushMode = Options::FlushMode;
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
std::vector<float> hann_window(window_size);
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
if (sqrt_hann) {
absl::c_transform(hann_window, hann_window.begin(),
[](double x) { return std::sqrt(x); });
}
return hann_window;
}
// PFFFT only supports transforms for inputs of length N of the form
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
bool IsValidFftSize(int size) {
if (size <= 0) {
return false;
}
constexpr int kFactors[] = {2, 3, 5};
int factorization[] = {0, 0, 0};
int n = static_cast<int>(size);
for (int i = 0; i < 3; ++i) {
while (n % kFactors[i] == 0) {
n = n / kFactors[i];
++factorization[i];
}
}
return factorization[0] >= 5 && n == 1;
}
} // namespace
// Converts audio buffers into tensors, possibly with resampling, buffering // Converts audio buffers into tensors, possibly with resampling, buffering
// and framing, according to specified inputs and options. All input audio // and framing, according to specified inputs and options. All input audio
// buffers will be first resampled from the input sample rate to the target // buffers will be first resampled from the input sample rate to the target
// sample rate if they are not equal. The resampled audio data (with the // sample rate if they are not equal. The resampled audio data (with the
// buffered samples from the previous runs in the streaming mode) will be broken // buffered samples from the previous runs in the streaming mode) will be broken
// into fixed-sized, possibly overlapping frames. Finally, all frames will be // into fixed-sized, possibly overlapping frames. If the calculator is not asked
// converted to and outputted as MediaPipe Tensors. The last output tensor will // to perform fft (the fft_size is not set in the calculator options), all
// be zero-padding if the remaining samples are insufficient. // frames will be converted to and outputted as MediaPipe Tensors. The last
// output tensor will be zero-padding if the remaining samples are insufficient.
// Otherwise, when the fft_size is set and valid, the calculator will perform
// fft on the fixed-sized audio frames, the complex DFT results will be
// converted to and outputted as 2D MediaPipe float Tensors where the first
// rows are the DFT real parts and the second rows are the DFT imagery parts.
// //
// This calculator assumes that the input timestamps refer to the first // This calculator assumes that the input timestamps refer to the first
// sample in each Matrix. The output timestamps follow this same convention. // sample in each Matrix. The output timestamps follow this same convention.
@ -86,11 +127,15 @@ namespace api2 {
// Outputs: // Outputs:
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// Vector containing a single Tensor that represents a fix-sized audio // Vector containing a single Tensor that represents a fix-sized audio
// frame. // frame or the complex DFT results.
// TIMESTAMPS - std::vector<Timestamp> @Optional // TIMESTAMPS - std::vector<Timestamp> @Optional
// Vector containing the output timestamps emitted by the current Process() // Vector containing the output timestamps emitted by the current Process()
// invocation. In the non-streaming mode, the vector contains all of the // invocation. In the non-streaming mode, the vector contains all of the
// output timestamps for an input audio buffer. // output timestamps for an input audio buffer.
// DC_AND_NYQUIST - std::pair<float, float> @Optional.
// A pair of dc component and nyquest component. Only can be connected when
// the calculator performs fft (the fft_size is set in the calculator
// options).
// //
// Example: // Example:
// node { // node {
@ -116,12 +161,14 @@ class AudioToTensorCalculator : public Node {
// such as sample rate. // such as sample rate.
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"}; static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
static constexpr Output<std::pair<float, float>>::Optional kDcAndNyquistOut{
"DC_AND_NYQUIST"};
// A vector of the output timestamps emitted by the current Process() // A vector of the output timestamps emitted by the current Process()
// invocation. The packet timestamp is the last emitted timestamp. // invocation. The packet timestamp is the last emitted timestamp.
static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{ static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{
"TIMESTAMPS"}; "TIMESTAMPS"};
MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut, MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut,
kTimestampsOut); kDcAndNyquistOut, kTimestampsOut);
static absl::Status UpdateContract(CalculatorContract* cc); static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc); absl::Status Open(CalculatorContext* cc);
@ -138,6 +185,9 @@ class AudioToTensorCalculator : public Node {
int frame_step_; int frame_step_;
bool stream_mode_; bool stream_mode_;
bool check_inconsistent_timestamps_; bool check_inconsistent_timestamps_;
int padding_samples_before_;
int padding_samples_after_;
FlushMode flush_mode_;
Timestamp initial_timestamp_ = Timestamp::Unstarted(); Timestamp initial_timestamp_ = Timestamp::Unstarted();
int64 cumulative_input_samples_ = 0; int64 cumulative_input_samples_ = 0;
Timestamp next_output_timestamp_ = Timestamp::Unstarted(); Timestamp next_output_timestamp_ = Timestamp::Unstarted();
@ -151,22 +201,33 @@ class AudioToTensorCalculator : public Node {
Matrix sample_buffer_; Matrix sample_buffer_;
int processed_buffer_cols_ = 0; int processed_buffer_cols_ = 0;
// The internal state of the FFT library.
PFFFT_Setup* fft_state_ = nullptr;
int fft_size_ = 0;
std::vector<float> fft_window_;
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
// pffft requires memory to work with to avoid using the stack.
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input); absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input);
absl::Status ProcessNonStreamingData(CalculatorContext* cc, absl::Status ProcessNonStreamingData(CalculatorContext* cc,
const Matrix& input); const Matrix& input);
absl::Status SetupStreamingResampler(double input_sample_rate_); absl::Status SetupStreamingResampler(double input_sample_rate_);
void AppendToSampleBuffer(Matrix buffer_to_append); void AppendToSampleBuffer(Matrix buffer_to_append);
void AppendZerosToSampleBuffer(int num_samples);
absl::StatusOr<std::vector<Tensor>> ConvertToTensor( absl::StatusOr<std::vector<Tensor>> ConvertToTensor(
const Matrix& frame_to_convert); const Matrix& block, std::vector<int> tensor_dims);
absl::Status OutputTensors(const Matrix& buffer, bool should_flush, absl::Status OutputTensor(const Matrix& block, Timestamp timestamp,
CalculatorContext* cc);
absl::Status ProcessBuffer(const Matrix& buffer, bool should_flush,
CalculatorContext* cc); CalculatorContext* cc);
}; };
absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
const auto& options = const auto& options = cc->Options<Options>();
cc->Options<mediapipe::AudioToTensorCalculatorOptions>();
if (!options.has_num_channels() || !options.has_num_samples() || if (!options.has_num_channels() || !options.has_num_samples() ||
!options.has_target_sample_rate()) { !options.has_target_sample_rate()) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
@ -174,13 +235,21 @@ absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
"`num_channels`, `num_samples`, and `target_sample_rate`."); "`num_channels`, `num_samples`, and `target_sample_rate`.");
} }
if (options.stream_mode()) { if (options.stream_mode()) {
// Explicitly disables tiemstamp offset to disallow the timestamp bound // Explicitly disables timestamp offset to disallow the timestamp bound
// from the input streams to be propagated to the output streams. // from the input streams to be propagated to the output streams.
// In the streaming mode, the output timestamp bound is based on // In the streaming mode, the output timestamp bound is based on
// next_output_timestamp_, which can be smaller than the current input // next_output_timestamp_, which can be smaller than the current input
// timestamps. // timestamps.
cc->SetTimestampOffset(TimestampDiff::Unset()); cc->SetTimestampOffset(TimestampDiff::Unset());
} }
if (options.padding_samples_before() < 0 ||
options.padding_samples_after() < 0) {
return absl::InvalidArgumentError("Negative zero padding unsupported");
}
if (options.flush_mode() != Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX &&
options.flush_mode() != Options::PROCEED_AS_USUAL) {
return absl::InvalidArgumentError("Unsupported flush mode");
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -202,6 +271,9 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
check_inconsistent_timestamps_ = options.check_inconsistent_timestamps(); check_inconsistent_timestamps_ = options.check_inconsistent_timestamps();
sample_buffer_.resize(num_channels_, Eigen::NoChange); sample_buffer_.resize(num_channels_, Eigen::NoChange);
} }
padding_samples_before_ = options.padding_samples_before();
padding_samples_after_ = options.padding_samples_after();
flush_mode_ = options.flush_mode();
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
!kAudioIn(cc).Header().IsEmpty()) !kAudioIn(cc).Header().IsEmpty())
@ -217,6 +289,25 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
source_sample_rate_ = input_header.sample_rate(); source_sample_rate_ = input_header.sample_rate();
} }
} }
AppendZerosToSampleBuffer(padding_samples_before_);
if (options.has_fft_size()) {
RET_CHECK(IsValidFftSize(options.fft_size()))
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
">=0 and c >= 0 and a >= 5, the requested fft size is "
<< options.fft_size();
RET_CHECK_EQ(1, num_channels_)
<< "Currently only support applying FFT on mono channel.";
fft_size_ = options.fft_size();
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
fft_window_ = HannWindow(fft_size_, /* sqrt_hann = */ false);
fft_input_buffer_.resize(fft_size_);
fft_workplace_.resize(fft_size_);
fft_output_.resize(fft_size_);
} else {
RET_CHECK(!kDcAndNyquistOut(cc).IsConnected())
<< "The DC_AND_NYQUIST output stream can only be connected when the "
"calculator outputs fft tensors";
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -262,7 +353,12 @@ absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) {
resampler_->Flush(&resampled_buffer); resampler_->Flush(&resampled_buffer);
AppendToSampleBuffer(std::move(resampled_buffer)); AppendToSampleBuffer(std::move(resampled_buffer));
} }
return OutputTensors(sample_buffer_, /*should_flush=*/true, cc); AppendZerosToSampleBuffer(padding_samples_after_);
MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/true, cc));
if (fft_state_) {
pffft_destroy_setup(fft_state_);
}
return absl::OkStatus();
} }
absl::Status AudioToTensorCalculator::ProcessStreamingData( absl::Status AudioToTensorCalculator::ProcessStreamingData(
@ -303,7 +399,7 @@ absl::Status AudioToTensorCalculator::ProcessStreamingData(
} }
} }
MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc)); MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/false, cc));
// Removes the processed samples from the global sample buffer. // Removes the processed samples from the global sample buffer.
sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() - sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() -
processed_buffer_cols_ - 1)); processed_buffer_cols_ - 1));
@ -323,9 +419,9 @@ absl::Status AudioToTensorCalculator::ProcessNonStreamingData(
input_frame); input_frame);
Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_, Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_,
resampled.size() / num_channels_); resampled.size() / num_channels_);
return OutputTensors(matrix_mapping, /*should_flush=*/true, cc); return ProcessBuffer(matrix_mapping, /*should_flush=*/true, cc);
} }
return OutputTensors(input_frame, /*should_flush=*/true, cc); return ProcessBuffer(input_frame, /*should_flush=*/true, cc);
} }
absl::Status AudioToTensorCalculator::SetupStreamingResampler( absl::Status AudioToTensorCalculator::SetupStreamingResampler(
@ -344,6 +440,16 @@ absl::Status AudioToTensorCalculator::SetupStreamingResampler(
return absl::OkStatus(); return absl::OkStatus();
} }
void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) {
CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`.
if (num_samples == 0) {
return;
}
sample_buffer_.conservativeResize(Eigen::NoChange,
sample_buffer_.cols() + num_samples);
sample_buffer_.rightCols(num_samples).setZero();
}
void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
sample_buffer_.conservativeResize( sample_buffer_.conservativeResize(
Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols()); Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols());
@ -351,49 +457,89 @@ void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
} }
absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor( absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor(
const Matrix& frame_to_convert) { const Matrix& block, std::vector<int> tensor_dims) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape(tensor_dims));
Tensor::Shape({num_channels_, num_samples_}));
auto buffer_view = tensor.GetCpuWriteView(); auto buffer_view = tensor.GetCpuWriteView();
if (frame_to_convert.size() < num_channels_ * num_samples_) { int total_size = 1;
for (int dim : tensor_dims) {
total_size *= dim;
}
if (block.size() < total_size) {
std::memset(buffer_view.buffer<float>(), 0, tensor.bytes()); std::memset(buffer_view.buffer<float>(), 0, tensor.bytes());
} }
std::memcpy(buffer_view.buffer<float>(), frame_to_convert.data(), std::memcpy(buffer_view.buffer<float>(), block.data(),
frame_to_convert.size() * sizeof(float)); block.size() * sizeof(float));
std::vector<Tensor> tensor_vector; std::vector<Tensor> tensor_vector;
tensor_vector.push_back(std::move(tensor)); tensor_vector.push_back(std::move(tensor));
return tensor_vector; return tensor_vector;
} }
absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer, absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block,
Timestamp timestamp,
CalculatorContext* cc) {
std::vector<Tensor> output_tensor;
if (fft_state_) {
Eigen::VectorXf time_series_data =
Eigen::VectorXf::Map(block.data(), block.size());
// Window on input audio prior to FFT.
std::transform(time_series_data.begin(), time_series_data.end(),
fft_window_.begin(), fft_input_buffer_.begin(),
std::multiplies<float>());
pffft_transform_ordered(fft_state_, fft_input_buffer_.data(),
fft_output_.data(), fft_workplace_.data(),
PFFFT_FORWARD);
if (kDcAndNyquistOut(cc).IsConnected()) {
kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]),
timestamp);
}
Matrix fft_output_matrix =
Eigen::Map<const Matrix>(fft_output_.data() + 2, 1, fft_size_ - 2);
fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_);
// The last two elements are the DFT Nyquist values.
fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part
fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part
ASSIGN_OR_RETURN(output_tensor,
ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2}));
} else {
ASSIGN_OR_RETURN(output_tensor,
ConvertToTensor(block, {num_channels_, num_samples_}));
}
kTensorsOut(cc).Send(std::move(output_tensor), timestamp);
return absl::OkStatus();
}
absl::Status AudioToTensorCalculator::ProcessBuffer(const Matrix& buffer,
bool should_flush, bool should_flush,
CalculatorContext* cc) { CalculatorContext* cc) {
const bool should_flush_at_timestamp_max =
stream_mode_ && should_flush &&
flush_mode_ == Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX;
int next_frame_first_col = 0; int next_frame_first_col = 0;
std::vector<Timestamp> timestamps; std::vector<Timestamp> timestamps;
while ((!stream_mode_ || !should_flush) && if (!should_flush_at_timestamp_max) {
next_frame_first_col + num_samples_ <= buffer.cols()) { while (next_frame_first_col + num_samples_ <= buffer.cols()) {
ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block( MP_RETURN_IF_ERROR(OutputTensor(
0, next_frame_first_col, buffer.block(0, next_frame_first_col, num_channels_, num_samples_),
num_channels_, num_samples_))); next_output_timestamp_, cc));
kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_);
timestamps.push_back(next_output_timestamp_); timestamps.push_back(next_output_timestamp_);
next_output_timestamp_ += round(frame_step_ / target_sample_rate_ * next_output_timestamp_ += round(frame_step_ / target_sample_rate_ *
Timestamp::kTimestampUnitsPerSecond); Timestamp::kTimestampUnitsPerSecond);
next_frame_first_col += frame_step_; next_frame_first_col += frame_step_;
} }
}
if (should_flush && next_frame_first_col < buffer.cols()) { if (should_flush && next_frame_first_col < buffer.cols()) {
ASSIGN_OR_RETURN(auto output_tensor,
ConvertToTensor(buffer.block(
0, next_frame_first_col, num_channels_,
std::min(num_samples_,
(int)buffer.cols() - next_frame_first_col))));
// In the streaming mode, the flush happens in Close() and a packet at // In the streaming mode, the flush happens in Close() and a packet at
// Timestamp::Max() will be emitted. In the non-streaming mode, each // Timestamp::Max() will be emitted. In the non-streaming mode, each
// Process() invocation will process the entire buffer completely. // Process() invocation will process the entire buffer completely.
Timestamp timestamp = Timestamp timestamp = should_flush_at_timestamp_max
stream_mode_ ? Timestamp::Max() : next_output_timestamp_; ? Timestamp::Max()
: next_output_timestamp_;
MP_RETURN_IF_ERROR(OutputTensor(
buffer.block(
0, next_frame_first_col, num_channels_,
std::min(num_samples_, (int)buffer.cols() - next_frame_first_col)),
timestamp, cc));
timestamps.push_back(timestamp); timestamps.push_back(timestamp);
kTensorsOut(cc).Send(std::move(output_tensor), timestamp);
} }
if (kTimestampsOut(cc).IsConnected()) { if (kTimestampsOut(cc).IsConnected()) {
Timestamp timestamp = timestamps.back(); Timestamp timestamp = timestamps.back();

View File

@ -44,4 +44,28 @@ message AudioToTensorCalculatorOptions {
// Set to false to disable checks for jitter in timestamp values. Useful with // Set to false to disable checks for jitter in timestamp values. Useful with
// live audio input. // live audio input.
optional bool check_inconsistent_timestamps = 6 [default = true]; optional bool check_inconsistent_timestamps = 6 [default = true];
// Size of the fft in number of bins. If set, the calculator outputs fft
// tensors.
optional int64 fft_size = 7;
// The amount of padding samples to add before the audio after resampling.
// Note that the timestamps shift. Currently, only zero padding is supported.
optional int64 padding_samples_before = 8;
// The amount of padding samples to add after the audio after resampling.
// Currently, only zero padding is supported.
optional int64 padding_samples_after = 9;
// Determines the "flushing" behavior in stream mode.
enum FlushMode {
// Unspecified (causes an error). Won't be used because of the default.
NONE = 0;
// Emit a packet with the entire remainder at `Timestamp::Max`.
ENTIRE_TAIL_AT_TIMESTAMP_MAX = 1;
// Continue emitting framed packets with relevant timestamps.
PROCEED_AS_USUAL = 2;
}
optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX];
} }

View File

@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "audio/dsp/resampler_q.h" #include "audio/dsp/resampler_q.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -32,6 +32,14 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using ::testing::Not;
using Options = ::mediapipe::AudioToTensorCalculatorOptions;
using FlushMode = Options::FlushMode;
int DivideRoundedUp(int dividend, int divisor) {
return (dividend + divisor - 1) / divisor;
}
std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples, std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples,
int timestamp) { int timestamp) {
auto matrix = std::make_unique<Matrix>(num_channels, num_samples); auto matrix = std::make_unique<Matrix>(num_channels, num_samples);
@ -292,16 +300,17 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
num_iterations_ = num_iterations; num_iterations_ = num_iterations;
} }
int GetExpectedNumOfSamples() { int GetExpectedNumOfSamples() { return output_sample_buffer_->cols(); }
Matrix* expected_matrix =
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
return expected_matrix->cols();
}
void Run(int num_samples, int num_overlapping_samples, void Run(int num_samples, int num_overlapping_samples,
double resampling_factor) { double resampling_factor, int padding_before = 0,
int padding_after = 0, bool expect_init_error = false) {
double input_sample_rate = 10000; double input_sample_rate = 10000;
double target_sample_rate = input_sample_rate * resampling_factor; double target_sample_rate = input_sample_rate * resampling_factor;
FlushMode flush_mode = (padding_before != 0 || padding_after != 0)
? Options::PROCEED_AS_USUAL
: Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX;
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>( auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"( absl::Substitute(R"(
input_stream: "audio" input_stream: "audio"
@ -319,16 +328,25 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
num_overlapping_samples: $1 num_overlapping_samples: $1
target_sample_rate: $2 target_sample_rate: $2
stream_mode:true stream_mode:true
padding_samples_before: $3
padding_samples_after: $4
flush_mode: $5
} }
} }
} }
)", )",
/*$0=*/num_samples, /*$1=*/num_overlapping_samples, /*$0=*/num_samples, /*$1=*/num_overlapping_samples,
/*$2=*/target_sample_rate)); /*$2=*/target_sample_rate, /*$3=*/padding_before,
/*$4=*/padding_after, /*$5=*/flush_mode));
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_); tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
// Run the graph. // Run the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config)); const absl::Status init_status = graph_.Initialize(graph_config);
if (expect_init_error) {
EXPECT_THAT(init_status, Not(IsOk()));
return;
}
MP_ASSERT_OK(init_status);
MP_ASSERT_OK(graph_.StartRun({})); MP_ASSERT_OK(graph_.StartRun({}));
for (int i = 0; i < num_iterations_; ++i) { for (int i = 0; i < num_iterations_; ++i) {
Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i); Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i);
@ -345,8 +363,18 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
} }
MP_ASSERT_OK(graph_.CloseAllInputStreams()); MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilIdle()); MP_ASSERT_OK(graph_.WaitUntilIdle());
if (resampling_factor != 1) { if (resampling_factor == 1) {
resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor); output_sample_buffer_ = std::make_unique<Matrix>(*sample_buffer_);
} else {
output_sample_buffer_ =
ResampleBuffer(*sample_buffer_, resampling_factor);
}
if (padding_before != 0 || padding_after != 0) {
Matrix padded = Matrix::Zero(
2, padding_before + output_sample_buffer_->cols() + padding_after);
padded.block(0, padding_before, 2, output_sample_buffer_->cols()) =
*output_sample_buffer_;
output_sample_buffer_->swap(padded);
} }
} }
@ -372,14 +400,12 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
auto buffer = output_tensor.GetCpuReadView().buffer<float>(); auto buffer = output_tensor.GetCpuReadView().buffer<float>();
int num_values = output_tensor.shape().num_elements(); int num_values = output_tensor.shape().num_elements();
std::vector<float> output_floats(buffer, buffer + num_values); std::vector<float> output_floats(buffer, buffer + num_values);
Matrix* expected_matrix =
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
for (int i = 0; i < num_values; ++i) { for (int i = 0; i < num_values; ++i) {
if (i + sample_offset >= expected_matrix->size()) { if (i + sample_offset >= output_sample_buffer_->size()) {
EXPECT_FLOAT_EQ(output_floats[i], 0); EXPECT_FLOAT_EQ(output_floats[i], 0);
} else { } else {
EXPECT_NEAR(output_floats[i], EXPECT_NEAR(output_floats[i],
expected_matrix->coeff((i + sample_offset) % 2, output_sample_buffer_->coeff((i + sample_offset) % 2,
(i + sample_offset) / 2), (i + sample_offset) / 2),
0.001) 0.001)
<< "i=" << i << ", sample_offset=" << sample_offset << "i=" << i << ", sample_offset=" << sample_offset
@ -391,7 +417,8 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } absl::Status TryCloseGraph() { return graph_.WaitUntilDone(); }
void CloseGraph() { MP_EXPECT_OK(TryCloseGraph()); }
private: private:
int input_buffer_num_samples_ = 10; int input_buffer_num_samples_ = 10;
@ -399,7 +426,7 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
CalculatorGraph graph_; CalculatorGraph graph_;
std::vector<Packet> tensors_packets_; std::vector<Packet> tensors_packets_;
std::unique_ptr<Matrix> sample_buffer_; std::unique_ptr<Matrix> sample_buffer_;
std::unique_ptr<Matrix> resampled_buffer_; std::unique_ptr<Matrix> output_sample_buffer_;
}; };
TEST_F(AudioToTensorCalculatorStreamingModeTest, TEST_F(AudioToTensorCalculatorStreamingModeTest,
@ -408,7 +435,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest,
/*resampling_factor=*/1.0f); /*resampling_factor=*/1.0f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/10, /*sample_offset=*/10,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 5),
/*timestamp_interval=*/500, /*timestamp_interval=*/500,
/*output_last_at_close=*/false); /*output_last_at_close=*/false);
CloseGraph(); CloseGraph();
@ -419,7 +446,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) {
/*resampling_factor=*/1.0f); /*resampling_factor=*/1.0f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/12, /*sample_offset=*/12,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 6),
/*timestamp_interval=*/600, /*timestamp_interval=*/600,
/*output_last_at_close=*/true); /*output_last_at_close=*/true);
CloseGraph(); CloseGraph();
@ -431,7 +458,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) {
/*resampling_factor=*/1.0f); /*resampling_factor=*/1.0f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/16, /*sample_offset=*/16,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 8),
/*timestamp_interval=*/800, /*timestamp_interval=*/800,
/*output_last_at_close=*/true); /*output_last_at_close=*/true);
CloseGraph(); CloseGraph();
@ -443,7 +470,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) {
/*resampling_factor=*/0.5f); /*resampling_factor=*/0.5f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/512, /*sample_offset=*/512,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256),
/*timestamp_interval=*/51200, /*timestamp_interval=*/51200,
/*output_last_at_close=*/true); /*output_last_at_close=*/true);
CloseGraph(); CloseGraph();
@ -455,7 +482,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) {
/*resampling_factor=*/0.5f); /*resampling_factor=*/0.5f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/384, /*sample_offset=*/384,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192),
/*timestamp_interval=*/38400, /*timestamp_interval=*/38400,
/*output_last_at_close=*/true); /*output_last_at_close=*/true);
CloseGraph(); CloseGraph();
@ -467,7 +494,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) {
/*resampling_factor=*/2.0f); /*resampling_factor=*/2.0f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/512, /*sample_offset=*/512,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256),
/*timestamp_interval=*/12800, /*timestamp_interval=*/12800,
/*output_last_at_close=*/true); /*output_last_at_close=*/true);
CloseGraph(); CloseGraph();
@ -479,12 +506,33 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) {
/*resampling_factor=*/2.0f); /*resampling_factor=*/2.0f);
CheckTensorsOutputPackets( CheckTensorsOutputPackets(
/*sample_offset=*/384, /*sample_offset=*/384,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192),
/*timestamp_interval=*/9600, /*timestamp_interval=*/9600,
/*output_last_at_close=*/true); /*output_last_at_close=*/true);
CloseGraph(); CloseGraph();
} }
TEST_F(AudioToTensorCalculatorStreamingModeTest,
UpsamplingWithOverlappingAndPadding) {
SetInputBufferNumSamplesPerChannel(1024);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
/*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/999);
CheckTensorsOutputPackets(
/*sample_offset=*/384,
/*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192),
/*timestamp_interval=*/9600,
/*output_last_at_close=*/false);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, NegativePaddingUnsupported) {
SetInputBufferNumSamplesPerChannel(1024);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
/*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/-3,
/*expect_init_error=*/true);
EXPECT_THAT(TryCloseGraph(), Not(IsOk()));
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, TEST_F(AudioToTensorCalculatorStreamingModeTest,
OnlyOutputInCloseIfNoSufficientSamples) { OnlyOutputInCloseIfNoSufficientSamples) {
SetNumIterations(1); SetNumIterations(1);
@ -498,5 +546,122 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest,
CloseGraph(); CloseGraph();
} }
class AudioToTensorCalculatorFftTest : public ::testing::Test {
protected:
// Creates an audio matrix containing a single sample of 1.0 at a specified
// offset.
std::unique_ptr<Matrix> CreateImpulseSignalData(int64 num_samples,
int impulse_offset_idx) {
Matrix impulse = Matrix::Zero(1, num_samples);
impulse(0, impulse_offset_idx) = 1.0;
return std::make_unique<Matrix>(std::move(impulse));
}
void ConfigGraph(int num_channels, int num_samples,
int num_overlapping_samples, double sample_rate,
int fft_size) {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "audio"
input_stream: "sample_rate"
output_stream: "tensors"
output_stream: "dc_and_nyquist"
node {
calculator: "AudioToTensorCalculator"
input_stream: "AUDIO:audio"
input_stream: "SAMPLE_RATE:sample_rate"
output_stream: "TENSORS:tensors"
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
options {
[mediapipe.AudioToTensorCalculatorOptions.ext] {
num_channels: $0
num_samples: $1
num_overlapping_samples: $2
target_sample_rate: $3
fft_size: $4
}
}
}
)",
/*$0=*/num_channels,
/*$1=*/num_samples,
/*$2=*/num_overlapping_samples,
/*$3=*/sample_rate, /*$4=*/fft_size));
std::vector<Packet> tensors_packets;
tool::AddVectorSink("tensors", &graph_config_, &tensors_packets_);
std::vector<Packet> dc_and_nyquist_packets;
tool::AddVectorSink("dc_and_nyquist", &graph_config_,
&dc_and_nyquist_packets_);
}
void RunGraph(std::unique_ptr<Matrix> input_data, double sample_rate) {
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"audio", MakePacket<Matrix>(*input_data).At(Timestamp(0))));
MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilIdle());
ASSERT_EQ(tensors_packets_.size(), dc_and_nyquist_packets_.size());
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
std::vector<Packet> tensors_packets_;
std::vector<Packet> dc_and_nyquist_packets_;
CalculatorGraphConfig graph_config_;
CalculatorGraph graph_;
};
TEST_F(AudioToTensorCalculatorFftTest, TestInvalidFftSize) {
ConfigGraph(1, 320, 160, 16000, 103);
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
auto status = graph_.WaitUntilIdle();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(status.message(),
::testing::HasSubstr("FFT size must be of the form"));
}
TEST_F(AudioToTensorCalculatorFftTest, TestInvalidNumChannels) {
ConfigGraph(3, 320, 160, 16000, 256);
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
auto status = graph_.WaitUntilIdle();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(
status.message(),
::testing::HasSubstr("only support applying FFT on mono channel"));
}
TEST_F(AudioToTensorCalculatorFftTest, TestImpulseSignal) {
constexpr double sample_rate = 16000;
ConfigGraph(1, 320, 160, sample_rate, 320);
RunGraph(CreateImpulseSignalData(320, 160), sample_rate);
for (int i = 0; i < tensors_packets_.size(); ++i) {
const auto& tensors = tensors_packets_[i].Get<std::vector<Tensor>>();
ASSERT_EQ(1, tensors.size());
const Tensor& output_tensor =
tensors_packets_[0].Get<std::vector<Tensor>>()[0];
auto* buffer = output_tensor.GetCpuReadView().buffer<float>();
int num_values = output_tensor.shape().num_elements();
const std::vector<float> output_floats(buffer, buffer + num_values);
// Impulse signal should have (approximately) const power across all
// frequency bins.
const auto& pair =
dc_and_nyquist_packets_[i].Get<std::pair<float, float>>();
EXPECT_FLOAT_EQ(pair.first, 1.0f);
EXPECT_FLOAT_EQ(pair.second, 1.0f);
for (int j = 0; j < num_values / 2; ++j) {
std::complex<float> cf(output_floats[j * 2], output_floats[j * 2 + 1]);
EXPECT_FLOAT_EQ(std::norm(cf), 1.0f);
}
}
CloseGraph();
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,165 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <utility>
#include "absl/status/status.h"
#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
namespace api2 {
namespace {
constexpr char kInputTensorsTag[] = "INPUT_TENSORS";
constexpr char kFeedbackTensorsTag[] = "FEEDBACK_TENSORS";
constexpr char kOutputTensorsTag[] = "TENSORS";
using Tensors = std::vector<Tensor>;
} // namespace
// FeedbackTensorsCalculator groups the input and the feedback (typically
// recurrent neural network cell state output tensors from the previous run)
// tensor vectors as the input tensor vector for the next recurrent model cell
// inference. On the first step, the feedback tensor is filled with zeros to
// jumpstart the loop.
class FeedbackTensorsCalculator : public Node {
public:
static constexpr Input<Tensors> kFeedbackTensorsIn{kFeedbackTensorsTag};
static constexpr Input<Tensors> kInputTensorsIn{kInputTensorsTag};
static constexpr Output<Tensors> kTensorsOut{kOutputTensorsTag};
MEDIAPIPE_NODE_CONTRACT(kFeedbackTensorsIn, kInputTensorsIn, kTensorsOut);
static absl::Status GetContract(CalculatorContract* cc) {
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
const auto& options =
cc->Options<mediapipe::FeedbackTensorsCalculatorOptions>();
const auto& shape_dims = options.feedback_tensor_shape().dims();
feedback_tensor_shape_.dims.assign(shape_dims.begin(), shape_dims.end());
feedback_tensor_size_ = feedback_tensor_shape_.num_elements();
num_feedback_tensors_ = options.num_feedback_tensors();
feedback_tensors_location_ = options.location();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (feedback_tensors_location_ ==
mediapipe::FeedbackTensorsCalculatorOptions::NONE) {
kTensorsOut(cc).Send(kInputTensorsIn(cc).packet().As<Tensors>());
return absl::OkStatus();
}
std::vector<Tensor> outputs;
switch (feedback_tensors_location_) {
case mediapipe::FeedbackTensorsCalculatorOptions::PREPENDED:
MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs));
MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs));
break;
case mediapipe::FeedbackTensorsCalculatorOptions::APPENDED:
MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs));
MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs));
break;
default:
return absl::InvalidArgumentError(
"Unsupported feedback tensors location");
}
kTensorsOut(cc).Send(std::move(outputs));
return absl::OkStatus();
}
private:
absl::Status AddInputTensors(CalculatorContext* cc,
std::vector<Tensor>& outputs) {
absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> input_tensors =
cc->Inputs()
.Tag(kInputTensorsTag)
.Value()
.Consume<std::vector<Tensor>>();
if (!input_tensors.ok()) {
return absl::InternalError("The input tensors packet is not consumable");
}
RET_CHECK(*input_tensors);
std::vector<Tensor>& inputs = **input_tensors;
outputs.insert(outputs.end(), std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
return absl::OkStatus();
}
absl::Status AddFeedbackTensors(CalculatorContext* cc,
std::vector<Tensor>& outputs) {
if (first_run_) {
for (int index = 0; index < num_feedback_tensors_; ++index) {
Tensor initial_feedback_tensor(Tensor::ElementType::kFloat32,
feedback_tensor_shape_);
float* data = initial_feedback_tensor.GetCpuWriteView().buffer<float>();
std::fill_n(data, feedback_tensor_size_, 0.0f);
outputs.push_back(std::move(initial_feedback_tensor));
}
first_run_ = false;
return absl::OkStatus();
}
if (num_feedback_tensors_ != kFeedbackTensorsIn(cc)->size()) {
return absl::InvalidArgumentError(
"The number of tensors fed back differs from the configuration");
}
absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> feedback_tensors =
cc->Inputs()
.Tag(kFeedbackTensorsTag)
.Value()
.Consume<std::vector<Tensor>>();
if (!feedback_tensors.ok()) {
return absl::InternalError(
"The feedback tensors packet is not consumable");
}
RET_CHECK(*feedback_tensors);
std::vector<Tensor>& feedbacks = **feedback_tensors;
for (const auto& feedback : feedbacks) {
if (feedback.shape().dims != feedback_tensor_shape_.dims) {
return absl::InvalidArgumentError(
"The shape of a tensor fed back differs from the configuration");
}
}
outputs.insert(outputs.end(), std::make_move_iterator(feedbacks.begin()),
std::make_move_iterator(feedbacks.end()));
return absl::OkStatus();
}
Tensor::Shape feedback_tensor_shape_;
int num_feedback_tensors_ = 0;
mediapipe::FeedbackTensorsCalculatorOptions::FeedbackTensorsLocation
feedback_tensors_location_;
int feedback_tensor_size_ = 0;
bool first_run_ = true;
};
MEDIAPIPE_REGISTER_NODE(FeedbackTensorsCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,47 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message FeedbackTensorsCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional FeedbackTensorsCalculatorOptions ext = 474496252;
}
// Represents the dimensions of a tensor starting from the outermost size.
message TensorShape {
repeated int32 dims = 1 [packed = true];
}
// The shape of the feedback tensors to add.
optional TensorShape feedback_tensor_shape = 1;
// The number of the feedback tensors to add.
optional int32 num_feedback_tensors = 2 [default = 1];
enum FeedbackTensorsLocation {
// The feedback tensors will not be added.
NONE = 0;
// The feedback tensors will be added before the input tensors.
PREPENDED = 1;
// The feedback tensors will be added after the input tensors.
APPENDED = 2;
}
// Determines the location of the feedback tensor(s) in the output vector.
optional FeedbackTensorsLocation location = 3 [default = APPENDED];
}

View File

@ -0,0 +1,389 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <initializer_list>
#include <memory>
#include <utility>
#include <vector>
#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::testing::ElementsAreArray;
using ::testing::Not;
using Tensors = std::vector<Tensor>;
template <typename T>
struct TensorElementType {
static constexpr Tensor::ElementType value = Tensor::ElementType::kNone;
};
template <>
struct TensorElementType<float> {
static constexpr Tensor::ElementType value = Tensor::ElementType::kFloat32;
};
template <>
struct TensorElementType<std::int8_t> {
static constexpr Tensor::ElementType value = Tensor::ElementType::kInt8;
};
template <>
struct TensorElementType<std::uint8_t> {
static constexpr Tensor::ElementType value = Tensor::ElementType::kUInt8;
};
template <>
struct TensorElementType<std::int32_t> {
static constexpr Tensor::ElementType value = Tensor::ElementType::kInt32;
};
template <typename T>
Tensor MakeTensor(std::initializer_list<int> shape,
std::initializer_list<T> values) {
Tensor tensor(TensorElementType<T>::value, shape);
CHECK_EQ(values.size(), tensor.shape().num_elements())
<< "The size of `values` is incompatible with `shape`";
absl::c_copy(values, tensor.GetCpuWriteView().buffer<T>());
return tensor;
}
template <typename T>
void ValidateTensor(const Tensor& tensor,
const std::vector<int>& expected_shape,
const std::vector<T>& expected_values) {
ASSERT_EQ(tensor.element_type(), TensorElementType<T>::value);
EXPECT_EQ(tensor.shape().dims, expected_shape);
EXPECT_EQ(tensor.shape().num_elements(), expected_values.size());
auto* tensor_buffer = tensor.GetCpuReadView().buffer<T>();
const std::vector<T> tensor_values(
tensor_buffer, tensor_buffer + tensor.shape().num_elements());
EXPECT_THAT(tensor_values, ElementsAreArray(expected_values));
}
TEST(FeedbackTensorsCalculatorTest, AppendsFeedback) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
input_stream: "feedback"
node {
calculator: "FeedbackTensorsCalculator"
input_stream: "INPUT_TENSORS:input"
input_stream: "FEEDBACK_TENSORS:feedback"
output_stream: "TENSORS:output"
options: {
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
feedback_tensor_shape: { dims: 2 dims: 3 }
location: APPENDED
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto initial_input_tensors = std::make_unique<Tensors>();
initial_input_tensors->push_back(
MakeTensor<std::int32_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
// At the beginning, the loopback packet with the model feedback is missing.
// The calculator has to assume it's all-zero with the shape from the options.
auto later_input_tensors = std::make_unique<Tensors>();
later_input_tensors->push_back(
MakeTensor<std::int32_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
auto later_feedback_tensors = std::make_unique<Tensors>();
later_feedback_tensors->push_back(
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
MP_ASSERT_OK(graph.CloseAllInputStreams())
<< "Couldn't close the graph inputs";
MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run";
ASSERT_EQ(output_packets.size(), 2);
const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>();
ASSERT_EQ(initial_combined_tensors.size(), 2);
ValidateTensor<std::int32_t>(initial_combined_tensors[0],
/*expected_shape=*/{2, 4},
/*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8});
// The initial feedback is zero.
ValidateTensor<float>(initial_combined_tensors[1], /*expected_shape=*/{2, 3},
/*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>();
ASSERT_EQ(later_combined_tensors.size(), 2);
ValidateTensor<std::int32_t>(later_combined_tensors[0],
/*expected_shape=*/{2, 4},
/*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1});
// Afterwards, the provided feedback is passed through.
ValidateTensor<float>(
later_combined_tensors[1], /*expected_shape=*/{2, 3},
/*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
}
TEST(FeedbackTensorsCalculatorTest, PrependsFeedback) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
input_stream: "feedback"
node {
calculator: "FeedbackTensorsCalculator"
input_stream: "INPUT_TENSORS:input"
input_stream: "FEEDBACK_TENSORS:feedback"
output_stream: "TENSORS:output"
options: {
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
feedback_tensor_shape: { dims: 3 dims: 2 }
location: PREPENDED
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto initial_input_tensors = std::make_unique<Tensors>();
initial_input_tensors->push_back(
MakeTensor<std::int8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
// At the beginning, the loopback packet with the model feedback is missing.
// The calculator has to assume it's all-zero with the shape from the options.
auto later_input_tensors = std::make_unique<Tensors>();
later_input_tensors->push_back(
MakeTensor<std::int8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
auto later_feedback_tensors = std::make_unique<Tensors>();
later_feedback_tensors->push_back(
MakeTensor({3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
MP_ASSERT_OK(graph.CloseAllInputStreams())
<< "Couldn't close the graph inputs";
MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run";
ASSERT_EQ(output_packets.size(), 2);
const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>();
ASSERT_EQ(initial_combined_tensors.size(), 2);
// The initial feedback is zero.
ValidateTensor<float>(initial_combined_tensors[0], /*expected_shape=*/{3, 2},
/*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
ValidateTensor<std::int8_t>(initial_combined_tensors[1],
/*expected_shape=*/{2, 4},
/*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8});
const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>();
ASSERT_EQ(later_combined_tensors.size(), 2);
// Afterwards, the provided feedback is passed through.
ValidateTensor<float>(
later_combined_tensors[0], /*expected_shape=*/{3, 2},
/*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
ValidateTensor<std::int8_t>(later_combined_tensors[1],
/*expected_shape=*/{2, 4},
/*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1});
}
TEST(FeedbackTensorsCalculatorTest, NoFeedback) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
input_stream: "feedback"
node {
calculator: "FeedbackTensorsCalculator"
input_stream: "INPUT_TENSORS:input"
input_stream: "FEEDBACK_TENSORS:feedback"
output_stream: "TENSORS:output"
options: {
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
feedback_tensor_shape: { dims: 3 dims: 4 }
location: NONE
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto initial_input_tensors = std::make_unique<Tensors>();
initial_input_tensors->push_back(
MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
// At the beginning, the loopback packet with the model feedback is missing.
auto later_input_tensors = std::make_unique<Tensors>();
later_input_tensors->push_back(
MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
// This feedback should be ignored due to `location: NONE`.
auto later_feedback_tensors = std::make_unique<Tensors>();
later_feedback_tensors->push_back(
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
MP_ASSERT_OK(graph.CloseAllInputStreams())
<< "Couldn't close the graph inputs";
MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run";
ASSERT_EQ(output_packets.size(), 2);
const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>();
ASSERT_EQ(initial_combined_tensors.size(), 1);
ValidateTensor<std::uint8_t>(initial_combined_tensors[0],
/*expected_shape=*/{2, 4},
/*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8});
// No feedback due to `location: NONE`.
const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>();
ASSERT_EQ(later_combined_tensors.size(), 1);
ValidateTensor<std::uint8_t>(later_combined_tensors[0],
/*expected_shape=*/{2, 4},
/*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1});
}
TEST(FeedbackTensorsCalculatorTest, ChecksTensorNumber) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
input_stream: "feedback"
node {
calculator: "FeedbackTensorsCalculator"
input_stream: "INPUT_TENSORS:input"
input_stream: "FEEDBACK_TENSORS:feedback"
output_stream: "TENSORS:output"
options: {
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
num_feedback_tensors: 2
feedback_tensor_shape: { dims: 2 dims: 3 }
location: PREPENDED
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto initial_input_tensors = std::make_unique<Tensors>();
initial_input_tensors->push_back(
MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
// At the beginning, the loopback packet with the model feedback is missing.
auto later_input_tensors = std::make_unique<Tensors>();
later_input_tensors->push_back(
MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
// This feedback should be ignored due to `location: NONE`.
auto later_feedback_tensors = std::make_unique<Tensors>();
later_feedback_tensors->push_back(
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
MP_ASSERT_OK(graph.CloseAllInputStreams())
<< "Couldn't close the graph inputs";
EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk()))
<< "Tensor number mismatch missed";
}
TEST(FeedbackTensorsCalculatorTest, ChecksShape) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
input_stream: "feedback"
node {
calculator: "FeedbackTensorsCalculator"
input_stream: "INPUT_TENSORS:input"
input_stream: "FEEDBACK_TENSORS:feedback"
output_stream: "TENSORS:output"
options: {
[mediapipe.FeedbackTensorsCalculatorOptions.ext] {
feedback_tensor_shape: { dims: 3 dims: 4 }
location: APPENDED
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto initial_input_tensors = std::make_unique<Tensors>();
initial_input_tensors->push_back(
MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(initial_input_tensors.release()).At(Timestamp(1))));
// At the beginning, the loopback packet with the model feedback is missing.
auto later_input_tensors = std::make_unique<Tensors>();
later_input_tensors->push_back(
MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", Adopt(later_input_tensors.release()).At(Timestamp(2))));
// This feedback should be ignored due to `location: NONE`.
auto later_feedback_tensors = std::make_unique<Tensors>();
later_feedback_tensors->push_back(
MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2))));
MP_ASSERT_OK(graph.CloseAllInputStreams())
<< "Couldn't close the graph inputs";
EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk()))
<< "Tensor shape mismatch missed";
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,74 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
namespace api2 {
// Trivially converts an input string into a Tensor that stores a copy of
// the string.
//
// Inputs:
// TEXT - std::string
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor storing a copy of the input string.
// Note that the underlying buffer of the Tensor is not necessarily
// null-terminated. It is the graph writer's responsibility to copy the
// correct number of characters when copying from this Tensor's buffer.
//
// Example:
// node {
// calculator: "TextToTensorCalculator"
// input_stream: "TEXT:text"
// output_stream: "TENSORS:tensors"
// }
class TextToTensorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);
absl::Status Process(CalculatorContext* cc) override;
};
absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
absl::string_view text = kTextIn(cc).Get();
int input_len = static_cast<int>(text.length());
std::vector<Tensor> result;
result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})});
std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
input_len * sizeof(char));
kTensorsOut(cc).Send(std::move(result));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,88 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_map.h"
namespace mediapipe {
namespace {
using ::testing::StrEq;
absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "TextToTensorCalculator"
input_stream: "TEXT:text"
output_stream: "TENSORS:tensors"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected 1", tensor_vec.size()));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
}
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
return std::string(buffer, text.length());
}
TEST(TextToTensorCalculatorTest, FooBarBaz) {
EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"),
IsOkAndHolds(StrEq("Foo. Bar? Baz!")));
}
TEST(TextToTensorCalculatorTest, Empty) {
EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq("")));
}
} // namespace
} // namespace mediapipe

View File

@ -231,7 +231,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices; std::vector<tensorflow::DeviceAttributes> devices;
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus());
EXPECT_THAT(devices.size(), 10); EXPECT_THAT(devices.size(), 10);
} }

View File

@ -220,7 +220,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices; std::vector<tensorflow::DeviceAttributes> devices;
ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus());
EXPECT_THAT(devices.size(), 10); EXPECT_THAT(devices.size(), 10);
} }

View File

@ -135,6 +135,7 @@ filegroup(
srcs = [ srcs = [
"testdata/anchor_golden_file_0.txt", "testdata/anchor_golden_file_0.txt",
"testdata/anchor_golden_file_1.txt", "testdata/anchor_golden_file_1.txt",
"testdata/anchor_golden_file_2.txt",
], ],
) )

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <cmath> #include <cmath>
#include <utility>
#include <vector> #include <vector>
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
@ -24,6 +25,19 @@ namespace mediapipe {
namespace { namespace {
struct MultiScaleAnchorInfo {
int32 level;
std::vector<float> aspect_ratios;
std::vector<float> scales;
std::pair<float, float> base_anchor_size;
std::pair<float, float> anchor_stride;
};
struct FeatureMapDim {
int height;
int width;
};
float CalculateScale(float min_scale, float max_scale, int stride_index, float CalculateScale(float min_scale, float max_scale, int stride_index,
int num_strides) { int num_strides) {
if (num_strides == 1) { if (num_strides == 1) {
@ -34,6 +48,71 @@ float CalculateScale(float min_scale, float max_scale, int stride_index,
} }
} }
int GetNumLayers(const SsdAnchorsCalculatorOptions& options) {
if (options.multiscale_anchor_generation()) {
return (options.max_level() - options.min_level() + 1);
}
return options.num_layers();
}
FeatureMapDim GetFeatureMapDimensions(
const SsdAnchorsCalculatorOptions& options, int index) {
FeatureMapDim feature_map_dims;
if (options.feature_map_height_size()) {
feature_map_dims.height = options.feature_map_height(index);
feature_map_dims.width = options.feature_map_width(index);
} else {
const int stride = options.strides(index);
feature_map_dims.height =
std::ceil(1.0f * options.input_size_height() / stride);
feature_map_dims.width =
std::ceil(1.0f * options.input_size_width() / stride);
}
return feature_map_dims;
}
// Although we have stride for both x and y, only one value is used for offset
// calculation. See
// tensorflow_models/object_detection/anchor_generators/multiscale_grid_anchor_generator.py;l=121
std::pair<float, float> GetMultiScaleAnchorOffset(
const SsdAnchorsCalculatorOptions& options, const float stride,
const int level) {
std::pair<float, float> result(0., 0.);
int denominator = std::pow(2, level);
if (options.input_size_height() % denominator == 0 ||
options.input_size_height() == 1) {
result.first = stride / 2.0;
}
if (options.input_size_width() % denominator == 0 ||
options.input_size_width() == 1) {
result.second = stride / 2.0;
}
return result;
}
void NormalizeAnchor(const int input_height, const int input_width,
Anchor* anchor) {
anchor->set_h(anchor->h() / (float)input_height);
anchor->set_w(anchor->w() / (float)input_width);
anchor->set_y_center(anchor->y_center() / (float)input_height);
anchor->set_x_center(anchor->x_center() / (float)input_width);
}
Anchor CalculateAnchorBox(const int y_center, const int x_center,
const float scale, const float aspect_ratio,
const std::pair<float, float> base_anchor_size,
// y-height first
const std::pair<float, float> anchor_stride,
const std::pair<float, float> anchor_offset) {
Anchor result;
float ratio_sqrt = std::sqrt(aspect_ratio);
result.set_h(scale * base_anchor_size.first / ratio_sqrt);
result.set_w(scale * ratio_sqrt * base_anchor_size.second);
result.set_y_center(y_center * anchor_stride.first + anchor_offset.first);
result.set_x_center(x_center * anchor_stride.second + anchor_offset.second);
return result;
}
} // namespace } // namespace
// Generate anchors for SSD object detection model. // Generate anchors for SSD object detection model.
@ -95,9 +174,77 @@ class SsdAnchorsCalculator : public CalculatorBase {
private: private:
static absl::Status GenerateAnchors( static absl::Status GenerateAnchors(
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options); std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options);
static absl::Status GenerateMultiScaleAnchors(
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options);
}; };
REGISTER_CALCULATOR(SsdAnchorsCalculator); REGISTER_CALCULATOR(SsdAnchorsCalculator);
// Generates grid anchors on the fly corresponding to multiple CNN layers as
// described in:
// "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002)
// T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
absl::Status SsdAnchorsCalculator::GenerateMultiScaleAnchors(
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) {
std::vector<MultiScaleAnchorInfo> anchor_infos;
for (int i = options.min_level(); i <= options.max_level(); ++i) {
MultiScaleAnchorInfo current_anchor_info;
// level
current_anchor_info.level = i;
// aspect_ratios
for (const float aspect_ratio : options.aspect_ratios()) {
current_anchor_info.aspect_ratios.push_back(aspect_ratio);
}
// scale
for (int i = 0; i < options.scales_per_octave(); ++i) {
current_anchor_info.scales.push_back(
std::pow(2.0, (double)i / (double)options.scales_per_octave()));
}
// anchor stride
float anchor_stride = std::pow(2.0, i);
current_anchor_info.anchor_stride =
std::make_pair(anchor_stride, anchor_stride);
// base_anchor_size
current_anchor_info.base_anchor_size =
std::make_pair(anchor_stride * options.anchor_scale(),
anchor_stride * options.anchor_scale());
anchor_infos.push_back(current_anchor_info);
}
for (unsigned int i = 0; i < anchor_infos.size(); ++i) {
FeatureMapDim dimensions = GetFeatureMapDimensions(options, i);
for (int y = 0; y < dimensions.height; ++y) {
for (int x = 0; x < dimensions.width; ++x) {
// loop over combination of scale and aspect ratio
for (unsigned int j = 0; j < anchor_infos[i].aspect_ratios.size();
++j) {
for (unsigned int k = 0; k < anchor_infos[i].scales.size(); ++k) {
Anchor anchor = CalculateAnchorBox(
/*y_center=*/y, /*x_center=*/x, anchor_infos[i].scales[k],
anchor_infos[i].aspect_ratios[j],
anchor_infos[i].base_anchor_size,
/*anchor_stride=*/anchor_infos[i].anchor_stride,
/*anchor_offset=*/
GetMultiScaleAnchorOffset(options,
anchor_infos[i].anchor_stride.first,
anchor_infos[i].level));
if (options.normalize_coordinates()) {
NormalizeAnchor(options.input_size_height(),
options.input_size_width(), &anchor);
}
anchors->push_back(anchor);
}
}
}
}
}
return absl::OkStatus();
}
absl::Status SsdAnchorsCalculator::GenerateAnchors( absl::Status SsdAnchorsCalculator::GenerateAnchors(
std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) { std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) {
// Verify the options. // Verify the options.
@ -106,15 +253,21 @@ absl::Status SsdAnchorsCalculator::GenerateAnchors(
"Both feature map shape and strides are missing. Must provide either " "Both feature map shape and strides are missing. Must provide either "
"one."); "one.");
} }
const int kNumLayers = GetNumLayers(options);
if (options.feature_map_height_size()) { if (options.feature_map_height_size()) {
if (options.strides_size()) { if (options.strides_size()) {
LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; LOG(ERROR) << "Found feature map shapes. Strides will be ignored.";
} }
CHECK_EQ(options.feature_map_height_size(), options.num_layers()); CHECK_EQ(options.feature_map_height_size(), kNumLayers);
CHECK_EQ(options.feature_map_height_size(), CHECK_EQ(options.feature_map_height_size(),
options.feature_map_width_size()); options.feature_map_width_size());
} else { } else {
CHECK_EQ(options.strides_size(), options.num_layers()); CHECK_EQ(options.strides_size(), kNumLayers);
}
if (options.multiscale_anchor_generation()) {
return GenerateMultiScaleAnchors(anchors, options);
} }
int layer_id = 0; int layer_id = 0;

View File

@ -60,4 +60,30 @@ message SsdAnchorsCalculatorOptions {
// This option can be used when the predicted anchor width and height are in // This option can be used when the predicted anchor width and height are in
// pixels. // pixels.
optional bool fixed_anchor_size = 14 [default = false]; optional bool fixed_anchor_size = 14 [default = false];
// Generates grid anchors on the fly corresponding to multiple CNN layers as
// described in:
// "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002)
// T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
optional bool multiscale_anchor_generation = 15 [default = false];
// minimum level in feature pyramid
// for multiscale_anchor_generation only!
optional int32 min_level = 16 [default = 3];
// maximum level in feature pyramid
// for multiscale_anchor_generation only!
optional int32 max_level = 17 [default = 7];
// Scale of anchor to feature stride
// for multiscale_anchor_generation only!
optional float anchor_scale = 18 [default = 4.0];
// Number of intermediate scale each scale octave
// for multiscale_anchor_generation only!
optional int32 scales_per_octave = 19 [default = 2];
// Whether to produce anchors in normalized coordinates.
// for multiscale_anchor_generation only!
optional bool normalize_coordinates = 20 [default = true];
} }

View File

@ -33,9 +33,6 @@ std::string GetGoldenFilePath(const std::string& filename) {
void ParseAnchorsFromText(const std::string& text, void ParseAnchorsFromText(const std::string& text,
std::vector<Anchor>* anchors) { std::vector<Anchor>* anchors) {
const std::string line_delimiter = "\n";
const std::string number_delimiter = ",";
std::istringstream stream(text); std::istringstream stream(text);
std::string line; std::string line;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
@ -64,6 +61,8 @@ void CompareAnchors(const std::vector<Anchor>& anchors_0,
testing::FloatNear(anchor_1.x_center(), 1e-5)); testing::FloatNear(anchor_1.x_center(), 1e-5));
EXPECT_THAT(anchor_0.y_center(), EXPECT_THAT(anchor_0.y_center(),
testing::FloatNear(anchor_1.y_center(), 1e-5)); testing::FloatNear(anchor_1.y_center(), 1e-5));
EXPECT_THAT(anchor_0.h(), testing::FloatNear(anchor_1.h(), 1e-5));
EXPECT_THAT(anchor_0.w(), testing::FloatNear(anchor_1.w(), 1e-5));
} }
} }
@ -148,4 +147,40 @@ TEST(SsdAnchorCalculatorTest, MobileSSDConfig) {
CompareAnchors(anchors, anchors_golden); CompareAnchors(anchors, anchors_golden);
} }
TEST(SsdAnchorCalculatorTest, RetinaNetSSDConfig) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "SsdAnchorsCalculator"
output_side_packet: "anchors"
options {
[mediapipe.SsdAnchorsCalculatorOptions.ext] {
input_size_height: 640
input_size_width: 640
strides: 64
strides: 128
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
multiscale_anchor_generation: true
min_level: 6
max_level: 7
anchor_scale: 3.0
scales_per_octave: 3
}
}
)pb"));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const auto& anchors =
runner.OutputSidePackets().Index(0).Get<std::vector<Anchor>>();
std::string anchors_string;
MP_EXPECT_OK(mediapipe::file::GetContents(
GetGoldenFilePath("anchor_golden_file_2.txt"), &anchors_string));
std::vector<Anchor> anchors_golden;
ParseAnchorsFromText(anchors_string, &anchors_golden);
CompareAnchors(anchors, anchors_golden);
}
} // namespace mediapipe } // namespace mediapipe

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
namespace mediapipe { namespace mediapipe {
@ -32,6 +33,8 @@ namespace mediapipe {
// it to the graph as input side packet or you can use some of // it to the graph as input side packet or you can use some of
// calculators like LocalFileContentsCalculator to get model // calculators like LocalFileContentsCalculator to get model
// blob and use it as input here. // blob and use it as input here.
// MODEL_FD - Tflite model file descriptor std::tuple<int, size_t, size_t>
// containing (fd, offset, size).
// //
// Output side packets: // Output side packets:
// MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel, // MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
@ -52,17 +55,42 @@ class TfLiteModelCalculator : public CalculatorBase {
std::function<void(tflite::FlatBufferModel*)>>; std::function<void(tflite::FlatBufferModel*)>>;
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag("MODEL_BLOB")) {
cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>(); cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>();
}
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
cc->InputSidePackets()
.Tag("MODEL_FD")
.Set<std::tuple<int, size_t, size_t>>();
}
cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>(); cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); Packet model_packet;
std::unique_ptr<tflite::FlatBufferModel> model;
if (cc->InputSidePackets().HasTag("MODEL_BLOB")) {
model_packet = cc->InputSidePackets().Tag("MODEL_BLOB");
const std::string& model_blob = model_packet.Get<std::string>(); const std::string& model_blob = model_packet.Get<std::string>();
std::unique_ptr<tflite::FlatBufferModel> model = model = tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
model_blob.size()); model_blob.size());
}
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>();
auto model_allocation = std::make_unique<tflite::MMAPAllocation>(
std::get<0>(model_fd), std::get<1>(model_fd), std::get<2>(model_fd),
tflite::DefaultErrorReporter());
model = tflite::FlatBufferModel::BuildFromAllocation(
std::move(model_allocation), tflite::DefaultErrorReporter());
}
RET_CHECK(model) << "Failed to load TfLite model from blob."; RET_CHECK(model) << "Failed to load TfLite model from blob.";
cc->OutputSidePackets().Tag("MODEL").Set( cc->OutputSidePackets().Tag("MODEL").Set(

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.3 KiB

View File

@ -90,6 +90,7 @@
{ {
"idiom" : "ipad", "idiom" : "ipad",
"size" : "83.5x83.5", "size" : "83.5x83.5",
"filename" : "83.5_c_Ipad_2x.png",
"scale" : "2x" "scale" : "2x"
}, },
{ {

View File

@ -21,7 +21,9 @@ cc_library(
":port", ":port",
"//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -4,7 +4,9 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/const_str.h" #include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h" #include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
@ -46,7 +48,7 @@ struct TagIndexLocation {
template <typename T> template <typename T>
class TagIndexMap { class TagIndexMap {
public: public:
std::vector<std::unique_ptr<T>>& operator[](const std::string& tag) { std::vector<std::unique_ptr<T>>& operator[](absl::string_view tag) {
return map_[tag]; return map_[tag];
} }
@ -72,7 +74,7 @@ class TagIndexMap {
// Note: entries are held by a unique_ptr to ensure pointers remain valid. // Note: entries are held by a unique_ptr to ensure pointers remain valid.
// Should use absl::flat_hash_map but ordering keys for now. // Should use absl::flat_hash_map but ordering keys for now.
std::map<std::string, std::vector<std::unique_ptr<T>>> map_; absl::btree_map<std::string, std::vector<std::unique_ptr<T>>> map_;
}; };
class Graph; class Graph;
@ -169,6 +171,16 @@ class SourceImpl {
return AddTarget(dest); return AddTarget(dest);
} }
template <typename U>
struct AllowCast
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>> {};
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
SourceImpl<IsSide, U> Cast() {
return SourceImpl<IsSide, U>(base_);
}
private: private:
// Never null. // Never null.
SourceBase* base_; SourceBase* base_;
@ -212,19 +224,19 @@ class NodeBase {
// of its entries by index. However, for nodes without visible contracts we // of its entries by index. However, for nodes without visible contracts we
// can't know whether a tag is indexable or not, so we would need the // can't know whether a tag is indexable or not, so we would need the
// multi-port to also be usable as a port directly (representing index 0). // multi-port to also be usable as a port directly (representing index 0).
MultiSource<> Out(const std::string& tag) { MultiSource<> Out(absl::string_view tag) {
return MultiSource<>(&out_streams_[tag]); return MultiSource<>(&out_streams_[tag]);
} }
MultiDestination<> In(const std::string& tag) { MultiDestination<> In(absl::string_view tag) {
return MultiDestination<>(&in_streams_[tag]); return MultiDestination<>(&in_streams_[tag]);
} }
MultiSideSource<> SideOut(const std::string& tag) { MultiSideSource<> SideOut(absl::string_view tag) {
return MultiSideSource<>(&out_sides_[tag]); return MultiSideSource<>(&out_sides_[tag]);
} }
MultiSideDestination<> SideIn(const std::string& tag) { MultiSideDestination<> SideIn(absl::string_view tag) {
return MultiSideDestination<>(&in_sides_[tag]); return MultiSideDestination<>(&in_sides_[tag]);
} }
@ -359,11 +371,11 @@ class PacketGenerator {
public: public:
PacketGenerator(std::string type) : type_(std::move(type)) {} PacketGenerator(std::string type) : type_(std::move(type)) {}
MultiSideSource<> SideOut(const std::string& tag) { MultiSideSource<> SideOut(absl::string_view tag) {
return MultiSideSource<>(&out_sides_[tag]); return MultiSideSource<>(&out_sides_[tag]);
} }
MultiSideDestination<> SideIn(const std::string& tag) { MultiSideDestination<> SideIn(absl::string_view tag) {
return MultiSideDestination<>(&in_sides_[tag]); return MultiSideDestination<>(&in_sides_[tag]);
} }
@ -452,19 +464,19 @@ class Graph {
} }
// Graph ports, non-typed. // Graph ports, non-typed.
MultiSource<> In(const std::string& graph_input) { MultiSource<> In(absl::string_view graph_input) {
return graph_boundary_.Out(graph_input); return graph_boundary_.Out(graph_input);
} }
MultiDestination<> Out(const std::string& graph_output) { MultiDestination<> Out(absl::string_view graph_output) {
return graph_boundary_.In(graph_output); return graph_boundary_.In(graph_output);
} }
MultiSideSource<> SideIn(const std::string& graph_input) { MultiSideSource<> SideIn(absl::string_view graph_input) {
return graph_boundary_.SideOut(graph_input); return graph_boundary_.SideOut(graph_input);
} }
MultiSideDestination<> SideOut(const std::string& graph_output) { MultiSideDestination<> SideOut(absl::string_view graph_output) {
return graph_boundary_.SideIn(graph_output); return graph_boundary_.SideIn(graph_output);
} }

View File

@ -2,6 +2,7 @@
#include <functional> #include <functional>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
@ -296,6 +297,32 @@ TEST(BuilderTest, EmptyTag) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }
TEST(BuilderTest, StringLikeTags) {
const char kA[] = "A";
const std::string kB = "B";
constexpr absl::string_view kC = "C";
builder::Graph graph;
auto& foo = graph.AddNode("Foo");
graph.In(kA).SetName("a") >> foo.In(kA);
graph.In(kB).SetName("b") >> foo.In(kB);
foo.Out(kC).SetName("c") >> graph.Out(kC);
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "A:a"
input_stream: "B:b"
output_stream: "C:c"
node {
calculator: "Foo"
input_stream: "A:a"
input_stream: "B:b"
output_stream: "C:c"
}
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, GraphIndexes) { TEST(BuilderTest, GraphIndexes) {
builder::Graph graph; builder::Graph graph;
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
@ -326,52 +353,63 @@ TEST(BuilderTest, GraphIndexes) {
class AnyAndSameTypeCalculator : public NodeIntf { class AnyAndSameTypeCalculator : public NodeIntf {
public: public:
static constexpr Input<AnyType> kAnyTypeInput{"INPUT"}; static constexpr Input<AnyType>::Optional kAnyTypeInput{"INPUT"};
static constexpr Output<AnyType> kAnyTypeOutput{"ANY_OUTPUT"}; static constexpr Output<AnyType>::Optional kAnyTypeOutput{"ANY_OUTPUT"};
static constexpr Output<SameType<kAnyTypeInput>> kSameTypeOutput{ static constexpr Output<SameType<kAnyTypeInput>>::Optional kSameTypeOutput{
"SAME_OUTPUT"}; "SAME_OUTPUT"};
static constexpr Output<SameType<kSameTypeOutput>> kRecursiveSameTypeOutput{
"RECURSIVE_SAME_OUTPUT"};
static constexpr Input<int> kIntInput{"INT_INPUT"}; static constexpr Input<int>::Optional kIntInput{"INT_INPUT"};
// `SameType` usage for this output is only for testing purposes. // `SameType` usage for this output is only for testing purposes.
// //
// `SameType` is designed to work with inputs of `AnyType` and, normally, you // `SameType` is designed to work with inputs of `AnyType` and, normally, you
// would not use `Output<SameType<kIntInput>>` in a real calculator. You // would not use `Output<SameType<kIntInput>>` in a real calculator. You
// should write `Output<int>` instead, since the type is known. // should write `Output<int>` instead, since the type is known.
static constexpr Output<SameType<kIntInput>> kSameIntOutput{ static constexpr Output<SameType<kIntInput>>::Optional kSameIntOutput{
"SAME_INT_OUTPUT"}; "SAME_INT_OUTPUT"};
static constexpr Output<SameType<kSameIntOutput>> kRecursiveSameIntOutput{
"RECURSIVE_SAME_INT_OUTPUT"};
MEDIAPIPE_NODE_INTERFACE(AnyTypeCalculator, kAnyTypeInput, kAnyTypeOutput, MEDIAPIPE_NODE_INTERFACE(AnyAndSameTypeCalculator, kAnyTypeInput,
kSameTypeOutput); kAnyTypeOutput, kSameTypeOutput);
}; };
TEST(BuilderTest, AnyAndSameTypeHandledProperly) { TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
builder::Graph graph; builder::Graph graph;
builder::Source<internal::Generic> any_input = builder::Source<AnyType> any_input = graph[Input<AnyType>{"GRAPH_ANY_INPUT"}];
graph[Input<AnyType>{"GRAPH_ANY_INPUT"}];
builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}]; builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}];
auto& node = graph.AddNode("AnyAndSameTypeCalculator"); auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
int_input >> node[AnyAndSameTypeCalculator::kIntInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput];
builder::Source<internal::Generic> any_type_output = builder::Source<AnyType> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput]; node[AnyAndSameTypeCalculator::kAnyTypeOutput];
any_type_output.SetName("any_type_output"); any_type_output.SetName("any_type_output");
builder::Source<internal::Generic> same_type_output = builder::Source<AnyType> same_type_output =
node[AnyAndSameTypeCalculator::kSameTypeOutput]; node[AnyAndSameTypeCalculator::kSameTypeOutput];
same_type_output.SetName("same_type_output"); same_type_output.SetName("same_type_output");
builder::Source<internal::Generic> same_int_output = builder::Source<AnyType> recursive_same_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
recursive_same_type_output.SetName("recursive_same_type_output");
builder::Source<int> same_int_output =
node[AnyAndSameTypeCalculator::kSameIntOutput]; node[AnyAndSameTypeCalculator::kSameIntOutput];
same_int_output.SetName("same_int_output"); same_int_output.SetName("same_int_output");
builder::Source<int> recursive_same_int_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
recursive_same_int_type_output.SetName("recursive_same_int_type_output");
CalculatorGraphConfig expected = CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie<
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( CalculatorGraphConfig>(R"pb(
node { node {
calculator: "AnyAndSameTypeCalculator" calculator: "AnyAndSameTypeCalculator"
input_stream: "INPUT:__stream_0" input_stream: "INPUT:__stream_0"
input_stream: "INT_INPUT:__stream_1" input_stream: "INT_INPUT:__stream_1"
output_stream: "ANY_OUTPUT:any_type_output" output_stream: "ANY_OUTPUT:any_type_output"
output_stream: "RECURSIVE_SAME_INT_OUTPUT:recursive_same_int_type_output"
output_stream: "RECURSIVE_SAME_OUTPUT:recursive_same_type_output"
output_stream: "SAME_INT_OUTPUT:same_int_output" output_stream: "SAME_INT_OUTPUT:same_int_output"
output_stream: "SAME_OUTPUT:same_type_output" output_stream: "SAME_OUTPUT:same_type_output"
} }
@ -381,6 +419,29 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }
TEST(BuilderTest, AnyTypeCanBeCast) {
builder::Graph graph;
builder::Source<std::string> any_input =
graph.In("GRAPH_ANY_INPUT").Cast<std::string>();
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
builder::Source<double> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "AnyAndSameTypeCalculator"
input_stream: "INPUT:__stream_0"
output_stream: "ANY_OUTPUT:any_type_output"
}
input_stream: "GRAPH_ANY_INPUT:__stream_0"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace test } // namespace test
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -27,9 +27,7 @@ using HolderBase = mediapipe::packet_internal::HolderBase;
template <typename T> template <typename T>
class Packet; class Packet;
struct DynamicType {}; struct AnyType {
struct AnyType : public DynamicType {
AnyType() = delete; AnyType() = delete;
}; };

View File

@ -73,14 +73,12 @@ class SideOutputBase : public PortBase {
}; };
struct NoneType { struct NoneType {
private:
NoneType() = delete; NoneType() = delete;
}; };
template <auto& P> template <auto& kP>
class SameType : public DynamicType { struct SameType {
public: static constexpr const decltype(kP)& kPort = kP;
static constexpr const decltype(P)& kPort = P;
}; };
class PacketTypeAccess; class PacketTypeAccess;
@ -137,21 +135,28 @@ struct IsOneOf : std::false_type {};
template <class... T> template <class... T>
struct IsOneOf<OneOf<T...>> : std::true_type {}; struct IsOneOf<OneOf<T...>> : std::true_type {};
template <typename T, typename std::enable_if< template <class T>
!std::is_base_of<DynamicType, T>{} && !IsOneOf<T>{}, struct IsSameType : std::false_type {};
template <class P, P& kP>
struct IsSameType<SameType<kP>> : std::true_type {};
template <typename T,
typename std::enable_if<!std::is_same<T, AnyType>{} &&
!IsOneOf<T>{} && !IsSameType<T>{},
int>::type = 0> int>::type = 0>
inline void SetType(CalculatorContract* cc, PacketType& pt) { inline void SetType(CalculatorContract* cc, PacketType& pt) {
pt.Set<T>(); pt.Set<T>();
} }
template <typename T, typename std::enable_if<std::is_base_of<DynamicType, T>{}, template <typename T, typename std::enable_if<IsSameType<T>{}, int>::type = 0>
int>::type = 0>
inline void SetType(CalculatorContract* cc, PacketType& pt) { inline void SetType(CalculatorContract* cc, PacketType& pt) {
pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag())); pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag()));
} }
template <> template <typename T,
inline void SetType<AnyType>(CalculatorContract* cc, PacketType& pt) { typename std::enable_if<std::is_same<T, AnyType>{}, int>::type = 0>
inline void SetType(CalculatorContract* cc, PacketType& pt) {
pt.SetAny(); pt.SetAny();
} }
@ -289,15 +294,15 @@ struct SideBase<InputBase> {
}; };
// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>? // TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
template <typename T, class = void> template <typename T, typename = void>
struct ActualPayloadType { struct ActualPayloadType {
using type = T; using type = T;
}; };
template <typename T> template <typename T>
struct ActualPayloadType< struct ActualPayloadType<T, std::enable_if_t<IsSameType<T>{}, void>> {
T, std::enable_if_t<std::is_base_of<DynamicType, T>{}, void>> { using type = typename ActualPayloadType<
using type = internal::Generic; typename std::decay_t<decltype(T::kPort)>::value_t>::type;
}; };
} // namespace internal } // namespace internal

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -85,14 +85,8 @@ cv::Mat MatView(const ImageFrame* image) {
const size_t steps[] = {static_cast<size_t>(image->WidthStep()), const size_t steps[] = {static_cast<size_t>(image->WidthStep()),
static_cast<size_t>(image->ByteDepth())}; static_cast<size_t>(image->ByteDepth())};
// Use ImageFrame to initialize in-place. ImageFrame still owns memory. // Use ImageFrame to initialize in-place. ImageFrame still owns memory.
if (steps[0] == sizes[1] * image->NumberOfChannels() * image->ByteDepth()) { return cv::Mat(dims, sizes, type, const_cast<uint8_t*>(image->PixelData()),
// Contiguous memory optimization. See b/78570764
return cv::Mat(dims, sizes, type, const_cast<uint8*>(image->PixelData()));
} else {
// Custom width step.
return cv::Mat(dims, sizes, type, const_cast<uint8*>(image->PixelData()),
steps); steps);
}
} }
} // namespace formats } // namespace formats

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -21,7 +21,6 @@
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
// Set image_frame to a constant per-channel pix_value. // Set image_frame to a constant per-channel pix_value.
@ -50,8 +49,8 @@ TEST(ImageFrameOpencvTest, ConvertToMat) {
ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height); ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height);
// Check adding constant images. // Check adding constant images.
const uint8 frame1_val = 12; const uint8_t frame1_val = 12;
const uint8 frame2_val = 34; const uint8_t frame2_val = 34;
SetToColor<uint8>(&frame1_val, &frame1); SetToColor<uint8>(&frame1_val, &frame1);
SetToColor<uint8>(&frame2_val, &frame2); SetToColor<uint8>(&frame2_val, &frame2);
// Get Mat wrapper around ImageFrame memory (zero copy). // Get Mat wrapper around ImageFrame memory (zero copy).
@ -77,6 +76,37 @@ TEST(ImageFrameOpencvTest, ConvertToMat) {
EXPECT_EQ(max_loc.y, i_height - 6); EXPECT_EQ(max_loc.y, i_height - 6);
} }
TEST(ImageFrameOpencvTest, ConvertToIpl) {
const int i_width = 123, i_height = 45;
ImageFrame frame1(ImageFormat::GRAY8, i_width, i_height);
ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height);
// Check adding constant images.
const uint8_t frame1_val = 12;
const uint8_t frame2_val = 34;
SetToColor<uint8>(&frame1_val, &frame1);
SetToColor<uint8>(&frame2_val, &frame2);
const cv::Mat frame1_mat = formats::MatView(&frame1);
const cv::Mat frame2_mat = formats::MatView(&frame2);
const cv::Mat frame_sum = frame1_mat + frame2_mat;
const auto frame_avg = static_cast<int>(cv::mean(frame_sum).val[0]);
EXPECT_EQ(frame_avg, frame1_val + frame2_val);
// Check setting min/max pixels.
uint8* frame1_ptr = frame1.MutablePixelData();
frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1;
frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100;
double min, max;
cv::Point min_loc, max_loc;
cv::minMaxLoc(frame1_mat, &min, &max, &min_loc, &max_loc);
EXPECT_EQ(min, 1);
EXPECT_EQ(min_loc.x, i_width - 5);
EXPECT_EQ(min_loc.y, i_height - 5);
EXPECT_EQ(max, 100);
EXPECT_EQ(max_loc.x, i_width - 6);
EXPECT_EQ(max_loc.y, i_height - 6);
}
TEST(ImageFrameOpencvTest, ImageFormats) { TEST(ImageFrameOpencvTest, ImageFormats) {
const int i_width = 123, i_height = 45; const int i_width = 123, i_height = 45;
ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height); ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height);

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
// Copyright 2019-2020 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.

View File

@ -37,26 +37,21 @@ namespace mediapipe {
bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; } bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; }
int BhwcBatchFromShape(const Tensor::Shape& shape) { int BhwcBatchFromShape(const Tensor::Shape& shape) {
LOG_IF(FATAL, shape.dims.empty()) if (shape.dims.empty()) {
<< "Tensor::Shape must be non-empty to retrieve a named dimension"; return 1;
}
return shape.dims[0]; return shape.dims[0];
} }
int BhwcHeightFromShape(const Tensor::Shape& shape) { int BhwcHeightFromShape(const Tensor::Shape& shape) {
LOG_IF(FATAL, shape.dims.empty())
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3]; return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3];
} }
int BhwcWidthFromShape(const Tensor::Shape& shape) { int BhwcWidthFromShape(const Tensor::Shape& shape) {
LOG_IF(FATAL, shape.dims.empty())
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2]; return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2];
} }
int BhwcDepthFromShape(const Tensor::Shape& shape) { int BhwcDepthFromShape(const Tensor::Shape& shape) {
LOG_IF(FATAL, shape.dims.empty())
<< "Tensor::Shape must be non-empty to retrieve a named dimension";
return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1]; return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1];
} }
@ -424,6 +419,11 @@ Tensor::Tensor(ElementType element_type, const Shape& shape,
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
void Tensor::Invalidate() { void Tensor::Invalidate() {
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
GLuint cleanup_gl_tex = GL_INVALID_INDEX;
GLuint cleanup_gl_fb = GL_INVALID_INDEX;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
{
absl::MutexLock lock(&view_mutex_); absl::MutexLock lock(&view_mutex_);
// If memory is allocated and not owned by the metal buffer. // If memory is allocated and not owned by the metal buffer.
// TODO: Re-design cpu buffer memory management. // TODO: Re-design cpu buffer memory management.
@ -432,6 +432,23 @@ void Tensor::Invalidate() {
} }
metal_buffer_ = nil; metal_buffer_ = nil;
cpu_buffer_ = nullptr; cpu_buffer_ = nullptr;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
// Don't need to wait for the resource to be deleted bacause if will be
// released on last reference deletion inside the OpenGL driver.
std::swap(cleanup_gl_tex, opengl_texture2d_);
std::swap(cleanup_gl_fb, frame_buffer_);
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
}
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
// Do not hold the view mutex while invoking GlContext::RunWithoutWaiting,
// since that method may acquire the context's own lock.
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) {
gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() {
glDeleteTextures(1, &cleanup_gl_tex);
glDeleteFramebuffers(1, &cleanup_gl_fb);
});
}
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
} }
#else #else

View File

@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <initializer_list> #include <initializer_list>
#include <numeric>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
@ -89,15 +90,23 @@ class Tensor {
public: public:
// No resources are allocated here. // No resources are allocated here.
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 }; enum class ElementType {
kNone,
kFloat16,
kFloat32,
kUInt8,
kInt8,
kInt32,
// TODO: Update the inference runner to handle kTfLiteString.
kChar
};
struct Shape { struct Shape {
Shape() = default; Shape() = default;
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {} Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
Shape(const std::vector<int>& dimensions) : dims(dimensions) {} Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
int num_elements() const { int num_elements() const {
int res = dims.empty() ? 0 : 1; return std::accumulate(dims.begin(), dims.end(), 1,
std::for_each(dims.begin(), dims.end(), [&res](int i) { res *= i; }); std::multiplies<int>());
return res;
} }
std::vector<int> dims; std::vector<int> dims;
}; };
@ -319,6 +328,8 @@ class Tensor {
return 1; return 1;
case ElementType::kInt32: case ElementType::kInt32:
return sizeof(int32_t); return sizeof(int32_t);
case ElementType::kChar:
return sizeof(char);
} }
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -1,5 +1,8 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include <cstring>
#include <string>
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -23,6 +26,9 @@ TEST(General, TestDataTypes) {
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3}); Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2); EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
} }
TEST(Cpu, TestMemoryAllocation) { TEST(Cpu, TestMemoryAllocation) {

View File

@ -15,7 +15,7 @@ def mediapipe_cc_test(
platforms = ["linux", "android", "ios", "wasm"], platforms = ["linux", "android", "ios", "wasm"],
exclude_platforms = None, exclude_platforms = None,
# ios_unit_test arguments # ios_unit_test arguments
ios_minimum_os_version = "9.0", ios_minimum_os_version = "11.0",
# android_cc_test arguments # android_cc_test arguments
open_gl_driver = None, open_gl_driver = None,
emulator_mini_boot = True, emulator_mini_boot = True,

View File

@ -108,6 +108,7 @@ cc_library(
":sharded_map", ":sharded_map",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_profile_cc_proto", "//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",

View File

@ -22,6 +22,7 @@
#include "absl/time/time.h" #include "absl/time/time.h"
#include "mediapipe/framework/port/advanced_proto_lite_inc.h" #include "mediapipe/framework/port/advanced_proto_lite_inc.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/re2.h" #include "mediapipe/framework/port/re2.h"
@ -244,7 +245,16 @@ absl::Status GraphProfiler::Start(mediapipe::Executor* executor) {
executor != nullptr) { executor != nullptr) {
// Inform the user via logging the path to the trace logs. // Inform the user via logging the path to the trace logs.
ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath()); ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath());
// Check that we can actually write to it.
auto status =
file::SetContents(absl::StrCat(trace_log_path, "trace_writing_check"),
"can write trace logs to this location");
if (status.ok()) {
LOG(INFO) << "trace_log_path: " << trace_log_path; LOG(INFO) << "trace_log_path: " << trace_log_path;
} else {
LOG(ERROR) << "cannot write to trace_log_path: " << trace_log_path << ": "
<< status;
}
is_running_ = true; is_running_ = true;
executor->Schedule([this] { executor->Schedule([this] {

View File

@ -5,7 +5,7 @@
buffers: { buffers: {
size_kb: 150000 size_kb: 150000
fill_policy: DISCARD fill_policy: RING_BUFFER
} }
data_sources: { data_sources: {
@ -21,12 +21,17 @@ data_sources: {
# - what is happening on each CPU at each moment # - what is happening on each CPU at each moment
ftrace_events: "power/cpu_frequency" ftrace_events: "power/cpu_frequency"
ftrace_events: "power/cpu_idle" ftrace_events: "power/cpu_idle"
# TODO: CPU frequency does not show up without scheduling
ftrace_events: "sched/sched_switch" ftrace_events: "sched/sched_switch"
compact_sched { compact_sched {
enabled: true enabled: true
} }
# GPU
ftrace_events: "power/gpu_frequency"
} }
} }
} }
write_into_file: true write_into_file: true
file_write_period_ms: 500 file_write_period_ms: 500
# b/243571696 Added to remove Perfetto timeouts when running benchmarks remotely.
duration_ms: 60000

View File

@ -821,6 +821,19 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_cc_test(
name = "switch_demux_calculator_test",
srcs = ["switch_demux_calculator_test.cc"],
deps = [
":container_util",
":switch_demux_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "switch_mux_calculator", name = "switch_mux_calculator",
srcs = ["switch_mux_calculator.cc"], srcs = ["switch_mux_calculator.cc"],

View File

@ -129,12 +129,12 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
// Relay side packets to all channels. // Relay side packets to all channels.
// Note: This is necessary because Calculator::Open only proceeds when every // Note: This is necessary because Calculator::Open only proceeds when every
// anticipated side-packet arrives. // anticipated side-packet arrives.
int channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap()); int side_channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap());
for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) { for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) {
int num_entries = cc->InputSidePackets().NumEntries(tag); int num_entries = cc->InputSidePackets().NumEntries(tag);
for (int index = 0; index < num_entries; ++index) { for (int index = 0; index < num_entries; ++index) {
Packet input = cc->InputSidePackets().Get(tag, index); Packet input = cc->InputSidePackets().Get(tag, index);
for (int channel = 0; channel < channel_count; ++channel) { for (int channel = 0; channel < side_channel_count; ++channel) {
std::string output_tag = tool::ChannelTag(tag, channel); std::string output_tag = tool::ChannelTag(tag, channel);
auto output_id = cc->OutputSidePackets().GetId(output_tag, index); auto output_id = cc->OutputSidePackets().GetId(output_tag, index);
if (output_id.IsValid()) { if (output_id.IsValid()) {
@ -143,6 +143,23 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
} }
} }
} }
// Relay headers to all channels.
int output_channel_count = tool::ChannelCount(cc->Outputs().TagMap());
for (const std::string& tag : ChannelTags(cc->Outputs().TagMap())) {
int num_entries = cc->Inputs().NumEntries(tag);
for (int index = 0; index < num_entries; ++index) {
auto& input = cc->Inputs().Get(tag, index);
if (input.Header().IsEmpty()) continue;
for (int channel = 0; channel < output_channel_count; ++channel) {
std::string output_tag = tool::ChannelTag(tag, channel);
auto output_id = cc->Outputs().GetId(output_tag, index);
if (output_id.IsValid()) {
cc->Outputs().Get(output_tag, index).SetHeader(input.Header());
}
}
}
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -0,0 +1,135 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/tool/container_util.h"
namespace mediapipe {
namespace {
// Returns a CalculatorGraph to run a single calculator.
CalculatorGraph BuildCalculatorGraph(CalculatorGraphConfig::Node node_config) {
CalculatorGraphConfig config;
*config.add_node() = node_config;
*config.mutable_input_stream() = node_config.input_stream();
*config.mutable_output_stream() = node_config.output_stream();
*config.mutable_input_side_packet() = node_config.input_side_packet();
*config.mutable_output_side_packet() = node_config.output_side_packet();
return CalculatorGraph(config);
}
// Creates a string packet.
Packet pack(std::string data, int timestamp) {
return MakePacket<std::string>(data).At(Timestamp(timestamp));
}
// Creates an int packet.
Packet pack(int data, int timestamp) {
return MakePacket<int>(data).At(Timestamp(timestamp));
}
// Tests showing packet channel synchronization through SwitchDemuxCalculator.
class SwitchDemuxCalculatorTest : public ::testing::Test {
protected:
SwitchDemuxCalculatorTest() {}
~SwitchDemuxCalculatorTest() override {}
void SetUp() override {}
void TearDown() override {}
// Defines a SwitchDemuxCalculator CalculatorGraphConfig::Node.
CalculatorGraphConfig::Node BuildNodeConfig() {
CalculatorGraphConfig::Node result;
*result.mutable_calculator() = "SwitchDemuxCalculator";
*result.add_input_stream() = "SELECT:select";
for (int c = 0; c < 2; ++c) {
*result.add_output_stream() =
absl::StrCat(tool::ChannelTag("FRAME", c), ":frame_", c);
*result.add_output_stream() =
absl::StrCat(tool::ChannelTag("MASK", c), ":mask_", c);
}
*result.add_input_stream() = "FRAME:frame";
*result.add_input_stream() = "MASK:mask";
return result;
}
};
// Shows the SwitchMuxCalculator is available.
TEST_F(SwitchDemuxCalculatorTest, IsRegistered) {
EXPECT_TRUE(CalculatorBaseRegistry::IsRegistered("SwitchDemuxCalculator"));
}
TEST_F(SwitchDemuxCalculatorTest, BasicDataFlow) {
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
CalculatorGraph graph = BuildCalculatorGraph(node_config);
std::vector<Packet> output_frames0;
EXPECT_TRUE(graph
.ObserveOutputStream("frame_0",
[&](const Packet& p) {
output_frames0.push_back(p);
return absl::OkStatus();
})
.ok());
std::vector<Packet> output_frames1;
EXPECT_TRUE(graph
.ObserveOutputStream("frame_1",
[&](const Packet& p) {
output_frames1.push_back(p);
return absl::OkStatus();
})
.ok());
EXPECT_TRUE(
graph.StartRun({}, {{"frame", MakePacket<std::string>("frame_header")}})
.ok());
// Finalize input for the "mask" input stream.
EXPECT_TRUE(graph.CloseInputStream("mask").ok());
// Channel 0 is selected just before corresponding packets arrive.
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 1)).ok());
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 10)).ok());
EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p0_t10", 10)).ok());
EXPECT_TRUE(graph.WaitUntilIdle().ok());
EXPECT_EQ(output_frames0.size(), 1);
EXPECT_EQ(output_frames1.size(), 0);
EXPECT_EQ(output_frames0[0].Get<std::string>(), "p0_t10");
// Channel 1 is selected just before corresponding packets arrive.
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 11)).ok());
EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 20)).ok());
EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p1_t20", 20)).ok());
EXPECT_TRUE(graph.WaitUntilIdle().ok());
EXPECT_EQ(output_frames0.size(), 1);
EXPECT_EQ(output_frames1.size(), 1);
EXPECT_EQ(output_frames1[0].Get<std::string>(), "p1_t20");
EXPECT_EQ(
graph.FindOutputStreamManager("frame_0")->Header().Get<std::string>(),
"frame_header");
EXPECT_EQ(
graph.FindOutputStreamManager("frame_1")->Header().Get<std::string>(),
"frame_header");
EXPECT_TRUE(graph.CloseAllPacketSources().ok());
EXPECT_TRUE(graph.WaitUntilDone().ok());
}
} // namespace
} // namespace mediapipe

View File

@ -271,6 +271,7 @@ cc_library(
deps = [ deps = [
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage", ":gpu_buffer_storage",
"@com_google_absl//absl/strings",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
":gpu_buffer_storage_image_frame", ":gpu_buffer_storage_image_frame",
@ -366,6 +367,23 @@ cc_library(
], ],
) )
cc_library(
name = "gpu_buffer_storage_ahwb",
srcs = ["gpu_buffer_storage_ahwb.cc"],
hdrs = ["gpu_buffer_storage_ahwb.h"],
linkopts = select({
"//conditions:default": [],
"//mediapipe:android": [
"-landroid",
],
}),
deps = [
":gpu_buffer_format",
":gpu_buffer_storage",
"@com_google_absl//absl/strings:str_format",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "gpu_origin_proto", name = "gpu_origin_proto",
srcs = ["gpu_origin.proto"], srcs = ["gpu_origin.proto"],
@ -1087,3 +1105,19 @@ ios_unit_test(
], ],
deps = [":gl_ios_test_lib"], deps = [":gl_ios_test_lib"],
) )
mediapipe_cc_test(
name = "gpu_buffer_storage_ahwb_test",
size = "small",
srcs = ["gpu_buffer_storage_ahwb_test.cc"],
exclude_platforms = [
"ios",
"wasm",
],
requires_full_emulation = True,
deps = [
":gpu_buffer_format",
":gpu_buffer_storage_ahwb",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -620,7 +620,9 @@ class GlSyncWrapper {
#endif #endif
GLenum result = glClientWaitSync(sync_, flags, timeout); GLenum result = glClientWaitSync(sync_, flags, timeout);
if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) {
Clear(); // TODO: we could clear at this point so later calls are faster,
// but we need to do so in a thread-safe way.
// Clear();
} }
// TODO: do something if the wait fails? // TODO: do something if the wait fails?
} }
@ -646,7 +648,9 @@ class GlSyncWrapper {
#endif #endif
GLenum result = glClientWaitSync(sync_, flags, 0); GLenum result = glClientWaitSync(sync_, flags, 0);
if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) {
Clear(); // TODO: we could clear at this point so later calls are faster,
// but we need to do so in a thread-safe way.
// Clear();
return true; return true;
} }
return false; return false;
@ -822,10 +826,17 @@ std::shared_ptr<GlSyncPoint> GlContext::CreateSyncToken() {
return token; return token;
} }
bool GlContext::IsAnyContextCurrent() {
ContextBinding ctx;
GetCurrentContextBinding(&ctx);
return ctx.context != kPlatformGlContextNone;
}
std::shared_ptr<GlSyncPoint> std::shared_ptr<GlSyncPoint>
GlContext::CreateSyncTokenForCurrentExternalContext( GlContext::CreateSyncTokenForCurrentExternalContext(
const std::shared_ptr<GlContext>& delegate_graph_context) { const std::shared_ptr<GlContext>& delegate_graph_context) {
CHECK(delegate_graph_context); CHECK(delegate_graph_context);
if (!IsAnyContextCurrent()) return nullptr;
if (delegate_graph_context->ShouldUseFenceSync()) { if (delegate_graph_context->ShouldUseFenceSync()) {
return std::shared_ptr<GlSyncPoint>( return std::shared_ptr<GlSyncPoint>(
new GlExternalFenceSyncPoint(delegate_graph_context)); new GlExternalFenceSyncPoint(delegate_graph_context));

View File

@ -303,6 +303,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
return *static_cast<T*>(entry.get()); return *static_cast<T*>(entry.get());
} }
// Returns true if any GL context, including external contexts not managed by
// the GlContext class, is current.
static bool IsAnyContextCurrent();
// Creates a synchronization token for the current, non-GlContext-owned // Creates a synchronization token for the current, non-GlContext-owned
// context. This can be passed to MediaPipe so it can synchronize with the // context. This can be passed to MediaPipe so it can synchronize with the
// commands issued in the external context up to this point. // commands issued in the external context up to this point.

View File

@ -145,9 +145,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
CHECK_NE(name_, 0); CHECK_NE(name_, 0);
GLuint name_to_delete = name_; GLuint name_to_delete = name_;
context->RunWithoutWaiting([name_to_delete, sync_token]() { context->RunWithoutWaiting([name_to_delete, sync_token]() {
if (sync_token) {
// TODO: maybe we do not actually have to wait for the // TODO: maybe we do not actually have to wait for the
// consumer sync here. Check docs. // consumer sync here. Check docs.
sync_token->WaitOnGpu(); sync_token->WaitOnGpu();
} else {
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback";
}
DLOG_IF(ERROR, !glIsTexture(name_to_delete)) DLOG_IF(ERROR, !glIsTexture(name_to_delete))
<< "Deleting invalid texture id: " << name_to_delete; << "Deleting invalid texture id: " << name_to_delete;
glDeleteTextures(1, &name_to_delete); glDeleteTextures(1, &name_to_delete);
@ -179,13 +183,19 @@ void GlTextureBuffer::Reuse() {
void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) { void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
CHECK(!producer_sync_) CHECK(!producer_sync_)
<< "Updated existing texture which had not been marked for reuse!"; << "Updated existing texture which had not been marked for reuse!";
CHECK(prod_token);
producer_sync_ = std::move(prod_token); producer_sync_ = std::move(prod_token);
producer_context_ = producer_sync_->GetContext(); producer_context_ = producer_sync_->GetContext();
} }
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const { void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
absl::MutexLock lock(&consumer_sync_mutex_); absl::MutexLock lock(&consumer_sync_mutex_);
if (cons_token) {
consumer_multi_sync_->Add(std::move(cons_token)); consumer_multi_sync_->Add(std::move(cons_token));
} else {
// TODO: change to a CHECK.
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in DidRead";
}
} }
GlTextureBuffer::~GlTextureBuffer() { GlTextureBuffer::~GlTextureBuffer() {

View File

@ -2,6 +2,8 @@
#include <memory> #include <memory>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
@ -10,6 +12,23 @@
namespace mediapipe { namespace mediapipe {
namespace {
struct StorageTypeFormatter {
void operator()(std::string* out,
const std::shared_ptr<internal::GpuBufferStorage>& s) const {
absl::StrAppend(out, s->storage_type().name());
}
};
} // namespace
std::string GpuBuffer::DebugString() const {
return absl::StrCat("GpuBuffer[",
absl::StrJoin(storages_, ", ", StorageTypeFormatter()),
"]");
}
internal::GpuBufferStorage& GpuBuffer::GetStorageForView( internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
TypeId view_provider_type, bool for_writing) const { TypeId view_provider_type, bool for_writing) const {
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr; const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
@ -52,7 +71,10 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
} }
} }
CHECK(chosen_storage) << "no view provider found"; CHECK(chosen_storage) << "no view provider found for requested view "
<< view_provider_type.name() << "; storages available: "
<< absl::StrJoin(storages_, ", ",
StorageTypeFormatter());
DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type));
return **chosen_storage; return **chosen_storage;
} }

View File

@ -129,6 +129,8 @@ class GpuBuffer {
return nullptr; return nullptr;
} }
std::string DebugString() const;
private: private:
class PlaceholderGpuBufferStorage class PlaceholderGpuBufferStorage
: public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> { : public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> {

View File

@ -21,18 +21,29 @@ licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
mediapipe_simple_subgraph(
name = "pose_landmarks_to_render_data",
graph = "pose_landmarks_to_render_data.pbtxt",
register_as = "PoseLandmarksToRenderData",
deps = [
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:split_proto_list_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_scale_calculator",
],
)
mediapipe_simple_subgraph( mediapipe_simple_subgraph(
name = "pose_renderer_gpu", name = "pose_renderer_gpu",
graph = "pose_renderer_gpu.pbtxt", graph = "pose_renderer_gpu.pbtxt",
register_as = "PoseRendererGpu", register_as = "PoseRendererGpu",
deps = [ deps = [
"//mediapipe/calculators/core:split_proto_list_calculator", ":pose_landmarks_to_render_data",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:recolor_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_data_calculator", "//mediapipe/calculators/util:rect_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_scale_calculator",
], ],
) )
@ -41,12 +52,11 @@ mediapipe_simple_subgraph(
graph = "pose_renderer_cpu.pbtxt", graph = "pose_renderer_cpu.pbtxt",
register_as = "PoseRendererCpu", register_as = "PoseRendererCpu",
deps = [ deps = [
"//mediapipe/calculators/core:split_proto_list_calculator", ":pose_landmarks_to_render_data",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:recolor_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_data_calculator", "//mediapipe/calculators/util:rect_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_scale_calculator",
], ],
) )

View File

@ -0,0 +1,236 @@
# MediaPipe pose landmarks to render data subgraph.
type: "PoseLandmarksToRenderData"
# Pose landmarks. (NormalizedLandmarkList)
input_stream: "LANDMARKS:pose_landmarks"
# Region of interest calculated based on landmarks. (NormalizedRect)
input_stream: "ROI:roi"
# Image size. (pair<int, int>)
input_stream: "IMAGE_SIZE:image_size"
# The resulting render data. (vector<RenderData>)
output_stream: "RENDER_DATA:merged_render_data"
# Calculates rendering scale based on the pose roi.
node {
calculator: "RectToRenderScaleCalculator"
input_stream: "NORM_RECT:roi"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "RENDER_SCALE:render_scale"
node_options: {
[type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] {
multiplier: 0.0012
}
}
}
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "visible_pose_landmarks"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 0 end: 25 }
}
}
}
# Converts landmarks to drawing primitives for annotation overlay.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:pose_landmarks"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_connections: 0
landmark_connections: 1
landmark_connections: 1
landmark_connections: 2
landmark_connections: 2
landmark_connections: 3
landmark_connections: 3
landmark_connections: 7
landmark_connections: 0
landmark_connections: 4
landmark_connections: 4
landmark_connections: 5
landmark_connections: 5
landmark_connections: 6
landmark_connections: 6
landmark_connections: 8
landmark_connections: 9
landmark_connections: 10
landmark_connections: 11
landmark_connections: 12
landmark_connections: 11
landmark_connections: 13
landmark_connections: 13
landmark_connections: 15
landmark_connections: 15
landmark_connections: 17
landmark_connections: 15
landmark_connections: 19
landmark_connections: 15
landmark_connections: 21
landmark_connections: 17
landmark_connections: 19
landmark_connections: 12
landmark_connections: 14
landmark_connections: 14
landmark_connections: 16
landmark_connections: 16
landmark_connections: 18
landmark_connections: 16
landmark_connections: 20
landmark_connections: 16
landmark_connections: 22
landmark_connections: 18
landmark_connections: 20
landmark_connections: 11
landmark_connections: 23
landmark_connections: 12
landmark_connections: 24
landmark_connections: 23
landmark_connections: 24
landmark_connections: 23
landmark_connections: 25
landmark_connections: 24
landmark_connections: 26
landmark_connections: 25
landmark_connections: 27
landmark_connections: 26
landmark_connections: 28
landmark_connections: 27
landmark_connections: 29
landmark_connections: 28
landmark_connections: 30
landmark_connections: 29
landmark_connections: 31
landmark_connections: 30
landmark_connections: 32
landmark_connections: 27
landmark_connections: 31
landmark_connections: 28
landmark_connections: 32
landmark_color { r: 255 g: 255 b: 255 }
connection_color { r: 255 g: 255 b: 255 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Take left pose landmarks.
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "landmarks_left_side"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 1 end: 4 }
ranges: { begin: 7 end: 8 }
ranges: { begin: 9 end: 10 }
ranges: { begin: 11 end: 12 }
ranges: { begin: 13 end: 14 }
ranges: { begin: 15 end: 16 }
ranges: { begin: 17 end: 18 }
ranges: { begin: 19 end: 20 }
ranges: { begin: 21 end: 22 }
ranges: { begin: 23 end: 24 }
combine_outputs: true
}
}
}
# Take right pose landmarks.
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "landmarks_right_side"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 4 end: 7 }
ranges: { begin: 8 end: 9 }
ranges: { begin: 10 end: 11 }
ranges: { begin: 12 end: 13 }
ranges: { begin: 14 end: 15 }
ranges: { begin: 16 end: 17 }
ranges: { begin: 18 end: 19 }
ranges: { begin: 20 end: 21 }
ranges: { begin: 22 end: 23 }
ranges: { begin: 24 end: 25 }
combine_outputs: true
}
}
}
# Render pose joints as big white circles.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:visible_pose_landmarks"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_background_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 255 g: 255 b: 255 }
connection_color { r: 255 g: 255 b: 255 }
thickness: 5.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Render pose left side joints as orange circles (inside white ones).
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks_left_side"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_left_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 255 g: 138 b: 0 }
connection_color { r: 255 g: 138 b: 0 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Render pose right side joints as cyan circles (inside white ones).
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks_right_side"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_right_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 0 g: 217 b: 231 }
connection_color { r: 0 g: 217 b: 231 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Merges annotations into one result.
node {
calculator: "ConcatenateRenderDataVectorCalculator"
input_stream: "landmarks_render_data"
input_stream: "landmarks_background_joints_render_data"
input_stream: "landmarks_left_joints_render_data"
input_stream: "landmarks_right_joints_render_data"
output_stream: "merged_render_data"
}

View File

@ -22,19 +22,6 @@ node {
output_stream: "SIZE:image_size" output_stream: "SIZE:image_size"
} }
# Calculates rendering scale based on the pose roi.
node {
calculator: "RectToRenderScaleCalculator"
input_stream: "NORM_RECT:roi"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "RENDER_SCALE:render_scale"
node_options: {
[type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] {
multiplier: 0.0012
}
}
}
# Converts detections to drawing primitives for annotation overlay. # Converts detections to drawing primitives for annotation overlay.
node { node {
calculator: "DetectionsToRenderDataCalculator" calculator: "DetectionsToRenderDataCalculator"
@ -48,204 +35,13 @@ node {
} }
} }
# Computes render data for landmarks.
node { node {
calculator: "SplitNormalizedLandmarkListCalculator" calculator: "PoseLandmarksToRenderData"
input_stream: "pose_landmarks" input_stream: "LANDMARKS:pose_landmarks"
output_stream: "visible_pose_landmarks" input_stream: "ROI:roi"
node_options: { input_stream: "IMAGE_SIZE:image_size"
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 0 end: 25 }
}
}
}
# Converts landmarks to drawing primitives for annotation overlay.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:pose_landmarks"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_render_data" output_stream: "RENDER_DATA:landmarks_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_connections: 0
landmark_connections: 1
landmark_connections: 1
landmark_connections: 2
landmark_connections: 2
landmark_connections: 3
landmark_connections: 3
landmark_connections: 7
landmark_connections: 0
landmark_connections: 4
landmark_connections: 4
landmark_connections: 5
landmark_connections: 5
landmark_connections: 6
landmark_connections: 6
landmark_connections: 8
landmark_connections: 9
landmark_connections: 10
landmark_connections: 11
landmark_connections: 12
landmark_connections: 11
landmark_connections: 13
landmark_connections: 13
landmark_connections: 15
landmark_connections: 15
landmark_connections: 17
landmark_connections: 15
landmark_connections: 19
landmark_connections: 15
landmark_connections: 21
landmark_connections: 17
landmark_connections: 19
landmark_connections: 12
landmark_connections: 14
landmark_connections: 14
landmark_connections: 16
landmark_connections: 16
landmark_connections: 18
landmark_connections: 16
landmark_connections: 20
landmark_connections: 16
landmark_connections: 22
landmark_connections: 18
landmark_connections: 20
landmark_connections: 11
landmark_connections: 23
landmark_connections: 12
landmark_connections: 24
landmark_connections: 23
landmark_connections: 24
landmark_connections: 23
landmark_connections: 25
landmark_connections: 24
landmark_connections: 26
landmark_connections: 25
landmark_connections: 27
landmark_connections: 26
landmark_connections: 28
landmark_connections: 27
landmark_connections: 29
landmark_connections: 28
landmark_connections: 30
landmark_connections: 29
landmark_connections: 31
landmark_connections: 30
landmark_connections: 32
landmark_connections: 27
landmark_connections: 31
landmark_connections: 28
landmark_connections: 32
landmark_color { r: 255 g: 255 b: 255 }
connection_color { r: 255 g: 255 b: 255 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Take left pose landmarks.
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "landmarks_left_side"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 1 end: 4 }
ranges: { begin: 7 end: 8 }
ranges: { begin: 9 end: 10 }
ranges: { begin: 11 end: 12 }
ranges: { begin: 13 end: 14 }
ranges: { begin: 15 end: 16 }
ranges: { begin: 17 end: 18 }
ranges: { begin: 19 end: 20 }
ranges: { begin: 21 end: 22 }
ranges: { begin: 23 end: 24 }
combine_outputs: true
}
}
}
# Take right pose landmarks.
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "landmarks_right_side"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 4 end: 7 }
ranges: { begin: 8 end: 9 }
ranges: { begin: 10 end: 11 }
ranges: { begin: 12 end: 13 }
ranges: { begin: 14 end: 15 }
ranges: { begin: 16 end: 17 }
ranges: { begin: 18 end: 19 }
ranges: { begin: 20 end: 21 }
ranges: { begin: 22 end: 23 }
ranges: { begin: 24 end: 25 }
combine_outputs: true
}
}
}
# Render pose joints as big white circles.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:visible_pose_landmarks"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_background_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 255 g: 255 b: 255 }
connection_color { r: 255 g: 255 b: 255 }
thickness: 5.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Render pose left side joints as orange circles (inside white ones).
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks_left_side"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_left_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 255 g: 138 b: 0 }
connection_color { r: 255 g: 138 b: 0 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Render pose right side joints as cyan circles (inside white ones).
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks_right_side"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_right_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 0 g: 217 b: 231 }
connection_color { r: 0 g: 217 b: 231 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
} }
# Converts normalized rects to drawing primitives for annotation overlay. # Converts normalized rects to drawing primitives for annotation overlay.
@ -283,10 +79,7 @@ node {
calculator: "AnnotationOverlayCalculator" calculator: "AnnotationOverlayCalculator"
input_stream: "IMAGE:segmented_image" input_stream: "IMAGE:segmented_image"
input_stream: "detection_render_data" input_stream: "detection_render_data"
input_stream: "landmarks_render_data" input_stream: "VECTOR:landmarks_render_data"
input_stream: "landmarks_background_joints_render_data"
input_stream: "landmarks_left_joints_render_data"
input_stream: "landmarks_right_joints_render_data"
input_stream: "roi_render_data" input_stream: "roi_render_data"
output_stream: "IMAGE:output_image" output_stream: "IMAGE:output_image"
} }

View File

@ -22,19 +22,6 @@ node {
output_stream: "SIZE:image_size" output_stream: "SIZE:image_size"
} }
# Calculates rendering scale based on the pose roi.
node {
calculator: "RectToRenderScaleCalculator"
input_stream: "NORM_RECT:roi"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "RENDER_SCALE:render_scale"
node_options: {
[type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] {
multiplier: 0.0012
}
}
}
# Converts detections to drawing primitives for annotation overlay. # Converts detections to drawing primitives for annotation overlay.
node { node {
calculator: "DetectionsToRenderDataCalculator" calculator: "DetectionsToRenderDataCalculator"
@ -48,204 +35,13 @@ node {
} }
} }
# Computes render data for landmarks.
node { node {
calculator: "SplitNormalizedLandmarkListCalculator" calculator: "PoseLandmarksToRenderData"
input_stream: "pose_landmarks" input_stream: "LANDMARKS:pose_landmarks"
output_stream: "visible_pose_landmarks" input_stream: "ROI:roi"
node_options: { input_stream: "IMAGE_SIZE:image_size"
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 0 end: 25 }
}
}
}
# Converts landmarks to drawing primitives for annotation overlay.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:pose_landmarks"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_render_data" output_stream: "RENDER_DATA:landmarks_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_connections: 0
landmark_connections: 1
landmark_connections: 1
landmark_connections: 2
landmark_connections: 2
landmark_connections: 3
landmark_connections: 3
landmark_connections: 7
landmark_connections: 0
landmark_connections: 4
landmark_connections: 4
landmark_connections: 5
landmark_connections: 5
landmark_connections: 6
landmark_connections: 6
landmark_connections: 8
landmark_connections: 9
landmark_connections: 10
landmark_connections: 11
landmark_connections: 12
landmark_connections: 11
landmark_connections: 13
landmark_connections: 13
landmark_connections: 15
landmark_connections: 15
landmark_connections: 17
landmark_connections: 15
landmark_connections: 19
landmark_connections: 15
landmark_connections: 21
landmark_connections: 17
landmark_connections: 19
landmark_connections: 12
landmark_connections: 14
landmark_connections: 14
landmark_connections: 16
landmark_connections: 16
landmark_connections: 18
landmark_connections: 16
landmark_connections: 20
landmark_connections: 16
landmark_connections: 22
landmark_connections: 18
landmark_connections: 20
landmark_connections: 11
landmark_connections: 23
landmark_connections: 12
landmark_connections: 24
landmark_connections: 23
landmark_connections: 24
landmark_connections: 23
landmark_connections: 25
landmark_connections: 24
landmark_connections: 26
landmark_connections: 25
landmark_connections: 27
landmark_connections: 26
landmark_connections: 28
landmark_connections: 27
landmark_connections: 29
landmark_connections: 28
landmark_connections: 30
landmark_connections: 29
landmark_connections: 31
landmark_connections: 30
landmark_connections: 32
landmark_connections: 27
landmark_connections: 31
landmark_connections: 28
landmark_connections: 32
landmark_color { r: 255 g: 255 b: 255 }
connection_color { r: 255 g: 255 b: 255 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Take left pose landmarks.
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "landmarks_left_side"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 1 end: 4 }
ranges: { begin: 7 end: 8 }
ranges: { begin: 9 end: 10 }
ranges: { begin: 11 end: 12 }
ranges: { begin: 13 end: 14 }
ranges: { begin: 15 end: 16 }
ranges: { begin: 17 end: 18 }
ranges: { begin: 19 end: 20 }
ranges: { begin: 21 end: 22 }
ranges: { begin: 23 end: 24 }
combine_outputs: true
}
}
}
# Take right pose landmarks.
node {
calculator: "SplitNormalizedLandmarkListCalculator"
input_stream: "pose_landmarks"
output_stream: "landmarks_right_side"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 4 end: 7 }
ranges: { begin: 8 end: 9 }
ranges: { begin: 10 end: 11 }
ranges: { begin: 12 end: 13 }
ranges: { begin: 14 end: 15 }
ranges: { begin: 16 end: 17 }
ranges: { begin: 18 end: 19 }
ranges: { begin: 20 end: 21 }
ranges: { begin: 22 end: 23 }
ranges: { begin: 24 end: 25 }
combine_outputs: true
}
}
}
# Render pose joints as big white circles.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:visible_pose_landmarks"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_background_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 255 g: 255 b: 255 }
connection_color { r: 255 g: 255 b: 255 }
thickness: 5.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Render pose left side joints as orange circles (inside white ones).
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks_left_side"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_left_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 255 g: 138 b: 0 }
connection_color { r: 255 g: 138 b: 0 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
}
# Render pose right side joints as cyan circles (inside white ones).
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:landmarks_right_side"
input_stream: "RENDER_SCALE:render_scale"
output_stream: "RENDER_DATA:landmarks_right_joints_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_color { r: 0 g: 217 b: 231 }
connection_color { r: 0 g: 217 b: 231 }
thickness: 3.0
visualize_landmark_depth: false
utilize_visibility: true
visibility_threshold: 0.5
}
}
} }
# Converts normalized rects to drawing primitives for annotation overlay. # Converts normalized rects to drawing primitives for annotation overlay.
@ -283,10 +79,7 @@ node {
calculator: "AnnotationOverlayCalculator" calculator: "AnnotationOverlayCalculator"
input_stream: "IMAGE_GPU:segmented_image" input_stream: "IMAGE_GPU:segmented_image"
input_stream: "detection_render_data" input_stream: "detection_render_data"
input_stream: "landmarks_render_data" input_stream: "VECTOR:landmarks_render_data"
input_stream: "landmarks_background_joints_render_data"
input_stream: "landmarks_left_joints_render_data"
input_stream: "landmarks_right_joints_render_data"
input_stream: "roi_render_data" input_stream: "roi_render_data"
output_stream: "IMAGE_GPU:output_image" output_stream: "IMAGE_GPU:output_image"
} }

View File

@ -174,6 +174,14 @@ public class ExternalTextureConverter implements TextureFrameProducer {
thread.setRotation(rotation); thread.setRotation(rotation);
} }
/**
* Sets whether the timestamps of each frame should be adjusted to be always monotonically
* increasing. The default behavior is that this is {@code true}.
*/
public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) {
thread.setShouldAdjustTimestamps(shouldAdjustTimestamps);
}
/** /**
* Sets an offset that can be used to adjust the timestamps on the camera frames, for example to * Sets an offset that can be used to adjust the timestamps on the camera frames, for example to
* conform to a preferred time-base or to account for a known device latency. The offset is added * conform to a preferred time-base or to account for a known device latency. The offset is added
@ -298,6 +306,7 @@ public class ExternalTextureConverter implements TextureFrameProducer {
private int bufferPoolMaxSize; private int bufferPoolMaxSize;
private ExternalTextureRenderer renderer = null; private ExternalTextureRenderer renderer = null;
private boolean shouldAdjustTimestamps = true;
private long nextFrameTimestampOffset = 0; private long nextFrameTimestampOffset = 0;
private long timestampOffsetNanos = 0; private long timestampOffsetNanos = 0;
private long previousTimestamp = 0; private long previousTimestamp = 0;
@ -433,6 +442,10 @@ public class ExternalTextureConverter implements TextureFrameProducer {
super.releaseGl(); // This releases the EGL context, so must do it after any GL calls. super.releaseGl(); // This releases the EGL context, so must do it after any GL calls.
} }
public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) {
this.shouldAdjustTimestamps = shouldAdjustTimestamps;
}
public void setTimestampOffsetNanos(long offsetInNanos) { public void setTimestampOffsetNanos(long offsetInNanos) {
timestampOffsetNanos = offsetInNanos; timestampOffsetNanos = offsetInNanos;
} }
@ -565,7 +578,8 @@ public class ExternalTextureConverter implements TextureFrameProducer {
// |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.) // |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.)
long textureTimestamp = long textureTimestamp =
(surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO; (surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO;
if (previousTimestampValid if (shouldAdjustTimestamps
&& previousTimestampValid
&& textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) { && textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) {
nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp; nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp;
} }

View File

@ -15,6 +15,10 @@
package com.google.mediapipe.framework; package com.google.mediapipe.framework;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.framework.image.ImageProperties;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
// TODO: use Preconditions in this file. // TODO: use Preconditions in this file.
@ -55,6 +59,50 @@ public class AndroidPacketCreator extends PacketCreator {
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
} }
/**
* Creates an Image packet from an {@link Image}.
*
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
*/
public Packet createImage(Image image) {
// TODO: Choose the best storage from multiple containers.
ImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
ByteBuffer buffer = ByteBufferExtractor.extract(image);
int numChannels = 0;
switch (properties.getImageFormat()) {
case Image.IMAGE_FORMAT_RGBA:
numChannels = 4;
break;
case Image.IMAGE_FORMAT_RGB:
numChannels = 3;
break;
case Image.IMAGE_FORMAT_ALPHA:
numChannels = 1;
break;
default: // fall out
}
if (numChannels == 0) {
throw new UnsupportedOperationException(
"Unsupported MediaPipe Image image format: " + properties.getImageFormat());
}
int width = image.getWidth();
int height = image.getHeight();
return createImage(buffer, width, height, numChannels);
}
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
}
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
}
// Unsupported type.
throw new UnsupportedOperationException(
"Unsupported Image container type: " + properties.getImageFormat());
}
/** /**
* Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on * Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on
* failure. * failure.

View File

@ -57,6 +57,7 @@ android_library(
], ],
deps = [ deps = [
":android_core", ":android_core",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//third_party:androidx_annotation", "//third_party:androidx_annotation",
"//third_party:androidx_legacy_support_v4", "//third_party:androidx_legacy_support_v4",
"@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_code_findbugs_jsr305",
@ -75,6 +76,7 @@ android_library(
srcs = glob( srcs = glob(
["**/*.java"], ["**/*.java"],
exclude = [ exclude = [
"image/**",
"Android*", "Android*",
"AssetCache.java", "AssetCache.java",
"AssetCacheDbHelper.java", "AssetCacheDbHelper.java",

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.framework.image">
<uses-sdk android:minSdkVersion="16" />
<application />
</manifest>

View File

@ -0,0 +1,32 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
licenses(["notice"])
android_library(
name = "image",
srcs = glob(["*.java"]),
manifest = "AndroidManifest.xml",
visibility = [
"//mediapipe:__subpackages__",
],
deps = [
"//third_party:androidx_legacy_support_v4",
"//third_party:autovalue",
"@maven//:androidx_annotation_annotation",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,49 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.graphics.Bitmap;
/**
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
* {@link IllegalArgumentException} will be thrown.
*/
public final class BitmapExtractor {
/**
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
*
* @param image the image to extract {@link android.graphics.Bitmap} from.
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
public static Bitmap extract(Image image) {
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
if (imageContainer != null) {
return ((BitmapImageContainer) imageContainer).getBitmap();
} else {
// TODO: Support ByteBuffer -> Bitmap conversion.
throw new IllegalArgumentException(
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
+ " supported");
}
}
private BitmapExtractor() {}
}

View File

@ -0,0 +1,72 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.content.Context;
import android.graphics.Bitmap;
import android.net.Uri;
import android.provider.MediaStore;
import java.io.IOException;
/**
* Builds {@link Image} from {@link android.graphics.Bitmap}.
*
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
* in it.
*
* <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in.
*/
public class BitmapImageBuilder {
// Mandatory fields.
private final Bitmap bitmap;
// Optional fields.
private long timestamp;
/**
* Creates the builder with a mandatory {@link android.graphics.Bitmap}.
*
* @param bitmap image data object.
*/
public BitmapImageBuilder(Bitmap bitmap) {
this.bitmap = bitmap;
timestamp = 0;
}
/**
* Creates the builder to build {@link Image} from a file.
*
* @param context the application context.
* @param uri the path to the resource file.
*/
public BitmapImageBuilder(Context context, Uri uri) throws IOException {
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
}
/** Sets value for {@link Image#getTimestamp()}. */
BitmapImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
}
}

View File

@ -0,0 +1,60 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.Image.ImageFormat;
class BitmapImageContainer implements ImageContainer {
private final Bitmap bitmap;
private final ImageProperties properties;
public BitmapImageContainer(Bitmap bitmap) {
this.bitmap = bitmap;
this.properties =
ImageProperties.builder()
.setImageFormat(convertFormatCode(bitmap.getConfig()))
.setStorageType(Image.STORAGE_TYPE_BITMAP)
.build();
}
public Bitmap getBitmap() {
return bitmap;
}
@Override
public ImageProperties getImageProperties() {
return properties;
}
@Override
public void close() {
bitmap.recycle();
}
@ImageFormat
static int convertFormatCode(Bitmap.Config config) {
switch (config) {
case ALPHA_8:
return Image.IMAGE_FORMAT_ALPHA;
case ARGB_8888:
return Image.IMAGE_FORMAT_RGBA;
default:
return Image.IMAGE_FORMAT_UNKNOWN;
}
}
}

View File

@ -0,0 +1,254 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.annotation.SuppressLint;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Locale;
/**
* Utility for extracting {@link ByteBuffer} from {@link Image}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
* {@link IllegalArgumentException} will be thrown.
*/
public class ByteBufferExtractor {
/**
* Extracts a {@link ByteBuffer} from an {@link Image}.
*
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
*
* @see Image#getContainedImageProperties()
* @return A read-only {@link ByteBuffer}.
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/
@SuppressLint("SwitchIntDef")
public static ByteBuffer extract(Image image) {
ImageContainer container = image.getContainer();
switch (container.getImageProperties().getStorageType()) {
case Image.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
default:
throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
+ " supported");
}
}
/**
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
*
* <p>Format conversion spec:
*
* <ul>
* <li>When extracting RGB images to RGBA format, A channel will always set to 255.
* <li>When extracting RGBA images to RGB format, A channel will be dropped.
* </ul>
*
* @param image the image to extract buffer from.
* @param targetFormat the image format of the result bytebuffer.
* @return the readonly {@link ByteBuffer} stored in {@link Image}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
ImageContainer container;
ImageProperties byteBufferProperties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(targetFormat)
.build();
if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer =
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
.asReadOnlyBuffer();
boolean unused = image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
return byteBuffer;
} else {
throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
+ " Bytebuffer is not supported");
}
}
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
@AutoValue
abstract static class Result {
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
public abstract ByteBuffer buffer();
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
@ImageFormat
public abstract int format();
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
}
}
/**
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
*
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
*
* @return the readonly {@link ByteBuffer} stored in {@link Image}
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
* given {@code imageFormat}
*/
static Result extractInRecommendedFormat(Image image) {
ImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
@ImageFormat int format = adviseImageFormat(bitmap);
Result result =
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
boolean unused =
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
return result;
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return Result.create(
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
byteBufferImageContainer.getImageFormat());
} else {
throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
+ " is not supported");
}
}
@ImageFormat
private static int adviseImageFormat(Bitmap bitmap) {
if (bitmap.getConfig() == Config.ARGB_8888) {
return Image.IMAGE_FORMAT_RGBA;
} else {
throw new IllegalArgumentException(
String.format(
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
+ " supported",
bitmap.getConfig()));
}
}
private static ByteBuffer extractByteBufferFromBitmap(
Bitmap bitmap, @ImageFormat int imageFormat) {
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
+ " supported");
}
if (bitmap.getConfig() == Config.ARGB_8888) {
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer);
buffer.rewind();
return buffer;
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
int w = bitmap.getWidth();
int h = bitmap.getHeight();
int[] pixels = new int[w * h];
bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
buffer.order(ByteOrder.nativeOrder());
for (int pixel : pixels) {
// getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA
buffer.put((byte) ((pixel >> 16) & 0xff));
buffer.put((byte) ((pixel >> 8) & 0xff));
buffer.put((byte) (pixel & 0xff));
}
buffer.rewind();
return buffer;
}
}
throw new IllegalArgumentException(
String.format(
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
+ " %d is not supported",
bitmap.getConfig(), imageFormat));
}
private static ByteBuffer convertByteBuffer(
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
// array reversely to convert in-place.
byte[] array = new byte[target.capacity()];
source.get(array, 0, source.capacity());
source.rewind();
int rgbCursor = source.capacity();
int rgbaCursor = target.capacity();
while (rgbCursor != rgbaCursor) {
array[--rgbaCursor] = (byte) 0xff; // A
array[--rgbaCursor] = array[--rgbCursor]; // B
array[--rgbaCursor] = array[--rgbCursor]; // G
array[--rgbaCursor] = array[--rgbCursor]; // R
}
target.put(array, 0, target.capacity());
target.rewind();
return target;
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
// array to convert in-place.
byte[] array = new byte[source.capacity()];
source.get(array, 0, source.capacity());
source.rewind();
int rgbaCursor = 0;
int rgbCursor = 0;
while (rgbaCursor < array.length) {
array[rgbCursor++] = array[rgbaCursor++]; // R
array[rgbCursor++] = array[rgbaCursor++]; // G
array[rgbCursor++] = array[rgbaCursor++]; // B
rgbaCursor++;
}
target.put(array, 0, target.capacity());
target.rewind();
return target;
} else {
throw new IllegalArgumentException(
String.format(
Locale.ENGLISH,
"Convert bytebuffer image format from %d to %d is not supported",
sourceFormat,
targetFormat));
}
}
// ByteBuffer is not able to be instantiated.
private ByteBufferExtractor() {}
}

View File

@ -0,0 +1,71 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import java.nio.ByteBuffer;
/**
* Builds a {@link Image} from a {@link ByteBuffer}.
*
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
*
* <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in.
*/
public class ByteBufferImageBuilder {
// Mandatory fields.
private final ByteBuffer buffer;
private final int width;
private final int height;
@ImageFormat private final int imageFormat;
// Optional fields.
private long timestamp;
/**
* Creates the builder with mandatory {@link ByteBuffer} and the represented image.
*
* <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height}
* and {@code imageFormat}.
*
* @param byteBuffer image data object.
* @param width the width of the represented image.
* @param height the height of the represented image.
* @param imageFormat how the data encode the image.
*/
public ByteBufferImageBuilder(
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
this.buffer = byteBuffer;
this.width = width;
this.height = height;
this.imageFormat = imageFormat;
// TODO: Validate bytebuffer size with width, height and image format
this.timestamp = 0;
}
/** Sets value for {@link Image#getTimestamp()}. */
ByteBufferImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
}
}

View File

@ -0,0 +1,58 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer {
private final ByteBuffer buffer;
private final ImageProperties properties;
public ByteBufferImageContainer(
ByteBuffer buffer,
@ImageFormat int imageFormat) {
this.buffer = buffer;
this.properties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(imageFormat)
.build();
}
public ByteBuffer getByteBuffer() {
return buffer;
}
@Override
public ImageProperties getImageProperties() {
return properties;
}
/**
* Returns the image format.
*/
@ImageFormat
public int getImageFormat() {
return properties.getImageFormat();
}
@Override
public void close() {
// No op for ByteBuffer.
}
}

View File

@ -0,0 +1,241 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import androidx.annotation.IntDef;
import androidx.annotation.Nullable;
import java.io.Closeable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
/**
* The wrapper class for image objects.
*
* <p>{@link Image} is designed to be an immutable image container, which could be shared
* cross-platforms.
*
* <p>To construct an {@link Image}, use the provided builders:
*
* <ul>
* <li>{@link ByteBufferImageBuilder}
* <li>{@link BitmapImageBuilder}
* <li>{@link MediaImageBuilder}
* </ul>
*
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
*
* <p>To extract concrete image, first check {@link StorageType} and then use the provided
* extractors:
*
* <ul>
* <li>{@link ByteBufferExtractor}
* <li>{@link BitmapExtractor}
* <li>{@link MediaImageExtractor}
* </ul>
*/
public class Image implements Closeable {
/** Specifies the image format of an image. */
@IntDef({
IMAGE_FORMAT_UNKNOWN,
IMAGE_FORMAT_RGBA,
IMAGE_FORMAT_RGB,
IMAGE_FORMAT_NV12,
IMAGE_FORMAT_NV21,
IMAGE_FORMAT_YV12,
IMAGE_FORMAT_YV21,
IMAGE_FORMAT_YUV_420_888,
IMAGE_FORMAT_ALPHA,
IMAGE_FORMAT_JPEG,
})
@Retention(RetentionPolicy.SOURCE)
public @interface ImageFormat {}
public static final int IMAGE_FORMAT_UNKNOWN = 0;
public static final int IMAGE_FORMAT_RGBA = 1;
public static final int IMAGE_FORMAT_RGB = 2;
public static final int IMAGE_FORMAT_NV12 = 3;
public static final int IMAGE_FORMAT_NV21 = 4;
public static final int IMAGE_FORMAT_YV12 = 5;
public static final int IMAGE_FORMAT_YV21 = 6;
public static final int IMAGE_FORMAT_YUV_420_888 = 7;
public static final int IMAGE_FORMAT_ALPHA = 8;
public static final int IMAGE_FORMAT_JPEG = 9;
/** Specifies the image container type. Would be useful for choosing extractors. */
@IntDef({
STORAGE_TYPE_BITMAP,
STORAGE_TYPE_BYTEBUFFER,
STORAGE_TYPE_MEDIA_IMAGE,
STORAGE_TYPE_IMAGE_PROXY,
})
@Retention(RetentionPolicy.SOURCE)
public @interface StorageType {}
public static final int STORAGE_TYPE_BITMAP = 1;
public static final int STORAGE_TYPE_BYTEBUFFER = 2;
public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
/**
* Returns a list of supported image properties for this {@link Image}.
*
* <p>Currently {@link Image} only support single storage type so the size of return list will
* always be 1.
*
* @see ImageProperties
*/
public List<ImageProperties> getContainedImageProperties() {
return Collections.singletonList(getContainer().getImageProperties());
}
/** Returns the timestamp attached to the image. */
long getTimestamp() {
return timestamp;
}
/** Returns the width of the image. */
public int getWidth() {
return width;
}
/** Returns the height of the image. */
public int getHeight() {
return height;
}
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
private synchronized void acquire() {
referenceCount += 1;
}
/**
* Removes a reference that was previously acquired or init.
*
* <p>When {@link Image} is created, it has 1 reference count.
*
* <p>When the reference count becomes 0, it will release the resource under the hood.
*/
@Override
// TODO: Create an internal flag to indicate image is closed, or use referenceCount
public synchronized void close() {
referenceCount -= 1;
if (referenceCount == 0) {
for (ImageContainer imageContainer : containerMap.values()) {
imageContainer.close();
}
}
}
/** Advanced API access for {@link Image}. */
static final class Internal {
/**
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
*
* <p>This method is more useful for image consumer to acquire a reference so image resource
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
* method.
*
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link Image} anymore.
*
* @see #close()
*/
void acquire() {
image.acquire();
}
private final Image image;
// Only Image creates the internal helper.
private Internal(Image image) {
this.image = image;
}
}
/** Gets {@link Internal} object which contains internal APIs. */
Internal getInternal() {
return new Internal(this);
}
private final Map<ImageProperties, ImageContainer> containerMap;
private final long timestamp;
private final int width;
private final int height;
private int referenceCount;
/** Constructs an {@link Image} with a built container. */
Image(ImageContainer container, long timestamp, int width, int height) {
this.containerMap = new HashMap<>();
containerMap.put(container.getImageProperties(), container);
this.timestamp = timestamp;
this.width = width;
this.height = height;
this.referenceCount = 1;
}
/**
* Gets one available container.
*
* @return the current container.
*/
ImageContainer getContainer() {
// According to the design, in the future we will support multiple containers in one image.
// Currently just return the original container.
// TODO: Cache multiple containers in Image.
return containerMap.values().iterator().next();
}
/**
* Gets container from required {@code storageType}. Returns {@code null} if not existed.
*
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
*/
@Nullable
ImageContainer getContainer(@StorageType int storageType) {
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
if (entry.getKey().getStorageType() == storageType) {
return entry.getValue();
}
}
return null;
}
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
@Nullable
ImageContainer getContainer(ImageProperties imageProperties) {
return containerMap.get(imageProperties);
}
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
boolean addContainer(ImageContainer container) {
ImageProperties imageProperties = container.getImageProperties();
if (containerMap.containsKey(imageProperties)) {
return false;
}
containerMap.put(imageProperties, container);
return true;
}
}

View File

@ -0,0 +1,27 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that can receive {@link Image} */
public interface ImageConsumer {
/**
* Called when an {@link Image} is available.
*
* <p>The argument is only guaranteed to be available until this method returns. if you need to
* extend its life time, acquire it, then release it when done.
*/
void onNewImage(Image image);
}

View File

@ -0,0 +1,25 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Manages internal image data storage. The interface is package-private. */
interface ImageContainer {
/** Returns the properties of the contained image. */
ImageProperties getImageProperties();
/** Close the image container and releases the image resource inside. */
void close();
}

View File

@ -0,0 +1,22 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that produce {@link Image} */
public interface ImageProducer {
/** Sets the consumer that receives the {@link Image}. */
void setImageConsumer(ImageConsumer imageConsumer);
}

View File

@ -0,0 +1,80 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import com.google.auto.value.AutoValue;
import com.google.auto.value.extension.memoized.Memoized;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.Image.StorageType;
/** Groups a set of properties to describe how an image is stored. */
@AutoValue
public abstract class ImageProperties {
/**
* Gets the pixel format of the image.
*
* @see Image.ImageFormat
*/
@ImageFormat
public abstract int getImageFormat();
/**
* Gets the storage type of the image.
*
* @see Image.StorageType
*/
@StorageType
public abstract int getStorageType();
@Memoized
@Override
public abstract int hashCode();
/**
* Creates a builder of {@link ImageProperties}.
*
* @see ImageProperties.Builder
*/
static Builder builder() {
return new AutoValue_ImageProperties.Builder();
}
/** Builds a {@link ImageProperties}. */
@AutoValue.Builder
abstract static class Builder {
/**
* Sets the {@link Image.ImageFormat}.
*
* @see ImageProperties#getImageFormat
*/
abstract Builder setImageFormat(@ImageFormat int value);
/**
* Sets the {@link Image.StorageType}.
*
* @see ImageProperties#getStorageType
*/
abstract Builder setStorageType(@StorageType int value);
/** Builds the {@link ImageProperties}. */
abstract ImageProperties build();
}
// Hide the constructor.
ImageProperties() {}
}

View File

@ -0,0 +1,62 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
/**
* Builds {@link Image} from {@link android.media.Image}.
*
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
* content in it.
*
* <p>Use {@link MediaImageExtractor} to get {@link android.media.Image} you passed in.
*/
@RequiresApi(VERSION_CODES.KITKAT)
public class MediaImageBuilder {
// Mandatory fields.
private final android.media.Image mediaImage;
// Optional fields.
private long timestamp;
/**
* Creates the builder with a mandatory {@link android.media.Image}.
*
* @param mediaImage image data object.
*/
public MediaImageBuilder(android.media.Image mediaImage) {
this.mediaImage = mediaImage;
this.timestamp = 0;
}
/** Sets value for {@link Image#getTimestamp()}. */
MediaImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(
new MediaImageContainer(mediaImage),
timestamp,
mediaImage.getWidth(),
mediaImage.getHeight());
}
}

View File

@ -0,0 +1,73 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.os.Build;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
import com.google.mediapipe.framework.image.Image.ImageFormat;
@RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer {
private final android.media.Image mediaImage;
private final ImageProperties properties;
public MediaImageContainer(android.media.Image mediaImage) {
this.mediaImage = mediaImage;
this.properties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
.build();
}
public android.media.Image getImage() {
return mediaImage;
}
@Override
public ImageProperties getImageProperties() {
return properties;
}
@Override
public void close() {
mediaImage.close();
}
@ImageFormat
static int convertFormatCode(int graphicsFormat) {
// We only cover the format mentioned in
// https://developer.android.com/reference/android/media/Image#getFormat()
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
return Image.IMAGE_FORMAT_RGBA;
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
return Image.IMAGE_FORMAT_RGB;
}
}
switch (graphicsFormat) {
case android.graphics.ImageFormat.JPEG:
return Image.IMAGE_FORMAT_JPEG;
case android.graphics.ImageFormat.YUV_420_888:
return Image.IMAGE_FORMAT_YUV_420_888;
default:
return Image.IMAGE_FORMAT_UNKNOWN;
}
}
}

View File

@ -0,0 +1,49 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
/**
* Utility for extracting {@link android.media.Image} from {@link Image}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
* otherwise {@link IllegalArgumentException} will be thrown.
*/
@RequiresApi(VERSION_CODES.KITKAT)
public class MediaImageExtractor {
private MediaImageExtractor() {}
/**
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
* {@link Image} that built from {@link MediaImageBuilder}.
*
* @param image the image to extract {@link android.media.Image} from.
* @return {@link android.media.Image} that stored in {@link Image}.
* @throws IllegalArgumentException if the extraction failed.
*/
public static android.media.Image extract(Image image) {
ImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
return ((MediaImageContainer) container).getImage();
}
throw new IllegalArgumentException(
"Extract Media Image from an Image created by objects other than Media Image"
+ " is not supported");
}
}

View File

@ -73,9 +73,14 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
// TODO: get the graph's main context from the packet context? // TODO: get the graph's main context from the packet context?
// Or clean up in some other way? // Or clean up in some other way?
if (context_for_deletion) { if (context_for_deletion) {
token = new mediapipe::GlSyncToken( auto sync = mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext(
mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext( context_for_deletion);
context_for_deletion)); // A Java handle to a token is a raw pointer to a std::shared_ptr on the
// heap, cast to a long. If the shared_ptr itself is null, leave the token
// null too.
if (sync) {
token = new mediapipe::GlSyncToken(std::move(sync));
}
} }
return reinterpret_cast<jlong>(token); return reinterpret_cast<jlong>(token);
} }

View File

@ -145,6 +145,7 @@ EOF
"//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/components:android_components",
"//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/java/com/google/mediapipe/glutil", "//mediapipe/java/com/google/mediapipe/glutil",
"//third_party:androidx_annotation", "//third_party:androidx_annotation",
"//third_party:androidx_appcompat", "//third_party:androidx_appcompat",

View File

@ -76,7 +76,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode == core::RunningMode::AUDIO_STREAM); options->running_mode == core::RunningMode::AUDIO_STREAM);
auto classifier_options_proto = std::make_unique<tasks::ClassifierOptions>( auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto( components::ConvertClassifierOptionsToProto(
&(options->classifier_options))); &(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap( options_proto->mutable_classifier_options()->Swap(

View File

@ -136,6 +136,11 @@ void ConfigureAudioToTensorCalculator(
// options { // options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext] // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
// { // {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// max_results: 4 // max_results: 4
// score_threshold: 0.5 // score_threshold: 0.5
// category_allowlist: "foo" // category_allowlist: "foo"
@ -225,15 +230,17 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// Adds inference subgraph and connects its input stream to the output // Adds inference subgraph and connects its input stream to the output
// tensors produced by the AudioToTensorCalculator. // tensors produced by the AudioToTensorCalculator.
auto& inference = AddInference(model_resources, graph); auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag); audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Adds postprocessing calculators and connects them to the graph output. // Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = auto& postprocessing = graph.AddNode(
graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
model_resources, task_options.classifier_options(), model_resources, task_options.classifier_options(),
&postprocessing.GetOptions<ClassificationPostprocessingOptions>())); &postprocessing.GetOptions<
tasks::components::ClassificationPostprocessingOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on // Time aggregation is only needed for performing audio classification on

View File

@ -37,7 +37,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -168,7 +167,7 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->classifier_options.max_results = 3; options->classifier_options.max_results = 3;
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -192,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->classifier_options.max_results = 0; options->classifier_options.max_results = 0;
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -208,7 +207,7 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.category_allowlist.push_back("foo"); options->classifier_options.category_allowlist.push_back("foo");
options->classifier_options.category_denylist.push_back("bar"); options->classifier_options.category_denylist.push_back("bar");
@ -226,7 +225,7 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) {
TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) { TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -242,7 +241,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) {
TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = 16000; options->sample_rate = 16000;
@ -260,7 +259,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<ClassificationResult> status_or_result) {};
@ -279,7 +278,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) { TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback = options->result_callback =
@ -301,7 +300,7 @@ class ClassifyTest : public tflite_shims::testing::Test {};
TEST_F(ClassifyTest, Succeeds) { TEST_F(ClassifyTest, Succeeds) {
auto audio_buffer = GetAudioData(k16kTestWavFilename); auto audio_buffer = GetAudioData(k16kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -315,7 +314,7 @@ TEST_F(ClassifyTest, Succeeds) {
TEST_F(ClassifyTest, SucceedsWithResampling) { TEST_F(ClassifyTest, SucceedsWithResampling) {
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -330,7 +329,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename); auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename);
auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename); auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -349,7 +348,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
TEST_F(ClassifyTest, SucceedsWithInsufficientData) { TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -374,7 +373,7 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename); auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -388,7 +387,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename); auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -404,7 +403,7 @@ TEST_F(ClassifyTest,
auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename); auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename);
auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename); auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
@ -424,7 +423,7 @@ TEST_F(ClassifyTest,
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.35f; options->classifier_options.score_threshold = 0.35f;
@ -440,7 +439,7 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.score_threshold = 0.35f; options->classifier_options.score_threshold = 0.35f;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -455,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.score_threshold = 0.1f; options->classifier_options.score_threshold = 0.1f;
options->classifier_options.category_allowlist.push_back("Speech"); options->classifier_options.category_allowlist.push_back("Speech");
@ -471,7 +470,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.score_threshold = 0.9f; options->classifier_options.score_threshold = 0.9f;
options->classifier_options.category_denylist.push_back("Speech"); options->classifier_options.category_denylist.push_back("Speech");
@ -499,7 +498,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
constexpr int kSampleRateHz = 48000; constexpr int kSampleRateHz = 48000;
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
@ -529,7 +528,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
constexpr int kSampleRateHz = 48000; constexpr int kSampleRateHz = 48000;
auto audio_buffer = GetAudioData(k48kTestWavFilename); auto audio_buffer = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioClassifierOptions>(); auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_file_name = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components:classifier_options_proto", "//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_classifier.proto; package mediapipe.tasks.audio.audio_classifier.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/classifier_options.proto"; import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message AudioClassifierOptions { message AudioClassifierOptions {
@ -31,7 +31,7 @@ message AudioClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
optional ClassifierOptions classifier_options = 2; optional components.proto.ClassifierOptions classifier_options = 2;
// The default sample rate of the input audio. Must be set when the // The default sample rate of the input audio. Must be set when the
// AudioClassifier is configured to process audio stream data. // AudioClassifier is configured to process audio stream data.

View File

@ -35,6 +35,8 @@ cc_library(
deps = [ deps = [
":image_preprocessing_options_cc_proto", ":image_preprocessing_options_cc_proto",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/image:image_clone_calculator",
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
@ -56,21 +58,11 @@ cc_library(
# TODO: Enable this test # TODO: Enable this test
mediapipe_proto_library(
name = "segmenter_options_proto",
srcs = ["segmenter_options.proto"],
)
cc_library( cc_library(
name = "classifier_options", name = "classifier_options",
srcs = ["classifier_options.cc"], srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"], hdrs = ["classifier_options.h"],
deps = [":classifier_options_cc_proto"], deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
)
mediapipe_proto_library(
name = "classifier_options_proto",
srcs = ["classifier_options.proto"],
) )
mediapipe_proto_library( mediapipe_proto_library(
@ -81,6 +73,7 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
], ],
) )
@ -90,7 +83,6 @@ cc_library(
hdrs = ["classification_postprocessing.h"], hdrs = ["classification_postprocessing.h"],
deps = [ deps = [
":classification_postprocessing_options_cc_proto", ":classification_postprocessing_options_cc_proto",
":classifier_options_cc_proto",
"//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator", "//mediapipe/calculators/tensor:tensors_dequantization_calculator",
@ -104,7 +96,12 @@ cc_library(
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
@ -119,3 +116,38 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -113,3 +113,66 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
cc_library(
name = "end_loop_calculator",
srcs = ["end_loop_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
],
alwayslink = 1,
)
mediapipe_proto_library(
name = "tensors_to_embeddings_calculator_proto",
srcs = ["tensors_to_embeddings_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
],
)
cc_library(
name = "tensors_to_embeddings_calculator",
srcs = ["tensors_to_embeddings_calculator.cc"],
deps = [
":tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)
cc_test(
name = "tensors_to_embeddings_calculator_test",
srcs = ["tensors_to_embeddings_calculator_test.cc"],
deps = [
":tensors_to_embeddings_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"@com_google_absl//absl/status",
],
)

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include <vector>
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
// Specialized EndLoopCalculator for Tasks specific types.
namespace mediapipe::tasks {
typedef EndLoopCalculator<std::vector<ClassificationResult>>
EndLoopClassificationResultCalculator;
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);
} // namespace mediapipe::tasks

View File

@ -25,7 +25,7 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:image_format_proto", "//mediapipe/framework/formats:image_format_proto",
"//mediapipe/tasks/cc/components:segmenter_options_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_proto",
"//mediapipe/util:label_map_proto", "//mediapipe/util:label_map_proto",
], ],
) )
@ -45,7 +45,7 @@ cc_library(
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components:segmenter_options_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/tasks/cc/vision/utils:image_utils",
"//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",

View File

@ -36,19 +36,22 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace tasks {
namespace { namespace {
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::ImageFrameSharedPtr; using ::mediapipe::ImageFrameSharedPtr;
using ::mediapipe::tasks::SegmenterOptions; using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions;
using ::mediapipe::tasks::components::proto::SegmenterOptions;
using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::GetImageLikeTensorShape;
using ::mediapipe::tasks::vision::Shape; using ::mediapipe::tasks::vision::Shape;
@ -254,7 +257,7 @@ std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResult(
return segmented_masks; return segmented_masks;
} }
MEDIAPIPE_REGISTER_NODE(TensorsToSegmentationCalculator); MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator);
} // namespace api2 } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks; package mediapipe.tasks;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/segmenter_options.proto"; import "mediapipe/tasks/cc/components/proto/segmenter_options.proto";
import "mediapipe/util/label_map.proto"; import "mediapipe/util/label_map.proto";
message TensorsToSegmentationCalculatorOptions { message TensorsToSegmentationCalculatorOptions {
@ -26,7 +26,7 @@ message TensorsToSegmentationCalculatorOptions {
optional TensorsToSegmentationCalculatorOptions ext = 458105876; optional TensorsToSegmentationCalculatorOptions ext = 458105876;
} }
optional SegmenterOptions segmenter_options = 1; optional components.proto.SegmenterOptions segmenter_options = 1;
// Identifying information for each classification label. // Identifying information for each classification label.
map<int64, mediapipe.LabelMapItem> label_items = 2; map<int64, mediapipe.LabelMapItem> label_items = 2;

View File

@ -117,7 +117,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation" output_stream: "SEGMENTATION:segmentation"
options { options {
@ -144,7 +144,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation" output_stream: "SEGMENTATION:segmentation"
options { options {
@ -172,7 +172,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "SEGMENTATION:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1" output_stream: "SEGMENTATION:1:segmented_mask_1"
@ -217,7 +217,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "SEGMENTATION:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1" output_stream: "SEGMENTATION:1:segmented_mask_1"
@ -258,7 +258,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "SEGMENTATION:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1" output_stream: "SEGMENTATION:1:segmented_mask_1"
@ -300,7 +300,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation" output_stream: "SEGMENTATION:segmentation"
options { options {
@ -333,7 +333,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
CalculatorRunner runner( CalculatorRunner runner(
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(
R"pb( R"pb(
calculator: "TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
input_stream: "OUTPUT_SIZE:size" input_stream: "OUTPUT_SIZE:size"
output_stream: "SEGMENTATION:segmentation" output_stream: "SEGMENTATION:segmentation"

View File

@ -0,0 +1,158 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
// case all values are 0.
float GetInverseL2Norm(const float* values, int size) {
float squared_l2_norm = 0.0f;
for (int i = 0; i < size; ++i) {
squared_l2_norm += values[i] * values[i];
}
float inv_l2_norm = 1.0f;
if (squared_l2_norm > 0.0f) {
inv_l2_norm = 1.0f / std::sqrt(squared_l2_norm);
}
return inv_l2_norm;
}
} // namespace
// Converts tensors into an EmbeddingResult object, performing optional
// L2-normalization and scalar-quantization on-the-fly if required through the
// options.
//
// Input:
// TENSORS - std::vector<Tensor>
// A vector of one or more Tensors of type kFloat32.
// Output:
// EMBEDDINGS - EmbeddingResult
// The contents of the input tensors converted into an EmbeddingResult
// proto.
class TensorsToEmbeddingsCalculator : public Node {
public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
bool l2_normalize_;
bool quantize_;
std::vector<std::string> head_names_;
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
};
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
auto options = cc->Options<mediapipe::TensorsToEmbeddingsCalculatorOptions>();
l2_normalize_ = options.embedder_options().l2_normalize();
quantize_ = options.embedder_options().quantize();
if (!options.head_names().empty()) {
head_names_.assign(options.head_names().begin(),
options.head_names().end());
}
return absl::OkStatus();
}
absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
EmbeddingResult result;
const auto& tensors = *kTensorsIn(cc);
if (!head_names_.empty() && tensors.size() != head_names_.size()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Mismatch between number of provided head names (%d) and number "
"of input tensors (%d).",
head_names_.size(), tensors.size()));
}
for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor = tensors[i];
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
auto* embeddings = result.add_embeddings();
embeddings->set_head_index(i);
if (!head_names_.empty()) {
embeddings->set_head_name(head_names_[i]);
}
if (quantize_) {
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries());
} else {
FillFloatEmbeddingEntry(tensor, embeddings->add_entries());
}
}
kEmbeddingsOut(cc).Send(result);
return absl::OkStatus();
}
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry(
const Tensor& tensor, EmbeddingEntry* entry) {
int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* float_embedding = entry->mutable_float_embedding();
for (int i = 0; i < size; ++i) {
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
}
}
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry(
const Tensor& tensor, EmbeddingEntry* entry) {
int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* values = entry->mutable_quantized_embedding()->mutable_values();
values->resize(size);
for (int i = 0; i < size; ++i) {
// Normalize.
float normalized = tensor_buffer[i] * inv_l2_norm;
// Quantize.
int unclamped_value = static_cast<int>(roundf(normalized * 128));
// Clamp and assign.
(*values)[i] =
static_cast<char>(std::max(-128, std::min(unclamped_value, 127)));
}
}
MEDIAPIPE_REGISTER_NODE(TensorsToEmbeddingsCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,35 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
message TensorsToEmbeddingsCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsToEmbeddingsCalculatorOptions ext = 474762326;
}
// The embedder options defining whether to L2-normalize or scalar-quantize
// the outputs.
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options =
1;
// The embedder head names.
repeated string head_names = 2;
}

View File

@ -0,0 +1,249 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::testing::HasSubstr;
using Node = ::mediapipe::CalculatorGraphConfig::Node;
// Builds the graph and feeds inputs.
void BuildGraph(CalculatorRunner* runner,
std::vector<std::vector<float>> tensors) {
auto inputs = std::make_unique<std::vector<Tensor>>();
for (const auto& tensor : tensors) {
inputs->emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{1, static_cast<int>(tensor.size())});
auto view = inputs->back().GetCpuWriteView();
float* buffer = view.buffer<float>();
ASSERT_NE(buffer, nullptr);
for (int i = 0; i < tensor.size(); ++i) {
buffer[i] = tensor[i];
}
}
auto& input_packets = runner->MutableInputs()->Tag("TENSORS").packets;
input_packets.push_back(Adopt(inputs.release()).At(Timestamp(0)));
}
TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {0.2, 0.3}});
auto status = runner.Run();
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
HasSubstr("Mismatch between number of provided head names"));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false }
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } }
head_index: 0
}
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false }
head_names: "foo"
head_names: "bar"
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } }
head_index: 0
head_name: "foo"
}
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
head_name: "bar"
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: false }
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries {
float_embedding { values: 0.44721356 values: 0.8944271 }
}
head_index: 0
}
embeddings {
entries {
float_embedding { values: -0.5547002 values: -0.8320503 }
}
head_index: 1
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: true }
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries {
quantized_embedding { values: "\x0d\x1a" } # 13,26
}
head_index: 0
}
embeddings {
entries {
quantized_embedding { values: "\xe6\xda" } # -26,-38
}
head_index: 1
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest,
SucceedsWithNormalizationAndQuantization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: true }
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries {
quantized_embedding { values: "\x39\x72" } # 57,114
}
head_index: 0
}
embeddings {
entries {
quantized_embedding { values: "\xb9\x95" } # -71,-107
}
head_index: 1
})pb")));
}
} // namespace
} // namespace mediapipe

View File

@ -35,9 +35,12 @@ limitations under the License.
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
@ -47,6 +50,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components {
namespace { namespace {
@ -57,18 +61,21 @@ using ::mediapipe::api2::Timestamp;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::ClassifierOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit; using ::tflite::ProcessUnit;
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
using ::tflite::TensorMetadata; using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Performs sanity checks on provided ClassifierOptions. // Performs sanity checks on provided ClassifierOptions.
@ -183,10 +190,10 @@ absl::StatusOr<LabelItems> GetLabelItemsIfAny(
absl::StatusOr<float> GetScoreThreshold( absl::StatusOr<float> GetScoreThreshold(
const ModelMetadataExtractor& metadata_extractor, const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) { const TensorMetadata& tensor_metadata) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(const ProcessUnit* score_thresholding_process_unit,
const ProcessUnit* score_thresholding_process_unit,
metadata_extractor.FindFirstProcessUnit( metadata_extractor.FindFirstProcessUnit(
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); tensor_metadata,
tflite::ProcessUnitOptions_ScoreThresholdingOptions));
if (score_thresholding_process_unit == nullptr) { if (score_thresholding_process_unit == nullptr) {
return kDefaultScoreThreshold; return kDefaultScoreThreshold;
} }
@ -230,8 +237,51 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
return category_indices; return category_indices;
} }
// Fills in the TensorsToClassificationCalculatorOptions based on the classifier absl::Status ConfigureScoreCalibrationIfAny(
// options and the (optional) output tensor metadata. const ModelMetadataExtractor& metadata_extractor, int tensor_index,
ClassificationPostprocessingOptions* options) {
const auto* tensor_metadata =
metadata_extractor.GetOutputTensorMetadata(tensor_index);
if (tensor_metadata == nullptr) {
return absl::OkStatus();
}
// Get ScoreCalibrationOptions, if any.
ASSIGN_OR_RETURN(const ProcessUnit* score_calibration_process_unit,
metadata_extractor.FindFirstProcessUnit(
*tensor_metadata,
tflite::ProcessUnitOptions_ScoreCalibrationOptions));
if (score_calibration_process_unit == nullptr) {
return absl::OkStatus();
}
auto* score_calibration_options =
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
// Get corresponding AssociatedFile.
auto score_calibration_filename =
metadata_extractor.FindFirstAssociatedFileName(
*tensor_metadata,
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
if (score_calibration_filename.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kNotFound,
"Found ScoreCalibrationOptions but missing required associated "
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
}
ASSIGN_OR_RETURN(
absl::string_view score_calibration_file,
metadata_extractor.GetAssociatedFile(score_calibration_filename));
ScoreCalibrationCalculatorOptions calculator_options;
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
score_calibration_options->score_transformation(),
score_calibration_options->default_score(), score_calibration_file,
&calculator_options));
(*options->mutable_score_calibration_options())[tensor_index] =
calculator_options;
return absl::OkStatus();
}
// Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator( absl::Status ConfigureTensorsToClassificationCalculator(
const ClassifierOptions& options, const ClassifierOptions& options,
const ModelMetadataExtractor& metadata_extractor, int tensor_index, const ModelMetadataExtractor& metadata_extractor, int tensor_index,
@ -303,6 +353,8 @@ absl::Status ConfigureClassificationPostprocessing(
ASSIGN_OR_RETURN(const auto heads_properties, ASSIGN_OR_RETURN(const auto heads_properties,
GetClassificationHeadsProperties(model_resources)); GetClassificationHeadsProperties(model_resources));
for (int i = 0; i < heads_properties.num_heads; ++i) { for (int i = 0; i < heads_properties.num_heads; ++i) {
MP_RETURN_IF_ERROR(ConfigureScoreCalibrationIfAny(
*model_resources.GetMetadataExtractor(), i, options));
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
classifier_options, *model_resources.GetMetadataExtractor(), i, classifier_options, *model_resources.GetMetadataExtractor(), i,
options->add_tensors_to_classifications_options())); options->add_tensors_to_classifications_options()));
@ -314,8 +366,8 @@ absl::Status ConfigureClassificationPostprocessing(
return absl::OkStatus(); return absl::OkStatus();
} }
// A "mediapipe.tasks.ClassificationPostprocessingSubgraph" converts raw // A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
// tensors into ClassificationResult objects. // raw tensors into ClassificationResult objects.
// - Accepts CPU input tensors. // - Accepts CPU input tensors.
// //
// Inputs: // Inputs:
@ -376,18 +428,21 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
} }
// If output tensors are quantized, they must be dequantized first. // If output tensors are quantized, they must be dequantized first.
GenericNode* tensors_dequantization_node; TensorsSource dequantized_tensors(&tensors_in);
if (options.has_quantized_outputs()) { if (options.has_quantized_outputs()) {
tensors_dequantization_node = GenericNode* tensors_dequantization_node =
&graph.AddNode("TensorsDequantizationCalculator"); &graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node->In(kTensorsTag); tensors_in >> tensors_dequantization_node->In(kTensorsTag);
dequantized_tensors = {tensors_dequantization_node, kTensorsTag};
} }
// If there are multiple classification heads, the output tensors need to be // If there are multiple classification heads, the output tensors need to be
// split. // split.
GenericNode* split_tensor_vector_node; std::vector<TensorsSource> split_tensors;
split_tensors.reserve(num_heads);
if (num_heads > 1) { if (num_heads > 1) {
split_tensor_vector_node = &graph.AddNode("SplitTensorVectorCalculator"); GenericNode* split_tensor_vector_node =
&graph.AddNode("SplitTensorVectorCalculator");
auto& split_tensor_vector_options = auto& split_tensor_vector_options =
split_tensor_vector_node split_tensor_vector_node
->GetOptions<mediapipe::SplitVectorCalculatorOptions>(); ->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
@ -395,12 +450,27 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
auto* range = split_tensor_vector_options.add_ranges(); auto* range = split_tensor_vector_options.add_ranges();
range->set_begin(i); range->set_begin(i);
range->set_end(i + 1); range->set_end(i + 1);
split_tensors.emplace_back(split_tensor_vector_node, i);
} }
if (options.has_quantized_outputs()) { dequantized_tensors >> split_tensor_vector_node->In(0);
tensors_dequantization_node->Out(kTensorsTag) >>
split_tensor_vector_node->In(0);
} else { } else {
tensors_in >> split_tensor_vector_node->In(0); split_tensors.emplace_back(dequantized_tensors);
}
// Adds score calibration for heads that specify it, if any.
std::vector<TensorsSource> calibrated_tensors;
calibrated_tensors.reserve(num_heads);
for (int i = 0; i < num_heads; ++i) {
if (options.score_calibration_options().contains(i)) {
GenericNode* score_calibration_node =
&graph.AddNode("ScoreCalibrationCalculator");
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
.CopyFrom(options.score_calibration_options().at(i));
split_tensors[i] >> score_calibration_node->In(kScoresTag);
calibrated_tensors.emplace_back(score_calibration_node,
kCalibratedScoresTag);
} else {
calibrated_tensors.emplace_back(split_tensors[i]);
} }
} }
@ -413,17 +483,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
tensors_to_classification_nodes.back() tensors_to_classification_nodes.back()
->GetOptions<TensorsToClassificationCalculatorOptions>() ->GetOptions<TensorsToClassificationCalculatorOptions>()
.CopyFrom(options.tensors_to_classifications_options(i)); .CopyFrom(options.tensors_to_classifications_options(i));
if (num_heads == 1) { calibrated_tensors[i] >>
if (options.has_quantized_outputs()) {
tensors_dequantization_node->Out(kTensorsTag) >>
tensors_to_classification_nodes.back()->In(kTensorsTag); tensors_to_classification_nodes.back()->In(kTensorsTag);
} else {
tensors_in >> tensors_to_classification_nodes.back()->In(kTensorsTag);
}
} else {
split_tensor_vector_node->Out(i) >>
tensors_to_classification_nodes.back()->In(kTensorsTag);
}
} }
// Aggregates Classifications into a single ClassificationResult. // Aggregates Classifications into a single ClassificationResult.
@ -444,7 +505,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::ClassificationPostprocessingSubgraph); ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
} // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -18,11 +18,12 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components {
// Configures a ClassificationPostprocessing subgraph using the provided model // Configures a ClassificationPostprocessing subgraph using the provided model
// resources and ClassifierOptions. // resources and ClassifierOptions.
@ -31,7 +32,7 @@ namespace tasks {
// Example usage: // Example usage:
// //
// auto& postprocessing = // auto& postprocessing =
// graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( // MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
// model_resources, // model_resources,
// classifier_options, // classifier_options,
@ -49,10 +50,11 @@ namespace tasks {
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results. // The output aggregated classification results.
absl::Status ConfigureClassificationPostprocessing( absl::Status ConfigureClassificationPostprocessing(
const core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const ClassifierOptions& classifier_options, const tasks::components::proto::ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options); ClassificationPostprocessingOptions* options);
} // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,17 +15,22 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe.tasks.components;
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
message ClassificationPostprocessingOptions { message ClassificationPostprocessingOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ClassificationPostprocessingOptions ext = 460416950; optional ClassificationPostprocessingOptions ext = 460416950;
} }
// Optional mapping between output tensor index and corresponding score
// calibration options.
map<int32, ScoreCalibrationCalculatorOptions> score_calibration_options = 4;
// Options for the TensorsToClassification calculators (one per classification // Options for the TensorsToClassification calculators (one per classification
// head) encapsulated by the ClassificationPostprocessing subgraph. // head) encapsulated by the ClassificationPostprocessing subgraph.
repeated mediapipe.TensorsToClassificationCalculatorOptions repeated mediapipe.TensorsToClassificationCalculatorOptions

View File

@ -41,9 +41,10 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
@ -51,6 +52,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components {
namespace { namespace {
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
@ -58,6 +60,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::ClassifierOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::proto::Approximately; using ::testing::proto::Approximately;
@ -65,6 +68,8 @@ using ::testing::proto::Approximately;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
constexpr char kQuantizedImageClassifierWithMetadata[] = constexpr char kQuantizedImageClassifierWithMetadata[] =
"vision/mobilenet_v1_0.25_224_quant.tflite"; "vision/mobilenet_v1_0.25_224_quant.tflite";
constexpr char kQuantizedImageClassifierWithDummyScoreCalibration[] =
"vision/mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
constexpr char kQuantizedImageClassifierWithoutMetadata[] = constexpr char kQuantizedImageClassifierWithoutMetadata[] =
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] = constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] =
@ -147,11 +152,12 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
ClassifierOptions options_in; ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38 min_score_threshold: -3.4028235e+38
top_k: -1 top_k: -1
sort_by_descending_score: true sort_by_descending_score: true
@ -169,11 +175,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
options_in.set_max_results(3); options_in.set_max_results(3);
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38 min_score_threshold: -3.4028235e+38
top_k: 3 top_k: 3
sort_by_descending_score: true sort_by_descending_score: true
@ -191,11 +198,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
options_in.set_score_threshold(0.5); options_in.set_score_threshold(0.5);
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: 0.5 min_score_threshold: 0.5
top_k: -1 top_k: -1
sort_by_descending_score: true sort_by_descending_score: true
@ -212,7 +220,7 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
ClassifierOptions options_in; ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
// Check label map size and two first elements. // Check label map size and two first elements.
@ -229,7 +237,8 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items(); ->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38 min_score_threshold: -3.4028235e+38
top_k: -1 top_k: -1
sort_by_descending_score: true sort_by_descending_score: true
@ -249,14 +258,15 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
options_in.add_category_allowlist("tench"); options_in.add_category_allowlist("tench");
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
// Clear label map and compare the rest of the options. // Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items(); ->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38 min_score_threshold: -3.4028235e+38
top_k: -1 top_k: -1
sort_by_descending_score: true sort_by_descending_score: true
@ -277,14 +287,15 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
options_in.add_category_denylist("background"); options_in.add_category_denylist("background");
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
// Clear label map and compare the rest of the options. // Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items(); ->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38 min_score_threshold: -3.4028235e+38
top_k: -1 top_k: -1
sort_by_descending_score: true sort_by_descending_score: true
@ -297,6 +308,56 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
)pb"))); )pb")));
} }
TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(
kQuantizedImageClassifierWithDummyScoreCalibration));
ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
// Check label map size and two first elements.
EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(),
kMobileNetNumClasses);
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(0),
EqualsProto(R"pb(name: "background")pb"));
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(1),
EqualsProto(R"pb(name: "tench")pb"));
// Clear label map.
options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items();
// Check sigmoids size and first element.
EXPECT_EQ(options_out.score_calibration_options_size(), 1);
auto score_calibration_options =
options_out.score_calibration_options().at(0);
EXPECT_EQ(score_calibration_options.sigmoids_size(), kMobileNetNumClasses);
EXPECT_THAT(score_calibration_options.sigmoids(0),
EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb"));
options_out.mutable_score_calibration_options()->at(0).clear_sigmoids();
// Compare the rest of the options.
EXPECT_THAT(
options_out,
Approximately(EqualsProto(
R"pb(score_calibration_options {
key: 0
value { score_transformation: IDENTITY default_score: 0.5 }
}
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
}
classification_aggregation_options { head_names: "probability" }
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
@ -304,7 +365,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
ClassifierOptions options_in; ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; ClassificationPostprocessingOptions options_out;
MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out)); options_in, &options_out));
// Check label maps sizes and first two elements. // Check label maps sizes and first two elements.
EXPECT_EQ( EXPECT_EQ(
@ -331,7 +392,8 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
options_out.mutable_tensors_to_classifications_options(1) options_out.mutable_tensors_to_classifications_options(1)
->clear_label_items(); ->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(tensors_to_classifications_options { R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38 min_score_threshold: -3.4028235e+38
top_k: -1 top_k: -1
sort_by_descending_score: true sort_by_descending_score: true
@ -358,8 +420,8 @@ class PostprocessingTest : public tflite_shims::testing::Test {
CreateModelResourcesForModel(model_name)); CreateModelResourcesForModel(model_name));
Graph graph; Graph graph;
auto& postprocessing = auto& postprocessing = graph.AddNode(
graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
*model_resources, options, *model_resources, options,
&postprocessing.GetOptions<ClassificationPostprocessingOptions>())); &postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
@ -503,6 +565,52 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
})pb")); })pb"));
} }
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph.
ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(results, EqualsProto(
R"pb(classifications {
entries {
categories {
index: 4
score: 0.6899744811
category_name: "tiger shark"
}
categories {
index: 3
score: 0.6456563062
category_name: "great white shark"
}
categories {
index: 2
score: 0.5986876601
category_name: "goldfish"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph. // Build graph.
ClassifierOptions options; ClassifierOptions options;
@ -621,5 +729,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
} }
} // namespace } // namespace
} // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

Some files were not shown because too many files have changed in this diff Show More