Merge branch 'google:master' into c-image-embedder-api

This commit is contained in:
Kinar R 2023-11-07 02:00:23 +05:30 committed by GitHub
commit 3b122a1e61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
109 changed files with 3581 additions and 447 deletions

View File

@ -513,6 +513,9 @@ http_archive(
"@//third_party:org_tensorflow_system_python.diff", "@//third_party:org_tensorflow_system_python.diff",
# Diff is generated with a script, don't update it manually. # Diff is generated with a script, don't update it manually.
"@//third_party:org_tensorflow_custom_ops.diff", "@//third_party:org_tensorflow_custom_ops.diff",
# Works around Bazel issue with objc_library.
# See https://github.com/bazelbuild/bazel/issues/19912
"@//third_party:org_tensorflow_objc_build_fixes.diff",
], ],
patch_args = [ patch_args = [
"-p1", "-p1",

View File

@ -0,0 +1,342 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 56;
objects = {
/* Begin PBXBuildFile section */
8566B55D2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h in Headers */ = {isa = PBXBuildFile; fileRef = 8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */; settings = {ATTRIBUTES = (Public, ); }; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = MediaPipeTasksDocGen.framework; sourceTree = BUILT_PRODUCTS_DIR; };
8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MediaPipeTasksDocGen.h; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
8566B5562ABABF9A00AAB22A /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
8566B54F2ABABF9A00AAB22A = {
isa = PBXGroup;
children = (
8566B55B2ABABF9A00AAB22A /* MediaPipeTasksDocGen */,
8566B55A2ABABF9A00AAB22A /* Products */,
);
sourceTree = "<group>";
};
8566B55A2ABABF9A00AAB22A /* Products */ = {
isa = PBXGroup;
children = (
8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */,
);
name = Products;
sourceTree = "<group>";
};
8566B55B2ABABF9A00AAB22A /* MediaPipeTasksDocGen */ = {
isa = PBXGroup;
children = (
8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */,
);
path = MediaPipeTasksDocGen;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
8566B5542ABABF9A00AAB22A /* Headers */ = {
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
8566B55D2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
8566B5582ABABF9A00AAB22A /* MediaPipeTasksDocGen */ = {
isa = PBXNativeTarget;
buildConfigurationList = 8566B5602ABABF9A00AAB22A /* Build configuration list for PBXNativeTarget "MediaPipeTasksDocGen" */;
buildPhases = (
8566B5542ABABF9A00AAB22A /* Headers */,
8566B5552ABABF9A00AAB22A /* Sources */,
8566B5562ABABF9A00AAB22A /* Frameworks */,
8566B5572ABABF9A00AAB22A /* Resources */,
);
buildRules = (
);
dependencies = (
);
name = MediaPipeTasksDocGen;
productName = MediaPipeTasksDocGen;
productReference = 8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */;
productType = "com.apple.product-type.framework";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
8566B5502ABABF9A00AAB22A /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastUpgradeCheck = 1430;
TargetAttributes = {
8566B5582ABABF9A00AAB22A = {
CreatedOnToolsVersion = 14.3.1;
};
};
};
buildConfigurationList = 8566B5532ABABF9A00AAB22A /* Build configuration list for PBXProject "MediaPipeTasksDocGen" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 8566B54F2ABABF9A00AAB22A;
productRefGroup = 8566B55A2ABABF9A00AAB22A /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
8566B5582ABABF9A00AAB22A /* MediaPipeTasksDocGen */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
8566B5572ABABF9A00AAB22A /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
8566B5552ABABF9A00AAB22A /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
8566B55E2ABABF9A00AAB22A /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
GCC_C_LANGUAGE_STANDARD = gnu11;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 16.4;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Debug;
};
8566B55F2ABABF9A00AAB22A /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
GCC_C_LANGUAGE_STANDARD = gnu11;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 16.4;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_OPTIMIZATION_LEVEL = "-O";
VALIDATE_PRODUCT = YES;
VERSIONING_SYSTEM = "apple-generic";
VERSION_INFO_PREFIX = "";
};
name = Release;
};
8566B5612ABABF9A00AAB22A /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu11 gnu++20";
PRODUCT_BUNDLE_IDENTIFIER = com.google.mediapipe.MediaPipeTasksDocGen;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
8566B5622ABABF9A00AAB22A /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEFINES_MODULE = YES;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_MODULE_VERIFIER = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSHumanReadableCopyright = "";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
MARKETING_VERSION = 1.0;
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu11 gnu++20";
PRODUCT_BUNDLE_IDENTIFIER = com.google.mediapipe.MediaPipeTasksDocGen;
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
8566B5532ABABF9A00AAB22A /* Build configuration list for PBXProject "MediaPipeTasksDocGen" */ = {
isa = XCConfigurationList;
buildConfigurations = (
8566B55E2ABABF9A00AAB22A /* Debug */,
8566B55F2ABABF9A00AAB22A /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
8566B5602ABABF9A00AAB22A /* Build configuration list for PBXNativeTarget "MediaPipeTasksDocGen" */ = {
isa = XCConfigurationList;
buildConfigurations = (
8566B5612ABABF9A00AAB22A /* Debug */,
8566B5622ABABF9A00AAB22A /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
};
rootObject = 8566B5502ABABF9A00AAB22A /* Project object */;
}

View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>

View File

@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>SchemeUserState</key>
<dict>
<key>MediaPipeTasksDocGen.xcscheme_^#shared#^_</key>
<dict>
<key>orderHint</key>
<integer>0</integer>
</dict>
</dict>
</dict>
</plist>

View File

@ -0,0 +1,17 @@
//
// MediaPipeTasksDocGen.h
// MediaPipeTasksDocGen
//
// Created by Mark McDonald on 20/9/2023.
//
#import <Foundation/Foundation.h>
//! Project version number for MediaPipeTasksDocGen.
FOUNDATION_EXPORT double MediaPipeTasksDocGenVersionNumber;
//! Project version string for MediaPipeTasksDocGen.
FOUNDATION_EXPORT const unsigned char MediaPipeTasksDocGenVersionString[];
// In this header, you should import all the public headers of your framework using statements like
// #import <MediaPipeTasksDocGen/PublicHeader.h>

View File

@ -0,0 +1,11 @@
# Uncomment the next line to define a global platform for your project
platform :ios, '15.0'
target 'MediaPipeTasksDocGen' do
# Comment the next line if you don't want to use dynamic frameworks
use_frameworks!
# Pods for MediaPipeTasksDocGen
pod 'MediaPipeTasksText'
pod 'MediaPipeTasksVision'
end

View File

@ -0,0 +1,9 @@
# MediaPipeTasksDocGen
This empty project is used to generate reference documentation for the
ObjectiveC and Swift libraries.
Docs are generated using [Jazzy](https://github.com/realm/jazzy) and published
to [the developer site](https://developers.google.com/mediapipe/solutions/).
To bump the API version used, edit [`Podfile`](./Podfile).

View File

@ -727,6 +727,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -742,6 +743,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util", "//mediapipe/framework/tool:options_util",
"//mediapipe/util:packet_test_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -17,6 +17,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -32,6 +33,7 @@ namespace {
constexpr char kTagAtPreStream[] = "AT_PRESTREAM"; constexpr char kTagAtPreStream[] = "AT_PRESTREAM";
constexpr char kTagAtPostStream[] = "AT_POSTSTREAM"; constexpr char kTagAtPostStream[] = "AT_POSTSTREAM";
constexpr char kTagAtZero[] = "AT_ZERO"; constexpr char kTagAtZero[] = "AT_ZERO";
constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK";
constexpr char kTagAtTick[] = "AT_TICK"; constexpr char kTagAtTick[] = "AT_TICK";
constexpr char kTagTick[] = "TICK"; constexpr char kTagTick[] = "TICK";
constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP"; constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP";
@ -43,6 +45,7 @@ static std::map<std::string, Timestamp>* kTimestampMap = []() {
res->emplace(kTagAtPostStream, Timestamp::PostStream()); res->emplace(kTagAtPostStream, Timestamp::PostStream());
res->emplace(kTagAtZero, Timestamp(0)); res->emplace(kTagAtZero, Timestamp(0));
res->emplace(kTagAtTick, Timestamp::Unset()); res->emplace(kTagAtTick, Timestamp::Unset());
res->emplace(kTagAtFirstTick, Timestamp::Unset());
res->emplace(kTagAtTimestamp, Timestamp::Unset()); res->emplace(kTagAtTimestamp, Timestamp::Unset());
return res; return res;
}(); }();
@ -59,8 +62,8 @@ std::string GetOutputTag(const CC& cc) {
// timestamp, depending on the tag used to define output stream(s). (One tag can // timestamp, depending on the tag used to define output stream(s). (One tag can
// be used only.) // be used only.)
// //
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP // Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK,
// and corresponding timestamps are Timestamp::PreStream(), // AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(),
// Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK // Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK
// input, and timestamp received from a side input. // input, and timestamp received from a side input.
// //
@ -96,6 +99,7 @@ class SidePacketToStreamCalculator : public CalculatorBase {
private: private:
bool is_tick_processing_ = false; bool is_tick_processing_ = false;
bool close_on_first_tick_ = false;
std::string output_tag_; std::string output_tag_;
}; };
REGISTER_CALCULATOR(SidePacketToStreamCalculator); REGISTER_CALCULATOR(SidePacketToStreamCalculator);
@ -103,13 +107,16 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator);
absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
const auto& tags = cc->Outputs().GetTags(); const auto& tags = cc->Outputs().GetTags();
RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1)
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, "
"AT_TIMESTAMP tags is allowed and required to specify output " "AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to "
"stream(s)."; "specify output stream(s).";
RET_CHECK( const bool has_tick_output =
(cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) || cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick);
(!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick))) const bool has_tick_input = cc->Inputs().HasTag(kTagTick);
<< "Either both of TICK and AT_TICK should be used or none of them."; RET_CHECK((has_tick_output && has_tick_input) ||
(!has_tick_output && !has_tick_input))
<< "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them.";
RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) && RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) &&
cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) || cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) ||
(!cc->Outputs().HasTag(kTagAtTimestamp) && (!cc->Outputs().HasTag(kTagAtTimestamp) &&
@ -148,11 +155,17 @@ absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) {
// timestamp bound update. // timestamp bound update.
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
} }
if (output_tag_ == kTagAtFirstTick) {
close_on_first_tick_ = true;
}
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
if (is_tick_processing_) { if (is_tick_processing_) {
if (cc->Outputs().Get(output_tag_, 0).IsClosed()) {
return absl::OkStatus();
}
// TICK input is guaranteed to be non-empty, as it's the only input stream // TICK input is guaranteed to be non-empty, as it's the only input stream
// for this calculator. // for this calculator.
const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp(); const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp();
@ -160,6 +173,9 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
cc->Outputs() cc->Outputs()
.Get(output_tag_, i) .Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); .AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
if (close_on_first_tick_) {
cc->Outputs().Get(output_tag_, i).Close();
}
} }
return absl::OkStatus(); return absl::OkStatus();
@ -170,6 +186,7 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
if (!cc->Outputs().HasTag(kTagAtTick) && if (!cc->Outputs().HasTag(kTagAtTick) &&
!cc->Outputs().HasTag(kTagAtFirstTick) &&
!cc->Outputs().HasTag(kTagAtTimestamp)) { !cc->Outputs().HasTag(kTagAtTimestamp)) {
const auto& timestamp = kTimestampMap->at(output_tag_); const auto& timestamp = kTimestampMap->at(output_tag_);
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {

View File

@ -27,13 +27,17 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_util.h" #include "mediapipe/framework/tool/options_util.h"
#include "mediapipe/util/packet_test_util.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using testing::HasSubstr; using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTick) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -52,10 +56,35 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
EXPECT_THAT( EXPECT_THAT(
status.message(), status.message(),
HasSubstr( HasSubstr(
"Either both of TICK and AT_TICK should be used or none of them.")); "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { TEST(SidePacketToStreamCalculator,
WrongConfigWithMissingTickForFirstTickProcessing) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "tick"
input_side_packet: "side_packet"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "side_packet"
output_stream: "AT_FIRST_TICK:packet"
}
)pb");
CalculatorGraph graph;
auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
HasSubstr(
"Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them."));
}
TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTimestampSideInput) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -76,7 +105,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
"or none of them.")); "or none of them."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { TEST(SidePacketToStreamCalculator, WrongConfigWithNonExistentTag) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -92,14 +121,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
CalculatorGraph graph; CalculatorGraph graph;
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_THAT( EXPECT_THAT(status.message(),
status.message(), HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
"AT_TIMESTAMP tags is allowed and required to specify output " "allowed and required to specify output stream(s)."));
"stream(s)."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { TEST(SidePacketToStreamCalculator, WrongConfigWithMixedTags) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -117,14 +145,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
CalculatorGraph graph; CalculatorGraph graph;
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_THAT( EXPECT_THAT(status.message(),
status.message(), HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
"AT_TIMESTAMP tags is allowed and required to specify output " "allowed and required to specify output stream(s)."));
"stream(s)."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughSidePackets) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -146,7 +173,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
"Same number of input side packets and output streams is required.")); "Same number of input side packets and output streams is required."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) { TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughOutputStreams) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -248,7 +275,50 @@ TEST(SidePacketToStreamCalculator, AtTick) {
tick_and_verify(/*at_timestamp=*/1025); tick_and_verify(/*at_timestamp=*/1025);
} }
TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { TEST(SidePacketToStreamCalculator, AtFirstTick) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "tick"
input_side_packet: "side_packet"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_stream: "TICK:tick"
input_side_packet: "side_packet"
output_stream: "AT_FIRST_TICK:packet"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("packet", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
const int expected_value = 20;
const Timestamp kTestTimestamp(1234);
MP_ASSERT_OK(
graph.StartRun({{"side_packet", MakePacket<int>(expected_value)}}));
auto insert_tick = [&graph](Timestamp at_timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
};
insert_tick(kTestTimestamp);
EXPECT_THAT(output_packets,
ElementsAre(PacketContainsTimestampAndPayload<int>(
Eq(kTestTimestamp), Eq(expected_value))));
output_packets.clear();
// Should not result in an additional output.
insert_tick(kTestTimestamp + 1);
EXPECT_THAT(output_packets, IsEmpty());
}
TEST(SidePacketToStreamCalculator, AtTickWithMultipleSidePackets) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -302,6 +372,62 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
tick_and_verify(/*at_timestamp=*/1025); tick_and_verify(/*at_timestamp=*/1025);
} }
TEST(SidePacketToStreamCalculator, AtFirstTickWithMultipleSidePackets) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "tick"
input_side_packet: "side_packet0"
input_side_packet: "side_packet1"
output_stream: "packet0"
output_stream: "packet1"
node {
calculator: "SidePacketToStreamCalculator"
input_stream: "TICK:tick"
input_side_packet: "side_packet0"
input_side_packet: "side_packet1"
output_stream: "AT_FIRST_TICK:0:packet0"
output_stream: "AT_FIRST_TICK:1:packet1"
}
)pb");
std::vector<Packet> output_packets0;
tool::AddVectorSink("packet0", &graph_config, &output_packets0);
std::vector<Packet> output_packets1;
tool::AddVectorSink("packet1", &graph_config, &output_packets1);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
const int expected_value0 = 20;
const int expected_value1 = 128;
const Timestamp kTestTimestamp(1234);
MP_ASSERT_OK(
graph.StartRun({{"side_packet0", MakePacket<int>(expected_value0)},
{"side_packet1", MakePacket<int>(expected_value1)}}));
auto insert_tick = [&graph](Timestamp at_timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
};
insert_tick(kTestTimestamp);
EXPECT_THAT(output_packets0,
ElementsAre(PacketContainsTimestampAndPayload<int>(
Eq(kTestTimestamp), Eq(expected_value0))));
EXPECT_THAT(output_packets1,
ElementsAre(PacketContainsTimestampAndPayload<int>(
Eq(kTestTimestamp), Eq(expected_value1))));
output_packets0.clear();
output_packets1.clear();
// Should not result in an additional output.
insert_tick(kTestTimestamp + 1);
EXPECT_THAT(output_packets0, IsEmpty());
EXPECT_THAT(output_packets1, IsEmpty());
}
TEST(SidePacketToStreamCalculator, AtTimestamp) { TEST(SidePacketToStreamCalculator, AtTimestamp) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
@ -334,7 +460,7 @@ TEST(SidePacketToStreamCalculator, AtTimestamp) {
EXPECT_EQ(expected_value, output_packets.back().Get<int>()); EXPECT_EQ(expected_value, output_packets.back().Get<int>());
} }
TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) { TEST(SidePacketToStreamCalculator, AtTimestampWithMultipleOutputs) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(

View File

@ -65,7 +65,7 @@ class ImageCloneCalculator : public Node {
} }
#else #else
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
cc, /*requesst_gpu_as_optional=*/true)); cc, /*request_gpu_as_optional=*/true));
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -118,7 +118,7 @@ absl::Status SegmentationSmoothingCalculator::GetContract(
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
cc, /*requesst_gpu_as_optional=*/true)); cc, /*request_gpu_as_optional=*/true));
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
return absl::OkStatus(); return absl::OkStatus();

View File

@ -206,7 +206,7 @@ class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl<InterfaceT> {
if constexpr (std::is_same_v<InterfaceT, WarpAffineCalculatorGpu> || if constexpr (std::is_same_v<InterfaceT, WarpAffineCalculatorGpu> ||
std::is_same_v<InterfaceT, WarpAffineCalculator>) { std::is_same_v<InterfaceT, WarpAffineCalculator>) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
cc, /*requesst_gpu_as_optional=*/true)); cc, /*request_gpu_as_optional=*/true));
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -1480,7 +1480,6 @@ cc_test(
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
], ],
) )

View File

