Project import generated by Copybara.
GitOrigin-RevId: bbbbcb4f5174dea33525729ede47c770069157cd
This commit is contained in:
parent
33d683c671
commit
1faeaae7e5
|
@ -120,7 +120,7 @@ just 86.22%.
|
|||
### Hand Landmark Model
|
||||
|
||||
After the palm detection over the whole image our subsequent hand landmark
|
||||
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite)
|
||||
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite)
|
||||
performs precise keypoint localization of 21 3D hand-knuckle coordinates inside
|
||||
the detected hand regions via regression, that is direct coordinate prediction.
|
||||
The model learns a consistent internal hand pose representation and is robust
|
||||
|
@ -163,6 +163,11 @@ unrelated, images. Default to `false`.
|
|||
|
||||
Maximum number of hands to detect. Default to `2`.
|
||||
|
||||
#### model_complexity
|
||||
|
||||
Complexity of the hand landmark model: `0` or `1`. Landmark accuracy as well as
|
||||
inference latency generally go up with the model complexity. Default to `1`.
|
||||
|
||||
#### min_detection_confidence
|
||||
|
||||
Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the
|
||||
|
@ -212,6 +217,7 @@ Supported configuration options:
|
|||
|
||||
* [static_image_mode](#static_image_mode)
|
||||
* [max_num_hands](#max_num_hands)
|
||||
* [model_complexity](#model_complexity)
|
||||
* [min_detection_confidence](#min_detection_confidence)
|
||||
* [min_tracking_confidence](#min_tracking_confidence)
|
||||
|
||||
|
@ -260,6 +266,7 @@ with mp_hands.Hands(
|
|||
# For webcam input:
|
||||
cap = cv2.VideoCapture(0)
|
||||
with mp_hands.Hands(
|
||||
model_complexity=0,
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5) as hands:
|
||||
while cap.isOpened():
|
||||
|
@ -302,6 +309,7 @@ and a [fun application], and the following usage example.
|
|||
Supported configuration options:
|
||||
|
||||
* [maxNumHands](#max_num_hands)
|
||||
* [modelComplexity](#model_complexity)
|
||||
* [minDetectionConfidence](#min_detection_confidence)
|
||||
* [minTrackingConfidence](#min_tracking_confidence)
|
||||
|
||||
|
@ -351,6 +359,7 @@ const hands = new Hands({locateFile: (file) => {
|
|||
}});
|
||||
hands.setOptions({
|
||||
maxNumHands: 2,
|
||||
modelComplexity: 1,
|
||||
minDetectionConfidence: 0.5,
|
||||
minTrackingConfidence: 0.5
|
||||
});
|
||||
|
|
|
@ -58,10 +58,12 @@ one over the other.
|
|||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite),
|
||||
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
|
||||
* Hand landmark model:
|
||||
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite),
|
||||
[TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_lite.tflite),
|
||||
[TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite),
|
||||
[TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite),
|
||||
[TF.js model](https://tfhub.dev/mediapipe/handskeleton/1)
|
||||
* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
|
||||
* [Model card](https://mediapipe.page.link/handmc),
|
||||
[Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
|
||||
|
||||
### [Pose](https://google.github.io/mediapipe/solutions/pose)
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ hip midpoints.
|
|||
:----------------------------------------------------------------------------------------------------: |
|
||||
*Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* |
|
||||
|
||||
### Pose Landmark Model (BlazePose GHUM 3D)
|
||||
### Pose Landmark Model (BlazePose [GHUM](https://github.com/google-research/google-research/tree/master/ghum) 3D)
|
||||
|
||||
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
|
||||
(see figure below).
|
||||
|
@ -486,6 +486,7 @@ on how to build MediaPipe examples.
|
|||
[BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204)
|
||||
([presentation](https://youtu.be/YPpUOTRn5tA))
|
||||
* [Models and model cards](./models.md#pose)
|
||||
* [GHUM & GHUML: Generative 3D Human Shape and Articulated Pose Models](https://github.com/google-research/google-research/tree/master/ghum)
|
||||
* [Web demo](https://code.mediapipe.dev/codepen/pose)
|
||||
* [Python Colab](https://mediapipe.page.link/pose_py_colab)
|
||||
|
||||
|
|
|
@ -531,9 +531,13 @@ cc_test(
|
|||
":split_vector_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -47,4 +47,8 @@ typedef BeginLoopCalculator<std::vector<std::vector<Matrix>>>
|
|||
BeginLoopMatrixVectorCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
|
||||
|
||||
// A calculator to process std::vector<uint64_t>.
|
||||
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -14,7 +14,11 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -301,4 +305,99 @@ TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class PassThroughAndTsBoundUpdateNode : public mediapipe::api2::Node {
|
||||
public:
|
||||
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
|
||||
static constexpr mediapipe::api2::Output<int> kOutValue{"VALUE"};
|
||||
static constexpr mediapipe::api2::Output<int> kOutTsBoundUpdate{
|
||||
"TS_BOUND_UPDATE"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInValue, kOutValue, kOutTsBoundUpdate);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
kOutValue(cc).Send(kInValue(cc));
|
||||
kOutTsBoundUpdate(cc).SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(PassThroughAndTsBoundUpdateNode);
|
||||
|
||||
class ToOptionalNode : public mediapipe::api2::Node {
|
||||
public:
|
||||
static constexpr mediapipe::api2::Input<int> kTick{"TICK"};
|
||||
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
|
||||
static constexpr mediapipe::api2::Output<absl::optional<int>> kOutValue{
|
||||
"OUTPUT"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTick, kInValue, kOutValue);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (kInValue(cc).IsEmpty()) {
|
||||
kOutValue(cc).Send(absl::nullopt);
|
||||
} else {
|
||||
kOutValue(cc).Send({kInValue(cc).Get()});
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ToOptionalNode);
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "select"
|
||||
node {
|
||||
calculator: "PassThroughAndTsBoundUpdateNode"
|
||||
input_stream: "VALUE:select"
|
||||
output_stream: "VALUE:select_ps"
|
||||
output_stream: "TS_BOUND_UPDATE:ts_bound_update"
|
||||
}
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:select_ps"
|
||||
input_stream: "INPUT:1:ts_bound_update"
|
||||
input_stream: "SELECT:select"
|
||||
output_stream: "OUTPUT:select_or_ts_bound_update"
|
||||
}
|
||||
node {
|
||||
calculator: "ToOptionalNode"
|
||||
input_stream: "TICK:select"
|
||||
input_stream: "VALUE:select_or_ts_bound_update"
|
||||
output_stream: "OUTPUT:output"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
auto send_value_fn = [&](int value, Timestamp ts) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(
|
||||
graph.AddPacketToInputStream("select", MakePacket<int>(value).At(ts)));
|
||||
return graph.WaitUntilIdle();
|
||||
};
|
||||
|
||||
MP_ASSERT_OK(send_value_fn(0, Timestamp(1)));
|
||||
ASSERT_EQ(output_packets.size(), 1);
|
||||
EXPECT_EQ(output_packets[0].Get<absl::optional<int>>(), 0);
|
||||
|
||||
MP_ASSERT_OK(send_value_fn(1, Timestamp(2)));
|
||||
ASSERT_EQ(output_packets.size(), 2);
|
||||
EXPECT_EQ(output_packets[1].Get<absl::optional<int>>(), absl::nullopt);
|
||||
|
||||
MP_ASSERT_OK(send_value_fn(0, Timestamp(3)));
|
||||
ASSERT_EQ(output_packets.size(), 3);
|
||||
EXPECT_EQ(output_packets[2].Get<absl::optional<int>>(), 0);
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -34,7 +34,6 @@ option java_outer_classname = "InferenceCalculatorProto";
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
message InferenceCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional InferenceCalculatorOptions ext = 336783863;
|
||||
|
@ -69,8 +68,30 @@ message InferenceCalculatorOptions {
|
|||
// Load pre-compiled serialized binary cache to accelerate init process.
|
||||
// Only available for OpenCL delegate on Android.
|
||||
// Kernel caching will only be enabled if this path is set.
|
||||
//
|
||||
// NOTE: binary cache usage may be skipped if valid serialized model,
|
||||
// specified by "serialized_model_dir", exists.
|
||||
//
|
||||
// TODO: update to cached_kernel_dir
|
||||
optional string cached_kernel_path = 2;
|
||||
|
||||
// A dir to load from and save to a pre-compiled serialized model used to
|
||||
// accelerate init process.
|
||||
//
|
||||
// NOTE: available for OpenCL delegate on Android only when
|
||||
// "use_advanced_gpu_api" is set to true and "model_token" is set
|
||||
// properly.
|
||||
//
|
||||
// NOTE: serialized model takes precedence over binary cache
|
||||
// specified by "cached_kernel_path", which still can be used if
|
||||
// serialized model is invalid or missing.
|
||||
optional string serialized_model_dir = 7;
|
||||
|
||||
// Unique token identifying the model. Used in conjunction with
|
||||
// "serialized_model_dir". It is the caller's responsibility to ensure
|
||||
// there is no clash of the tokens.
|
||||
optional string model_token = 8;
|
||||
|
||||
// Encapsulated compilation/runtime tradeoffs.
|
||||
enum InferenceUsage {
|
||||
UNSPECIFIED = 0;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
@ -49,8 +50,8 @@ class InferenceCalculatorGlImpl
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status ReadKernelsFromFile();
|
||||
absl::Status WriteKernelsToFile();
|
||||
absl::Status ReadGpuCaches();
|
||||
absl::Status SaveGpuCaches();
|
||||
absl::Status LoadModel(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||
|
@ -82,6 +83,8 @@ class InferenceCalculatorGlImpl
|
|||
|
||||
bool use_kernel_caching_ = false;
|
||||
std::string cached_kernel_filename_;
|
||||
bool use_serialized_model_ = false;
|
||||
std::string serialized_model_path_;
|
||||
};
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
|
||||
|
@ -114,6 +117,9 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
|||
tflite_gpu_runner_usage_ = delegate.gpu().usage();
|
||||
use_kernel_caching_ =
|
||||
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
|
||||
use_serialized_model_ = use_advanced_gpu_api_ &&
|
||||
delegate.gpu().has_serialized_model_dir() &&
|
||||
delegate.gpu().has_model_token();
|
||||
use_gpu_delegate_ = !use_advanced_gpu_api_;
|
||||
|
||||
if (use_kernel_caching_) {
|
||||
|
@ -123,6 +129,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
|||
".ker";
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
if (use_serialized_model_) {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
serialized_model_path_ = mediapipe::file::JoinPath(
|
||||
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
|
||||
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
||||
// for everything.
|
||||
|
@ -210,7 +222,7 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
|
||||
absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
if (use_kernel_caching_) {
|
||||
// Save kernel file.
|
||||
|
@ -220,12 +232,22 @@ absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
|
|||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
|
||||
}
|
||||
if (use_serialized_model_) {
|
||||
// Save serialized model file.
|
||||
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
|
||||
tflite_gpu_runner_->GetSerializedModel());
|
||||
absl::string_view serialized_model(
|
||||
reinterpret_cast<char*>(serialized_model_vec.data()),
|
||||
serialized_model_vec.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(WriteKernelsToFile());
|
||||
MP_RETURN_IF_ERROR(SaveGpuCaches());
|
||||
if (use_gpu_delegate_) {
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
||||
gpu_buffers_in_.clear();
|
||||
|
@ -239,17 +261,24 @@ absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() {
|
||||
absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
if (use_kernel_caching_) {
|
||||
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
|
||||
// Load pre-compiled kernel file.
|
||||
if (mediapipe::File::Exists(cached_kernel_filename_)) {
|
||||
std::string cache_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
|
||||
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
|
||||
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
|
||||
}
|
||||
std::string cache_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
|
||||
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
|
||||
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
|
||||
}
|
||||
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
|
||||
// Load serialized model file.
|
||||
std::string serialized_model_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
file::GetContents(serialized_model_path_, &serialized_model_str));
|
||||
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
|
||||
serialized_model_str.end());
|
||||
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
return absl::OkStatus();
|
||||
|
@ -313,7 +342,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
|||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(ReadKernelsFromFile());
|
||||
MP_RETURN_IF_ERROR(ReadGpuCaches());
|
||||
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||
|
||||
|
|
|
@ -24,20 +24,20 @@ message SsdAnchorsCalculatorOptions {
|
|||
optional SsdAnchorsCalculatorOptions ext = 247258239;
|
||||
}
|
||||
// Size of input images.
|
||||
required int32 input_size_width = 1;
|
||||
required int32 input_size_height = 2;
|
||||
optional int32 input_size_width = 1; // required
|
||||
optional int32 input_size_height = 2; // required
|
||||
|
||||
// Min and max scales for generating anchor boxes on feature maps.
|
||||
required float min_scale = 3;
|
||||
required float max_scale = 4;
|
||||
optional float min_scale = 3; // required
|
||||
optional float max_scale = 4; // required
|
||||
|
||||
// The offset for the center of anchors. The value is in the scale of stride.
|
||||
// E.g. 0.5 meaning 0.5 * |current_stride| in pixels.
|
||||
required float anchor_offset_x = 5 [default = 0.5];
|
||||
required float anchor_offset_y = 6 [default = 0.5];
|
||||
optional float anchor_offset_x = 5 [default = 0.5]; // required
|
||||
optional float anchor_offset_y = 6 [default = 0.5]; // required
|
||||
|
||||
// Number of output feature maps to generate the anchors on.
|
||||
required int32 num_layers = 7;
|
||||
optional int32 num_layers = 7; // required
|
||||
// Sizes of output feature maps to create anchors. Either feature_map size or
|
||||
// stride should be provided.
|
||||
repeated int32 feature_map_width = 8;
|
||||
|
|
|
@ -26,12 +26,12 @@ message TfLiteTensorsToDetectionsCalculatorOptions {
|
|||
}
|
||||
|
||||
// The number of output classes predicted by the detection model.
|
||||
required int32 num_classes = 1;
|
||||
optional int32 num_classes = 1; // required
|
||||
// The number of output boxes predicted by the detection model.
|
||||
required int32 num_boxes = 2;
|
||||
optional int32 num_boxes = 2; // required
|
||||
// The number of output values per boxes predicted by the detection model. The
|
||||
// values contain bounding boxes, keypoints, etc.
|
||||
required int32 num_coords = 3;
|
||||
optional int32 num_coords = 3; // required
|
||||
|
||||
// The offset of keypoint coordinates in the location tensor.
|
||||
optional int32 keypoint_coord_offset = 9;
|
||||
|
|
|
@ -31,7 +31,7 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
|
|||
}
|
||||
|
||||
// Number of landmarks from the output of the model.
|
||||
required int32 num_landmarks = 1;
|
||||
optional int32 num_landmarks = 1; // required
|
||||
|
||||
// Size of the input image for the model. These options are used only when
|
||||
// normalized landmarks are needed. Z coordinate is scaled as X assuming
|
||||
|
|
|
@ -24,9 +24,9 @@ message TfLiteTensorsToSegmentationCalculatorOptions {
|
|||
}
|
||||
|
||||
// Dimensions of input segmentation tensor to process.
|
||||
required int32 tensor_width = 1;
|
||||
required int32 tensor_height = 2;
|
||||
required int32 tensor_channels = 3;
|
||||
optional int32 tensor_width = 1; // required
|
||||
optional int32 tensor_height = 2; // required
|
||||
optional int32 tensor_channels = 3; // required
|
||||
|
||||
// How much to use previous mask when computing current one; range [0-1].
|
||||
// This is a tradeoff between responsiveness (0.0) and accuracy (1.0).
|
||||
|
|
|
@ -98,6 +98,43 @@ public class MainActivity extends AppCompatActivity {
|
|||
}
|
||||
}
|
||||
|
||||
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||
int width = imageView.getWidth();
|
||||
int height = imageView.getHeight();
|
||||
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||
width = (int) (height * aspectRatio);
|
||||
} else {
|
||||
height = (int) (width / aspectRatio);
|
||||
}
|
||||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||
}
|
||||
|
||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||
return inputBitmap;
|
||||
}
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
return Bitmap.createBitmap(
|
||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||
}
|
||||
|
||||
/** Sets up the UI components for the static image demo. */
|
||||
private void setupStaticImageDemoUiComponents() {
|
||||
// The Intent to access gallery and read images as bitmap.
|
||||
|
@ -111,37 +148,16 @@ public class MainActivity extends AppCompatActivity {
|
|||
Bitmap bitmap = null;
|
||||
try {
|
||||
bitmap =
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData());
|
||||
downscaleBitmap(
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData()));
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap reading error:" + e);
|
||||
}
|
||||
try {
|
||||
InputStream imageData =
|
||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(
|
||||
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
bitmap =
|
||||
Bitmap.createBitmap(
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
|
||||
}
|
||||
bitmap = rotateBitmap(bitmap, imageData);
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ android {
|
|||
buildToolsVersion "30.0.3"
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.google.mediapipe.apps.hands"
|
||||
applicationId "com.google.mediapipe.apps.facemesh"
|
||||
minSdkVersion 21
|
||||
targetSdkVersion 30
|
||||
versionCode 1
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
<application
|
||||
android:allowBackup="true"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="MediaPipe FaceMesh"
|
||||
android:label="MediaPipe Face Mesh"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/AppTheme">
|
||||
|
|
|
@ -99,6 +99,43 @@ public class MainActivity extends AppCompatActivity {
|
|||
}
|
||||
}
|
||||
|
||||
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||
int width = imageView.getWidth();
|
||||
int height = imageView.getHeight();
|
||||
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||
width = (int) (height * aspectRatio);
|
||||
} else {
|
||||
height = (int) (width / aspectRatio);
|
||||
}
|
||||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||
}
|
||||
|
||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||
return inputBitmap;
|
||||
}
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
return Bitmap.createBitmap(
|
||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||
}
|
||||
|
||||
/** Sets up the UI components for the static image demo. */
|
||||
private void setupStaticImageDemoUiComponents() {
|
||||
// The Intent to access gallery and read images as bitmap.
|
||||
|
@ -112,37 +149,16 @@ public class MainActivity extends AppCompatActivity {
|
|||
Bitmap bitmap = null;
|
||||
try {
|
||||
bitmap =
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData());
|
||||
downscaleBitmap(
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData()));
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap reading error:" + e);
|
||||
}
|
||||
try {
|
||||
InputStream imageData =
|
||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(
|
||||
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
bitmap =
|
||||
Bitmap.createBitmap(
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
|
||||
}
|
||||
bitmap = rotateBitmap(bitmap, imageData);
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||
}
|
||||
|
|
|
@ -100,6 +100,43 @@ public class MainActivity extends AppCompatActivity {
|
|||
}
|
||||
}
|
||||
|
||||
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||
int width = imageView.getWidth();
|
||||
int height = imageView.getHeight();
|
||||
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||
width = (int) (height * aspectRatio);
|
||||
} else {
|
||||
height = (int) (width / aspectRatio);
|
||||
}
|
||||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||
}
|
||||
|
||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||
return inputBitmap;
|
||||
}
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
return Bitmap.createBitmap(
|
||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||
}
|
||||
|
||||
/** Sets up the UI components for the static image demo. */
|
||||
private void setupStaticImageDemoUiComponents() {
|
||||
// The Intent to access gallery and read images as bitmap.
|
||||
|
@ -113,37 +150,16 @@ public class MainActivity extends AppCompatActivity {
|
|||
Bitmap bitmap = null;
|
||||
try {
|
||||
bitmap =
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData());
|
||||
downscaleBitmap(
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData()));
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap reading error:" + e);
|
||||
}
|
||||
try {
|
||||
InputStream imageData =
|
||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(
|
||||
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
bitmap =
|
||||
Bitmap.createBitmap(
|
||||
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
|
||||
}
|
||||
bitmap = rotateBitmap(bitmap, imageData);
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ android_binary(
|
|||
assets = [
|
||||
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
||||
],
|
||||
assets_dir = "",
|
||||
|
|
|
@ -39,7 +39,7 @@ android_binary(
|
|||
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
|
||||
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
|
||||
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
|
|
|
@ -62,7 +62,7 @@ objc_library(
|
|||
copts = ["-std=c++17"],
|
||||
data = [
|
||||
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
||||
],
|
||||
|
|
|
@ -57,7 +57,7 @@ objc_library(
|
|||
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
|
||||
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
|
||||
"//mediapipe/modules/face_landmark:face_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
|
||||
"//mediapipe/modules/pose_detection:pose_detection.tflite",
|
||||
|
|
|
@ -427,7 +427,8 @@ absl::Status CalculatorGraph::Initialize(
|
|||
const std::map<std::string, Packet>& side_packets) {
|
||||
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
|
||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(
|
||||
input_config, /*graph_registry=*/nullptr, &service_manager_));
|
||||
input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr,
|
||||
&service_manager_));
|
||||
return Initialize(std::move(validated_graph), side_packets);
|
||||
}
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class CalculatorRunner {
|
|||
const StreamContentsSet& Outputs() const { return *outputs_; }
|
||||
|
||||
// Returns the access to the output side packets.
|
||||
const PacketSet& OutputSidePackets() { return *output_side_packets_.get(); }
|
||||
const PacketSet& OutputSidePackets() { return *output_side_packets_; }
|
||||
|
||||
// Returns a graph counter.
|
||||
mediapipe::Counter* GetCounter(const std::string& name);
|
||||
|
|
|
@ -77,13 +77,6 @@ bool Image::ConvertToGpu() const {
|
|||
#else
|
||||
// GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture)
|
||||
auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_);
|
||||
glBindTexture(GL_TEXTURE_2D, buffer->name());
|
||||
// See GlCalculatorHelperImpl::SetStandardTextureParams
|
||||
glTexParameteri(buffer->target(), GL_TEXTURE_MIN_FILTER, GL_LINEAR);
|
||||
glTexParameteri(buffer->target(), GL_TEXTURE_MAG_FILTER, GL_LINEAR);
|
||||
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
|
||||
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
glFlush();
|
||||
gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer));
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
|
|
@ -244,6 +244,8 @@ cc_test(
|
|||
srcs = ["mux_input_stream_handler_test.cc"],
|
||||
deps = [
|
||||
":mux_input_stream_handler",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:make_pair_calculator",
|
||||
"//mediapipe/calculators/core:mux_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/core:round_robin_demux_calculator",
|
||||
|
|
|
@ -75,13 +75,30 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
|||
int control_value = control_packet.Get<int>();
|
||||
CHECK_LE(0, control_value);
|
||||
CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1);
|
||||
|
||||
const auto& data_stream = input_stream_managers_.Get(
|
||||
input_stream_managers_.BeginId() + control_value);
|
||||
|
||||
// Data stream may contain some outdated packets which failed to be popped
|
||||
// out during "FillInputSet". (This handler doesn't sync input streams,
|
||||
// hence "FillInputSet" can be triggerred before every input stream is
|
||||
// filled with packets corresponding to the same timestamp.)
|
||||
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
|
||||
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
|
||||
if (empty) {
|
||||
CHECK_LE(stream_timestamp, *min_stream_timestamp);
|
||||
return NodeReadiness::kNotReady;
|
||||
if (stream_timestamp <= *min_stream_timestamp) {
|
||||
// "data_stream" didn't receive a packet corresponding to the current
|
||||
// "control_stream" packet yet.
|
||||
return NodeReadiness::kNotReady;
|
||||
}
|
||||
// "data_stream" timestamp bound update detected.
|
||||
return NodeReadiness::kReadyForProcess;
|
||||
}
|
||||
if (stream_timestamp > *min_stream_timestamp) {
|
||||
// The earliest packet "data_stream" holds corresponds to a control packet
|
||||
// yet to arrive, which means there won't be a "data_stream" packet
|
||||
// corresponding to the current "control_stream" packet, which should be
|
||||
// indicated as timestamp boun update.
|
||||
return NodeReadiness::kReadyForProcess;
|
||||
}
|
||||
CHECK_EQ(stream_timestamp, *min_stream_timestamp);
|
||||
return NodeReadiness::kReadyForProcess;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -19,9 +20,10 @@
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// A regression test for b/31620439. MuxInputStreamHandler's accesses to the
|
||||
// control and data streams should be atomic so that it has a consistent view
|
||||
// of the two streams. None of the CHECKs in the GetNodeReadiness() method of
|
||||
|
@ -87,5 +89,561 @@ TEST(MuxInputStreamHandlerTest, AtomicAccessToControlAndDataStreams) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
MATCHER_P2(IntPacket, value, ts, "") {
|
||||
return arg.template Get<int>() == value && arg.Timestamp() == ts;
|
||||
}
|
||||
|
||||
struct GateAndMuxGraphInput {
|
||||
int input0;
|
||||
int input1;
|
||||
int input2;
|
||||
int select;
|
||||
bool allow0;
|
||||
bool allow1;
|
||||
bool allow2;
|
||||
Timestamp at;
|
||||
};
|
||||
|
||||
constexpr char kGateAndMuxGraph[] = R"pb(
|
||||
input_stream: "input0"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
input_stream: "select"
|
||||
input_stream: "allow0"
|
||||
input_stream: "allow1"
|
||||
input_stream: "allow2"
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "ALLOW:allow0"
|
||||
input_stream: "input0"
|
||||
output_stream: "output0"
|
||||
}
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "ALLOW:allow1"
|
||||
input_stream: "input1"
|
||||
output_stream: "output1"
|
||||
}
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "ALLOW:allow2"
|
||||
input_stream: "input2"
|
||||
output_stream: "output2"
|
||||
}
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:output0"
|
||||
input_stream: "INPUT:1:output1"
|
||||
input_stream: "INPUT:2:output2"
|
||||
input_stream: "SELECT:select"
|
||||
output_stream: "OUTPUT:output"
|
||||
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||
})pb";
|
||||
|
||||
absl::Status SendInput(GateAndMuxGraphInput in, CalculatorGraph& graph) {
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(in.input0).At(in.at)));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(in.input1).At(in.at)));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"input2", MakePacket<int>(in.input2).At(in.at)));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(in.select).At(in.at)));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"allow0", MakePacket<bool>(in.allow0).At(in.at)));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"allow1", MakePacket<bool>(in.allow1).At(in.at)));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"allow2", MakePacket<bool>(in.allow2).At(in.at)));
|
||||
return graph.WaitUntilIdle();
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest, BasicMuxing) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 0,
|
||||
.allow0 = true,
|
||||
.allow1 = false,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(1)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 1,
|
||||
.allow0 = false,
|
||||
.allow1 = true,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(2)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||
IntPacket(900, Timestamp(2))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 2,
|
||||
.allow0 = false,
|
||||
.allow1 = false,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(3)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||
IntPacket(900, Timestamp(2)),
|
||||
IntPacket(800, Timestamp(3))));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest, MuxingNonEmptyInputs) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 0,
|
||||
.allow0 = true,
|
||||
.allow1 = true,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(1)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 1,
|
||||
.allow0 = true,
|
||||
.allow1 = true,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(2)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||
IntPacket(900, Timestamp(2))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 2,
|
||||
.allow0 = true,
|
||||
.allow1 = true,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(3)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
|
||||
IntPacket(900, Timestamp(2)),
|
||||
IntPacket(800, Timestamp(3))));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest, MuxingAllTimestampBoundUpdates) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 0,
|
||||
.allow0 = false,
|
||||
.allow1 = false,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(1)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 1,
|
||||
.allow0 = false,
|
||||
.allow1 = false,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(2)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 2,
|
||||
.allow0 = false,
|
||||
.allow1 = false,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(3)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest, MuxingSlectedTimestampBoundUpdates) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 0,
|
||||
.allow0 = false,
|
||||
.allow1 = true,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(1)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 1,
|
||||
.allow0 = true,
|
||||
.allow1 = false,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(2)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 2,
|
||||
.allow0 = true,
|
||||
.allow1 = true,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(3)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest, MuxingSometimesTimestampBoundUpdates) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 0,
|
||||
.allow0 = false,
|
||||
.allow1 = false,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(1)},
|
||||
graph));
|
||||
EXPECT_TRUE(output_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 1,
|
||||
.allow0 = false,
|
||||
.allow1 = true,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(2)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 2,
|
||||
.allow0 = true,
|
||||
.allow1 = true,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(3)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 1000,
|
||||
.input1 = 900,
|
||||
.input2 = 800,
|
||||
.select = 2,
|
||||
.allow0 = true,
|
||||
.allow1 = true,
|
||||
.allow2 = true,
|
||||
.at = Timestamp(4)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
|
||||
IntPacket(800, Timestamp(4))));
|
||||
|
||||
MP_ASSERT_OK(SendInput({.input0 = 700,
|
||||
.input1 = 600,
|
||||
.input2 = 500,
|
||||
.select = 0,
|
||||
.allow0 = true,
|
||||
.allow1 = false,
|
||||
.allow2 = false,
|
||||
.at = Timestamp(5)},
|
||||
graph));
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
|
||||
IntPacket(800, Timestamp(4)),
|
||||
IntPacket(700, Timestamp(5))));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
MATCHER_P(EmptyPacket, ts, "") {
|
||||
return arg.IsEmpty() && arg.Timestamp() == ts;
|
||||
}
|
||||
|
||||
MATCHER_P2(Pair, m1, m2, "") {
|
||||
const auto& p = arg.template Get<std::pair<Packet, Packet>>();
|
||||
return testing::Matches(m1)(p.first) && testing::Matches(m2)(p.second);
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest,
|
||||
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacket) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input0"
|
||||
input_stream: "input1"
|
||||
input_stream: "select"
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:input0"
|
||||
input_stream: "INPUT:1:input1"
|
||||
input_stream: "SELECT:select"
|
||||
output_stream: "OUTPUT:output"
|
||||
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||
}
|
||||
node {
|
||||
calculator: "MakePairCalculator"
|
||||
input_stream: "select"
|
||||
input_stream: "output"
|
||||
output_stream: "pair"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> pair_packets;
|
||||
tool::AddVectorSink("pair", &config, &pair_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_TRUE(pair_packets.empty());
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(1000).At(Timestamp(2))));
|
||||
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
|
||||
EmptyPacket(Timestamp(1)))));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(900).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(800).At(Timestamp(4))));
|
||||
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
|
||||
EmptyPacket(Timestamp(1)))));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
pair_packets,
|
||||
ElementsAre(
|
||||
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2)))));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(1).At(Timestamp(3))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
pair_packets,
|
||||
ElementsAre(
|
||||
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
|
||||
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3)))));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(1).At(Timestamp(4))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
pair_packets,
|
||||
ElementsAre(
|
||||
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
|
||||
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
|
||||
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest,
|
||||
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacketPacketsAtOnce) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input0"
|
||||
input_stream: "input1"
|
||||
input_stream: "select"
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:input0"
|
||||
input_stream: "INPUT:1:input1"
|
||||
input_stream: "SELECT:select"
|
||||
output_stream: "OUTPUT:output"
|
||||
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||
}
|
||||
node {
|
||||
calculator: "MakePairCalculator"
|
||||
input_stream: "select"
|
||||
input_stream: "output"
|
||||
output_stream: "pair"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> pair_packets;
|
||||
tool::AddVectorSink("pair", &config, &pair_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(1000).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(900).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(800).At(Timestamp(4))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(1).At(Timestamp(3))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(1).At(Timestamp(4))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
pair_packets,
|
||||
ElementsAre(
|
||||
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
|
||||
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
|
||||
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest,
|
||||
TimestampBoundUpdateTriggersTimestampBoundUpdate) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input0"
|
||||
input_stream: "input1"
|
||||
input_stream: "select"
|
||||
input_stream: "allow0"
|
||||
input_stream: "allow1"
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "ALLOW:allow0"
|
||||
input_stream: "input0"
|
||||
output_stream: "output0"
|
||||
}
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "ALLOW:allow1"
|
||||
input_stream: "input1"
|
||||
output_stream: "output1"
|
||||
}
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:output0"
|
||||
input_stream: "INPUT:1:output1"
|
||||
input_stream: "SELECT:select"
|
||||
output_stream: "OUTPUT:output"
|
||||
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||
}
|
||||
node {
|
||||
calculator: "MakePairCalculator"
|
||||
input_stream: "select"
|
||||
input_stream: "output"
|
||||
output_stream: "pair"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> pair_packets;
|
||||
tool::AddVectorSink("pair", &config, &pair_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(1000).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"allow0", MakePacket<bool>(false).At(Timestamp(1))));
|
||||
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
|
||||
EmptyPacket(Timestamp(1)))));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(900).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"allow0", MakePacket<bool>(true).At(Timestamp(2))));
|
||||
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
pair_packets,
|
||||
ElementsAre(
|
||||
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
|
||||
Pair(IntPacket(0, Timestamp(2)), IntPacket(900, Timestamp(2)))));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -56,7 +56,9 @@ class SubgraphContext {
|
|||
return options_map_.Get<T>();
|
||||
}
|
||||
|
||||
const CalculatorGraphConfig::Node& OriginalNode() { return original_node_; }
|
||||
const CalculatorGraphConfig::Node& OriginalNode() const {
|
||||
return original_node_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ServiceBinding<T> Service(const GraphService<T>& service) const {
|
||||
|
|
|
@ -724,6 +724,7 @@ cc_test(
|
|||
srcs = ["subgraph_expansion_test.cc"],
|
||||
deps = [
|
||||
":node_chain_subgraph_cc_proto",
|
||||
":node_chain_subgraph_options_lib",
|
||||
":subgraph_expansion",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/type_map.h"
|
||||
|
||||
#define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging()
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tool {
|
||||
|
||||
|
@ -47,13 +49,13 @@ absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in,
|
|||
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
|
||||
if (IsLengthDelimited(wire_type)) {
|
||||
uint32 length;
|
||||
RET_CHECK(in->ReadVarint32(&length));
|
||||
RET_CHECK(in->ReadString(result, length));
|
||||
RET_CHECK_NO_LOG(in->ReadVarint32(&length));
|
||||
RET_CHECK_NO_LOG(in->ReadString(result, length));
|
||||
} else {
|
||||
std::string field_data;
|
||||
StringOutputStream sos(&field_data);
|
||||
CodedOutputStream cos(&sos);
|
||||
RET_CHECK(WireFormatLite::SkipField(in, tag, &cos));
|
||||
RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, &cos));
|
||||
// Skip the tag written by SkipField.
|
||||
int tag_size = CodedOutputStream::VarintSize32(tag);
|
||||
cos.Trim();
|
||||
|
@ -67,13 +69,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
|
|||
CodedInputStream* in,
|
||||
std::vector<std::string>* field_values) {
|
||||
uint32 data_size;
|
||||
RET_CHECK(in->ReadVarint32(&data_size));
|
||||
RET_CHECK_NO_LOG(in->ReadVarint32(&data_size));
|
||||
// fake_tag encodes the wire-type for calls to WireFormatLite::SkipField.
|
||||
uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type);
|
||||
while (data_size > 0) {
|
||||
std::string number;
|
||||
MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number));
|
||||
RET_CHECK_LE(number.size(), data_size);
|
||||
RET_CHECK_NO_LOG(number.size() <= data_size);
|
||||
field_values->push_back(number);
|
||||
data_size -= number.size();
|
||||
}
|
||||
|
@ -98,7 +100,7 @@ absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
|
|||
field_values->push_back(value);
|
||||
}
|
||||
} else {
|
||||
RET_CHECK(WireFormatLite::SkipField(in, tag, out));
|
||||
RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, out));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -157,12 +159,12 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
|
|||
MP_RETURN_IF_ERROR(access.SetMessage(*message));
|
||||
std::vector<std::string>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK(index >= 0 && index < v.size());
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
|
||||
field_type, field_values));
|
||||
} else {
|
||||
RET_CHECK(index >= 0 && index <= v.size());
|
||||
RET_CHECK(index + length >= 0 && index + length <= v.size());
|
||||
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
|
||||
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
|
||||
v.erase(v.begin() + index, v.begin() + index + length);
|
||||
v.insert(v.begin() + index, field_values.begin(), field_values.end());
|
||||
}
|
||||
|
@ -184,12 +186,12 @@ absl::Status ProtoUtilLite::GetFieldRange(
|
|||
MP_RETURN_IF_ERROR(access.SetMessage(message));
|
||||
std::vector<std::string>& v = *access.mutable_field_values();
|
||||
if (!proto_path.empty()) {
|
||||
RET_CHECK(index >= 0 && index < v.size());
|
||||
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetFieldRange(v[index], proto_path, length, field_type, field_values));
|
||||
} else {
|
||||
RET_CHECK(index >= 0 && index <= v.size());
|
||||
RET_CHECK(index + length >= 0 && index + length <= v.size());
|
||||
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
|
||||
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
|
||||
field_values->insert(field_values->begin(), v.begin() + index,
|
||||
v.begin() + index + length);
|
||||
}
|
||||
|
|
|
@ -274,12 +274,14 @@ absl::Status ConnectSubgraphStreams(
|
|||
|
||||
absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
|
||||
const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager) {
|
||||
graph_registry =
|
||||
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
||||
RET_CHECK(config);
|
||||
|
||||
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
|
||||
CalculatorGraphConfig::Node(), config));
|
||||
graph_options ? *graph_options : CalculatorGraphConfig::Node(), config));
|
||||
auto* nodes = config->mutable_node();
|
||||
while (1) {
|
||||
auto subgraph_nodes_start = std::stable_partition(
|
||||
|
|
|
@ -72,6 +72,7 @@ absl::Status ConnectSubgraphStreams(
|
|||
absl::Status ExpandSubgraphs(
|
||||
CalculatorGraphConfig* config,
|
||||
const GraphRegistry* graph_registry = nullptr,
|
||||
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
|
||||
// Creates a graph wrapping the provided node and exposing all of its
|
||||
|
|
|
@ -560,9 +560,111 @@ TEST(SubgraphExpansionTest, GraphServicesUsage) {
|
|||
MP_ASSERT_OK(service_manager.SetServiceObject(
|
||||
kStringTestService, std::make_shared<std::string>("ExpectedNode")));
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr,
|
||||
/*graph_options=*/nullptr,
|
||||
&service_manager));
|
||||
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
// Shows SubgraphOptions consumed by GraphRegistry::CreateByName.
|
||||
TEST(SubgraphExpansionTest, SubgraphOptionsUsage) {
|
||||
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
|
||||
GraphRegistry graph_registry;
|
||||
|
||||
// CalculatorGraph::Initialize passes the SubgraphOptions into:
|
||||
// (1) GraphRegistry::CreateByName("NodeChainSubgraph", options)
|
||||
// (2) tool::ExpandSubgraphs(&config, options)
|
||||
auto graph_options =
|
||||
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
|
||||
options {
|
||||
[mediapipe.NodeChainSubgraphOptions.ext] {
|
||||
node_type: "DoubleIntCalculator"
|
||||
chain_length: 3
|
||||
}
|
||||
})pb");
|
||||
SubgraphContext context(&graph_options, /*service_manager=*/nullptr);
|
||||
|
||||
// "NodeChainSubgraph" consumes graph_options only in CreateByName.
|
||||
auto subgraph_status =
|
||||
graph_registry.CreateByName("", "NodeChainSubgraph", &context);
|
||||
MP_ASSERT_OK(subgraph_status);
|
||||
auto subgraph = std::move(subgraph_status).value();
|
||||
MP_ASSERT_OK(
|
||||
tool::ExpandSubgraphs(&subgraph, &graph_registry, &graph_options));
|
||||
|
||||
CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "DoubleIntCalculator"
|
||||
input_stream: "stream_0"
|
||||
output_stream: "stream_1"
|
||||
}
|
||||
node {
|
||||
calculator: "DoubleIntCalculator"
|
||||
input_stream: "stream_1"
|
||||
output_stream: "stream_2"
|
||||
}
|
||||
node {
|
||||
calculator: "DoubleIntCalculator"
|
||||
input_stream: "stream_2"
|
||||
output_stream: "stream_3"
|
||||
}
|
||||
input_stream: "INPUT:stream_0"
|
||||
output_stream: "OUTPUT:stream_3"
|
||||
)pb");
|
||||
|
||||
EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
// Shows SubgraphOptions consumed by tool::ExpandSubgraphs.
|
||||
TEST(SubgraphExpansionTest, SimpleSubgraphOptionsUsage) {
|
||||
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
|
||||
GraphRegistry graph_registry;
|
||||
auto moon_options =
|
||||
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
|
||||
options {
|
||||
[mediapipe.NodeChainSubgraphOptions.ext] {
|
||||
node_type: "DoubleIntCalculator"
|
||||
chain_length: 3
|
||||
}
|
||||
})pb");
|
||||
auto moon_subgraph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
type: "MoonSubgraph"
|
||||
graph_options: {
|
||||
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
|
||||
}
|
||||
node: {
|
||||
calculator: "MoonCalculator"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
|
||||
}
|
||||
option_value: "chain_length:options/chain_length"
|
||||
}
|
||||
)pb");
|
||||
|
||||
// The moon_options are copied into the graph_options of moon_subgraph.
|
||||
MP_ASSERT_OK(
|
||||
tool::ExpandSubgraphs(&moon_subgraph, &graph_registry, &moon_options));
|
||||
|
||||
// The field chain_length is copied from moon_options into MoonCalculator.
|
||||
CalculatorGraphConfig expected_graph =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "MoonCalculator"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {
|
||||
chain_length: 3
|
||||
}
|
||||
}
|
||||
option_value: "chain_length:options/chain_length"
|
||||
}
|
||||
type: "MoonSubgraph"
|
||||
graph_options {
|
||||
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
|
||||
}
|
||||
)pb");
|
||||
EXPECT_THAT(moon_subgraph, mediapipe::EqualsProto(expected_graph));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -150,12 +150,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
|
|||
const int packet_count = 10;
|
||||
// Send int value packets at {10K, 20K, 30K, ..., 100K}.
|
||||
for (uint64 t = 1; t <= packet_count; ++t) {
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
||||
if (send_bounds) {
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"enable", MakePacket<bool>(true).At(Timestamp(t * 10000))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
}
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
// The inputs are sent to the input stream "foo", they should pass through.
|
||||
EXPECT_EQ(out_foo.size(), t);
|
||||
|
@ -175,12 +176,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
|
|||
|
||||
// Send int value packets at {110K, 120K, ..., 200K}.
|
||||
for (uint64 t = 11; t <= packet_count * 2; ++t) {
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
||||
if (send_bounds) {
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"enable", MakePacket<bool>(false).At(Timestamp(t * 10000))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
}
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
// The inputs are sent to the input stream "foo", they should pass through.
|
||||
EXPECT_EQ(out_foo.size(), t);
|
||||
|
|
|
@ -143,11 +143,12 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) {
|
|||
absl::Status PerformBasicTransforms(
|
||||
const CalculatorGraphConfig& input_graph_config,
|
||||
const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager,
|
||||
CalculatorGraphConfig* output_graph_config) {
|
||||
*output_graph_config = input_graph_config;
|
||||
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
|
||||
service_manager));
|
||||
graph_options, service_manager));
|
||||
|
||||
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
|
||||
|
||||
|
@ -347,6 +348,7 @@ absl::Status NodeTypeInfo::Initialize(
|
|||
absl::Status ValidatedGraphConfig::Initialize(
|
||||
const CalculatorGraphConfig& input_config,
|
||||
const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager) {
|
||||
RET_CHECK(!initialized_)
|
||||
<< "ValidatedGraphConfig can be initialized only once.";
|
||||
|
@ -356,8 +358,8 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
<< input_config.DebugString();
|
||||
#endif
|
||||
|
||||
MP_RETURN_IF_ERROR(PerformBasicTransforms(input_config, graph_registry,
|
||||
service_manager, &config_));
|
||||
MP_RETURN_IF_ERROR(PerformBasicTransforms(
|
||||
input_config, graph_registry, graph_options, service_manager, &config_));
|
||||
|
||||
// Initialize the basic node information.
|
||||
MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
|
||||
|
@ -431,22 +433,24 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
}
|
||||
|
||||
absl::Status ValidatedGraphConfig::Initialize(
|
||||
const std::string& graph_type, const Subgraph::SubgraphOptions* options,
|
||||
const GraphRegistry* graph_registry,
|
||||
const std::string& graph_type, const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager) {
|
||||
graph_registry =
|
||||
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
||||
SubgraphContext subgraph_context(options, service_manager);
|
||||
SubgraphContext subgraph_context(graph_options, service_manager);
|
||||
auto status_or_config =
|
||||
graph_registry->CreateByName("", graph_type, &subgraph_context);
|
||||
MP_RETURN_IF_ERROR(status_or_config.status());
|
||||
return Initialize(status_or_config.value(), graph_registry, service_manager);
|
||||
return Initialize(status_or_config.value(), graph_registry, graph_options,
|
||||
service_manager);
|
||||
}
|
||||
|
||||
absl::Status ValidatedGraphConfig::Initialize(
|
||||
const std::vector<CalculatorGraphConfig>& input_configs,
|
||||
const std::vector<CalculatorGraphTemplate>& input_templates,
|
||||
const std::string& graph_type, const Subgraph::SubgraphOptions* arguments,
|
||||
const std::string& graph_type,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager) {
|
||||
GraphRegistry graph_registry;
|
||||
for (auto& config : input_configs) {
|
||||
|
@ -455,7 +459,8 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
for (auto& templ : input_templates) {
|
||||
graph_registry.Register(templ.config().type(), templ);
|
||||
}
|
||||
return Initialize(graph_type, arguments, &graph_registry, service_manager);
|
||||
return Initialize(graph_type, &graph_registry, graph_options,
|
||||
service_manager);
|
||||
}
|
||||
|
||||
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {
|
||||
|
|
|
@ -195,18 +195,21 @@ class ValidatedGraphConfig {
|
|||
// Initializes the ValidatedGraphConfig. This function must be called
|
||||
// before any other functions. Subgraphs are specified through the
|
||||
// global graph registry or an optional local graph registry.
|
||||
absl::Status Initialize(const CalculatorGraphConfig& input_config,
|
||||
const GraphRegistry* graph_registry = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
absl::Status Initialize(
|
||||
const CalculatorGraphConfig& input_config,
|
||||
const GraphRegistry* graph_registry = nullptr,
|
||||
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
|
||||
// Initializes the ValidatedGraphConfig from registered graph and subgraph
|
||||
// configs. Subgraphs are retrieved from the specified graph registry or from
|
||||
// the global graph registry. A subgraph can be instantiated directly by
|
||||
// specifying its type in |graph_type|.
|
||||
absl::Status Initialize(const std::string& graph_type,
|
||||
const Subgraph::SubgraphOptions* options = nullptr,
|
||||
const GraphRegistry* graph_registry = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
absl::Status Initialize(
|
||||
const std::string& graph_type,
|
||||
const GraphRegistry* graph_registry = nullptr,
|
||||
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
|
||||
// Initializes the ValidatedGraphConfig from the specified graph and subgraph
|
||||
// configs. Template graph and subgraph configs can be specified through
|
||||
|
@ -218,7 +221,7 @@ class ValidatedGraphConfig {
|
|||
const std::vector<CalculatorGraphConfig>& input_configs,
|
||||
const std::vector<CalculatorGraphTemplate>& input_templates,
|
||||
const std::string& graph_type = "",
|
||||
const Subgraph::SubgraphOptions* arguments = nullptr,
|
||||
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
|
||||
// Returns true if the ValidatedGraphConfig has been initialized.
|
||||
|
|
|
@ -155,6 +155,7 @@ TEST(ValidatedGraphConfigTest, InitializeSubgraphWithServiceCalculatorB) {
|
|||
kStringTestService, std::make_shared<std::string>(calculator_name)));
|
||||
MP_EXPECT_OK(config.Initialize(graph,
|
||||
/*graph_registry=*/nullptr,
|
||||
/*subgraph_options=*/nullptr,
|
||||
/*service_manager=*/&service_manager));
|
||||
ASSERT_TRUE(config.Initialized());
|
||||
EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph(
|
||||
|
|
|
@ -196,7 +196,9 @@ cc_library(
|
|||
deps = [
|
||||
":gl_base",
|
||||
":gl_context",
|
||||
":gl_texture_view",
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage",
|
||||
# TODO: remove this dependency. Some other teams' tests
|
||||
# depend on having an indirect image_frame dependency, need to be
|
||||
# fixed first.
|
||||
|
@ -205,6 +207,17 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gl_texture_view",
|
||||
srcs = ["gl_texture_view.cc"],
|
||||
hdrs = ["gl_texture_view.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gl_base",
|
||||
":gl_context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_buffer",
|
||||
srcs = ["gpu_buffer.cc"],
|
||||
|
@ -214,12 +227,15 @@ cc_library(
|
|||
":gl_base",
|
||||
":gl_context",
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage",
|
||||
":gl_texture_view",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
":gl_texture_buffer",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
":gpu_buffer_storage_cv_pixel_buffer",
|
||||
"//mediapipe/objc:util",
|
||||
"//mediapipe/objc:CFHolder",
|
||||
],
|
||||
|
@ -244,6 +260,35 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_buffer_storage",
|
||||
hdrs = ["gpu_buffer_storage.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gl_base",
|
||||
":gpu_buffer_format",
|
||||
"//mediapipe/framework/deps:no_destructor",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_buffer_storage_cv_pixel_buffer",
|
||||
srcs = ["gpu_buffer_storage_cv_pixel_buffer.cc"],
|
||||
hdrs = ["gpu_buffer_storage_cv_pixel_buffer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gl_base",
|
||||
":gl_context",
|
||||
":gl_texture_view",
|
||||
":gpu_buffer_storage",
|
||||
"//mediapipe/objc:CFHolder",
|
||||
"//mediapipe/objc:util",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "gpu_origin_proto",
|
||||
srcs = ["gpu_origin.proto"],
|
||||
|
|
|
@ -109,12 +109,10 @@ GlTexture GlCalculatorHelper::CreateSourceTexture(
|
|||
return impl_->CreateSourceTexture(image_frame);
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer,
|
||||
int plane) {
|
||||
return impl_->CreateSourceTexture(pixel_buffer, plane);
|
||||
}
|
||||
#endif
|
||||
|
||||
void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer,
|
||||
int* width, int* height) {
|
||||
|
|
|
@ -29,10 +29,6 @@
|
|||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/gpu/graph_support.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <CoreVideo/CoreVideo.h>
|
||||
#endif // __APPLE__
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class GlCalculatorHelperImpl;
|
||||
|
@ -111,24 +107,35 @@ class GlCalculatorHelper {
|
|||
// where it is supported (iOS, for now) they take advantage of memory sharing
|
||||
// between the CPU and GPU, avoiding memory copies.
|
||||
|
||||
// Creates a texture representing an input frame, and manages sync token.
|
||||
// Gives access to an input frame as an OpenGL texture for reading (sampling).
|
||||
//
|
||||
// IMPORTANT: the returned GlTexture should be treated as a short-term view
|
||||
// into the frame (typically for the duration of a Process call). Do not store
|
||||
// it as a member in your calculator. If you need to keep a frame around,
|
||||
// store the GpuBuffer instead, and call CreateSourceTexture again on each
|
||||
// Process call.
|
||||
//
|
||||
// TODO: rename this; the use of "Create" makes this sound more expensive than
|
||||
// it is.
|
||||
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer);
|
||||
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
|
||||
GlTexture CreateSourceTexture(const mediapipe::Image& image);
|
||||
|
||||
#ifdef __APPLE__
|
||||
// Creates a texture from a plane of a planar buffer.
|
||||
// Gives read access to a plane of a planar buffer.
|
||||
// The plane index is zero-based. The number of planes depends on the
|
||||
// internal format of the buffer.
|
||||
// Note: multi-plane support is not available on all platforms.
|
||||
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer, int plane);
|
||||
#endif
|
||||
|
||||
// Convenience function for converting an ImageFrame to GpuBuffer and then
|
||||
// accessing it as a texture.
|
||||
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
|
||||
|
||||
// Extracts GpuBuffer dimensions without creating a texture.
|
||||
ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead")
|
||||
void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width,
|
||||
int* height);
|
||||
|
||||
// Creates a texture representing an output frame, and manages sync token.
|
||||
// Gives access to an OpenGL texture for writing (rendering) a new frame.
|
||||
// TODO: This should either return errors or a status.
|
||||
GlTexture CreateDestinationTexture(
|
||||
int output_width, int output_height,
|
||||
|
|
|
@ -62,10 +62,7 @@ class GlCalculatorHelperImpl {
|
|||
|
||||
private:
|
||||
// Makes a GpuBuffer accessible as a texture in the GL context.
|
||||
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, int plane,
|
||||
bool for_reading);
|
||||
void AttachGlTexture(GlTexture& texture, const GpuBuffer& gpu_buffer,
|
||||
int plane, bool for_reading);
|
||||
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view);
|
||||
|
||||
// Create the framebuffer for rendering.
|
||||
void CreateFramebuffer();
|
||||
|
|
|
@ -91,9 +91,7 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
|
|||
}
|
||||
|
||||
GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
|
||||
int plane, bool for_reading) {
|
||||
GlTextureView view = gpu_buffer.GetGlTextureView(plane, for_reading);
|
||||
|
||||
GlTextureView view) {
|
||||
if (gpu_buffer.format() != GpuBufferFormat::kUnknown) {
|
||||
// TODO: do the params need to be reset here??
|
||||
glBindTexture(view.target(), view.name());
|
||||
|
@ -109,19 +107,18 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
|
|||
|
||||
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
||||
const GpuBuffer& gpu_buffer) {
|
||||
return MapGpuBuffer(gpu_buffer, 0, true);
|
||||
return CreateSourceTexture(gpu_buffer, 0);
|
||||
}
|
||||
|
||||
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
||||
const GpuBuffer& gpu_buffer, int plane) {
|
||||
return MapGpuBuffer(gpu_buffer, plane, true);
|
||||
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(plane));
|
||||
}
|
||||
|
||||
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
|
||||
const ImageFrame& image_frame) {
|
||||
GlTexture texture =
|
||||
MapGpuBuffer(GpuBuffer::CopyingImageFrame(image_frame), 0, true);
|
||||
return texture;
|
||||
auto gpu_buffer = GpuBuffer::CopyingImageFrame(image_frame);
|
||||
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(0));
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -150,11 +147,9 @@ GlTexture GlCalculatorHelperImpl::CreateDestinationTexture(
|
|||
CreateFramebuffer();
|
||||
}
|
||||
|
||||
GpuBuffer buffer =
|
||||
GpuBuffer gpu_buffer =
|
||||
gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format);
|
||||
GlTexture texture = MapGpuBuffer(buffer, 0, false);
|
||||
|
||||
return texture;
|
||||
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureWriteView(0));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
|
||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/gpu/gl_texture_view.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Wrap(
|
||||
|
@ -122,6 +125,15 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
|||
|
||||
if (alignment != 4 && data) glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
|
||||
|
||||
// TODO: does this need to set the texture params? We set them again when the
|
||||
// texture is actually acccessed via GlTexture[View]. Or should they always be
|
||||
// set on creation?
|
||||
if (format_ != GpuBufferFormat::kUnknown) {
|
||||
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||
format_, /*plane=*/0, context->GetGlVersion());
|
||||
context->SetStandardTextureParams(target_, info.gl_internal_format);
|
||||
}
|
||||
|
||||
glBindTexture(target_, 0);
|
||||
|
||||
// Use the deletion callback to delete the texture on the context
|
||||
|
@ -167,7 +179,7 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
|||
producer_context_ = producer_sync_->GetContext();
|
||||
}
|
||||
|
||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) {
|
||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
||||
absl::MutexLock lock(&consumer_sync_mutex_);
|
||||
consumer_multi_sync_->Add(std::move(cons_token));
|
||||
}
|
||||
|
@ -181,7 +193,7 @@ GlTextureBuffer::~GlTextureBuffer() {
|
|||
}
|
||||
}
|
||||
|
||||
void GlTextureBuffer::WaitUntilComplete() {
|
||||
void GlTextureBuffer::WaitUntilComplete() const {
|
||||
// Buffers created by the application (using the constructor that wraps an
|
||||
// existing texture) have no sync token and are assumed to be already
|
||||
// complete.
|
||||
|
@ -190,7 +202,7 @@ void GlTextureBuffer::WaitUntilComplete() {
|
|||
}
|
||||
}
|
||||
|
||||
void GlTextureBuffer::WaitOnGpu() {
|
||||
void GlTextureBuffer::WaitOnGpu() const {
|
||||
// Buffers created by the application (using the constructor that wraps an
|
||||
// existing texture) have no sync token and are assumed to be already
|
||||
// complete.
|
||||
|
@ -212,4 +224,127 @@ void GlTextureBuffer::WaitForConsumersOnGpu() {
|
|||
// precisely, on only one GL context.
|
||||
}
|
||||
|
||||
GlTextureView GlTextureBuffer::GetGlTextureReadView(
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
|
||||
auto gl_context = GlContext::GetCurrent();
|
||||
CHECK(gl_context);
|
||||
CHECK_EQ(plane, 0);
|
||||
// Insert wait call to sync with the producer.
|
||||
WaitOnGpu();
|
||||
GlTextureView::DetachFn detach = [this](mediapipe::GlTextureView& texture) {
|
||||
// Inform the GlTextureBuffer that we have finished accessing its
|
||||
// contents, and create a consumer sync point.
|
||||
DidRead(texture.gl_context()->CreateSyncToken());
|
||||
};
|
||||
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
|
||||
std::move(gpu_buffer), plane, std::move(detach),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
GlTextureView GlTextureBuffer::GetGlTextureWriteView(
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
|
||||
auto gl_context = GlContext::GetCurrent();
|
||||
CHECK(gl_context);
|
||||
CHECK_EQ(plane, 0);
|
||||
// Insert wait call to sync with the producer.
|
||||
WaitOnGpu();
|
||||
Reuse(); // TODO: the producer wait should probably be part of Reuse in the
|
||||
// case when there are no consumers.
|
||||
GlTextureView::DoneWritingFn done_writing =
|
||||
[this](const mediapipe::GlTextureView& texture) {
|
||||
ViewDoneWriting(texture);
|
||||
};
|
||||
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
|
||||
std::move(gpu_buffer), plane, nullptr,
|
||||
std::move(done_writing));
|
||||
}
|
||||
|
||||
void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) {
|
||||
// Inform the GlTextureBuffer that we have produced new content, and create
|
||||
// a producer sync point.
|
||||
Updated(view.gl_context()->CreateSyncToken());
|
||||
|
||||
#ifdef __ANDROID__
|
||||
// On (some?) Android devices, the texture may need to be explicitly
|
||||
// detached from the current framebuffer.
|
||||
// TODO: is this necessary even with the unbind in BindFramebuffer?
|
||||
// It is not clear if this affected other contexts too, but let's keep it
|
||||
// while in doubt.
|
||||
GLint type = GL_NONE;
|
||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
|
||||
&type);
|
||||
if (type == GL_TEXTURE) {
|
||||
GLint color_attachment = 0;
|
||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
||||
&color_attachment);
|
||||
if (color_attachment == name()) {
|
||||
glBindFramebuffer(GL_FRAMEBUFFER, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Some Android drivers log a GL_INVALID_ENUM error after the first
|
||||
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
|
||||
// even though it should be ok to ask for the type and get back GL_NONE.
|
||||
// Let's just ignore any pending errors here.
|
||||
GLenum error;
|
||||
while ((error = glGetError()) != GL_NO_ERROR) {
|
||||
}
|
||||
|
||||
#endif // __ANDROID__
|
||||
}
|
||||
|
||||
static void ReadTexture(const GlTextureView& view, GpuBufferFormat format,
|
||||
void* output, size_t size) {
|
||||
// TODO: check buffer size? We could use glReadnPixels where available
|
||||
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
|
||||
// won't overflow the buffer with glReadPixels, we'd also need to check or
|
||||
// reset several glPixelStore parameters (e.g. what if someone had the
|
||||
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
|
||||
CHECK(view.gl_context());
|
||||
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||
format, view.plane(), view.gl_context()->GetGlVersion());
|
||||
|
||||
GLint current_fbo;
|
||||
glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo);
|
||||
CHECK_NE(current_fbo, 0);
|
||||
|
||||
GLint color_attachment_name;
|
||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
||||
&color_attachment_name);
|
||||
if (color_attachment_name != view.name()) {
|
||||
// Save the viewport. Note that we assume that the color attachment is a
|
||||
// GL_TEXTURE_2D texture.
|
||||
GLint viewport[4];
|
||||
glGetIntegerv(GL_VIEWPORT, viewport);
|
||||
|
||||
// Set the data from GLTextureView object.
|
||||
glViewport(0, 0, view.width(), view.height());
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
||||
view.name(), 0);
|
||||
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
||||
info.gl_type, output);
|
||||
|
||||
// Restore from the saved viewport and color attachment name.
|
||||
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
|
||||
color_attachment_name, 0);
|
||||
} else {
|
||||
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
||||
info.gl_type, output);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ImageFrame> GlTextureBuffer::AsImageFrame() const {
|
||||
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
|
||||
auto output = absl::make_unique<ImageFrame>(
|
||||
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
auto view = GetGlTextureReadView(nullptr, 0);
|
||||
ReadTexture(view, format(), output->MutablePixelData(),
|
||||
output->PixelDataSize());
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -25,13 +25,14 @@
|
|||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class GlCalculatorHelperImpl;
|
||||
|
||||
// Implements a GPU memory buffer as an OpenGL texture. For internal use.
|
||||
class GlTextureBuffer {
|
||||
class GlTextureBuffer : public mediapipe::internal::GpuBufferStorage {
|
||||
public:
|
||||
// This is called when the texture buffer is deleted. It is passed a sync
|
||||
// token created at that time on the GlContext. If the GlTextureBuffer has
|
||||
|
@ -85,6 +86,13 @@ class GlTextureBuffer {
|
|||
int height() const { return height_; }
|
||||
GpuBufferFormat format() const { return format_; }
|
||||
|
||||
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||
int plane) const override;
|
||||
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||
int plane) override;
|
||||
void ViewDoneWriting(const GlTextureView& view) override;
|
||||
std::unique_ptr<ImageFrame> AsImageFrame() const override;
|
||||
|
||||
// If this texture is going to be used outside of the context that produced
|
||||
// it, this method should be called to ensure that its updated contents are
|
||||
// available. When this method returns, all changed made before the call to
|
||||
|
@ -94,13 +102,13 @@ class GlTextureBuffer {
|
|||
// NOTE: This blocks the current CPU thread and makes the changes visible
|
||||
// to the CPU. If you want to access the data via OpenGL, use WaitOnGpu
|
||||
// instead.
|
||||
void WaitUntilComplete();
|
||||
void WaitUntilComplete() const;
|
||||
|
||||
// Call this method to synchronize the current GL context with the texture's
|
||||
// producer. This will not block the current CPU thread, but will ensure that
|
||||
// subsequent GL commands see the texture in its complete status, with all
|
||||
// rendering done on the GPU by the generating context.
|
||||
void WaitOnGpu();
|
||||
void WaitOnGpu() const;
|
||||
|
||||
// Informs the buffer that its contents are going to be overwritten.
|
||||
// This invalidates the current sync token.
|
||||
|
@ -114,7 +122,7 @@ class GlTextureBuffer {
|
|||
void Updated(std::shared_ptr<GlSyncPoint> prod_token);
|
||||
|
||||
// Informs the buffer that a consumer has finished reading from it.
|
||||
void DidRead(std::shared_ptr<GlSyncPoint> cons_token);
|
||||
void DidRead(std::shared_ptr<GlSyncPoint> cons_token) const;
|
||||
|
||||
// Waits for all pending consumers to finish accessing the current content
|
||||
// of the texture. This (preferably the OnGpu version) should be called
|
||||
|
@ -143,10 +151,11 @@ class GlTextureBuffer {
|
|||
const GLenum target_ = GL_TEXTURE_2D;
|
||||
// Token tracking changes to this texture. Used by WaitUntilComplete.
|
||||
std::shared_ptr<GlSyncPoint> producer_sync_;
|
||||
absl::Mutex consumer_sync_mutex_;
|
||||
mutable absl::Mutex consumer_sync_mutex_;
|
||||
// Tokens tracking the point when consumers finished using this texture.
|
||||
std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_ ABSL_GUARDED_BY(
|
||||
consumer_sync_mutex_) = absl::make_unique<GlMultiSyncPoint>();
|
||||
mutable std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_
|
||||
ABSL_GUARDED_BY(consumer_sync_mutex_) =
|
||||
absl::make_unique<GlMultiSyncPoint>();
|
||||
DeletionCallback deletion_callback_;
|
||||
std::shared_ptr<GlContext> producer_context_;
|
||||
};
|
||||
|
|
16
mediapipe/gpu/gl_texture_view.cc
Normal file
16
mediapipe/gpu/gl_texture_view.cc
Normal file
|
@ -0,0 +1,16 @@
|
|||
#include "mediapipe/gpu/gl_texture_view.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
void GlTextureView::Release() {
|
||||
if (detach_) detach_(*this);
|
||||
detach_ = nullptr;
|
||||
gl_context_ = nullptr;
|
||||
gpu_buffer_ = nullptr;
|
||||
plane_ = 0;
|
||||
name_ = 0;
|
||||
width_ = 0;
|
||||
height_ = 0;
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
86
mediapipe/gpu/gl_texture_view.h
Normal file
86
mediapipe/gpu/gl_texture_view.h
Normal file
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
|
||||
#define MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class GlContext;
|
||||
class GlTextureViewManager;
|
||||
class GpuBuffer;
|
||||
|
||||
class GlTextureView {
|
||||
public:
|
||||
GlTextureView() {}
|
||||
~GlTextureView() { Release(); }
|
||||
// TODO: make this class move-only.
|
||||
|
||||
GlContext* gl_context() const { return gl_context_; }
|
||||
int width() const { return width_; }
|
||||
int height() const { return height_; }
|
||||
GLenum target() const { return target_; }
|
||||
GLuint name() const { return name_; }
|
||||
const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; }
|
||||
int plane() const { return plane_; }
|
||||
|
||||
using DetachFn = std::function<void(GlTextureView&)>;
|
||||
using DoneWritingFn = std::function<void(const GlTextureView&)>;
|
||||
|
||||
private:
|
||||
friend class GpuBuffer;
|
||||
friend class GlTextureBuffer;
|
||||
friend class GpuBufferStorageCvPixelBuffer;
|
||||
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
|
||||
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
|
||||
DetachFn detach, DoneWritingFn done_writing)
|
||||
: gl_context_(context),
|
||||
target_(target),
|
||||
name_(name),
|
||||
width_(width),
|
||||
height_(height),
|
||||
gpu_buffer_(std::move(gpu_buffer)),
|
||||
plane_(plane),
|
||||
detach_(std::move(detach)),
|
||||
done_writing_(std::move(done_writing)) {}
|
||||
|
||||
// TODO: remove this friend declaration.
|
||||
friend class GlTexture;
|
||||
void Release();
|
||||
// TODO: make this non-const.
|
||||
void DoneWriting() const {
|
||||
if (done_writing_) done_writing_(*this);
|
||||
}
|
||||
|
||||
GlContext* gl_context_ = nullptr;
|
||||
GLenum target_ = GL_TEXTURE_2D;
|
||||
GLuint name_ = 0;
|
||||
// Note: when scale is not 1, we still give the nominal size of the image.
|
||||
int width_ = 0;
|
||||
int height_ = 0;
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer_; // using shared_ptr temporarily
|
||||
int plane_ = 0;
|
||||
DetachFn detach_;
|
||||
DoneWritingFn done_writing_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
|
|
@ -8,62 +8,7 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
void GlTextureView::Release() {
|
||||
if (detach_) detach_(*this);
|
||||
detach_ = nullptr;
|
||||
gl_context_ = nullptr;
|
||||
gpu_buffer_ = nullptr;
|
||||
plane_ = 0;
|
||||
name_ = 0;
|
||||
width_ = 0;
|
||||
height_ = 0;
|
||||
}
|
||||
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
#if TARGET_OS_OSX
|
||||
typedef CVOpenGLTextureRef CVTextureType;
|
||||
#else
|
||||
typedef CVOpenGLESTextureRef CVTextureType;
|
||||
#endif // TARGET_OS_OSX
|
||||
|
||||
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
|
||||
CVReturn err;
|
||||
auto gl_context = GlContext::GetCurrent();
|
||||
CHECK(gl_context);
|
||||
#if TARGET_OS_OSX
|
||||
CVTextureType cv_texture_temp;
|
||||
err = CVOpenGLTextureCacheCreateTextureFromImage(
|
||||
kCFAllocatorDefault, gl_context->cv_texture_cache(),
|
||||
GetCVPixelBufferRef(), NULL, &cv_texture_temp);
|
||||
CHECK(cv_texture_temp && !err)
|
||||
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
|
||||
CFHolder<CVTextureType> cv_texture;
|
||||
cv_texture.adopt(cv_texture_temp);
|
||||
return GlTextureView(
|
||||
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
|
||||
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
|
||||
[cv_texture](
|
||||
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
|
||||
#else
|
||||
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||
format(), plane, gl_context->GetGlVersion());
|
||||
CVTextureType cv_texture_temp;
|
||||
err = CVOpenGLESTextureCacheCreateTextureFromImage(
|
||||
kCFAllocatorDefault, gl_context->cv_texture_cache(),
|
||||
GetCVPixelBufferRef(), NULL, GL_TEXTURE_2D, info.gl_internal_format,
|
||||
width() / info.downscale, height() / info.downscale, info.gl_format,
|
||||
info.gl_type, plane, &cv_texture_temp);
|
||||
CHECK(cv_texture_temp && !err)
|
||||
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
|
||||
CFHolder<CVTextureType> cv_texture;
|
||||
cv_texture.adopt(cv_texture_temp);
|
||||
return GlTextureView(
|
||||
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
|
||||
CVOpenGLESTextureGetName(*cv_texture), width(), height(), *this, plane,
|
||||
[cv_texture](
|
||||
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
|
||||
#endif // TARGET_OS_OSX
|
||||
}
|
||||
|
||||
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
||||
auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame);
|
||||
|
@ -72,187 +17,11 @@ GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
|||
CHECK_OK(maybe_buffer.status());
|
||||
return GpuBuffer(std::move(maybe_buffer).value());
|
||||
}
|
||||
|
||||
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
|
||||
CHECK(GetCVPixelBufferRef());
|
||||
return CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef());
|
||||
}
|
||||
|
||||
void GlTextureView::DoneWriting() const {
|
||||
CHECK(gpu_buffer_);
|
||||
#if TARGET_IPHONE_SIMULATOR
|
||||
CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef();
|
||||
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
||||
CHECK(err == kCVReturnSuccess)
|
||||
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
||||
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
|
||||
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
|
||||
uint8_t* pixel_ptr =
|
||||
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
||||
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
||||
// TODO: restore previous framebuffer? Move this to helper so we
|
||||
// can use BindFramebuffer?
|
||||
glViewport(0, 0, width(), height());
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target(),
|
||||
name(), 0);
|
||||
|
||||
size_t contiguous_bytes_per_row = width() * 4;
|
||||
if (bytes_per_row == contiguous_bytes_per_row) {
|
||||
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||
pixel_ptr);
|
||||
} else {
|
||||
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
||||
height());
|
||||
uint8_t* temp_ptr = contiguous_buffer.data();
|
||||
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||
temp_ptr);
|
||||
for (int i = 0; i < height(); ++i) {
|
||||
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
||||
temp_ptr += contiguous_bytes_per_row;
|
||||
pixel_ptr += bytes_per_row;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
||||
}
|
||||
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
||||
CHECK(err == kCVReturnSuccess)
|
||||
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
||||
#endif
|
||||
}
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
||||
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
|
||||
auto gl_context = GlContext::GetCurrent();
|
||||
CHECK(gl_context);
|
||||
const GlTextureBufferSharedPtr& texture_buffer =
|
||||
GetGlTextureBufferSharedPtr();
|
||||
// Insert wait call to sync with the producer.
|
||||
texture_buffer->WaitOnGpu();
|
||||
CHECK_EQ(plane, 0);
|
||||
GlTextureView::DetachFn detach;
|
||||
if (for_reading) {
|
||||
detach = [](mediapipe::GlTextureView& texture) {
|
||||
// Inform the GlTextureBuffer that we have finished accessing its
|
||||
// contents, and create a consumer sync point.
|
||||
texture.gpu_buffer().GetGlTextureBufferSharedPtr()->DidRead(
|
||||
texture.gl_context()->CreateSyncToken());
|
||||
};
|
||||
}
|
||||
return GlTextureView(gl_context.get(), texture_buffer->target(),
|
||||
texture_buffer->name(), width(), height(), *this, plane,
|
||||
std::move(detach));
|
||||
}
|
||||
|
||||
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
|
||||
auto gl_context = GlContext::GetCurrent();
|
||||
CHECK(gl_context);
|
||||
|
||||
auto buffer = GlTextureBuffer::Create(image_frame);
|
||||
|
||||
// TODO: does this need to set the texture params? We set them again when the
|
||||
// texture is actually acccessed via GlTexture[View]. Or should they always be
|
||||
// set on creation?
|
||||
if (buffer->format() != GpuBufferFormat::kUnknown) {
|
||||
glBindTexture(GL_TEXTURE_2D, buffer->name());
|
||||
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||
buffer->format(), /*plane=*/0, gl_context->GetGlVersion());
|
||||
gl_context->SetStandardTextureParams(buffer->target(),
|
||||
info.gl_internal_format);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
}
|
||||
|
||||
return GpuBuffer(std::move(buffer));
|
||||
}
|
||||
|
||||
static void ReadTexture(const GlTextureView& view, void* output, size_t size) {
|
||||
// TODO: check buffer size? We could use glReadnPixels where available
|
||||
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
|
||||
// won't overflow the buffer with glReadPixels, we'd also need to check or
|
||||
// reset several glPixelStore parameters (e.g. what if someone had the
|
||||
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
|
||||
CHECK(view.gl_context());
|
||||
GlTextureInfo info =
|
||||
GlTextureInfoForGpuBufferFormat(view.gpu_buffer().format(), view.plane(),
|
||||
view.gl_context()->GetGlVersion());
|
||||
|
||||
GLint current_fbo;
|
||||
glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo);
|
||||
CHECK_NE(current_fbo, 0);
|
||||
|
||||
GLint color_attachment_name;
|
||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
||||
&color_attachment_name);
|
||||
if (color_attachment_name != view.name()) {
|
||||
// Save the viewport. Note that we assume that the color attachment is a
|
||||
// GL_TEXTURE_2D texture.
|
||||
GLint viewport[4];
|
||||
glGetIntegerv(GL_VIEWPORT, viewport);
|
||||
|
||||
// Set the data from GLTextureView object.
|
||||
glViewport(0, 0, view.width(), view.height());
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
||||
view.name(), 0);
|
||||
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
||||
info.gl_type, output);
|
||||
|
||||
// Restore from the saved viewport and color attachment name.
|
||||
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
|
||||
color_attachment_name, 0);
|
||||
} else {
|
||||
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
|
||||
info.gl_type, output);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
|
||||
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
|
||||
auto output = absl::make_unique<ImageFrame>(
|
||||
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
auto view = GetGlTextureView(0, true);
|
||||
ReadTexture(view, output->MutablePixelData(), output->PixelDataSize());
|
||||
return output;
|
||||
}
|
||||
|
||||
void GlTextureView::DoneWriting() const {
|
||||
CHECK(gpu_buffer_);
|
||||
// Inform the GlTextureBuffer that we have produced new content, and create
|
||||
// a producer sync point.
|
||||
gpu_buffer_.GetGlTextureBufferSharedPtr()->Updated(
|
||||
gl_context()->CreateSyncToken());
|
||||
|
||||
#ifdef __ANDROID__
|
||||
// On (some?) Android devices, the texture may need to be explicitly
|
||||
// detached from the current framebuffer.
|
||||
// TODO: is this necessary even with the unbind in BindFramebuffer?
|
||||
// It is not clear if this affected other contexts too, but let's keep it
|
||||
// while in doubt.
|
||||
GLint type = GL_NONE;
|
||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
|
||||
&type);
|
||||
if (type == GL_TEXTURE) {
|
||||
GLint color_attachment = 0;
|
||||
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
|
||||
&color_attachment);
|
||||
if (color_attachment == name()) {
|
||||
glBindFramebuffer(GL_FRAMEBUFFER, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Some Android drivers log a GL_INVALID_ENUM error after the first
|
||||
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
|
||||
// even though it should be ok to ask for the type and get back GL_NONE.
|
||||
// Let's just ignore any pending errors here.
|
||||
GLenum error;
|
||||
while ((error = glGetError()) != GL_NO_ERROR) {
|
||||
}
|
||||
|
||||
#endif // __ANDROID__
|
||||
return GpuBuffer(GlTextureBuffer::Create(image_frame));
|
||||
}
|
||||
|
||||
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gl_texture_view.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <CoreVideo/CoreVideo.h>
|
||||
|
@ -27,6 +29,10 @@
|
|||
#include "mediapipe/objc/CFHolder.h"
|
||||
#endif // defined(__APPLE__)
|
||||
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
||||
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
@ -34,7 +40,6 @@
|
|||
namespace mediapipe {
|
||||
|
||||
class GlContext;
|
||||
class GlTextureView;
|
||||
|
||||
// This class wraps a platform-specific buffer of GPU data.
|
||||
// An instance of GpuBuffer acts as an opaque reference to the underlying
|
||||
|
@ -71,9 +76,9 @@ class GpuBuffer {
|
|||
}
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
||||
int width() const;
|
||||
int height() const;
|
||||
GpuBufferFormat format() const;
|
||||
int width() const { return current_storage().width(); }
|
||||
int height() const { return current_storage().height(); }
|
||||
GpuBufferFormat format() const { return current_storage().format(); }
|
||||
|
||||
// Converts to true iff valid.
|
||||
explicit operator bool() const { return operator!=(nullptr); }
|
||||
|
@ -88,8 +93,15 @@ class GpuBuffer {
|
|||
// Allow assignment from nullptr.
|
||||
GpuBuffer& operator=(std::nullptr_t other);
|
||||
|
||||
// TODO: split into read and write, remove const from write.
|
||||
GlTextureView GetGlTextureView(int plane, bool for_reading) const;
|
||||
GlTextureView GetGlTextureReadView(int plane) const {
|
||||
return current_storage().GetGlTextureReadView(
|
||||
std::make_shared<GpuBuffer>(*this), plane);
|
||||
}
|
||||
|
||||
GlTextureView GetGlTextureWriteView(int plane) {
|
||||
return current_storage().GetGlTextureWriteView(
|
||||
std::make_shared<GpuBuffer>(*this), plane);
|
||||
}
|
||||
|
||||
// Make a GpuBuffer copying the data from an ImageFrame.
|
||||
static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame);
|
||||
|
@ -99,114 +111,84 @@ class GpuBuffer {
|
|||
// In order to work correctly across platforms, callers should always treat
|
||||
// the returned ImageFrame as if it shares memory with the GpuBuffer, i.e.
|
||||
// treat it as immutable if the GpuBuffer must not be modified.
|
||||
std::unique_ptr<ImageFrame> AsImageFrame() const;
|
||||
std::unique_ptr<ImageFrame> AsImageFrame() const {
|
||||
return current_storage().AsImageFrame();
|
||||
}
|
||||
|
||||
private:
|
||||
class PlaceholderGpuBufferStorage
|
||||
: public mediapipe::internal::GpuBufferStorage {
|
||||
public:
|
||||
int width() const override { return 0; }
|
||||
int height() const override { return 0; }
|
||||
virtual GpuBufferFormat format() const override {
|
||||
return GpuBufferFormat::kUnknown;
|
||||
}
|
||||
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||
int plane) const override {
|
||||
return {};
|
||||
}
|
||||
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||
int plane) override {
|
||||
return {};
|
||||
}
|
||||
void ViewDoneWriting(const GlTextureView& view) override{};
|
||||
std::unique_ptr<ImageFrame> AsImageFrame() const override {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
mediapipe::internal::GpuBufferStorage& no_storage() const {
|
||||
static PlaceholderGpuBufferStorage placeholder;
|
||||
return placeholder;
|
||||
}
|
||||
|
||||
const mediapipe::internal::GpuBufferStorage& current_storage() const {
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
CFHolder<CVPixelBufferRef> pixel_buffer_;
|
||||
if (pixel_buffer_ != nullptr) return pixel_buffer_;
|
||||
#else
|
||||
if (texture_buffer_) return *texture_buffer_;
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
return no_storage();
|
||||
}
|
||||
|
||||
mediapipe::internal::GpuBufferStorage& current_storage() {
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
if (pixel_buffer_ != nullptr) return pixel_buffer_;
|
||||
#else
|
||||
if (texture_buffer_) return *texture_buffer_;
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
return no_storage();
|
||||
}
|
||||
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
GpuBufferStorageCvPixelBuffer pixel_buffer_;
|
||||
#else
|
||||
GlTextureBufferSharedPtr texture_buffer_;
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
};
|
||||
|
||||
class GlTextureView {
|
||||
public:
|
||||
GlTextureView() {}
|
||||
~GlTextureView() { Release(); }
|
||||
// TODO: make this class move-only.
|
||||
|
||||
GlContext* gl_context() const { return gl_context_; }
|
||||
int width() const { return width_; }
|
||||
int height() const { return height_; }
|
||||
GLenum target() const { return target_; }
|
||||
GLuint name() const { return name_; }
|
||||
const GpuBuffer& gpu_buffer() const { return gpu_buffer_; }
|
||||
int plane() const { return plane_; }
|
||||
|
||||
private:
|
||||
friend class GpuBuffer;
|
||||
using DetachFn = std::function<void(GlTextureView&)>;
|
||||
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
|
||||
int height, GpuBuffer gpu_buffer, int plane, DetachFn detach)
|
||||
: gl_context_(context),
|
||||
target_(target),
|
||||
name_(name),
|
||||
width_(width),
|
||||
height_(height),
|
||||
gpu_buffer_(std::move(gpu_buffer)),
|
||||
plane_(plane),
|
||||
detach_(std::move(detach)) {}
|
||||
|
||||
// TODO: remove this friend declaration.
|
||||
friend class GlTexture;
|
||||
void Release();
|
||||
// TODO: make this non-const.
|
||||
void DoneWriting() const;
|
||||
|
||||
GlContext* gl_context_ = nullptr;
|
||||
GLenum target_ = GL_TEXTURE_2D;
|
||||
GLuint name_ = 0;
|
||||
// Note: when scale is not 1, we still give the nominal size of the image.
|
||||
int width_ = 0;
|
||||
int height_ = 0;
|
||||
GpuBuffer gpu_buffer_;
|
||||
int plane_ = 0;
|
||||
DetachFn detach_;
|
||||
};
|
||||
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
||||
return ¤t_storage() == &no_storage();
|
||||
}
|
||||
|
||||
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
||||
inline int GpuBuffer::width() const {
|
||||
return static_cast<int>(CVPixelBufferGetWidth(*pixel_buffer_));
|
||||
}
|
||||
|
||||
inline int GpuBuffer::height() const {
|
||||
return static_cast<int>(CVPixelBufferGetHeight(*pixel_buffer_));
|
||||
}
|
||||
|
||||
inline GpuBufferFormat GpuBuffer::format() const {
|
||||
return GpuBufferFormatForCVPixelFormat(
|
||||
CVPixelBufferGetPixelFormatType(*pixel_buffer_));
|
||||
}
|
||||
|
||||
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
||||
return pixel_buffer_ == other;
|
||||
}
|
||||
|
||||
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
||||
return pixel_buffer_ == other.pixel_buffer_;
|
||||
}
|
||||
|
||||
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
|
||||
pixel_buffer_.reset(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
inline int GpuBuffer::width() const { return texture_buffer_->width(); }
|
||||
|
||||
inline int GpuBuffer::height() const { return texture_buffer_->height(); }
|
||||
|
||||
inline GpuBufferFormat GpuBuffer::format() const {
|
||||
return texture_buffer_->format();
|
||||
}
|
||||
|
||||
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
|
||||
return texture_buffer_ == other;
|
||||
}
|
||||
|
||||
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
|
||||
return texture_buffer_ == other.texture_buffer_;
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
}
|
||||
|
||||
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
|
||||
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
pixel_buffer_.reset(other);
|
||||
#else
|
||||
texture_buffer_ = other;
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
return *this;
|
||||
}
|
||||
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_GPU_GPU_BUFFER_H_
|
||||
|
|
41
mediapipe/gpu/gpu_buffer_storage.h
Normal file
41
mediapipe/gpu/gpu_buffer_storage.h
Normal file
|
@ -0,0 +1,41 @@
|
|||
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
|
||||
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
|
||||
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
|
||||
namespace mediapipe {
|
||||
class GlTextureView;
|
||||
class GpuBuffer;
|
||||
} // namespace mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace internal {
|
||||
|
||||
using mediapipe::GlTextureView;
|
||||
using mediapipe::GpuBuffer;
|
||||
using mediapipe::GpuBufferFormat;
|
||||
|
||||
class GlTextureViewManager {
|
||||
public:
|
||||
virtual ~GlTextureViewManager() = default;
|
||||
virtual GlTextureView GetGlTextureReadView(
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const = 0;
|
||||
virtual GlTextureView GetGlTextureWriteView(
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) = 0;
|
||||
virtual void ViewDoneWriting(const GlTextureView& view) = 0;
|
||||
};
|
||||
|
||||
class GpuBufferStorage : public GlTextureViewManager {
|
||||
public:
|
||||
virtual ~GpuBufferStorage() = default;
|
||||
virtual int width() const = 0;
|
||||
virtual int height() const = 0;
|
||||
virtual GpuBufferFormat format() const = 0;
|
||||
virtual std::unique_ptr<ImageFrame> AsImageFrame() const = 0;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
|
116
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc
Normal file
116
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc
Normal file
|
@ -0,0 +1,116 @@
|
|||
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
|
||||
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
#include "mediapipe/objc/util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
#if TARGET_OS_OSX
|
||||
typedef CVOpenGLTextureRef CVTextureType;
|
||||
#else
|
||||
typedef CVOpenGLESTextureRef CVTextureType;
|
||||
#endif // TARGET_OS_OSX
|
||||
|
||||
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureReadView(
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
|
||||
CVReturn err;
|
||||
auto gl_context = GlContext::GetCurrent();
|
||||
CHECK(gl_context);
|
||||
#if TARGET_OS_OSX
|
||||
CVTextureType cv_texture_temp;
|
||||
err = CVOpenGLTextureCacheCreateTextureFromImage(
|
||||
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
|
||||
&cv_texture_temp);
|
||||
CHECK(cv_texture_temp && !err)
|
||||
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
|
||||
CFHolder<CVTextureType> cv_texture;
|
||||
cv_texture.adopt(cv_texture_temp);
|
||||
return GlTextureView(
|
||||
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
|
||||
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
|
||||
[cv_texture](
|
||||
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
|
||||
#else
|
||||
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
|
||||
format(), plane, gl_context->GetGlVersion());
|
||||
CVTextureType cv_texture_temp;
|
||||
err = CVOpenGLESTextureCacheCreateTextureFromImage(
|
||||
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
|
||||
GL_TEXTURE_2D, info.gl_internal_format, width() / info.downscale,
|
||||
height() / info.downscale, info.gl_format, info.gl_type, plane,
|
||||
&cv_texture_temp);
|
||||
CHECK(cv_texture_temp && !err)
|
||||
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
|
||||
CFHolder<CVTextureType> cv_texture;
|
||||
cv_texture.adopt(cv_texture_temp);
|
||||
return GlTextureView(
|
||||
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
|
||||
CVOpenGLESTextureGetName(*cv_texture), width(), height(),
|
||||
std::move(gpu_buffer), plane,
|
||||
[cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ },
|
||||
// TODO: make GetGlTextureView for write view non-const, remove cast
|
||||
// Note: we have to copy *this here because this storage is currently
|
||||
// stored in GpuBuffer by value, and so the this pointer becomes invalid
|
||||
// if the GpuBuffer is moved/copied. TODO: fix this.
|
||||
[me = *this](const mediapipe::GlTextureView& view) {
|
||||
const_cast<GpuBufferStorageCvPixelBuffer*>(&me)->ViewDoneWriting(view);
|
||||
});
|
||||
#endif // TARGET_OS_OSX
|
||||
}
|
||||
|
||||
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureWriteView(
|
||||
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
|
||||
// For this storage there is currently no difference between read and write
|
||||
// views, so we delegate to the read method.
|
||||
return GetGlTextureReadView(std::move(gpu_buffer), plane);
|
||||
}
|
||||
|
||||
void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) {
|
||||
#if TARGET_IPHONE_SIMULATOR
|
||||
CVPixelBufferRef pixel_buffer = **this;
|
||||
CHECK(pixel_buffer);
|
||||
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
||||
CHECK(err == kCVReturnSuccess)
|
||||
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
||||
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
|
||||
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
|
||||
uint8_t* pixel_ptr =
|
||||
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
||||
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
||||
// TODO: restore previous framebuffer? Move this to helper so we
|
||||
// can use BindFramebuffer?
|
||||
glViewport(0, 0, view.width(), view.height());
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
||||
view.name(), 0);
|
||||
|
||||
size_t contiguous_bytes_per_row = view.width() * 4;
|
||||
if (bytes_per_row == contiguous_bytes_per_row) {
|
||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||
pixel_ptr);
|
||||
} else {
|
||||
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
||||
view.height());
|
||||
uint8_t* temp_ptr = contiguous_buffer.data();
|
||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||
temp_ptr);
|
||||
for (int i = 0; i < view.height(); ++i) {
|
||||
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
||||
temp_ptr += contiguous_bytes_per_row;
|
||||
pixel_ptr += bytes_per_row;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
||||
}
|
||||
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
||||
CHECK(err == kCVReturnSuccess)
|
||||
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<ImageFrame> GpuBufferStorageCvPixelBuffer::AsImageFrame()
|
||||
const {
|
||||
return CreateImageFrameForCVPixelBuffer(**this);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
41
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h
Normal file
41
mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h
Normal file
|
@ -0,0 +1,41 @@
|
|||
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
|
||||
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
|
||||
|
||||
#include <CoreVideo/CoreVideo.h>
|
||||
|
||||
#include "mediapipe/gpu/gl_texture_view.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage.h"
|
||||
#include "mediapipe/objc/CFHolder.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class GlContext;
|
||||
|
||||
class GpuBufferStorageCvPixelBuffer
|
||||
: public mediapipe::internal::GpuBufferStorage,
|
||||
public CFHolder<CVPixelBufferRef> {
|
||||
public:
|
||||
using CFHolder<CVPixelBufferRef>::CFHolder;
|
||||
GpuBufferStorageCvPixelBuffer(const CFHolder<CVPixelBufferRef>& other)
|
||||
: CFHolder(other) {}
|
||||
GpuBufferStorageCvPixelBuffer(CFHolder<CVPixelBufferRef>&& other)
|
||||
: CFHolder(std::move(other)) {}
|
||||
int width() const { return static_cast<int>(CVPixelBufferGetWidth(**this)); }
|
||||
int height() const {
|
||||
return static_cast<int>(CVPixelBufferGetHeight(**this));
|
||||
}
|
||||
virtual GpuBufferFormat format() const {
|
||||
return GpuBufferFormatForCVPixelFormat(
|
||||
CVPixelBufferGetPixelFormatType(**this));
|
||||
}
|
||||
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||
int plane) const override;
|
||||
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
|
||||
int plane) override;
|
||||
std::unique_ptr<ImageFrame> AsImageFrame() const override;
|
||||
void ViewDoneWriting(const GlTextureView& view) override;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
|
|
@ -64,6 +64,13 @@ public class AppTextureFrame implements TextureFrame {
|
|||
return timestamp;
|
||||
}
|
||||
|
||||
/** Returns true if a call to waitUntilReleased() would block waiting for release. */
|
||||
public boolean isNotYetReleased() {
|
||||
synchronized (this) {
|
||||
return inUse && releaseSyncToken == null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Waits until the consumer is done with the texture.
|
||||
*
|
||||
|
|
|
@ -26,7 +26,8 @@ android_library(
|
|||
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_gpu_image.binarypb",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image.binarypb",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/palm_detection:palm_detection.tflite",
|
||||
],
|
||||
assets_dir = "",
|
||||
|
|
|
@ -78,6 +78,7 @@ public class Hands extends ImageSolutionBase {
|
|||
Connection.create(HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP),
|
||||
Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP));
|
||||
|
||||
private static final String MODEL_COMPLEXITY = "model_complexity";
|
||||
private static final String NUM_HANDS = "num_hands";
|
||||
private static final String USE_PREV_LANDMARKS = "use_prev_landmarks";
|
||||
private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb";
|
||||
|
@ -131,6 +132,7 @@ public class Hands extends ImageSolutionBase {
|
|||
initialize(context, solutionInfo, outputHandler);
|
||||
Map<String, Packet> inputSidePackets = new HashMap<>();
|
||||
inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands()));
|
||||
inputSidePackets.put(MODEL_COMPLEXITY, packetCreator.createInt32(options.modelComplexity()));
|
||||
inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode()));
|
||||
start(inputSidePackets);
|
||||
}
|
||||
|
|
|
@ -26,6 +26,10 @@ import com.google.auto.value.AutoValue;
|
|||
* <p>maxNumHands: Maximum number of hands to detect. See details in
|
||||
* https://solutions.mediapipe.dev/hands#max_num_hands.
|
||||
*
|
||||
* <p>modelComplexity: Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
* inference latency generally go up with the model complexity. See details in
|
||||
* https://solutions.mediapipe.dev/hands#model_complexity.
|
||||
*
|
||||
* <p>minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for hand detection to be
|
||||
* considered successful. See details in
|
||||
* https://solutions.mediapipe.dev/hands#min_detection_confidence.
|
||||
|
@ -43,6 +47,8 @@ public abstract class HandsOptions {
|
|||
|
||||
public abstract int maxNumHands();
|
||||
|
||||
public abstract int modelComplexity();
|
||||
|
||||
public abstract float minDetectionConfidence();
|
||||
|
||||
public abstract float minTrackingConfidence();
|
||||
|
@ -59,6 +65,7 @@ public abstract class HandsOptions {
|
|||
public Builder withDefaultValues() {
|
||||
return setStaticImageMode(false)
|
||||
.setMaxNumHands(2)
|
||||
.setModelComplexity(1)
|
||||
.setMinDetectionConfidence(0.5f)
|
||||
.setMinTrackingConfidence(0.5f)
|
||||
.setRunOnGpu(true);
|
||||
|
@ -68,6 +75,8 @@ public abstract class HandsOptions {
|
|||
|
||||
public abstract Builder setMaxNumHands(int value);
|
||||
|
||||
public abstract Builder setModelComplexity(int value);
|
||||
|
||||
public abstract Builder setMinDetectionConfidence(float value);
|
||||
|
||||
public abstract Builder setMinTrackingConfidence(float value);
|
||||
|
|
|
@ -22,16 +22,30 @@ licenses(["notice"])
|
|||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
exports_files([
|
||||
"hand_landmark.tflite",
|
||||
"hand_landmark_full.tflite",
|
||||
"hand_landmark_lite.tflite",
|
||||
"hand_landmark_sparse.tflite",
|
||||
"handedness.txt",
|
||||
])
|
||||
|
||||
mediapipe_simple_subgraph(
|
||||
name = "hand_landmark_model_loader",
|
||||
graph = "hand_landmark_model_loader.pbtxt",
|
||||
register_as = "HandLandmarkModelLoader",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/tflite:tflite_model_calculator",
|
||||
"//mediapipe/calculators/util:local_file_contents_calculator",
|
||||
"//mediapipe/framework/tool:switch_container",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_simple_subgraph(
|
||||
name = "hand_landmark_cpu",
|
||||
graph = "hand_landmark_cpu.pbtxt",
|
||||
register_as = "HandLandmarkCpu",
|
||||
deps = [
|
||||
":hand_landmark_model_loader",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||
|
@ -50,6 +64,7 @@ mediapipe_simple_subgraph(
|
|||
graph = "hand_landmark_gpu.pbtxt",
|
||||
register_as = "HandLandmarkGpu",
|
||||
deps = [
|
||||
":hand_landmark_model_loader",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||
|
|
Binary file not shown.
|
@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
|
|||
# (NormalizedRect)
|
||||
input_stream: "ROI:hand_rect"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
|
||||
# NOTE: if a hand is not present within the given ROI, for this particular
|
||||
# timestamp there will not be an output packet in the LANDMARKS stream. However,
|
||||
|
@ -40,16 +45,23 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Loads the hand landmark TF Lite model.
|
||||
node {
|
||||
calculator: "HandLandmarkModelLoader"
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
output_side_packet: "MODEL:model"
|
||||
}
|
||||
|
||||
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||
# scores.
|
||||
node {
|
||||
calculator: "InferenceCalculator"
|
||||
input_side_packet: "MODEL:model"
|
||||
input_stream: "TENSORS:input_tensor"
|
||||
output_stream: "TENSORS:output_tensors"
|
||||
options: {
|
||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
||||
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
||||
delegate {
|
||||
xnnpack {}
|
||||
}
|
||||
|
|
BIN
mediapipe/modules/hand_landmark/hand_landmark_full.tflite
Executable file
BIN
mediapipe/modules/hand_landmark/hand_landmark_full.tflite
Executable file
Binary file not shown.
|
@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
|
|||
# (NormalizedRect)
|
||||
input_stream: "ROI:hand_rect"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
|
||||
# NOTE: if a hand is not present within the given ROI, for this particular
|
||||
# timestamp there will not be an output packet in the LANDMARKS stream. However,
|
||||
|
@ -41,18 +46,21 @@ node {
|
|||
}
|
||||
}
|
||||
|
||||
# Loads the hand landmark TF Lite model.
|
||||
node {
|
||||
calculator: "HandLandmarkModelLoader"
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
output_side_packet: "MODEL:model"
|
||||
}
|
||||
|
||||
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||
# scores.
|
||||
node {
|
||||
calculator: "InferenceCalculator"
|
||||
input_side_packet: "MODEL:model"
|
||||
input_stream: "TENSORS:input_tensor"
|
||||
output_stream: "TENSORS:output_tensors"
|
||||
options: {
|
||||
[mediapipe.InferenceCalculatorOptions.ext] {
|
||||
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Splits a vector of tensors to multiple vectors according to the ranges
|
||||
|
|
BIN
mediapipe/modules/hand_landmark/hand_landmark_lite.tflite
Executable file
BIN
mediapipe/modules/hand_landmark/hand_landmark_lite.tflite
Executable file
Binary file not shown.
|
@ -0,0 +1,63 @@
|
|||
# MediaPipe graph to load a selected hand landmark TF Lite model.
|
||||
|
||||
type: "HandLandmarkModelLoader"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# TF Lite model represented as a FlatBuffer.
|
||||
# (std::unique_ptr<tflite::FlatBufferModel, std::function<void(tflite::FlatBufferModel*)>>)
|
||||
output_side_packet: "MODEL:model"
|
||||
|
||||
# Determines path to the desired pose landmark model file.
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_side_packet: "SELECT:model_complexity"
|
||||
output_side_packet: "PACKET:model_path"
|
||||
options: {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
select: 1
|
||||
contained_node: {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet {
|
||||
string_value: "mediapipe/modules/hand_landmark/hand_landmark_lite.tflite"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
contained_node: {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet {
|
||||
string_value: "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Loads the file in the specified path into a blob.
|
||||
node {
|
||||
calculator: "LocalFileContentsCalculator"
|
||||
input_side_packet: "FILE_PATH:model_path"
|
||||
output_side_packet: "CONTENTS:model_blob"
|
||||
options: {
|
||||
[mediapipe.LocalFileContentsCalculatorOptions.ext]: {
|
||||
text_mode: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the input blob into a TF Lite model.
|
||||
node {
|
||||
calculator: "TfLiteModelCalculator"
|
||||
input_side_packet: "MODEL_BLOB:model_blob"
|
||||
output_side_packet: "MODEL:model"
|
||||
}
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
|||
# Max number of hands to detect/track. (int)
|
||||
input_side_packet: "NUM_HANDS:num_hands"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# Whether landmarks on the previous image should be used to help localize
|
||||
# landmarks on the current image. (bool)
|
||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||
|
@ -177,6 +182,7 @@ node {
|
|||
# Detect hand landmarks for the specific hand rect.
|
||||
node {
|
||||
calculator: "HandLandmarkCpu"
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
input_stream: "IMAGE:image_for_landmarks"
|
||||
input_stream: "ROI:single_hand_rect"
|
||||
output_stream: "LANDMARKS:single_hand_landmarks"
|
||||
|
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
|||
# Max number of hands to detect/track. (int)
|
||||
input_side_packet: "NUM_HANDS:num_hands"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# Whether landmarks on the previous image should be used to help localize
|
||||
# landmarks on the current image. (bool)
|
||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||
|
@ -85,6 +90,7 @@ node {
|
|||
calculator: "HandLandmarkTrackingCpu"
|
||||
input_stream: "IMAGE:image_frame"
|
||||
input_side_packet: "NUM_HANDS:num_hands"
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||
output_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
output_stream: "HANDEDNESS:multi_handedness"
|
||||
|
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
|||
# Max number of hands to detect/track. (int)
|
||||
input_side_packet: "NUM_HANDS:num_hands"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# Whether landmarks on the previous image should be used to help localize
|
||||
# landmarks on the current image. (bool)
|
||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||
|
@ -178,6 +183,7 @@ node {
|
|||
# Detect hand landmarks for the specific hand rect.
|
||||
node {
|
||||
calculator: "HandLandmarkGpu"
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
input_stream: "IMAGE:image_for_landmarks"
|
||||
input_stream: "ROI:single_hand_rect"
|
||||
output_stream: "LANDMARKS:single_hand_landmarks"
|
||||
|
|
|
@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
|
|||
# Max number of hands to detect/track. (int)
|
||||
input_side_packet: "NUM_HANDS:num_hands"
|
||||
|
||||
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
|
||||
# inference latency generally go up with the model complexity. If unspecified,
|
||||
# functions as set to 1. (int)
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
|
||||
# Whether landmarks on the previous image should be used to help localize
|
||||
# landmarks on the current image. (bool)
|
||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||
|
@ -85,6 +90,7 @@ node {
|
|||
calculator: "HandLandmarkTrackingGpu"
|
||||
input_stream: "IMAGE:gpu_buffer"
|
||||
input_side_packet: "NUM_HANDS:num_hands"
|
||||
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
|
||||
output_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
output_stream: "HANDEDNESS:multi_handedness"
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
# - "face_landmark.tflite" is available at
|
||||
# "mediapipe/modules/face_landmark/face_landmark.tflite"
|
||||
#
|
||||
# - "hand_landmark.tflite" is available at
|
||||
# "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
||||
# - "hand_landmark_full.tflite" is available at
|
||||
# "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
|
||||
#
|
||||
# - "hand_recrop.tflite" is available at
|
||||
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
# - "face_landmark.tflite" is available at
|
||||
# "mediapipe/modules/face_landmark/face_landmark.tflite"
|
||||
#
|
||||
# - "hand_landmark.tflite" is available at
|
||||
# "mediapipe/modules/hand_landmark/hand_landmark.tflite"
|
||||
# - "hand_landmark_full.tflite" is available at
|
||||
# "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
|
||||
#
|
||||
# - "hand_recrop.tflite" is available at
|
||||
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"
|
||||
|
|
|
@ -237,7 +237,7 @@ absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
bool FilterDetectionCalculator::IsValidLabel(const std::string& label) {
|
||||
bool match = !limit_labels_ || ContainsKey(allowed_labels_, label);
|
||||
bool match = !limit_labels_ || allowed_labels_.contains(label);
|
||||
if (!match) {
|
||||
// If no exact match is found, check for regular expression
|
||||
// comparions in the allowed_labels.
|
||||
|
|
|
@ -21,6 +21,9 @@ cc_library(
|
|||
features = ["-parse_headers"],
|
||||
linkopts = [
|
||||
"-framework Accelerate",
|
||||
"-framework CoreFoundation",
|
||||
"-framework CoreGraphics",
|
||||
"-framework CoreVideo",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
deps = [
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#import <Accelerate/Accelerate.h>
|
||||
#import <CoreFoundation/CoreFoundation.h>
|
||||
#import <CoreGraphics/CoreGraphics.h>
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
|
|
|
@ -89,6 +89,7 @@ class Hands(SolutionBase):
|
|||
def __init__(self,
|
||||
static_image_mode=False,
|
||||
max_num_hands=2,
|
||||
model_complexity=1,
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5):
|
||||
"""Initializes a MediaPipe Hand object.
|
||||
|
@ -99,6 +100,10 @@ class Hands(SolutionBase):
|
|||
https://solutions.mediapipe.dev/hands#static_image_mode.
|
||||
max_num_hands: Maximum number of hands to detect. See details in
|
||||
https://solutions.mediapipe.dev/hands#max_num_hands.
|
||||
model_complexity: Complexity of the hand landmark model: 0 or 1.
|
||||
Landmark accuracy as well as inference latency generally go up with the
|
||||
model complexity. See details in
|
||||
https://solutions.mediapipe.dev/hands#model_complexity.
|
||||
min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for hand
|
||||
detection to be considered successful. See details in
|
||||
https://solutions.mediapipe.dev/hands#min_detection_confidence.
|
||||
|
@ -109,6 +114,7 @@ class Hands(SolutionBase):
|
|||
super().__init__(
|
||||
binary_graph_path=_BINARYPB_FILE_PATH,
|
||||
side_inputs={
|
||||
'model_complexity': model_complexity,
|
||||
'num_hands': max_num_hands,
|
||||
'use_prev_landmarks': not static_image_mode,
|
||||
},
|
||||
|
|
|
@ -32,7 +32,8 @@ from mediapipe.python.solutions import hands as mp_hands
|
|||
|
||||
|
||||
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
||||
DIFF_THRESHOLD = 20 # pixels
|
||||
LITE_MODEL_DIFF_THRESHOLD = 25 # pixels
|
||||
FULL_MODEL_DIFF_THRESHOLD = 20 # pixels
|
||||
EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
|
||||
[289, 237], [322, 203], [219, 216],
|
||||
[238, 138], [249, 90], [253, 51],
|
||||
|
@ -40,7 +41,7 @@ EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
|
|||
[185, 19], [138, 208], [131, 127],
|
||||
[124, 77], [117, 36], [106, 222],
|
||||
[92, 159], [79, 124], [68, 93]],
|
||||
[[580, 36], [504, 50], [459, 94],
|
||||
[[580, 34], [504, 50], [459, 94],
|
||||
[429, 146], [397, 182], [507, 167],
|
||||
[479, 245], [469, 292], [464, 330],
|
||||
[545, 180], [534, 265], [533, 319],
|
||||
|
@ -75,14 +76,18 @@ class HandsTest(parameterized.TestCase):
|
|||
self.assertIsNone(results.multi_hand_landmarks)
|
||||
self.assertIsNone(results.multi_handedness)
|
||||
|
||||
@parameterized.named_parameters(('static_image_mode', True, 1),
|
||||
('video_mode', False, 5))
|
||||
def test_multi_hands(self, static_image_mode, num_frames):
|
||||
@parameterized.named_parameters(
|
||||
('static_image_mode_with_lite_model', True, 0, 5),
|
||||
('video_mode_with_lite_model', False, 0, 10),
|
||||
('static_image_mode_with_full_model', True, 1, 5),
|
||||
('video_mode_with_full_model', False, 1, 10))
|
||||
def test_multi_hands(self, static_image_mode, model_complexity, num_frames):
|
||||
image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg')
|
||||
image = cv2.imread(image_path)
|
||||
with mp_hands.Hands(
|
||||
static_image_mode=static_image_mode,
|
||||
max_num_hands=2,
|
||||
model_complexity=model_complexity,
|
||||
min_detection_confidence=0.5) as hands:
|
||||
for idx in range(num_frames):
|
||||
results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
@ -104,7 +109,8 @@ class HandsTest(parameterized.TestCase):
|
|||
prediction_error = np.abs(
|
||||
np.asarray(multi_hand_coordinates) -
|
||||
np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION))
|
||||
npt.assert_array_less(prediction_error, DIFF_THRESHOLD)
|
||||
diff_threshold = LITE_MODEL_DIFF_THRESHOLD if model_complexity == 0 else FULL_MODEL_DIFF_THRESHOLD
|
||||
npt.assert_array_less(prediction_error, diff_threshold)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -81,6 +82,32 @@ ObjectDef GetSSBOObjectDef(int channels) {
|
|||
return gpu_object_def;
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
|
||||
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
|
||||
cl::InferenceOptions result{};
|
||||
result.priority1 = options.priority1;
|
||||
result.priority2 = options.priority2;
|
||||
result.priority3 = options.priority3;
|
||||
result.usage = options.usage;
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
|
||||
const std::vector<BHWC>& expected) {
|
||||
RET_CHECK_EQ(actual.size(), expected.size());
|
||||
const int size = actual.size();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const auto& dims = actual[i].dimensions;
|
||||
const BHWC& bhwc = expected[i];
|
||||
RET_CHECK(dims.b == bhwc.b && dims.h == bhwc.h && dims.w == bhwc.w &&
|
||||
dims.c == bhwc.c);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
#endif // __ANDROID__
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status TFLiteGPURunner::InitializeWithModel(
|
||||
|
@ -139,16 +166,16 @@ absl::Status TFLiteGPURunner::Build() {
|
|||
// try to build OpenCL first. If something goes wrong, fall back to OpenGL.
|
||||
absl::Status status = InitializeOpenCL(&builder);
|
||||
if (status.ok()) {
|
||||
LOG(INFO) << "OpenCL backend is used.";
|
||||
VLOG(2) << "OpenCL backend is used.";
|
||||
} else {
|
||||
LOG(ERROR) << "Falling back to OpenGL: " << status.message();
|
||||
VLOG(2) << "Falling back to OpenGL: " << status.message();
|
||||
MP_RETURN_IF_ERROR(InitializeOpenGL(&builder));
|
||||
}
|
||||
}
|
||||
|
||||
// Both graphs are not needed anymore. Make sure they are deleted.
|
||||
// GL graph not needed anymore, CL graph maybe needed for serialized model
|
||||
// calculation.
|
||||
graph_gl_.reset(nullptr);
|
||||
graph_cl_.reset(nullptr);
|
||||
|
||||
// 2. Describe output/input objects for created builder.
|
||||
for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) {
|
||||
|
@ -204,18 +231,57 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
|
|||
env_options.serialized_binary_cache = serialized_binary_cache_;
|
||||
}
|
||||
cl::InferenceEnvironmentProperties properties;
|
||||
cl::InferenceOptions cl_options;
|
||||
cl_options.priority1 = options_.priority1;
|
||||
cl_options.priority2 = options_.priority2;
|
||||
cl_options.priority3 = options_.priority3;
|
||||
cl_options.usage = options_.usage;
|
||||
MP_RETURN_IF_ERROR(
|
||||
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
||||
|
||||
// Try to initialize from serialized model first.
|
||||
if (!serialized_model_.empty()) {
|
||||
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
|
||||
if (init_status.ok()) {
|
||||
serialized_model_used_ = true;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
VLOG(2) << "Failed to init from serialized model: [" << init_status
|
||||
<< "]. Trying to init from scratch.";
|
||||
}
|
||||
|
||||
// Initialize from scratch.
|
||||
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
|
||||
GraphFloat32 graph_cl;
|
||||
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
|
||||
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
|
||||
cl_options, std::move(*graph_cl_), builder));
|
||||
#endif
|
||||
cl_options, std::move(graph_cl), builder));
|
||||
|
||||
#endif // __ANDROID__
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
|
||||
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
||||
std::unique_ptr<InferenceBuilder>* builder) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
cl_environment_->NewInferenceBuilder(serialized_model_, builder));
|
||||
MP_RETURN_IF_ERROR(VerifyShapes(builder->get()->inputs(), input_shapes_));
|
||||
return VerifyShapes(builder->get()->outputs(), output_shapes_);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
||||
RET_CHECK(runner_) << "Runner is in invalid state.";
|
||||
if (serialized_model_used_) {
|
||||
return serialized_model_;
|
||||
}
|
||||
RET_CHECK(graph_cl_) << "CL graph is not initialized.";
|
||||
GraphFloat32 graph_cl;
|
||||
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
|
||||
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
|
||||
std::vector<uint8_t> serialized_model;
|
||||
MP_RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
|
||||
cl_options, std::move(graph_cl), &serialized_model));
|
||||
return serialized_model;
|
||||
}
|
||||
|
||||
#endif // __ANDROID__
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -29,7 +30,7 @@
|
|||
|
||||
#ifdef __ANDROID__
|
||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||
#endif
|
||||
#endif // __ANDROID__
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
@ -90,11 +91,22 @@ class TFLiteGPURunner {
|
|||
std::vector<uint8_t> GetSerializedBinaryCache() {
|
||||
return cl_environment_->GetSerializedBinaryCache();
|
||||
}
|
||||
#endif
|
||||
|
||||
void SetSerializedModel(std::vector<uint8_t>&& serialized_model) {
|
||||
serialized_model_ = std::move(serialized_model);
|
||||
serialized_model_used_ = false;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
|
||||
#endif // __ANDROID__
|
||||
|
||||
private:
|
||||
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
|
||||
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
|
||||
#ifdef __ANDROID__
|
||||
absl::Status InitializeOpenCLFromSerializedModel(
|
||||
std::unique_ptr<InferenceBuilder>* builder);
|
||||
#endif // __ANDROID__
|
||||
|
||||
InferenceOptions options_;
|
||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||
|
@ -103,9 +115,12 @@ class TFLiteGPURunner {
|
|||
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||
|
||||
std::vector<uint8_t> serialized_binary_cache_;
|
||||
#endif
|
||||
std::vector<uint8_t> serialized_model_;
|
||||
bool serialized_model_used_ = false;
|
||||
#endif // __ANDROID__
|
||||
|
||||
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
|
||||
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is
|
||||
// ready
|
||||
std::unique_ptr<GraphFloat32> graph_gl_;
|
||||
std::unique_ptr<GraphFloat32> graph_cl_;
|
||||
std::unique_ptr<InferenceRunner> runner_;
|
||||
|
|
Loading…
Reference in New Issue
Block a user