Project import generated by Copybara.
GitOrigin-RevId: bbbbcb4f5174dea33525729ede47c770069157cd
This commit is contained in:
parent
33d683c671
commit
1faeaae7e5
|
@ -120,7 +120,7 @@ just 86.22%.
|
||||||
### Hand Landmark Model
|
### Hand Landmark Model
|
||||||
|
|
||||||
After the palm detection over the whole image our subsequent hand landmark
|
After the palm detection over the whole image our subsequent hand landmark
|
||||||
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite)
|
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite)
|
||||||
performs precise keypoint localization of 21 3D hand-knuckle coordinates inside
|
performs precise keypoint localization of 21 3D hand-knuckle coordinates inside
|
||||||
the detected hand regions via regression, that is direct coordinate prediction.
|
the detected hand regions via regression, that is direct coordinate prediction.
|
||||||
The model learns a consistent internal hand pose representation and is robust
|
The model learns a consistent internal hand pose representation and is robust
|
||||||
|
@ -163,6 +163,11 @@ unrelated, images. Default to `false`.
|
||||||
|
|
||||||
Maximum number of hands to detect. Default to `2`.
|
Maximum number of hands to detect. Default to `2`.
|
||||||
|
|
||||||
|
#### model_complexity
|
||||||
|
|
||||||
|
Complexity of the hand landmark model: `0` or `1`. Landmark accuracy as well as
|
||||||
|
inference latency generally go up with the model complexity. Default to `1`.
|
||||||
|
|
||||||
#### min_detection_confidence
|
#### min_detection_confidence
|
||||||
|
|
||||||
Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the
|
Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the
|
||||||
|
@ -212,6 +217,7 @@ Supported configuration options:
|
||||||
|
|
||||||
* [static_image_mode](#static_image_mode)
|
* [static_image_mode](#static_image_mode)
|
||||||
* [max_num_hands](#max_num_hands)
|
* [max_num_hands](#max_num_hands)
|
||||||
|
* [model_complexity](#model_complexity)
|
||||||
* [min_detection_confidence](#min_detection_confidence)
|
* [min_detection_confidence](#min_detection_confidence)
|
||||||
* [min_tracking_confidence](#min_tracking_confidence)
|
* [min_tracking_confidence](#min_tracking_confidence)
|
||||||
|
|
||||||
|
@ -260,6 +266,7 @@ with mp_hands.Hands(
|
||||||
# For webcam input:
|
# For webcam input:
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
with mp_hands.Hands(
|
with mp_hands.Hands(
|
||||||
|
model_complexity=0,
|
||||||
min_detection_confidence=0.5,
|
min_detection_confidence=0.5,
|
||||||
min_tracking_confidence=0.5) as hands:
|
min_tracking_confidence=0.5) as hands:
|
||||||
while cap.isOpened():
|
while cap.isOpened():
|
||||||
|
@ -302,6 +309,7 @@ and a [fun application], and the following usage example.
|
||||||
Supported configuration options:
|
Supported configuration options:
|
||||||
|
|
||||||
* [maxNumHands](#max_num_hands)
|
* [maxNumHands](#max_num_hands)
|
||||||
|
* [modelComplexity](#model_complexity)
|
||||||
* [minDetectionConfidence](#min_detection_confidence)
|
* [minDetectionConfidence](#min_detection_confidence)
|
||||||
* [minTrackingConfidence](#min_tracking_confidence)
|
* [minTrackingConfidence](#min_tracking_confidence)
|
||||||
|
|
||||||
|
@ -351,6 +359,7 @@ const hands = new Hands({locateFile: (file) => {
|
||||||
}});
|
}});
|
||||||
hands.setOptions({
|
hands.setOptions({
|
||||||
maxNumHands: 2,
|
maxNumHands: 2,
|
||||||
|
modelComplexity: 1,
|
||||||
minDetectionConfidence: 0.5,
|
minDetectionConfidence: 0.5,
|
||||||
minTrackingConfidence: 0.5
|
minTrackingConfidence: 0.5
|
||||||
});
|
});
|
||||||
|
|
|
@ -58,10 +58,12 @@ one over the other.
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite),
|
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite),
|
||||||
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
|
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
|
||||||
* Hand landmark model:
|
* Hand landmark model:
|
||||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite),
|
[TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_lite.tflite),
|
||||||
|
[TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite),
|
||||||
[TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite),
|
[TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite),
|
||||||
[TF.js model](https://tfhub.dev/mediapipe/handskeleton/1)
|
[TF.js model](https://tfhub.dev/mediapipe/handskeleton/1)
|
||||||
* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
|
* [Model card](https://mediapipe.page.link/handmc),
|
||||||
|
[Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
|
||||||
|
|
||||||
### [Pose](https://google.github.io/mediapipe/solutions/pose)
|
### [Pose](https://google.github.io/mediapipe/solutions/pose)
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ hip midpoints.
|
||||||
:----------------------------------------------------------------------------------------------------: |
|
:----------------------------------------------------------------------------------------------------: |
|
||||||
*Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* |
|
*Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* |
|
||||||
|
|
||||||
### Pose Landmark Model (BlazePose GHUM 3D)
|
### Pose Landmark Model (BlazePose [GHUM](https://github.com/google-research/google-research/tree/master/ghum) 3D)
|
||||||
|
|
||||||
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
|
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
|
||||||
(see figure below).
|
(see figure below).
|
||||||
|
@ -486,6 +486,7 @@ on how to build MediaPipe examples.
|
||||||
[BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204)
|
[BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204)
|
||||||
([presentation](https://youtu.be/YPpUOTRn5tA))
|
([presentation](https://youtu.be/YPpUOTRn5tA))
|
||||||
* [Models and model cards](./models.md#pose)
|
* [Models and model cards](./models.md#pose)
|
||||||
|
* [GHUM & GHUML: Generative 3D Human Shape and Articulated Pose Models](https://github.com/google-research/google-research/tree/master/ghum)
|
||||||
* [Web demo](https://code.mediapipe.dev/codepen/pose)
|
* [Web demo](https://code.mediapipe.dev/codepen/pose)
|
||||||
* [Python Colab](https://mediapipe.page.link/pose_py_colab)
|
* [Python Colab](https://mediapipe.page.link/pose_py_colab)
|
||||||
|
|
||||||
|
|
|
@ -531,9 +531,13 @@ cc_test(
|
||||||
":split_vector_calculator",
|
":split_vector_calculator",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -47,4 +47,8 @@ typedef BeginLoopCalculator<std::vector<std::vector<Matrix>>>
|
||||||
BeginLoopMatrixVectorCalculator;
|
BeginLoopMatrixVectorCalculator;
|
||||||
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
|
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
|
||||||
|
|
||||||
|
// A calculator to process std::vector<uint64_t>.
|
||||||
|
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
|
||||||
|
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -14,7 +14,11 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "mediapipe/calculators/core/split_vector_calculator.h"
|
#include "mediapipe/calculators/core/split_vector_calculator.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -301,4 +305,99 @@ TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
class PassThroughAndTsBoundUpdateNode : public mediapipe::api2::Node {
|
||||||
|
public:
|
||||||
|
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
|
||||||
|
static constexpr mediapipe::api2::Output<int> kOutValue{"VALUE"};
|
||||||
|
static constexpr mediapipe::api2::Output<int> kOutTsBoundUpdate{
|
||||||
|
"TS_BOUND_UPDATE"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kInValue, kOutValue, kOutTsBoundUpdate);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
kOutValue(cc).Send(kInValue(cc));
|
||||||
|
kOutTsBoundUpdate(cc).SetNextTimestampBound(
|
||||||
|
cc->InputTimestamp().NextAllowedInStream());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(PassThroughAndTsBoundUpdateNode);
|
||||||
|
|
||||||
|
class ToOptionalNode : public mediapipe::api2::Node {
|
||||||
|
public:
|
||||||
|
static constexpr mediapipe::api2::Input<int> kTick{"TICK"};
|
||||||
|
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
|
||||||
|
static constexpr mediapipe::api2::Output<absl::optional<int>> kOutValue{
|
||||||
|
"OUTPUT"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTick, kInValue, kOutValue);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
if (kInValue(cc).IsEmpty()) {
|
||||||
|
kOutValue(cc).Send(absl::nullopt);
|
||||||
|
} else {
|
||||||
|
kOutValue(cc).Send({kInValue(cc).Get()});
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ToOptionalNode);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
R"pb(
|
||||||
|
input_stream: "select"
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughAndTsBoundUpdateNode"
|
||||||
|
input_stream: "VALUE:select"
|
||||||
|
output_stream: "VALUE:select_ps"
|
||||||
|
output_stream: "TS_BOUND_UPDATE:ts_bound_update"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "MuxCalculator"
|
||||||
|
input_stream: "INPUT:0:select_ps"
|
||||||
|
input_stream: "INPUT:1:ts_bound_update"
|
||||||
|
input_stream: "SELECT:select"
|
||||||
|
output_stream: "OUTPUT:select_or_ts_bound_update"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "ToOptionalNode"
|
||||||
|
input_stream: "TICK:select"
|
||||||
|
input_stream: "VALUE:select_or_ts_bound_update"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
auto send_value_fn = [&](int value, Timestamp ts) -> absl::Status {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
graph.AddPacketToInputStream("select", MakePacket<int>(value).At(ts)));
|
||||||
|
return graph.WaitUntilIdle();
|
||||||
|
};
|
||||||
|
|
||||||
|
MP_ASSERT_OK(send_value_fn(0, Timestamp(1)));
|
||||||
|
ASSERT_EQ(output_packets.size(), 1);
|
||||||
|
EXPECT_EQ(output_packets[0].Get<absl::optional<int>>(), 0);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(send_value_fn(1, Timestamp(2)));
|
||||||
|
ASSERT_EQ(output_packets.size(), 2);
|
||||||
|
EXPECT_EQ(output_packets[1].Get<absl::optional<int>>(), absl::nullopt);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(send_value_fn(0, Timestamp(3)));
|
||||||
|
ASSERT_EQ(output_packets.size(), 3);
|
||||||
|
EXPECT_EQ(output_packets[2].Get<absl::optional<int>>(), 0);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -34,7 +34,6 @@ option java_outer_classname = "InferenceCalculatorProto";
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
message InferenceCalculatorOptions {
|
message InferenceCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional InferenceCalculatorOptions ext = 336783863;
|
optional InferenceCalculatorOptions ext = 336783863;
|
||||||
|
@ -69,8 +68,30 @@ message InferenceCalculatorOptions {
|
||||||
// Load pre-compiled serialized binary cache to accelerate init process.
|
// Load pre-compiled serialized binary cache to accelerate init process.
|
||||||
// Only available for OpenCL delegate on Android.
|
// Only available for OpenCL delegate on Android.
|
||||||
// Kernel caching will only be enabled if this path is set.
|
// Kernel caching will only be enabled if this path is set.
|
||||||
|
//
|
||||||
|
// NOTE: binary cache usage may be skipped if valid serialized model,
|
||||||
|
// specified by "serialized_model_dir", exists.
|
||||||
|
//
|
||||||
|
// TODO: update to cached_kernel_dir
|
||||||
optional string cached_kernel_path = 2;
|
optional string cached_kernel_path = 2;
|
||||||
|
|
||||||
|
// A dir to load from and save to a pre-compiled serialized model used to
|
||||||
|
// accelerate init process.
|
||||||
|
//
|
||||||
|
// NOTE: available for OpenCL delegate on Android only when
|
||||||
|
// "use_advanced_gpu_api" is set to true and "model_token" is set
|
||||||
|
// properly.
|
||||||
|
//
|
||||||
|
// NOTE: serialized model takes precedence over binary cache
|
||||||
|
// specified by "cached_kernel_path", which still can be used if
|
||||||
|
// serialized model is invalid or missing.
|
||||||
|
optional string serialized_model_dir = 7;
|
||||||
|
|
||||||
|
// Unique token identifying the model. Used in conjunction with
|
||||||
|
// "serialized_model_dir". It is the caller's responsibility to ensure
|
||||||
|
// there is no clash of the tokens.
|
||||||
|
optional string model_token = 8;
|
||||||
|
|
||||||
// Encapsulated compilation/runtime tradeoffs.
|
// Encapsulated compilation/runtime tradeoffs.
|
||||||
enum InferenceUsage {
|
enum InferenceUsage {
|
||||||
UNSPECIFIED = 0;
|
UNSPECIFIED = 0;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/util/tflite/config.h"
|
#include "mediapipe/util/tflite/config.h"
|
||||||
|
|
||||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||||
|
@ -49,8 +50,8 @@ class InferenceCalculatorGlImpl
|
||||||
absl::Status Close(CalculatorContext* cc) override;
|
absl::Status Close(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::Status ReadKernelsFromFile();
|
absl::Status ReadGpuCaches();
|
||||||
absl::Status WriteKernelsToFile();
|
absl::Status SaveGpuCaches();
|
||||||
absl::Status LoadModel(CalculatorContext* cc);
|
absl::Status LoadModel(CalculatorContext* cc);
|
||||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||||
|
@ -82,6 +83,8 @@ class InferenceCalculatorGlImpl
|
||||||
|
|
||||||
bool use_kernel_caching_ = false;
|
bool use_kernel_caching_ = false;
|
||||||
std::string cached_kernel_filename_;
|
std::string cached_kernel_filename_;
|
||||||
|
bool use_serialized_model_ = false;
|
||||||
|
std::string serialized_model_path_;
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
|
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
|
||||||
|
@ -114,6 +117,9 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
||||||
tflite_gpu_runner_usage_ = delegate.gpu().usage();
|
tflite_gpu_runner_usage_ = delegate.gpu().usage();
|
||||||
use_kernel_caching_ =
|
use_kernel_caching_ =
|
||||||
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
|
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
|
||||||
|
use_serialized_model_ = use_advanced_gpu_api_ &&
|
||||||
|
delegate.gpu().has_serialized_model_dir() &&
|
||||||
|
delegate.gpu().has_model_token();
|
||||||
use_gpu_delegate_ = !use_advanced_gpu_api_;
|
use_gpu_delegate_ = !use_advanced_gpu_api_;
|
||||||
|
|
||||||
if (use_kernel_caching_) {
|
if (use_kernel_caching_) {
|
||||||
|
@ -123,6 +129,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
||||||
".ker";
|
".ker";
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
}
|
}
|
||||||
|
if (use_serialized_model_) {
|
||||||
|
#ifdef MEDIAPIPE_ANDROID
|
||||||
|
serialized_model_path_ = mediapipe::file::JoinPath(
|
||||||
|
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
|
||||||
|
#endif // MEDIAPIPE_ANDROID
|
||||||
|
}
|
||||||
|
|
||||||
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
||||||
// for everything.
|
// for everything.
|
||||||
|
@ -210,7 +222,7 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
|
absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() {
|
||||||
#ifdef MEDIAPIPE_ANDROID
|
#ifdef MEDIAPIPE_ANDROID
|
||||||
if (use_kernel_caching_) {
|
if (use_kernel_caching_) {
|
||||||
// Save kernel file.
|
// Save kernel file.
|
||||||
|
@ -220,12 +232,22 @@ absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
|
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
|
||||||
}
|
}
|
||||||
|
if (use_serialized_model_) {
|
||||||
|
// Save serialized model file.
|
||||||
|
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
|
||||||
|
tflite_gpu_runner_->GetSerializedModel());
|
||||||
|
absl::string_view serialized_model(
|
||||||
|
reinterpret_cast<char*>(serialized_model_vec.data()),
|
||||||
|
serialized_model_vec.size());
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
|
||||||
|
}
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
|
||||||
MP_RETURN_IF_ERROR(WriteKernelsToFile());
|
MP_RETURN_IF_ERROR(SaveGpuCaches());
|
||||||
if (use_gpu_delegate_) {
|
if (use_gpu_delegate_) {
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
||||||
gpu_buffers_in_.clear();
|
gpu_buffers_in_.clear();
|
||||||
|
@ -239,17 +261,24 @@ absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() {
|
absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
|
||||||
#ifdef MEDIAPIPE_ANDROID
|
#ifdef MEDIAPIPE_ANDROID
|
||||||
if (use_kernel_caching_) {
|
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
|
||||||
// Load pre-compiled kernel file.
|
// Load pre-compiled kernel file.
|
||||||
if (mediapipe::File::Exists(cached_kernel_filename_)) {
|
|
||||||
std::string cache_str;
|
std::string cache_str;
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
|
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
|
||||||
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
|
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
|
||||||
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
|
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
|
||||||
}
|
}
|
||||||
|
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
|
||||||
|
// Load serialized model file.
|
||||||
|
std::string serialized_model_str;
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
file::GetContents(serialized_model_path_, &serialized_model_str));
|
||||||
|
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
|
||||||
|
serialized_model_str.end());
|
||||||
|
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
|
||||||
}
|
}
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -313,7 +342,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||||
}
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(ReadKernelsFromFile());
|
MP_RETURN_IF_ERROR(ReadGpuCaches());
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||||
|
|
||||||
|
|
|
@ -24,20 +24,20 @@ message SsdAnchorsCalculatorOptions {
|
||||||
optional SsdAnchorsCalculatorOptions ext = 247258239;
|
optional SsdAnchorsCalculatorOptions ext = 247258239;
|
||||||
}
|
}
|
||||||
// Size of input images.
|
// Size of input images.
|
||||||
required int32 input_size_width = 1;
|
optional int32 input_size_width = 1; // required
|
||||||
required int32 input_size_height = 2;
|
optional int32 input_size_height = 2; // required
|
||||||
|
|
||||||
// Min and max scales for generating anchor boxes on feature maps.
|
// Min and max scales for generating anchor boxes on feature maps.
|
||||||
required float min_scale = 3;
|
optional float min_scale = 3; // required
|
||||||
required float max_scale = 4;
|
optional float max_scale = 4; // required
|
||||||
|
|
||||||
// The offset for the center of anchors. The value is in the scale of stride.
|
// The offset for the center of anchors. The value is in the scale of stride.
|
||||||
// E.g. 0.5 meaning 0.5 * |current_stride| in pixels.
|
// E.g. 0.5 meaning 0.5 * |current_stride| in pixels.
|
||||||
required float anchor_offset_x = 5 [default = 0.5];
|
optional float anchor_offset_x = 5 [default = 0.5]; // required
|
||||||
required float anchor_offset_y = 6 [default = 0.5];
|
optional float anchor_offset_y = 6 [default = 0.5]; // required
|
||||||
|
|
||||||
// Number of output feature maps to generate the anchors on.
|
// Number of output feature maps to generate the anchors on.
|
||||||
required int32 num_layers = 7;
|
optional int32 num_layers = 7; // required
|
||||||
// Sizes of output feature maps to create anchors. Either feature_map size or
|
// Sizes of output feature maps to create anchors. Either feature_map size or
|
||||||
// stride should be provided.
|
// stride should be provided.
|
||||||
repeated int32 feature_map_width = 8;
|
repeated int32 feature_map_width = 8;
|
||||||
|
|
|
@ -26,12 +26,12 @@ message TfLiteTensorsToDetectionsCalculatorOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The number of output classes predicted by the detection model.
|
// The number of output classes predicted by the detection model.
|
||||||
required int32 num_classes = 1;
|
optional int32 num_classes = 1; // required
|
||||||
// The number of output boxes predicted by the detection model.
|
// The number of output boxes predicted by the detection model.
|
||||||
required int32 num_boxes = 2;
|
optional int32 num_boxes = 2; // required
|
||||||
// The number of output values per boxes predicted by the detection model. The
|
// The number of output values per boxes predicted by the detection model. The
|
||||||
// values contain bounding boxes, keypoints, etc.
|
// values contain bounding boxes, keypoints, etc.
|
||||||
required int32 num_coords = 3;
|
optional int32 num_coords = 3; // required
|
||||||
|
|
||||||
// The offset of keypoint coordinates in the location tensor.
|
// The offset of keypoint coordinates in the location tensor.
|
||||||
optional int32 keypoint_coord_offset = 9;
|
optional int32 keypoint_coord_offset = 9;
|
||||||
|
|
|
@ -31,7 +31,7 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Number of landmarks from the output of the model.
|
// Number of landmarks from the output of the model.
|
||||||
required int32 num_landmarks = 1;
|
optional int32 num_landmarks = 1; // required
|
||||||
|
|
||||||
// Size of the input image for the model. These options are used only when
|
// Size of the input image for the model. These options are used only when
|
||||||
// normalized landmarks are needed. Z coordinate is scaled as X assuming
|
// normalized landmarks are needed. Z coordinate is scaled as X assuming
|
||||||
|
|
|
@ -24,9 +24,9 @@ message TfLiteTensorsToSegmentationCalculatorOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dimensions of input segmentation tensor to process.
|
// Dimensions of input segmentation tensor to process.
|
||||||
required int32 tensor_width = 1;
|
optional int32 tensor_width = 1; // required
|
||||||
required int32 tensor_height = 2;
|
optional int32 tensor_height = 2; // required
|
||||||
required int32 tensor_channels = 3;
|
optional int32 tensor_channels = 3; // required
|
||||||
|
|
||||||
// How much to use previous mask when computing current one; range [0-1].
|
// How much to use previous mask when computing current one; range [0-1].
|
||||||
// This is a tradeoff between responsiveness (0.0) and accuracy (1.0).
|
// This is a tradeoff between responsiveness (0.0) and accuracy (1.0).
|
||||||
|
|
|
@ -98,32 +98,25 @@ public class MainActivity extends AppCompatActivity {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sets up the UI components for the static image demo. */
|
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||||
private void setupStaticImageDemoUiComponents() {
|
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||||
// The Intent to access gallery and read images as bitmap.
|
int width = imageView.getWidth();
|
||||||
imageGetter =
|
int height = imageView.getHeight();
|
||||||
registerForActivityResult(
|
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||||
new ActivityResultContracts.StartActivityForResult(),
|
width = (int) (height * aspectRatio);
|
||||||
result -> {
|
} else {
|
||||||
Intent resultIntent = result.getData();
|
height = (int) (width / aspectRatio);
|
||||||
if (resultIntent != null) {
|
|
||||||
if (result.getResultCode() == RESULT_OK) {
|
|
||||||
Bitmap bitmap = null;
|
|
||||||
try {
|
|
||||||
bitmap =
|
|
||||||
MediaStore.Images.Media.getBitmap(
|
|
||||||
this.getContentResolver(), resultIntent.getData());
|
|
||||||
} catch (IOException e) {
|
|
||||||
Log.e(TAG, "Bitmap reading error:" + e);
|
|
||||||
}
|
}
|
||||||
try {
|
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||||
InputStream imageData =
|
}
|
||||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
|
||||||
|
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||||
int orientation =
|
int orientation =
|
||||||
new ExifInterface(imageData)
|
new ExifInterface(imageData)
|
||||||
.getAttributeInt(
|
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||||
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||||
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
|
return inputBitmap;
|
||||||
|
}
|
||||||
Matrix matrix = new Matrix();
|
Matrix matrix = new Matrix();
|
||||||
switch (orientation) {
|
switch (orientation) {
|
||||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||||
|
@ -138,10 +131,33 @@ public class MainActivity extends AppCompatActivity {
|
||||||
default:
|
default:
|
||||||
matrix.postRotate(0);
|
matrix.postRotate(0);
|
||||||
}
|
}
|
||||||
bitmap =
|
return Bitmap.createBitmap(
|
||||||
Bitmap.createBitmap(
|
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Sets up the UI components for the static image demo. */
|
||||||
|
private void setupStaticImageDemoUiComponents() {
|
||||||
|
// The Intent to access gallery and read images as bitmap.
|
||||||
|
imageGetter =
|
||||||
|
registerForActivityResult(
|
||||||
|
new ActivityResultContracts.StartActivityForResult(),
|
||||||
|
result -> {
|
||||||
|
Intent resultIntent = result.getData();
|
||||||
|
if (resultIntent != null) {
|
||||||
|
if (result.getResultCode() == RESULT_OK) {
|
||||||
|
Bitmap bitmap = null;
|
||||||
|
try {
|
||||||
|
bitmap =
|
||||||
|
downscaleBitmap(
|
||||||
|
MediaStore.Images.Media.getBitmap(
|
||||||
|
this.getContentResolver(), resultIntent.getData()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
Log.e(TAG, "Bitmap reading error:" + e);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
InputStream imageData =
|
||||||
|
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||||
|
bitmap = rotateBitmap(bitmap, imageData);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ android {
|
||||||
buildToolsVersion "30.0.3"
|
buildToolsVersion "30.0.3"
|
||||||
|
|
||||||
defaultConfig {
|
defaultConfig {
|
||||||
applicationId "com.google.mediapipe.apps.hands"
|
applicationId "com.google.mediapipe.apps.facemesh"
|
||||||
minSdkVersion 21
|
minSdkVersion 21
|
||||||
targetSdkVersion 30
|
targetSdkVersion 30
|
||||||
versionCode 1
|
versionCode 1
|
||||||
|
|
|
@ -99,32 +99,25 @@ public class MainActivity extends AppCompatActivity {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sets up the UI components for the static image demo. */
|
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||||
private void setupStaticImageDemoUiComponents() {
|
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||||
// The Intent to access gallery and read images as bitmap.
|
int width = imageView.getWidth();
|
||||||
imageGetter =
|
int height = imageView.getHeight();
|
||||||
registerForActivityResult(
|
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||||
new ActivityResultContracts.StartActivityForResult(),
|
width = (int) (height * aspectRatio);
|
||||||
result -> {
|
} else {
|
||||||
Intent resultIntent = result.getData();
|
height = (int) (width / aspectRatio);
|
||||||
if (resultIntent != null) {
|
|
||||||
if (result.getResultCode() == RESULT_OK) {
|
|
||||||
Bitmap bitmap = null;
|
|
||||||
try {
|
|
||||||
bitmap =
|
|
||||||
MediaStore.Images.Media.getBitmap(
|
|
||||||
this.getContentResolver(), resultIntent.getData());
|
|
||||||
} catch (IOException e) {
|
|
||||||
Log.e(TAG, "Bitmap reading error:" + e);
|
|
||||||
}
|
}
|
||||||
try {
|
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||||
InputStream imageData =
|
}
|
||||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
|
||||||
|
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||||
int orientation =
|
int orientation =
|
||||||
new ExifInterface(imageData)
|
new ExifInterface(imageData)
|
||||||
.getAttributeInt(
|
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||||
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||||
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
|
return inputBitmap;
|
||||||
|
}
|
||||||
Matrix matrix = new Matrix();
|
Matrix matrix = new Matrix();
|
||||||
switch (orientation) {
|
switch (orientation) {
|
||||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||||
|
@ -139,10 +132,33 @@ public class MainActivity extends AppCompatActivity {
|
||||||
default:
|
default:
|
||||||
matrix.postRotate(0);
|
matrix.postRotate(0);
|
||||||
}
|
}
|
||||||
bitmap =
|
return Bitmap.createBitmap(
|
||||||
Bitmap.createBitmap(
|
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Sets up the UI components for the static image demo. */
|
||||||
|
private void setupStaticImageDemoUiComponents() {
|
||||||
|
// The Intent to access gallery and read images as bitmap.
|
||||||
|
imageGetter =
|
||||||
|
registerForActivityResult(
|
||||||
|
new ActivityResultContracts.StartActivityForResult(),
|
||||||
|
result -> {
|
||||||
|
Intent resultIntent = result.getData();
|
||||||
|
if (resultIntent != null) {
|
||||||
|
if (result.getResultCode() == RESULT_OK) {
|
||||||
|
Bitmap bitmap = null;
|
||||||
|
try {
|
||||||
|
bitmap =
|
||||||
|
downscaleBitmap(
|
||||||
|
MediaStore.Images.Media.getBitmap(
|
||||||
|
this.getContentResolver(), resultIntent.getData()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
Log.e(TAG, "Bitmap reading error:" + e);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
InputStream imageData =
|
||||||
|
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||||
|
bitmap = rotateBitmap(bitmap, imageData);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,32 +100,25 @@ public class MainActivity extends AppCompatActivity {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sets up the UI components for the static image demo. */
|
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||||
private void setupStaticImageDemoUiComponents() {
|
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||||
// The Intent to access gallery and read images as bitmap.
|
int width = imageView.getWidth();
|
||||||
imageGetter =
|
int height = imageView.getHeight();
|
||||||
registerForActivityResult(
|
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||||
new ActivityResultContracts.StartActivityForResult(),
|
width = (int) (height * aspectRatio);
|
||||||
result -> {
|
} else {
|
||||||
Intent resultIntent = result.getData();
|
height = (int) (width / aspectRatio);
|
||||||
if (resultIntent != null) {
|
|
||||||
if (result.getResultCode() == RESULT_OK) {
|
|
||||||
Bitmap bitmap = null;
|
|
||||||
try {
|
|
||||||
bitmap =
|
|
||||||
MediaStore.Images.Media.getBitmap(
|
|
||||||
this.getContentResolver(), resultIntent.getData());
|
|
||||||
} catch (IOException e) {
|
|
||||||
Log.e(TAG, "Bitmap reading error:" + e);
|
|
||||||
}
|
}
|
||||||
try {
|
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||||
InputStream imageData =
|
}
|
||||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
|
||||||
|
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||||
int orientation =
|
int orientation =
|
||||||
new ExifInterface(imageData)
|
new ExifInterface(imageData)
|
||||||
.getAttributeInt(
|
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||||
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||||
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
|
return inputBitmap;
|
||||||
|
}
|
||||||
Matrix matrix = new Matrix();
|
Matrix matrix = new Matrix();
|
||||||
switch (orientation) {
|
switch (orientation) {
|
||||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||||
|
@ -140,10 +133,33 @@ public class MainActivity extends AppCompatActivity {
|
||||||
default:
|
default:
|
||||||
matrix.postRotate(0);
|
matrix.postRotate(0);
|
||||||
}
|
}
|
||||||
bitmap =
|
return Bitmap.createBitmap(
|
||||||
Bitmap.createBitmap(
|
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Sets up the UI components for the static image demo. */
|
||||||
|
private void setupStaticImageDemoUiComponents() {
|
||||||
|
// The Intent to access gallery and read images as bitmap.
|
||||||
|
imageGetter =
|
||||||
|
registerForActivityResult(
|
||||||
|
new ActivityResultContracts.StartActivityForResult(),
|
||||||
|
result -> {
|
||||||
|
Intent resultIntent = result.getData();
|
||||||
|
if (resultIntent != null) {
|
||||||
|
if (result.getResultCode() == RESULT_OK) {
|
||||||
|
Bitmap bitmap = null;
|
||||||
|
try {
|
||||||
|
bitmap =
|
||||||
|
downscaleBitmap(
|
||||||
|
MediaStore.Images.Media.getBitmap(
|
||||||
|
this.getContentResolver(), resultIntent.getData()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
Log.e(TAG, "Bitmap reading error:" + e);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
InputStream imageData =
|
||||||
|
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||||
|
bitmap = rotateBitmap(bitmap, imageData);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ android_binary(
|
||||||
assets = [
|
assets = [
|
||||||
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
|
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||||
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
||||||
],
|
],
|
||||||
assets_dir = "",
|
assets_dir = "",
|
||||||
|
|
|
@ -39,7 +39,7 @@ android_binary(
|
||||||
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
|
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
|
||||||
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
|
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
|
||||||
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||||
|
|
|
@ -62,7 +62,7 @@ objc_library(
|
||||||
copts = ["-std=c++17"],
|
copts = ["-std=c++17"],
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
|
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
||||||
],
|
],
|
||||||
|
|
|
@ -57,7 +57,7 @@ objc_library(
|
||||||
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
|
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
|
||||||
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
|
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
|
||||||
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||||
|
|
|
@ -427,7 +427,8 @@ absl::Status CalculatorGraph::Initialize(
|
||||||
const std::map<std::string, Packet>& side_packets) {
|
const std::map<std::string, Packet>& side_packets) {
|
||||||
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
|
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
|
||||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(
|
MP_RETURN_IF_ERROR(validated_graph->Initialize(
|
||||||
input_config, /*graph_registry=*/nullptr, &service_manager_));
|
input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr,
|
||||||
|
&service_manager_));
|
||||||
return Initialize(std::move(validated_graph), side_packets);
|
return Initialize(std::move(validated_graph), side_packets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ class CalculatorRunner {
|
||||||
const StreamContentsSet& Outputs() const { return *outputs_; }
|
const StreamContentsSet& Outputs() const { return *outputs_; }
|
||||||
|
|
||||||
// Returns the access to the output side packets.
|
// Returns the access to the output side packets.
|
||||||
const PacketSet& OutputSidePackets() { return *output_side_packets_.get(); }
|
const PacketSet& OutputSidePackets() { return *output_side_packets_; }
|
||||||
|
|
||||||
// Returns a graph counter.
|
// Returns a graph counter.
|
||||||
mediapipe::Counter* GetCounter(const std::string& name);
|
mediapipe::Counter* GetCounter(const std::string& name);
|
||||||
|
|
|
@ -77,13 +77,6 @@ bool Image::ConvertToGpu() const {
|
||||||
#else
|
#else
|
||||||
// GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture)
|
// GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture)
|
||||||
auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_);
|
auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_);
|
||||||
glBindTexture(GL_TEXTURE_2D, buffer->name());
|
|
||||||
// See GlCalculatorHelperImpl::SetStandardTextureParams
|
|
||||||
glTexParameteri(buffer->target(), GL_TEXTURE_MIN_FILTER, GL_LINEAR);
|
|
||||||
glTexParameteri(buffer->target(), GL_TEXTURE_MAG_FILTER, GL_LINEAR);
|
|
||||||
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
|
|
||||||
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
|
|
||||||
glBindTexture(GL_TEXTURE_2D, 0);
|
|
||||||
glFlush();
|
glFlush();
|
||||||
gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer));
|
gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer));
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
|
|
@ -244,6 +244,8 @@ cc_test(
|
||||||
srcs = ["mux_input_stream_handler_test.cc"],
|
srcs = ["mux_input_stream_handler_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":mux_input_stream_handler",
|
":mux_input_stream_handler",
|
||||||
|
"//mediapipe/calculators/core:gate_calculator",
|
||||||
|
"//mediapipe/calculators/core:make_pair_calculator",
|
||||||
"//mediapipe/calculators/core:mux_calculator",
|
"//mediapipe/calculators/core:mux_calculator",
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
"//mediapipe/calculators/core:round_robin_demux_calculator",
|
"//mediapipe/calculators/core:round_robin_demux_calculator",
|
||||||
|
|
|
@ -75,14 +75,31 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
||||||
int control_value = control_packet.Get<int>();
|
int control_value = control_packet.Get<int>();
|
||||||
CHECK_LE(0, control_value);
|
CHECK_LE(0, control_value);
|
||||||
CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1);
|
CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1);
|
||||||
|
|
||||||
const auto& data_stream = input_stream_managers_.Get(
|
const auto& data_stream = input_stream_managers_.Get(
|
||||||
input_stream_managers_.BeginId() + control_value);
|
input_stream_managers_.BeginId() + control_value);
|
||||||
|
|
||||||
|
// Data stream may contain some outdated packets which failed to be popped
|
||||||
|
// out during "FillInputSet". (This handler doesn't sync input streams,
|
||||||
|
// hence "FillInputSet" can be triggerred before every input stream is
|
||||||
|
// filled with packets corresponding to the same timestamp.)
|
||||||
|
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
|
||||||
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
|
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
|
||||||
if (empty) {
|
if (empty) {
|
||||||
CHECK_LE(stream_timestamp, *min_stream_timestamp);
|
if (stream_timestamp <= *min_stream_timestamp) {
|
||||||
|
// "data_stream" didn't receive a packet corresponding to the current
|
||||||
|
// "control_stream" packet yet.
|
||||||
return NodeReadiness::kNotReady;
|
return NodeReadiness::kNotReady;
|
||||||
}
|
}
|
||||||
|
// "data_stream" timestamp bound update detected.
|
||||||
|
return NodeReadiness::kReadyForProcess;
|
||||||
|
}
|
||||||
|
if (stream_timestamp > *min_stream_timestamp) {
|
||||||
|
// The earliest packet "data_stream" holds corresponds to a control packet
|
||||||
|
// yet to arrive, which means there won't be a "data_stream" packet
|
||||||
|
// corresponding to the current "control_stream" packet, which should be
|
||||||
|
// indicated as timestamp boun update.
|
||||||
|
return NodeReadiness::kReadyForProcess;
|
||||||
|
}
|
||||||
CHECK_EQ(stream_timestamp, *min_stream_timestamp);
|
CHECK_EQ(stream_timestamp, *min_stream_timestamp);
|
||||||
return NodeReadiness::kReadyForProcess;
|
return NodeReadiness::kReadyForProcess;
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -19,9 +20,10 @@
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
|
||||||
// A regression test for b/31620439. MuxInputStreamHandler's accesses to the
|
// A regression test for b/31620439. MuxInputStreamHandler's accesses to the
|
||||||
// control and data streams should be atomic so that it has a consistent view
|
// control and data streams should be atomic so that it has a consistent view
|
||||||
// of the two streams. None of the CHECKs in the GetNodeReadiness() method of
|
// of the two streams. None of the CHECKs in the GetNodeReadiness() method of
|
||||||
|
@ -87,5 +89,561 @@ TEST(MuxInputStreamHandlerTest, AtomicAccessToControlAndDataStreams) {
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MATCHER_P2(IntPacket, value, ts, "") {
|
||||||
|
return arg.template Get<int>() == value && arg.Timestamp() == ts;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GateAndMuxGraphInput {
|
||||||
|
int input0;
|
||||||
|
int input1;
|
||||||
|
int input2;
|
||||||
|
int select;
|
||||||
|
bool allow0;
|
||||||
|
bool allow1;
|
||||||
|
bool allow2;
|
||||||
|
Timestamp at;
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr char kGateAndMuxGraph[] = R"pb(
|
||||||
|
input_stream: "input0"
|
||||||
|
input_stream: "input1"
|
||||||
|
input_stream: "input2"
|
||||||
|
input_stream: "select"
|
||||||
|
input_stream: "allow0"
|
||||||
|
input_stream: "allow1"
|
||||||
|
input_stream: "allow2"
|
||||||
|
node {
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "ALLOW:allow0"
|
||||||
|
input_stream: "input0"
|
||||||
|
output_stream: "output0"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "ALLOW:allow1"
|
||||||
|
input_stream: "input1"
|
||||||
|
output_stream: "output1"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "ALLOW:allow2"
|
||||||
|
input_stream: "input2"
|
||||||
|
output_stream: "output2"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "MuxCalculator"
|
||||||
|
input_stream: "INPUT:0:output0"
|
||||||
|
input_stream: "INPUT:1:output1"
|
||||||
|
input_stream: "INPUT:2:output2"
|
||||||
|
input_stream: "SELECT:select"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||||
|
})pb";
|
||||||
|
|
||||||
|
absl::Status SendInput(GateAndMuxGraphInput in, CalculatorGraph& graph) {
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"input0", MakePacket<int>(in.input0).At(in.at)));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(in.input1).At(in.at)));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"input2", MakePacket<int>(in.input2).At(in.at)));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(in.select).At(in.at)));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"allow0", MakePacket<bool>(in.allow0).At(in.at)));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"allow1", MakePacket<bool>(in.allow1).At(in.at)));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"allow2", MakePacket<bool>(in.allow2).At(in.at)));
|
||||||
|
return graph.WaitUntilIdle();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest, BasicMuxing) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 0,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(1)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 1,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(2)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||||
|
IntPacket(900, Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 2,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(3)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||||
|
IntPacket(900, Timestamp(2)),
|
||||||
|
IntPacket(800, Timestamp(3))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest, MuxingNonEmptyInputs) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 0,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(1)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 1,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(2)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||||
|
IntPacket(900, Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 2,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(3)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||||
|
IntPacket(900, Timestamp(2)),
|
||||||
|
IntPacket(800, Timestamp(3))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest, MuxingAllTimestampBoundUpdates) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 0,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(1)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 1,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(2)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 2,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(3)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest, MuxingSlectedTimestampBoundUpdates) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 0,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(1)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 1,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(2)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 2,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(3)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest, MuxingSometimesTimestampBoundUpdates) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 0,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(1)},
|
||||||
|
graph));
|
||||||
|
EXPECT_TRUE(output_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 1,
|
||||||
|
.allow0 = false,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(2)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 2,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(3)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||||
|
.input1 = 900,
|
||||||
|
.input2 = 800,
|
||||||
|
.select = 2,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = true,
|
||||||
|
.allow2 = true,
|
||||||
|
.at = Timestamp(4)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
|
||||||
|
IntPacket(800, Timestamp(4))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendInput({.input0 = 700,
|
||||||
|
.input1 = 600,
|
||||||
|
.input2 = 500,
|
||||||
|
.select = 0,
|
||||||
|
.allow0 = true,
|
||||||
|
.allow1 = false,
|
||||||
|
.allow2 = false,
|
||||||
|
.at = Timestamp(5)},
|
||||||
|
graph));
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
|
||||||
|
IntPacket(800, Timestamp(4)),
|
||||||
|
IntPacket(700, Timestamp(5))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
MATCHER_P(EmptyPacket, ts, "") {
|
||||||
|
return arg.IsEmpty() && arg.Timestamp() == ts;
|
||||||
|
}
|
||||||
|
|
||||||
|
MATCHER_P2(Pair, m1, m2, "") {
|
||||||
|
const auto& p = arg.template Get<std::pair<Packet, Packet>>();
|
||||||
|
return testing::Matches(m1)(p.first) && testing::Matches(m2)(p.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest,
|
||||||
|
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacket) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input0"
|
||||||
|
input_stream: "input1"
|
||||||
|
input_stream: "select"
|
||||||
|
node {
|
||||||
|
calculator: "MuxCalculator"
|
||||||
|
input_stream: "INPUT:0:input0"
|
||||||
|
input_stream: "INPUT:1:input1"
|
||||||
|
input_stream: "SELECT:select"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "MakePairCalculator"
|
||||||
|
input_stream: "select"
|
||||||
|
input_stream: "output"
|
||||||
|
output_stream: "pair"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> pair_packets;
|
||||||
|
tool::AddVectorSink("pair", &config, &pair_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_TRUE(pair_packets.empty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input0", MakePacket<int>(1000).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
|
||||||
|
EmptyPacket(Timestamp(1)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(900).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(800).At(Timestamp(4))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
|
||||||
|
EmptyPacket(Timestamp(1)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
pair_packets,
|
||||||
|
ElementsAre(
|
||||||
|
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||||
|
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(1).At(Timestamp(3))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
pair_packets,
|
||||||
|
ElementsAre(
|
||||||
|
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||||
|
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
|
||||||
|
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(1).At(Timestamp(4))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
pair_packets,
|
||||||
|
ElementsAre(
|
||||||
|
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||||
|
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
|
||||||
|
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
|
||||||
|
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest,
|
||||||
|
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacketPacketsAtOnce) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input0"
|
||||||
|
input_stream: "input1"
|
||||||
|
input_stream: "select"
|
||||||
|
node {
|
||||||
|
calculator: "MuxCalculator"
|
||||||
|
input_stream: "INPUT:0:input0"
|
||||||
|
input_stream: "INPUT:1:input1"
|
||||||
|
input_stream: "SELECT:select"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "MakePairCalculator"
|
||||||
|
input_stream: "select"
|
||||||
|
input_stream: "output"
|
||||||
|
output_stream: "pair"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> pair_packets;
|
||||||
|
tool::AddVectorSink("pair", &config, &pair_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input0", MakePacket<int>(1000).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(900).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(800).At(Timestamp(4))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(1).At(Timestamp(3))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(1).At(Timestamp(4))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
pair_packets,
|
||||||
|
ElementsAre(
|
||||||
|
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||||
|
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
|
||||||
|
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
|
||||||
|
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest,
|
||||||
|
TimestampBoundUpdateTriggersTimestampBoundUpdate) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input0"
|
||||||
|
input_stream: "input1"
|
||||||
|
input_stream: "select"
|
||||||
|
input_stream: "allow0"
|
||||||
|
input_stream: "allow1"
|
||||||
|
node {
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "ALLOW:allow0"
|
||||||
|
input_stream: "input0"
|
||||||
|
output_stream: "output0"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "ALLOW:allow1"
|
||||||
|
input_stream: "input1"
|
||||||
|
output_stream: "output1"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "MuxCalculator"
|
||||||
|
input_stream: "INPUT:0:output0"
|
||||||
|
input_stream: "INPUT:1:output1"
|
||||||
|
input_stream: "SELECT:select"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "MakePairCalculator"
|
||||||
|
input_stream: "select"
|
||||||
|
input_stream: "output"
|
||||||
|
output_stream: "pair"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> pair_packets;
|
||||||
|
tool::AddVectorSink("pair", &config, &pair_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input0", MakePacket<int>(1000).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"allow0", MakePacket<bool>(false).At(Timestamp(1))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
|
||||||
|
EmptyPacket(Timestamp(1)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input0", MakePacket<int>(900).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"allow0", MakePacket<bool>(true).At(Timestamp(2))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
pair_packets,
|
||||||
|
ElementsAre(
|
||||||
|
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||||
|
Pair(IntPacket(0, Timestamp(2)), IntPacket(900, Timestamp(2)))));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -56,7 +56,9 @@ class SubgraphContext {
|
||||||
return options_map_.Get<T>();
|
return options_map_.Get<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
const CalculatorGraphConfig::Node& OriginalNode() { return original_node_; }
|
const CalculatorGraphConfig::Node& OriginalNode() const {
|
||||||
|
return original_node_;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ServiceBinding<T> Service(const GraphService<T>& service) const {
|
ServiceBinding<T> Service(const GraphService<T>& service) const {
|
||||||
|
|
|
@ -724,6 +724,7 @@ cc_test(
|
||||||
srcs = ["subgraph_expansion_test.cc"],
|
srcs = ["subgraph_expansion_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":node_chain_subgraph_cc_proto",
|
":node_chain_subgraph_cc_proto",
|
||||||
|
":node_chain_subgraph_options_lib",
|
||||||
":subgraph_expansion",
|
":subgraph_expansion",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
|
|
@ -23,6 +23,8 @@
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/type_map.h"
|
#include "mediapipe/framework/type_map.h"
|
||||||
|
|
||||||
|
#define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging()
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tool {
|
namespace tool {
|
||||||
|
|
||||||
|
@ -47,13 +49,13 @@ absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in,
|
||||||
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
|
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
|
||||||
if (IsLengthDelimited(wire_type)) {
|
if (IsLengthDelimited(wire_type)) {
|
||||||
uint32 length;
|
uint32 length;
|
||||||
RET_CHECK(in->ReadVarint32(&length));
|
RET_CHECK_NO_LOG(in->ReadVarint32(&length));
|
||||||
RET_CHECK(in->ReadString(result, length));
|
RET_CHECK_NO_LOG(in->ReadString(result, length));
|
||||||
} else {
|
} else {
|
||||||
std::string field_data;
|
std::string field_data;
|
||||||
StringOutputStream sos(&field_data);
|
StringOutputStream sos(&field_data);
|
||||||
CodedOutputStream cos(&sos);
|
CodedOutputStream cos(&sos);
|
||||||
RET_CHECK(WireFormatLite::SkipField(in, tag, &cos));
|
RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, &cos));
|
||||||
// Skip the tag written by SkipField.
|
// Skip the tag written by SkipField.
|
||||||
int tag_size = CodedOutputStream::VarintSize32(tag);
|
int tag_size = CodedOutputStream::VarintSize32(tag);
|
||||||
cos.Trim();
|
cos.Trim();
|
||||||
|
@ -67,13 +69,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
|
||||||
CodedInputStream* in,
|
CodedInputStream* in,
|
||||||
std::vector<std::string>* field_values) {
|
std::vector<std::string>* field_values) {
|
||||||
uint32 data_size;
|
uint32 data_size;
|
||||||
RET_CHECK(in->ReadVarint32(&data_size));
|
RET_CHECK_NO_LOG(in->ReadVarint32(&data_size));
|
||||||
// fake_tag encodes the wire-type for calls to WireFormatLite::SkipField.
|
// fake_tag encodes the wire-type for calls to WireFormatLite::SkipField.
|
||||||
uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type);
|
uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type);
|
||||||
while (data_size > 0) {
|
while (data_size > 0) {
|
||||||
std::string number;
|
std::string number;
|
||||||
MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number));
|
MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number));
|
||||||
RET_CHECK_LE(number.size(), data_size);
|
RET_CHECK_NO_LOG(number.size() <= data_size);
|
||||||
field_values->push_back(number);
|
field_values->push_back(number);
|
||||||
data_size -= number.size();
|
data_size -= number.size();
|
||||||
}
|
}
|
||||||
|
@ -98,7 +100,7 @@ absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
|
||||||
field_values->push_back(value);
|
field_values->push_back(value);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(WireFormatLite::SkipField(in, tag, out));
|
RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, out));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -157,12 +159,12 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
|
||||||
MP_RETURN_IF_ERROR(access.SetMessage(*message));
|
MP_RETURN_IF_ERROR(access.SetMessage(*message));
|
||||||
std::vector<std::string>& v = *access.mutable_field_values();
|
std::vector<std::string>& v = *access.mutable_field_values();
|
||||||
if (!proto_path.empty()) {
|
if (!proto_path.empty()) {
|
||||||
RET_CHECK(index >= 0 && index < v.size());
|
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||||
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
|
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
|
||||||
field_type, field_values));
|
field_type, field_values));
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(index >= 0 && index <= v.size());
|
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
|
||||||
RET_CHECK(index + length >= 0 && index + length <= v.size());
|
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
|
||||||
v.erase(v.begin() + index, v.begin() + index + length);
|
v.erase(v.begin() + index, v.begin() + index + length);
|
||||||
v.insert(v.begin() + index, field_values.begin(), field_values.end());
|
v.insert(v.begin() + index, field_values.begin(), field_values.end());
|
||||||
}
|
}
|
||||||
|
@ -184,12 +186,12 @@ absl::Status ProtoUtilLite::GetFieldRange(
|
||||||
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
||||||
std::vector<std::string>& v = *access.mutable_field_values();
|
std::vector<std::string>& v = *access.mutable_field_values();
|
||||||
if (!proto_path.empty()) {
|
if (!proto_path.empty()) {
|
||||||
RET_CHECK(index >= 0 && index < v.size());
|
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
GetFieldRange(v[index], proto_path, length, field_type, field_values));
|
GetFieldRange(v[index], proto_path, length, field_type, field_values));
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(index >= 0 && index <= v.size());
|
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
|
||||||
RET_CHECK(index + length >= 0 && index + length <= v.size());
|
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
|
||||||
field_values->insert(field_values->begin(), v.begin() + index,
|
field_values->insert(field_values->begin(), v.begin() + index,
|
||||||
v.begin() + index + length);
|
v.begin() + index + length);
|
||||||
}
|
}
|
||||||
|
|
|
@ -274,12 +274,14 @@ absl::Status ConnectSubgraphStreams(
|
||||||
|
|
||||||
absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
|
absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
|
||||||
const GraphRegistry* graph_registry,
|
const GraphRegistry* graph_registry,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options,
|
||||||
const GraphServiceManager* service_manager) {
|
const GraphServiceManager* service_manager) {
|
||||||
graph_registry =
|
graph_registry =
|
||||||
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
||||||
RET_CHECK(config);
|
RET_CHECK(config);
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
|
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
|
||||||
CalculatorGraphConfig::Node(), config));
|
graph_options ? *graph_options : CalculatorGraphConfig::Node(), config));
|
||||||
auto* nodes = config->mutable_node();
|
auto* nodes = config->mutable_node();
|
||||||
while (1) {
|
while (1) {
|
||||||
auto subgraph_nodes_start = std::stable_partition(
|
auto subgraph_nodes_start = std::stable_partition(
|
||||||
|
|
|
@ -72,6 +72,7 @@ absl::Status ConnectSubgraphStreams(
|
||||||
absl::Status ExpandSubgraphs(
|
absl::Status ExpandSubgraphs(
|
||||||
CalculatorGraphConfig* config,
|
CalculatorGraphConfig* config,
|
||||||
const GraphRegistry* graph_registry = nullptr,
|
const GraphRegistry* graph_registry = nullptr,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||||
const GraphServiceManager* service_manager = nullptr);
|
const GraphServiceManager* service_manager = nullptr);
|
||||||
|
|
||||||
// Creates a graph wrapping the provided node and exposing all of its
|
// Creates a graph wrapping the provided node and exposing all of its
|
||||||
|
|
|
@ -560,9 +560,111 @@ TEST(SubgraphExpansionTest, GraphServicesUsage) {
|
||||||
MP_ASSERT_OK(service_manager.SetServiceObject(
|
MP_ASSERT_OK(service_manager.SetServiceObject(
|
||||||
kStringTestService, std::make_shared<std::string>("ExpectedNode")));
|
kStringTestService, std::make_shared<std::string>("ExpectedNode")));
|
||||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr,
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr,
|
||||||
|
/*graph_options=*/nullptr,
|
||||||
&service_manager));
|
&service_manager));
|
||||||
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shows SubgraphOptions consumed by GraphRegistry::CreateByName.
|
||||||
|
TEST(SubgraphExpansionTest, SubgraphOptionsUsage) {
|
||||||
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
|
||||||
|
GraphRegistry graph_registry;
|
||||||
|
|
||||||
|
// CalculatorGraph::Initialize passes the SubgraphOptions into:
|
||||||
|
// (1) GraphRegistry::CreateByName("NodeChainSubgraph", options)
|
||||||
|
// (2) tool::ExpandSubgraphs(&config, options)
|
||||||
|
auto graph_options =
|
||||||
|
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
|
||||||
|
options {
|
||||||
|
[mediapipe.NodeChainSubgraphOptions.ext] {
|
||||||
|
node_type: "DoubleIntCalculator"
|
||||||
|
chain_length: 3
|
||||||
|
}
|
||||||
|
})pb");
|
||||||
|
SubgraphContext context(&graph_options, /*service_manager=*/nullptr);
|
||||||
|
|
||||||
|
// "NodeChainSubgraph" consumes graph_options only in CreateByName.
|
||||||
|
auto subgraph_status =
|
||||||
|
graph_registry.CreateByName("", "NodeChainSubgraph", &context);
|
||||||
|
MP_ASSERT_OK(subgraph_status);
|
||||||
|
auto subgraph = std::move(subgraph_status).value();
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
tool::ExpandSubgraphs(&subgraph, &graph_registry, &graph_options));
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected_graph =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "DoubleIntCalculator"
|
||||||
|
input_stream: "stream_0"
|
||||||
|
output_stream: "stream_1"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "DoubleIntCalculator"
|
||||||
|
input_stream: "stream_1"
|
||||||
|
output_stream: "stream_2"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "DoubleIntCalculator"
|
||||||
|
input_stream: "stream_2"
|
||||||
|
output_stream: "stream_3"
|
||||||
|
}
|
||||||
|
input_stream: "INPUT:stream_0"
|
||||||
|
output_stream: "OUTPUT:stream_3"
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows SubgraphOptions consumed by tool::ExpandSubgraphs.
|
||||||
|
TEST(SubgraphExpansionTest, SimpleSubgraphOptionsUsage) {
|
||||||
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
|
||||||
|
GraphRegistry graph_registry;
|
||||||
|
auto moon_options =
|
||||||
|
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
|
||||||
|
options {
|
||||||
|
[mediapipe.NodeChainSubgraphOptions.ext] {
|
||||||
|
node_type: "DoubleIntCalculator"
|
||||||
|
chain_length: 3
|
||||||
|
}
|
||||||
|
})pb");
|
||||||
|
auto moon_subgraph =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
type: "MoonSubgraph"
|
||||||
|
graph_options: {
|
||||||
|
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
calculator: "MoonCalculator"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
|
||||||
|
}
|
||||||
|
option_value: "chain_length:options/chain_length"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
// The moon_options are copied into the graph_options of moon_subgraph.
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
tool::ExpandSubgraphs(&moon_subgraph, &graph_registry, &moon_options));
|
||||||
|
|
||||||
|
// The field chain_length is copied from moon_options into MoonCalculator.
|
||||||
|
CalculatorGraphConfig expected_graph =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "MoonCalculator"
|
||||||
|
node_options {
|
||||||
|
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {
|
||||||
|
chain_length: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
option_value: "chain_length:options/chain_length"
|
||||||
|
}
|
||||||
|
type: "MoonSubgraph"
|
||||||
|
graph_options {
|
||||||
|
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(moon_subgraph, mediapipe::EqualsProto(expected_graph));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -150,12 +150,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
|
||||||
const int packet_count = 10;
|
const int packet_count = 10;
|
||||||
// Send int value packets at {10K, 20K, 30K, ..., 100K}.
|
// Send int value packets at {10K, 20K, 30K, ..., 100K}.
|
||||||
for (uint64 t = 1; t <= packet_count; ++t) {
|
for (uint64 t = 1; t <= packet_count; ++t) {
|
||||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
|
||||||
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
|
||||||
if (send_bounds) {
|
if (send_bounds) {
|
||||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||||
"enable", MakePacket<bool>(true).At(Timestamp(t * 10000))));
|
"enable", MakePacket<bool>(true).At(Timestamp(t * 10000))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
}
|
}
|
||||||
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||||
|
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
||||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
// The inputs are sent to the input stream "foo", they should pass through.
|
// The inputs are sent to the input stream "foo", they should pass through.
|
||||||
EXPECT_EQ(out_foo.size(), t);
|
EXPECT_EQ(out_foo.size(), t);
|
||||||
|
@ -175,12 +176,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
|
||||||
|
|
||||||
// Send int value packets at {110K, 120K, ..., 200K}.
|
// Send int value packets at {110K, 120K, ..., 200K}.
|
||||||
for (uint64 t = 11; t <= packet_count * 2; ++t) {
|
for (uint64 t = 11; t <= packet_count * 2; ++t) {
|
||||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
|
||||||
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
|
||||||
if (send_bounds) {
|
if (send_bounds) {
|
||||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||||
"enable", MakePacket<bool>(false).At(Timestamp(t * 10000))));
|
"enable", MakePacket<bool>(false).At(Timestamp(t * 10000))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
}
|
}
|
||||||
|
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||||
|
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
||||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
// The inputs are sent to the input stream "foo", they should pass through.
|
// The inputs are sent to the input stream "foo", they should pass through.
|
||||||
EXPECT_EQ(out_foo.size(), t);
|
EXPECT_EQ(out_foo.size(), t);
|
||||||
|
|
|
@ -143,11 +143,12 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) {
|
||||||
absl::Status PerformBasicTransforms(
|
absl::Status PerformBasicTransforms(
|
||||||
const CalculatorGraphConfig& input_graph_config,
|
const CalculatorGraphConfig& input_graph_config,
|
||||||
const GraphRegistry* graph_registry,
|
const GraphRegistry* graph_registry,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options,
|
||||||
const GraphServiceManager* service_manager,
|
const GraphServiceManager* service_manager,
|
||||||
CalculatorGraphConfig* output_graph_config) {
|
CalculatorGraphConfig* output_graph_config) {
|
||||||
*output_graph_config = input_graph_config;
|
*output_graph_config = input_graph_config;
|
||||||
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
|
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
|
||||||
service_manager));
|
graph_options, service_manager));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
|
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
|
||||||
|
|
||||||
|
@ -347,6 +348,7 @@ absl::Status NodeTypeInfo::Initialize(
|
||||||
absl::Status ValidatedGraphConfig::Initialize(
|
absl::Status ValidatedGraphConfig::Initialize(
|
||||||
const CalculatorGraphConfig& input_config,
|
const CalculatorGraphConfig& input_config,
|
||||||
const GraphRegistry* graph_registry,
|
const GraphRegistry* graph_registry,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options,
|
||||||
const GraphServiceManager* service_manager) {
|
const GraphServiceManager* service_manager) {
|
||||||
RET_CHECK(!initialized_)
|
RET_CHECK(!initialized_)
|
||||||
<< "ValidatedGraphConfig can be initialized only once.";
|
<< "ValidatedGraphConfig can be initialized only once.";
|
||||||
|
@ -356,8 +358,8 @@ absl::Status ValidatedGraphConfig::Initialize(
|
||||||
<< input_config.DebugString();
|
<< input_config.DebugString();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(PerformBasicTransforms(input_config, graph_registry,
|
MP_RETURN_IF_ERROR(PerformBasicTransforms(
|
||||||
service_manager, &config_));
|
input_config, graph_registry, graph_options, service_manager, &config_));
|
||||||
|
|
||||||
// Initialize the basic node information.
|
// Initialize the basic node information.
|
||||||
MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
|
MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
|
||||||
|
@ -431,22 +433,24 @@ absl::Status ValidatedGraphConfig::Initialize(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ValidatedGraphConfig::Initialize(
|
absl::Status ValidatedGraphConfig::Initialize(
|
||||||
const std::string& graph_type, const Subgraph::SubgraphOptions* options,
|
const std::string& graph_type, const GraphRegistry* graph_registry,
|
||||||
const GraphRegistry* graph_registry,
|
const Subgraph::SubgraphOptions* graph_options,
|
||||||
const GraphServiceManager* service_manager) {
|
const GraphServiceManager* service_manager) {
|
||||||
graph_registry =
|
graph_registry =
|
||||||
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
||||||
SubgraphContext subgraph_context(options, service_manager);
|
SubgraphContext subgraph_context(graph_options, service_manager);
|
||||||
auto status_or_config =
|
auto status_or_config =
|
||||||
graph_registry->CreateByName("", graph_type, &subgraph_context);
|
graph_registry->CreateByName("", graph_type, &subgraph_context);
|
||||||
MP_RETURN_IF_ERROR(status_or_config.status());
|
MP_RETURN_IF_ERROR(status_or_config.status());
|
||||||
return Initialize(status_or_config.value(), graph_registry, service_manager);
|
return Initialize(status_or_config.value(), graph_registry, graph_options,
|
||||||
|
service_manager);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ValidatedGraphConfig::Initialize(
|
absl::Status ValidatedGraphConfig::Initialize(
|
||||||
const std::vector<CalculatorGraphConfig>& input_configs,
|
const std::vector<CalculatorGraphConfig>& input_configs,
|
||||||
const std::vector<CalculatorGraphTemplate>& input_templates,
|
const std::vector<CalculatorGraphTemplate>& input_templates,
|
||||||
const std::string& graph_type, const Subgraph::SubgraphOptions* arguments,
|
const std::string& graph_type,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options,
|
||||||
const GraphServiceManager* service_manager) {
|
const GraphServiceManager* service_manager) {
|
||||||
GraphRegistry graph_registry;
|
GraphRegistry graph_registry;
|
||||||
for (auto& config : input_configs) {
|
for (auto& config : input_configs) {
|
||||||
|
@ -455,7 +459,8 @@ absl::Status ValidatedGraphConfig::Initialize(
|
||||||
for (auto& templ : input_templates) {
|
for (auto& templ : input_templates) {
|
||||||
graph_registry.Register(templ.config().type(), templ);
|
graph_registry.Register(templ.config().type(), templ);
|
||||||
}
|
}
|
||||||
return Initialize(graph_type, arguments, &graph_registry, service_manager);
|
return Initialize(graph_type, &graph_registry, graph_options,
|
||||||
|
service_manager);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {
|
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {
|
||||||
|
|
|
@ -195,17 +195,20 @@ class ValidatedGraphConfig {
|
||||||
// Initializes the ValidatedGraphConfig. This function must be called
|
// Initializes the ValidatedGraphConfig. This function must be called
|
||||||
// before any other functions. Subgraphs are specified through the
|
// before any other functions. Subgraphs are specified through the
|
||||||
// global graph registry or an optional local graph registry.
|
// global graph registry or an optional local graph registry.
|
||||||
absl::Status Initialize(const CalculatorGraphConfig& input_config,
|
absl::Status Initialize(
|
||||||
|
const CalculatorGraphConfig& input_config,
|
||||||
const GraphRegistry* graph_registry = nullptr,
|
const GraphRegistry* graph_registry = nullptr,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||||
const GraphServiceManager* service_manager = nullptr);
|
const GraphServiceManager* service_manager = nullptr);
|
||||||
|
|
||||||
// Initializes the ValidatedGraphConfig from registered graph and subgraph
|
// Initializes the ValidatedGraphConfig from registered graph and subgraph
|
||||||
// configs. Subgraphs are retrieved from the specified graph registry or from
|
// configs. Subgraphs are retrieved from the specified graph registry or from
|
||||||
// the global graph registry. A subgraph can be instantiated directly by
|
// the global graph registry. A subgraph can be instantiated directly by
|
||||||
// specifying its type in |graph_type|.
|
// specifying its type in |graph_type|.
|
||||||
absl::Status Initialize(const std::string& graph_type,
|
absl::Status Initialize(
|
||||||
const Subgraph::SubgraphOptions* options = nullptr,
|
const std::string& graph_type,
|
||||||
const GraphRegistry* graph_registry = nullptr,
|
const GraphRegistry* graph_registry = nullptr,
|
||||||
|
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||||
const GraphServiceManager* service_manager = nullptr);
|
const GraphServiceManager* service_manager = nullptr);
|
||||||
|
|
||||||
// Initializes the ValidatedGraphConfig from the specified graph and subgraph
|
// Initializes the ValidatedGraphConfig from the specified graph and subgraph
|
||||||
|
@ -218,7 +221,7 @@ class ValidatedGraphConfig {
|
||||||
const std::vector<CalculatorGraphConfig>& input_configs,
|
const std::vector<CalculatorGraphConfig>& input_configs,
|
||||||
const std::vector<CalculatorGraphTemplate>& input_templates,
|
const std::vector<CalculatorGraphTemplate>& input_templates,
|
||||||
const std::string& graph_type = "",
|
const std::string& graph_type = "",
|
||||||
const Subgraph::SubgraphOptions* arguments = nullptr,
|
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||||
const GraphServiceManager* service_manager = nullptr);
|
const GraphServiceManager* service_manager = nullptr);
|
||||||
|
|
||||||
// Returns true if the ValidatedGraphConfig has been initialized.
|
// Returns true if the ValidatedGraphConfig has been initialized.
|
||||||
|
|
|
@ -155,6 +155,7 @@ TEST(ValidatedGraphConfigTest, InitializeSubgraphWithServiceCalculatorB) {
|
||||||
kStringTestService, std::make_shared<std::string>(calculator_name)));
|
kStringTestService, std::make_shared<std::string>(calculator_name)));
|
||||||
MP_EXPECT_OK(config.Initialize(graph,
|
MP_EXPECT_OK(config.Initialize(graph,
|
||||||
/*graph_registry=*/nullptr,
|
/*graph_registry=*/nullptr,
|
||||||
|
/*subgraph_options=*/nullptr,
|
||||||
/*service_manager=*/&service_manager));
|
/*service_manager=*/&service_manager));
|
||||||
ASSERT_TRUE(config.Initialized());
|
ASSERT_TRUE(config.Initialized());
|
||||||
EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph(
|
EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph(
|
||||||
|
|
|
@ -196,7 +196,9 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":gl_base",
|
":gl_base",
|
||||||
":gl_context",
|
":gl_context",
|
||||||
|
":gl_texture_view",
|
||||||
":gpu_buffer_format",
|
":gpu_buffer_format",
|
||||||
|
":gpu_buffer_storage",
|
||||||
# TODO: remove this dependency. Some other teams' tests
|
# TODO: remove this dependency. Some other teams' tests
|
||||||
# depend on having an indirect image_frame dependency, need to be
|
# depend on having an indirect image_frame dependency, need to be
|
||||||
# fixed first.
|
# fixed first.
|
||||||
|
@ -205,6 +207,17 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_texture_view",
|
||||||
|
srcs = ["gl_texture_view.cc"],
|
||||||
|
hdrs = ["gl_texture_view.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":gl_base",
|
||||||
|
":gl_context",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gpu_buffer",
|
name = "gpu_buffer",
|
||||||
srcs = ["gpu_buffer.cc"],
|
srcs = ["gpu_buffer.cc"],
|
||||||
|
@ -214,12 +227,15 @@ cc_library(
|
||||||
":gl_base",
|
":gl_base",
|
||||||
":gl_context",
|
":gl_context",
|
||||||
":gpu_buffer_format",
|
":gpu_buffer_format",
|
||||||
|
":gpu_buffer_storage",
|
||||||
|
":gl_texture_view",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":gl_texture_buffer",
|
":gl_texture_buffer",
|
||||||
],
|
],
|
||||||
"//mediapipe:ios": [
|
"//mediapipe:ios": [
|
||||||
|
":gpu_buffer_storage_cv_pixel_buffer",
|
||||||
"//mediapipe/objc:util",
|
"//mediapipe/objc:util",
|
||||||
"//mediapipe/objc:CFHolder",
|
"//mediapipe/objc:CFHolder",
|
||||||
],
|
],
|
||||||
|
@ -244,6 +260,35 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gpu_buffer_storage",
|
||||||
|
hdrs = ["gpu_buffer_storage.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":gl_base",
|
||||||
|
":gpu_buffer_format",
|
||||||
|
"//mediapipe/framework/deps:no_destructor",
|
||||||
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
"//mediapipe/framework/port:logging",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gpu_buffer_storage_cv_pixel_buffer",
|
||||||
|
srcs = ["gpu_buffer_storage_cv_pixel_buffer.cc"],
|
||||||
|
hdrs = ["gpu_buffer_storage_cv_pixel_buffer.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":gl_base",
|
||||||
|
":gl_context",
|
||||||
|
":gl_texture_view",
|
||||||
|
":gpu_buffer_storage",
|
||||||
|
"//mediapipe/objc:CFHolder",
|
||||||
|
"//mediapipe/objc:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "gpu_origin_proto",
|
name = "gpu_origin_proto",
|
||||||
srcs = ["gpu_origin.proto"],
|
srcs = ["gpu_origin.proto"],
|
||||||
|
|
|
@ -109,12 +109,10 @@ GlTexture GlCalculatorHelper::CreateSourceTexture(
|
||||||
return impl_->CreateSourceTexture(image_frame);
|
return impl_->CreateSourceTexture(image_frame);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __APPLE__
|
|
||||||
GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer,
|
GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer,
|
||||||
int plane) {
|
int plane) {
|
||||||
return impl_->CreateSourceTexture(pixel_buffer, plane);
|
return impl_->CreateSourceTexture(pixel_buffer, plane);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer,
|
void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer,
|
||||||
int* width, int* height) {
|
int* width, int* height) {
|
||||||
|
|
|
@ -29,10 +29,6 @@
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
#include "mediapipe/gpu/graph_support.h"
|
#include "mediapipe/gpu/graph_support.h"
|
||||||
|
|
||||||
#ifdef __APPLE__
|
|
||||||
#include <CoreVideo/CoreVideo.h>
|
|
||||||
#endif // __APPLE__
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
class GlCalculatorHelperImpl;
|
class GlCalculatorHelperImpl;
|
||||||
|
@ -111,24 +107,35 @@ class GlCalculatorHelper {
|
||||||
// where it is supported (iOS, for now) they take advantage of memory sharing
|
// where it is supported (iOS, for now) they take advantage of memory sharing
|
||||||
// between the CPU and GPU, avoiding memory copies.
|
// between the CPU and GPU, avoiding memory copies.
|
||||||
|
|
||||||
// Creates a texture representing an input frame, and manages sync token.
|
// Gives access to an input frame as an OpenGL texture for reading (sampling).
|
||||||
|
//
|
||||||
|
// IMPORTANT: the returned GlTexture should be treated as a short-term view
|
||||||
|
// into the frame (typically for the duration of a Process call). Do not store
|
||||||
|
// it as a member in your calculator. If you need to keep a frame around,
|
||||||
|
// store the GpuBuffer instead, and call CreateSourceTexture again on each
|
||||||
|
// Process call.
|
||||||
|
//
|
||||||
|
// TODO: rename this; the use of "Create" makes this sound more expensive than
|
||||||
|
// it is.
|
||||||
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer);
|
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer);
|
||||||
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
|
|
||||||
GlTexture CreateSourceTexture(const mediapipe::Image& image);
|
GlTexture CreateSourceTexture(const mediapipe::Image& image);
|
||||||
|
|
||||||
#ifdef __APPLE__
|
// Gives read access to a plane of a planar buffer.
|
||||||
// Creates a texture from a plane of a planar buffer.
|
|
||||||
// The plane index is zero-based. The number of planes depends on the
|
// The plane index is zero-based. The number of planes depends on the
|
||||||
// internal format of the buffer.
|
// internal format of the buffer.
|
||||||
|
// Note: multi-plane support is not available on all platforms.
|
||||||
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer, int plane);
|
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer, int plane);
|
||||||
#endif
|
|
||||||
|
// Convenience function for converting an ImageFrame to GpuBuffer and then
|
||||||
|
// accessing it as a texture.
|
||||||
|
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
|
||||||
|
|
||||||
// Extracts GpuBuffer dimensions without creating a texture.
|
// Extracts GpuBuffer dimensions without creating a texture.
|
||||||
ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead")
|
ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead")
|
||||||
void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width,
|
void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width,
|
||||||
int* height);
|
int* height);
|
||||||
|
|
||||||
// Creates a texture representing an output frame, and manages sync token.
|
// Gives access to an OpenGL texture for writing (rendering) a new frame.
|
||||||
// TODO: This should either return errors or a status.
|
// TODO: This should either return errors or a status.
|
||||||
GlTexture CreateDestinationTexture(
|
GlTexture CreateDestinationTexture(
|
||||||
int output_width, int output_height,
|
int output_width, int output_height,
|
||||||
|
|
|
@ -62,10 +62,7 @@ class GlCalculatorHelperImpl {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Makes a GpuBuffer accessible as a texture in the GL context.
|
// Makes a GpuBuffer accessible as a texture in the GL context.
|
||||||
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, int plane,
|
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view);
|
||||||
bool for_reading);
|
|
||||||
void AttachGlTexture(GlTexture& texture, const GpuBuffer& gpu_buffer,
|
|
||||||
int plane, bool for_reading);
|
|
||||||
|
|
||||||
// Create the framebuffer for rendering.
|
// Create the framebuffer for rendering.
|
||||||
void CreateFramebuffer();
|
void CreateFramebuffer();
|
||||||
|
|
|
@ -91,9 +91,7 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
|
||||||
}
|
}
|
||||||
|
|
||||||
GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
|
GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
|
||||||
int plane, bool for_reading) {
|
GlTextureView view) {
|
||||||
GlTextureView view = gpu_buffer.GetGlTextureView(plane, for_reading);
|
|
||||||
|
|
||||||
if (gpu_buffer.format() != GpuBufferFormat::kUnknown) {
|
if (gpu_buffer.format() != GpuBufferFormat::kUnknown) {
|
||||||
// TODO: do the params need to be reset here??
|
// TODO: do the params need to be reset here??
|
||||||
glBindTexture(view.target(), view.name());
|
glBindTexture(view.target(), view.name());
|
||||||
|
@ -109,19 +107,18 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
|
||||||
|
|
||||||
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
||||||
const GpuBuffer& gpu_buffer) {
|
const GpuBuffer& gpu_buffer) {
|
||||||
return MapGpuBuffer(gpu_buffer, 0, true);
|
return CreateSourceTexture(gpu_buffer, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
||||||
const GpuBuffer& gpu_buffer, int plane) {
|
const GpuBuffer& gpu_buffer, int plane) {
|
||||||
return MapGpuBuffer(gpu_buffer, plane, true);
|
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(plane));
|
||||||
}
|
}
|
||||||
|
|
||||||
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
||||||
const ImageFrame& image_frame) {
|
const ImageFrame& image_frame) {
|
||||||
GlTexture texture =
|
auto gpu_buffer = GpuBuffer::CopyingImageFrame(image_frame);
|
||||||
MapGpuBuffer(GpuBuffer::CopyingImageFrame(image_frame), 0, true);
|
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(0));
|
||||||
return texture;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -150,11 +147,9 @@ GlTexture GlCalculatorHelperImpl::CreateDestinationTexture(
|
||||||
CreateFramebuffer();
|
CreateFramebuffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuBuffer buffer =
|
GpuBuffer gpu_buffer =
|
||||||
gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format);
|
gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format);
|
||||||
GlTexture texture = MapGpuBuffer(buffer, 0, false);
|
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureWriteView(0));
|
||||||
|
|
||||||
return texture;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -14,6 +14,9 @@
|
||||||
|
|
||||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/gpu/gl_texture_view.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Wrap(
|
std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Wrap(
|
||||||
|
@ -122,6 +125,15 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
||||||
|
|
||||||
if (alignment != 4 && data) glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
|
if (alignment != 4 && data) glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
|
||||||
|
|
||||||
|
// TODO: does this need to set the texture params? We set them again when the
|
||||||
|
// texture is actually acccessed via GlTexture[View]. Or should they always be
|
||||||
|
// set on creation?
|
||||||
|
if (format_ != GpuBufferFormat::kUnknown) {
|
||||||
|
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||||
|
format_, /*plane=*/0, context->GetGlVersion());
|
||||||
|
context->SetStandardTextureParams(target_, info.gl_internal_format);
|
||||||
|
}
|
||||||
|
|
||||||
glBindTexture(target_, 0);
|
glBindTexture(target_, 0);
|
||||||
|
|
||||||
// Use the deletion callback to delete the texture on the context
|
// Use the deletion callback to delete the texture on the context
|
||||||
|
@ -167,7 +179,7 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
||||||
producer_context_ = producer_sync_->GetContext();
|
producer_context_ = producer_sync_->GetContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) {
|
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
||||||
absl::MutexLock lock(&consumer_sync_mutex_);
|
absl::MutexLock lock(&consumer_sync_mutex_);
|
||||||
consumer_multi_sync_->Add(std::move(cons_token));
|
consumer_multi_sync_->Add(std::move(cons_token));
|
||||||
}
|
}
|
||||||
|
@ -181,7 +193,7 @@ GlTextureBuffer::~GlTextureBuffer() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlTextureBuffer::WaitUntilComplete() {
|
void GlTextureBuffer::WaitUntilComplete() const {
|
||||||
// Buffers created by the application (using the constructor that wraps an
|
// Buffers created by the application (using the constructor that wraps an
|
||||||
// existing texture) have no sync token and are assumed to be already
|
// existing texture) have no sync token and are assumed to be already
|
||||||
// complete.
|
// complete.
|
||||||
|
@ -190,7 +202,7 @@ void GlTextureBuffer::WaitUntilComplete() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlTextureBuffer::WaitOnGpu() {
|
void GlTextureBuffer::WaitOnGpu() const {
|
||||||
// Buffers created by the application (using the constructor that wraps an
|
// Buffers created by the application (using the constructor that wraps an
|
||||||
// existing texture) have no sync token and are assumed to be already
|
// existing texture) have no sync token and are assumed to be already
|
||||||
// complete.
|
// complete.
|
||||||
|
@ -212,4 +224,127 @@ void GlTextureBuffer::WaitForConsumersOnGpu() {
|
||||||
// precisely, on only one GL context.
|
// precisely, on only one GL context.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GlTextureView GlTextureBuffer::GetGlTextureReadView(
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
|
||||||
|
auto gl_context = GlContext::GetCurrent();
|
||||||
|
CHECK(gl_context);
|
||||||
|
CHECK_EQ(plane, 0);
|
||||||
|
// Insert wait call to sync with the producer.
|
||||||
|
WaitOnGpu();
|
||||||
|
GlTextureView::DetachFn detach = [this](mediapipe::GlTextureView& texture) {
|
||||||
|
// Inform the GlTextureBuffer that we have finished accessing its
|
||||||
|
// contents, and create a consumer sync point.
|
||||||
|
DidRead(texture.gl_context()->CreateSyncToken());
|
||||||
|
};
|
||||||
|
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
|
||||||
|
std::move(gpu_buffer), plane, std::move(detach),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
GlTextureView GlTextureBuffer::GetGlTextureWriteView(
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
|
||||||
|
auto gl_context = GlContext::GetCurrent();
|
||||||
|
CHECK(gl_context);
|
||||||
|
CHECK_EQ(plane, 0);
|
||||||
|
// Insert wait call to sync with the producer.
|
||||||
|
WaitOnGpu();
|
||||||
|
Reuse(); // TODO: the producer wait should probably be part of Reuse in the
|
||||||
|
// case when there are no consumers.
|
||||||
|
GlTextureView::DoneWritingFn done_writing =
|
||||||
|
[this](const mediapipe::GlTextureView& texture) {
|
||||||
|
ViewDoneWriting(texture);
|
||||||
|
};
|
||||||
|
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
|
||||||
|
std::move(gpu_buffer), plane, nullptr,
|
||||||
|
std::move(done_writing));
|
||||||
|
}
|
||||||
|
|
||||||
|
void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) {
|
||||||
|
// Inform the GlTextureBuffer that we have produced new content, and create
|
||||||
|
// a producer sync point.
|
||||||
|
Updated(view.gl_context()->CreateSyncToken());
|
||||||
|
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
// On (some?) Android devices, the texture may need to be explicitly
|
||||||
|
// detached from the current framebuffer.
|
||||||
|
// TODO: is this necessary even with the unbind in BindFramebuffer?
|
||||||
|
// It is not clear if this affected other contexts too, but let's keep it
|
||||||
|
// while in doubt.
|
||||||
|
GLint type = GL_NONE;
|
||||||
|
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||||
|
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
|
||||||
|
&type);
|
||||||
|
if (type == GL_TEXTURE) {
|
||||||
|
GLint color_attachment = 0;
|
||||||
|
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||||
|
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
||||||
|
&color_attachment);
|
||||||
|
if (color_attachment == name()) {
|
||||||
|
glBindFramebuffer(GL_FRAMEBUFFER, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some Android drivers log a GL_INVALID_ENUM error after the first
|
||||||
|
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
|
||||||
|
// even though it should be ok to ask for the type and get back GL_NONE.
|
||||||
|
// Let's just ignore any pending errors here.
|
||||||
|
GLenum error;
|
||||||
|
while ((error = glGetError()) != GL_NO_ERROR) {
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // __ANDROID__
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ReadTexture(const GlTextureView& view, GpuBufferFormat format,
|
||||||
|
void* output, size_t size) {
|
||||||
|
// TODO: check buffer size? We could use glReadnPixels where available
|
||||||
|
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
|
||||||
|
// won't overflow the buffer with glReadPixels, we'd also need to check or
|
||||||
|
// reset several glPixelStore parameters (e.g. what if someone had the
|
||||||
|
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
|
||||||
|
CHECK(view.gl_context());
|
||||||
|
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||||
|
format, view.plane(), view.gl_context()->GetGlVersion());
|
||||||
|
|
||||||
|
GLint current_fbo;
|
||||||
|
glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo);
|
||||||
|
CHECK_NE(current_fbo, 0);
|
||||||
|
|
||||||
|
GLint color_attachment_name;
|
||||||
|
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||||
|
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
||||||
|
&color_attachment_name);
|
||||||
|
if (color_attachment_name != view.name()) {
|
||||||
|
// Save the viewport. Note that we assume that the color attachment is a
|
||||||
|
// GL_TEXTURE_2D texture.
|
||||||
|
GLint viewport[4];
|
||||||
|
glGetIntegerv(GL_VIEWPORT, viewport);
|
||||||
|
|
||||||
|
// Set the data from GLTextureView object.
|
||||||
|
glViewport(0, 0, view.width(), view.height());
|
||||||
|
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
||||||
|
view.name(), 0);
|
||||||
|
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
||||||
|
info.gl_type, output);
|
||||||
|
|
||||||
|
// Restore from the saved viewport and color attachment name.
|
||||||
|
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
|
||||||
|
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
|
||||||
|
color_attachment_name, 0);
|
||||||
|
} else {
|
||||||
|
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
||||||
|
info.gl_type, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ImageFrame> GlTextureBuffer::AsImageFrame() const {
|
||||||
|
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
|
||||||
|
auto output = absl::make_unique<ImageFrame>(
|
||||||
|
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
|
||||||
|
auto view = GetGlTextureReadView(nullptr, 0);
|
||||||
|
ReadTexture(view, format(), output->MutablePixelData(),
|
||||||
|
output->PixelDataSize());
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -25,13 +25,14 @@
|
||||||
#include "mediapipe/gpu/gl_base.h"
|
#include "mediapipe/gpu/gl_base.h"
|
||||||
#include "mediapipe/gpu/gl_context.h"
|
#include "mediapipe/gpu/gl_context.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||||
|
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
class GlCalculatorHelperImpl;
|
class GlCalculatorHelperImpl;
|
||||||
|
|
||||||
// Implements a GPU memory buffer as an OpenGL texture. For internal use.
|
// Implements a GPU memory buffer as an OpenGL texture. For internal use.
|
||||||
class GlTextureBuffer {
|
class GlTextureBuffer : public mediapipe::internal::GpuBufferStorage {
|
||||||
public:
|
public:
|
||||||
// This is called when the texture buffer is deleted. It is passed a sync
|
// This is called when the texture buffer is deleted. It is passed a sync
|
||||||
// token created at that time on the GlContext. If the GlTextureBuffer has
|
// token created at that time on the GlContext. If the GlTextureBuffer has
|
||||||
|
@ -85,6 +86,13 @@ class GlTextureBuffer {
|
||||||
int height() const { return height_; }
|
int height() const { return height_; }
|
||||||
GpuBufferFormat format() const { return format_; }
|
GpuBufferFormat format() const { return format_; }
|
||||||
|
|
||||||
|
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||||
|
int plane) const override;
|
||||||
|
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||||
|
int plane) override;
|
||||||
|
void ViewDoneWriting(const GlTextureView& view) override;
|
||||||
|
std::unique_ptr<ImageFrame> AsImageFrame() const override;
|
||||||
|
|
||||||
// If this texture is going to be used outside of the context that produced
|
// If this texture is going to be used outside of the context that produced
|
||||||
// it, this method should be called to ensure that its updated contents are
|
// it, this method should be called to ensure that its updated contents are
|
||||||
// available. When this method returns, all changed made before the call to
|
// available. When this method returns, all changed made before the call to
|
||||||
|
@ -94,13 +102,13 @@ class GlTextureBuffer {
|
||||||
// NOTE: This blocks the current CPU thread and makes the changes visible
|
// NOTE: This blocks the current CPU thread and makes the changes visible
|
||||||
// to the CPU. If you want to access the data via OpenGL, use WaitOnGpu
|
// to the CPU. If you want to access the data via OpenGL, use WaitOnGpu
|
||||||
// instead.
|
// instead.
|
||||||
void WaitUntilComplete();
|
void WaitUntilComplete() const;
|
||||||
|
|
||||||
// Call this method to synchronize the current GL context with the texture's
|
// Call this method to synchronize the current GL context with the texture's
|
||||||
// producer. This will not block the current CPU thread, but will ensure that
|
// producer. This will not block the current CPU thread, but will ensure that
|
||||||
// subsequent GL commands see the texture in its complete status, with all
|
// subsequent GL commands see the texture in its complete status, with all
|
||||||
// rendering done on the GPU by the generating context.
|
// rendering done on the GPU by the generating context.
|
||||||
void WaitOnGpu();
|
void WaitOnGpu() const;
|
||||||
|
|
||||||
// Informs the buffer that its contents are going to be overwritten.
|
// Informs the buffer that its contents are going to be overwritten.
|
||||||
// This invalidates the current sync token.
|
// This invalidates the current sync token.
|
||||||
|
@ -114,7 +122,7 @@ class GlTextureBuffer {
|
||||||
void Updated(std::shared_ptr<GlSyncPoint> prod_token);
|
void Updated(std::shared_ptr<GlSyncPoint> prod_token);
|
||||||
|
|
||||||
// Informs the buffer that a consumer has finished reading from it.
|
// Informs the buffer that a consumer has finished reading from it.
|
||||||
void DidRead(std::shared_ptr<GlSyncPoint> cons_token);
|
void DidRead(std::shared_ptr<GlSyncPoint> cons_token) const;
|
||||||
|
|
||||||
// Waits for all pending consumers to finish accessing the current content
|
// Waits for all pending consumers to finish accessing the current content
|
||||||
// of the texture. This (preferably the OnGpu version) should be called
|
// of the texture. This (preferably the OnGpu version) should be called
|
||||||
|
@ -143,10 +151,11 @@ class GlTextureBuffer {
|
||||||
const GLenum target_ = GL_TEXTURE_2D;
|
const GLenum target_ = GL_TEXTURE_2D;
|
||||||
// Token tracking changes to this texture. Used by WaitUntilComplete.
|
// Token tracking changes to this texture. Used by WaitUntilComplete.
|
||||||
std::shared_ptr<GlSyncPoint> producer_sync_;
|
std::shared_ptr<GlSyncPoint> producer_sync_;
|
||||||
absl::Mutex consumer_sync_mutex_;
|
mutable absl::Mutex consumer_sync_mutex_;
|
||||||
// Tokens tracking the point when consumers finished using this texture.
|
// Tokens tracking the point when consumers finished using this texture.
|
||||||
std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_ ABSL_GUARDED_BY(
|
mutable std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_
|
||||||
consumer_sync_mutex_) = absl::make_unique<GlMultiSyncPoint>();
|
ABSL_GUARDED_BY(consumer_sync_mutex_) =
|
||||||
|
absl::make_unique<GlMultiSyncPoint>();
|
||||||
DeletionCallback deletion_callback_;
|
DeletionCallback deletion_callback_;
|
||||||
std::shared_ptr<GlContext> producer_context_;
|
std::shared_ptr<GlContext> producer_context_;
|
||||||
};
|
};
|
||||||
|
|
16
mediapipe/gpu/gl_texture_view.cc
Normal file
16
mediapipe/gpu/gl_texture_view.cc
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
#include "mediapipe/gpu/gl_texture_view.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
void GlTextureView::Release() {
|
||||||
|
if (detach_) detach_(*this);
|
||||||
|
detach_ = nullptr;
|
||||||
|
gl_context_ = nullptr;
|
||||||
|
gpu_buffer_ = nullptr;
|
||||||
|
plane_ = 0;
|
||||||
|
name_ = 0;
|
||||||
|
width_ = 0;
|
||||||
|
height_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
86
mediapipe/gpu/gl_texture_view.h
Normal file
86
mediapipe/gpu/gl_texture_view.h
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
// Copyright 2019 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
|
||||||
|
#define MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mediapipe/gpu/gl_base.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
class GlContext;
|
||||||
|
class GlTextureViewManager;
|
||||||
|
class GpuBuffer;
|
||||||
|
|
||||||
|
class GlTextureView {
|
||||||
|
public:
|
||||||
|
GlTextureView() {}
|
||||||
|
~GlTextureView() { Release(); }
|
||||||
|
// TODO: make this class move-only.
|
||||||
|
|
||||||
|
GlContext* gl_context() const { return gl_context_; }
|
||||||
|
int width() const { return width_; }
|
||||||
|
int height() const { return height_; }
|
||||||
|
GLenum target() const { return target_; }
|
||||||
|
GLuint name() const { return name_; }
|
||||||
|
const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; }
|
||||||
|
int plane() const { return plane_; }
|
||||||
|
|
||||||
|
using DetachFn = std::function<void(GlTextureView&)>;
|
||||||
|
using DoneWritingFn = std::function<void(const GlTextureView&)>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class GpuBuffer;
|
||||||
|
friend class GlTextureBuffer;
|
||||||
|
friend class GpuBufferStorageCvPixelBuffer;
|
||||||
|
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
|
||||||
|
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
|
||||||
|
DetachFn detach, DoneWritingFn done_writing)
|
||||||
|
: gl_context_(context),
|
||||||
|
target_(target),
|
||||||
|
name_(name),
|
||||||
|
width_(width),
|
||||||
|
height_(height),
|
||||||
|
gpu_buffer_(std::move(gpu_buffer)),
|
||||||
|
plane_(plane),
|
||||||
|
detach_(std::move(detach)),
|
||||||
|
done_writing_(std::move(done_writing)) {}
|
||||||
|
|
||||||
|
// TODO: remove this friend declaration.
|
||||||
|
friend class GlTexture;
|
||||||
|
void Release();
|
||||||
|
// TODO: make this non-const.
|
||||||
|
void DoneWriting() const {
|
||||||
|
if (done_writing_) done_writing_(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
GlContext* gl_context_ = nullptr;
|
||||||
|
GLenum target_ = GL_TEXTURE_2D;
|
||||||
|
GLuint name_ = 0;
|
||||||
|
// Note: when scale is not 1, we still give the nominal size of the image.
|
||||||
|
int width_ = 0;
|
||||||
|
int height_ = 0;
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer_; // using shared_ptr temporarily
|
||||||
|
int plane_ = 0;
|
||||||
|
DetachFn detach_;
|
||||||
|
DoneWritingFn done_writing_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
|
|
@ -8,62 +8,7 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
void GlTextureView::Release() {
|
|
||||||
if (detach_) detach_(*this);
|
|
||||||
detach_ = nullptr;
|
|
||||||
gl_context_ = nullptr;
|
|
||||||
gpu_buffer_ = nullptr;
|
|
||||||
plane_ = 0;
|
|
||||||
name_ = 0;
|
|
||||||
width_ = 0;
|
|
||||||
height_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
#if TARGET_OS_OSX
|
|
||||||
typedef CVOpenGLTextureRef CVTextureType;
|
|
||||||
#else
|
|
||||||
typedef CVOpenGLESTextureRef CVTextureType;
|
|
||||||
#endif // TARGET_OS_OSX
|
|
||||||
|
|
||||||
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
|
|
||||||
CVReturn err;
|
|
||||||
auto gl_context = GlContext::GetCurrent();
|
|
||||||
CHECK(gl_context);
|
|
||||||
#if TARGET_OS_OSX
|
|
||||||
CVTextureType cv_texture_temp;
|
|
||||||
err = CVOpenGLTextureCacheCreateTextureFromImage(
|
|
||||||
kCFAllocatorDefault, gl_context->cv_texture_cache(),
|
|
||||||
GetCVPixelBufferRef(), NULL, &cv_texture_temp);
|
|
||||||
CHECK(cv_texture_temp && !err)
|
|
||||||
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
|
|
||||||
CFHolder<CVTextureType> cv_texture;
|
|
||||||
cv_texture.adopt(cv_texture_temp);
|
|
||||||
return GlTextureView(
|
|
||||||
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
|
|
||||||
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
|
|
||||||
[cv_texture](
|
|
||||||
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
|
|
||||||
#else
|
|
||||||
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
|
||||||
format(), plane, gl_context->GetGlVersion());
|
|
||||||
CVTextureType cv_texture_temp;
|
|
||||||
err = CVOpenGLESTextureCacheCreateTextureFromImage(
|
|
||||||
kCFAllocatorDefault, gl_context->cv_texture_cache(),
|
|
||||||
GetCVPixelBufferRef(), NULL, GL_TEXTURE_2D, info.gl_internal_format,
|
|
||||||
width() / info.downscale, height() / info.downscale, info.gl_format,
|
|
||||||
info.gl_type, plane, &cv_texture_temp);
|
|
||||||
CHECK(cv_texture_temp && !err)
|
|
||||||
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
|
|
||||||
CFHolder<CVTextureType> cv_texture;
|
|
||||||
cv_texture.adopt(cv_texture_temp);
|
|
||||||
return GlTextureView(
|
|
||||||
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
|
|
||||||
CVOpenGLESTextureGetName(*cv_texture), width(), height(), *this, plane,
|
|
||||||
[cv_texture](
|
|
||||||
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
|
|
||||||
#endif // TARGET_OS_OSX
|
|
||||||
}
|
|
||||||
|
|
||||||
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
||||||
auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame);
|
auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame);
|
||||||
|
@ -72,187 +17,11 @@ GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
||||||
CHECK_OK(maybe_buffer.status());
|
CHECK_OK(maybe_buffer.status());
|
||||||
return GpuBuffer(std::move(maybe_buffer).value());
|
return GpuBuffer(std::move(maybe_buffer).value());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
|
|
||||||
CHECK(GetCVPixelBufferRef());
|
|
||||||
return CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef());
|
|
||||||
}
|
|
||||||
|
|
||||||
void GlTextureView::DoneWriting() const {
|
|
||||||
CHECK(gpu_buffer_);
|
|
||||||
#if TARGET_IPHONE_SIMULATOR
|
|
||||||
CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef();
|
|
||||||
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
|
||||||
CHECK(err == kCVReturnSuccess)
|
|
||||||
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
|
||||||
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
|
|
||||||
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
|
|
||||||
uint8_t* pixel_ptr =
|
|
||||||
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
|
||||||
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
|
||||||
// TODO: restore previous framebuffer? Move this to helper so we
|
|
||||||
// can use BindFramebuffer?
|
|
||||||
glViewport(0, 0, width(), height());
|
|
||||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target(),
|
|
||||||
name(), 0);
|
|
||||||
|
|
||||||
size_t contiguous_bytes_per_row = width() * 4;
|
|
||||||
if (bytes_per_row == contiguous_bytes_per_row) {
|
|
||||||
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
|
||||||
pixel_ptr);
|
|
||||||
} else {
|
|
||||||
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
|
||||||
height());
|
|
||||||
uint8_t* temp_ptr = contiguous_buffer.data();
|
|
||||||
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
|
||||||
temp_ptr);
|
|
||||||
for (int i = 0; i < height(); ++i) {
|
|
||||||
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
|
||||||
temp_ptr += contiguous_bytes_per_row;
|
|
||||||
pixel_ptr += bytes_per_row;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
|
||||||
}
|
|
||||||
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
|
||||||
CHECK(err == kCVReturnSuccess)
|
|
||||||
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
|
||||||
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
|
|
||||||
auto gl_context = GlContext::GetCurrent();
|
|
||||||
CHECK(gl_context);
|
|
||||||
const GlTextureBufferSharedPtr& texture_buffer =
|
|
||||||
GetGlTextureBufferSharedPtr();
|
|
||||||
// Insert wait call to sync with the producer.
|
|
||||||
texture_buffer->WaitOnGpu();
|
|
||||||
CHECK_EQ(plane, 0);
|
|
||||||
GlTextureView::DetachFn detach;
|
|
||||||
if (for_reading) {
|
|
||||||
detach = [](mediapipe::GlTextureView& texture) {
|
|
||||||
// Inform the GlTextureBuffer that we have finished accessing its
|
|
||||||
// contents, and create a consumer sync point.
|
|
||||||
texture.gpu_buffer().GetGlTextureBufferSharedPtr()->DidRead(
|
|
||||||
texture.gl_context()->CreateSyncToken());
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return GlTextureView(gl_context.get(), texture_buffer->target(),
|
|
||||||
texture_buffer->name(), width(), height(), *this, plane,
|
|
||||||
std::move(detach));
|
|
||||||
}
|
|
||||||
|
|
||||||
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
||||||
auto gl_context = GlContext::GetCurrent();
|
return GpuBuffer(GlTextureBuffer::Create(image_frame));
|
||||||
CHECK(gl_context);
|
|
||||||
|
|
||||||
auto buffer = GlTextureBuffer::Create(image_frame);
|
|
||||||
|
|
||||||
// TODO: does this need to set the texture params? We set them again when the
|
|
||||||
// texture is actually acccessed via GlTexture[View]. Or should they always be
|
|
||||||
// set on creation?
|
|
||||||
if (buffer->format() != GpuBufferFormat::kUnknown) {
|
|
||||||
glBindTexture(GL_TEXTURE_2D, buffer->name());
|
|
||||||
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
|
||||||
buffer->format(), /*plane=*/0, gl_context->GetGlVersion());
|
|
||||||
gl_context->SetStandardTextureParams(buffer->target(),
|
|
||||||
info.gl_internal_format);
|
|
||||||
glBindTexture(GL_TEXTURE_2D, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return GpuBuffer(std::move(buffer));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ReadTexture(const GlTextureView& view, void* output, size_t size) {
|
|
||||||
// TODO: check buffer size? We could use glReadnPixels where available
|
|
||||||
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
|
|
||||||
// won't overflow the buffer with glReadPixels, we'd also need to check or
|
|
||||||
// reset several glPixelStore parameters (e.g. what if someone had the
|
|
||||||
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
|
|
||||||
CHECK(view.gl_context());
|
|
||||||
GlTextureInfo info =
|
|
||||||
GlTextureInfoForGpuBufferFormat(view.gpu_buffer().format(), view.plane(),
|
|
||||||
view.gl_context()->GetGlVersion());
|
|
||||||
|
|
||||||
GLint current_fbo;
|
|
||||||
glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo);
|
|
||||||
CHECK_NE(current_fbo, 0);
|
|
||||||
|
|
||||||
GLint color_attachment_name;
|
|
||||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
|
||||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
|
||||||
&color_attachment_name);
|
|
||||||
if (color_attachment_name != view.name()) {
|
|
||||||
// Save the viewport. Note that we assume that the color attachment is a
|
|
||||||
// GL_TEXTURE_2D texture.
|
|
||||||
GLint viewport[4];
|
|
||||||
glGetIntegerv(GL_VIEWPORT, viewport);
|
|
||||||
|
|
||||||
// Set the data from GLTextureView object.
|
|
||||||
glViewport(0, 0, view.width(), view.height());
|
|
||||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
|
||||||
view.name(), 0);
|
|
||||||
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
|
||||||
info.gl_type, output);
|
|
||||||
|
|
||||||
// Restore from the saved viewport and color attachment name.
|
|
||||||
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
|
|
||||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
|
|
||||||
color_attachment_name, 0);
|
|
||||||
} else {
|
|
||||||
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
|
||||||
info.gl_type, output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
|
|
||||||
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
|
|
||||||
auto output = absl::make_unique<ImageFrame>(
|
|
||||||
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
|
|
||||||
auto view = GetGlTextureView(0, true);
|
|
||||||
ReadTexture(view, output->MutablePixelData(), output->PixelDataSize());
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GlTextureView::DoneWriting() const {
|
|
||||||
CHECK(gpu_buffer_);
|
|
||||||
// Inform the GlTextureBuffer that we have produced new content, and create
|
|
||||||
// a producer sync point.
|
|
||||||
gpu_buffer_.GetGlTextureBufferSharedPtr()->Updated(
|
|
||||||
gl_context()->CreateSyncToken());
|
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
|
||||||
// On (some?) Android devices, the texture may need to be explicitly
|
|
||||||
// detached from the current framebuffer.
|
|
||||||
// TODO: is this necessary even with the unbind in BindFramebuffer?
|
|
||||||
// It is not clear if this affected other contexts too, but let's keep it
|
|
||||||
// while in doubt.
|
|
||||||
GLint type = GL_NONE;
|
|
||||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
|
||||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
|
|
||||||
&type);
|
|
||||||
if (type == GL_TEXTURE) {
|
|
||||||
GLint color_attachment = 0;
|
|
||||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
|
||||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
|
||||||
&color_attachment);
|
|
||||||
if (color_attachment == name()) {
|
|
||||||
glBindFramebuffer(GL_FRAMEBUFFER, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some Android drivers log a GL_INVALID_ENUM error after the first
|
|
||||||
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
|
|
||||||
// even though it should be ok to ask for the type and get back GL_NONE.
|
|
||||||
// Let's just ignore any pending errors here.
|
|
||||||
GLenum error;
|
|
||||||
while ((error = glGetError()) != GL_NO_ERROR) {
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // __ANDROID__
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
|
|
@ -19,7 +19,9 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/gpu/gl_base.h"
|
#include "mediapipe/gpu/gl_base.h"
|
||||||
|
#include "mediapipe/gpu/gl_texture_view.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||||
|
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
#if defined(__APPLE__)
|
||||||
#include <CoreVideo/CoreVideo.h>
|
#include <CoreVideo/CoreVideo.h>
|
||||||
|
@ -27,6 +29,10 @@
|
||||||
#include "mediapipe/objc/CFHolder.h"
|
#include "mediapipe/objc/CFHolder.h"
|
||||||
#endif // defined(__APPLE__)
|
#endif // defined(__APPLE__)
|
||||||
|
|
||||||
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
|
||||||
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
|
||||||
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
@ -34,7 +40,6 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
class GlContext;
|
class GlContext;
|
||||||
class GlTextureView;
|
|
||||||
|
|
||||||
// This class wraps a platform-specific buffer of GPU data.
|
// This class wraps a platform-specific buffer of GPU data.
|
||||||
// An instance of GpuBuffer acts as an opaque reference to the underlying
|
// An instance of GpuBuffer acts as an opaque reference to the underlying
|
||||||
|
@ -71,9 +76,9 @@ class GpuBuffer {
|
||||||
}
|
}
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
|
||||||
int width() const;
|
int width() const { return current_storage().width(); }
|
||||||
int height() const;
|
int height() const { return current_storage().height(); }
|
||||||
GpuBufferFormat format() const;
|
GpuBufferFormat format() const { return current_storage().format(); }
|
||||||
|
|
||||||
// Converts to true iff valid.
|
// Converts to true iff valid.
|
||||||
explicit operator bool() const { return operator!=(nullptr); }
|
explicit operator bool() const { return operator!=(nullptr); }
|
||||||
|
@ -88,8 +93,15 @@ class GpuBuffer {
|
||||||
// Allow assignment from nullptr.
|
// Allow assignment from nullptr.
|
||||||
GpuBuffer& operator=(std::nullptr_t other);
|
GpuBuffer& operator=(std::nullptr_t other);
|
||||||
|
|
||||||
// TODO: split into read and write, remove const from write.
|
GlTextureView GetGlTextureReadView(int plane) const {
|
||||||
GlTextureView GetGlTextureView(int plane, bool for_reading) const;
|
return current_storage().GetGlTextureReadView(
|
||||||
|
std::make_shared<GpuBuffer>(*this), plane);
|
||||||
|
}
|
||||||
|
|
||||||
|
GlTextureView GetGlTextureWriteView(int plane) {
|
||||||
|
return current_storage().GetGlTextureWriteView(
|
||||||
|
std::make_shared<GpuBuffer>(*this), plane);
|
||||||
|
}
|
||||||
|
|
||||||
// Make a GpuBuffer copying the data from an ImageFrame.
|
// Make a GpuBuffer copying the data from an ImageFrame.
|
||||||
static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame);
|
static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame);
|
||||||
|
@ -99,114 +111,84 @@ class GpuBuffer {
|
||||||
// In order to work correctly across platforms, callers should always treat
|
// In order to work correctly across platforms, callers should always treat
|
||||||
// the returned ImageFrame as if it shares memory with the GpuBuffer, i.e.
|
// the returned ImageFrame as if it shares memory with the GpuBuffer, i.e.
|
||||||
// treat it as immutable if the GpuBuffer must not be modified.
|
// treat it as immutable if the GpuBuffer must not be modified.
|
||||||
std::unique_ptr<ImageFrame> AsImageFrame() const;
|
std::unique_ptr<ImageFrame> AsImageFrame() const {
|
||||||
|
return current_storage().AsImageFrame();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
class PlaceholderGpuBufferStorage
|
||||||
|
: public mediapipe::internal::GpuBufferStorage {
|
||||||
|
public:
|
||||||
|
int width() const override { return 0; }
|
||||||
|
int height() const override { return 0; }
|
||||||
|
virtual GpuBufferFormat format() const override {
|
||||||
|
return GpuBufferFormat::kUnknown;
|
||||||
|
}
|
||||||
|
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||||
|
int plane) const override {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||||
|
int plane) override {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
void ViewDoneWriting(const GlTextureView& view) override{};
|
||||||
|
std::unique_ptr<ImageFrame> AsImageFrame() const override {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mediapipe::internal::GpuBufferStorage& no_storage() const {
|
||||||
|
static PlaceholderGpuBufferStorage placeholder;
|
||||||
|
return placeholder;
|
||||||
|
}
|
||||||
|
|
||||||
|
const mediapipe::internal::GpuBufferStorage& current_storage() const {
|
||||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
CFHolder<CVPixelBufferRef> pixel_buffer_;
|
if (pixel_buffer_ != nullptr) return pixel_buffer_;
|
||||||
|
#else
|
||||||
|
if (texture_buffer_) return *texture_buffer_;
|
||||||
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
return no_storage();
|
||||||
|
}
|
||||||
|
|
||||||
|
mediapipe::internal::GpuBufferStorage& current_storage() {
|
||||||
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
if (pixel_buffer_ != nullptr) return pixel_buffer_;
|
||||||
|
#else
|
||||||
|
if (texture_buffer_) return *texture_buffer_;
|
||||||
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
return no_storage();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
GpuBufferStorageCvPixelBuffer pixel_buffer_;
|
||||||
#else
|
#else
|
||||||
GlTextureBufferSharedPtr texture_buffer_;
|
GlTextureBufferSharedPtr texture_buffer_;
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
};
|
};
|
||||||
|
|
||||||
class GlTextureView {
|
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
||||||
public:
|
return ¤t_storage() == &no_storage();
|
||||||
GlTextureView() {}
|
}
|
||||||
~GlTextureView() { Release(); }
|
|
||||||
// TODO: make this class move-only.
|
|
||||||
|
|
||||||
GlContext* gl_context() const { return gl_context_; }
|
|
||||||
int width() const { return width_; }
|
|
||||||
int height() const { return height_; }
|
|
||||||
GLenum target() const { return target_; }
|
|
||||||
GLuint name() const { return name_; }
|
|
||||||
const GpuBuffer& gpu_buffer() const { return gpu_buffer_; }
|
|
||||||
int plane() const { return plane_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class GpuBuffer;
|
|
||||||
using DetachFn = std::function<void(GlTextureView&)>;
|
|
||||||
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
|
|
||||||
int height, GpuBuffer gpu_buffer, int plane, DetachFn detach)
|
|
||||||
: gl_context_(context),
|
|
||||||
target_(target),
|
|
||||||
name_(name),
|
|
||||||
width_(width),
|
|
||||||
height_(height),
|
|
||||||
gpu_buffer_(std::move(gpu_buffer)),
|
|
||||||
plane_(plane),
|
|
||||||
detach_(std::move(detach)) {}
|
|
||||||
|
|
||||||
// TODO: remove this friend declaration.
|
|
||||||
friend class GlTexture;
|
|
||||||
void Release();
|
|
||||||
// TODO: make this non-const.
|
|
||||||
void DoneWriting() const;
|
|
||||||
|
|
||||||
GlContext* gl_context_ = nullptr;
|
|
||||||
GLenum target_ = GL_TEXTURE_2D;
|
|
||||||
GLuint name_ = 0;
|
|
||||||
// Note: when scale is not 1, we still give the nominal size of the image.
|
|
||||||
int width_ = 0;
|
|
||||||
int height_ = 0;
|
|
||||||
GpuBuffer gpu_buffer_;
|
|
||||||
int plane_ = 0;
|
|
||||||
DetachFn detach_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
||||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
|
||||||
inline int GpuBuffer::width() const {
|
|
||||||
return static_cast<int>(CVPixelBufferGetWidth(*pixel_buffer_));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int GpuBuffer::height() const {
|
|
||||||
return static_cast<int>(CVPixelBufferGetHeight(*pixel_buffer_));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline GpuBufferFormat GpuBuffer::format() const {
|
|
||||||
return GpuBufferFormatForCVPixelFormat(
|
|
||||||
CVPixelBufferGetPixelFormatType(*pixel_buffer_));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
|
||||||
return pixel_buffer_ == other;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
|
||||||
return pixel_buffer_ == other.pixel_buffer_;
|
return pixel_buffer_ == other.pixel_buffer_;
|
||||||
}
|
|
||||||
|
|
||||||
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
|
|
||||||
pixel_buffer_.reset(other);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
inline int GpuBuffer::width() const { return texture_buffer_->width(); }
|
|
||||||
|
|
||||||
inline int GpuBuffer::height() const { return texture_buffer_->height(); }
|
|
||||||
|
|
||||||
inline GpuBufferFormat GpuBuffer::format() const {
|
|
||||||
return texture_buffer_->format();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
|
||||||
return texture_buffer_ == other;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
|
||||||
return texture_buffer_ == other.texture_buffer_;
|
return texture_buffer_ == other.texture_buffer_;
|
||||||
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
}
|
}
|
||||||
|
|
||||||
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
|
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
|
||||||
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
|
pixel_buffer_.reset(other);
|
||||||
|
#else
|
||||||
texture_buffer_ = other;
|
texture_buffer_ = other;
|
||||||
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_GPU_GPU_BUFFER_H_
|
#endif // MEDIAPIPE_GPU_GPU_BUFFER_H_
|
||||||
|
|
41
mediapipe/gpu/gpu_buffer_storage.h
Normal file
41
mediapipe/gpu/gpu_buffer_storage.h
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
|
||||||
|
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
class GlTextureView;
|
||||||
|
class GpuBuffer;
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
using mediapipe::GlTextureView;
|
||||||
|
using mediapipe::GpuBuffer;
|
||||||
|
using mediapipe::GpuBufferFormat;
|
||||||
|
|
||||||
|
class GlTextureViewManager {
|
||||||
|
public:
|
||||||
|
virtual ~GlTextureViewManager() = default;
|
||||||
|
virtual GlTextureView GetGlTextureReadView(
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const = 0;
|
||||||
|
virtual GlTextureView GetGlTextureWriteView(
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) = 0;
|
||||||
|
virtual void ViewDoneWriting(const GlTextureView& view) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class GpuBufferStorage : public GlTextureViewManager {
|
||||||
|
public:
|
||||||
|
virtual ~GpuBufferStorage() = default;
|
||||||
|
virtual int width() const = 0;
|
||||||
|
virtual int height() const = 0;
|
||||||
|
virtual GpuBufferFormat format() const = 0;
|
||||||
|
virtual std::unique_ptr<ImageFrame> AsImageFrame() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
|
116
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc
Normal file
116
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
|
||||||
|
|
||||||
|
#include "mediapipe/gpu/gl_context.h"
|
||||||
|
#include "mediapipe/objc/util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
#if TARGET_OS_OSX
|
||||||
|
typedef CVOpenGLTextureRef CVTextureType;
|
||||||
|
#else
|
||||||
|
typedef CVOpenGLESTextureRef CVTextureType;
|
||||||
|
#endif // TARGET_OS_OSX
|
||||||
|
|
||||||
|
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureReadView(
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
|
||||||
|
CVReturn err;
|
||||||
|
auto gl_context = GlContext::GetCurrent();
|
||||||
|
CHECK(gl_context);
|
||||||
|
#if TARGET_OS_OSX
|
||||||
|
CVTextureType cv_texture_temp;
|
||||||
|
err = CVOpenGLTextureCacheCreateTextureFromImage(
|
||||||
|
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
|
||||||
|
&cv_texture_temp);
|
||||||
|
CHECK(cv_texture_temp && !err)
|
||||||
|
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
|
||||||
|
CFHolder<CVTextureType> cv_texture;
|
||||||
|
cv_texture.adopt(cv_texture_temp);
|
||||||
|
return GlTextureView(
|
||||||
|
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
|
||||||
|
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
|
||||||
|
[cv_texture](
|
||||||
|
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
|
||||||
|
#else
|
||||||
|
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||||
|
format(), plane, gl_context->GetGlVersion());
|
||||||
|
CVTextureType cv_texture_temp;
|
||||||
|
err = CVOpenGLESTextureCacheCreateTextureFromImage(
|
||||||
|
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
|
||||||
|
GL_TEXTURE_2D, info.gl_internal_format, width() / info.downscale,
|
||||||
|
height() / info.downscale, info.gl_format, info.gl_type, plane,
|
||||||
|
&cv_texture_temp);
|
||||||
|
CHECK(cv_texture_temp && !err)
|
||||||
|
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
|
||||||
|
CFHolder<CVTextureType> cv_texture;
|
||||||
|
cv_texture.adopt(cv_texture_temp);
|
||||||
|
return GlTextureView(
|
||||||
|
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
|
||||||
|
CVOpenGLESTextureGetName(*cv_texture), width(), height(),
|
||||||
|
std::move(gpu_buffer), plane,
|
||||||
|
[cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ },
|
||||||
|
// TODO: make GetGlTextureView for write view non-const, remove cast
|
||||||
|
// Note: we have to copy *this here because this storage is currently
|
||||||
|
// stored in GpuBuffer by value, and so the this pointer becomes invalid
|
||||||
|
// if the GpuBuffer is moved/copied. TODO: fix this.
|
||||||
|
[me = *this](const mediapipe::GlTextureView& view) {
|
||||||
|
const_cast<GpuBufferStorageCvPixelBuffer*>(&me)->ViewDoneWriting(view);
|
||||||
|
});
|
||||||
|
#endif // TARGET_OS_OSX
|
||||||
|
}
|
||||||
|
|
||||||
|
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureWriteView(
|
||||||
|
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
|
||||||
|
// For this storage there is currently no difference between read and write
|
||||||
|
// views, so we delegate to the read method.
|
||||||
|
return GetGlTextureReadView(std::move(gpu_buffer), plane);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) {
|
||||||
|
#if TARGET_IPHONE_SIMULATOR
|
||||||
|
CVPixelBufferRef pixel_buffer = **this;
|
||||||
|
CHECK(pixel_buffer);
|
||||||
|
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
||||||
|
CHECK(err == kCVReturnSuccess)
|
||||||
|
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
||||||
|
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
|
||||||
|
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
|
||||||
|
uint8_t* pixel_ptr =
|
||||||
|
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
||||||
|
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
||||||
|
// TODO: restore previous framebuffer? Move this to helper so we
|
||||||
|
// can use BindFramebuffer?
|
||||||
|
glViewport(0, 0, view.width(), view.height());
|
||||||
|
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
||||||
|
view.name(), 0);
|
||||||
|
|
||||||
|
size_t contiguous_bytes_per_row = view.width() * 4;
|
||||||
|
if (bytes_per_row == contiguous_bytes_per_row) {
|
||||||
|
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||||
|
pixel_ptr);
|
||||||
|
} else {
|
||||||
|
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
||||||
|
view.height());
|
||||||
|
uint8_t* temp_ptr = contiguous_buffer.data();
|
||||||
|
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||||
|
temp_ptr);
|
||||||
|
for (int i = 0; i < view.height(); ++i) {
|
||||||
|
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
||||||
|
temp_ptr += contiguous_bytes_per_row;
|
||||||
|
pixel_ptr += bytes_per_row;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
||||||
|
}
|
||||||
|
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
||||||
|
CHECK(err == kCVReturnSuccess)
|
||||||
|
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ImageFrame> GpuBufferStorageCvPixelBuffer::AsImageFrame()
|
||||||
|
const {
|
||||||
|
return CreateImageFrameForCVPixelBuffer(**this);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
41
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h
Normal file
41
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
|
||||||
|
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
|
||||||
|
|
||||||
|
#include <CoreVideo/CoreVideo.h>
|
||||||
|
|
||||||
|
#include "mediapipe/gpu/gl_texture_view.h"
|
||||||
|
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||||
|
#include "mediapipe/objc/CFHolder.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
class GlContext;
|
||||||
|
|
||||||
|
class GpuBufferStorageCvPixelBuffer
|
||||||
|
: public mediapipe::internal::GpuBufferStorage,
|
||||||
|
public CFHolder<CVPixelBufferRef> {
|
||||||
|
public:
|
||||||
|
using CFHolder<CVPixelBufferRef>::CFHolder;
|
||||||
|
GpuBufferStorageCvPixelBuffer(const CFHolder<CVPixelBufferRef>& other)
|
||||||
|
: CFHolder(other) {}
|
||||||
|
GpuBufferStorageCvPixelBuffer(CFHolder<CVPixelBufferRef>&& other)
|
||||||
|
: CFHolder(std::move(other)) {}
|
||||||
|
int width() const { return static_cast<int>(CVPixelBufferGetWidth(**this)); }
|
||||||
|
int height() const {
|
||||||
|
return static_cast<int>(CVPixelBufferGetHeight(**this));
|
||||||
|
}
|
||||||
|
virtual GpuBufferFormat format() const {
|
||||||
|
return GpuBufferFormatForCVPixelFormat(
|
||||||
|
CVPixelBufferGetPixelFormatType(**this));
|
||||||
|
}
|
||||||
|
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||||
|
int plane) const override;
|
||||||
|
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||||
|
int plane) override;
|
||||||
|
std::unique_ptr<ImageFrame> AsImageFrame() const override;
|
||||||
|
void ViewDoneWriting(const GlTextureView& view) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
|
|
@ -64,6 +64,13 @@ public class AppTextureFrame implements TextureFrame {
|
||||||
return timestamp;
|
return timestamp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns true if a call to waitUntilReleased() would block waiting for release. */
|
||||||
|
public boolean isNotYetReleased() {
|
||||||
|
synchronized (this) {
|
||||||
|
return inUse && releaseSyncToken == null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Waits until the consumer is done with the texture.
|
* Waits until the consumer is done with the texture.
|
||||||
*
|
*
|
||||||
|
|
|
@ -26,7 +26,8 @@ android_library(
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_gpu_image.binarypb",
|
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_gpu_image.binarypb",
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image.binarypb",
|
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image.binarypb",
|
||||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
|
||||||
|
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||||
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
||||||
],
|
],
|
||||||
assets_dir = "",
|
assets_dir = "",
|
||||||
|
|
|
@ -78,6 +78,7 @@ public class Hands extends ImageSolutionBase {
|
||||||
Connection.create(HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP),
|
Connection.create(HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP),
|
||||||
Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP));
|
Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP));
|
||||||
|
|
||||||
|
private static final String MODEL_COMPLEXITY = "model_complexity";
|
||||||
private static final String NUM_HANDS = "num_hands";
|
private static final String NUM_HANDS = "num_hands";
|
||||||
private static final String USE_PREV_LANDMARKS = "use_prev_landmarks";
|
private static final String USE_PREV_LANDMARKS = "use_prev_landmarks";
|
||||||
private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb";
|
private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb";
|
||||||
|
@ -131,6 +132,7 @@ public class Hands extends ImageSolutionBase {
|
||||||
initialize(context, solutionInfo, outputHandler);
|
initialize(context, solutionInfo, outputHandler);
|
||||||
Map<String, Packet> inputSidePackets = new HashMap<>();
|
Map<String, Packet> inputSidePackets = new HashMap<>();
|
||||||
inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands()));
|
inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands()));
|
||||||
|
inputSidePackets.put(MODEL_COMPLEXITY, packetCreator.createInt32(options.modelComplexity()));
|
||||||
inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode()));
|
inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode()));
|
||||||
start(inputSidePackets);
|
start(inputSidePackets);
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,10 @@ import com.google.auto.value.AutoValue;
|
||||||
* <p>maxNumHands: Maximum number of hands to detect. See details in
|
* <p>maxNumHands: Maximum number of hands to detect. See details in
|
||||||
* https://solutions.mediapipe.dev/hands#max_num_hands.
|
* https://solutions.mediapipe.dev/hands#max_num_hands.
|
||||||
*
|
*
|
||||||
|
* <p>modelComplexity: Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
* inference latency generally go up with the model complexity. See details in
|
||||||
|
* https://solutions.mediapipe.dev/hands#model_complexity.
|
||||||
|
*
|
||||||
* <p>minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for hand detection to be
|
* <p>minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for hand detection to be
|
||||||
* considered successful. See details in
|
* considered successful. See details in
|
||||||
* https://solutions.mediapipe.dev/hands#min_detection_confidence.
|
* https://solutions.mediapipe.dev/hands#min_detection_confidence.
|
||||||
|
@ -43,6 +47,8 @@ public abstract class HandsOptions {
|
||||||
|
|
||||||
public abstract int maxNumHands();
|
public abstract int maxNumHands();
|
||||||
|
|
||||||
|
public abstract int modelComplexity();
|
||||||
|
|
||||||
public abstract float minDetectionConfidence();
|
public abstract float minDetectionConfidence();
|
||||||
|
|
||||||
public abstract float minTrackingConfidence();
|
public abstract float minTrackingConfidence();
|
||||||
|
@ -59,6 +65,7 @@ public abstract class HandsOptions {
|
||||||
public Builder withDefaultValues() {
|
public Builder withDefaultValues() {
|
||||||
return setStaticImageMode(false)
|
return setStaticImageMode(false)
|
||||||
.setMaxNumHands(2)
|
.setMaxNumHands(2)
|
||||||
|
.setModelComplexity(1)
|
||||||
.setMinDetectionConfidence(0.5f)
|
.setMinDetectionConfidence(0.5f)
|
||||||
.setMinTrackingConfidence(0.5f)
|
.setMinTrackingConfidence(0.5f)
|
||||||
.setRunOnGpu(true);
|
.setRunOnGpu(true);
|
||||||
|
@ -68,6 +75,8 @@ public abstract class HandsOptions {
|
||||||
|
|
||||||
public abstract Builder setMaxNumHands(int value);
|
public abstract Builder setMaxNumHands(int value);
|
||||||
|
|
||||||
|
public abstract Builder setModelComplexity(int value);
|
||||||
|
|
||||||
public abstract Builder setMinDetectionConfidence(float value);
|
public abstract Builder setMinDetectionConfidence(float value);
|
||||||
|
|
||||||
public abstract Builder setMinTrackingConfidence(float value);
|
public abstract Builder setMinTrackingConfidence(float value);
|
||||||
|
|
|
@ -22,16 +22,30 @@ licenses(["notice"])
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
exports_files([
|
exports_files([
|
||||||
"hand_landmark.tflite",
|
"hand_landmark_full.tflite",
|
||||||
|
"hand_landmark_lite.tflite",
|
||||||
"hand_landmark_sparse.tflite",
|
"hand_landmark_sparse.tflite",
|
||||||
"handedness.txt",
|
"handedness.txt",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
mediapipe_simple_subgraph(
|
||||||
|
name = "hand_landmark_model_loader",
|
||||||
|
graph = "hand_landmark_model_loader.pbtxt",
|
||||||
|
register_as = "HandLandmarkModelLoader",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||||
|
"//mediapipe/calculators/tflite:tflite_model_calculator",
|
||||||
|
"//mediapipe/calculators/util:local_file_contents_calculator",
|
||||||
|
"//mediapipe/framework/tool:switch_container",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_simple_subgraph(
|
mediapipe_simple_subgraph(
|
||||||
name = "hand_landmark_cpu",
|
name = "hand_landmark_cpu",
|
||||||
graph = "hand_landmark_cpu.pbtxt",
|
graph = "hand_landmark_cpu.pbtxt",
|
||||||
register_as = "HandLandmarkCpu",
|
register_as = "HandLandmarkCpu",
|
||||||
deps = [
|
deps = [
|
||||||
|
":hand_landmark_model_loader",
|
||||||
"//mediapipe/calculators/core:gate_calculator",
|
"//mediapipe/calculators/core:gate_calculator",
|
||||||
"//mediapipe/calculators/core:split_vector_calculator",
|
"//mediapipe/calculators/core:split_vector_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||||
|
@ -50,6 +64,7 @@ mediapipe_simple_subgraph(
|
||||||
graph = "hand_landmark_gpu.pbtxt",
|
graph = "hand_landmark_gpu.pbtxt",
|
||||||
register_as = "HandLandmarkGpu",
|
register_as = "HandLandmarkGpu",
|
||||||
deps = [
|
deps = [
|
||||||
|
":hand_landmark_model_loader",
|
||||||
"//mediapipe/calculators/core:gate_calculator",
|
"//mediapipe/calculators/core:gate_calculator",
|
||||||
"//mediapipe/calculators/core:split_vector_calculator",
|
"//mediapipe/calculators/core:split_vector_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||||
|
|
Binary file not shown.
|
@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
|
||||||
# (NormalizedRect)
|
# (NormalizedRect)
|
||||||
input_stream: "ROI:hand_rect"
|
input_stream: "ROI:hand_rect"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
|
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
|
||||||
# NOTE: if a hand is not present within the given ROI, for this particular
|
# NOTE: if a hand is not present within the given ROI, for this particular
|
||||||
# timestamp there will not be an output packet in the LANDMARKS stream. However,
|
# timestamp there will not be an output packet in the LANDMARKS stream. However,
|
||||||
|
@ -40,16 +45,23 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Loads the hand landmark TF Lite model.
|
||||||
|
node {
|
||||||
|
calculator: "HandLandmarkModelLoader"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
output_side_packet: "MODEL:model"
|
||||||
|
}
|
||||||
|
|
||||||
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
|
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
|
||||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||||
# scores.
|
# scores.
|
||||||
node {
|
node {
|
||||||
calculator: "InferenceCalculator"
|
calculator: "InferenceCalculator"
|
||||||
|
input_side_packet: "MODEL:model"
|
||||||
input_stream: "TENSORS:input_tensor"
|
input_stream: "TENSORS:input_tensor"
|
||||||
output_stream: "TENSORS:output_tensors"
|
output_stream: "TENSORS:output_tensors"
|
||||||
options: {
|
options: {
|
||||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
[mediapipe.InferenceCalculatorOptions.ext] {
|
||||||
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
|
||||||
delegate {
|
delegate {
|
||||||
xnnpack {}
|
xnnpack {}
|
||||||
}
|
}
|
||||||
|
|
BIN
mediapipe/modules/hand_landmark/hand_landmark_full.tflite
Executable file
BIN
mediapipe/modules/hand_landmark/hand_landmark_full.tflite
Executable file
Binary file not shown.
|
@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
|
||||||
# (NormalizedRect)
|
# (NormalizedRect)
|
||||||
input_stream: "ROI:hand_rect"
|
input_stream: "ROI:hand_rect"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
|
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
|
||||||
# NOTE: if a hand is not present within the given ROI, for this particular
|
# NOTE: if a hand is not present within the given ROI, for this particular
|
||||||
# timestamp there will not be an output packet in the LANDMARKS stream. However,
|
# timestamp there will not be an output packet in the LANDMARKS stream. However,
|
||||||
|
@ -41,18 +46,21 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Loads the hand landmark TF Lite model.
|
||||||
|
node {
|
||||||
|
calculator: "HandLandmarkModelLoader"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
output_side_packet: "MODEL:model"
|
||||||
|
}
|
||||||
|
|
||||||
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
|
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
|
||||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||||
# scores.
|
# scores.
|
||||||
node {
|
node {
|
||||||
calculator: "InferenceCalculator"
|
calculator: "InferenceCalculator"
|
||||||
|
input_side_packet: "MODEL:model"
|
||||||
input_stream: "TENSORS:input_tensor"
|
input_stream: "TENSORS:input_tensor"
|
||||||
output_stream: "TENSORS:output_tensors"
|
output_stream: "TENSORS:output_tensors"
|
||||||
options: {
|
|
||||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
|
||||||
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Splits a vector of tensors to multiple vectors according to the ranges
|
# Splits a vector of tensors to multiple vectors according to the ranges
|
||||||
|
|
BIN
mediapipe/modules/hand_landmark/hand_landmark_lite.tflite
Executable file
BIN
mediapipe/modules/hand_landmark/hand_landmark_lite.tflite
Executable file
Binary file not shown.
|
@ -0,0 +1,63 @@
|
||||||
|
# MediaPipe graph to load a selected hand landmark TF Lite model.
|
||||||
|
|
||||||
|
type: "HandLandmarkModelLoader"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
|
# TF Lite model represented as a FlatBuffer.
|
||||||
|
# (std::unique_ptr<tflite::FlatBufferModel, std::function<void(tflite::FlatBufferModel*)>>)
|
||||||
|
output_side_packet: "MODEL:model"
|
||||||
|
|
||||||
|
# Determines path to the desired pose landmark model file.
|
||||||
|
node {
|
||||||
|
calculator: "SwitchContainer"
|
||||||
|
input_side_packet: "SELECT:model_complexity"
|
||||||
|
output_side_packet: "PACKET:model_path"
|
||||||
|
options: {
|
||||||
|
[mediapipe.SwitchContainerOptions.ext] {
|
||||||
|
select: 1
|
||||||
|
contained_node: {
|
||||||
|
calculator: "ConstantSidePacketCalculator"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet {
|
||||||
|
string_value: "mediapipe/modules/hand_landmark/hand_landmark_lite.tflite"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contained_node: {
|
||||||
|
calculator: "ConstantSidePacketCalculator"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet {
|
||||||
|
string_value: "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Loads the file in the specified path into a blob.
|
||||||
|
node {
|
||||||
|
calculator: "LocalFileContentsCalculator"
|
||||||
|
input_side_packet: "FILE_PATH:model_path"
|
||||||
|
output_side_packet: "CONTENTS:model_blob"
|
||||||
|
options: {
|
||||||
|
[mediapipe.LocalFileContentsCalculatorOptions.ext]: {
|
||||||
|
text_mode: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Converts the input blob into a TF Lite model.
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteModelCalculator"
|
||||||
|
input_side_packet: "MODEL_BLOB:model_blob"
|
||||||
|
output_side_packet: "MODEL:model"
|
||||||
|
}
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
||||||
# Max number of hands to detect/track. (int)
|
# Max number of hands to detect/track. (int)
|
||||||
input_side_packet: "NUM_HANDS:num_hands"
|
input_side_packet: "NUM_HANDS:num_hands"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
# Whether landmarks on the previous image should be used to help localize
|
# Whether landmarks on the previous image should be used to help localize
|
||||||
# landmarks on the current image. (bool)
|
# landmarks on the current image. (bool)
|
||||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||||
|
@ -177,6 +182,7 @@ node {
|
||||||
# Detect hand landmarks for the specific hand rect.
|
# Detect hand landmarks for the specific hand rect.
|
||||||
node {
|
node {
|
||||||
calculator: "HandLandmarkCpu"
|
calculator: "HandLandmarkCpu"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
input_stream: "IMAGE:image_for_landmarks"
|
input_stream: "IMAGE:image_for_landmarks"
|
||||||
input_stream: "ROI:single_hand_rect"
|
input_stream: "ROI:single_hand_rect"
|
||||||
output_stream: "LANDMARKS:single_hand_landmarks"
|
output_stream: "LANDMARKS:single_hand_landmarks"
|
||||||
|
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
||||||
# Max number of hands to detect/track. (int)
|
# Max number of hands to detect/track. (int)
|
||||||
input_side_packet: "NUM_HANDS:num_hands"
|
input_side_packet: "NUM_HANDS:num_hands"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
# Whether landmarks on the previous image should be used to help localize
|
# Whether landmarks on the previous image should be used to help localize
|
||||||
# landmarks on the current image. (bool)
|
# landmarks on the current image. (bool)
|
||||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||||
|
@ -85,6 +90,7 @@ node {
|
||||||
calculator: "HandLandmarkTrackingCpu"
|
calculator: "HandLandmarkTrackingCpu"
|
||||||
input_stream: "IMAGE:image_frame"
|
input_stream: "IMAGE:image_frame"
|
||||||
input_side_packet: "NUM_HANDS:num_hands"
|
input_side_packet: "NUM_HANDS:num_hands"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||||
output_stream: "LANDMARKS:multi_hand_landmarks"
|
output_stream: "LANDMARKS:multi_hand_landmarks"
|
||||||
output_stream: "HANDEDNESS:multi_handedness"
|
output_stream: "HANDEDNESS:multi_handedness"
|
||||||
|
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
||||||
# Max number of hands to detect/track. (int)
|
# Max number of hands to detect/track. (int)
|
||||||
input_side_packet: "NUM_HANDS:num_hands"
|
input_side_packet: "NUM_HANDS:num_hands"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
# Whether landmarks on the previous image should be used to help localize
|
# Whether landmarks on the previous image should be used to help localize
|
||||||
# landmarks on the current image. (bool)
|
# landmarks on the current image. (bool)
|
||||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||||
|
@ -178,6 +183,7 @@ node {
|
||||||
# Detect hand landmarks for the specific hand rect.
|
# Detect hand landmarks for the specific hand rect.
|
||||||
node {
|
node {
|
||||||
calculator: "HandLandmarkGpu"
|
calculator: "HandLandmarkGpu"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
input_stream: "IMAGE:image_for_landmarks"
|
input_stream: "IMAGE:image_for_landmarks"
|
||||||
input_stream: "ROI:single_hand_rect"
|
input_stream: "ROI:single_hand_rect"
|
||||||
output_stream: "LANDMARKS:single_hand_landmarks"
|
output_stream: "LANDMARKS:single_hand_landmarks"
|
||||||
|
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
||||||
# Max number of hands to detect/track. (int)
|
# Max number of hands to detect/track. (int)
|
||||||
input_side_packet: "NUM_HANDS:num_hands"
|
input_side_packet: "NUM_HANDS:num_hands"
|
||||||
|
|
||||||
|
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||||
|
# inference latency generally go up with the model complexity. If unspecified,
|
||||||
|
# functions as set to 1. (int)
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
|
|
||||||
# Whether landmarks on the previous image should be used to help localize
|
# Whether landmarks on the previous image should be used to help localize
|
||||||
# landmarks on the current image. (bool)
|
# landmarks on the current image. (bool)
|
||||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||||
|
@ -85,6 +90,7 @@ node {
|
||||||
calculator: "HandLandmarkTrackingGpu"
|
calculator: "HandLandmarkTrackingGpu"
|
||||||
input_stream: "IMAGE:gpu_buffer"
|
input_stream: "IMAGE:gpu_buffer"
|
||||||
input_side_packet: "NUM_HANDS:num_hands"
|
input_side_packet: "NUM_HANDS:num_hands"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||||
output_stream: "LANDMARKS:multi_hand_landmarks"
|
output_stream: "LANDMARKS:multi_hand_landmarks"
|
||||||
output_stream: "HANDEDNESS:multi_handedness"
|
output_stream: "HANDEDNESS:multi_handedness"
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
# - "face_landmark.tflite" is available at
|
# - "face_landmark.tflite" is available at
|
||||||
# "mediapipe/modules/face_landmark/face_landmark.tflite"
|
# "mediapipe/modules/face_landmark/face_landmark.tflite"
|
||||||
#
|
#
|
||||||
# - "hand_landmark.tflite" is available at
|
# - "hand_landmark_full.tflite" is available at
|
||||||
# "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
# "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
|
||||||
#
|
#
|
||||||
# - "hand_recrop.tflite" is available at
|
# - "hand_recrop.tflite" is available at
|
||||||
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
# - "face_landmark.tflite" is available at
|
# - "face_landmark.tflite" is available at
|
||||||
# "mediapipe/modules/face_landmark/face_landmark.tflite"
|
# "mediapipe/modules/face_landmark/face_landmark.tflite"
|
||||||
#
|
#
|
||||||
# - "hand_landmark.tflite" is available at
|
# - "hand_landmark_full.tflite" is available at
|
||||||
# "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
# "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
|
||||||
#
|
#
|
||||||
# - "hand_recrop.tflite" is available at
|
# - "hand_recrop.tflite" is available at
|
||||||
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
||||||
|
|
|
@ -237,7 +237,7 @@ absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FilterDetectionCalculator::IsValidLabel(const std::string& label) {
|
bool FilterDetectionCalculator::IsValidLabel(const std::string& label) {
|
||||||
bool match = !limit_labels_ || ContainsKey(allowed_labels_, label);
|
bool match = !limit_labels_ || allowed_labels_.contains(label);
|
||||||
if (!match) {
|
if (!match) {
|
||||||
// If no exact match is found, check for regular expression
|
// If no exact match is found, check for regular expression
|
||||||
// comparions in the allowed_labels.
|
// comparions in the allowed_labels.
|
||||||
|
|
|
@ -21,6 +21,9 @@ cc_library(
|
||||||
features = ["-parse_headers"],
|
features = ["-parse_headers"],
|
||||||
linkopts = [
|
linkopts = [
|
||||||
"-framework Accelerate",
|
"-framework Accelerate",
|
||||||
|
"-framework CoreFoundation",
|
||||||
|
"-framework CoreGraphics",
|
||||||
|
"-framework CoreVideo",
|
||||||
],
|
],
|
||||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||||
deps = [
|
deps = [
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#import <Accelerate/Accelerate.h>
|
#import <Accelerate/Accelerate.h>
|
||||||
#import <CoreFoundation/CoreFoundation.h>
|
#import <CoreFoundation/CoreFoundation.h>
|
||||||
|
#import <CoreGraphics/CoreGraphics.h>
|
||||||
#import <CoreVideo/CoreVideo.h>
|
#import <CoreVideo/CoreVideo.h>
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
|
|
@ -89,6 +89,7 @@ class Hands(SolutionBase):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
static_image_mode=False,
|
static_image_mode=False,
|
||||||
max_num_hands=2,
|
max_num_hands=2,
|
||||||
|
model_complexity=1,
|
||||||
min_detection_confidence=0.5,
|
min_detection_confidence=0.5,
|
||||||
min_tracking_confidence=0.5):
|
min_tracking_confidence=0.5):
|
||||||
"""Initializes a MediaPipe Hand object.
|
"""Initializes a MediaPipe Hand object.
|
||||||
|
@ -99,6 +100,10 @@ class Hands(SolutionBase):
|
||||||
https://solutions.mediapipe.dev/hands#static_image_mode.
|
https://solutions.mediapipe.dev/hands#static_image_mode.
|
||||||
max_num_hands: Maximum number of hands to detect. See details in
|
max_num_hands: Maximum number of hands to detect. See details in
|
||||||
https://solutions.mediapipe.dev/hands#max_num_hands.
|
https://solutions.mediapipe.dev/hands#max_num_hands.
|
||||||
|
model_complexity: Complexity of the hand landmark model: 0 or 1.
|
||||||
|
Landmark accuracy as well as inference latency generally go up with the
|
||||||
|
model complexity. See details in
|
||||||
|
https://solutions.mediapipe.dev/hands#model_complexity.
|
||||||
min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for hand
|
min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for hand
|
||||||
detection to be considered successful. See details in
|
detection to be considered successful. See details in
|
||||||
https://solutions.mediapipe.dev/hands#min_detection_confidence.
|
https://solutions.mediapipe.dev/hands#min_detection_confidence.
|
||||||
|
@ -109,6 +114,7 @@ class Hands(SolutionBase):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
binary_graph_path=_BINARYPB_FILE_PATH,
|
binary_graph_path=_BINARYPB_FILE_PATH,
|
||||||
side_inputs={
|
side_inputs={
|
||||||
|
'model_complexity': model_complexity,
|
||||||
'num_hands': max_num_hands,
|
'num_hands': max_num_hands,
|
||||||
'use_prev_landmarks': not static_image_mode,
|
'use_prev_landmarks': not static_image_mode,
|
||||||
},
|
},
|
||||||
|
|
|
@ -32,7 +32,8 @@ from mediapipe.python.solutions import hands as mp_hands
|
||||||
|
|
||||||
|
|
||||||
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
||||||
DIFF_THRESHOLD = 20 # pixels
|
LITE_MODEL_DIFF_THRESHOLD = 25 # pixels
|
||||||
|
FULL_MODEL_DIFF_THRESHOLD = 20 # pixels
|
||||||
EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
|
EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
|
||||||
[289, 237], [322, 203], [219, 216],
|
[289, 237], [322, 203], [219, 216],
|
||||||
[238, 138], [249, 90], [253, 51],
|
[238, 138], [249, 90], [253, 51],
|
||||||
|
@ -40,7 +41,7 @@ EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
|
||||||
[185, 19], [138, 208], [131, 127],
|
[185, 19], [138, 208], [131, 127],
|
||||||
[124, 77], [117, 36], [106, 222],
|
[124, 77], [117, 36], [106, 222],
|
||||||
[92, 159], [79, 124], [68, 93]],
|
[92, 159], [79, 124], [68, 93]],
|
||||||
[[580, 36], [504, 50], [459, 94],
|
[[580, 34], [504, 50], [459, 94],
|
||||||
[429, 146], [397, 182], [507, 167],
|
[429, 146], [397, 182], [507, 167],
|
||||||
[479, 245], [469, 292], [464, 330],
|
[479, 245], [469, 292], [464, 330],
|
||||||
[545, 180], [534, 265], [533, 319],
|
[545, 180], [534, 265], [533, 319],
|
||||||
|
@ -75,14 +76,18 @@ class HandsTest(parameterized.TestCase):
|
||||||
self.assertIsNone(results.multi_hand_landmarks)
|
self.assertIsNone(results.multi_hand_landmarks)
|
||||||
self.assertIsNone(results.multi_handedness)
|
self.assertIsNone(results.multi_handedness)
|
||||||
|
|
||||||
@parameterized.named_parameters(('static_image_mode', True, 1),
|
@parameterized.named_parameters(
|
||||||
('video_mode', False, 5))
|
('static_image_mode_with_lite_model', True, 0, 5),
|
||||||
def test_multi_hands(self, static_image_mode, num_frames):
|
('video_mode_with_lite_model', False, 0, 10),
|
||||||
|
('static_image_mode_with_full_model', True, 1, 5),
|
||||||
|
('video_mode_with_full_model', False, 1, 10))
|
||||||
|
def test_multi_hands(self, static_image_mode, model_complexity, num_frames):
|
||||||
image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg')
|
image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg')
|
||||||
image = cv2.imread(image_path)
|
image = cv2.imread(image_path)
|
||||||
with mp_hands.Hands(
|
with mp_hands.Hands(
|
||||||
static_image_mode=static_image_mode,
|
static_image_mode=static_image_mode,
|
||||||
max_num_hands=2,
|
max_num_hands=2,
|
||||||
|
model_complexity=model_complexity,
|
||||||
min_detection_confidence=0.5) as hands:
|
min_detection_confidence=0.5) as hands:
|
||||||
for idx in range(num_frames):
|
for idx in range(num_frames):
|
||||||
results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||||
|
@ -104,7 +109,8 @@ class HandsTest(parameterized.TestCase):
|
||||||
prediction_error = np.abs(
|
prediction_error = np.abs(
|
||||||
np.asarray(multi_hand_coordinates) -
|
np.asarray(multi_hand_coordinates) -
|
||||||
np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION))
|
np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION))
|
||||||
npt.assert_array_less(prediction_error, DIFF_THRESHOLD)
|
diff_threshold = LITE_MODEL_DIFF_THRESHOLD if model_complexity == 0 else FULL_MODEL_DIFF_THRESHOLD
|
||||||
|
npt.assert_array_less(prediction_error, diff_threshold)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
@ -81,6 +82,32 @@ ObjectDef GetSSBOObjectDef(int channels) {
|
||||||
return gpu_object_def;
|
return gpu_object_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
|
||||||
|
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
|
||||||
|
cl::InferenceOptions result{};
|
||||||
|
result.priority1 = options.priority1;
|
||||||
|
result.priority2 = options.priority2;
|
||||||
|
result.priority3 = options.priority3;
|
||||||
|
result.usage = options.usage;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
|
||||||
|
const std::vector<BHWC>& expected) {
|
||||||
|
RET_CHECK_EQ(actual.size(), expected.size());
|
||||||
|
const int size = actual.size();
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
const auto& dims = actual[i].dimensions;
|
||||||
|
const BHWC& bhwc = expected[i];
|
||||||
|
RET_CHECK(dims.b == bhwc.b && dims.h == bhwc.h && dims.w == bhwc.w &&
|
||||||
|
dims.c == bhwc.c);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // __ANDROID__
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status TFLiteGPURunner::InitializeWithModel(
|
absl::Status TFLiteGPURunner::InitializeWithModel(
|
||||||
|
@ -139,16 +166,16 @@ absl::Status TFLiteGPURunner::Build() {
|
||||||
// try to build OpenCL first. If something goes wrong, fall back to OpenGL.
|
// try to build OpenCL first. If something goes wrong, fall back to OpenGL.
|
||||||
absl::Status status = InitializeOpenCL(&builder);
|
absl::Status status = InitializeOpenCL(&builder);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
LOG(INFO) << "OpenCL backend is used.";
|
VLOG(2) << "OpenCL backend is used.";
|
||||||
} else {
|
} else {
|
||||||
LOG(ERROR) << "Falling back to OpenGL: " << status.message();
|
VLOG(2) << "Falling back to OpenGL: " << status.message();
|
||||||
MP_RETURN_IF_ERROR(InitializeOpenGL(&builder));
|
MP_RETURN_IF_ERROR(InitializeOpenGL(&builder));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Both graphs are not needed anymore. Make sure they are deleted.
|
// GL graph not needed anymore, CL graph maybe needed for serialized model
|
||||||
|
// calculation.
|
||||||
graph_gl_.reset(nullptr);
|
graph_gl_.reset(nullptr);
|
||||||
graph_cl_.reset(nullptr);
|
|
||||||
|
|
||||||
// 2. Describe output/input objects for created builder.
|
// 2. Describe output/input objects for created builder.
|
||||||
for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) {
|
for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) {
|
||||||
|
@ -204,18 +231,57 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||||
env_options.serialized_binary_cache = serialized_binary_cache_;
|
env_options.serialized_binary_cache = serialized_binary_cache_;
|
||||||
}
|
}
|
||||||
cl::InferenceEnvironmentProperties properties;
|
cl::InferenceEnvironmentProperties properties;
|
||||||
cl::InferenceOptions cl_options;
|
|
||||||
cl_options.priority1 = options_.priority1;
|
|
||||||
cl_options.priority2 = options_.priority2;
|
|
||||||
cl_options.priority3 = options_.priority3;
|
|
||||||
cl_options.usage = options_.usage;
|
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
||||||
|
|
||||||
|
// Try to initialize from serialized model first.
|
||||||
|
if (!serialized_model_.empty()) {
|
||||||
|
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
|
||||||
|
if (init_status.ok()) {
|
||||||
|
serialized_model_used_ = true;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
VLOG(2) << "Failed to init from serialized model: [" << init_status
|
||||||
|
<< "]. Trying to init from scratch.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize from scratch.
|
||||||
|
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
|
||||||
|
GraphFloat32 graph_cl;
|
||||||
|
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
|
||||||
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
|
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
|
||||||
cl_options, std::move(*graph_cl_), builder));
|
cl_options, std::move(graph_cl), builder));
|
||||||
#endif
|
|
||||||
|
#endif // __ANDROID__
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
|
||||||
|
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
||||||
|
std::unique_ptr<InferenceBuilder>* builder) {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
cl_environment_->NewInferenceBuilder(serialized_model_, builder));
|
||||||
|
MP_RETURN_IF_ERROR(VerifyShapes(builder->get()->inputs(), input_shapes_));
|
||||||
|
return VerifyShapes(builder->get()->outputs(), output_shapes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
||||||
|
RET_CHECK(runner_) << "Runner is in invalid state.";
|
||||||
|
if (serialized_model_used_) {
|
||||||
|
return serialized_model_;
|
||||||
|
}
|
||||||
|
RET_CHECK(graph_cl_) << "CL graph is not initialized.";
|
||||||
|
GraphFloat32 graph_cl;
|
||||||
|
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
|
||||||
|
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
|
||||||
|
std::vector<uint8_t> serialized_model;
|
||||||
|
MP_RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
|
||||||
|
cl_options, std::move(graph_cl), &serialized_model));
|
||||||
|
return serialized_model;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // __ANDROID__
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -29,7 +30,7 @@
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#ifdef __ANDROID__
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||||
#endif
|
#endif // __ANDROID__
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
@ -90,11 +91,22 @@ class TFLiteGPURunner {
|
||||||
std::vector<uint8_t> GetSerializedBinaryCache() {
|
std::vector<uint8_t> GetSerializedBinaryCache() {
|
||||||
return cl_environment_->GetSerializedBinaryCache();
|
return cl_environment_->GetSerializedBinaryCache();
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
void SetSerializedModel(std::vector<uint8_t>&& serialized_model) {
|
||||||
|
serialized_model_ = std::move(serialized_model);
|
||||||
|
serialized_model_used_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
|
||||||
|
#endif // __ANDROID__
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
|
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
|
||||||
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
|
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
|
||||||
|
#ifdef __ANDROID__
|
||||||
|
absl::Status InitializeOpenCLFromSerializedModel(
|
||||||
|
std::unique_ptr<InferenceBuilder>* builder);
|
||||||
|
#endif // __ANDROID__
|
||||||
|
|
||||||
InferenceOptions options_;
|
InferenceOptions options_;
|
||||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||||
|
@ -103,9 +115,12 @@ class TFLiteGPURunner {
|
||||||
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||||
|
|
||||||
std::vector<uint8_t> serialized_binary_cache_;
|
std::vector<uint8_t> serialized_binary_cache_;
|
||||||
#endif
|
std::vector<uint8_t> serialized_model_;
|
||||||
|
bool serialized_model_used_ = false;
|
||||||
|
#endif // __ANDROID__
|
||||||
|
|
||||||
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
|
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is
|
||||||
|
// ready
|
||||||
std::unique_ptr<GraphFloat32> graph_gl_;
|
std::unique_ptr<GraphFloat32> graph_gl_;
|
||||||
std::unique_ptr<GraphFloat32> graph_cl_;
|
std::unique_ptr<GraphFloat32> graph_cl_;
|
||||||
std::unique_ptr<InferenceRunner> runner_;
|
std::unique_ptr<InferenceRunner> runner_;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user