@ -109,7 +109,7 @@ bool IsValidFftSize(int size) {
// Non-streaming mode: when "stream_mode" is set to false in the calculator // Non-streaming mode: when "stream_mode" is set to false in the calculator
// options, the calculators treats the packets in the input audio stream as // options, the calculators treats the packets in the input audio stream as
// a batch of unrelated audio buffers. In each Process() call, the input // a batch of unrelated audio buffers. In each Process() call, the input
// buffer will be frist resampled, and framed as fixed-sized, possibly // buffer will be first resampled, and framed as fixed-sized, possibly
// overlapping tensors. The last tensor produced by a Process() invocation // overlapping tensors. The last tensor produced by a Process() invocation
// will be zero-padding if the remaining samples are insufficient. As the // will be zero-padding if the remaining samples are insufficient. As the
// calculator treats the input packets as unrelated, all samples will be // calculator treats the input packets as unrelated, all samples will be
@ -159,7 +159,7 @@ class AudioToTensorCalculator : public Node {
public: public:
static constexpr Input<Matrix> kAudioIn{"AUDIO"}; static constexpr Input<Matrix> kAudioIn{"AUDIO"};
// TODO: Removes this optional input stream when the "AUDIO" stream // TODO: Removes this optional input stream when the "AUDIO" stream
// uses the new mediapipe audio data containers that carry audio metatdata, // uses the new mediapipe audio data containers that carry audio metadata,
// such as sample rate. // such as sample rate.
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"}; static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};

View File

@ -37,7 +37,7 @@ message AudioToTensorCalculatorOptions {
// will be converted into tensors. // will be converted into tensors.
optional double target_sample_rate = 4; optional double target_sample_rate = 4;
// Whether to treat the input audio stream as a continous stream or a batch // Whether to treat the input audio stream as a continuous stream or a batch
// of unrelated audio buffers. // of unrelated audio buffers.
optional bool stream_mode = 5 [default = true]; optional bool stream_mode = 5 [default = true];

View File

@ -82,7 +82,7 @@ namespace api2 {
// //
// Outputs: // Outputs:
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// Vector containing a single Tensor populated with an extrated RGB image. // Vector containing a single Tensor populated with an extracted RGB image.
// MATRIX - std::array<float, 16> @Optional // MATRIX - std::array<float, 16> @Optional
// An std::array<float, 16> representing a 4x4 row-major-order matrix that // An std::array<float, 16> representing a 4x4 row-major-order matrix that
// maps a point on the input image to a point on the output tensor, and // maps a point on the input image to a point on the output tensor, and
@ -212,7 +212,7 @@ class ImageToTensorCalculator : public Node {
std::array<float, 16> matrix; std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix( GetRotatedSubRectToRectTransformMatrix(
roi, image->width(), image->height(), roi, image->width(), image->height(),
/*flip_horizontaly=*/false, &matrix); /*flip_horizontally=*/false, &matrix);
kOutMatrix(cc).Send(std::move(matrix)); kOutMatrix(cc).Send(std::move(matrix));
} }

View File

@ -206,7 +206,7 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
} else if (image_channels == 1) { } else if (image_channels == 1) {
return ImageFormat::GRAY8; return ImageFormat::GRAY8;
} }
ABSL_CHECK(false) << "Unsupported input image channles: " << image_channels; ABSL_CHECK(false) << "Unsupported input image channels: " << image_channels;
} }
Packet MakeImageFramePacket(cv::Mat input) { Packet MakeImageFramePacket(cv::Mat input) {

View File

@ -57,7 +57,7 @@ class SubRectExtractorGl {
absl::Status ExtractSubRectToBuffer( absl::Status ExtractSubRectToBuffer(
const tflite::gpu::gl::GlTexture& texture, const tflite::gpu::gl::GlTexture& texture,
const tflite::gpu::HW& texture_size, const RotatedRect& sub_rect, const tflite::gpu::HW& texture_size, const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta, bool flip_horizontally, float alpha, float beta,
const tflite::gpu::HW& destination_size, const tflite::gpu::HW& destination_size,
tflite::gpu::gl::CommandQueue* command_queue, tflite::gpu::gl::CommandQueue* command_queue,
tflite::gpu::gl::GlBuffer* destination); tflite::gpu::gl::GlBuffer* destination);
@ -154,13 +154,13 @@ void main() {
absl::Status SubRectExtractorGl::ExtractSubRectToBuffer( absl::Status SubRectExtractorGl::ExtractSubRectToBuffer(
const tflite::gpu::gl::GlTexture& texture, const tflite::gpu::gl::GlTexture& texture,
const tflite::gpu::HW& texture_size, const RotatedRect& texture_sub_rect, const tflite::gpu::HW& texture_size, const RotatedRect& texture_sub_rect,
bool flip_horizontaly, float alpha, float beta, bool flip_horizontally, float alpha, float beta,
const tflite::gpu::HW& destination_size, const tflite::gpu::HW& destination_size,
tflite::gpu::gl::CommandQueue* command_queue, tflite::gpu::gl::CommandQueue* command_queue,
tflite::gpu::gl::GlBuffer* destination) { tflite::gpu::gl::GlBuffer* destination) {
std::array<float, 16> transform_mat; std::array<float, 16> transform_mat;
GetRotatedSubRectToRectTransformMatrix(texture_sub_rect, texture_size.w, GetRotatedSubRectToRectTransformMatrix(texture_sub_rect, texture_size.w,
texture_size.h, flip_horizontaly, texture_size.h, flip_horizontally,
&transform_mat); &transform_mat);
MP_RETURN_IF_ERROR(texture.BindAsSampler2D(0)); MP_RETURN_IF_ERROR(texture.BindAsSampler2D(0));
@ -308,7 +308,7 @@ class GlProcessor : public ImageToTensorConverter {
input_texture, input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()), tflite::gpu::HW(source_texture.height(), source_texture.width()),
roi, roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset, /*flip_horizontally=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]), tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_queue_.get(), &output)); command_queue_.get(), &output));

View File

@ -199,7 +199,7 @@ class GlProcessor : public ImageToTensorConverter {
range_min, range_max)); range_min, range_max));
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView(); auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi, MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
/*flip_horizontaly=*/false, /*flip_horizontally=*/false,
transform.scale, transform.offset, transform.scale, transform.offset,
output_shape, &tensor_view)); output_shape, &tensor_view));
return absl::OkStatus(); return absl::OkStatus();
@ -210,7 +210,7 @@ class GlProcessor : public ImageToTensorConverter {
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
const RotatedRect& sub_rect, const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta, bool flip_horizontally, float alpha, float beta,
const Tensor::Shape& output_shape, const Tensor::Shape& output_shape,
Tensor::OpenGlTexture2dView* output) { Tensor::OpenGlTexture2dView* output) {
const int output_height = output_shape.dims[1]; const int output_height = output_shape.dims[1];
@ -263,13 +263,13 @@ class GlProcessor : public ImageToTensorConverter {
ABSL_LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; ABSL_LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread.";
if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) { if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
GetTransposedRotatedSubRectToRectTransformMatrix( GetTransposedRotatedSubRectToRectTransformMatrix(
sub_rect, texture.width(), texture.height(), flip_horizontaly, sub_rect, texture.width(), texture.height(), flip_horizontally,
&transform_mat); &transform_mat);
glUniformMatrix4fv(matrix_id_, 1, GL_FALSE, transform_mat.data()); glUniformMatrix4fv(matrix_id_, 1, GL_FALSE, transform_mat.data());
} else { } else {
GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(), GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(),
texture.height(), flip_horizontaly, texture.height(),
&transform_mat); flip_horizontally, &transform_mat);
glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data()); glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data());
} }

View File

@ -179,13 +179,13 @@ class SubRectExtractorMetal {
} }
absl::Status Execute(id<MTLTexture> input_texture, absl::Status Execute(id<MTLTexture> input_texture,
const RotatedRect& sub_rect, bool flip_horizontaly, const RotatedRect& sub_rect, bool flip_horizontally,
float alpha, float beta, float alpha, float beta,
const tflite::gpu::HW& destination_size, const tflite::gpu::HW& destination_size,
id<MTLCommandBuffer> command_buffer, id<MTLCommandBuffer> command_buffer,
id<MTLBuffer> destination) { id<MTLBuffer> destination) {
auto output_texture = MTLTextureWithBuffer(destination_size, destination); auto output_texture = MTLTextureWithBuffer(destination_size, destination);
return InternalExecute(input_texture, sub_rect, flip_horizontaly, alpha, return InternalExecute(input_texture, sub_rect, flip_horizontally, alpha,
beta, destination_size, command_buffer, beta, destination_size, command_buffer,
output_texture); output_texture);
} }
@ -211,7 +211,7 @@ class SubRectExtractorMetal {
absl::Status InternalExecute(id<MTLTexture> input_texture, absl::Status InternalExecute(id<MTLTexture> input_texture,
const RotatedRect& sub_rect, const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta, bool flip_horizontally, float alpha, float beta,
const tflite::gpu::HW& destination_size, const tflite::gpu::HW& destination_size,
id<MTLCommandBuffer> command_buffer, id<MTLCommandBuffer> command_buffer,
id<MTLTexture> output_texture) { id<MTLTexture> output_texture) {
@ -223,7 +223,7 @@ class SubRectExtractorMetal {
std::array<float, 16> transform_mat; std::array<float, 16> transform_mat;
GetRotatedSubRectToRectTransformMatrix(sub_rect, input_texture.width, GetRotatedSubRectToRectTransformMatrix(sub_rect, input_texture.width,
input_texture.height, input_texture.height,
flip_horizontaly, &transform_mat); flip_horizontally, &transform_mat);
id<MTLBuffer> transform_mat_buffer = id<MTLBuffer> transform_mat_buffer =
[device_ newBufferWithBytes:&transform_mat [device_ newBufferWithBytes:&transform_mat
length:sizeof(transform_mat) length:sizeof(transform_mat)
@ -383,7 +383,7 @@ class MetalProcessor : public ImageToTensorConverter {
MtlBufferView::GetWriteView(output_tensor, command_buffer); MtlBufferView::GetWriteView(output_tensor, command_buffer);
MP_RETURN_IF_ERROR(extractor_->Execute( MP_RETURN_IF_ERROR(extractor_->Execute(
texture, roi, texture, roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset, /*flip_horizontally=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]), tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_buffer, buffer_view.buffer())); command_buffer, buffer_view.buffer()));
[command_buffer commit]; [command_buffer commit];

View File

@ -92,7 +92,7 @@ absl::StatusOr<ValueTransformation> GetValueRangeTransformation(
void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
int rect_width, int rect_height, int rect_width, int rect_height,
bool flip_horizontaly, bool flip_horizontally,
std::array<float, 16>* matrix_ptr) { std::array<float, 16>* matrix_ptr) {
std::array<float, 16>& matrix = *matrix_ptr; std::array<float, 16>& matrix = *matrix_ptr;
// The resulting matrix is multiplication of below commented out matrices: // The resulting matrix is multiplication of below commented out matrices:
@ -118,7 +118,7 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
// {0.0f, 0.0f, a, 0.0f} // {0.0f, 0.0f, a, 0.0f}
// {0.0f, 0.0f, 0.0f, 1.0f} // {0.0f, 0.0f, 0.0f, 1.0f}
const float flip = flip_horizontaly ? -1 : 1; const float flip = flip_horizontally ? -1 : 1;
// Matrix for optional horizontal flip around middle of output image. // Matrix for optional horizontal flip around middle of output image.
// { fl , 0.0f, 0.0f, 0.0f} // { fl , 0.0f, 0.0f, 0.0f}
// { 0.0f, 1.0f, 0.0f, 0.0f} // { 0.0f, 1.0f, 0.0f, 0.0f}
@ -177,13 +177,13 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
void GetTransposedRotatedSubRectToRectTransformMatrix( void GetTransposedRotatedSubRectToRectTransformMatrix(
const RotatedRect& sub_rect, int rect_width, int rect_height, const RotatedRect& sub_rect, int rect_width, int rect_height,
bool flip_horizontaly, std::array<float, 16>* matrix_ptr) { bool flip_horizontally, std::array<float, 16>* matrix_ptr) {
std::array<float, 16>& matrix = *matrix_ptr; std::array<float, 16>& matrix = *matrix_ptr;
// See comments in GetRotatedSubRectToRectTransformMatrix for detailed // See comments in GetRotatedSubRectToRectTransformMatrix for detailed
// calculations. // calculations.
const float a = sub_rect.width; const float a = sub_rect.width;
const float b = sub_rect.height; const float b = sub_rect.height;
const float flip = flip_horizontaly ? -1 : 1; const float flip = flip_horizontally ? -1 : 1;
const float c = std::cos(sub_rect.rotation); const float c = std::cos(sub_rect.rotation);
const float d = std::sin(sub_rect.rotation); const float d = std::sin(sub_rect.rotation);
const float e = sub_rect.center_x; const float e = sub_rect.center_x;

View File

@ -74,7 +74,7 @@ absl::StatusOr<std::array<float, 4>> PadRoi(int input_tensor_width,
// Represents a transformation of value which involves scaling and offsetting. // Represents a transformation of value which involves scaling and offsetting.
// To apply transformation: // To apply transformation:
// ValueTransformation transform = ... // ValueTransformation transform = ...
// float transformed_value = transform.scale * value + transfrom.offset; // float transformed_value = transform.scale * value + transform.offset;
struct ValueTransformation { struct ValueTransformation {
float scale; float scale;
float offset; float offset;
@ -99,11 +99,11 @@ absl::StatusOr<ValueTransformation> GetValueRangeTransformation(
// @sub_rect - rotated sub rect in absolute coordinates // @sub_rect - rotated sub rect in absolute coordinates
// @rect_width - rect width // @rect_width - rect width
// @rect_height - rect height // @rect_height - rect height
// @flip_horizontaly - we need to flip the output buffer. // @flip_horizontally - we need to flip the output buffer.
// @matrix - 4x4 matrix (array of 16 elements) to populate // @matrix - 4x4 matrix (array of 16 elements) to populate
void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
int rect_width, int rect_height, int rect_width, int rect_height,
bool flip_horizontaly, bool flip_horizontally,
std::array<float, 16>* matrix); std::array<float, 16>* matrix);
// Returns the transpose of the matrix found with // Returns the transpose of the matrix found with
@ -118,11 +118,11 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
// @sub_rect - rotated sub rect in absolute coordinates // @sub_rect - rotated sub rect in absolute coordinates
// @rect_width - rect width // @rect_width - rect width
// @rect_height - rect height // @rect_height - rect height
// @flip_horizontaly - we need to flip the output buffer. // @flip_horizontally - we need to flip the output buffer.
// @matrix - 4x4 matrix (array of 16 elements) to populate // @matrix - 4x4 matrix (array of 16 elements) to populate
void GetTransposedRotatedSubRectToRectTransformMatrix( void GetTransposedRotatedSubRectToRectTransformMatrix(
const RotatedRect& sub_rect, int rect_width, int rect_height, const RotatedRect& sub_rect, int rect_width, int rect_height,
bool flip_horizontaly, std::array<float, 16>* matrix); bool flip_horizontally, std::array<float, 16>* matrix);
// Validates the output dimensions set in the option proto. The input option // Validates the output dimensions set in the option proto. The input option
// proto is expected to have to following fields: // proto is expected to have to following fields:

View File

@ -32,7 +32,7 @@ message TensorConverterCalculatorOptions {
// Custom settings to override the internal scaling factors `div` and `sub`. // Custom settings to override the internal scaling factors `div` and `sub`.
// Both values must be set to non-negative values. Will only take effect on // Both values must be set to non-negative values. Will only take effect on
// CPU AND when |use_custom_normalization| is set to true. When these custom // CPU AND when |use_custom_normalization| is set to true. When these custom
// values take effect, the |zero_center| setting above will be overriden, and // values take effect, the |zero_center| setting above will be overridden, and
// the normalized_value will be calculated as: // the normalized_value will be calculated as:
// normalized_value = input / custom_div - custom_sub. // normalized_value = input / custom_div - custom_sub.
optional bool use_custom_normalization = 6 [default = false]; optional bool use_custom_normalization = 6 [default = false];

View File

@ -34,7 +34,7 @@ message TensorsToClassificationCalculatorOptions {
repeated Entry entries = 1; repeated Entry entries = 1;
} }
// Score threshold for perserving the class. // Score threshold for preserving the class.
optional float min_score_threshold = 1; optional float min_score_threshold = 1;
// Number of highest scoring labels to output. If top_k is not positive then // Number of highest scoring labels to output. If top_k is not positive then
// all labels are used. // all labels are used.

View File

@ -15,7 +15,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
@ -147,7 +146,7 @@ BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) {
// TENSORS - Vector of Tensors of type kFloat32. The vector of tensors can have // TENSORS - Vector of Tensors of type kFloat32. The vector of tensors can have
// 2 or 3 tensors. First tensor is the predicted raw boxes/keypoints. // 2 or 3 tensors. First tensor is the predicted raw boxes/keypoints.
// The size of the values must be (num_boxes * num_predicted_values). // The size of the values must be (num_boxes * num_predicted_values).
// Second tensor is the score tensor. The size of the valuse must be // Second tensor is the score tensor. The size of the values must be
// (num_boxes * num_classes). It's optional to pass in a third tensor // (num_boxes * num_classes). It's optional to pass in a third tensor
// for anchors (e.g. for SSD models) depend on the outputs of the // for anchors (e.g. for SSD models) depend on the outputs of the
// detection model. The size of anchor tensor must be (num_boxes * // detection model. The size of anchor tensor must be (num_boxes *
@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node {
const int* detection_classes, const int* detection_classes,
std::vector<Detection>* output_detections); std::vector<Detection>* output_detections);
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
float box_xmax, float score, int class_id, float box_xmax, absl::Span<const float> scores,
absl::Span<const int> class_ids,
bool flip_vertically); bool flip_vertically);
bool IsClassIndexAllowed(int class_index); bool IsClassIndexAllowed(int class_index);
@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node {
int num_boxes_ = 0; int num_boxes_ = 0;
int num_coords_ = 0; int num_coords_ = 0;
int max_results_ = -1; int max_results_ = -1;
int classes_per_detection_ = 1;
BoxFormat box_output_format_ = BoxFormat box_output_format_ =
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
@ -267,7 +268,7 @@ absl::Status TensorsToDetectionsCalculator::UpdateContract(
if (CanUseGpu()) { if (CanUseGpu()) {
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
cc, /*requesst_gpu_as_optional=*/true)); cc, /*request_gpu_as_optional=*/true));
#elif MEDIAPIPE_METAL_ENABLED #elif MEDIAPIPE_METAL_ENABLED
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
auto num_boxes_view = num_boxes_tensor->GetCpuReadView(); auto num_boxes_view = num_boxes_tensor->GetCpuReadView();
auto num_boxes = num_boxes_view.buffer<float>(); auto num_boxes = num_boxes_view.buffer<float>();
num_boxes_ = num_boxes[0]; num_boxes_ = num_boxes[0];
// The detection model with Detection_PostProcess op may output duplicate
// boxes with different classes, in the following format:
// num_boxes_tensor = [num_boxes]
// detection_classes_tensor = [box_1_class_1, box_1_class_2, ...]
// detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ]
// detection_boxes_tensor = [box_1, box1, ... ]
// Each box repeats classes_per_detection_ times.
// Note Detection_PostProcess op is only supported in CPU.
RET_CHECK_EQ(max_detections % num_boxes_, 0);
classes_per_detection_ = max_detections / num_boxes_;
auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView(); auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView();
auto detection_boxes = detection_boxes_view.buffer<float>(); auto detection_boxes = detection_boxes_view.buffer<float>();
@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
auto detection_classes_view = detection_classes_tensor->GetCpuReadView(); auto detection_classes_view = detection_classes_tensor->GetCpuReadView();
auto detection_classes_ptr = detection_classes_view.buffer<float>(); auto detection_classes_ptr = detection_classes_view.buffer<float>();
std::vector<int> detection_classes(num_boxes_); std::vector<int> detection_classes(num_boxes_ * classes_per_detection_);
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < detection_classes.size(); ++i) {
detection_classes[i] = static_cast<int>(detection_classes_ptr[i]); detection_classes[i] = static_cast<int>(detection_classes_ptr[i]);
} }
MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores, MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
@ -863,24 +874,25 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
absl::Status TensorsToDetectionsCalculator::ConvertToDetections( absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
const float* detection_boxes, const float* detection_scores, const float* detection_boxes, const float* detection_scores,
const int* detection_classes, std::vector<Detection>* output_detections) { const int* detection_classes, std::vector<Detection>* output_detections) {
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < num_boxes_ * classes_per_detection_;
i += classes_per_detection_) {
if (max_results_ > 0 && output_detections->size() == max_results_) { if (max_results_ > 0 && output_detections->size() == max_results_) {
break; break;
} }
if (options_.has_min_score_thresh() &&
detection_scores[i] < options_.min_score_thresh()) {
continue;
}
if (!IsClassIndexAllowed(detection_classes[i])) {
continue;
}
const int box_offset = i * num_coords_; const int box_offset = i * num_coords_;
Detection detection = ConvertToDetection( Detection detection = ConvertToDetection(
/*box_ymin=*/detection_boxes[box_offset + box_indices_[0]], /*box_ymin=*/detection_boxes[box_offset + box_indices_[0]],
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]], /*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]], /*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]], /*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
detection_scores[i], detection_classes[i], options_.flip_vertically()); absl::MakeConstSpan(detection_scores + i, classes_per_detection_),
absl::MakeConstSpan(detection_classes + i, classes_per_detection_),
options_.flip_vertically());
// if all the scores and classes are filtered out, we skip the empty
// detection.
if (detection.score().empty()) {
continue;
}
const auto& bbox = detection.location_data().relative_bounding_box(); const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
std::isnan(bbox.height())) { std::isnan(bbox.height())) {
@ -910,11 +922,21 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
} }
Detection TensorsToDetectionsCalculator::ConvertToDetection( Detection TensorsToDetectionsCalculator::ConvertToDetection(
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, float box_ymin, float box_xmin, float box_ymax, float box_xmax,
int class_id, bool flip_vertically) { absl::Span<const float> scores, absl::Span<const int> class_ids,
bool flip_vertically) {
Detection detection; Detection detection;
detection.add_score(score); for (int i = 0; i < scores.size(); ++i) {
detection.add_label_id(class_id); if (!IsClassIndexAllowed(class_ids[i])) {
continue;
}
if (options_.has_min_score_thresh() &&
scores[i] < options_.min_score_thresh()) {
continue;
}
detection.add_score(scores[i]);
detection.add_label_id(class_ids[i]);
}
LocationData* location_data = detection.mutable_location_data(); LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);

View File

@ -75,7 +75,7 @@ message TensorsToDetectionsCalculatorOptions {
// representation has a bottom-left origin (e.g., in OpenGL). // representation has a bottom-left origin (e.g., in OpenGL).
optional bool flip_vertically = 18 [default = false]; optional bool flip_vertically = 18 [default = false];
// Score threshold for perserving decoded detections. // Score threshold for preserving decoded detections.
optional float min_score_thresh = 19; optional float min_score_thresh = 19;
// The maximum number of the detection results to return. If < 0, all // The maximum number of the detection results to return. If < 0, all

View File

@ -124,7 +124,7 @@ absl::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) {
kFlipVertically(cc).IsConnected())) { kFlipVertically(cc).IsConnected())) {
RET_CHECK(options_.has_input_image_height() && RET_CHECK(options_.has_input_image_height() &&
options_.has_input_image_width()) options_.has_input_image_width())
<< "Must provide input width/height for using flipping when outputing " << "Must provide input width/height for using flipping when outputting "
"landmarks in absolute coordinates."; "landmarks in absolute coordinates.";
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -208,7 +208,7 @@ absl::Status TensorsToSegmentationCalculator::GetContract(
if (CanUseGpu()) { if (CanUseGpu()) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
cc, /*requesst_gpu_as_optional=*/true)); cc, /*request_gpu_as_optional=*/true));
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED

View File

@ -60,24 +60,22 @@ struct FormattingTestCase {
std::vector<float> inputs; std::vector<float> inputs;
std::vector<float> expected_outputs; std::vector<float> expected_outputs;
Options::Activation activation; Options::Activation activation;
int rows; int rows = 1;
int cols; int cols = 1;
int channels; int rows_new = 1;
int cols_new = 1;
int channels = 1;
double max_abs_diff = 1e-7;
}; };
using TensorsToSegmentationCalculatorTest = TestWithParam<FormattingTestCase>; using TensorsToSegmentationCalculatorTest = TestWithParam<FormattingTestCase>;
// Currently only useable for tests with no output resize.
TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) { TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
const FormattingTestCase& test_case = GetParam(); const auto& [test_name, inputs, expected_outputs, activation, rows, cols,
std::vector<float> inputs = test_case.inputs; rows_new, cols_new, channels, max_abs_diff] = GetParam();
std::vector<float> expected_outputs = test_case.expected_outputs;
Options::Activation activation = test_case.activation;
int rows = test_case.rows;
int cols = test_case.cols;
int channels = test_case.channels;
std::string string_config = absl::Substitute( auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
R"pb( R"pb(
input_stream: "tensors" input_stream: "tensors"
input_stream: "size" input_stream: "size"
@ -93,9 +91,7 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
} }
} }
)pb", )pb",
ActivationTypeToString(activation)); ActivationTypeToString(activation)));
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(string_config);
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("image_as_mask", &graph_config, &output_packets); tool::AddVectorSink("image_as_mask", &graph_config, &output_packets);
@ -119,28 +115,34 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
MP_ASSERT_OK(graph.AddPacketToInputStream( MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensors", mediapipe::Adopt(tensors.release()).At(Timestamp(0)))); "tensors", mediapipe::Adopt(tensors.release()).At(Timestamp(0))));
} }
// The output size is defined as pair(new_width, new_height).
MP_ASSERT_OK(graph.AddPacketToInputStream( MP_ASSERT_OK(graph.AddPacketToInputStream(
"size", "size", mediapipe::Adopt(new std::pair<int, int>(cols_new, rows_new))
mediapipe::Adopt(new std::pair<int, int>(rows, cols)).At(Timestamp(0)))); .At(Timestamp(0))));
MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_THAT(output_packets, SizeIs(1)); ASSERT_THAT(output_packets, SizeIs(1));
const Image& image_as_mask = output_packets[0].Get<Image>(); const Image& image_as_mask = output_packets[0].Get<Image>();
EXPECT_FALSE(image_as_mask.UsesGpu());
std::shared_ptr<cv::Mat> result_mat = formats::MatView(&image_as_mask); std::shared_ptr<cv::Mat> result_mat = formats::MatView(&image_as_mask);
EXPECT_EQ(result_mat->rows, rows); EXPECT_EQ(result_mat->rows, rows_new);
EXPECT_EQ(result_mat->cols, cols); EXPECT_EQ(result_mat->cols, cols_new);
EXPECT_EQ(result_mat->channels(), channels); EXPECT_EQ(result_mat->channels(), 1);
// Compare the real result with the expected result. // Compare the real result with the expected result.
cv::Mat expected_result = cv::Mat( cv::Mat expected_result =
rows, cols, CV_32FC1, const_cast<float*>(expected_outputs.data())); cv::Mat(rows_new, cols_new, CV_32FC1,
const_cast<float*>(expected_outputs.data()));
cv::Mat diff; cv::Mat diff;
cv::absdiff(*result_mat, expected_result, diff); cv::absdiff(*result_mat, expected_result, diff);
double max_val; double max_val;
cv::minMaxLoc(diff, nullptr, &max_val); cv::minMaxLoc(diff, nullptr, &max_val);
// Expects the maximum absolute pixel-by-pixel difference is less than 1e-5.
// This delta is for passthorugh accuracy only. // The max allowable diff between output and expected output varies between
EXPECT_LE(max_val, 1e-5); // tests.
EXPECT_LE(max_val, max_abs_diff);
MP_ASSERT_OK(graph.CloseInputStream("tensors")); MP_ASSERT_OK(graph.CloseInputStream("tensors"));
MP_ASSERT_OK(graph.CloseInputStream("size")); MP_ASSERT_OK(graph.CloseInputStream("size"));
@ -150,17 +152,96 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TensorsToSegmentationCalculatorTests, TensorsToSegmentationCalculatorTest, TensorsToSegmentationCalculatorTests, TensorsToSegmentationCalculatorTest,
testing::ValuesIn<FormattingTestCase>({ testing::ValuesIn<FormattingTestCase>({
{/*test_name=*/"NoActivationAndNoOutputResize", {.test_name = "NoActivationAndNoOutputResize",
/*inputs=*/ .inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 12.0, 13.0, 14.0, 15.0, 16.0},
14.0, 15.0, 16.0}, .expected_outputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
/*expected_outputs=*/ 11.0, 12.0, 13.0, 14.0, 15.0, 16.0},
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, .activation = Options::NONE,
14.0, 15.0, 16.0}, .rows = 4,
/*activation=*/Options::NONE, .cols = 4,
/*rows=*/4, .rows_new = 4,
/*cols=*/4, .cols_new = 4,
/*channels=*/1}, .channels = 1,
.max_abs_diff = 1e-7},
{.test_name = "OutputResizeOnly",
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0, 16.0},
.expected_outputs = {1, 1.5, 2.166667, 2.833333, 3.5, 4,
3.8, 4.3, 4.966667, 5.633333, 6.3, 6.8,
7, 7.5, 8.166667, 8.833333, 9.5, 10,
10.2, 10.7, 11.366667, 12.033333, 12.7, 13.2,
13, 13.5, 14.166667, 14.833333, 15.5, 16},
.activation = Options::NONE,
.rows = 4,
.cols = 4,
.rows_new = 5,
.cols_new = 6,
.channels = 1,
.max_abs_diff = 1e-6},
{.test_name = "SigmoidActivationWithNoOutputResize",
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0, 16.0},
.expected_outputs = {0.731059, 0.880797, 0.952574, 0.982014, 0.993307,
0.997527, 0.999089, 0.999665, 0.999877, 0.999955,
0.999983, 0.999994, 0.999998, 0.999999, 1.0, 1.0},
.activation = Options::SIGMOID,
.rows = 4,
.cols = 4,
.rows_new = 4,
.cols_new = 4,
.channels = 1,
.max_abs_diff = 1e-6},
{.test_name = "SigmoidActivationWithOutputResize",
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0, 16.0},
.expected_outputs = {0.731059, 0.805928, 0.89276, 0.940611, 0.967294,
0.982014, 0.914633, 0.93857, 0.966279, 0.981363,
0.989752, 0.994369, 0.996592, 0.997666, 0.998873,
0.999404, 0.999683, 0.999829, 0.999913, 0.99994,
0.999971, 0.999985, 0.999992, 0.999996, 0.999998,
0.999998, 0.999999, 1.0, 1.0, 1.0},
.activation = Options::SIGMOID,
.rows = 4,
.cols = 4,
.rows_new = 5,
.cols_new = 6,
.channels = 1,
.max_abs_diff = 1e-6},
{.test_name = "SoftmaxActivationWithNoOutputResize",
.inputs = {1.0, 2.0, 4.0, 2.0, 3.0, 5.0, 6.0, 1.5,
7.0, 10.0, 11.0, 4.0, 12.0, 15.0, 16.0, 18.5,
19.0, 20.0, 22.0, 23.0, 24.5, 23.4, 25.6, 28.3,
29.2, 30.0, 24.6, 29.2, 30.0, 24.9, 31.2, 30.3},
.expected_outputs = {0.731059, 0.119203, 0.880797, 0.0109869, 0.952574,
0.000911051, 0.952574, 0.924142, 0.731059,
0.731059, 0.24974, 0.937027, 0.689974, 0.990048,
0.0060598, 0.28905},
.activation = Options::SOFTMAX,
.rows = 4,
.cols = 4,
.rows_new = 4,
.cols_new = 4,
.channels = 2,
.max_abs_diff = 1e-6},
{.test_name = "SoftmaxActivationWithOutputResize",
.inputs = {1.0, 2.0, 4.0, 2.0, 3.0, 5.0, 6.0, 1.5,
7.0, 10.0, 11.0, 4.0, 12.0, 15.0, 16.0, 18.5,
19.0, 20.0, 22.0, 23.0, 24.5, 23.4, 25.6, 28.3,
29.2, 30.0, 24.6, 29.2, 30.0, 24.9, 31.2, 30.3},
.expected_outputs = {0.731059, 0.425131, 0.246135, 0.753865, 0.445892,
0.0109869, 0.886119, 0.461259, 0.185506, 0.781934,
0.790618, 0.650195, 0.841816, 0.603901, 0.40518,
0.561962, 0.765871, 0.930584, 0.718733, 0.763744,
0.703402, 0.281989, 0.459635, 0.742634, 0.689974,
0.840011, 0.82605, 0.170058, 0.147555, 0.28905},
.activation = Options::SOFTMAX,
.rows = 4,
.cols = 4,
.rows_new = 5,
.cols_new = 6,
.channels = 2,
.max_abs_diff = 1e-6},
}), }),
[](const testing::TestParamInfo< [](const testing::TestParamInfo<
TensorsToSegmentationCalculatorTest::ParamType>& info) { TensorsToSegmentationCalculatorTest::ParamType>& info) {

View File

@ -79,7 +79,7 @@ namespace mpms = mediapipe::mediasequence;
// and label and label_id are optional but at least one of them should be set. // and label and label_id are optional but at least one of them should be set.
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store // "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
// prefixed versions of each stream, which allows for multiple image streams to // prefixed versions of each stream, which allows for multiple image streams to
// be included. However, the default names are suppored by more tools. // be included. However, the default names are supported by more tools.
// //
// Example config: // Example config:
// node { // node {

View File

@ -67,8 +67,8 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
// -- 1-D or 2-D Tensor // -- 1-D or 2-D Tensor
// Output: // Output:
// -- Matrix with the same values as the Tensor // -- Matrix with the same values as the Tensor
// If input tensor is 1 dimensional, the ouput Matrix is of (1xn) shape. // If input tensor is 1 dimensional, the output Matrix is of (1xn) shape.
// If input tensor is 2 dimensional (batched), the ouput Matrix is (mxn) shape. // If input tensor is 2 dimensional (batched), the output Matrix is (mxn) shape.
// //
// Example Config // Example Config
// node: { // node: {

View File

@ -111,8 +111,8 @@ class InferenceState {
// input_side_packet. // input_side_packet.
// //
// The input and output streams are TensorFlow tensors labeled by tags. The tags // The input and output streams are TensorFlow tensors labeled by tags. The tags
// for the streams are matched to feeds and fetchs in a TensorFlow session using // for the streams are matched to feeds and fetches in a TensorFlow session
// a named_signature.generic_signature in the ModelManifest. The // using a named_signature.generic_signature in the ModelManifest. The
// generic_signature is used as key-value pairs between the MediaPipe tag and // generic_signature is used as key-value pairs between the MediaPipe tag and
// the TensorFlow tensor. The signature_name in the options proto determines // the TensorFlow tensor. The signature_name in the options proto determines
// which named_signature is used. The keys in the generic_signature must be // which named_signature is used. The keys in the generic_signature must be
@ -128,7 +128,7 @@ class InferenceState {
// addition. Once batch_size inputs have been provided, the batch will be run // addition. Once batch_size inputs have been provided, the batch will be run
// and the output tensors sent out on the output streams with timestamps // and the output tensors sent out on the output streams with timestamps
// corresponding to the input stream packets. Setting the batch_size to 1 // corresponding to the input stream packets. Setting the batch_size to 1
// completely disables batching, but is indepdent of add_batch_dim_to_tensors. // completely disables batching, but is independent of add_batch_dim_to_tensors.
// //
// The TensorFlowInferenceCalculator also support feeding states recurrently for // The TensorFlowInferenceCalculator also support feeding states recurrently for
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the // RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the

View File

@ -42,7 +42,7 @@ message TensorFlowInferenceCalculatorOptions {
// If the 0th dimension is the batch dimension, then the tensors are // If the 0th dimension is the batch dimension, then the tensors are
// concatenated on that dimension. If the 0th is a data dimension, then a 0th // concatenated on that dimension. If the 0th is a data dimension, then a 0th
// dimension is added before concatenating. If added, the extra dimension is // dimension is added before concatenating. If added, the extra dimension is
// removed before outputing the tensor. Examples of each case: If you want // removed before outputting the tensor. Examples of each case: If you want
// to batch spectra of audio over time for an LSTM, a time-frequency // to batch spectra of audio over time for an LSTM, a time-frequency
// representation has a 0th dimension as the batch dimension. If you want to // representation has a 0th dimension as the batch dimension. If you want to
// batch frames of video that are [width, height, channels], the batch // batch frames of video that are [width, height, channels], the batch

View File

@ -1,6 +1,7 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
networkTimeout=10000 networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -83,10 +83,8 @@ done
# This is normally unused # This is normally unused
# shellcheck disable=SC2034 # shellcheck disable=SC2034
APP_BASE_NAME=${0##*/} APP_BASE_NAME=${0##*/}
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit # Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value. # Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum MAX_FD=maximum
@ -133,18 +131,21 @@ location of your Java installation."
fi fi
else else
JAVACMD=java JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. if ! command -v java >/dev/null 2>&1
then
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the Please set the JAVA_HOME variable in your environment to match the
location of your Java installation." location of your Java installation."
fi fi
fi
# Increase the maximum file descriptors if we can. # Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #( case $MAX_FD in #(
max*) max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045 # shellcheck disable=SC2039,SC3045
MAX_FD=$( ulimit -H -n ) || MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit" warn "Could not query maximum file descriptor limit"
esac esac
@ -152,7 +153,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
'' | soft) :;; #( '' | soft) :;; #(
*) *)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045 # shellcheck disable=SC2039,SC3045
ulimit -n "$MAX_FD" || ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD" warn "Could not set maximum file descriptor limit to $MAX_FD"
esac esac
@ -197,11 +198,15 @@ if "$cygwin" || "$msys" ; then
done done
fi fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
# shell script including quotes and variable substitutions, so put them in DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded. # Collect all arguments for the java command:
# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
# and any embedded shellness will be escaped.
# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
# treated as '${Hostname}' itself on the command line.
set -- \ set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \ "-Dorg.gradle.appname=$APP_BASE_NAME" \

View File

@ -56,7 +56,7 @@ absl::Status RunMPPGraph() {
for (const std::string& kv_pair : kv_pairs) { for (const std::string& kv_pair : kv_pairs) {
std::vector<std::string> name_and_value = absl::StrSplit(kv_pair, '='); std::vector<std::string> name_and_value = absl::StrSplit(kv_pair, '=');
RET_CHECK(name_and_value.size() == 2); RET_CHECK(name_and_value.size() == 2);
RET_CHECK(!mediapipe::ContainsKey(input_side_packets, name_and_value[0])); RET_CHECK(!input_side_packets.contains(name_and_value[0]));
std::string input_side_packet_contents; std::string input_side_packet_contents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents( MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
name_and_value[1], &input_side_packet_contents)); name_and_value[1], &input_side_packet_contents));

View File

@ -616,6 +616,7 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/functional/bind_front.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"

View File

@ -134,7 +134,7 @@ absl::Status ParseTagAndName(absl::string_view tag_and_name, std::string* tag,
RET_CHECK(name); RET_CHECK(name);
absl::Status tag_status = absl::OkStatus(); absl::Status tag_status = absl::OkStatus();
absl::Status name_status = absl::UnknownError(""); absl::Status name_status = absl::UnknownError("");
int name_index = 0; int name_index = -1;
std::vector<std::string> v = absl::StrSplit(tag_and_name, ':'); std::vector<std::string> v = absl::StrSplit(tag_and_name, ':');
if (v.size() == 1) { if (v.size() == 1) {
name_status = ValidateName(v[0]); name_status = ValidateName(v[0]);
@ -143,7 +143,7 @@ absl::Status ParseTagAndName(absl::string_view tag_and_name, std::string* tag,
tag_status = ValidateTag(v[0]); tag_status = ValidateTag(v[0]);
name_status = ValidateName(v[1]); name_status = ValidateName(v[1]);
name_index = 1; name_index = 1;
} } // else omitted, name_index == -1, triggering error.
if (name_index == -1 || tag_status != absl::OkStatus() || if (name_index == -1 || tag_status != absl::OkStatus() ||
name_status != absl::OkStatus()) { name_status != absl::OkStatus()) {
tag->clear(); tag->clear();

View File

@ -516,6 +516,7 @@ cc_library(
":gpu_buffer_storage", ":gpu_buffer_storage",
":image_frame_view", ":image_frame_view",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )
@ -526,12 +527,14 @@ mediapipe_proto_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
objc_library( cc_library(
name = "pixel_buffer_pool_util", name = "pixel_buffer_pool_util",
srcs = ["pixel_buffer_pool_util.mm"], srcs = ["pixel_buffer_pool_util.cc"],
hdrs = ["pixel_buffer_pool_util.h"], hdrs = ["pixel_buffer_pool_util.h"],
copts = [ copts = [
"-x objective-c++",
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-fobjc-arc", # enable reference-counting
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
@ -542,13 +545,14 @@ objc_library(
], ],
) )
objc_library( cc_library(
name = "metal_shared_resources", name = "metal_shared_resources",
srcs = ["metal_shared_resources.mm"], srcs = ["metal_shared_resources.cc"],
hdrs = ["metal_shared_resources.h"], hdrs = ["metal_shared_resources.h"],
copts = [ copts = [
"-x objective-c++", "-x objective-c++",
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-fobjc-arc", # enable reference-counting
], ],
features = ["-layering_check"], features = ["-layering_check"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -557,15 +561,17 @@ objc_library(
"@google_toolbox_for_mac//:GTM_Defines", "@google_toolbox_for_mac//:GTM_Defines",
] + [ ] + [
], ],
alwayslink = 1,
) )
objc_library( cc_library(
name = "MPPMetalUtil", name = "MPPMetalUtil",
srcs = ["MPPMetalUtil.mm"], srcs = ["MPPMetalUtil.cc"],
hdrs = ["MPPMetalUtil.h"], hdrs = ["MPPMetalUtil.h"],
copts = [ copts = [
"-x objective-c++", "-x objective-c++",
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-fobjc-arc", # enable reference-counting
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
@ -575,6 +581,7 @@ objc_library(
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
"@google_toolbox_for_mac//:GTM_Defines", "@google_toolbox_for_mac//:GTM_Defines",
], ],
alwayslink = 1,
) )
mediapipe_proto_library( mediapipe_proto_library(
@ -857,12 +864,14 @@ cc_library(
}), }),
) )
objc_library( cc_library(
name = "MPPMetalHelper", name = "MPPMetalHelper",
srcs = ["MPPMetalHelper.mm"], srcs = ["MPPMetalHelper.cc"],
hdrs = ["MPPMetalHelper.h"], hdrs = ["MPPMetalHelper.h"],
copts = [ copts = [
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-x objective-c++",
"-fobjc-arc",
], ],
features = ["-layering_check"], features = ["-layering_check"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -1215,9 +1224,13 @@ mediapipe_cc_test(
], ],
requires_full_emulation = True, requires_full_emulation = True,
deps = [ deps = [
":gl_texture_buffer",
":gl_texture_util",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage_ahwb", ":gpu_buffer_storage_ahwb",
":gpu_test_base",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/tool:test_util",
], ],
) )

View File

@ -14,15 +14,14 @@
#import "mediapipe/gpu/MPPMetalHelper.h" #import "mediapipe/gpu/MPPMetalHelper.h"
#import "GTMDefines.h"
#include "absl/log/absl_check.h" #include "absl/log/absl_check.h"
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
#include "mediapipe/framework/port/ret_check.h"
#import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/gpu_buffer.h"
#import "mediapipe/gpu/gpu_service.h" #import "mediapipe/gpu/gpu_service.h"
#import "mediapipe/gpu/graph_support.h" #import "mediapipe/gpu/graph_support.h"
#import "mediapipe/gpu/metal_shared_resources.h" #import "mediapipe/gpu/metal_shared_resources.h"
#import "GTMDefines.h"
#include "mediapipe/framework/port/ret_check.h"
@interface MPPMetalHelper () { @interface MPPMetalHelper () {
mediapipe::GpuResources* _gpuResources; mediapipe::GpuResources* _gpuResources;
@ -31,7 +30,8 @@
namespace mediapipe { namespace mediapipe {
// Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. // Using a C++ class so it can be declared as a friend of
// LegacyCalculatorSupport.
class MetalHelperLegacySupport { class MetalHelperLegacySupport {
public: public:
static CalculatorContract* GetCalculatorContract() { static CalculatorContract* GetCalculatorContract() {
@ -61,7 +61,8 @@ class MetalHelperLegacySupport {
- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc { - (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc {
if (!cc) return nil; if (!cc) return nil;
return [self initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()]; return [self
initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()];
} }
+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc { + (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc {
@ -77,7 +78,8 @@ class MetalHelperLegacySupport {
} }
// Legacy support. // Legacy support.
- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets { - (instancetype)initWithSidePackets:
(const mediapipe::PacketSet&)inputSidePackets {
auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContext(); auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContext();
if (cc) { if (cc) {
ABSL_CHECK_EQ(&inputSidePackets, &cc->InputSidePackets()); ABSL_CHECK_EQ(&inputSidePackets, &cc->InputSidePackets());
@ -85,16 +87,19 @@ class MetalHelperLegacySupport {
} }
// TODO: remove when we can. // TODO: remove when we can.
ABSL_LOG(WARNING) << "CalculatorContext not available. If this calculator uses " ABSL_LOG(WARNING)
<< "CalculatorContext not available. If this calculator uses "
"CalculatorBase, call initWithCalculatorContext instead."; "CalculatorBase, call initWithCalculatorContext instead.";
mediapipe::GpuSharedData* gpu_shared = mediapipe::GpuSharedData* gpu_shared =
inputSidePackets.Tag(mediapipe::kGpuSharedTagName).Get<mediapipe::GpuSharedData*>(); inputSidePackets.Tag(mediapipe::kGpuSharedTagName)
.Get<mediapipe::GpuSharedData*>();
return [self initWithGpuResources:gpu_shared->gpu_resources.get()]; return [self initWithGpuResources:gpu_shared->gpu_resources.get()];
} }
// Legacy support. // Legacy support.
+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets { + (absl::Status)setupInputSidePackets:
(mediapipe::PacketTypeSet*)inputSidePackets {
auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract(); auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract();
if (cc) { if (cc) {
ABSL_CHECK_EQ(inputSidePackets, &cc->InputSidePackets()); ABSL_CHECK_EQ(inputSidePackets, &cc->InputSidePackets());
@ -102,11 +107,11 @@ class MetalHelperLegacySupport {
} }
// TODO: remove when we can. // TODO: remove when we can.
ABSL_LOG(WARNING) << "CalculatorContract not available. If you're calling this " ABSL_LOG(WARNING)
<< "CalculatorContract not available. If you're calling this "
"from a GetContract method, call updateContract instead."; "from a GetContract method, call updateContract instead.";
auto id = inputSidePackets->GetId(mediapipe::kGpuSharedTagName, 0); auto id = inputSidePackets->GetId(mediapipe::kGpuSharedTagName, 0);
RET_CHECK(id.IsValid()) RET_CHECK(id.IsValid()) << "A " << mediapipe::kGpuSharedTagName
<< "A " << mediapipe::kGpuSharedTagName
<< " input side packet is required here."; << " input side packet is required here.";
inputSidePackets->Get(id).Set<mediapipe::GpuSharedData*>(); inputSidePackets->Get(id).Set<mediapipe::GpuSharedData*>();
return absl::OkStatus(); return absl::OkStatus();
@ -125,10 +130,12 @@ class MetalHelperLegacySupport {
} }
- (id<MTLCommandBuffer>)commandBuffer { - (id<MTLCommandBuffer>)commandBuffer {
return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; return
[_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer];
} }
- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:
(const mediapipe::GpuBuffer&)gpuBuffer
plane:(size_t)plane { plane:(size_t)plane {
CVPixelBufferRef pixel_buffer = mediapipe::GetCVPixelBufferRef(gpuBuffer); CVPixelBufferRef pixel_buffer = mediapipe::GetCVPixelBufferRef(gpuBuffer);
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
@ -178,40 +185,47 @@ class MetalHelperLegacySupport {
CVMetalTextureRef texture; CVMetalTextureRef texture;
CVReturn err = CVMetalTextureCacheCreateTextureFromImage( CVReturn err = CVMetalTextureCacheCreateTextureFromImage(
NULL, _gpuResources->metal_shared().resources().mtlTextureCache, NULL, _gpuResources->metal_shared().resources().mtlTextureCache,
mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width,
&texture); height, plane, &texture);
ABSL_CHECK_EQ(err, kCVReturnSuccess); ABSL_CHECK_EQ(err, kCVReturnSuccess);
return texture; return texture;
} }
- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer { - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:
(const mediapipe::GpuBuffer&)gpuBuffer {
return [self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:0]; return [self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:0];
} }
- (id<MTLTexture>)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer { - (id<MTLTexture>)metalTextureWithGpuBuffer:
(const mediapipe::GpuBuffer&)gpuBuffer {
return [self metalTextureWithGpuBuffer:gpuBuffer plane:0]; return [self metalTextureWithGpuBuffer:gpuBuffer plane:0];
} }
- (id<MTLTexture>)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer - (id<MTLTexture>)metalTextureWithGpuBuffer:
(const mediapipe::GpuBuffer&)gpuBuffer
plane:(size_t)plane { plane:(size_t)plane {
CFHolder<CVMetalTextureRef> cvTexture; CFHolder<CVMetalTextureRef> cvTexture;
cvTexture.adopt([self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:plane]); cvTexture.adopt([self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:plane]);
return CVMetalTextureGetTexture(*cvTexture); return CVMetalTextureGetTexture(*cvTexture);
} }
- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width
height:(int)height {
return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); return _gpuResources->gpu_buffer_pool().GetBuffer(width, height);
} }
- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width
height:(int)height height:(int)height
format:(mediapipe::GpuBufferFormat)format { format:(mediapipe::GpuBufferFormat)
format {
return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format);
} }
- (id<MTLLibrary>)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { - (id<MTLLibrary>)newLibraryWithResourceName:(NSString*)name
error:(NSError* _Nullable*)error {
return [_gpuResources->metal_shared().resources().mtlDevice return [_gpuResources->metal_shared().resources().mtlDevice
newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name newLibraryWithFile:[[NSBundle bundleForClass:[self class]]
pathForResource:name
ofType:@"metallib"] ofType:@"metallib"]
error:error]; error:error];
} }

View File

@ -69,10 +69,10 @@
while (!bufferCompleted) { while (!bufferCompleted) {
auto duration = absl::Now() - start_time; auto duration = absl::Now() - start_time;
// If the spin-lock takes more than 5 ms then go to blocking wait: // If the spin-lock takes more than 5 ms then go to blocking wait:
// - it frees the CPU core for another threads: increase the performance/decrease power // - it frees the CPU core for another threads: increase the
// consumption. // performance/decrease power consumption.
// - if a driver thread that notifies that the GPU buffer is completed has lower priority then // - if a driver thread that notifies that the GPU buffer is completed has
// the CPU core is allocated for the thread. // lower priority then the CPU core is allocated for the thread.
if (duration >= absl::Milliseconds(5)) { if (duration >= absl::Milliseconds(5)) {
[commandBuffer waitUntilCompleted]; [commandBuffer waitUntilCompleted];
break; break;

View File

@ -57,8 +57,8 @@ void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) {
// static // static
absl::Status GlCalculatorHelper::UpdateContract(CalculatorContract* cc, absl::Status GlCalculatorHelper::UpdateContract(CalculatorContract* cc,
bool requesst_gpu_as_optional) { bool request_gpu_as_optional) {
if (requesst_gpu_as_optional) { if (request_gpu_as_optional) {
cc->UseService(kGpuService).Optional(); cc->UseService(kGpuService).Optional();
} else { } else {
cc->UseService(kGpuService); cc->UseService(kGpuService);

View File

@ -68,7 +68,7 @@ class GlCalculatorHelper {
// This method can be called from GetContract to set up the needed GPU // This method can be called from GetContract to set up the needed GPU
// resources. // resources.
static absl::Status UpdateContract(CalculatorContract* cc, static absl::Status UpdateContract(CalculatorContract* cc,
bool requesst_gpu_as_optional = false); bool request_gpu_as_optional = false);
// This method can be called from FillExpectations to set the correct types // This method can be called from FillExpectations to set the correct types
// for the shared GL input side packet(s). // for the shared GL input side packet(s).

View File

@ -14,6 +14,8 @@
#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_buffer.h"
#include <cstdint>
#include "absl/log/absl_check.h" #include "absl/log/absl_check.h"
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -131,6 +133,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
SymbolAvailable(&glTexStorage2D)) { SymbolAvailable(&glTexStorage2D)) {
ABSL_CHECK(data == nullptr) << "unimplemented"; ABSL_CHECK(data == nullptr) << "unimplemented";
glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_); glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_);
} else if (info.immutable) {
ABSL_CHECK(SymbolAvailable(&glTexStorage2D) &&
context->GetGlVersion() != GlVersion::kGLES2)
<< "Immutable GpuBuffer format requested is not supported in this "
<< "GlContext. Format was " << static_cast<uint32_t>(format_);
ABSL_CHECK(data == nullptr) << "unimplemented";
glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_);
} else { } else {
glTexImage2D(target_, 0 /* level */, info.gl_internal_format, width_, glTexImage2D(target_, 0 /* level */, info.gl_internal_format, width_,
height_, 0 /* border */, info.gl_format, info.gl_type, data); height_, 0 /* border */, info.gl_format, info.gl_type, data);

View File

@ -35,6 +35,10 @@ namespace mediapipe {
#endif // GL_HALF_FLOAT_OES #endif // GL_HALF_FLOAT_OES
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
#ifndef GL_RGBA8
#define GL_RGBA8 0x8058
#endif // GL_RGBA8
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#ifdef GL_ES_VERSION_2_0 #ifdef GL_ES_VERSION_2_0
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
@ -163,6 +167,14 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
{ {
{GL_RGBA32F, GL_RGBA, GL_FLOAT, 1}, {GL_RGBA32F, GL_RGBA, GL_FLOAT, 1},
}}, }},
{GpuBufferFormat::kImmutableRGBAFloat128,
{
{GL_RGBA32F, GL_RGBA, GL_FLOAT, 1, true /* immutable */},
}},
{GpuBufferFormat::kImmutableRGBA32,
{
{GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, 1, true /* immutable */},
}},
}}; }};
static const auto* gles2_format_info = ([] { static const auto* gles2_format_info = ([] {
@ -206,6 +218,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
switch (format) { switch (format) {
case GpuBufferFormat::kImmutableRGBA32:
case GpuBufferFormat::kBGRA32: case GpuBufferFormat::kBGRA32:
// TODO: verify we are handling order of channels correctly. // TODO: verify we are handling order of channels correctly.
return ImageFormat::SRGBA; return ImageFormat::SRGBA;
@ -221,10 +234,11 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
return ImageFormat::SRGB; return ImageFormat::SRGB;
case GpuBufferFormat::kTwoComponentFloat32: case GpuBufferFormat::kTwoComponentFloat32:
return ImageFormat::VEC32F2; return ImageFormat::VEC32F2;
case GpuBufferFormat::kImmutableRGBAFloat128:
case GpuBufferFormat::kRGBAFloat128: case GpuBufferFormat::kRGBAFloat128:
return ImageFormat::VEC32F4; return ImageFormat::VEC32F4;
case GpuBufferFormat::kRGBA32: case GpuBufferFormat::kRGBA32:
// TODO: this likely maps to ImageFormat::SRGBA return ImageFormat::SRGBA;
case GpuBufferFormat::kGrayHalf16: case GpuBufferFormat::kGrayHalf16:
case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Alpha:
case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kOneComponent8Red:

View File

@ -53,6 +53,10 @@ enum class GpuBufferFormat : uint32_t {
kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible. kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible.
kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'), kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'),
kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'), kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'),
// Immutable version of kRGBA32
kImmutableRGBA32 = MEDIAPIPE_FOURCC('4', 'C', 'I', '8'),
// Immutable version of kRGBAFloat128
kImmutableRGBAFloat128 = MEDIAPIPE_FOURCC('4', 'C', 'I', 'f'),
// 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling. // 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling.
kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'), kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'),
// 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling. // 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling.
@ -78,6 +82,9 @@ struct GlTextureInfo {
// For multiplane buffers, this represents how many times smaller than // For multiplane buffers, this represents how many times smaller than
// the nominal image size a plane is. // the nominal image size a plane is.
int downscale; int downscale;
// For GLES3.1+ compute shaders, users may explicitly request immutable
// textures.
bool immutable = false;
}; };
const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
@ -121,6 +128,8 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
return kCVPixelFormatType_64RGBAHalf; return kCVPixelFormatType_64RGBAHalf;
case GpuBufferFormat::kRGBAFloat128: case GpuBufferFormat::kRGBAFloat128:
return kCVPixelFormatType_128RGBAFloat; return kCVPixelFormatType_128RGBAFloat;
case GpuBufferFormat::kImmutableRGBA32:
case GpuBufferFormat::kImmutableRGBAFloat128:
case GpuBufferFormat::kNV12: case GpuBufferFormat::kNV12:
case GpuBufferFormat::kNV21: case GpuBufferFormat::kNV21:
case GpuBufferFormat::kI420: case GpuBufferFormat::kI420:

View File

@ -151,7 +151,7 @@ static std::shared_ptr<GpuBufferStorageCvPixelBuffer> ConvertFromImageFrame(
std::shared_ptr<GpuBufferStorageImageFrame> frame) { std::shared_ptr<GpuBufferStorageImageFrame> frame) {
auto status_or_buffer = auto status_or_buffer =
CreateCVPixelBufferForImageFrame(frame->image_frame()); CreateCVPixelBufferForImageFrame(frame->image_frame());
ABSL_CHECK(status_or_buffer.ok()); ABSL_CHECK_OK(status_or_buffer);
return std::make_shared<GpuBufferStorageCvPixelBuffer>( return std::make_shared<GpuBufferStorageCvPixelBuffer>(
std::move(status_or_buffer).value()); std::move(status_or_buffer).value());
} }

View File

@ -50,9 +50,10 @@
- (CVMetalTextureCacheRef)mtlTextureCache { - (CVMetalTextureCacheRef)mtlTextureCache {
@synchronized(self) { @synchronized(self) {
if (!_mtlTextureCache) { if (!_mtlTextureCache) {
CVReturn __unused err = CVReturn __unused err = CVMetalTextureCacheCreate(
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, NSAssert(err == kCVReturnSuccess,
@"Error at CVMetalTextureCacheCreate %d ; device %@", err,
self.mtlDevice); self.mtlDevice);
// TODO: register and flush metal caches too. // TODO: register and flush metal caches too.
} }

View File

@ -24,23 +24,27 @@
namespace mediapipe { namespace mediapipe {
CVPixelBufferPoolRef CreateCVPixelBufferPool( CVPixelBufferPoolRef CreateCVPixelBufferPool(int width, int height,
int width, int height, OSType pixelFormat, int keepCount, OSType pixelFormat, int keepCount,
CFTimeInterval maxAge) { CFTimeInterval maxAge) {
CVPixelBufferPoolRef pool = NULL; CVPixelBufferPoolRef pool = NULL;
NSMutableDictionary *sourcePixelBufferOptions = NSMutableDictionary *sourcePixelBufferOptions =
[(__bridge NSDictionary*)GetCVPixelBufferAttributesForGlCompatibility() mutableCopy]; [(__bridge NSDictionary *)GetCVPixelBufferAttributesForGlCompatibility()
mutableCopy];
[sourcePixelBufferOptions addEntriesFromDictionary:@{ [sourcePixelBufferOptions addEntriesFromDictionary:@{
(id)kCVPixelBufferPixelFormatTypeKey : @(pixelFormat), (id)kCVPixelBufferPixelFormatTypeKey : @(pixelFormat),
(id)kCVPixelBufferWidthKey : @(width), (id)kCVPixelBufferWidthKey : @(width),
(id)kCVPixelBufferHeightKey : @(height), (id)kCVPixelBufferHeightKey : @(height),
}]; }];
NSMutableDictionary *pixelBufferPoolOptions = [[NSMutableDictionary alloc] init]; NSMutableDictionary *pixelBufferPoolOptions =
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] = @(keepCount); [[NSMutableDictionary alloc] init];
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] =
@(keepCount);
if (maxAge > 0) { if (maxAge > 0) {
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] = @(maxAge); pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] =
@(maxAge);
} }
CVPixelBufferPoolCreate( CVPixelBufferPoolCreate(
@ -50,8 +54,9 @@ CVPixelBufferPoolRef CreateCVPixelBufferPool(
return pool; return pool;
} }
OSStatus PreallocateCVPixelBufferPoolBuffers( OSStatus PreallocateCVPixelBufferPoolBuffers(CVPixelBufferPoolRef pool,
CVPixelBufferPoolRef pool, int count, CFDictionaryRef auxAttributes) { int count,
CFDictionaryRef auxAttributes) {
CVReturn err = kCVReturnSuccess; CVReturn err = kCVReturnSuccess;
NSMutableArray *pixelBuffers = [[NSMutableArray alloc] init]; NSMutableArray *pixelBuffers = [[NSMutableArray alloc] init];
for (int i = 0; i < count && err == kCVReturnSuccess; i++) { for (int i = 0; i < count && err == kCVReturnSuccess; i++) {
@ -68,30 +73,37 @@ OSStatus PreallocateCVPixelBufferPoolBuffers(
return err; return err;
} }
CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold(int allocationThreshold) { CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold(
int allocationThreshold) {
if (allocationThreshold > 0) { if (allocationThreshold > 0) {
return (CFDictionaryRef)CFBridgingRetain( return (CFDictionaryRef)CFBridgingRetain(@{
@{(id)kCVPixelBufferPoolAllocationThresholdKey: @(allocationThreshold)}); (id)kCVPixelBufferPoolAllocationThresholdKey : @(allocationThreshold)
});
} else { } else {
return nil; return nil;
} }
} }
CVReturn CreateCVPixelBufferWithPool( CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool,
CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes, CFDictionaryRef auxAttributes,
CVTextureCacheType textureCache, CVPixelBufferRef* outBuffer) { CVTextureCacheType textureCache,
return CreateCVPixelBufferWithPool(pool, auxAttributes, [textureCache](){ CVPixelBufferRef *outBuffer) {
return CreateCVPixelBufferWithPool(
pool, auxAttributes,
[textureCache]() {
#if TARGET_OS_OSX #if TARGET_OS_OSX
CVOpenGLTextureCacheFlush(textureCache, 0); CVOpenGLTextureCacheFlush(textureCache, 0);
#else #else
CVOpenGLESTextureCacheFlush(textureCache, 0); CVOpenGLESTextureCacheFlush(textureCache, 0);
#endif // TARGET_OS_OSX #endif // TARGET_OS_OSX
}, outBuffer); },
outBuffer);
} }
CVReturn CreateCVPixelBufferWithPool( CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool,
CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes, CFDictionaryRef auxAttributes,
std::function<void(void)> flush, CVPixelBufferRef* outBuffer) { std::function<void(void)> flush,
CVPixelBufferRef *outBuffer) {
CVReturn err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( CVReturn err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes(
kCFAllocatorDefault, pool, auxAttributes, outBuffer); kCFAllocatorDefault, pool, auxAttributes, outBuffer);
if (err == kCVReturnWouldExceedAllocationThreshold) { if (err == kCVReturnWouldExceedAllocationThreshold) {
@ -103,10 +115,12 @@ CVReturn CreateCVPixelBufferWithPool(
kCFAllocatorDefault, pool, auxAttributes, outBuffer); kCFAllocatorDefault, pool, auxAttributes, outBuffer);
} }
if (err == kCVReturnWouldExceedAllocationThreshold) { if (err == kCVReturnWouldExceedAllocationThreshold) {
// TODO: allow the application to set the threshold. For now, disable it by // TODO: allow the application to set the threshold. For now, disable it
// default, since the threshold we are using is arbitrary and some graphs routinely cross it. // by default, since the threshold we are using is arbitrary and some
// graphs routinely cross it.
#ifdef ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK #ifdef ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK
NSLog(@"Using more buffers than expected! This is a debug-only warning, " NSLog(
@"Using more buffers than expected! This is a debug-only warning, "
"you can ignore it if your app works fine otherwise."); "you can ignore it if your app works fine otherwise.");
#ifdef DEBUG #ifdef DEBUG
NSLog(@"Pool status: %@", ((__bridge NSObject *)pool).description); NSLog(@"Pool status: %@", ((__bridge NSObject *)pool).description);

View File

@ -52,9 +52,9 @@ objc_library(
) )
MEDIAPIPE_IOS_SRCS = [ MEDIAPIPE_IOS_SRCS = [
"MPPGraph.mm", "MPPGraph.cc",
"MPPTimestampConverter.mm", "MPPTimestampConverter.cc",
"NSError+util_status.mm", "NSError+util_status.cc",
] ]
MEDIAPIPE_IOS_HDRS = [ MEDIAPIPE_IOS_HDRS = [
@ -63,11 +63,13 @@ MEDIAPIPE_IOS_HDRS = [
"NSError+util_status.h", "NSError+util_status.h",
] ]
objc_library( cc_library(
name = "mediapipe_framework_ios", name = "mediapipe_framework_ios",
srcs = MEDIAPIPE_IOS_SRCS, srcs = MEDIAPIPE_IOS_SRCS,
hdrs = MEDIAPIPE_IOS_HDRS, hdrs = MEDIAPIPE_IOS_HDRS,
copts = [ copts = [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
], ],
# This build rule is public to allow external customers to build their own iOS apps. # This build rule is public to allow external customers to build their own iOS apps.
@ -99,6 +101,7 @@ objc_library(
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"@google_toolbox_for_mac//:GTM_Defines", "@google_toolbox_for_mac//:GTM_Defines",
], ],
alwayslink = 1,
) )
objc_library( objc_library(

View File

@ -19,6 +19,7 @@
#include <atomic> #include <atomic>
#import "GTMDefines.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
@ -26,22 +27,22 @@
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/gpu/gpu_shared_data_internal.h"
#import "mediapipe/objc/NSError+util_status.h"
#include "mediapipe/objc/util.h" #include "mediapipe/objc/util.h"
#import "mediapipe/objc/NSError+util_status.h"
#import "GTMDefines.h"
@implementation MPPGraph { @implementation MPPGraph {
// Graph is wrapped in a unique_ptr because it was generating 39+KB of unnecessary ObjC runtime // Graph is wrapped in a unique_ptr because it was generating 39+KB of
// information. See https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de // unnecessary ObjC runtime information. See
// for details. // https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de for
// details.
std::unique_ptr<mediapipe::CalculatorGraph> _graph; std::unique_ptr<mediapipe::CalculatorGraph> _graph;
/// Input side packets that will be added to the graph when it is started. /// Input side packets that will be added to the graph when it is started.
std::map<std::string, mediapipe::Packet> _inputSidePackets; std::map<std::string, mediapipe::Packet> _inputSidePackets;
/// Packet headers that will be added to the graph when it is started. /// Packet headers that will be added to the graph when it is started.
std::map<std::string, mediapipe::Packet> _streamHeaders; std::map<std::string, mediapipe::Packet> _streamHeaders;
/// Service packets to be added to the graph when it is started. /// Service packets to be added to the graph when it is started.
std::map<const mediapipe::GraphServiceBase*, mediapipe::Packet> _servicePackets; std::map<const mediapipe::GraphServiceBase*, mediapipe::Packet>
_servicePackets;
/// Number of frames currently being processed by the graph. /// Number of frames currently being processed by the graph.
std::atomic<int32_t> _framesInFlight; std::atomic<int32_t> _framesInFlight;
@ -56,7 +57,8 @@
BOOL _started; BOOL _started;
} }
- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config { - (instancetype)initWithGraphConfig:
(const mediapipe::CalculatorGraphConfig&)config {
self = [super init]; self = [super init];
if (self) { if (self) {
// Turn on Cocoa multithreading, since MediaPipe uses threads. // Turn on Cocoa multithreading, since MediaPipe uses threads.
@ -76,34 +78,41 @@
return _graph->GetGraphInputStreamAddMode(); return _graph->GetGraphInputStreamAddMode();
} }
- (void)setPacketAddMode:(mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode { - (void)setPacketAddMode:
(mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode {
_graph->SetGraphInputStreamAddMode(mode); _graph->SetGraphInputStreamAddMode(mode);
} }
- (void)addFrameOutputStream:(const std::string&)outputStreamName - (void)addFrameOutputStream:(const std::string&)outputStreamName
outputPacketType:(MPPPacketType)packetType { outputPacketType:(MPPPacketType)packetType {
std::string callbackInputName; std::string callbackInputName;
mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config, &callbackInputName, mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config,
&callbackInputName,
/*use_std_function=*/true); /*use_std_function=*/true);
// No matter what ownership qualifiers are put on the pointer, NewPermanentCallback will // No matter what ownership qualifiers are put on the pointer,
// still end up with a strong pointer to MPPGraph*. That is why we use void* instead. // NewPermanentCallback will still end up with a strong pointer to MPPGraph*.
// That is why we use void* instead.
void* wrapperVoid = (__bridge void*)self; void* wrapperVoid = (__bridge void*)self;
_inputSidePackets[callbackInputName] = _inputSidePackets[callbackInputName] =
mediapipe::MakePacket<std::function<void(const mediapipe::Packet&)>>( mediapipe::MakePacket<std::function<void(const mediapipe::Packet&)>>(
[wrapperVoid, outputStreamName, packetType](const mediapipe::Packet& packet) { [wrapperVoid, outputStreamName,
CallFrameDelegate(wrapperVoid, outputStreamName, packetType, packet); packetType](const mediapipe::Packet& packet) {
CallFrameDelegate(wrapperVoid, outputStreamName, packetType,
packet);
}); });
} }
- (NSString*)description { - (NSString*)description {
return [NSString stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self, return [NSString
stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self,
_framesInFlight.load(std::memory_order_relaxed)]; _framesInFlight.load(std::memory_order_relaxed)];
} }
/// This is the function that gets called by the CallbackCalculator that /// This is the function that gets called by the CallbackCalculator that
/// receives the graph's output. /// receives the graph's output.
void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
MPPPacketType packetType, const mediapipe::Packet& packet) { MPPPacketType packetType,
const mediapipe::Packet& packet) {
MPPGraph* wrapper = (__bridge MPPGraph*)wrapperVoid; MPPGraph* wrapper = (__bridge MPPGraph*)wrapperVoid;
@autoreleasepool { @autoreleasepool {
if (packetType == MPPPacketTypeRaw) { if (packetType == MPPPacketTypeRaw) {
@ -118,13 +127,16 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
if (format == mediapipe::ImageFormat::SRGBA || if (format == mediapipe::ImageFormat::SRGBA ||
format == mediapipe::ImageFormat::GRAY8) { format == mediapipe::ImageFormat::GRAY8) {
CVPixelBufferRef pixelBuffer; CVPixelBufferRef pixelBuffer;
// If kCVPixelFormatType_32RGBA does not work, it returns kCVReturnInvalidPixelFormat. // If kCVPixelFormatType_32RGBA does not work, it returns
// kCVReturnInvalidPixelFormat.
CVReturn error = CVPixelBufferCreate( CVReturn error = CVPixelBufferCreate(
NULL, frame.Width(), frame.Height(), kCVPixelFormatType_32BGRA, NULL, frame.Width(), frame.Height(), kCVPixelFormatType_32BGRA,
GetCVPixelBufferAttributesForGlCompatibility(), &pixelBuffer); GetCVPixelBufferAttributesForGlCompatibility(), &pixelBuffer);
_GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferCreate failed: %d", error); _GTMDevAssert(error == kCVReturnSuccess,
@"CVPixelBufferCreate failed: %d", error);
error = CVPixelBufferLockBaseAddress(pixelBuffer, 0); error = CVPixelBufferLockBaseAddress(pixelBuffer, 0);
_GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferLockBaseAddress failed: %d", error); _GTMDevAssert(error == kCVReturnSuccess,
@"CVPixelBufferLockBaseAddress failed: %d", error);
vImage_Buffer vDestination = vImageForCVPixelBuffer(pixelBuffer); vImage_Buffer vDestination = vImageForCVPixelBuffer(pixelBuffer);
// Note: we have to throw away const here, but we should not overwrite // Note: we have to throw away const here, but we should not overwrite
@ -133,26 +145,31 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
if (format == mediapipe::ImageFormat::SRGBA) { if (format == mediapipe::ImageFormat::SRGBA) {
// Swap R and B channels. // Swap R and B channels.
const uint8_t permuteMap[4] = {2, 1, 0, 3}; const uint8_t permuteMap[4] = {2, 1, 0, 3};
vImage_Error __unused vError = vImage_Error __unused vError = vImagePermuteChannels_ARGB8888(
vImagePermuteChannels_ARGB8888(&vSource, &vDestination, permuteMap, kvImageNoFlags); &vSource, &vDestination, permuteMap, kvImageNoFlags);
_GTMDevAssert(vError == kvImageNoError, @"vImagePermuteChannels failed: %zd", vError); _GTMDevAssert(vError == kvImageNoError,
@"vImagePermuteChannels failed: %zd", vError);
} else { } else {
// Convert grayscale back to BGRA // Convert grayscale back to BGRA
vImage_Error __unused vError = vImageGrayToBGRA(&vSource, &vDestination); vImage_Error __unused vError =
_GTMDevAssert(vError == kvImageNoError, @"vImageGrayToBGRA failed: %zd", vError); vImageGrayToBGRA(&vSource, &vDestination);
_GTMDevAssert(vError == kvImageNoError,
@"vImageGrayToBGRA failed: %zd", vError);
} }
error = CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); error = CVPixelBufferUnlockBaseAddress(pixelBuffer, 0);
_GTMDevAssert(error == kCVReturnSuccess, _GTMDevAssert(error == kCVReturnSuccess,
@"CVPixelBufferUnlockBaseAddress failed: %d", error); @"CVPixelBufferUnlockBaseAddress failed: %d", error);
if ([wrapper.delegate respondsToSelector:@selector if ([wrapper.delegate
respondsToSelector:@selector
(mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) {
[wrapper.delegate mediapipeGraph:wrapper [wrapper.delegate mediapipeGraph:wrapper
didOutputPixelBuffer:pixelBuffer didOutputPixelBuffer:pixelBuffer
fromStream:streamName fromStream:streamName
timestamp:packet.Timestamp()]; timestamp:packet.Timestamp()];
} else if ([wrapper.delegate respondsToSelector:@selector } else if ([wrapper.delegate
respondsToSelector:@selector
(mediapipeGraph:didOutputPixelBuffer:fromStream:)]) { (mediapipeGraph:didOutputPixelBuffer:fromStream:)]) {
[wrapper.delegate mediapipeGraph:wrapper [wrapper.delegate mediapipeGraph:wrapper
didOutputPixelBuffer:pixelBuffer didOutputPixelBuffer:pixelBuffer
@ -168,7 +185,8 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
wrapper->_framesInFlight--; wrapper->_framesInFlight--;
CVPixelBufferRef pixelBuffer; CVPixelBufferRef pixelBuffer;
if (packetType == MPPPacketTypePixelBuffer) if (packetType == MPPPacketTypePixelBuffer)
pixelBuffer = mediapipe::GetCVPixelBufferRef(packet.Get<mediapipe::GpuBuffer>()); pixelBuffer =
mediapipe::GetCVPixelBufferRef(packet.Get<mediapipe::GpuBuffer>());
else else
pixelBuffer = packet.Get<mediapipe::Image>().GetCVPixelBufferRef(); pixelBuffer = packet.Get<mediapipe::Image>().GetCVPixelBufferRef();
if ([wrapper.delegate if ([wrapper.delegate
@ -192,13 +210,15 @@ if ([wrapper.delegate
} }
} }
- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName { - (void)setHeaderPacket:(const mediapipe::Packet&)packet
forStream:(const std::string&)streamName {
_GTMDevAssert(!_started, @"%@ must be called before the graph is started", _GTMDevAssert(!_started, @"%@ must be called before the graph is started",
NSStringFromSelector(_cmd)); NSStringFromSelector(_cmd));
_streamHeaders[streamName] = packet; _streamHeaders[streamName] = packet;
} }
- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name { - (void)setSidePacket:(const mediapipe::Packet&)packet
named:(const std::string&)name {
_GTMDevAssert(!_started, @"%@ must be called before the graph is started", _GTMDevAssert(!_started, @"%@ must be called before the graph is started",
NSStringFromSelector(_cmd)); NSStringFromSelector(_cmd));
_inputSidePackets[name] = packet; _inputSidePackets[name] = packet;
@ -211,7 +231,8 @@ if ([wrapper.delegate
_servicePackets[&service] = std::move(packet); _servicePackets[&service] = std::move(packet);
} }
- (void)addSidePackets:(const std::map<std::string, mediapipe::Packet>&)extraSidePackets { - (void)addSidePackets:
(const std::map<std::string, mediapipe::Packet>&)extraSidePackets {
_GTMDevAssert(!_started, @"%@ must be called before the graph is started", _GTMDevAssert(!_started, @"%@ must be called before the graph is started",
NSStringFromSelector(_cmd)); NSStringFromSelector(_cmd));
_inputSidePackets.insert(extraSidePackets.begin(), extraSidePackets.end()); _inputSidePackets.insert(extraSidePackets.begin(), extraSidePackets.end());
@ -232,7 +253,8 @@ if ([wrapper.delegate
- (absl::Status)performStart { - (absl::Status)performStart {
absl::Status status; absl::Status status;
for (const auto& service_packet : _servicePackets) { for (const auto& service_packet : _servicePackets) {
status = _graph->SetServicePacket(*service_packet.first, service_packet.second); status =
_graph->SetServicePacket(*service_packet.first, service_packet.second);
if (!status.ok()) { if (!status.ok()) {
return status; return status;
} }
@ -269,10 +291,11 @@ if ([wrapper.delegate
} }
- (BOOL)waitUntilDoneWithError:(NSError**)error { - (BOOL)waitUntilDoneWithError:(NSError**)error {
// Since this method blocks with no timeout, it should not be called in the main thread in // Since this method blocks with no timeout, it should not be called in the
// an app. However, it's fine to allow that in a test. // main thread in an app. However, it's fine to allow that in a test.
// TODO: is this too heavy-handed? Maybe a warning would be fine. // TODO: is this too heavy-handed? Maybe a warning would be fine.
_GTMDevAssert(![NSThread isMainThread] || (NSClassFromString(@"XCTest")), _GTMDevAssert(
![NSThread isMainThread] || (NSClassFromString(@"XCTest")),
@"waitUntilDoneWithError: should not be called on the main thread"); @"waitUntilDoneWithError: should not be called on the main thread");
absl::Status status = _graph->WaitUntilDone(); absl::Status status = _graph->WaitUntilDone();
_started = NO; _started = NO;
@ -289,7 +312,8 @@ if ([wrapper.delegate
- (BOOL)movePacket:(mediapipe::Packet&&)packet - (BOOL)movePacket:(mediapipe::Packet&&)packet
intoStream:(const std::string&)streamName intoStream:(const std::string&)streamName
error:(NSError**)error { error:(NSError**)error {
absl::Status status = _graph->AddPacketToInputStream(streamName, std::move(packet)); absl::Status status =
_graph->AddPacketToInputStream(streamName, std::move(packet));
if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status];
return status.ok(); return status.ok();
} }
@ -305,7 +329,8 @@ if ([wrapper.delegate
- (BOOL)setMaxQueueSize:(int)maxQueueSize - (BOOL)setMaxQueueSize:(int)maxQueueSize
forStream:(const std::string&)streamName forStream:(const std::string&)streamName
error:(NSError**)error { error:(NSError**)error {
absl::Status status = _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize); absl::Status status =
_graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize);
if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status];
return status.ok(); return status.ok();
} }
@ -313,7 +338,8 @@ if ([wrapper.delegate
- (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)imageBuffer - (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)imageBuffer
packetType:(MPPPacketType)packetType { packetType:(MPPPacketType)packetType {
mediapipe::Packet packet; mediapipe::Packet packet;
if (packetType == MPPPacketTypeImageFrame || packetType == MPPPacketTypeImageFrameBGRANoSwap) { if (packetType == MPPPacketTypeImageFrame ||
packetType == MPPPacketTypeImageFrameBGRANoSwap) {
auto frame = CreateImageFrameForCVPixelBuffer( auto frame = CreateImageFrameForCVPixelBuffer(
imageBuffer, /* canOverwrite = */ false, imageBuffer, /* canOverwrite = */ false,
/* bgrAsRgb = */ packetType == MPPPacketTypeImageFrameBGRANoSwap); /* bgrAsRgb = */ packetType == MPPPacketTypeImageFrameBGRANoSwap);
@ -328,7 +354,8 @@ if ([wrapper.delegate
packet = mediapipe::MakePacket<mediapipe::Image>(imageBuffer); packet = mediapipe::MakePacket<mediapipe::Image>(imageBuffer);
#else #else
// CPU // CPU
auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer, /* canOverwrite = */ false, auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer,
/* canOverwrite = */ false,
/* bgrAsRgb = */ false); /* bgrAsRgb = */ false);
packet = mediapipe::MakePacket<mediapipe::Image>(std::move(frame)); packet = mediapipe::MakePacket<mediapipe::Image>(std::move(frame));
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
@ -339,7 +366,8 @@ if ([wrapper.delegate
} }
- (mediapipe::Packet)imagePacketWithPixelBuffer:(CVPixelBufferRef)pixelBuffer { - (mediapipe::Packet)imagePacketWithPixelBuffer:(CVPixelBufferRef)pixelBuffer {
return [self packetWithPixelBuffer:(pixelBuffer) packetType:(MPPPacketTypeImage)]; return [self packetWithPixelBuffer:(pixelBuffer)
packetType:(MPPPacketTypeImage)];
} }
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer - (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
@ -367,13 +395,16 @@ if ([wrapper.delegate
allowOverwrite:(BOOL)allowOverwrite allowOverwrite:(BOOL)allowOverwrite
error:(NSError**)error { error:(NSError**)error {
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO; if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType]; mediapipe::Packet packet =
[self packetWithPixelBuffer:imageBuffer packetType:packetType];
BOOL success; BOOL success;
if (allowOverwrite) { if (allowOverwrite) {
packet = std::move(packet).At(timestamp); packet = std::move(packet).At(timestamp);
success = [self movePacket:std::move(packet) intoStream:inputName error:error]; success =
[self movePacket:std::move(packet) intoStream:inputName error:error];
} else { } else {
success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error]; success =
[self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
} }
if (success) _framesInFlight++; if (success) _framesInFlight++;
return success; return success;
@ -407,7 +438,8 @@ if ([wrapper.delegate
} }
- (void)debugPrintGlInfo { - (void)debugPrintGlInfo {
std::shared_ptr<mediapipe::GpuResources> gpu_resources = _graph->GetGpuResources(); std::shared_ptr<mediapipe::GpuResources> gpu_resources =
_graph->GetGpuResources();
if (!gpu_resources) { if (!gpu_resources) {
NSLog(@"GPU not set up."); NSLog(@"GPU not set up.");
return; return;
@ -415,14 +447,15 @@ if ([wrapper.delegate
NSString* extensionString; NSString* extensionString;
(void)gpu_resources->gl_context()->Run([&extensionString] { (void)gpu_resources->gl_context()->Run([&extensionString] {
extensionString = [NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)]; extensionString =
[NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)];
return absl::OkStatus(); return absl::OkStatus();
}); });
NSArray* extensions = [extensionString componentsSeparatedByCharactersInSet: NSArray* extensions = [extensionString
[NSCharacterSet whitespaceCharacterSet]]; componentsSeparatedByCharactersInSet:[NSCharacterSet
for (NSString* oneExtension in extensions) whitespaceCharacterSet]];
NSLog(@"%@", oneExtension); for (NSString* oneExtension in extensions) NSLog(@"%@", oneExtension);
} }
@end @end

View File

@ -20,8 +20,7 @@
mediapipe::TimestampDiff _timestampOffset; mediapipe::TimestampDiff _timestampOffset;
} }
- (instancetype)init - (instancetype)init {
{
self = [super init]; self = [super init];
if (self) { if (self) {
[self reset]; [self reset];
@ -36,11 +35,14 @@
} }
- (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime { - (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime {
Float64 sampleSeconds = CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0; Float64 sampleSeconds =
const int64 sampleUsec = sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond; CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0;
const int64 sampleUsec =
sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond;
_mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset; _mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset;
if (_mediapipeTimestamp <= _lastTimestamp) { if (_mediapipeTimestamp <= _lastTimestamp) {
_timestampOffset = _timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp; _timestampOffset =
_timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp;
_mediapipeTimestamp = _lastTimestamp + 1; _mediapipeTimestamp = _lastTimestamp + 1;
} }
_lastTimestamp = _mediapipeTimestamp; _lastTimestamp = _mediapipeTimestamp;

View File

@ -0,0 +1,72 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/objc/NSError+util_status.h"
@implementation GUSUtilStatusWrapper
+ (instancetype)wrapStatus:(const absl::Status &)status {
return [[self alloc] initWithStatus:status];
}
- (instancetype)initWithStatus:(const absl::Status &)status {
self = [super init];
if (self) {
_status = status;
}
return self;
}
- (NSString *)description {
return [NSString stringWithFormat:@"<%@: %p; status = %s>", [self class],
self, _status.message().data()];
}
@end
@implementation NSError (GUSGoogleUtilStatus)
NSString *const kGUSGoogleUtilStatusErrorDomain =
@"GoogleUtilStatusErrorDomain";
NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey";
+ (NSError *)gus_errorWithStatus:(const absl::Status &)status {
NSDictionary *userInfo = @{
NSLocalizedDescriptionKey : @(status.message().data()),
kGUSGoogleUtilStatusErrorKey : [GUSUtilStatusWrapper wrapStatus:status],
};
NSError *error =
[NSError errorWithDomain:kGUSGoogleUtilStatusErrorDomain
code:static_cast<NSInteger>(status.code())
userInfo:userInfo];
return error;
}
- (absl::Status)gus_status {
NSString *domain = self.domain;
if ([domain isEqual:kGUSGoogleUtilStatusErrorDomain]) {
GUSUtilStatusWrapper *wrapper = self.userInfo[kGUSGoogleUtilStatusErrorKey];
if (wrapper) return wrapper.status;
#if 0
// Unfortunately, util/task/posixerrorspace.h is not in portable status yet.
// TODO: fix that.
} else if ([domain isEqual:NSPOSIXErrorDomain]) {
return ::util::PosixErrorToStatus(self.code, self.localizedDescription.UTF8String);
#endif
}
return absl::Status(absl::StatusCode::kUnknown,
self.localizedDescription.UTF8String);
}
@end

View File

@ -207,8 +207,12 @@ class ImageTest(absltest.TestCase):
loaded_image = Image.create_from_file(image_path) loaded_image = Image.create_from_file(image_path)
self.assertEqual(loaded_image.width, 720) self.assertEqual(loaded_image.width, 720)
self.assertEqual(loaded_image.height, 382) self.assertEqual(loaded_image.height, 382)
self.assertEqual(loaded_image.channels, 3) # On Mac w/ GPU support, images use 4 channels (SRGBA). Otherwise, all
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) # images use 3 channels (SRGB).
self.assertIn(loaded_image.channels, [3, 4])
self.assertIn(
loaded_image.image_format, [ImageFormat.SRGB, ImageFormat.SRGBA]
)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -51,10 +51,10 @@ void ImageSubmodule(pybind11::module* module) {
```python ```python
import cv2 import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1] cv_mat = cv2.imread(input_file)
rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat) rgb_frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.Image( gray_frame = mp.Image(
image_format=ImageFormat.GRAY, image_format=mp.ImageFormat.GRAY8,
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
from PIL import Image from PIL import Image
@ -244,12 +244,26 @@ void ImageSubmodule(pybind11::module* module) {
image.def_static( image.def_static(
"create_from_file", "create_from_file",
[](const std::string& file_name) { [](const std::string& file_name) {
unsigned char* image_data = nullptr;
int width; int width;
int height; int height;
int channels; int channels;
auto* image_data =
stbi_load(file_name.c_str(), &width, &height, &channels, #if TARGET_OS_OSX && !MEDIAPIPE_DISABLE_GPU
// Our ObjC layer does not support 3-channel images, so we read the
// number of channels first and request RGBA if needed.
if (stbi_info(file_name.c_str(), &width, &height, &channels)) {
if (channels == 3) {
channels = 4;
}
int unused;
image_data =
stbi_load(file_name.c_str(), &width, &height, &unused, channels);
}
#else
image_data = stbi_load(file_name.c_str(), &width, &height, &channels,
/*desired_channels=*/0); /*desired_channels=*/0);
#endif // TARGET_OS_OSX && !MEDIAPIPE_DISABLE_GPU
if (image_data == nullptr) { if (image_data == nullptr) {
throw RaisePyError(PyExc_RuntimeError, throw RaisePyError(PyExc_RuntimeError,
absl::StrFormat("Image decoding failed (%s): %s", absl::StrFormat("Image decoding failed (%s): %s",
@ -263,11 +277,13 @@ void ImageSubmodule(pybind11::module* module) {
ImageFormat::GRAY8, width, height, width, image_data, ImageFormat::GRAY8, width, height, width, image_data,
stbi_image_free); stbi_image_free);
break; break;
#if !TARGET_OS_OSX || MEDIAPIPE_DISABLE_GPU
case 3: case 3:
image_frame = std::make_shared<ImageFrame>( image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGB, width, height, 3 * width, image_data, ImageFormat::SRGB, width, height, 3 * width, image_data,
stbi_image_free); stbi_image_free);
break; break;
#endif // !TARGET_OS_OSX || MEDIAPIPE_DISABLE_GPU
case 4: case 4:
image_frame = std::make_shared<ImageFrame>( image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGBA, width, height, 4 * width, image_data, ImageFormat::SRGBA, width, height, 4 * width, image_data,

View File

@ -81,8 +81,10 @@ void ImageFrameSubmodule(pybind11::module* module) {
become immutable after creation. become immutable after creation.
Creation examples: Creation examples:
```python
import cv2 import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1] cv_mat = cv2.imread(input_file)
rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.ImageFrame( gray_frame = mp.ImageFrame(
image_format=ImageFormat.GRAY, image_format=ImageFormat.GRAY,
@ -92,6 +94,7 @@ void ImageFrameSubmodule(pybind11::module* module) {
pil_img = Image.new('RGB', (60, 30), color = 'red') pil_img = Image.new('RGB', (60, 30), color = 'red')
image_frame = mp.ImageFrame( image_frame = mp.ImageFrame(
image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
```
The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling
`ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the

View File

@ -30,13 +30,12 @@ cc_library(
"//mediapipe/tasks/c/components/processors:classifier_options_converter", "//mediapipe/tasks/c/components/processors:classifier_options_converter",
"//mediapipe/tasks/c/core:base_options", "//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter", "//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/image_classifier", "//mediapipe/tasks/cc/vision/image_classifier",
"//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h" #include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h"
#include <cstdint>
#include <cstdlib>
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -26,6 +28,7 @@ limitations under the License.
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h" #include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" #include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h" #include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
@ -41,7 +44,10 @@ using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions; CppConvertToClassifierOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::vision::CreateImageFromBuffer; using ::mediapipe::tasks::vision::CreateImageFromBuffer;
using ::mediapipe::tasks::vision::core::RunningMode;
using ::mediapipe::tasks::vision::image_classifier::ImageClassifier; using ::mediapipe::tasks::vision::image_classifier::ImageClassifier;
typedef ::mediapipe::tasks::vision::image_classifier::ImageClassifierResult
CppImageClassifierResult;
int CppProcessError(absl::Status status, char** error_msg) { int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) { if (error_msg) {
@ -60,6 +66,53 @@ ImageClassifier* CppImageClassifierCreate(const ImageClassifierOptions& options,
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToClassifierOptions(options.classifier_options, CppConvertToClassifierOptions(options.classifier_options,
&cpp_options->classifier_options); &cpp_options->classifier_options);
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
// Enable callback for processing live stream data when the running mode is
// set to RunningMode::LIVE_STREAM.
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
if (options.result_callback == nullptr) {
const absl::Status status = absl::InvalidArgumentError(
"Provided null pointer to callback function.");
ABSL_LOG(ERROR) << "Failed to create ImageClassifier: " << status;
CppProcessError(status, error_msg);
return nullptr;
}
ImageClassifierOptions::result_callback_fn result_callback =
options.result_callback;
cpp_options->result_callback =
[result_callback](absl::StatusOr<CppImageClassifierResult> cpp_result,
const Image& image, int64_t timestamp) {
char* error_msg = nullptr;
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status();
CppProcessError(cpp_result.status(), &error_msg);
result_callback(nullptr, MpImage(), timestamp, error_msg);
free(error_msg);
return;
}
// Result is valid for the lifetime of the callback function.
ImageClassifierResult result;
CppConvertToClassificationResult(*cpp_result, &result);
const auto& image_frame = image.GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<::ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
result_callback(&result, mp_image, timestamp,
/* error_msg= */ nullptr);
CppCloseClassificationResult(&result);
};
}
auto classifier = ImageClassifier::Create(std::move(cpp_options)); auto classifier = ImageClassifier::Create(std::move(cpp_options));
if (!classifier.ok()) { if (!classifier.ok()) {
@ -75,8 +128,8 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image,
ImageClassifierResult* result, ImageClassifierResult* result,
char** error_msg) { char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) { if (image->type == MpImage::GPU_BUFFER) {
absl::Status status = const absl::Status status =
absl::InvalidArgumentError("gpu buffer not supported yet"); absl::InvalidArgumentError("GPU Buffer not supported yet.");
ABSL_LOG(ERROR) << "Classification failed: " << status.message(); ABSL_LOG(ERROR) << "Classification failed: " << status.message();
return CppProcessError(status, error_msg); return CppProcessError(status, error_msg);
@ -102,6 +155,68 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image,
return 0; return 0;
} }
int CppImageClassifierClassifyForVideo(void* classifier, const MpImage* image,
int64_t timestamp_ms,
ImageClassifierResult* result,
char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Classification failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_classifier = static_cast<ImageClassifier*>(classifier);
auto cpp_result = cpp_classifier->ClassifyForVideo(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToClassificationResult(*cpp_result, result);
return 0;
}
int CppImageClassifierClassifyAsync(void* classifier, const MpImage* image,
int64_t timestamp_ms, char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Classification failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_classifier = static_cast<ImageClassifier*>(classifier);
auto cpp_result = cpp_classifier->ClassifyAsync(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Data preparation for the image classification failed: "
<< cpp_result;
return CppProcessError(cpp_result, error_msg);
}
return 0;
}
void CppImageClassifierCloseResult(ImageClassifierResult* result) { void CppImageClassifierCloseResult(ImageClassifierResult* result) {
CppCloseClassificationResult(result); CppCloseClassificationResult(result);
} }
@ -134,6 +249,22 @@ int image_classifier_classify_image(void* classifier, const MpImage* image,
CppImageClassifierClassify(classifier, image, result, error_msg); CppImageClassifierClassify(classifier, image, result, error_msg);
} }
int image_classifier_classify_for_video(void* classifier, const MpImage* image,
int64_t timestamp_ms,
ImageClassifierResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::image_classifier::
CppImageClassifierClassifyForVideo(classifier, image, timestamp_ms,
result, error_msg);
}
int image_classifier_classify_async(void* classifier, const MpImage* image,
int64_t timestamp_ms, char** error_msg) {
return mediapipe::tasks::c::vision::image_classifier::
CppImageClassifierClassifyAsync(classifier, image, timestamp_ms,
error_msg);
}
void image_classifier_close_result(ImageClassifierResult* result) { void image_classifier_close_result(ImageClassifierResult* result) {
mediapipe::tasks::c::vision::image_classifier::CppImageClassifierCloseResult( mediapipe::tasks::c::vision::image_classifier::CppImageClassifierCloseResult(
result); result);

View File

@ -92,9 +92,16 @@ struct ImageClassifierOptions {
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM. Arguments of the callback function include:
typedef void (*result_callback_fn)(ImageClassifierResult*, const MpImage*, // the pointer to classification result, the image that result was obtained
int64_t); // on, the timestamp relevant to classification results and pointer to error
// message in case of any failure. The validity of the passed arguments is
// true for the lifetime of the callback function.
//
// A caller is responsible for closing image classifier result.
typedef void (*result_callback_fn)(ImageClassifierResult* result,
const MpImage image, int64_t timestamp_ms,
char* error_msg);
result_callback_fn result_callback; result_callback_fn result_callback;
}; };
@ -110,13 +117,22 @@ MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options,
// If an error occurs, returns an error code and sets the error parameter to an // If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory // an error message (if `error_msg` is not nullptr). You must free the memory
// allocated for the error message. // allocated for the error message.
//
// TODO: Add API for video and live stream processing.
MP_EXPORT int image_classifier_classify_image(void* classifier, MP_EXPORT int image_classifier_classify_image(void* classifier,
const MpImage* image, const MpImage* image,
ImageClassifierResult* result, ImageClassifierResult* result,
char** error_msg = nullptr); char** error_msg = nullptr);
MP_EXPORT int image_classifier_classify_for_video(void* classifier,
const MpImage* image,
int64_t timestamp_ms,
ImageClassifierResult* result,
char** error_msg = nullptr);
MP_EXPORT int image_classifier_classify_async(void* classifier,
const MpImage* image,
int64_t timestamp_ms,
char** error_msg = nullptr);
// Frees the memory allocated inside a ImageClassifierResult result. // Frees the memory allocated inside a ImageClassifierResult result.
// Does not free the result pointer itself. // Does not free the result pointer itself.
MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result); MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h" #include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h"
#include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <string> #include <string>
@ -36,12 +37,13 @@ using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kModelName[] = "mobilenet_v2_1.0_224.tflite"; constexpr char kModelName[] = "mobilenet_v2_1.0_224.tflite";
constexpr float kPrecision = 1e-4; constexpr float kPrecision = 1e-4;
constexpr int kIterations = 100;
std::string GetFullPath(absl::string_view file_name) { std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name); return JoinPath("./", kTestDataDirectory, file_name);
} }
TEST(ImageClassifierTest, SmokeTest) { TEST(ImageClassifierTest, ImageModeTest) {
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
ASSERT_TRUE(image.ok()); ASSERT_TRUE(image.ok());
@ -63,14 +65,13 @@ TEST(ImageClassifierTest, SmokeTest) {
void* classifier = image_classifier_create(&options); void* classifier = image_classifier_create(&options);
EXPECT_NE(classifier, nullptr); EXPECT_NE(classifier, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = { const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME, .type = MpImage::IMAGE_FRAME,
.image_frame = { .image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.format = static_cast<ImageFormat>( .image_buffer = image_frame->PixelData(),
image->GetImageFrameSharedPtr()->Format()), .width = image_frame->Width(),
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(), .height = image_frame->Height()}};
.width = image->GetImageFrameSharedPtr()->Width(),
.height = image->GetImageFrameSharedPtr()->Height()}};
ImageClassifierResult result; ImageClassifierResult result;
image_classifier_classify_image(classifier, &mp_image, &result); image_classifier_classify_image(classifier, &mp_image, &result);
@ -84,6 +85,120 @@ TEST(ImageClassifierTest, SmokeTest) {
image_classifier_close(classifier); image_classifier_close(classifier);
} }
TEST(ImageClassifierTest, VideoModeTest) {
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::VIDEO,
/* classifier_options= */
{/* display_names_locale= */ nullptr,
/* max_results= */ 3,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0},
/* result_callback= */ nullptr,
};
void* classifier = image_classifier_create(&options);
EXPECT_NE(classifier, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
ImageClassifierResult result;
image_classifier_classify_for_video(classifier, &mp_image, i, &result);
EXPECT_EQ(result.classifications_count, 1);
EXPECT_EQ(result.classifications[0].categories_count, 3);
EXPECT_EQ(
std::string{result.classifications[0].categories[0].category_name},
"cheeseburger");
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f,
kPrecision);
image_classifier_close_result(&result);
}
image_classifier_close(classifier);
}
// A structure to support LiveStreamModeTest below. This structure holds a
// static method `Fn` for a callback function of C API. A `static` qualifier
// allows to take an address of the method to follow API style. Another static
// struct member is `last_timestamp` that is used to verify that current
// timestamp is greater than the previous one.
struct LiveStreamModeCallback {
static int64_t last_timestamp;
static void Fn(ImageClassifierResult* classifier_result, const MpImage image,
int64_t timestamp, char* error_msg) {
ASSERT_NE(classifier_result, nullptr);
ASSERT_EQ(error_msg, nullptr);
EXPECT_EQ(
std::string{
classifier_result->classifications[0].categories[0].category_name},
"cheeseburger");
EXPECT_NEAR(classifier_result->classifications[0].categories[0].score,
0.7939f, kPrecision);
EXPECT_GT(image.image_frame.width, 0);
EXPECT_GT(image.image_frame.height, 0);
EXPECT_GT(timestamp, last_timestamp);
last_timestamp++;
}
};
int64_t LiveStreamModeCallback::last_timestamp = -1;
TEST(ImageClassifierTest, LiveStreamModeTest) {
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::LIVE_STREAM,
/* classifier_options= */
{/* display_names_locale= */ nullptr,
/* max_results= */ 3,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0},
/* result_callback= */ LiveStreamModeCallback::Fn,
};
void* classifier = image_classifier_create(&options);
EXPECT_NE(classifier, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
EXPECT_GE(image_classifier_classify_async(classifier, &mp_image, i), 0);
}
image_classifier_close(classifier);
// Due to the flow limiter, the total of outputs might be smaller than the
// number of iterations.
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
}
TEST(ImageClassifierTest, InvalidArgumentHandling) { TEST(ImageClassifierTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path. // It is an error to set neither the asset buffer nor the path.
ImageClassifierOptions options = { ImageClassifierOptions options = {
@ -124,7 +239,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) {
ImageClassifierResult result; ImageClassifierResult result;
char* error_msg; char* error_msg;
image_classifier_classify_image(classifier, &mp_image, &result, &error_msg); image_classifier_classify_image(classifier, &mp_image, &result, &error_msg);
EXPECT_THAT(error_msg, HasSubstr("gpu buffer not supported yet")); EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet"));
free(error_msg); free(error_msg);
image_classifier_close(classifier); image_classifier_close(classifier);
} }

View File

@ -98,3 +98,9 @@ mediapipe_proto_library(
name = "transformer_params_proto", name = "transformer_params_proto",
srcs = ["transformer_params.proto"], srcs = ["transformer_params.proto"],
) )
mediapipe_proto_library(
name = "llm_params_proto",
srcs = ["llm_params.proto"],
deps = [":transformer_params_proto"],
)

View File

@ -0,0 +1,41 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.components.processors.proto;
import "mediapipe/tasks/cc/components/processors/proto/transformer_params.proto";
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "LLMParametersProto";
// Parameters for Large Language Models (LLM).
message LLMParameters {
TransformerParameters transformer_parameters = 1;
// Size of vocabulary.
int32 vocab_size = 2;
// Whether or not to disable KV cache, which is also referred as state
// somewhere else.
bool disable_kv_cache = 3;
// Id of the start token.
int32 start_token_id = 4;
// Token to determine the end of output stream.
string stop_token = 5;
}

View File

@ -44,6 +44,21 @@ message TransformerParameters {
// Number of stacked transformers, `N` in the paper. // Number of stacked transformers, `N` in the paper.
int32 num_stacks = 7; int32 num_stacks = 7;
// Whether to use Multi-Query-Attention (MQA). // Deprecated: bool use_mqa. Use num_kv_heads below.
bool use_mqa = 8; reserved 8;
// Number of kv heads. 0 means Multi-Head-Attention (MHA), key and value have
// same number of heads as query; 1 means Multi-Query-Attention (MQA), key and
// value have one head; otherwise, this specifies the number of heads for key
// and value, and Grouped-Query-Attention (GQA) will be used. See
// https://arxiv.org/pdf/2305.13245.pdf for details.
int32 num_kv_heads = 9;
// Different types of attention mask type.
enum AttentionMaskType {
UNSPECIFIED = 0;
CAUSAL = 1;
PREFIX = 2;
}
AttentionMaskType attention_mask_type = 10;
} }

View File

@ -264,6 +264,7 @@ cc_library_with_tflite(
"//mediapipe/framework:executor", "//mediapipe/framework:executor",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util", "//mediapipe/framework/tool:name_util",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",

View File

@ -39,6 +39,10 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace core { namespace core {
@ -88,16 +92,34 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
} // namespace } // namespace
/* static */ /* static */
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets,
std::shared_ptr<::mediapipe::GpuResources> resources) {
#else
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create( absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver, std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback, PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor, std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) { std::optional<PacketMap> input_side_packets) {
#endif // !MEDIAPIPE_DISABLE_GPU
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
MP_RETURN_IF_ERROR(task_runner->Initialize( MP_RETURN_IF_ERROR(task_runner->Initialize(
std::move(config), std::move(op_resolver), std::move(default_executor), std::move(config), std::move(op_resolver), std::move(default_executor),
std::move(input_side_packets))); std::move(input_side_packets)));
#if !MEDIAPIPE_DISABLE_GPU
if (resources) {
MP_RETURN_IF_ERROR(
task_runner->graph_.SetGpuResources(std::move(resources)));
}
#endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(task_runner->Start()); MP_RETURN_IF_ERROR(task_runner->Start());
return task_runner; return task_runner;
} }

View File

@ -42,6 +42,11 @@ limitations under the License.
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe { namespace mediapipe {
#if !MEDIAPIPE_DISABLE_GPU
class GpuResources;
#endif // !MEDIAPIPE_DISABLE_GPU
namespace tasks { namespace tasks {
namespace core { namespace core {
@ -72,12 +77,22 @@ class TaskRunner {
// asynchronous method, Send(), to provide the input packets. If the packets // asynchronous method, Send(), to provide the input packets. If the packets
// callback is absent, clients must use the synchronous method, Process(), to // callback is absent, clients must use the synchronous method, Process(), to
// provide the input packets and receive the output packets. // provide the input packets and receive the output packets.
#if !MEDIAPIPE_DISABLE_GPU
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt,
std::shared_ptr<::mediapipe::GpuResources> resources = nullptr);
#else
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create( static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr, std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr, PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr, std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt); std::optional<PacketMap> input_side_packets = std::nullopt);
#endif // !MEDIAPIPE_DISABLE_GPU
// TaskRunner is neither copyable nor movable. // TaskRunner is neither copyable nor movable.
TaskRunner(const TaskRunner&) = delete; TaskRunner(const TaskRunner&) = delete;

View File

@ -57,6 +57,7 @@ CALCULATORS_AND_GRAPHS = [
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
] ]
@ -83,6 +84,7 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h", "//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h", "//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPMask.h",
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetector.h", "//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetector.h",
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorOptions.h", "//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorOptions.h",
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorResult.h", "//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorResult.h",
@ -98,6 +100,9 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifier.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifier.h",
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierOptions.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierOptions.h",
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
"//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenter.h",
"//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenterOptions.h",
"//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenterResult.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
@ -178,6 +183,7 @@ apple_static_xcframework(
":MPPTaskOptions.h", ":MPPTaskOptions.h",
":MPPTaskResult.h", ":MPPTaskResult.h",
":MPPImage.h", ":MPPImage.h",
":MPPMask.h",
":MPPRunningMode.h", ":MPPRunningMode.h",
":MPPFaceDetector.h", ":MPPFaceDetector.h",
":MPPFaceDetectorOptions.h", ":MPPFaceDetectorOptions.h",
@ -188,6 +194,9 @@ apple_static_xcframework(
":MPPImageClassifier.h", ":MPPImageClassifier.h",
":MPPImageClassifierOptions.h", ":MPPImageClassifierOptions.h",
":MPPImageClassifierResult.h", ":MPPImageClassifierResult.h",
":MPPImageSegmenter.h",
":MPPImageSegmenterOptions.h",
":MPPImageSegmenterResult.h",
":MPPHandLandmarker.h", ":MPPHandLandmarker.h",
":MPPHandLandmarkerOptions.h", ":MPPHandLandmarkerOptions.h",
":MPPHandLandmarkerResult.h", ":MPPHandLandmarkerResult.h",
@ -204,6 +213,7 @@ apple_static_xcframework(
"//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer", "//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer",
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker", "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker",
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenter",
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector", "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector",
], ],
) )

View File

@ -14,6 +14,15 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
/**
* The delegate to run MediaPipe. If the delegate is not set, the default
* delegate CPU is used.
*/
typedef NS_ENUM(NSUInteger, MPPDelegate) {
MPPDelegateCPU,
MPPDelegateGPU,
} NS_SWIFT_NAME(Delegate);
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
@ -26,6 +35,9 @@ NS_SWIFT_NAME(BaseOptions)
/** The path to the model asset to open and mmap in memory. */ /** The path to the model asset to open and mmap in memory. */
@property(nonatomic, copy) NSString *modelAssetPath; @property(nonatomic, copy) NSString *modelAssetPath;
/** Overrides the default backend to use for the provided model. */
@property(nonatomic) MPPDelegate delegate;
@end @end
NS_ASSUME_NONNULL_END NS_ASSUME_NONNULL_END

View File

@ -20,6 +20,7 @@
self = [super init]; self = [super init];
if (self) { if (self) {
self.modelAssetPath = [[NSString alloc] init]; self.modelAssetPath = [[NSString alloc] init];
self.delegate = MPPDelegateCPU;
} }
return self; return self;
} }
@ -28,6 +29,7 @@
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
baseOptions.modelAssetPath = self.modelAssetPath; baseOptions.modelAssetPath = self.modelAssetPath;
baseOptions.delegate = self.delegate;
return baseOptions; return baseOptions;
} }

View File

@ -21,6 +21,7 @@ objc_library(
srcs = ["sources/MPPBaseOptions+Helpers.mm"], srcs = ["sources/MPPBaseOptions+Helpers.mm"],
hdrs = ["sources/MPPBaseOptions+Helpers.h"], hdrs = ["sources/MPPBaseOptions+Helpers.h"],
deps = [ deps = [
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",

View File

@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
namespace { namespace {
using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
using InferenceCalculatorOptionsProto = ::mediapipe::InferenceCalculatorOptions;
} }
@implementation MPPBaseOptions (Helpers) @implementation MPPBaseOptions (Helpers)
@ -33,6 +35,11 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
if (self.modelAssetPath) { if (self.modelAssetPath) {
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
} }
if (self.delegate == MPPDelegateGPU) {
baseOptionsProto->mutable_acceleration()->mutable_gpu()->MergeFrom(
InferenceCalculatorOptionsProto::Delegate::Gpu());
}
} }
@end @end

View File

@ -31,3 +31,28 @@ objc_library(
"//mediapipe/tasks/ios/core:MPPTaskResult", "//mediapipe/tasks/ios/core:MPPTaskResult",
], ],
) )
objc_library(
name = "MPPLanguageDetector",
srcs = ["sources/MPPLanguageDetector.mm"],
hdrs = ["sources/MPPLanguageDetector.h"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
module_name = "MPPLanguageDetector",
deps = [
":MPPLanguageDetectorOptions",
":MPPLanguageDetectorResult",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
"//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorOptionsHelpers",
"//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorResultHelpers",
],
)

View File

@ -0,0 +1,88 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.h"
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN
/**
* @brief Predicts the language of an input text.
*
* This API expects a TFLite model with [TFLite Model
* Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory
* (described below) input tensor, output tensor, and the language codes in an AssociatedFile.
*
* Metadata is required for models with int32 input tensors because it contains the input
* process unit for the model's Tokenizer. No metadata is required for models with string
* input tensors.
*
* Input tensor
* - One input tensor (`kTfLiteString`) of shape `[1]` containing the input string.
*
* Output tensor
* - One output tensor (`kTfLiteFloat32`) of shape `[1 x N]` where `N` is the number of languages.
*/
NS_SWIFT_NAME(LanguageDetector)
@interface MPPLanguageDetector : NSObject
/**
* Creates a new instance of `LanguageDetector` from an absolute path to a TensorFlow Lite
* model file stored locally on the device and the default `LanguageDetectorOptions`.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
* @param error An optional error parameter populated when there is an error in initializing the
* language detector.
*
* @return A new instance of `LanguageDetector` with the given model path. `nil` if there is an
* error in initializing the language detector.
*/
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/**
* Creates a new instance of `LanguageDetector` from the given `LanguageDetectorOptions`.
*
* @param options The options of type `LanguageDetectorOptions` to use for configuring the
* `LanguageDetector`.
* @param error An optional error parameter populated when there is an error in initializing the
* language detector.
*
* @return A new instance of `LanguageDetector` with the given options. `nil` if there is an
* error in initializing the language detector.
*/
- (nullable instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/**
* Predicts the language of the input text.
*
* @param text The `NSString` for which language is to be predicted.
* @param error An optional error parameter populated when there is an error in performing
* language prediction on the input text.
*
* @return A `LanguageDetectorResult` object that contains a list of language predictions.
*/
- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text
error:(NSError **)error NS_SWIFT_NAME(detect(text:));
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,96 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.h"
#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h"
namespace {
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
} // namespace
static NSString *const kClassificationsStreamName = @"classifications_out";
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
static NSString *const kTextInStreamName = @"text_in";
static NSString *const kTextTag = @"TEXT";
static NSString *const kTaskGraphName =
@"mediapipe.tasks.text.language_detector.LanguageDetectorGraph";
@interface MPPLanguageDetector () {
/** iOS Text Task Runner */
MPPTextTaskRunner *_textTaskRunner;
}
@end
@implementation MPPLanguageDetector
- (instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options error:(NSError **)error {
self = [super init];
if (self) {
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
kClassificationsStreamName] ]
taskOptions:options
enableFlowLimiting:NO
error:error];
if (!taskInfo) {
return nil;
}
_textTaskRunner =
[[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig]
error:error];
if (!_textTaskRunner) {
return nil;
}
}
return self;
}
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
MPPLanguageDetectorOptions *options = [[MPPLanguageDetectorOptions alloc] init];
options.baseOptions.modelAssetPath = modelPath;
return [self initWithOptions:options error:error];
}
- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPTextPacketCreator createWithText:text];
std::map<std::string, Packet> packetMap = {{kTextInStreamName.cppString, packet}};
std::optional<PacketMap> outputPacketMap = [_textTaskRunner processPacketMap:packetMap
error:error];
if (!outputPacketMap.has_value()) {
return nil;
}
return
[MPPLanguageDetectorResult languageDetectorResultWithClassificationsPacket:
outputPacketMap.value()[kClassificationsStreamName.cppString]];
}
@end

View File

@ -30,3 +30,15 @@ objc_library(
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorOptions", "//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorOptions",
], ],
) )
objc_library(
name = "MPPLanguageDetectorResultHelpers",
srcs = ["sources/MPPLanguageDetectorResult+Helpers.mm"],
hdrs = ["sources/MPPLanguageDetectorResult+Helpers.h"],
deps = [
"//mediapipe/framework:packet",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorResult",
],
)

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h"
#include "mediapipe/framework/packet.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPLanguageDetectorResult (Helpers)
+ (MPPLanguageDetectorResult *)languageDetectorResultWithClassificationsPacket:
(const mediapipe::Packet &)packet;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,61 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
namespace {
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
} // namespace
#define int kMicroSecondsPerMilliSecond = 1000;
@implementation MPPLanguageDetectorResult (Helpers)
+ (MPPLanguageDetectorResult *)languageDetectorResultWithClassificationsPacket:
(const mediapipe::Packet &)packet {
MPPClassificationResult *classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
return [MPPLanguageDetectorResult
languageDetectorResultWithClassificationResult:classificationResult
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}
+ (MPPLanguageDetectorResult *)
languageDetectorResultWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
NSMutableArray<MPPLanguagePrediction *> *languagePredictions =
[NSMutableArray arrayWithCapacity:classificationResult.classifications.count];
if (classificationResult.classifications.count > 0) {
for (MPPCategory *category in classificationResult.classifications[0].categories) {
MPPLanguagePrediction *languagePrediction =
[[MPPLanguagePrediction alloc] initWithLanguageCode:category.categoryName
probability:category.score];
[languagePredictions addObject:languagePrediction];
}
}
return [[MPPLanguageDetectorResult alloc] initWithLanguagePredictions:languagePredictions
timestampInMilliseconds:timestampInMilliseconds];
}
@end

View File

@ -37,7 +37,7 @@ vImage_Buffer allocatedVImageBuffer(vImagePixelCount width, vImagePixelCount hei
} }
static void FreeDataProviderReleaseCallback(void *buffer, const void *data, size_t size) { static void FreeDataProviderReleaseCallback(void *buffer, const void *data, size_t size) {
delete (vImage_Buffer *)buffer; delete[] (vImage_Buffer *)buffer;
} }
} // namespace } // namespace

View File

@ -47,7 +47,7 @@ using ::mediapipe::Packet;
->PixelData() ->PixelData()
width:confidenceMask.width() width:confidenceMask.width()
height:confidenceMask.height() height:confidenceMask.height()
shouldCopy:shouldCopyMaskPacketData ? YES : NO]]; shouldCopy:shouldCopyMaskPacketData]];
} }
} }
@ -57,7 +57,7 @@ using ::mediapipe::Packet;
initWithUInt8Data:(UInt8 *)cppCategoryMask.GetImageFrameSharedPtr().get()->PixelData() initWithUInt8Data:(UInt8 *)cppCategoryMask.GetImageFrameSharedPtr().get()->PixelData()
width:cppCategoryMask.width() width:cppCategoryMask.width()
height:cppCategoryMask.height() height:cppCategoryMask.height()
shouldCopy:shouldCopyMaskPacketData ? YES : NO]; shouldCopy:shouldCopyMaskPacketData];
} }
if (qualityScoresPacket.ValidateAsType<std::vector<float>>().ok()) { if (qualityScoresPacket.ValidateAsType<std::vector<float>>().ok()) {

View File

@ -37,3 +37,23 @@ objc_library(
"//mediapipe/tasks/ios/vision/core:MPPRunningMode", "//mediapipe/tasks/ios/vision/core:MPPRunningMode",
], ],
) )
objc_library(
name = "MPPPoseLandmarksConnections",
hdrs = ["sources/MPPPoseLandmarksConnections.h"],
module_name = "MPPPoseLandmarksConnections",
deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"],
)
objc_library(
name = "MPPPoseLandmarker",
hdrs = ["sources/MPPPoseLandmarker.h"],
module_name = "MPPPoseLandmarker",
deps = [
":MPPPoseLandmarkerOptions",
":MPPPoseLandmarkerResult",
":MPPPoseLandmarksConnections",
"//mediapipe/tasks/ios/components/containers:MPPConnection",
"//mediapipe/tasks/ios/vision/core:MPPImage",
],
)

View File

@ -0,0 +1,160 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerOptions.h"
#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h"
NS_ASSUME_NONNULL_BEGIN
/**
* @brief Performs pose landmarks detection on images.
*
* This API expects a pre-trained pose landmarks model asset bundle.
*/
NS_SWIFT_NAME(PoseLandmarker)
@interface MPPPoseLandmarker : NSObject
/** The array of connections between all the landmarks in the detected pose. */
@property(class, nonatomic, readonly) NSArray<MPPConnection *> *poseLandmarks;
/**
* Creates a new instance of `PoseLandmarker` from an absolute path to a model asset bundle stored
* locally on the device and the default `PoseLandmarkerOptions`.
*
* @param modelPath An absolute path to a model asset bundle stored locally on the device.
* @param error An optional error parameter populated when there is an error in initializing the
* pose landmarker.
*
* @return A new instance of `PoseLandmarker` with the given model path. `nil` if there is an error
* in initializing the pose landmarker.
*/
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/**
* Creates a new instance of `PoseLandmarker` from the given `PoseLandmarkerOptions`.
*
* @param options The options of type `PoseLandmarkerOptions` to use for configuring the
* `PoseLandmarker`.
* @param error An optional error parameter populated when there is an error in initializing the
* pose landmarker.
*
* @return A new instance of `PoseLandmarker` with the given options. `nil` if there is an error in
* initializing the pose landmarker.
*/
- (nullable instancetype)initWithOptions:(MPPPoseLandmarkerOptions *)options
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/**
* Performs pose landmarks detection on the provided `MPImage` using the whole image as region of
* interest. Rotation will be applied according to the `orientation` property of the provided
* `MPImage`. Only use this method when the `PoseLandmarker` is created with running mode `.image`.
*
* This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a
* source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use
* `kCVPixelFormatType_32BGRA` as its pixel format.
*
*
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
* channel.
*
* @param image The `MPImage` on which pose landmarks detection is to be performed.
* @param error An optional error parameter populated when there is an error in performing pose
* landmark detection on the input image.
*
* @return An `PoseLandmarkerResult` object that contains the pose landmarks detection
* results.
*/
- (nullable MPPPoseLandmarkerResult *)detectImage:(MPPImage *)image
error:(NSError **)error NS_SWIFT_NAME(detect(image:));
/**
* Performs pose landmarks detection on the provided video frame of type `MPImage` using the whole
* image as region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPImage`. Only use this method when the `PoseLandmarker` is created with running
* mode `.video`.
*
* It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must
* be monotonically increasing.
*
* This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a
* source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use
* `kCVPixelFormatType_32BGRA` as its pixel format.
*
*
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
* channel.
*
* @param image The `MPImage` on which pose landmarks detection is to be performed.
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing pose
* landmark detection on the input video frame.
*
* @return An `PoseLandmarkerResult` object that contains the pose landmarks detection
* results.
*/
- (nullable MPPPoseLandmarkerResult *)detectVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
/**
* Sends live stream image data of type `MPImage` to perform pose landmarks detection using the
* whole image as region of interest. Rotation will be applied according to the `orientation`
* property of the provided `MPImage`. Only use this method when the `PoseLandmarker` is created
* with running mode`.liveStream`.
*
* The object which needs to be continuously notified of the available results of pose landmark
* detection must confirm to `PoseLandmarkerLiveStreamDelegate` protocol and implement the
* `poseLandmarker(_:didFinishDetectionWithResult:timestampInMilliseconds:error:)` delegate method.
*
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the pose landmarker. The input timestamps must be monotonically increasing.
*
* This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a
* source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use
* `kCVPixelFormatType_32BGRA` as its pixel format.
*
* If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an
* Alpha channel.
*
* If this method is used for performing pose landmarks detection on live camera frames using
* `AVFoundation`, ensure that you request `AVCaptureVideoDataOutput` to output frames in
* `kCMPixelFormat_32BGRA` using its `videoSettings` property.
*
* @param image A live stream image data of type `MPImage` on which pose landmarks detection is to
* be performed.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image is sent to the pose landmarker. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing pose
* landmark detection on the input live stream image data.
*
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)detectAsyncImage:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -46,7 +46,7 @@ NS_SWIFT_NAME(PoseLandmarkerResult)
*/ */
- (instancetype)initWithLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks - (instancetype)initWithLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
segmentationMasks:(NSArray<MPPMask *> *)segmentationMasks segmentationMasks:(nullable NSArray<MPPMask *> *)segmentationMasks
timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER; timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER;
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_UNAVAILABLE; - (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_UNAVAILABLE;

View File

@ -0,0 +1,40 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
NS_ASSUME_NONNULL_BEGIN
NSArray<MPPConnection *> *const MPPPoseLandmarksConnections = @[
[[MPPConnection alloc] initWithStart:0 end:1], [[MPPConnection alloc] initWithStart:1 end:2],
[[MPPConnection alloc] initWithStart:2 end:3], [[MPPConnection alloc] initWithStart:3 end:7],
[[MPPConnection alloc] initWithStart:0 end:4], [[MPPConnection alloc] initWithStart:4 end:5],
[[MPPConnection alloc] initWithStart:5 end:6], [[MPPConnection alloc] initWithStart:6 end:8],
[[MPPConnection alloc] initWithStart:9 end:10], [[MPPConnection alloc] initWithStart:11 end:12],
[[MPPConnection alloc] initWithStart:11 end:13], [[MPPConnection alloc] initWithStart:13 end:15],
[[MPPConnection alloc] initWithStart:15 end:17], [[MPPConnection alloc] initWithStart:15 end:19],
[[MPPConnection alloc] initWithStart:15 end:21], [[MPPConnection alloc] initWithStart:17 end:19],
[[MPPConnection alloc] initWithStart:12 end:14], [[MPPConnection alloc] initWithStart:14 end:16],
[[MPPConnection alloc] initWithStart:16 end:18], [[MPPConnection alloc] initWithStart:16 end:20],
[[MPPConnection alloc] initWithStart:16 end:22], [[MPPConnection alloc] initWithStart:18 end:20],
[[MPPConnection alloc] initWithStart:11 end:23], [[MPPConnection alloc] initWithStart:12 end:24],
[[MPPConnection alloc] initWithStart:23 end:24], [[MPPConnection alloc] initWithStart:23 end:25],
[[MPPConnection alloc] initWithStart:26 end:28], [[MPPConnection alloc] initWithStart:27 end:29],
[[MPPConnection alloc] initWithStart:28 end:30], [[MPPConnection alloc] initWithStart:29 end:31],
[[MPPConnection alloc] initWithStart:30 end:32], [[MPPConnection alloc] initWithStart:27 end:31],
[[MPPConnection alloc] initWithStart:28 end:32]
];
NS_ASSUME_NONNULL_END

View File

@ -36,3 +36,21 @@ objc_library(
"//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerOptions", "//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerOptions",
], ],
) )
objc_library(
name = "MPPPoseLandmarkerResultHelpers",
srcs = ["sources/MPPPoseLandmarkerResult+Helpers.mm"],
hdrs = ["sources/MPPPoseLandmarkerResult+Helpers.h"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
deps = [
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/tasks/ios/components/containers/utils:MPPLandmarkHelpers",
"//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerResult",
],
)

View File

@ -0,0 +1,63 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPPoseLandmarkerResult (Helpers)
/**
* Creates an `MPPPoseLandmarkerResult` from landmarks, world landmarks and segmentation mask
* packets.
*
* @param landmarksPacket A MediaPipe packet wrapping a `std::vector<NormalizedlandmarkListProto>`.
* @param worldLandmarksPacket A MediaPipe packet wrapping a `std::vector<LandmarkListProto>`.
* @param segmentationMasksPacket a MediaPipe packet wrapping a `std::vector<Image>`.
*
* @return An `MPPPoseLandmarkerResult` object that contains the hand landmark detection
* results.
*/
+ (MPPPoseLandmarkerResult *)
poseLandmarkerResultWithLandmarksPacket:(const mediapipe::Packet &)landmarksPacket
worldLandmarksPacket:(const mediapipe::Packet &)worldLandmarksPacket
segmentationMasksPacket:(const mediapipe::Packet *)segmentationMasksPacket;
/**
* Creates an `MPPPoseLandmarkerResult` from landmarks, world landmarks and segmentation mask
* images.
*
* @param landmarksProto A vector of protos of type `std::vector<NormalizedlandmarkListProto>`.
* @param worldLandmarksProto A vector of protos of type `std::vector<LandmarkListProto>`.
* @param segmentationMasks A vector of type `std::vector<Image>`.
* @param timestampInMilliSeconds The timestamp of the Packet that contained the result.
*
* @return An `MPPPoseLandmarkerResult` object that contains the pose landmark detection
* results.
*/
+ (MPPPoseLandmarkerResult *)
poseLandmarkerResultWithLandmarksProto:
(const std::vector<::mediapipe::NormalizedLandmarkList> &)landmarksProto
worldLandmarksProto:
(const std::vector<::mediapipe::LandmarkList> &)worldLandmarksProto
segmentationMasks:(const std::vector<mediapipe::Image> *)segmentationMasks
timestampInMilliSeconds:(NSInteger)timestampInMilliseconds;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,124 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPLandmark+Helpers.h"
namespace {
using LandmarkListProto = ::mediapipe::LandmarkList;
using NormalizedLandmarkListProto = ::mediapipe::NormalizedLandmarkList;
using ::mediapipe::Image;
using ::mediapipe::Packet;
static const int kMicroSecondsPerMilliSecond = 1000;
} // namespace
@implementation MPPPoseLandmarkerResult (Helpers)
+ (MPPPoseLandmarkerResult *)emptyPoseLandmarkerResultWithTimestampInMilliseconds:
(NSInteger)timestampInMilliseconds {
return [[MPPPoseLandmarkerResult alloc] initWithLandmarks:@[]
worldLandmarks:@[]
segmentationMasks:@[]
timestampInMilliseconds:timestampInMilliseconds];
}
+ (MPPPoseLandmarkerResult *)
poseLandmarkerResultWithLandmarksProto:
(const std::vector<NormalizedLandmarkListProto> &)landmarksProto
worldLandmarksProto:
(const std::vector<LandmarkListProto> &)worldLandmarksProto
segmentationMasks:(const std::vector<Image> *)segmentationMasks
timestampInMilliSeconds:(NSInteger)timestampInMilliseconds {
NSMutableArray<NSMutableArray<MPPNormalizedLandmark *> *> *multiplePoseLandmarks =
[NSMutableArray arrayWithCapacity:(NSUInteger)landmarksProto.size()];
for (const auto &landmarkListProto : landmarksProto) {
NSMutableArray<MPPNormalizedLandmark *> *landmarks =
[NSMutableArray arrayWithCapacity:(NSUInteger)landmarkListProto.landmark().size()];
for (const auto &normalizedLandmarkProto : landmarkListProto.landmark()) {
MPPNormalizedLandmark *normalizedLandmark =
[MPPNormalizedLandmark normalizedLandmarkWithProto:normalizedLandmarkProto];
[landmarks addObject:normalizedLandmark];
}
[multiplePoseLandmarks addObject:landmarks];
}
NSMutableArray<NSMutableArray<MPPLandmark *> *> *multiplePoseWorldLandmarks =
[NSMutableArray arrayWithCapacity:(NSUInteger)worldLandmarksProto.size()];
for (const auto &worldLandmarkListProto : worldLandmarksProto) {
NSMutableArray<MPPLandmark *> *worldLandmarks =
[NSMutableArray arrayWithCapacity:(NSUInteger)worldLandmarkListProto.landmark().size()];
for (const auto &landmarkProto : worldLandmarkListProto.landmark()) {
MPPLandmark *landmark = [MPPLandmark landmarkWithProto:landmarkProto];
[worldLandmarks addObject:landmark];
}
[multiplePoseWorldLandmarks addObject:worldLandmarks];
}
NSMutableArray<MPPMask *> *confidenceMasks =
[NSMutableArray arrayWithCapacity:(NSUInteger)segmentationMasks->size()];
for (const auto &segmentationMask : *segmentationMasks) {
[confidenceMasks addObject:[[MPPMask alloc] initWithFloat32Data:(float *)segmentationMask
.GetImageFrameSharedPtr()
.get()
->PixelData()
width:segmentationMask.width()
height:segmentationMask.height()
/** Always deep copy */
shouldCopy:YES]];
}
MPPPoseLandmarkerResult *poseLandmarkerResult =
[[MPPPoseLandmarkerResult alloc] initWithLandmarks:multiplePoseLandmarks
worldLandmarks:multiplePoseWorldLandmarks
segmentationMasks:confidenceMasks
timestampInMilliseconds:timestampInMilliseconds];
return poseLandmarkerResult;
}
+ (MPPPoseLandmarkerResult *)
poseLandmarkerResultWithLandmarksPacket:(const Packet &)landmarksPacket
worldLandmarksPacket:(const Packet &)worldLandmarksPacket
segmentationMasksPacket:(const Packet *)segmentationMasksPacket {
NSInteger timestampInMilliseconds =
(NSInteger)(landmarksPacket.Timestamp().Value() / kMicroSecondsPerMilliSecond);
if (landmarksPacket.IsEmpty()) {
return [MPPPoseLandmarkerResult
emptyPoseLandmarkerResultWithTimestampInMilliseconds:timestampInMilliseconds];
}
if (!landmarksPacket.ValidateAsType<std::vector<NormalizedLandmarkListProto>>().ok() ||
!worldLandmarksPacket.ValidateAsType<std::vector<LandmarkListProto>>().ok()) {
return [MPPPoseLandmarkerResult
emptyPoseLandmarkerResultWithTimestampInMilliseconds:timestampInMilliseconds];
}
const std::vector<Image> *segmentationMasks =
segmentationMasksPacket ? &(segmentationMasksPacket->Get<std::vector<Image>>()) : nullptr;
return [MPPPoseLandmarkerResult
poseLandmarkerResultWithLandmarksProto:landmarksPacket
.Get<std::vector<NormalizedLandmarkListProto>>()
worldLandmarksProto:worldLandmarksPacket
.Get<std::vector<LandmarkListProto>>()
segmentationMasks:segmentationMasks
timestampInMilliSeconds:timestampInMilliseconds];
}
@end

View File

@ -70,7 +70,7 @@ class BaseOptions:
platform_name = platform.system() platform_name = platform.system()
if self.delegate == BaseOptions.Delegate.GPU: if self.delegate == BaseOptions.Delegate.GPU:
if platform_name == 'Linux': if platform_name in ['Linux', 'Darwin']:
acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu()) acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu())
else: else:
raise NotImplementedError( raise NotImplementedError(

View File

@ -26,9 +26,11 @@ pybind_library(
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/python/pybind:util", "//mediapipe/python/pybind:util",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
], ],

View File

@ -14,6 +14,7 @@
#include "mediapipe/tasks/python/core/pybind/task_runner.h" #include "mediapipe/tasks/python/core/pybind/task_runner.h"
#include "absl/log/absl_log.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/python/pybind/util.h" #include "mediapipe/python/pybind/util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
@ -21,6 +22,9 @@
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h" #include "pybind11_protobuf/native_proto_caster.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc");
return absl::OkStatus(); return absl::OkStatus();
}; };
} }
#if !MEDIAPIPE_DISABLE_GPU
auto gpu_resources_ = mediapipe::GpuResources::Create();
if (!gpu_resources_.ok()) {
ABSL_LOG(INFO) << "GPU suport is not available: "
<< gpu_resources_.status();
gpu_resources_ = nullptr;
}
auto task_runner = TaskRunner::Create(
std::move(graph_config),
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback),
/* default_executor= */ nullptr,
/* input_side_packes= */ std::nullopt, std::move(*gpu_resources_));
#else
auto task_runner = TaskRunner::Create( auto task_runner = TaskRunner::Create(
std::move(graph_config), std::move(graph_config),
absl::make_unique<core::MediaPipeBuiltinOpResolver>(), absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback)); std::move(callback));
#endif // !MEDIAPIPE_DISABLE_GPU
RaisePyErrorIfNotOk(task_runner.status()); RaisePyErrorIfNotOk(task_runner.status());
return std::move(*task_runner); return std::move(*task_runner);
}, },

View File

@ -211,3 +211,20 @@ py_test(
"//mediapipe/tasks/python/vision/core:image_processing_options", "//mediapipe/tasks/python/vision/core:image_processing_options",
], ],
) )
py_test(
name = "face_stylizer_test",
srcs = ["face_stylizer_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:face_stylizer",
"//mediapipe/tasks/python/vision/core:image_processing_options",
],
)

View File

@ -0,0 +1,191 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for face stylizer."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import face_stylizer
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_BaseOptions = base_options_module.BaseOptions
_Rect = rect.Rect
_Image = image_module.Image
_FaceStylizer = face_stylizer.FaceStylizer
_FaceStylizerOptions = face_stylizer.FaceStylizerOptions
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL = 'face_stylizer_color_ink.task'
_LARGE_FACE_IMAGE = 'portrait.jpg'
_MODEL_IMAGE_SIZE = 256
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class FaceStylizerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
)
)
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL)
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _FaceStylizer.create_from_model_path(self.model_path) as stylizer:
self.assertIsInstance(stylizer, _FaceStylizer)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceStylizerOptions(base_options=base_options)
with _FaceStylizer.create_from_options(options) as stylizer:
self.assertIsInstance(stylizer, _FaceStylizer)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite'
)
options = _FaceStylizerOptions(base_options=base_options)
_FaceStylizer.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _FaceStylizerOptions(base_options=base_options)
stylizer = _FaceStylizer.create_from_options(options)
self.assertIsInstance(stylizer, _FaceStylizer)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
)
def test_stylize(self, model_file_type, image_file_name):
# Load the test image.
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, image_file_name)
)
)
# Creates stylizer.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceStylizerOptions(base_options=base_options)
stylizer = _FaceStylizer.create_from_options(options)
# Performs face stylization on the input.
stylized_image = stylizer.stylize(self.test_image)
self.assertIsInstance(stylized_image, _Image)
# Closes the stylizer explicitly when the stylizer is not used in
# a context.
stylizer.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
)
def test_stylize_in_context(self, model_file_type, image_file_name):
# Load the test image.
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, image_file_name)
)
)
# Creates stylizer.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceStylizerOptions(base_options=base_options)
with _FaceStylizer.create_from_options(options) as stylizer:
# Performs face stylization on the input.
stylized_image = stylizer.stylize(self.test_image)
self.assertIsInstance(stylized_image, _Image)
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
def test_stylize_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceStylizerOptions(base_options=base_options)
with _FaceStylizer.create_from_options(options) as stylizer:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
)
)
# Region-of-interest around the face.
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
image_processing_options = _ImageProcessingOptions(roi)
# Performs face stylization on the input.
stylized_image = stylizer.stylize(test_image, image_processing_options)
self.assertIsInstance(stylized_image, _Image)
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
def test_stylize_succeeds_with_no_face_detected(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceStylizerOptions(base_options=base_options)
with _FaceStylizer.create_from_options(options) as stylizer:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
)
)
# Region-of-interest that doesn't contain a human face.
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
image_processing_options = _ImageProcessingOptions(roi)
# Performs face stylization on the input.
stylized_image = stylizer.stylize(test_image, image_processing_options)
self.assertIsNone(stylized_image)
if __name__ == '__main__':
absltest.main()

View File

@ -48,6 +48,7 @@ mediapipe_files(srcs = [
"face_landmark.tflite", "face_landmark.tflite",
"face_landmarker.task", "face_landmarker.task",
"face_landmarker_v2.task", "face_landmarker_v2.task",
"face_stylizer_color_ink.task",
"fist.jpg", "fist.jpg",
"fist.png", "fist.png",
"gesture_recognizer.task", "gesture_recognizer.task",
@ -183,6 +184,7 @@ filegroup(
"face_detection_short_range.tflite", "face_detection_short_range.tflite",
"face_landmarker.task", "face_landmarker.task",
"face_landmarker_v2.task", "face_landmarker_v2.task",
"face_stylizer_color_ink.task",
"hair_segmentation.tflite", "hair_segmentation.tflite",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",

View File

@ -2854,7 +2854,262 @@ auxiliary_landmarks {
face_blendshapes { face_blendshapes {
classification { classification {
index: 0 index: 0
score: 1.6770242e-05 score: 8.47715e-07
label: "tongueOut" label: "_neutral"
}
classification {
index: 1
score: 0.020850565
label: "browDownLeft"
}
classification {
index: 2
score: 0.007629181
label: "browDownRight"
}
classification {
index: 3
score: 0.26410568
label: "browInnerUp"
}
classification {
index: 4
score: 0.04212071
label: "browOuterUpLeft"
}
classification {
index: 5
score: 0.07319052
label: "browOuterUpRight"
}
classification {
index: 6
score: 9.39117e-06
label: "cheekPuff"
}
classification {
index: 7
score: 1.9243858e-07
label: "cheekSquintLeft"
}
classification {
index: 8
score: 4.066475e-08
label: "cheekSquintRight"
}
classification {
index: 9
score: 0.46092203
label: "eyeBlinkLeft"
}
classification {
index: 10
score: 0.40371567
label: "eyeBlinkRight"
}
classification {
index: 11
score: 0.65011656
label: "eyeLookDownLeft"
}
classification {
index: 12
score: 0.6423024
label: "eyeLookDownRight"
}
classification {
index: 13
score: 0.04721973
label: "eyeLookInLeft"
}
classification {
index: 14
score: 0.08176838
label: "eyeLookInRight"
}
classification {
index: 15
score: 0.09520102
label: "eyeLookOutLeft"
}
classification {
index: 16
score: 0.07271895
label: "eyeLookOutRight"
}
classification {
index: 17
score: 0.011193463
label: "eyeLookUpLeft"
}
classification {
index: 18
score: 0.007041815
label: "eyeLookUpRight"
}
classification {
index: 19
score: 0.27120194
label: "eyeSquintLeft"
}
classification {
index: 20
score: 0.21675573
label: "eyeSquintRight"
}
classification {
index: 21
score: 0.0018824162
label: "eyeWideLeft"
}
classification {
index: 22
score: 0.0011966582
label: "eyeWideRight"
}
classification {
index: 23
score: 1.9298719e-05
label: "jawForward"
}
classification {
index: 24
score: 9.670858e-06
label: "jawLeft"
}
classification {
index: 25
score: 0.000115385694
label: "jawOpen"
}
classification {
index: 26
score: 0.00023342477
label: "jawRight"
}
classification {
index: 27
score: 2.8894076e-05
label: "mouthClose"
}
classification {
index: 28
score: 0.003933548
label: "mouthDimpleLeft"
}
classification {
index: 29
score: 0.0051949574
label: "mouthDimpleRight"
}
classification {
index: 30
score: 0.00067943585
label: "mouthFrownLeft"
}
classification {
index: 31
score: 0.0006520291
label: "mouthFrownRight"
}
classification {
index: 32
score: 0.0006695333
label: "mouthFunnel"
}
classification {
index: 33
score: 8.578597e-05
label: "mouthLeft"
}
classification {
index: 34
score: 2.6707421e-05
label: "mouthLowerDownLeft"
}
classification {
index: 35
score: 2.153054e-05
label: "mouthLowerDownRight"
}
classification {
index: 36
score: 0.0132145975
label: "mouthPressLeft"
}
classification {
index: 37
score: 0.009528495
label: "mouthPressRight"
}
classification {
index: 38
score: 0.056963783
label: "mouthPucker"
}
classification {
index: 39
score: 0.027331185
label: "mouthRight"
}
classification {
index: 40
score: 0.00072388636
label: "mouthRollLower"
}
classification {
index: 41
score: 0.00021191382
label: "mouthRollUpper"
}
classification {
index: 42
score: 0.23938002
label: "mouthShrugLower"
}
classification {
index: 43
score: 0.052946873
label: "mouthShrugUpper"
}
classification {
index: 44
score: 0.68681276
label: "mouthSmileLeft"
}
classification {
index: 45
score: 0.68557316
label: "mouthSmileRight"
}
classification {
index: 46
score: 0.0030625665
label: "mouthStretchLeft"
}
classification {
index: 47
score: 0.003999545
label: "mouthStretchRight"
}
classification {
index: 48
score: 0.013184475
label: "mouthUpperUpLeft"
}
classification {
index: 49
score: 0.017995607
label: "mouthUpperUpRight"
}
classification {
index: 50
score: 2.0452394e-06
label: "noseSneerLeft"
}
classification {
index: 51
score: 3.7912793e-07
label: "noseSneerRight"
} }
} }

View File

@ -31,27 +31,57 @@ mediapipe_ts_library(
mediapipe_ts_library( mediapipe_ts_library(
name = "drawing_utils", name = "drawing_utils",
srcs = ["drawing_utils.ts"], srcs = [
"drawing_utils.ts",
"drawing_utils_category_mask.ts",
],
deps = [ deps = [
":image",
":image_shader_context",
":mask",
":types", ":types",
"//mediapipe/tasks/web/components/containers:bounding_box", "//mediapipe/tasks/web/components/containers:bounding_box",
"//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )
mediapipe_ts_library( mediapipe_ts_library(
name = "image", name = "drawing_utils_test_lib",
srcs = [ testonly = True,
"image.ts", srcs = ["drawing_utils.test.ts"],
"image_shader_context.ts", deps = [
":drawing_utils",
":image",
":image_shader_context",
":mask",
], ],
) )
jasmine_node_test(
name = "drawing_utils_test",
deps = [":drawing_utils_test_lib"],
)
mediapipe_ts_library(
name = "image",
srcs = ["image.ts"],
deps = ["image_shader_context"],
)
mediapipe_ts_library(
name = "image_shader_context",
srcs = ["image_shader_context.ts"],
)
mediapipe_ts_library( mediapipe_ts_library(
name = "image_test_lib", name = "image_test_lib",
testonly = True, testonly = True,
srcs = ["image.test.ts"], srcs = ["image.test.ts"],
deps = [":image"], deps = [
":image",
":image_shader_context",
],
) )
jasmine_node_test( jasmine_node_test(
@ -64,6 +94,7 @@ mediapipe_ts_library(
srcs = ["mask.ts"], srcs = ["mask.ts"],
deps = [ deps = [
":image", ":image",
":image_shader_context",
"//mediapipe/web/graph_runner:platform_utils", "//mediapipe/web/graph_runner:platform_utils",
], ],
) )
@ -74,6 +105,7 @@ mediapipe_ts_library(
srcs = ["mask.test.ts"], srcs = ["mask.test.ts"],
deps = [ deps = [
":image", ":image",
":image_shader_context",
":mask", ":mask",
], ],
) )
@ -89,6 +121,7 @@ mediapipe_ts_library(
deps = [ deps = [
":image", ":image",
":image_processing_options", ":image_processing_options",
":image_shader_context",
":mask", ":mask",
":vision_task_options", ":vision_task_options",
"//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto",

View File

@ -0,0 +1,103 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {DrawingUtils} from './drawing_utils';
import {MPImageShaderContext} from './image_shader_context';
import {MPMask} from './mask';
const WIDTH = 2;
const HEIGHT = 2;
const skip = typeof document === 'undefined';
if (skip) {
console.log('These tests must be run in a browser.');
}
(skip ? xdescribe : describe)('DrawingUtils', () => {
let shaderContext = new MPImageShaderContext();
let canvas2D: HTMLCanvasElement;
let context2D: CanvasRenderingContext2D;
let drawingUtils2D: DrawingUtils;
let canvasWebGL: HTMLCanvasElement;
let contextWebGL: WebGL2RenderingContext;
let drawingUtilsWebGL: DrawingUtils;
beforeEach(() => {
shaderContext = new MPImageShaderContext();
canvasWebGL = document.createElement('canvas');
canvasWebGL.width = WIDTH;
canvasWebGL.height = HEIGHT;
contextWebGL = canvasWebGL.getContext('webgl2')!;
drawingUtilsWebGL = new DrawingUtils(contextWebGL);
canvas2D = document.createElement('canvas');
canvas2D.width = WIDTH;
canvas2D.height = HEIGHT;
context2D = canvas2D.getContext('2d')!;
drawingUtils2D = new DrawingUtils(context2D, contextWebGL);
});
afterEach(() => {
shaderContext.close();
drawingUtils2D.close();
drawingUtilsWebGL.close();
});
describe('drawCategoryMask() ', () => {
const colors = [
[0, 0, 0, 255],
[0, 255, 0, 255],
[0, 0, 255, 255],
[255, 255, 255, 255],
];
const expectedResult = new Uint8Array(
[0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255],
);
it('on 2D canvas', () => {
const categoryMask = new MPMask(
[new Uint8Array([0, 1, 2, 3])],
/* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH,
HEIGHT);
drawingUtils2D.drawCategoryMask(categoryMask, colors);
const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data;
expect(actualResult)
.toEqual(new Uint8ClampedArray(expectedResult.buffer));
});
it('on WebGL canvas', () => {
const categoryMask = new MPMask(
[new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped
/* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH,
HEIGHT);
drawingUtilsWebGL.drawCategoryMask(categoryMask, colors);
const actualResult = new Uint8Array(WIDTH * WIDTH * 4);
contextWebGL.readPixels(
0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE,
actualResult);
expect(actualResult).toEqual(expectedResult);
});
});
// TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox
});

Some files were not shown because too many files have changed in this diff Show More