diff --git a/WORKSPACE b/WORKSPACE index df2c4f93b..3a539569f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -513,6 +513,9 @@ http_archive( "@//third_party:org_tensorflow_system_python.diff", # Diff is generated with a script, don't update it manually. "@//third_party:org_tensorflow_custom_ops.diff", + # Works around Bazel issue with objc_library. + # See https://github.com/bazelbuild/bazel/issues/19912 + "@//third_party:org_tensorflow_objc_build_fixes.diff", ], patch_args = [ "-p1", diff --git a/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.pbxproj b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.pbxproj new file mode 100644 index 000000000..8a95288c9 --- /dev/null +++ b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.pbxproj @@ -0,0 +1,342 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 56; + objects = { + +/* Begin PBXBuildFile section */ + 8566B55D2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h in Headers */ = {isa = PBXBuildFile; fileRef = 8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */; settings = {ATTRIBUTES = (Public, ); }; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = MediaPipeTasksDocGen.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + 8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MediaPipeTasksDocGen.h; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 8566B5562ABABF9A00AAB22A /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 8566B54F2ABABF9A00AAB22A = { + isa = PBXGroup; + children = ( + 8566B55B2ABABF9A00AAB22A /* MediaPipeTasksDocGen */, + 8566B55A2ABABF9A00AAB22A /* Products */, + ); + sourceTree = ""; + }; + 8566B55A2ABABF9A00AAB22A /* Products */ = { + isa = PBXGroup; + children = ( + 8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */, + ); + name = Products; + sourceTree = ""; + }; + 8566B55B2ABABF9A00AAB22A /* MediaPipeTasksDocGen */ = { + isa = PBXGroup; + children = ( + 8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */, + ); + path = MediaPipeTasksDocGen; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXHeadersBuildPhase section */ + 8566B5542ABABF9A00AAB22A /* Headers */ = { + isa = PBXHeadersBuildPhase; + buildActionMask = 2147483647; + files = ( + 8566B55D2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h in Headers */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXHeadersBuildPhase section */ + +/* Begin PBXNativeTarget section */ + 8566B5582ABABF9A00AAB22A /* MediaPipeTasksDocGen */ = { + isa = PBXNativeTarget; + buildConfigurationList = 8566B5602ABABF9A00AAB22A /* Build configuration list for PBXNativeTarget "MediaPipeTasksDocGen" */; + buildPhases = ( + 8566B5542ABABF9A00AAB22A /* Headers */, + 8566B5552ABABF9A00AAB22A /* Sources */, + 8566B5562ABABF9A00AAB22A /* Frameworks */, + 8566B5572ABABF9A00AAB22A /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = MediaPipeTasksDocGen; + productName = MediaPipeTasksDocGen; + productReference = 8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */; + productType = "com.apple.product-type.framework"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 8566B5502ABABF9A00AAB22A /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = 1; + LastUpgradeCheck = 1430; + TargetAttributes = { + 8566B5582ABABF9A00AAB22A = { + CreatedOnToolsVersion = 14.3.1; + }; + }; + }; + buildConfigurationList = 8566B5532ABABF9A00AAB22A /* Build configuration list for PBXProject "MediaPipeTasksDocGen" */; + compatibilityVersion = "Xcode 14.0"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 8566B54F2ABABF9A00AAB22A; + productRefGroup = 8566B55A2ABABF9A00AAB22A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 8566B5582ABABF9A00AAB22A /* MediaPipeTasksDocGen */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 8566B5572ABABF9A00AAB22A /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 8566B5552ABABF9A00AAB22A /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 8566B55E2ABABF9A00AAB22A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.4; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + VERSIONING_SYSTEM = "apple-generic"; + VERSION_INFO_PREFIX = ""; + }; + name = Debug; + }; + 8566B55F2ABABF9A00AAB22A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.4; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + VERSIONING_SYSTEM = "apple-generic"; + VERSION_INFO_PREFIX = ""; + }; + name = Release; + }; + 8566B5612ABABF9A00AAB22A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEFINES_MODULE = YES; + DYLIB_COMPATIBILITY_VERSION = 1; + DYLIB_CURRENT_VERSION = 1; + DYLIB_INSTALL_NAME_BASE = "@rpath"; + ENABLE_MODULE_VERIFIER = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + "@loader_path/Frameworks", + ); + MARKETING_VERSION = 1.0; + MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu11 gnu++20"; + PRODUCT_BUNDLE_IDENTIFIER = com.google.mediapipe.MediaPipeTasksDocGen; + PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + SKIP_INSTALL = YES; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 8566B5622ABABF9A00AAB22A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEFINES_MODULE = YES; + DYLIB_COMPATIBILITY_VERSION = 1; + DYLIB_CURRENT_VERSION = 1; + DYLIB_INSTALL_NAME_BASE = "@rpath"; + ENABLE_MODULE_VERIFIER = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + "@loader_path/Frameworks", + ); + MARKETING_VERSION = 1.0; + MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu11 gnu++20"; + PRODUCT_BUNDLE_IDENTIFIER = com.google.mediapipe.MediaPipeTasksDocGen; + PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + SKIP_INSTALL = YES; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 8566B5532ABABF9A00AAB22A /* Build configuration list for PBXProject "MediaPipeTasksDocGen" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 8566B55E2ABABF9A00AAB22A /* Debug */, + 8566B55F2ABABF9A00AAB22A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 8566B5602ABABF9A00AAB22A /* Build configuration list for PBXNativeTarget "MediaPipeTasksDocGen" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 8566B5612ABABF9A00AAB22A /* Debug */, + 8566B5622ABABF9A00AAB22A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 8566B5502ABABF9A00AAB22A /* Project object */; +} diff --git a/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 000000000..919434a62 --- /dev/null +++ b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 000000000..b3ea17378 --- /dev/null +++ b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/xcuserdata/macd.xcuserdatad/UserInterfaceState.xcuserstate b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/xcuserdata/macd.xcuserdatad/UserInterfaceState.xcuserstate new file mode 100644 index 000000000..d667b462e Binary files /dev/null and b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/project.xcworkspace/xcuserdata/macd.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/xcuserdata/macd.xcuserdatad/xcschemes/xcschememanagement.plist b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/xcuserdata/macd.xcuserdatad/xcschemes/xcschememanagement.plist new file mode 100644 index 000000000..adc534a03 --- /dev/null +++ b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen.xcodeproj/xcuserdata/macd.xcuserdatad/xcschemes/xcschememanagement.plist @@ -0,0 +1,14 @@ + + + + + SchemeUserState + + MediaPipeTasksDocGen.xcscheme_^#shared#^_ + + orderHint + 0 + + + + diff --git a/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen/MediaPipeTasksDocGen.h b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen/MediaPipeTasksDocGen.h new file mode 100644 index 000000000..2ce44b27b --- /dev/null +++ b/docs/MediaPipeTasksDocGen/MediaPipeTasksDocGen/MediaPipeTasksDocGen.h @@ -0,0 +1,17 @@ +// +// MediaPipeTasksDocGen.h +// MediaPipeTasksDocGen +// +// Created by Mark McDonald on 20/9/2023. +// + +#import + +//! Project version number for MediaPipeTasksDocGen. +FOUNDATION_EXPORT double MediaPipeTasksDocGenVersionNumber; + +//! Project version string for MediaPipeTasksDocGen. +FOUNDATION_EXPORT const unsigned char MediaPipeTasksDocGenVersionString[]; + +// In this header, you should import all the public headers of your framework using statements like +// #import diff --git a/docs/MediaPipeTasksDocGen/Podfile b/docs/MediaPipeTasksDocGen/Podfile new file mode 100644 index 000000000..3c8d8f09d --- /dev/null +++ b/docs/MediaPipeTasksDocGen/Podfile @@ -0,0 +1,11 @@ +# Uncomment the next line to define a global platform for your project +platform :ios, '15.0' + +target 'MediaPipeTasksDocGen' do + # Comment the next line if you don't want to use dynamic frameworks + use_frameworks! + + # Pods for MediaPipeTasksDocGen + pod 'MediaPipeTasksText' + pod 'MediaPipeTasksVision' +end diff --git a/docs/MediaPipeTasksDocGen/README.md b/docs/MediaPipeTasksDocGen/README.md new file mode 100644 index 000000000..475253057 --- /dev/null +++ b/docs/MediaPipeTasksDocGen/README.md @@ -0,0 +1,9 @@ +# MediaPipeTasksDocGen + +This empty project is used to generate reference documentation for the +ObjectiveC and Swift libraries. + +Docs are generated using [Jazzy](https://github.com/realm/jazzy) and published +to [the developer site](https://developers.google.com/mediapipe/solutions/). + +To bump the API version used, edit [`Podfile`](./Podfile). diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index aacf694c1..729e91492 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -727,6 +727,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", ], alwayslink = 1, ) @@ -742,6 +743,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index 311f7d815..686d705dd 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/status/status.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" @@ -32,6 +33,7 @@ namespace { constexpr char kTagAtPreStream[] = "AT_PRESTREAM"; constexpr char kTagAtPostStream[] = "AT_POSTSTREAM"; constexpr char kTagAtZero[] = "AT_ZERO"; +constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK"; constexpr char kTagAtTick[] = "AT_TICK"; constexpr char kTagTick[] = "TICK"; constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP"; @@ -43,6 +45,7 @@ static std::map* kTimestampMap = []() { res->emplace(kTagAtPostStream, Timestamp::PostStream()); res->emplace(kTagAtZero, Timestamp(0)); res->emplace(kTagAtTick, Timestamp::Unset()); + res->emplace(kTagAtFirstTick, Timestamp::Unset()); res->emplace(kTagAtTimestamp, Timestamp::Unset()); return res; }(); @@ -59,8 +62,8 @@ std::string GetOutputTag(const CC& cc) { // timestamp, depending on the tag used to define output stream(s). (One tag can // be used only.) // -// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP -// and corresponding timestamps are Timestamp::PreStream(), +// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK, +// AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(), // Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK // input, and timestamp received from a side input. // @@ -96,6 +99,7 @@ class SidePacketToStreamCalculator : public CalculatorBase { private: bool is_tick_processing_ = false; + bool close_on_first_tick_ = false; std::string output_tag_; }; REGISTER_CALCULATOR(SidePacketToStreamCalculator); @@ -103,13 +107,16 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { const auto& tags = cc->Outputs().GetTags(); RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) - << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s)."; - RET_CHECK( - (cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) || - (!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick))) - << "Either both of TICK and AT_TICK should be used or none of them."; + << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, " + "AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to " + "specify output stream(s)."; + const bool has_tick_output = + cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick); + const bool has_tick_input = cc->Inputs().HasTag(kTagTick); + RET_CHECK((has_tick_output && has_tick_input) || + (!has_tick_output && !has_tick_input)) + << "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output " + "should be used or none of them."; RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) && cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) || (!cc->Outputs().HasTag(kTagAtTimestamp) && @@ -148,11 +155,17 @@ absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { // timestamp bound update. cc->SetOffset(TimestampDiff(0)); } + if (output_tag_ == kTagAtFirstTick) { + close_on_first_tick_ = true; + } return absl::OkStatus(); } absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { if (is_tick_processing_) { + if (cc->Outputs().Get(output_tag_, 0).IsClosed()) { + return absl::OkStatus(); + } // TICK input is guaranteed to be non-empty, as it's the only input stream // for this calculator. const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp(); @@ -160,6 +173,9 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { cc->Outputs() .Get(output_tag_, i) .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); + if (close_on_first_tick_) { + cc->Outputs().Get(output_tag_, i).Close(); + } } return absl::OkStatus(); @@ -170,6 +186,7 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { if (!cc->Outputs().HasTag(kTagAtTick) && + !cc->Outputs().HasTag(kTagAtFirstTick) && !cc->Outputs().HasTag(kTagAtTimestamp)) { const auto& timestamp = kTimestampMap->at(output_tag_); for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc index 086b73fcd..6c0941b44 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc @@ -27,13 +27,17 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/tool/options_util.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { -using testing::HasSubstr; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; -TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { +TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTick) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -52,10 +56,35 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { EXPECT_THAT( status.message(), HasSubstr( - "Either both of TICK and AT_TICK should be used or none of them.")); + "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output " + "should be used or none of them.")); } -TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { +TEST(SidePacketToStreamCalculator, + WrongConfigWithMissingTickForFirstTickProcessing) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"pb( + input_stream: "tick" + input_side_packet: "side_packet" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_side_packet: "side_packet" + output_stream: "AT_FIRST_TICK:packet" + } + )pb"); + CalculatorGraph graph; + auto status = graph.Initialize(graph_config); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr( + "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output " + "should be used or none of them.")); +} + +TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTimestampSideInput) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -76,7 +105,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { "or none of them.")); } -TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { +TEST(SidePacketToStreamCalculator, WrongConfigWithNonExistentTag) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -92,14 +121,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s).")); + EXPECT_THAT(status.message(), + HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, " + "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is " + "allowed and required to specify output stream(s).")); } -TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { +TEST(SidePacketToStreamCalculator, WrongConfigWithMixedTags) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -117,14 +145,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s).")); + EXPECT_THAT(status.message(), + HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, " + "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is " + "allowed and required to specify output stream(s).")); } -TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { +TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughSidePackets) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -146,7 +173,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { "Same number of input side packets and output streams is required.")); } -TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) { +TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughOutputStreams) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -248,7 +275,50 @@ TEST(SidePacketToStreamCalculator, AtTick) { tick_and_verify(/*at_timestamp=*/1025); } -TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { +TEST(SidePacketToStreamCalculator, AtFirstTick) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"pb( + input_stream: "tick" + input_side_packet: "side_packet" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_stream: "TICK:tick" + input_side_packet: "side_packet" + output_stream: "AT_FIRST_TICK:packet" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("packet", &graph_config, &output_packets); + CalculatorGraph graph; + + MP_ASSERT_OK(graph.Initialize(graph_config)); + const int expected_value = 20; + const Timestamp kTestTimestamp(1234); + MP_ASSERT_OK( + graph.StartRun({{"side_packet", MakePacket(expected_value)}})); + + auto insert_tick = [&graph](Timestamp at_timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tick", MakePacket(/*doesn't matter*/ 1).At(at_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + insert_tick(kTestTimestamp); + + EXPECT_THAT(output_packets, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(kTestTimestamp), Eq(expected_value)))); + + output_packets.clear(); + + // Should not result in an additional output. + insert_tick(kTestTimestamp + 1); + EXPECT_THAT(output_packets, IsEmpty()); +} + +TEST(SidePacketToStreamCalculator, AtTickWithMultipleSidePackets) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -302,6 +372,62 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { tick_and_verify(/*at_timestamp=*/1025); } +TEST(SidePacketToStreamCalculator, AtFirstTickWithMultipleSidePackets) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"pb( + input_stream: "tick" + input_side_packet: "side_packet0" + input_side_packet: "side_packet1" + output_stream: "packet0" + output_stream: "packet1" + node { + calculator: "SidePacketToStreamCalculator" + input_stream: "TICK:tick" + input_side_packet: "side_packet0" + input_side_packet: "side_packet1" + output_stream: "AT_FIRST_TICK:0:packet0" + output_stream: "AT_FIRST_TICK:1:packet1" + } + )pb"); + std::vector output_packets0; + tool::AddVectorSink("packet0", &graph_config, &output_packets0); + std::vector output_packets1; + tool::AddVectorSink("packet1", &graph_config, &output_packets1); + CalculatorGraph graph; + + MP_ASSERT_OK(graph.Initialize(graph_config)); + const int expected_value0 = 20; + const int expected_value1 = 128; + const Timestamp kTestTimestamp(1234); + MP_ASSERT_OK( + graph.StartRun({{"side_packet0", MakePacket(expected_value0)}, + {"side_packet1", MakePacket(expected_value1)}})); + + auto insert_tick = [&graph](Timestamp at_timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tick", MakePacket(/*doesn't matter*/ 1).At(at_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + insert_tick(kTestTimestamp); + + EXPECT_THAT(output_packets0, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(kTestTimestamp), Eq(expected_value0)))); + EXPECT_THAT(output_packets1, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(kTestTimestamp), Eq(expected_value1)))); + + output_packets0.clear(); + output_packets1.clear(); + + // Should not result in an additional output. + insert_tick(kTestTimestamp + 1); + EXPECT_THAT(output_packets0, IsEmpty()); + EXPECT_THAT(output_packets1, IsEmpty()); +} + TEST(SidePacketToStreamCalculator, AtTimestamp) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( @@ -334,7 +460,7 @@ TEST(SidePacketToStreamCalculator, AtTimestamp) { EXPECT_EQ(expected_value, output_packets.back().Get()); } -TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) { +TEST(SidePacketToStreamCalculator, AtTimestampWithMultipleOutputs) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( diff --git a/mediapipe/calculators/image/image_clone_calculator.cc b/mediapipe/calculators/image/image_clone_calculator.cc index 563b4a4ad..0929e81e5 100644 --- a/mediapipe/calculators/image/image_clone_calculator.cc +++ b/mediapipe/calculators/image/image_clone_calculator.cc @@ -65,7 +65,7 @@ class ImageCloneCalculator : public Node { } #else MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( - cc, /*requesst_gpu_as_optional=*/true)); + cc, /*request_gpu_as_optional=*/true)); #endif // MEDIAPIPE_DISABLE_GPU return absl::OkStatus(); } diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc index d238975c6..ab2148f36 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc @@ -118,7 +118,7 @@ absl::Status SegmentationSmoothingCalculator::GetContract( #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( - cc, /*requesst_gpu_as_optional=*/true)); + cc, /*request_gpu_as_optional=*/true)); #endif // !MEDIAPIPE_DISABLE_GPU return absl::OkStatus(); diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc index dba500dfa..0bbf6c72d 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.cc +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -206,7 +206,7 @@ class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl { if constexpr (std::is_same_v || std::is_same_v) { MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( - cc, /*requesst_gpu_as_optional=*/true)); + cc, /*request_gpu_as_optional=*/true)); } return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 6c433e9b5..ac2ced837 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -1480,7 +1480,6 @@ cc_test( "@com_google_absl//absl/log", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", ], ) diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index c8d38a653..7230e178d 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -109,7 +109,7 @@ bool IsValidFftSize(int size) { // Non-streaming mode: when "stream_mode" is set to false in the calculator // options, the calculators treats the packets in the input audio stream as // a batch of unrelated audio buffers. In each Process() call, the input -// buffer will be frist resampled, and framed as fixed-sized, possibly +// buffer will be first resampled, and framed as fixed-sized, possibly // overlapping tensors. The last tensor produced by a Process() invocation // will be zero-padding if the remaining samples are insufficient. As the // calculator treats the input packets as unrelated, all samples will be @@ -159,7 +159,7 @@ class AudioToTensorCalculator : public Node { public: static constexpr Input kAudioIn{"AUDIO"}; // TODO: Removes this optional input stream when the "AUDIO" stream - // uses the new mediapipe audio data containers that carry audio metatdata, + // uses the new mediapipe audio data containers that carry audio metadata, // such as sample rate. static constexpr Input::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; static constexpr Output> kTensorsOut{"TENSORS"}; diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index 948c82a36..a49825586 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -37,7 +37,7 @@ message AudioToTensorCalculatorOptions { // will be converted into tensors. optional double target_sample_rate = 4; - // Whether to treat the input audio stream as a continous stream or a batch + // Whether to treat the input audio stream as a continuous stream or a batch // of unrelated audio buffers. optional bool stream_mode = 5 [default = true]; diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 171b28eb4..924df6af3 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -82,7 +82,7 @@ namespace api2 { // // Outputs: // TENSORS - std::vector -// Vector containing a single Tensor populated with an extrated RGB image. +// Vector containing a single Tensor populated with an extracted RGB image. // MATRIX - std::array @Optional // An std::array representing a 4x4 row-major-order matrix that // maps a point on the input image to a point on the output tensor, and @@ -212,7 +212,7 @@ class ImageToTensorCalculator : public Node { std::array matrix; GetRotatedSubRectToRectTransformMatrix( roi, image->width(), image->height(), - /*flip_horizontaly=*/false, &matrix); + /*flip_horizontally=*/false, &matrix); kOutMatrix(cc).Send(std::move(matrix)); } diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 7017c1e3a..51150a1ca 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -206,7 +206,7 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { } else if (image_channels == 1) { return ImageFormat::GRAY8; } - ABSL_CHECK(false) << "Unsupported input image channles: " << image_channels; + ABSL_CHECK(false) << "Unsupported input image channels: " << image_channels; } Packet MakeImageFramePacket(cv::Mat input) { diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index b32b67869..04b791bd4 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -57,7 +57,7 @@ class SubRectExtractorGl { absl::Status ExtractSubRectToBuffer( const tflite::gpu::gl::GlTexture& texture, const tflite::gpu::HW& texture_size, const RotatedRect& sub_rect, - bool flip_horizontaly, float alpha, float beta, + bool flip_horizontally, float alpha, float beta, const tflite::gpu::HW& destination_size, tflite::gpu::gl::CommandQueue* command_queue, tflite::gpu::gl::GlBuffer* destination); @@ -154,13 +154,13 @@ void main() { absl::Status SubRectExtractorGl::ExtractSubRectToBuffer( const tflite::gpu::gl::GlTexture& texture, const tflite::gpu::HW& texture_size, const RotatedRect& texture_sub_rect, - bool flip_horizontaly, float alpha, float beta, + bool flip_horizontally, float alpha, float beta, const tflite::gpu::HW& destination_size, tflite::gpu::gl::CommandQueue* command_queue, tflite::gpu::gl::GlBuffer* destination) { std::array transform_mat; GetRotatedSubRectToRectTransformMatrix(texture_sub_rect, texture_size.w, - texture_size.h, flip_horizontaly, + texture_size.h, flip_horizontally, &transform_mat); MP_RETURN_IF_ERROR(texture.BindAsSampler2D(0)); @@ -308,7 +308,7 @@ class GlProcessor : public ImageToTensorConverter { input_texture, tflite::gpu::HW(source_texture.height(), source_texture.width()), roi, - /*flip_horizontaly=*/false, transform.scale, transform.offset, + /*flip_horizontally=*/false, transform.scale, transform.offset, tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]), command_queue_.get(), &output)); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 2522cae85..930d9fe21 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -199,7 +199,7 @@ class GlProcessor : public ImageToTensorConverter { range_min, range_max)); auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView(); MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi, - /*flip_horizontaly=*/false, + /*flip_horizontally=*/false, transform.scale, transform.offset, output_shape, &tensor_view)); return absl::OkStatus(); @@ -210,7 +210,7 @@ class GlProcessor : public ImageToTensorConverter { absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, const RotatedRect& sub_rect, - bool flip_horizontaly, float alpha, float beta, + bool flip_horizontally, float alpha, float beta, const Tensor::Shape& output_shape, Tensor::OpenGlTexture2dView* output) { const int output_height = output_shape.dims[1]; @@ -263,13 +263,13 @@ class GlProcessor : public ImageToTensorConverter { ABSL_LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) { GetTransposedRotatedSubRectToRectTransformMatrix( - sub_rect, texture.width(), texture.height(), flip_horizontaly, + sub_rect, texture.width(), texture.height(), flip_horizontally, &transform_mat); glUniformMatrix4fv(matrix_id_, 1, GL_FALSE, transform_mat.data()); } else { GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(), - texture.height(), flip_horizontaly, - &transform_mat); + texture.height(), + flip_horizontally, &transform_mat); glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data()); } diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index cef2abcd7..f47d2da9a 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -179,13 +179,13 @@ class SubRectExtractorMetal { } absl::Status Execute(id input_texture, - const RotatedRect& sub_rect, bool flip_horizontaly, + const RotatedRect& sub_rect, bool flip_horizontally, float alpha, float beta, const tflite::gpu::HW& destination_size, id command_buffer, id destination) { auto output_texture = MTLTextureWithBuffer(destination_size, destination); - return InternalExecute(input_texture, sub_rect, flip_horizontaly, alpha, + return InternalExecute(input_texture, sub_rect, flip_horizontally, alpha, beta, destination_size, command_buffer, output_texture); } @@ -211,7 +211,7 @@ class SubRectExtractorMetal { absl::Status InternalExecute(id input_texture, const RotatedRect& sub_rect, - bool flip_horizontaly, float alpha, float beta, + bool flip_horizontally, float alpha, float beta, const tflite::gpu::HW& destination_size, id command_buffer, id output_texture) { @@ -223,7 +223,7 @@ class SubRectExtractorMetal { std::array transform_mat; GetRotatedSubRectToRectTransformMatrix(sub_rect, input_texture.width, input_texture.height, - flip_horizontaly, &transform_mat); + flip_horizontally, &transform_mat); id transform_mat_buffer = [device_ newBufferWithBytes:&transform_mat length:sizeof(transform_mat) @@ -383,7 +383,7 @@ class MetalProcessor : public ImageToTensorConverter { MtlBufferView::GetWriteView(output_tensor, command_buffer); MP_RETURN_IF_ERROR(extractor_->Execute( texture, roi, - /*flip_horizontaly=*/false, transform.scale, transform.offset, + /*flip_horizontally=*/false, transform.scale, transform.offset, tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]), command_buffer, buffer_view.buffer())); [command_buffer commit]; diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index 3f91f3dc2..b6ed5216c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -92,7 +92,7 @@ absl::StatusOr GetValueRangeTransformation( void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, int rect_width, int rect_height, - bool flip_horizontaly, + bool flip_horizontally, std::array* matrix_ptr) { std::array& matrix = *matrix_ptr; // The resulting matrix is multiplication of below commented out matrices: @@ -118,7 +118,7 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, // {0.0f, 0.0f, a, 0.0f} // {0.0f, 0.0f, 0.0f, 1.0f} - const float flip = flip_horizontaly ? -1 : 1; + const float flip = flip_horizontally ? -1 : 1; // Matrix for optional horizontal flip around middle of output image. // { fl , 0.0f, 0.0f, 0.0f} // { 0.0f, 1.0f, 0.0f, 0.0f} @@ -177,13 +177,13 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, void GetTransposedRotatedSubRectToRectTransformMatrix( const RotatedRect& sub_rect, int rect_width, int rect_height, - bool flip_horizontaly, std::array* matrix_ptr) { + bool flip_horizontally, std::array* matrix_ptr) { std::array& matrix = *matrix_ptr; // See comments in GetRotatedSubRectToRectTransformMatrix for detailed // calculations. const float a = sub_rect.width; const float b = sub_rect.height; - const float flip = flip_horizontaly ? -1 : 1; + const float flip = flip_horizontally ? -1 : 1; const float c = std::cos(sub_rect.rotation); const float d = std::sin(sub_rect.rotation); const float e = sub_rect.center_x; diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.h b/mediapipe/calculators/tensor/image_to_tensor_utils.h index a73529dce..63810923d 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.h +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.h @@ -74,7 +74,7 @@ absl::StatusOr> PadRoi(int input_tensor_width, // Represents a transformation of value which involves scaling and offsetting. // To apply transformation: // ValueTransformation transform = ... -// float transformed_value = transform.scale * value + transfrom.offset; +// float transformed_value = transform.scale * value + transform.offset; struct ValueTransformation { float scale; float offset; @@ -99,11 +99,11 @@ absl::StatusOr GetValueRangeTransformation( // @sub_rect - rotated sub rect in absolute coordinates // @rect_width - rect width // @rect_height - rect height -// @flip_horizontaly - we need to flip the output buffer. +// @flip_horizontally - we need to flip the output buffer. // @matrix - 4x4 matrix (array of 16 elements) to populate void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, int rect_width, int rect_height, - bool flip_horizontaly, + bool flip_horizontally, std::array* matrix); // Returns the transpose of the matrix found with @@ -118,11 +118,11 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, // @sub_rect - rotated sub rect in absolute coordinates // @rect_width - rect width // @rect_height - rect height -// @flip_horizontaly - we need to flip the output buffer. +// @flip_horizontally - we need to flip the output buffer. // @matrix - 4x4 matrix (array of 16 elements) to populate void GetTransposedRotatedSubRectToRectTransformMatrix( const RotatedRect& sub_rect, int rect_width, int rect_height, - bool flip_horizontaly, std::array* matrix); + bool flip_horizontally, std::array* matrix); // Validates the output dimensions set in the option proto. The input option // proto is expected to have to following fields: diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.proto b/mediapipe/calculators/tensor/tensor_converter_calculator.proto index 2c5e0be56..b80d1e805 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.proto +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.proto @@ -32,7 +32,7 @@ message TensorConverterCalculatorOptions { // Custom settings to override the internal scaling factors `div` and `sub`. // Both values must be set to non-negative values. Will only take effect on // CPU AND when |use_custom_normalization| is set to true. When these custom - // values take effect, the |zero_center| setting above will be overriden, and + // values take effect, the |zero_center| setting above will be overridden, and // the normalized_value will be calculated as: // normalized_value = input / custom_div - custom_sub. optional bool use_custom_normalization = 6 [default = false]; diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto index 32bc4b63a..28012a455 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -34,7 +34,7 @@ message TensorsToClassificationCalculatorOptions { repeated Entry entries = 1; } - // Score threshold for perserving the class. + // Score threshold for preserving the class. optional float min_score_threshold = 1; // Number of highest scoring labels to output. If top_k is not positive then // all labels are used. diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 8e649c0a1..2b4a22fc6 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -15,7 +15,6 @@ #include #include -#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" @@ -147,7 +146,7 @@ BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) { // TENSORS - Vector of Tensors of type kFloat32. The vector of tensors can have // 2 or 3 tensors. First tensor is the predicted raw boxes/keypoints. // The size of the values must be (num_boxes * num_predicted_values). -// Second tensor is the score tensor. The size of the valuse must be +// Second tensor is the score tensor. The size of the values must be // (num_boxes * num_classes). It's optional to pass in a third tensor // for anchors (e.g. for SSD models) depend on the outputs of the // detection model. The size of anchor tensor must be (num_boxes * @@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node { const int* detection_classes, std::vector* output_detections); Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, - float box_xmax, float score, int class_id, + float box_xmax, absl::Span scores, + absl::Span class_ids, bool flip_vertically); bool IsClassIndexAllowed(int class_index); @@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node { int num_boxes_ = 0; int num_coords_ = 0; int max_results_ = -1; + int classes_per_detection_ = 1; BoxFormat box_output_format_ = mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; @@ -267,7 +268,7 @@ absl::Status TensorsToDetectionsCalculator::UpdateContract( if (CanUseGpu()) { #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( - cc, /*requesst_gpu_as_optional=*/true)); + cc, /*request_gpu_as_optional=*/true)); #elif MEDIAPIPE_METAL_ENABLED MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) @@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( auto num_boxes_view = num_boxes_tensor->GetCpuReadView(); auto num_boxes = num_boxes_view.buffer(); num_boxes_ = num_boxes[0]; + // The detection model with Detection_PostProcess op may output duplicate + // boxes with different classes, in the following format: + // num_boxes_tensor = [num_boxes] + // detection_classes_tensor = [box_1_class_1, box_1_class_2, ...] + // detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ] + // detection_boxes_tensor = [box_1, box1, ... ] + // Each box repeats classes_per_detection_ times. + // Note Detection_PostProcess op is only supported in CPU. + RET_CHECK_EQ(max_detections % num_boxes_, 0); + classes_per_detection_ = max_detections / num_boxes_; auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView(); auto detection_boxes = detection_boxes_view.buffer(); @@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( auto detection_classes_view = detection_classes_tensor->GetCpuReadView(); auto detection_classes_ptr = detection_classes_view.buffer(); - std::vector detection_classes(num_boxes_); - for (int i = 0; i < num_boxes_; ++i) { + std::vector detection_classes(num_boxes_ * classes_per_detection_); + for (int i = 0; i < detection_classes.size(); ++i) { detection_classes[i] = static_cast(detection_classes_ptr[i]); } MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores, @@ -863,24 +874,25 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes( absl::Status TensorsToDetectionsCalculator::ConvertToDetections( const float* detection_boxes, const float* detection_scores, const int* detection_classes, std::vector* output_detections) { - for (int i = 0; i < num_boxes_; ++i) { + for (int i = 0; i < num_boxes_ * classes_per_detection_; + i += classes_per_detection_) { if (max_results_ > 0 && output_detections->size() == max_results_) { break; } - if (options_.has_min_score_thresh() && - detection_scores[i] < options_.min_score_thresh()) { - continue; - } - if (!IsClassIndexAllowed(detection_classes[i])) { - continue; - } const int box_offset = i * num_coords_; Detection detection = ConvertToDetection( /*box_ymin=*/detection_boxes[box_offset + box_indices_[0]], /*box_xmin=*/detection_boxes[box_offset + box_indices_[1]], /*box_ymax=*/detection_boxes[box_offset + box_indices_[2]], /*box_xmax=*/detection_boxes[box_offset + box_indices_[3]], - detection_scores[i], detection_classes[i], options_.flip_vertically()); + absl::MakeConstSpan(detection_scores + i, classes_per_detection_), + absl::MakeConstSpan(detection_classes + i, classes_per_detection_), + options_.flip_vertically()); + // if all the scores and classes are filtered out, we skip the empty + // detection. + if (detection.score().empty()) { + continue; + } const auto& bbox = detection.location_data().relative_bounding_box(); if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || std::isnan(bbox.height())) { @@ -910,11 +922,21 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( } Detection TensorsToDetectionsCalculator::ConvertToDetection( - float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, - int class_id, bool flip_vertically) { + float box_ymin, float box_xmin, float box_ymax, float box_xmax, + absl::Span scores, absl::Span class_ids, + bool flip_vertically) { Detection detection; - detection.add_score(score); - detection.add_label_id(class_id); + for (int i = 0; i < scores.size(); ++i) { + if (!IsClassIndexAllowed(class_ids[i])) { + continue; + } + if (options_.has_min_score_thresh() && + scores[i] < options_.min_score_thresh()) { + continue; + } + detection.add_score(scores[i]); + detection.add_label_id(class_ids[i]); + } LocationData* location_data = detection.mutable_location_data(); location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto index 5cedff6c7..49db8e3e7 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto @@ -75,7 +75,7 @@ message TensorsToDetectionsCalculatorOptions { // representation has a bottom-left origin (e.g., in OpenGL). optional bool flip_vertically = 18 [default = false]; - // Score threshold for perserving decoded detections. + // Score threshold for preserving decoded detections. optional float min_score_thresh = 19; // The maximum number of the detection results to return. If < 0, all diff --git a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc index 5942f234d..77488443f 100644 --- a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc @@ -124,7 +124,7 @@ absl::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) { kFlipVertically(cc).IsConnected())) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input width/height for using flipping when outputing " + << "Must provide input width/height for using flipping when outputting " "landmarks in absolute coordinates."; } return absl::OkStatus(); diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 6456126ae..24fd1bd52 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -208,7 +208,7 @@ absl::Status TensorsToSegmentationCalculator::GetContract( if (CanUseGpu()) { #if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract( - cc, /*requesst_gpu_as_optional=*/true)); + cc, /*request_gpu_as_optional=*/true)); #if MEDIAPIPE_METAL_ENABLED MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif // MEDIAPIPE_METAL_ENABLED diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator_test.cc index 3db9145d2..e5c6b8ade 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator_test.cc @@ -60,42 +60,38 @@ struct FormattingTestCase { std::vector inputs; std::vector expected_outputs; Options::Activation activation; - int rows; - int cols; - int channels; + int rows = 1; + int cols = 1; + int rows_new = 1; + int cols_new = 1; + int channels = 1; + double max_abs_diff = 1e-7; }; using TensorsToSegmentationCalculatorTest = TestWithParam; -// Currently only useable for tests with no output resize. TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) { - const FormattingTestCase& test_case = GetParam(); - std::vector inputs = test_case.inputs; - std::vector expected_outputs = test_case.expected_outputs; - Options::Activation activation = test_case.activation; - int rows = test_case.rows; - int cols = test_case.cols; - int channels = test_case.channels; + const auto& [test_name, inputs, expected_outputs, activation, rows, cols, + rows_new, cols_new, channels, max_abs_diff] = GetParam(); - std::string string_config = absl::Substitute( - R"pb( - input_stream: "tensors" - input_stream: "size" - node { - calculator: "TensorsToSegmentationCalculator" - input_stream: "TENSORS:tensors" - input_stream: "OUTPUT_SIZE:size" - output_stream: "MASK:image_as_mask" - options: { - [mediapipe.TensorsToSegmentationCalculatorOptions.ext] { - activation: $0 - } - } - } - )pb", - ActivationTypeToString(activation)); auto graph_config = - mediapipe::ParseTextProtoOrDie(string_config); + mediapipe::ParseTextProtoOrDie(absl::Substitute( + R"pb( + input_stream: "tensors" + input_stream: "size" + node { + calculator: "TensorsToSegmentationCalculator" + input_stream: "TENSORS:tensors" + input_stream: "OUTPUT_SIZE:size" + output_stream: "MASK:image_as_mask" + options: { + [mediapipe.TensorsToSegmentationCalculatorOptions.ext] { + activation: $0 + } + } + } + )pb", + ActivationTypeToString(activation))); std::vector output_packets; tool::AddVectorSink("image_as_mask", &graph_config, &output_packets); @@ -119,28 +115,34 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) { MP_ASSERT_OK(graph.AddPacketToInputStream( "tensors", mediapipe::Adopt(tensors.release()).At(Timestamp(0)))); } + + // The output size is defined as pair(new_width, new_height). MP_ASSERT_OK(graph.AddPacketToInputStream( - "size", - mediapipe::Adopt(new std::pair(rows, cols)).At(Timestamp(0)))); + "size", mediapipe::Adopt(new std::pair(cols_new, rows_new)) + .At(Timestamp(0)))); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_THAT(output_packets, SizeIs(1)); const Image& image_as_mask = output_packets[0].Get(); + EXPECT_FALSE(image_as_mask.UsesGpu()); + std::shared_ptr result_mat = formats::MatView(&image_as_mask); - EXPECT_EQ(result_mat->rows, rows); - EXPECT_EQ(result_mat->cols, cols); - EXPECT_EQ(result_mat->channels(), channels); + EXPECT_EQ(result_mat->rows, rows_new); + EXPECT_EQ(result_mat->cols, cols_new); + EXPECT_EQ(result_mat->channels(), 1); // Compare the real result with the expected result. - cv::Mat expected_result = cv::Mat( - rows, cols, CV_32FC1, const_cast(expected_outputs.data())); + cv::Mat expected_result = + cv::Mat(rows_new, cols_new, CV_32FC1, + const_cast(expected_outputs.data())); cv::Mat diff; cv::absdiff(*result_mat, expected_result, diff); double max_val; cv::minMaxLoc(diff, nullptr, &max_val); - // Expects the maximum absolute pixel-by-pixel difference is less than 1e-5. - // This delta is for passthorugh accuracy only. - EXPECT_LE(max_val, 1e-5); + + // The max allowable diff between output and expected output varies between + // tests. + EXPECT_LE(max_val, max_abs_diff); MP_ASSERT_OK(graph.CloseInputStream("tensors")); MP_ASSERT_OK(graph.CloseInputStream("size")); @@ -150,17 +152,96 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) { INSTANTIATE_TEST_SUITE_P( TensorsToSegmentationCalculatorTests, TensorsToSegmentationCalculatorTest, testing::ValuesIn({ - {/*test_name=*/"NoActivationAndNoOutputResize", - /*inputs=*/ - {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, - 14.0, 15.0, 16.0}, - /*expected_outputs=*/ - {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, - 14.0, 15.0, 16.0}, - /*activation=*/Options::NONE, - /*rows=*/4, - /*cols=*/4, - /*channels=*/1}, + {.test_name = "NoActivationAndNoOutputResize", + .inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, 16.0}, + .expected_outputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, + 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + .activation = Options::NONE, + .rows = 4, + .cols = 4, + .rows_new = 4, + .cols_new = 4, + .channels = 1, + .max_abs_diff = 1e-7}, + {.test_name = "OutputResizeOnly", + .inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, 16.0}, + .expected_outputs = {1, 1.5, 2.166667, 2.833333, 3.5, 4, + 3.8, 4.3, 4.966667, 5.633333, 6.3, 6.8, + 7, 7.5, 8.166667, 8.833333, 9.5, 10, + 10.2, 10.7, 11.366667, 12.033333, 12.7, 13.2, + 13, 13.5, 14.166667, 14.833333, 15.5, 16}, + .activation = Options::NONE, + .rows = 4, + .cols = 4, + .rows_new = 5, + .cols_new = 6, + .channels = 1, + .max_abs_diff = 1e-6}, + {.test_name = "SigmoidActivationWithNoOutputResize", + .inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, 16.0}, + .expected_outputs = {0.731059, 0.880797, 0.952574, 0.982014, 0.993307, + 0.997527, 0.999089, 0.999665, 0.999877, 0.999955, + 0.999983, 0.999994, 0.999998, 0.999999, 1.0, 1.0}, + .activation = Options::SIGMOID, + .rows = 4, + .cols = 4, + .rows_new = 4, + .cols_new = 4, + .channels = 1, + .max_abs_diff = 1e-6}, + {.test_name = "SigmoidActivationWithOutputResize", + .inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, 16.0}, + .expected_outputs = {0.731059, 0.805928, 0.89276, 0.940611, 0.967294, + 0.982014, 0.914633, 0.93857, 0.966279, 0.981363, + 0.989752, 0.994369, 0.996592, 0.997666, 0.998873, + 0.999404, 0.999683, 0.999829, 0.999913, 0.99994, + 0.999971, 0.999985, 0.999992, 0.999996, 0.999998, + 0.999998, 0.999999, 1.0, 1.0, 1.0}, + .activation = Options::SIGMOID, + .rows = 4, + .cols = 4, + .rows_new = 5, + .cols_new = 6, + .channels = 1, + .max_abs_diff = 1e-6}, + {.test_name = "SoftmaxActivationWithNoOutputResize", + .inputs = {1.0, 2.0, 4.0, 2.0, 3.0, 5.0, 6.0, 1.5, + 7.0, 10.0, 11.0, 4.0, 12.0, 15.0, 16.0, 18.5, + 19.0, 20.0, 22.0, 23.0, 24.5, 23.4, 25.6, 28.3, + 29.2, 30.0, 24.6, 29.2, 30.0, 24.9, 31.2, 30.3}, + .expected_outputs = {0.731059, 0.119203, 0.880797, 0.0109869, 0.952574, + 0.000911051, 0.952574, 0.924142, 0.731059, + 0.731059, 0.24974, 0.937027, 0.689974, 0.990048, + 0.0060598, 0.28905}, + .activation = Options::SOFTMAX, + .rows = 4, + .cols = 4, + .rows_new = 4, + .cols_new = 4, + .channels = 2, + .max_abs_diff = 1e-6}, + {.test_name = "SoftmaxActivationWithOutputResize", + .inputs = {1.0, 2.0, 4.0, 2.0, 3.0, 5.0, 6.0, 1.5, + 7.0, 10.0, 11.0, 4.0, 12.0, 15.0, 16.0, 18.5, + 19.0, 20.0, 22.0, 23.0, 24.5, 23.4, 25.6, 28.3, + 29.2, 30.0, 24.6, 29.2, 30.0, 24.9, 31.2, 30.3}, + .expected_outputs = {0.731059, 0.425131, 0.246135, 0.753865, 0.445892, + 0.0109869, 0.886119, 0.461259, 0.185506, 0.781934, + 0.790618, 0.650195, 0.841816, 0.603901, 0.40518, + 0.561962, 0.765871, 0.930584, 0.718733, 0.763744, + 0.703402, 0.281989, 0.459635, 0.742634, 0.689974, + 0.840011, 0.82605, 0.170058, 0.147555, 0.28905}, + .activation = Options::SOFTMAX, + .rows = 4, + .cols = 4, + .rows_new = 5, + .cols_new = 6, + .channels = 2, + .max_abs_diff = 1e-6}, }), [](const testing::TestParamInfo< TensorsToSegmentationCalculatorTest::ParamType>& info) { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 4972b202d..95962c261 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -79,7 +79,7 @@ namespace mpms = mediapipe::mediasequence; // and label and label_id are optional but at least one of them should be set. // "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store // prefixed versions of each stream, which allows for multiple image streams to -// be included. However, the default names are suppored by more tools. +// be included. However, the default names are supported by more tools. // // Example config: // node { diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc index dc3d97844..ed234b3fa 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -67,8 +67,8 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, // -- 1-D or 2-D Tensor // Output: // -- Matrix with the same values as the Tensor -// If input tensor is 1 dimensional, the ouput Matrix is of (1xn) shape. -// If input tensor is 2 dimensional (batched), the ouput Matrix is (mxn) shape. +// If input tensor is 1 dimensional, the output Matrix is of (1xn) shape. +// If input tensor is 2 dimensional (batched), the output Matrix is (mxn) shape. // // Example Config // node: { diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index 84c32fed6..39993ada0 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -111,8 +111,8 @@ class InferenceState { // input_side_packet. // // The input and output streams are TensorFlow tensors labeled by tags. The tags -// for the streams are matched to feeds and fetchs in a TensorFlow session using -// a named_signature.generic_signature in the ModelManifest. The +// for the streams are matched to feeds and fetches in a TensorFlow session +// using a named_signature.generic_signature in the ModelManifest. The // generic_signature is used as key-value pairs between the MediaPipe tag and // the TensorFlow tensor. The signature_name in the options proto determines // which named_signature is used. The keys in the generic_signature must be @@ -128,7 +128,7 @@ class InferenceState { // addition. Once batch_size inputs have been provided, the batch will be run // and the output tensors sent out on the output streams with timestamps // corresponding to the input stream packets. Setting the batch_size to 1 -// completely disables batching, but is indepdent of add_batch_dim_to_tensors. +// completely disables batching, but is independent of add_batch_dim_to_tensors. // // The TensorFlowInferenceCalculator also support feeding states recurrently for // RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto index a243412c0..f09664592 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.proto @@ -42,7 +42,7 @@ message TensorFlowInferenceCalculatorOptions { // If the 0th dimension is the batch dimension, then the tensors are // concatenated on that dimension. If the 0th is a data dimension, then a 0th // dimension is added before concatenating. If added, the extra dimension is - // removed before outputing the tensor. Examples of each case: If you want + // removed before outputting the tensor. Examples of each case: If you want // to batch spectra of audio over time for an LSTM, a time-frequency // representation has a 0th dimension as the batch dimension. If you want to // batch frames of video that are [width, height, channels], the batch diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar index afba10928..7f93135c4 100644 Binary files a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar and b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 4e86b9270..3fa8f862f 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/mediapipe/examples/android/solutions/gradlew b/mediapipe/examples/android/solutions/gradlew index 65dcd68d6..1aa94a426 100755 --- a/mediapipe/examples/android/solutions/gradlew +++ b/mediapipe/examples/android/solutions/gradlew @@ -83,10 +83,8 @@ done # This is normally unused # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum @@ -133,10 +131,13 @@ location of your Java installation." fi else JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. @@ -144,7 +145,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then case $MAX_FD in #( max*) # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC2039,SC3045 MAX_FD=$( ulimit -H -n ) || warn "Could not query maximum file descriptor limit" esac @@ -152,7 +153,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then '' | soft) :;; #( *) # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC2039,SC3045 ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD" esac @@ -197,11 +198,15 @@ if "$cygwin" || "$msys" ; then done fi -# Collect all arguments for the java command; -# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of -# shell script including quotes and variable substitutions, so put them in -# double quotes to make sure that they get re-expanded; and -# * put everything else in single quotes, so that it's not re-expanded. + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. set -- \ "-Dorg.gradle.appname=$APP_BASE_NAME" \ diff --git a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc index dbabf84b1..f958924f0 100644 --- a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc +++ b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc @@ -56,7 +56,7 @@ absl::Status RunMPPGraph() { for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - RET_CHECK(!mediapipe::ContainsKey(input_side_packets, name_and_value[0])); + RET_CHECK(!input_side_packets.contains(name_and_value[0])); std::string input_side_packet_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( name_and_value[1], &input_side_packet_contents)); diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 77e3ab16d..65c8a15c8 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -616,6 +616,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/tool/sink_test.cc b/mediapipe/framework/tool/sink_test.cc index c5316af4d..9769aeeee 100644 --- a/mediapipe/framework/tool/sink_test.cc +++ b/mediapipe/framework/tool/sink_test.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/functional/bind_front.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" diff --git a/mediapipe/framework/tool/validate_name.cc b/mediapipe/framework/tool/validate_name.cc index 8f9be7687..4415f76b5 100644 --- a/mediapipe/framework/tool/validate_name.cc +++ b/mediapipe/framework/tool/validate_name.cc @@ -134,7 +134,7 @@ absl::Status ParseTagAndName(absl::string_view tag_and_name, std::string* tag, RET_CHECK(name); absl::Status tag_status = absl::OkStatus(); absl::Status name_status = absl::UnknownError(""); - int name_index = 0; + int name_index = -1; std::vector v = absl::StrSplit(tag_and_name, ':'); if (v.size() == 1) { name_status = ValidateName(v[0]); @@ -143,7 +143,7 @@ absl::Status ParseTagAndName(absl::string_view tag_and_name, std::string* tag, tag_status = ValidateTag(v[0]); name_status = ValidateName(v[1]); name_index = 1; - } + } // else omitted, name_index == -1, triggering error. if (name_index == -1 || tag_status != absl::OkStatus() || name_status != absl::OkStatus()) { tag->clear(); diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 27770acaa..f39b8d3f7 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -516,6 +516,7 @@ cc_library( ":gpu_buffer_storage", ":image_frame_view", "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/strings:str_format", ], ) @@ -526,12 +527,14 @@ mediapipe_proto_library( visibility = ["//visibility:public"], ) -objc_library( +cc_library( name = "pixel_buffer_pool_util", - srcs = ["pixel_buffer_pool_util.mm"], + srcs = ["pixel_buffer_pool_util.cc"], hdrs = ["pixel_buffer_pool_util.h"], copts = [ + "-x objective-c++", "-Wno-shorten-64-to-32", + "-fobjc-arc", # enable reference-counting ], visibility = ["//visibility:public"], deps = [ @@ -542,13 +545,14 @@ objc_library( ], ) -objc_library( +cc_library( name = "metal_shared_resources", - srcs = ["metal_shared_resources.mm"], + srcs = ["metal_shared_resources.cc"], hdrs = ["metal_shared_resources.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", + "-fobjc-arc", # enable reference-counting ], features = ["-layering_check"], visibility = ["//visibility:public"], @@ -557,15 +561,17 @@ objc_library( "@google_toolbox_for_mac//:GTM_Defines", ] + [ ], + alwayslink = 1, ) -objc_library( +cc_library( name = "MPPMetalUtil", - srcs = ["MPPMetalUtil.mm"], + srcs = ["MPPMetalUtil.cc"], hdrs = ["MPPMetalUtil.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", + "-fobjc-arc", # enable reference-counting ], visibility = ["//visibility:public"], deps = [ @@ -575,6 +581,7 @@ objc_library( "@com_google_absl//absl/time", "@google_toolbox_for_mac//:GTM_Defines", ], + alwayslink = 1, ) mediapipe_proto_library( @@ -857,12 +864,14 @@ cc_library( }), ) -objc_library( +cc_library( name = "MPPMetalHelper", - srcs = ["MPPMetalHelper.mm"], + srcs = ["MPPMetalHelper.cc"], hdrs = ["MPPMetalHelper.h"], copts = [ "-Wno-shorten-64-to-32", + "-x objective-c++", + "-fobjc-arc", ], features = ["-layering_check"], visibility = ["//visibility:public"], @@ -1215,9 +1224,13 @@ mediapipe_cc_test( ], requires_full_emulation = True, deps = [ + ":gl_texture_buffer", + ":gl_texture_util", ":gpu_buffer_format", ":gpu_buffer_storage_ahwb", + ":gpu_test_base", "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:test_util", ], ) diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.cc similarity index 74% rename from mediapipe/gpu/MPPMetalHelper.mm rename to mediapipe/gpu/MPPMetalHelper.cc index c66483698..e92d6aae7 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.cc @@ -14,15 +14,14 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#import "GTMDefines.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" +#include "mediapipe/framework/port/ret_check.h" #import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/gpu_service.h" #import "mediapipe/gpu/graph_support.h" #import "mediapipe/gpu/metal_shared_resources.h" -#import "GTMDefines.h" - -#include "mediapipe/framework/port/ret_check.h" @interface MPPMetalHelper () { mediapipe::GpuResources* _gpuResources; @@ -31,7 +30,8 @@ namespace mediapipe { -// Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. +// Using a C++ class so it can be declared as a friend of +// LegacyCalculatorSupport. class MetalHelperLegacySupport { public: static CalculatorContract* GetCalculatorContract() { @@ -61,7 +61,8 @@ class MetalHelperLegacySupport { - (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc { if (!cc) return nil; - return [self initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()]; + return [self + initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()]; } + (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc { @@ -77,7 +78,8 @@ class MetalHelperLegacySupport { } // Legacy support. -- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets { +- (instancetype)initWithSidePackets: + (const mediapipe::PacketSet&)inputSidePackets { auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContext(); if (cc) { ABSL_CHECK_EQ(&inputSidePackets, &cc->InputSidePackets()); @@ -85,16 +87,19 @@ class MetalHelperLegacySupport { } // TODO: remove when we can. - ABSL_LOG(WARNING) << "CalculatorContext not available. If this calculator uses " - "CalculatorBase, call initWithCalculatorContext instead."; + ABSL_LOG(WARNING) + << "CalculatorContext not available. If this calculator uses " + "CalculatorBase, call initWithCalculatorContext instead."; mediapipe::GpuSharedData* gpu_shared = - inputSidePackets.Tag(mediapipe::kGpuSharedTagName).Get(); + inputSidePackets.Tag(mediapipe::kGpuSharedTagName) + .Get(); return [self initWithGpuResources:gpu_shared->gpu_resources.get()]; } // Legacy support. -+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets { ++ (absl::Status)setupInputSidePackets: + (mediapipe::PacketTypeSet*)inputSidePackets { auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract(); if (cc) { ABSL_CHECK_EQ(inputSidePackets, &cc->InputSidePackets()); @@ -102,12 +107,12 @@ class MetalHelperLegacySupport { } // TODO: remove when we can. - ABSL_LOG(WARNING) << "CalculatorContract not available. If you're calling this " - "from a GetContract method, call updateContract instead."; + ABSL_LOG(WARNING) + << "CalculatorContract not available. If you're calling this " + "from a GetContract method, call updateContract instead."; auto id = inputSidePackets->GetId(mediapipe::kGpuSharedTagName, 0); - RET_CHECK(id.IsValid()) - << "A " << mediapipe::kGpuSharedTagName - << " input side packet is required here."; + RET_CHECK(id.IsValid()) << "A " << mediapipe::kGpuSharedTagName + << " input side packet is required here."; inputSidePackets->Get(id).Set(); return absl::OkStatus(); } @@ -125,10 +130,12 @@ class MetalHelperLegacySupport { } - (id)commandBuffer { - return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; + return + [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; } -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer: + (const mediapipe::GpuBuffer&)gpuBuffer plane:(size_t)plane { CVPixelBufferRef pixel_buffer = mediapipe::GetCVPixelBufferRef(gpuBuffer); OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); @@ -178,41 +185,48 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( NULL, _gpuResources->metal_shared().resources().mtlTextureCache, - mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, - &texture); + mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, + height, plane, &texture); ABSL_CHECK_EQ(err, kCVReturnSuccess); return texture; } -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer { +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer: + (const mediapipe::GpuBuffer&)gpuBuffer { return [self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:0]; } -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer { +- (id)metalTextureWithGpuBuffer: + (const mediapipe::GpuBuffer&)gpuBuffer { return [self metalTextureWithGpuBuffer:gpuBuffer plane:0]; } -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer - plane:(size_t)plane { +- (id)metalTextureWithGpuBuffer: + (const mediapipe::GpuBuffer&)gpuBuffer + plane:(size_t)plane { CFHolder cvTexture; cvTexture.adopt([self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:plane]); return CVMetalTextureGetTexture(*cvTexture); } -- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { +- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width + height:(int)height { return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width - height:(int)height - format:(mediapipe::GpuBufferFormat)format { + height:(int)height + format:(mediapipe::GpuBufferFormat) + format { return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); } -- (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { +- (id)newLibraryWithResourceName:(NSString*)name + error:(NSError* _Nullable*)error { return [_gpuResources->metal_shared().resources().mtlDevice - newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name - ofType:@"metallib"] + newLibraryWithFile:[[NSBundle bundleForClass:[self class]] + pathForResource:name + ofType:@"metallib"] error:error]; } diff --git a/mediapipe/gpu/MPPMetalUtil.mm b/mediapipe/gpu/MPPMetalUtil.cc similarity index 95% rename from mediapipe/gpu/MPPMetalUtil.mm rename to mediapipe/gpu/MPPMetalUtil.cc index ba8be0dbd..c9bd6798d 100644 --- a/mediapipe/gpu/MPPMetalUtil.mm +++ b/mediapipe/gpu/MPPMetalUtil.cc @@ -69,10 +69,10 @@ while (!bufferCompleted) { auto duration = absl::Now() - start_time; // If the spin-lock takes more than 5 ms then go to blocking wait: - // - it frees the CPU core for another threads: increase the performance/decrease power - // consumption. - // - if a driver thread that notifies that the GPU buffer is completed has lower priority then - // the CPU core is allocated for the thread. + // - it frees the CPU core for another threads: increase the + // performance/decrease power consumption. + // - if a driver thread that notifies that the GPU buffer is completed has + // lower priority then the CPU core is allocated for the thread. if (duration >= absl::Milliseconds(5)) { [commandBuffer waitUntilCompleted]; break; diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index 1b113f8ac..20f155e15 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -57,8 +57,8 @@ void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { // static absl::Status GlCalculatorHelper::UpdateContract(CalculatorContract* cc, - bool requesst_gpu_as_optional) { - if (requesst_gpu_as_optional) { + bool request_gpu_as_optional) { + if (request_gpu_as_optional) { cc->UseService(kGpuService).Optional(); } else { cc->UseService(kGpuService); diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index f5d98ebfe..45b25f67e 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -68,7 +68,7 @@ class GlCalculatorHelper { // This method can be called from GetContract to set up the needed GPU // resources. static absl::Status UpdateContract(CalculatorContract* cc, - bool requesst_gpu_as_optional = false); + bool request_gpu_as_optional = false); // This method can be called from FillExpectations to set the correct types // for the shared GL input side packet(s). diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 48afbd219..7e4694a0e 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gl_texture_buffer.h" +#include + #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "mediapipe/framework/formats/image_frame.h" @@ -131,6 +133,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { SymbolAvailable(&glTexStorage2D)) { ABSL_CHECK(data == nullptr) << "unimplemented"; glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_); + } else if (info.immutable) { + ABSL_CHECK(SymbolAvailable(&glTexStorage2D) && + context->GetGlVersion() != GlVersion::kGLES2) + << "Immutable GpuBuffer format requested is not supported in this " + << "GlContext. Format was " << static_cast(format_); + ABSL_CHECK(data == nullptr) << "unimplemented"; + glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_); } else { glTexImage2D(target_, 0 /* level */, info.gl_internal_format, width_, height_, 0 /* border */, info.gl_format, info.gl_type, data); diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 646fb383f..510a9cd48 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -35,6 +35,10 @@ namespace mediapipe { #endif // GL_HALF_FLOAT_OES #endif // __EMSCRIPTEN__ +#ifndef GL_RGBA8 +#define GL_RGBA8 0x8058 +#endif // GL_RGBA8 + #if !MEDIAPIPE_DISABLE_GPU #ifdef GL_ES_VERSION_2_0 static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { @@ -163,6 +167,14 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, { {GL_RGBA32F, GL_RGBA, GL_FLOAT, 1}, }}, + {GpuBufferFormat::kImmutableRGBAFloat128, + { + {GL_RGBA32F, GL_RGBA, GL_FLOAT, 1, true /* immutable */}, + }}, + {GpuBufferFormat::kImmutableRGBA32, + { + {GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, 1, true /* immutable */}, + }}, }}; static const auto* gles2_format_info = ([] { @@ -206,6 +218,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { switch (format) { + case GpuBufferFormat::kImmutableRGBA32: case GpuBufferFormat::kBGRA32: // TODO: verify we are handling order of channels correctly. return ImageFormat::SRGBA; @@ -221,10 +234,11 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { return ImageFormat::SRGB; case GpuBufferFormat::kTwoComponentFloat32: return ImageFormat::VEC32F2; + case GpuBufferFormat::kImmutableRGBAFloat128: case GpuBufferFormat::kRGBAFloat128: return ImageFormat::VEC32F4; case GpuBufferFormat::kRGBA32: - // TODO: this likely maps to ImageFormat::SRGBA + return ImageFormat::SRGBA; case GpuBufferFormat::kGrayHalf16: case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Red: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 06eabda77..223780939 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -53,6 +53,10 @@ enum class GpuBufferFormat : uint32_t { kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible. kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'), kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'), + // Immutable version of kRGBA32 + kImmutableRGBA32 = MEDIAPIPE_FOURCC('4', 'C', 'I', '8'), + // Immutable version of kRGBAFloat128 + kImmutableRGBAFloat128 = MEDIAPIPE_FOURCC('4', 'C', 'I', 'f'), // 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling. kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'), // 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling. @@ -78,6 +82,9 @@ struct GlTextureInfo { // For multiplane buffers, this represents how many times smaller than // the nominal image size a plane is. int downscale; + // For GLES3.1+ compute shaders, users may explicitly request immutable + // textures. + bool immutable = false; }; const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, @@ -121,6 +128,8 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { return kCVPixelFormatType_64RGBAHalf; case GpuBufferFormat::kRGBAFloat128: return kCVPixelFormatType_128RGBAFloat; + case GpuBufferFormat::kImmutableRGBA32: + case GpuBufferFormat::kImmutableRGBAFloat128: case GpuBufferFormat::kNV12: case GpuBufferFormat::kNV21: case GpuBufferFormat::kI420: diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index da6c5a72d..5983758f9 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -151,7 +151,7 @@ static std::shared_ptr ConvertFromImageFrame( std::shared_ptr frame) { auto status_or_buffer = CreateCVPixelBufferForImageFrame(frame->image_frame()); - ABSL_CHECK(status_or_buffer.ok()); + ABSL_CHECK_OK(status_or_buffer); return std::make_shared( std::move(status_or_buffer).value()); } diff --git a/mediapipe/gpu/metal_shared_resources.mm b/mediapipe/gpu/metal_shared_resources.cc similarity index 85% rename from mediapipe/gpu/metal_shared_resources.mm rename to mediapipe/gpu/metal_shared_resources.cc index 80d755a01..925c0f995 100644 --- a/mediapipe/gpu/metal_shared_resources.mm +++ b/mediapipe/gpu/metal_shared_resources.cc @@ -50,9 +50,10 @@ - (CVMetalTextureCacheRef)mtlTextureCache { @synchronized(self) { if (!_mtlTextureCache) { - CVReturn __unused err = - CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); - NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, + CVReturn __unused err = CVMetalTextureCacheCreate( + NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, + @"Error at CVMetalTextureCacheCreate %d ; device %@", err, self.mtlDevice); // TODO: register and flush metal caches too. } diff --git a/mediapipe/gpu/pixel_buffer_pool_util.mm b/mediapipe/gpu/pixel_buffer_pool_util.cc similarity index 63% rename from mediapipe/gpu/pixel_buffer_pool_util.mm rename to mediapipe/gpu/pixel_buffer_pool_util.cc index 0b13cb194..9980d0a5d 100644 --- a/mediapipe/gpu/pixel_buffer_pool_util.mm +++ b/mediapipe/gpu/pixel_buffer_pool_util.cc @@ -24,23 +24,27 @@ namespace mediapipe { -CVPixelBufferPoolRef CreateCVPixelBufferPool( - int width, int height, OSType pixelFormat, int keepCount, - CFTimeInterval maxAge) { +CVPixelBufferPoolRef CreateCVPixelBufferPool(int width, int height, + OSType pixelFormat, int keepCount, + CFTimeInterval maxAge) { CVPixelBufferPoolRef pool = NULL; NSMutableDictionary *sourcePixelBufferOptions = - [(__bridge NSDictionary*)GetCVPixelBufferAttributesForGlCompatibility() mutableCopy]; + [(__bridge NSDictionary *)GetCVPixelBufferAttributesForGlCompatibility() + mutableCopy]; [sourcePixelBufferOptions addEntriesFromDictionary:@{ (id)kCVPixelBufferPixelFormatTypeKey : @(pixelFormat), (id)kCVPixelBufferWidthKey : @(width), (id)kCVPixelBufferHeightKey : @(height), }]; - NSMutableDictionary *pixelBufferPoolOptions = [[NSMutableDictionary alloc] init]; - pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] = @(keepCount); + NSMutableDictionary *pixelBufferPoolOptions = + [[NSMutableDictionary alloc] init]; + pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] = + @(keepCount); if (maxAge > 0) { - pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] = @(maxAge); + pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] = + @(maxAge); } CVPixelBufferPoolCreate( @@ -50,8 +54,9 @@ CVPixelBufferPoolRef CreateCVPixelBufferPool( return pool; } -OSStatus PreallocateCVPixelBufferPoolBuffers( - CVPixelBufferPoolRef pool, int count, CFDictionaryRef auxAttributes) { +OSStatus PreallocateCVPixelBufferPoolBuffers(CVPixelBufferPoolRef pool, + int count, + CFDictionaryRef auxAttributes) { CVReturn err = kCVReturnSuccess; NSMutableArray *pixelBuffers = [[NSMutableArray alloc] init]; for (int i = 0; i < count && err == kCVReturnSuccess; i++) { @@ -68,30 +73,37 @@ OSStatus PreallocateCVPixelBufferPoolBuffers( return err; } -CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold(int allocationThreshold) { +CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold( + int allocationThreshold) { if (allocationThreshold > 0) { - return (CFDictionaryRef)CFBridgingRetain( - @{(id)kCVPixelBufferPoolAllocationThresholdKey: @(allocationThreshold)}); + return (CFDictionaryRef)CFBridgingRetain(@{ + (id)kCVPixelBufferPoolAllocationThresholdKey : @(allocationThreshold) + }); } else { return nil; } } -CVReturn CreateCVPixelBufferWithPool( - CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes, - CVTextureCacheType textureCache, CVPixelBufferRef* outBuffer) { - return CreateCVPixelBufferWithPool(pool, auxAttributes, [textureCache](){ +CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool, + CFDictionaryRef auxAttributes, + CVTextureCacheType textureCache, + CVPixelBufferRef *outBuffer) { + return CreateCVPixelBufferWithPool( + pool, auxAttributes, + [textureCache]() { #if TARGET_OS_OSX - CVOpenGLTextureCacheFlush(textureCache, 0); + CVOpenGLTextureCacheFlush(textureCache, 0); #else - CVOpenGLESTextureCacheFlush(textureCache, 0); + CVOpenGLESTextureCacheFlush(textureCache, 0); #endif // TARGET_OS_OSX - }, outBuffer); + }, + outBuffer); } -CVReturn CreateCVPixelBufferWithPool( - CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes, - std::function flush, CVPixelBufferRef* outBuffer) { +CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool, + CFDictionaryRef auxAttributes, + std::function flush, + CVPixelBufferRef *outBuffer) { CVReturn err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( kCFAllocatorDefault, pool, auxAttributes, outBuffer); if (err == kCVReturnWouldExceedAllocationThreshold) { @@ -103,11 +115,13 @@ CVReturn CreateCVPixelBufferWithPool( kCFAllocatorDefault, pool, auxAttributes, outBuffer); } if (err == kCVReturnWouldExceedAllocationThreshold) { - // TODO: allow the application to set the threshold. For now, disable it by - // default, since the threshold we are using is arbitrary and some graphs routinely cross it. + // TODO: allow the application to set the threshold. For now, disable it + // by default, since the threshold we are using is arbitrary and some + // graphs routinely cross it. #ifdef ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK - NSLog(@"Using more buffers than expected! This is a debug-only warning, " - "you can ignore it if your app works fine otherwise."); + NSLog( + @"Using more buffers than expected! This is a debug-only warning, " + "you can ignore it if your app works fine otherwise."); #ifdef DEBUG NSLog(@"Pool status: %@", ((__bridge NSObject *)pool).description); #endif // DEBUG diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index df6c8db08..481a60bb6 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -52,9 +52,9 @@ objc_library( ) MEDIAPIPE_IOS_SRCS = [ - "MPPGraph.mm", - "MPPTimestampConverter.mm", - "NSError+util_status.mm", + "MPPGraph.cc", + "MPPTimestampConverter.cc", + "NSError+util_status.cc", ] MEDIAPIPE_IOS_HDRS = [ @@ -63,11 +63,13 @@ MEDIAPIPE_IOS_HDRS = [ "NSError+util_status.h", ] -objc_library( +cc_library( name = "mediapipe_framework_ios", srcs = MEDIAPIPE_IOS_SRCS, hdrs = MEDIAPIPE_IOS_HDRS, copts = [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting "-Wno-shorten-64-to-32", ], # This build rule is public to allow external customers to build their own iOS apps. @@ -99,6 +101,7 @@ objc_library( "@com_google_absl//absl/synchronization", "@google_toolbox_for_mac//:GTM_Defines", ], + alwayslink = 1, ) objc_library( diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.cc similarity index 74% rename from mediapipe/objc/MPPGraph.mm rename to mediapipe/objc/MPPGraph.cc index 3123eb863..df9a1ebd6 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.cc @@ -19,6 +19,7 @@ #include +#import "GTMDefines.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" @@ -26,22 +27,22 @@ #include "mediapipe/framework/graph_service.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/objc/NSError+util_status.h" #include "mediapipe/objc/util.h" -#import "mediapipe/objc/NSError+util_status.h" -#import "GTMDefines.h" - @implementation MPPGraph { - // Graph is wrapped in a unique_ptr because it was generating 39+KB of unnecessary ObjC runtime - // information. See https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de - // for details. + // Graph is wrapped in a unique_ptr because it was generating 39+KB of + // unnecessary ObjC runtime information. See + // https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de for + // details. std::unique_ptr _graph; /// Input side packets that will be added to the graph when it is started. std::map _inputSidePackets; /// Packet headers that will be added to the graph when it is started. std::map _streamHeaders; /// Service packets to be added to the graph when it is started. - std::map _servicePackets; + std::map + _servicePackets; /// Number of frames currently being processed by the graph. std::atomic _framesInFlight; @@ -56,7 +57,8 @@ BOOL _started; } -- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config { +- (instancetype)initWithGraphConfig: + (const mediapipe::CalculatorGraphConfig&)config { self = [super init]; if (self) { // Turn on Cocoa multithreading, since MediaPipe uses threads. @@ -76,40 +78,47 @@ return _graph->GetGraphInputStreamAddMode(); } -- (void)setPacketAddMode:(mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode { +- (void)setPacketAddMode: + (mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode { _graph->SetGraphInputStreamAddMode(mode); } - (void)addFrameOutputStream:(const std::string&)outputStreamName outputPacketType:(MPPPacketType)packetType { std::string callbackInputName; - mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config, &callbackInputName, - /*use_std_function=*/true); - // No matter what ownership qualifiers are put on the pointer, NewPermanentCallback will - // still end up with a strong pointer to MPPGraph*. That is why we use void* instead. + mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config, + &callbackInputName, + /*use_std_function=*/true); + // No matter what ownership qualifiers are put on the pointer, + // NewPermanentCallback will still end up with a strong pointer to MPPGraph*. + // That is why we use void* instead. void* wrapperVoid = (__bridge void*)self; _inputSidePackets[callbackInputName] = mediapipe::MakePacket>( - [wrapperVoid, outputStreamName, packetType](const mediapipe::Packet& packet) { - CallFrameDelegate(wrapperVoid, outputStreamName, packetType, packet); + [wrapperVoid, outputStreamName, + packetType](const mediapipe::Packet& packet) { + CallFrameDelegate(wrapperVoid, outputStreamName, packetType, + packet); }); } -- (NSString *)description { - return [NSString stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self, - _framesInFlight.load(std::memory_order_relaxed)]; +- (NSString*)description { + return [NSString + stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self, + _framesInFlight.load(std::memory_order_relaxed)]; } /// This is the function that gets called by the CallbackCalculator that /// receives the graph's output. void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, - MPPPacketType packetType, const mediapipe::Packet& packet) { + MPPPacketType packetType, + const mediapipe::Packet& packet) { MPPGraph* wrapper = (__bridge MPPGraph*)wrapperVoid; @autoreleasepool { if (packetType == MPPPacketTypeRaw) { [wrapper.delegate mediapipeGraph:wrapper - didOutputPacket:packet - fromStream:streamName]; + didOutputPacket:packet + fromStream:streamName]; } else if (packetType == MPPPacketTypeImageFrame) { wrapper->_framesInFlight--; const auto& frame = packet.Get(); @@ -118,13 +127,16 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, if (format == mediapipe::ImageFormat::SRGBA || format == mediapipe::ImageFormat::GRAY8) { CVPixelBufferRef pixelBuffer; - // If kCVPixelFormatType_32RGBA does not work, it returns kCVReturnInvalidPixelFormat. + // If kCVPixelFormatType_32RGBA does not work, it returns + // kCVReturnInvalidPixelFormat. CVReturn error = CVPixelBufferCreate( NULL, frame.Width(), frame.Height(), kCVPixelFormatType_32BGRA, GetCVPixelBufferAttributesForGlCompatibility(), &pixelBuffer); - _GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferCreate failed: %d", error); + _GTMDevAssert(error == kCVReturnSuccess, + @"CVPixelBufferCreate failed: %d", error); error = CVPixelBufferLockBaseAddress(pixelBuffer, 0); - _GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferLockBaseAddress failed: %d", error); + _GTMDevAssert(error == kCVReturnSuccess, + @"CVPixelBufferLockBaseAddress failed: %d", error); vImage_Buffer vDestination = vImageForCVPixelBuffer(pixelBuffer); // Note: we have to throw away const here, but we should not overwrite @@ -133,30 +145,35 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, if (format == mediapipe::ImageFormat::SRGBA) { // Swap R and B channels. const uint8_t permuteMap[4] = {2, 1, 0, 3}; - vImage_Error __unused vError = - vImagePermuteChannels_ARGB8888(&vSource, &vDestination, permuteMap, kvImageNoFlags); - _GTMDevAssert(vError == kvImageNoError, @"vImagePermuteChannels failed: %zd", vError); + vImage_Error __unused vError = vImagePermuteChannels_ARGB8888( + &vSource, &vDestination, permuteMap, kvImageNoFlags); + _GTMDevAssert(vError == kvImageNoError, + @"vImagePermuteChannels failed: %zd", vError); } else { // Convert grayscale back to BGRA - vImage_Error __unused vError = vImageGrayToBGRA(&vSource, &vDestination); - _GTMDevAssert(vError == kvImageNoError, @"vImageGrayToBGRA failed: %zd", vError); + vImage_Error __unused vError = + vImageGrayToBGRA(&vSource, &vDestination); + _GTMDevAssert(vError == kvImageNoError, + @"vImageGrayToBGRA failed: %zd", vError); } error = CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); _GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferUnlockBaseAddress failed: %d", error); - if ([wrapper.delegate respondsToSelector:@selector - (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { + if ([wrapper.delegate + respondsToSelector:@selector + (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { [wrapper.delegate mediapipeGraph:wrapper - didOutputPixelBuffer:pixelBuffer - fromStream:streamName - timestamp:packet.Timestamp()]; - } else if ([wrapper.delegate respondsToSelector:@selector - (mediapipeGraph:didOutputPixelBuffer:fromStream:)]) { + didOutputPixelBuffer:pixelBuffer + fromStream:streamName + timestamp:packet.Timestamp()]; + } else if ([wrapper.delegate + respondsToSelector:@selector + (mediapipeGraph:didOutputPixelBuffer:fromStream:)]) { [wrapper.delegate mediapipeGraph:wrapper - didOutputPixelBuffer:pixelBuffer - fromStream:streamName]; + didOutputPixelBuffer:pixelBuffer + fromStream:streamName]; } CVPixelBufferRelease(pixelBuffer); } else { @@ -168,22 +185,23 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, wrapper->_framesInFlight--; CVPixelBufferRef pixelBuffer; if (packetType == MPPPacketTypePixelBuffer) - pixelBuffer = mediapipe::GetCVPixelBufferRef(packet.Get()); + pixelBuffer = + mediapipe::GetCVPixelBufferRef(packet.Get()); else pixelBuffer = packet.Get().GetCVPixelBufferRef(); -if ([wrapper.delegate + if ([wrapper.delegate respondsToSelector:@selector (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { [wrapper.delegate mediapipeGraph:wrapper - didOutputPixelBuffer:pixelBuffer - fromStream:streamName - timestamp:packet.Timestamp()]; + didOutputPixelBuffer:pixelBuffer + fromStream:streamName + timestamp:packet.Timestamp()]; } else if ([wrapper.delegate respondsToSelector:@selector (mediapipeGraph:didOutputPixelBuffer:fromStream:)]) { [wrapper.delegate mediapipeGraph:wrapper - didOutputPixelBuffer:pixelBuffer - fromStream:streamName]; + didOutputPixelBuffer:pixelBuffer + fromStream:streamName]; } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } else { @@ -192,13 +210,15 @@ if ([wrapper.delegate } } -- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName { +- (void)setHeaderPacket:(const mediapipe::Packet&)packet + forStream:(const std::string&)streamName { _GTMDevAssert(!_started, @"%@ must be called before the graph is started", NSStringFromSelector(_cmd)); _streamHeaders[streamName] = packet; } -- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name { +- (void)setSidePacket:(const mediapipe::Packet&)packet + named:(const std::string&)name { _GTMDevAssert(!_started, @"%@ must be called before the graph is started", NSStringFromSelector(_cmd)); _inputSidePackets[name] = packet; @@ -211,7 +231,8 @@ if ([wrapper.delegate _servicePackets[&service] = std::move(packet); } -- (void)addSidePackets:(const std::map&)extraSidePackets { +- (void)addSidePackets: + (const std::map&)extraSidePackets { _GTMDevAssert(!_started, @"%@ must be called before the graph is started", NSStringFromSelector(_cmd)); _inputSidePackets.insert(extraSidePackets.begin(), extraSidePackets.end()); @@ -232,7 +253,8 @@ if ([wrapper.delegate - (absl::Status)performStart { absl::Status status; for (const auto& service_packet : _servicePackets) { - status = _graph->SetServicePacket(*service_packet.first, service_packet.second); + status = + _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } @@ -269,11 +291,12 @@ if ([wrapper.delegate } - (BOOL)waitUntilDoneWithError:(NSError**)error { - // Since this method blocks with no timeout, it should not be called in the main thread in - // an app. However, it's fine to allow that in a test. + // Since this method blocks with no timeout, it should not be called in the + // main thread in an app. However, it's fine to allow that in a test. // TODO: is this too heavy-handed? Maybe a warning would be fine. - _GTMDevAssert(![NSThread isMainThread] || (NSClassFromString(@"XCTest")), - @"waitUntilDoneWithError: should not be called on the main thread"); + _GTMDevAssert( + ![NSThread isMainThread] || (NSClassFromString(@"XCTest")), + @"waitUntilDoneWithError: should not be called on the main thread"); absl::Status status = _graph->WaitUntilDone(); _started = NO; if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; @@ -289,7 +312,8 @@ if ([wrapper.delegate - (BOOL)movePacket:(mediapipe::Packet&&)packet intoStream:(const std::string&)streamName error:(NSError**)error { - absl::Status status = _graph->AddPacketToInputStream(streamName, std::move(packet)); + absl::Status status = + _graph->AddPacketToInputStream(streamName, std::move(packet)); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } @@ -305,15 +329,17 @@ if ([wrapper.delegate - (BOOL)setMaxQueueSize:(int)maxQueueSize forStream:(const std::string&)streamName error:(NSError**)error { - absl::Status status = _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize); + absl::Status status = + _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } - (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)imageBuffer - packetType:(MPPPacketType)packetType { + packetType:(MPPPacketType)packetType { mediapipe::Packet packet; - if (packetType == MPPPacketTypeImageFrame || packetType == MPPPacketTypeImageFrameBGRANoSwap) { + if (packetType == MPPPacketTypeImageFrame || + packetType == MPPPacketTypeImageFrameBGRANoSwap) { auto frame = CreateImageFrameForCVPixelBuffer( imageBuffer, /* canOverwrite = */ false, /* bgrAsRgb = */ packetType == MPPPacketTypeImageFrameBGRANoSwap); @@ -328,7 +354,8 @@ if ([wrapper.delegate packet = mediapipe::MakePacket(imageBuffer); #else // CPU - auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer, /* canOverwrite = */ false, + auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer, + /* canOverwrite = */ false, /* bgrAsRgb = */ false); packet = mediapipe::MakePacket(std::move(frame)); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER @@ -339,7 +366,8 @@ if ([wrapper.delegate } - (mediapipe::Packet)imagePacketWithPixelBuffer:(CVPixelBufferRef)pixelBuffer { - return [self packetWithPixelBuffer:(pixelBuffer) packetType:(MPPPacketTypeImage)]; + return [self packetWithPixelBuffer:(pixelBuffer) + packetType:(MPPPacketTypeImage)]; } - (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer @@ -367,13 +395,16 @@ if ([wrapper.delegate allowOverwrite:(BOOL)allowOverwrite error:(NSError**)error { if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO; - mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType]; + mediapipe::Packet packet = + [self packetWithPixelBuffer:imageBuffer packetType:packetType]; BOOL success; if (allowOverwrite) { packet = std::move(packet).At(timestamp); - success = [self movePacket:std::move(packet) intoStream:inputName error:error]; + success = + [self movePacket:std::move(packet) intoStream:inputName error:error]; } else { - success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error]; + success = + [self sendPacket:packet.At(timestamp) intoStream:inputName error:error]; } if (success) _framesInFlight++; return success; @@ -407,22 +438,24 @@ if ([wrapper.delegate } - (void)debugPrintGlInfo { - std::shared_ptr gpu_resources = _graph->GetGpuResources(); + std::shared_ptr gpu_resources = + _graph->GetGpuResources(); if (!gpu_resources) { NSLog(@"GPU not set up."); return; } NSString* extensionString; - (void)gpu_resources->gl_context()->Run([&extensionString]{ - extensionString = [NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)]; + (void)gpu_resources->gl_context()->Run([&extensionString] { + extensionString = + [NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)]; return absl::OkStatus(); }); - NSArray* extensions = [extensionString componentsSeparatedByCharactersInSet: - [NSCharacterSet whitespaceCharacterSet]]; - for (NSString* oneExtension in extensions) - NSLog(@"%@", oneExtension); + NSArray* extensions = [extensionString + componentsSeparatedByCharactersInSet:[NSCharacterSet + whitespaceCharacterSet]]; + for (NSString* oneExtension in extensions) NSLog(@"%@", oneExtension); } @end diff --git a/mediapipe/objc/MPPTimestampConverter.mm b/mediapipe/objc/MPPTimestampConverter.cc similarity index 81% rename from mediapipe/objc/MPPTimestampConverter.mm rename to mediapipe/objc/MPPTimestampConverter.cc index e53758d71..44857c8e9 100644 --- a/mediapipe/objc/MPPTimestampConverter.mm +++ b/mediapipe/objc/MPPTimestampConverter.cc @@ -20,8 +20,7 @@ mediapipe::TimestampDiff _timestampOffset; } -- (instancetype)init -{ +- (instancetype)init { self = [super init]; if (self) { [self reset]; @@ -36,11 +35,14 @@ } - (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime { - Float64 sampleSeconds = CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0; - const int64 sampleUsec = sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond; + Float64 sampleSeconds = + CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0; + const int64 sampleUsec = + sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond; _mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset; if (_mediapipeTimestamp <= _lastTimestamp) { - _timestampOffset = _timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp; + _timestampOffset = + _timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp; _mediapipeTimestamp = _lastTimestamp + 1; } _lastTimestamp = _mediapipeTimestamp; diff --git a/mediapipe/objc/NSError+util_status.cc b/mediapipe/objc/NSError+util_status.cc new file mode 100644 index 000000000..144ec6ed4 --- /dev/null +++ b/mediapipe/objc/NSError+util_status.cc @@ -0,0 +1,72 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/objc/NSError+util_status.h" + +@implementation GUSUtilStatusWrapper + ++ (instancetype)wrapStatus:(const absl::Status &)status { + return [[self alloc] initWithStatus:status]; +} + +- (instancetype)initWithStatus:(const absl::Status &)status { + self = [super init]; + if (self) { + _status = status; + } + return self; +} + +- (NSString *)description { + return [NSString stringWithFormat:@"<%@: %p; status = %s>", [self class], + self, _status.message().data()]; +} + +@end + +@implementation NSError (GUSGoogleUtilStatus) + +NSString *const kGUSGoogleUtilStatusErrorDomain = + @"GoogleUtilStatusErrorDomain"; +NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey"; + ++ (NSError *)gus_errorWithStatus:(const absl::Status &)status { + NSDictionary *userInfo = @{ + NSLocalizedDescriptionKey : @(status.message().data()), + kGUSGoogleUtilStatusErrorKey : [GUSUtilStatusWrapper wrapStatus:status], + }; + NSError *error = + [NSError errorWithDomain:kGUSGoogleUtilStatusErrorDomain + code:static_cast(status.code()) + userInfo:userInfo]; + return error; +} + +- (absl::Status)gus_status { + NSString *domain = self.domain; + if ([domain isEqual:kGUSGoogleUtilStatusErrorDomain]) { + GUSUtilStatusWrapper *wrapper = self.userInfo[kGUSGoogleUtilStatusErrorKey]; + if (wrapper) return wrapper.status; +#if 0 + // Unfortunately, util/task/posixerrorspace.h is not in portable status yet. + // TODO: fix that. + } else if ([domain isEqual:NSPOSIXErrorDomain]) { + return ::util::PosixErrorToStatus(self.code, self.localizedDescription.UTF8String); +#endif + } + return absl::Status(absl::StatusCode::kUnknown, + self.localizedDescription.UTF8String); +} + +@end diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index cd9124948..3181fb5f1 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -207,8 +207,12 @@ class ImageTest(absltest.TestCase): loaded_image = Image.create_from_file(image_path) self.assertEqual(loaded_image.width, 720) self.assertEqual(loaded_image.height, 382) - self.assertEqual(loaded_image.channels, 3) - self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) + # On Mac w/ GPU support, images use 4 channels (SRGBA). Otherwise, all + # images use 3 channels (SRGB). + self.assertIn(loaded_image.channels, [3, 4]) + self.assertIn( + loaded_image.image_format, [ImageFormat.SRGB, ImageFormat.SRGBA] + ) if __name__ == '__main__': absltest.main() diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 800e883b4..98f162342 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -51,10 +51,10 @@ void ImageSubmodule(pybind11::module* module) { ```python import cv2 - cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat) + cv_mat = cv2.imread(input_file) + rgb_frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=cv_mat) gray_frame = mp.Image( - image_format=ImageFormat.GRAY, + image_format=mp.ImageFormat.GRAY8, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image @@ -244,12 +244,26 @@ void ImageSubmodule(pybind11::module* module) { image.def_static( "create_from_file", [](const std::string& file_name) { + unsigned char* image_data = nullptr; int width; int height; int channels; - auto* image_data = - stbi_load(file_name.c_str(), &width, &height, &channels, - /*desired_channels=*/0); + +#if TARGET_OS_OSX && !MEDIAPIPE_DISABLE_GPU + // Our ObjC layer does not support 3-channel images, so we read the + // number of channels first and request RGBA if needed. + if (stbi_info(file_name.c_str(), &width, &height, &channels)) { + if (channels == 3) { + channels = 4; + } + int unused; + image_data = + stbi_load(file_name.c_str(), &width, &height, &unused, channels); + } +#else + image_data = stbi_load(file_name.c_str(), &width, &height, &channels, + /*desired_channels=*/0); +#endif // TARGET_OS_OSX && !MEDIAPIPE_DISABLE_GPU if (image_data == nullptr) { throw RaisePyError(PyExc_RuntimeError, absl::StrFormat("Image decoding failed (%s): %s", @@ -263,11 +277,13 @@ void ImageSubmodule(pybind11::module* module) { ImageFormat::GRAY8, width, height, width, image_data, stbi_image_free); break; +#if !TARGET_OS_OSX || MEDIAPIPE_DISABLE_GPU case 3: image_frame = std::make_shared( ImageFormat::SRGB, width, height, 3 * width, image_data, stbi_image_free); break; +#endif // !TARGET_OS_OSX || MEDIAPIPE_DISABLE_GPU case 4: image_frame = std::make_shared( ImageFormat::SRGBA, width, height, 4 * width, image_data, diff --git a/mediapipe/python/pybind/image_frame.cc b/mediapipe/python/pybind/image_frame.cc index 7348133eb..90db05066 100644 --- a/mediapipe/python/pybind/image_frame.cc +++ b/mediapipe/python/pybind/image_frame.cc @@ -81,17 +81,20 @@ void ImageFrameSubmodule(pybind11::module* module) { become immutable after creation. Creation examples: - import cv2 - cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) - gray_frame = mp.ImageFrame( - image_format=ImageFormat.GRAY, - data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) - from PIL import Image - pil_img = Image.new('RGB', (60, 30), color = 'red') - image_frame = mp.ImageFrame( - image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ```python + import cv2 + cv_mat = cv2.imread(input_file) + rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) + gray_frame = mp.ImageFrame( + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + + from PIL import Image + pil_img = Image.new('RGB', (60, 30), color = 'red') + image_frame = mp.ImageFrame( + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ``` The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the diff --git a/mediapipe/tasks/c/vision/image_classifier/BUILD b/mediapipe/tasks/c/vision/image_classifier/BUILD index df0e636c5..e8ac090e9 100644 --- a/mediapipe/tasks/c/vision/image_classifier/BUILD +++ b/mediapipe/tasks/c/vision/image_classifier/BUILD @@ -30,13 +30,12 @@ cc_library( "//mediapipe/tasks/c/components/processors:classifier_options_converter", "//mediapipe/tasks/c/core:base_options", "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/image_classifier", "//mediapipe/tasks/cc/vision/utils:image_utils", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc index 4245ca4cd..ff6f5bdfc 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier.cc @@ -15,6 +15,8 @@ limitations under the License. #include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h" +#include +#include #include #include @@ -26,6 +28,7 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/classification_result_converter.h" #include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" #include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -41,7 +44,10 @@ using ::mediapipe::tasks::c::components::processors:: CppConvertToClassifierOptions; using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; using ::mediapipe::tasks::vision::CreateImageFromBuffer; +using ::mediapipe::tasks::vision::core::RunningMode; using ::mediapipe::tasks::vision::image_classifier::ImageClassifier; +typedef ::mediapipe::tasks::vision::image_classifier::ImageClassifierResult + CppImageClassifierResult; int CppProcessError(absl::Status status, char** error_msg) { if (error_msg) { @@ -60,6 +66,53 @@ ImageClassifier* CppImageClassifierCreate(const ImageClassifierOptions& options, CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); CppConvertToClassifierOptions(options.classifier_options, &cpp_options->classifier_options); + cpp_options->running_mode = static_cast(options.running_mode); + + // Enable callback for processing live stream data when the running mode is + // set to RunningMode::LIVE_STREAM. + if (cpp_options->running_mode == RunningMode::LIVE_STREAM) { + if (options.result_callback == nullptr) { + const absl::Status status = absl::InvalidArgumentError( + "Provided null pointer to callback function."); + ABSL_LOG(ERROR) << "Failed to create ImageClassifier: " << status; + CppProcessError(status, error_msg); + return nullptr; + } + + ImageClassifierOptions::result_callback_fn result_callback = + options.result_callback; + cpp_options->result_callback = + [result_callback](absl::StatusOr cpp_result, + const Image& image, int64_t timestamp) { + char* error_msg = nullptr; + + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status(); + CppProcessError(cpp_result.status(), &error_msg); + result_callback(nullptr, MpImage(), timestamp, error_msg); + free(error_msg); + return; + } + + // Result is valid for the lifetime of the callback function. + ImageClassifierResult result; + CppConvertToClassificationResult(*cpp_result, &result); + + const auto& image_frame = image.GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = { + .format = static_cast<::ImageFormat>(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + result_callback(&result, mp_image, timestamp, + /* error_msg= */ nullptr); + + CppCloseClassificationResult(&result); + }; + } auto classifier = ImageClassifier::Create(std::move(cpp_options)); if (!classifier.ok()) { @@ -75,8 +128,8 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image, ImageClassifierResult* result, char** error_msg) { if (image->type == MpImage::GPU_BUFFER) { - absl::Status status = - absl::InvalidArgumentError("gpu buffer not supported yet"); + const absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet."); ABSL_LOG(ERROR) << "Classification failed: " << status.message(); return CppProcessError(status, error_msg); @@ -102,6 +155,68 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image, return 0; } +int CppImageClassifierClassifyForVideo(void* classifier, const MpImage* image, + int64_t timestamp_ms, + ImageClassifierResult* result, + char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Classification failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_classifier = static_cast(classifier); + auto cpp_result = cpp_classifier->ClassifyForVideo(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToClassificationResult(*cpp_result, result); + return 0; +} + +int CppImageClassifierClassifyAsync(void* classifier, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Classification failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_classifier = static_cast(classifier); + auto cpp_result = cpp_classifier->ClassifyAsync(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Data preparation for the image classification failed: " + << cpp_result; + return CppProcessError(cpp_result, error_msg); + } + return 0; +} + void CppImageClassifierCloseResult(ImageClassifierResult* result) { CppCloseClassificationResult(result); } @@ -134,6 +249,22 @@ int image_classifier_classify_image(void* classifier, const MpImage* image, CppImageClassifierClassify(classifier, image, result, error_msg); } +int image_classifier_classify_for_video(void* classifier, const MpImage* image, + int64_t timestamp_ms, + ImageClassifierResult* result, + char** error_msg) { + return mediapipe::tasks::c::vision::image_classifier:: + CppImageClassifierClassifyForVideo(classifier, image, timestamp_ms, + result, error_msg); +} + +int image_classifier_classify_async(void* classifier, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + return mediapipe::tasks::c::vision::image_classifier:: + CppImageClassifierClassifyAsync(classifier, image, timestamp_ms, + error_msg); +} + void image_classifier_close_result(ImageClassifierResult* result) { mediapipe::tasks::c::vision::image_classifier::CppImageClassifierCloseResult( result); diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier.h b/mediapipe/tasks/c/vision/image_classifier/image_classifier.h index 60dc4a2c4..549c3f300 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier.h @@ -92,9 +92,16 @@ struct ImageClassifierOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set - // to RunningMode::LIVE_STREAM. - typedef void (*result_callback_fn)(ImageClassifierResult*, const MpImage*, - int64_t); + // to RunningMode::LIVE_STREAM. Arguments of the callback function include: + // the pointer to classification result, the image that result was obtained + // on, the timestamp relevant to classification results and pointer to error + // message in case of any failure. The validity of the passed arguments is + // true for the lifetime of the callback function. + // + // A caller is responsible for closing image classifier result. + typedef void (*result_callback_fn)(ImageClassifierResult* result, + const MpImage image, int64_t timestamp_ms, + char* error_msg); result_callback_fn result_callback; }; @@ -110,13 +117,22 @@ MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options, // If an error occurs, returns an error code and sets the error parameter to an // an error message (if `error_msg` is not nullptr). You must free the memory // allocated for the error message. -// -// TODO: Add API for video and live stream processing. MP_EXPORT int image_classifier_classify_image(void* classifier, const MpImage* image, ImageClassifierResult* result, char** error_msg = nullptr); +MP_EXPORT int image_classifier_classify_for_video(void* classifier, + const MpImage* image, + int64_t timestamp_ms, + ImageClassifierResult* result, + char** error_msg = nullptr); + +MP_EXPORT int image_classifier_classify_async(void* classifier, + const MpImage* image, + int64_t timestamp_ms, + char** error_msg = nullptr); + // Frees the memory allocated inside a ImageClassifierResult result. // Does not free the result pointer itself. MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result); diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc index e8e84d864..790f5ce36 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h" +#include #include #include @@ -36,12 +37,13 @@ using testing::HasSubstr; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kModelName[] = "mobilenet_v2_1.0_224.tflite"; constexpr float kPrecision = 1e-4; +constexpr int kIterations = 100; std::string GetFullPath(absl::string_view file_name) { return JoinPath("./", kTestDataDirectory, file_name); } -TEST(ImageClassifierTest, SmokeTest) { +TEST(ImageClassifierTest, ImageModeTest) { const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); ASSERT_TRUE(image.ok()); @@ -63,14 +65,13 @@ TEST(ImageClassifierTest, SmokeTest) { void* classifier = image_classifier_create(&options); EXPECT_NE(classifier, nullptr); + const auto& image_frame = image->GetImageFrameSharedPtr(); const MpImage mp_image = { .type = MpImage::IMAGE_FRAME, - .image_frame = { - .format = static_cast( - image->GetImageFrameSharedPtr()->Format()), - .image_buffer = image->GetImageFrameSharedPtr()->PixelData(), - .width = image->GetImageFrameSharedPtr()->Width(), - .height = image->GetImageFrameSharedPtr()->Height()}}; + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; ImageClassifierResult result; image_classifier_classify_image(classifier, &mp_image, &result); @@ -84,6 +85,120 @@ TEST(ImageClassifierTest, SmokeTest) { image_classifier_close(classifier); } +TEST(ImageClassifierTest, VideoModeTest) { + const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + ImageClassifierOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::VIDEO, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + /* result_callback= */ nullptr, + }; + + void* classifier = image_classifier_create(&options); + EXPECT_NE(classifier, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + ImageClassifierResult result; + image_classifier_classify_for_video(classifier, &mp_image, i, &result); + EXPECT_EQ(result.classifications_count, 1); + EXPECT_EQ(result.classifications[0].categories_count, 3); + EXPECT_EQ( + std::string{result.classifications[0].categories[0].category_name}, + "cheeseburger"); + EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f, + kPrecision); + image_classifier_close_result(&result); + } + image_classifier_close(classifier); +} + +// A structure to support LiveStreamModeTest below. This structure holds a +// static method `Fn` for a callback function of C API. A `static` qualifier +// allows to take an address of the method to follow API style. Another static +// struct member is `last_timestamp` that is used to verify that current +// timestamp is greater than the previous one. +struct LiveStreamModeCallback { + static int64_t last_timestamp; + static void Fn(ImageClassifierResult* classifier_result, const MpImage image, + int64_t timestamp, char* error_msg) { + ASSERT_NE(classifier_result, nullptr); + ASSERT_EQ(error_msg, nullptr); + EXPECT_EQ( + std::string{ + classifier_result->classifications[0].categories[0].category_name}, + "cheeseburger"); + EXPECT_NEAR(classifier_result->classifications[0].categories[0].score, + 0.7939f, kPrecision); + EXPECT_GT(image.image_frame.width, 0); + EXPECT_GT(image.image_frame.height, 0); + EXPECT_GT(timestamp, last_timestamp); + last_timestamp++; + } +}; +int64_t LiveStreamModeCallback::last_timestamp = -1; + +TEST(ImageClassifierTest, LiveStreamModeTest) { + const auto image = DecodeImageFromFile(GetFullPath("burger.jpg")); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + + ImageClassifierOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::LIVE_STREAM, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + /* result_callback= */ LiveStreamModeCallback::Fn, + }; + + void* classifier = image_classifier_create(&options); + EXPECT_NE(classifier, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + EXPECT_GE(image_classifier_classify_async(classifier, &mp_image, i), 0); + } + image_classifier_close(classifier); + + // Due to the flow limiter, the total of outputs might be smaller than the + // number of iterations. + EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations); + EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0); +} + TEST(ImageClassifierTest, InvalidArgumentHandling) { // It is an error to set neither the asset buffer nor the path. ImageClassifierOptions options = { @@ -124,7 +239,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) { ImageClassifierResult result; char* error_msg; image_classifier_classify_image(classifier, &mp_image, &result, &error_msg); - EXPECT_THAT(error_msg, HasSubstr("gpu buffer not supported yet")); + EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet")); free(error_msg); image_classifier_close(classifier); } diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index a45c91633..55cf3fca1 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -98,3 +98,9 @@ mediapipe_proto_library( name = "transformer_params_proto", srcs = ["transformer_params.proto"], ) + +mediapipe_proto_library( + name = "llm_params_proto", + srcs = ["llm_params.proto"], + deps = [":transformer_params_proto"], +) diff --git a/mediapipe/tasks/cc/components/processors/proto/llm_params.proto b/mediapipe/tasks/cc/components/processors/proto/llm_params.proto new file mode 100644 index 000000000..b0c253598 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/llm_params.proto @@ -0,0 +1,41 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.components.processors.proto; + +import "mediapipe/tasks/cc/components/processors/proto/transformer_params.proto"; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "LLMParametersProto"; + +// Parameters for Large Language Models (LLM). +message LLMParameters { + TransformerParameters transformer_parameters = 1; + + // Size of vocabulary. + int32 vocab_size = 2; + + // Whether or not to disable KV cache, which is also referred as state + // somewhere else. + bool disable_kv_cache = 3; + + // Id of the start token. + int32 start_token_id = 4; + + // Token to determine the end of output stream. + string stop_token = 5; +} diff --git a/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto index b2d13c3a2..a04aa9571 100644 --- a/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto +++ b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto @@ -44,6 +44,21 @@ message TransformerParameters { // Number of stacked transformers, `N` in the paper. int32 num_stacks = 7; - // Whether to use Multi-Query-Attention (MQA). - bool use_mqa = 8; + // Deprecated: bool use_mqa. Use num_kv_heads below. + reserved 8; + + // Number of kv heads. 0 means Multi-Head-Attention (MHA), key and value have + // same number of heads as query; 1 means Multi-Query-Attention (MQA), key and + // value have one head; otherwise, this specifies the number of heads for key + // and value, and Grouped-Query-Attention (GQA) will be used. See + // https://arxiv.org/pdf/2305.13245.pdf for details. + int32 num_kv_heads = 9; + + // Different types of attention mask type. + enum AttentionMaskType { + UNSPECIFIED = 0; + CAUSAL = 1; + PREFIX = 2; + } + AttentionMaskType attention_mask_type = 10; } diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index fa61feb9d..bb0d4b001 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -264,6 +264,7 @@ cc_library_with_tflite( "//mediapipe/framework:executor", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:name_util", + "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/tasks/cc:common", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/mediapipe/tasks/cc/core/task_runner.cc b/mediapipe/tasks/cc/core/task_runner.cc index 88c91bcdb..e3862ddd7 100644 --- a/mediapipe/tasks/cc/core/task_runner.cc +++ b/mediapipe/tasks/cc/core/task_runner.cc @@ -39,6 +39,10 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_shared_data_internal.h" +#endif // !MEDIAPIPE_DISABLE_GPU + namespace mediapipe { namespace tasks { namespace core { @@ -88,16 +92,34 @@ absl::StatusOr GenerateOutputPacketMap( } // namespace /* static */ +#if !MEDIAPIPE_DISABLE_GPU +absl::StatusOr> TaskRunner::Create( + CalculatorGraphConfig config, + std::unique_ptr op_resolver, + PacketsCallback packets_callback, + std::shared_ptr default_executor, + std::optional input_side_packets, + std::shared_ptr<::mediapipe::GpuResources> resources) { +#else absl::StatusOr> TaskRunner::Create( CalculatorGraphConfig config, std::unique_ptr op_resolver, PacketsCallback packets_callback, std::shared_ptr default_executor, std::optional input_side_packets) { +#endif // !MEDIAPIPE_DISABLE_GPU auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); MP_RETURN_IF_ERROR(task_runner->Initialize( std::move(config), std::move(op_resolver), std::move(default_executor), std::move(input_side_packets))); + +#if !MEDIAPIPE_DISABLE_GPU + if (resources) { + MP_RETURN_IF_ERROR( + task_runner->graph_.SetGpuResources(std::move(resources))); + } +#endif // !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(task_runner->Start()); return task_runner; } diff --git a/mediapipe/tasks/cc/core/task_runner.h b/mediapipe/tasks/cc/core/task_runner.h index 810063d4b..ef48bef55 100644 --- a/mediapipe/tasks/cc/core/task_runner.h +++ b/mediapipe/tasks/cc/core/task_runner.h @@ -42,6 +42,11 @@ limitations under the License. #include "tensorflow/lite/core/api/op_resolver.h" namespace mediapipe { + +#if !MEDIAPIPE_DISABLE_GPU +class GpuResources; +#endif // !MEDIAPIPE_DISABLE_GPU + namespace tasks { namespace core { @@ -72,12 +77,22 @@ class TaskRunner { // asynchronous method, Send(), to provide the input packets. If the packets // callback is absent, clients must use the synchronous method, Process(), to // provide the input packets and receive the output packets. +#if !MEDIAPIPE_DISABLE_GPU + static absl::StatusOr> Create( + CalculatorGraphConfig config, + std::unique_ptr op_resolver = nullptr, + PacketsCallback packets_callback = nullptr, + std::shared_ptr default_executor = nullptr, + std::optional input_side_packets = std::nullopt, + std::shared_ptr<::mediapipe::GpuResources> resources = nullptr); +#else static absl::StatusOr> Create( CalculatorGraphConfig config, std::unique_ptr op_resolver = nullptr, PacketsCallback packets_callback = nullptr, std::shared_ptr default_executor = nullptr, std::optional input_side_packets = std::nullopt); +#endif // !MEDIAPIPE_DISABLE_GPU // TaskRunner is neither copyable nor movable. TaskRunner(const TaskRunner&) = delete; diff --git a/mediapipe/tasks/ios/BUILD b/mediapipe/tasks/ios/BUILD index 7f3db7f7a..88b99ffec 100644 --- a/mediapipe/tasks/ios/BUILD +++ b/mediapipe/tasks/ios/BUILD @@ -57,6 +57,7 @@ CALCULATORS_AND_GRAPHS = [ "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", ] @@ -83,6 +84,7 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h", "//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h", "//mediapipe/tasks/ios/vision/core:sources/MPPImage.h", + "//mediapipe/tasks/ios/vision/core:sources/MPPMask.h", "//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetector.h", "//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorOptions.h", "//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorResult.h", @@ -98,6 +100,9 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifier.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierOptions.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h", + "//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenter.h", + "//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenterOptions.h", + "//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenterResult.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h", @@ -178,6 +183,7 @@ apple_static_xcframework( ":MPPTaskOptions.h", ":MPPTaskResult.h", ":MPPImage.h", + ":MPPMask.h", ":MPPRunningMode.h", ":MPPFaceDetector.h", ":MPPFaceDetectorOptions.h", @@ -188,6 +194,9 @@ apple_static_xcframework( ":MPPImageClassifier.h", ":MPPImageClassifierOptions.h", ":MPPImageClassifierResult.h", + ":MPPImageSegmenter.h", + ":MPPImageSegmenterOptions.h", + ":MPPImageSegmenterResult.h", ":MPPHandLandmarker.h", ":MPPHandLandmarkerOptions.h", ":MPPHandLandmarkerResult.h", @@ -204,6 +213,7 @@ apple_static_xcframework( "//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer", "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker", "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", + "//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenter", "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector", ], ) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index bef6bb9ee..eecb5e14e 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -14,6 +14,15 @@ #import +/** + * The delegate to run MediaPipe. If the delegate is not set, the default + * delegate CPU is used. + */ +typedef NS_ENUM(NSUInteger, MPPDelegate) { + MPPDelegateCPU, + MPPDelegateGPU, +} NS_SWIFT_NAME(Delegate); + NS_ASSUME_NONNULL_BEGIN /** @@ -26,6 +35,9 @@ NS_SWIFT_NAME(BaseOptions) /** The path to the model asset to open and mmap in memory. */ @property(nonatomic, copy) NSString *modelAssetPath; +/** Overrides the default backend to use for the provided model. */ +@property(nonatomic) MPPDelegate delegate; + @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m index a43119ad8..fac1b94c0 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -20,6 +20,7 @@ self = [super init]; if (self) { self.modelAssetPath = [[NSString alloc] init]; + self.delegate = MPPDelegateCPU; } return self; } @@ -28,6 +29,7 @@ MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; baseOptions.modelAssetPath = self.modelAssetPath; + baseOptions.delegate = self.delegate; return baseOptions; } diff --git a/mediapipe/tasks/ios/core/utils/BUILD b/mediapipe/tasks/ios/core/utils/BUILD index 3cd8bf231..d5a166eb3 100644 --- a/mediapipe/tasks/ios/core/utils/BUILD +++ b/mediapipe/tasks/ios/core/utils/BUILD @@ -21,6 +21,7 @@ objc_library( srcs = ["sources/MPPBaseOptions+Helpers.mm"], hdrs = ["sources/MPPBaseOptions+Helpers.h"], deps = [ + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm index 73bcac49d..9b2307c7e 100644 --- a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" namespace { using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; +using InferenceCalculatorOptionsProto = ::mediapipe::InferenceCalculatorOptions; } @implementation MPPBaseOptions (Helpers) @@ -33,6 +35,11 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; if (self.modelAssetPath) { baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); } + + if (self.delegate == MPPDelegateGPU) { + baseOptionsProto->mutable_acceleration()->mutable_gpu()->MergeFrom( + InferenceCalculatorOptionsProto::Delegate::Gpu()); + } } @end diff --git a/mediapipe/tasks/ios/text/language_detector/BUILD b/mediapipe/tasks/ios/text/language_detector/BUILD index 3b59fbd59..4df278037 100644 --- a/mediapipe/tasks/ios/text/language_detector/BUILD +++ b/mediapipe/tasks/ios/text/language_detector/BUILD @@ -31,3 +31,28 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) + +objc_library( + name = "MPPLanguageDetector", + srcs = ["sources/MPPLanguageDetector.mm"], + hdrs = ["sources/MPPLanguageDetector.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPLanguageDetector", + deps = [ + ":MPPLanguageDetectorOptions", + ":MPPLanguageDetectorResult", + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorOptionsHelpers", + "//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h new file mode 100644 index 000000000..7213a8e5f --- /dev/null +++ b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h @@ -0,0 +1,88 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.h" +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Predicts the language of an input text. + * + * This API expects a TFLite model with [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensor, output tensor, and the language codes in an AssociatedFile. + * + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string + * input tensors. + * + * Input tensor + * - One input tensor (`kTfLiteString`) of shape `[1]` containing the input string. + * + * Output tensor + * - One output tensor (`kTfLiteFloat32`) of shape `[1 x N]` where `N` is the number of languages. + */ +NS_SWIFT_NAME(LanguageDetector) +@interface MPPLanguageDetector : NSObject + +/** + * Creates a new instance of `LanguageDetector` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `LanguageDetectorOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * language detector. + * + * @return A new instance of `LanguageDetector` with the given model path. `nil` if there is an + * error in initializing the language detector. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `LanguageDetector` from the given `LanguageDetectorOptions`. + * + * @param options The options of type `LanguageDetectorOptions` to use for configuring the + * `LanguageDetector`. + * @param error An optional error parameter populated when there is an error in initializing the + * language detector. + * + * @return A new instance of `LanguageDetector` with the given options. `nil` if there is an + * error in initializing the language detector. + */ +- (nullable instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Predicts the language of the input text. + * + * @param text The `NSString` for which language is to be predicted. + * @param error An optional error parameter populated when there is an error in performing + * language prediction on the input text. + * + * @return A `LanguageDetectorResult` object that contains a list of language predictions. + */ +- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(detect(text:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.mm b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.mm new file mode 100644 index 000000000..4c9628c82 --- /dev/null +++ b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.mm @@ -0,0 +1,96 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = + @"mediapipe.tasks.text.language_detector.LanguageDetectorGraph"; + +@interface MPPLanguageDetector () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPLanguageDetector + +- (instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPLanguageDetectorOptions *options = [[MPPLanguageDetectorOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + std::optional outputPacketMap = [_textTaskRunner processPacketMap:packetMap + error:error]; + + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPLanguageDetectorResult languageDetectorResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +@end diff --git a/mediapipe/tasks/ios/text/language_detector/utils/BUILD b/mediapipe/tasks/ios/text/language_detector/utils/BUILD index 00a37e940..74de385c0 100644 --- a/mediapipe/tasks/ios/text/language_detector/utils/BUILD +++ b/mediapipe/tasks/ios/text/language_detector/utils/BUILD @@ -30,3 +30,15 @@ objc_library( "//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorOptions", ], ) + +objc_library( + name = "MPPLanguageDetectorResultHelpers", + srcs = ["sources/MPPLanguageDetectorResult+Helpers.mm"], + hdrs = ["sources/MPPLanguageDetectorResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorResult", + ], +) diff --git a/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h new file mode 100644 index 000000000..87431d157 --- /dev/null +++ b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPLanguageDetectorResult (Helpers) + ++ (MPPLanguageDetectorResult *)languageDetectorResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.mm b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.mm new file mode 100644 index 000000000..3fcbe0b80 --- /dev/null +++ b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.mm @@ -0,0 +1,61 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +#define int kMicroSecondsPerMilliSecond = 1000; + +@implementation MPPLanguageDetectorResult (Helpers) + ++ (MPPLanguageDetectorResult *)languageDetectorResultWithClassificationsPacket: + (const mediapipe::Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; + + return [MPPLanguageDetectorResult + languageDetectorResultWithClassificationResult:classificationResult + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + ++ (MPPLanguageDetectorResult *) + languageDetectorResultWithClassificationResult:(MPPClassificationResult *)classificationResult + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + NSMutableArray *languagePredictions = + [NSMutableArray arrayWithCapacity:classificationResult.classifications.count]; + + if (classificationResult.classifications.count > 0) { + for (MPPCategory *category in classificationResult.classifications[0].categories) { + MPPLanguagePrediction *languagePrediction = + [[MPPLanguagePrediction alloc] initWithLanguageCode:category.categoryName + probability:category.score]; + [languagePredictions addObject:languagePrediction]; + } + } + + return [[MPPLanguageDetectorResult alloc] initWithLanguagePredictions:languagePredictions + timestampInMilliseconds:timestampInMilliseconds]; +} + +@end diff --git a/mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.mm b/mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.mm index 440b321b9..e80d91253 100644 --- a/mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.mm +++ b/mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.mm @@ -37,7 +37,7 @@ vImage_Buffer allocatedVImageBuffer(vImagePixelCount width, vImagePixelCount hei } static void FreeDataProviderReleaseCallback(void *buffer, const void *data, size_t size) { - delete (vImage_Buffer *)buffer; + delete[] (vImage_Buffer *)buffer; } } // namespace diff --git a/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm index 885df734d..c4c3d398c 100644 --- a/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm @@ -47,7 +47,7 @@ using ::mediapipe::Packet; ->PixelData() width:confidenceMask.width() height:confidenceMask.height() - shouldCopy:shouldCopyMaskPacketData ? YES : NO]]; + shouldCopy:shouldCopyMaskPacketData]]; } } @@ -57,7 +57,7 @@ using ::mediapipe::Packet; initWithUInt8Data:(UInt8 *)cppCategoryMask.GetImageFrameSharedPtr().get()->PixelData() width:cppCategoryMask.width() height:cppCategoryMask.height() - shouldCopy:shouldCopyMaskPacketData ? YES : NO]; + shouldCopy:shouldCopyMaskPacketData]; } if (qualityScoresPacket.ValidateAsType>().ok()) { diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/BUILD b/mediapipe/tasks/ios/vision/pose_landmarker/BUILD index 16e791a97..a7b612bce 100644 --- a/mediapipe/tasks/ios/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/ios/vision/pose_landmarker/BUILD @@ -37,3 +37,23 @@ objc_library( "//mediapipe/tasks/ios/vision/core:MPPRunningMode", ], ) + +objc_library( + name = "MPPPoseLandmarksConnections", + hdrs = ["sources/MPPPoseLandmarksConnections.h"], + module_name = "MPPPoseLandmarksConnections", + deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"], +) + +objc_library( + name = "MPPPoseLandmarker", + hdrs = ["sources/MPPPoseLandmarker.h"], + module_name = "MPPPoseLandmarker", + deps = [ + ":MPPPoseLandmarkerOptions", + ":MPPPoseLandmarkerResult", + ":MPPPoseLandmarksConnections", + "//mediapipe/tasks/ios/components/containers:MPPConnection", + "//mediapipe/tasks/ios/vision/core:MPPImage", + ], +) diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarker.h b/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarker.h new file mode 100644 index 000000000..d70d1f129 --- /dev/null +++ b/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarker.h @@ -0,0 +1,160 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerOptions.h" +#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs pose landmarks detection on images. + * + * This API expects a pre-trained pose landmarks model asset bundle. + */ +NS_SWIFT_NAME(PoseLandmarker) +@interface MPPPoseLandmarker : NSObject + +/** The array of connections between all the landmarks in the detected pose. */ +@property(class, nonatomic, readonly) NSArray *poseLandmarks; + +/** + * Creates a new instance of `PoseLandmarker` from an absolute path to a model asset bundle stored + * locally on the device and the default `PoseLandmarkerOptions`. + * + * @param modelPath An absolute path to a model asset bundle stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * pose landmarker. + * + * @return A new instance of `PoseLandmarker` with the given model path. `nil` if there is an error + * in initializing the pose landmarker. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `PoseLandmarker` from the given `PoseLandmarkerOptions`. + * + * @param options The options of type `PoseLandmarkerOptions` to use for configuring the + * `PoseLandmarker`. + * @param error An optional error parameter populated when there is an error in initializing the + * pose landmarker. + * + * @return A new instance of `PoseLandmarker` with the given options. `nil` if there is an error in + * initializing the pose landmarker. + */ +- (nullable instancetype)initWithOptions:(MPPPoseLandmarkerOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs pose landmarks detection on the provided `MPImage` using the whole image as region of + * interest. Rotation will be applied according to the `orientation` property of the provided + * `MPImage`. Only use this method when the `PoseLandmarker` is created with running mode `.image`. + * + * This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a + * source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use + * `kCVPixelFormatType_32BGRA` as its pixel format. + * + * + * If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha + * channel. + * + * @param image The `MPImage` on which pose landmarks detection is to be performed. + * @param error An optional error parameter populated when there is an error in performing pose + * landmark detection on the input image. + * + * @return An `PoseLandmarkerResult` object that contains the pose landmarks detection + * results. + */ +- (nullable MPPPoseLandmarkerResult *)detectImage:(MPPImage *)image + error:(NSError **)error NS_SWIFT_NAME(detect(image:)); + +/** + * Performs pose landmarks detection on the provided video frame of type `MPImage` using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPImage`. Only use this method when the `PoseLandmarker` is created with running + * mode `.video`. + * + * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must + * be monotonically increasing. + * + * This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a + * source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use + * `kCVPixelFormatType_32BGRA` as its pixel format. + * + * + * If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha + * channel. + * + * @param image The `MPImage` on which pose landmarks detection is to be performed. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing pose + * landmark detection on the input video frame. + * + * @return An `PoseLandmarkerResult` object that contains the pose landmarks detection + * results. + */ +- (nullable MPPPoseLandmarkerResult *)detectVideoFrame:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:)); + +/** + * Sends live stream image data of type `MPImage` to perform pose landmarks detection using the + * whole image as region of interest. Rotation will be applied according to the `orientation` + * property of the provided `MPImage`. Only use this method when the `PoseLandmarker` is created + * with running mode`.liveStream`. + * + * The object which needs to be continuously notified of the available results of pose landmark + * detection must confirm to `PoseLandmarkerLiveStreamDelegate` protocol and implement the + * `poseLandmarker(_:didFinishDetectionWithResult:timestampInMilliseconds:error:)` delegate method. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the pose landmarker. The input timestamps must be monotonically increasing. + * + * This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a + * source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use + * `kCVPixelFormatType_32BGRA` as its pixel format. + * + * If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an + * Alpha channel. + * + * If this method is used for performing pose landmarks detection on live camera frames using + * `AVFoundation`, ensure that you request `AVCaptureVideoDataOutput` to output frames in + * `kCMPixelFormat_32BGRA` using its `videoSettings` property. + * + * @param image A live stream image data of type `MPImage` on which pose landmarks detection is to + * be performed. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the pose landmarker. The input timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing pose + * landmark detection on the input live stream image data. + * + * @return `YES` if the image was sent to the task successfully, otherwise `NO`. + */ +- (BOOL)detectAsyncImage:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h b/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h index ff9d001b2..b3dc72e8f 100644 --- a/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h +++ b/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h @@ -46,7 +46,7 @@ NS_SWIFT_NAME(PoseLandmarkerResult) */ - (instancetype)initWithLandmarks:(NSArray *> *)landmarks worldLandmarks:(NSArray *> *)worldLandmarks - segmentationMasks:(NSArray *)segmentationMasks + segmentationMasks:(nullable NSArray *)segmentationMasks timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER; - (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarksConnections.h b/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarksConnections.h new file mode 100644 index 000000000..71dcad6b8 --- /dev/null +++ b/mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarksConnections.h @@ -0,0 +1,40 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h" + +NS_ASSUME_NONNULL_BEGIN + +NSArray *const MPPPoseLandmarksConnections = @[ + [[MPPConnection alloc] initWithStart:0 end:1], [[MPPConnection alloc] initWithStart:1 end:2], + [[MPPConnection alloc] initWithStart:2 end:3], [[MPPConnection alloc] initWithStart:3 end:7], + [[MPPConnection alloc] initWithStart:0 end:4], [[MPPConnection alloc] initWithStart:4 end:5], + [[MPPConnection alloc] initWithStart:5 end:6], [[MPPConnection alloc] initWithStart:6 end:8], + [[MPPConnection alloc] initWithStart:9 end:10], [[MPPConnection alloc] initWithStart:11 end:12], + [[MPPConnection alloc] initWithStart:11 end:13], [[MPPConnection alloc] initWithStart:13 end:15], + [[MPPConnection alloc] initWithStart:15 end:17], [[MPPConnection alloc] initWithStart:15 end:19], + [[MPPConnection alloc] initWithStart:15 end:21], [[MPPConnection alloc] initWithStart:17 end:19], + [[MPPConnection alloc] initWithStart:12 end:14], [[MPPConnection alloc] initWithStart:14 end:16], + [[MPPConnection alloc] initWithStart:16 end:18], [[MPPConnection alloc] initWithStart:16 end:20], + [[MPPConnection alloc] initWithStart:16 end:22], [[MPPConnection alloc] initWithStart:18 end:20], + [[MPPConnection alloc] initWithStart:11 end:23], [[MPPConnection alloc] initWithStart:12 end:24], + [[MPPConnection alloc] initWithStart:23 end:24], [[MPPConnection alloc] initWithStart:23 end:25], + [[MPPConnection alloc] initWithStart:26 end:28], [[MPPConnection alloc] initWithStart:27 end:29], + [[MPPConnection alloc] initWithStart:28 end:30], [[MPPConnection alloc] initWithStart:29 end:31], + [[MPPConnection alloc] initWithStart:30 end:32], [[MPPConnection alloc] initWithStart:27 end:31], + [[MPPConnection alloc] initWithStart:28 end:32] +]; + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/utils/BUILD b/mediapipe/tasks/ios/vision/pose_landmarker/utils/BUILD index 94c945bc1..ee5e15bf9 100644 --- a/mediapipe/tasks/ios/vision/pose_landmarker/utils/BUILD +++ b/mediapipe/tasks/ios/vision/pose_landmarker/utils/BUILD @@ -36,3 +36,21 @@ objc_library( "//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerOptions", ], ) + +objc_library( + name = "MPPPoseLandmarkerResultHelpers", + srcs = ["sources/MPPPoseLandmarkerResult+Helpers.mm"], + hdrs = ["sources/MPPPoseLandmarkerResult+Helpers.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPLandmarkHelpers", + "//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerResult", + ], +) diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.h b/mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.h new file mode 100644 index 000000000..0b20dc32a --- /dev/null +++ b/mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.h @@ -0,0 +1,63 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h" + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPPoseLandmarkerResult (Helpers) + +/** + * Creates an `MPPPoseLandmarkerResult` from landmarks, world landmarks and segmentation mask + * packets. + * + * @param landmarksPacket A MediaPipe packet wrapping a `std::vector`. + * @param worldLandmarksPacket A MediaPipe packet wrapping a `std::vector`. + * @param segmentationMasksPacket a MediaPipe packet wrapping a `std::vector`. + * + * @return An `MPPPoseLandmarkerResult` object that contains the hand landmark detection + * results. + */ ++ (MPPPoseLandmarkerResult *) + poseLandmarkerResultWithLandmarksPacket:(const mediapipe::Packet &)landmarksPacket + worldLandmarksPacket:(const mediapipe::Packet &)worldLandmarksPacket + segmentationMasksPacket:(const mediapipe::Packet *)segmentationMasksPacket; + +/** + * Creates an `MPPPoseLandmarkerResult` from landmarks, world landmarks and segmentation mask + * images. + * + * @param landmarksProto A vector of protos of type `std::vector`. + * @param worldLandmarksProto A vector of protos of type `std::vector`. + * @param segmentationMasks A vector of type `std::vector`. + * @param timestampInMilliSeconds The timestamp of the Packet that contained the result. + * + * @return An `MPPPoseLandmarkerResult` object that contains the pose landmark detection + * results. + */ ++ (MPPPoseLandmarkerResult *) + poseLandmarkerResultWithLandmarksProto: + (const std::vector<::mediapipe::NormalizedLandmarkList> &)landmarksProto + worldLandmarksProto: + (const std::vector<::mediapipe::LandmarkList> &)worldLandmarksProto + segmentationMasks:(const std::vector *)segmentationMasks + timestampInMilliSeconds:(NSInteger)timestampInMilliseconds; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.mm b/mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.mm new file mode 100644 index 000000000..6cd67ff9c --- /dev/null +++ b/mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.mm @@ -0,0 +1,124 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.h" + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPLandmark+Helpers.h" + +namespace { +using LandmarkListProto = ::mediapipe::LandmarkList; +using NormalizedLandmarkListProto = ::mediapipe::NormalizedLandmarkList; +using ::mediapipe::Image; +using ::mediapipe::Packet; + +static const int kMicroSecondsPerMilliSecond = 1000; +} // namespace + +@implementation MPPPoseLandmarkerResult (Helpers) + ++ (MPPPoseLandmarkerResult *)emptyPoseLandmarkerResultWithTimestampInMilliseconds: + (NSInteger)timestampInMilliseconds { + return [[MPPPoseLandmarkerResult alloc] initWithLandmarks:@[] + worldLandmarks:@[] + segmentationMasks:@[] + timestampInMilliseconds:timestampInMilliseconds]; +} + ++ (MPPPoseLandmarkerResult *) + poseLandmarkerResultWithLandmarksProto: + (const std::vector &)landmarksProto + worldLandmarksProto: + (const std::vector &)worldLandmarksProto + segmentationMasks:(const std::vector *)segmentationMasks + timestampInMilliSeconds:(NSInteger)timestampInMilliseconds { + NSMutableArray *> *multiplePoseLandmarks = + [NSMutableArray arrayWithCapacity:(NSUInteger)landmarksProto.size()]; + + for (const auto &landmarkListProto : landmarksProto) { + NSMutableArray *landmarks = + [NSMutableArray arrayWithCapacity:(NSUInteger)landmarkListProto.landmark().size()]; + for (const auto &normalizedLandmarkProto : landmarkListProto.landmark()) { + MPPNormalizedLandmark *normalizedLandmark = + [MPPNormalizedLandmark normalizedLandmarkWithProto:normalizedLandmarkProto]; + [landmarks addObject:normalizedLandmark]; + } + [multiplePoseLandmarks addObject:landmarks]; + } + + NSMutableArray *> *multiplePoseWorldLandmarks = + [NSMutableArray arrayWithCapacity:(NSUInteger)worldLandmarksProto.size()]; + + for (const auto &worldLandmarkListProto : worldLandmarksProto) { + NSMutableArray *worldLandmarks = + [NSMutableArray arrayWithCapacity:(NSUInteger)worldLandmarkListProto.landmark().size()]; + for (const auto &landmarkProto : worldLandmarkListProto.landmark()) { + MPPLandmark *landmark = [MPPLandmark landmarkWithProto:landmarkProto]; + [worldLandmarks addObject:landmark]; + } + [multiplePoseWorldLandmarks addObject:worldLandmarks]; + } + + NSMutableArray *confidenceMasks = + [NSMutableArray arrayWithCapacity:(NSUInteger)segmentationMasks->size()]; + + for (const auto &segmentationMask : *segmentationMasks) { + [confidenceMasks addObject:[[MPPMask alloc] initWithFloat32Data:(float *)segmentationMask + .GetImageFrameSharedPtr() + .get() + ->PixelData() + width:segmentationMask.width() + height:segmentationMask.height() + /** Always deep copy */ + shouldCopy:YES]]; + } + + MPPPoseLandmarkerResult *poseLandmarkerResult = + [[MPPPoseLandmarkerResult alloc] initWithLandmarks:multiplePoseLandmarks + worldLandmarks:multiplePoseWorldLandmarks + segmentationMasks:confidenceMasks + timestampInMilliseconds:timestampInMilliseconds]; + return poseLandmarkerResult; +} + ++ (MPPPoseLandmarkerResult *) + poseLandmarkerResultWithLandmarksPacket:(const Packet &)landmarksPacket + worldLandmarksPacket:(const Packet &)worldLandmarksPacket + segmentationMasksPacket:(const Packet *)segmentationMasksPacket { + NSInteger timestampInMilliseconds = + (NSInteger)(landmarksPacket.Timestamp().Value() / kMicroSecondsPerMilliSecond); + + if (landmarksPacket.IsEmpty()) { + return [MPPPoseLandmarkerResult + emptyPoseLandmarkerResultWithTimestampInMilliseconds:timestampInMilliseconds]; + } + + if (!landmarksPacket.ValidateAsType>().ok() || + !worldLandmarksPacket.ValidateAsType>().ok()) { + return [MPPPoseLandmarkerResult + emptyPoseLandmarkerResultWithTimestampInMilliseconds:timestampInMilliseconds]; + } + + const std::vector *segmentationMasks = + segmentationMasksPacket ? &(segmentationMasksPacket->Get>()) : nullptr; + + return [MPPPoseLandmarkerResult + poseLandmarkerResultWithLandmarksProto:landmarksPacket + .Get>() + worldLandmarksProto:worldLandmarksPacket + .Get>() + segmentationMasks:segmentationMasks + timestampInMilliSeconds:timestampInMilliseconds]; +} + +@end diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 2d4258fed..da81bcd5d 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -70,7 +70,7 @@ class BaseOptions: platform_name = platform.system() if self.delegate == BaseOptions.Delegate.GPU: - if platform_name == 'Linux': + if platform_name in ['Linux', 'Darwin']: acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu()) else: raise NotImplementedError( diff --git a/mediapipe/tasks/python/core/pybind/BUILD b/mediapipe/tasks/python/core/pybind/BUILD index 88ea05f4f..391712f27 100644 --- a/mediapipe/tasks/python/core/pybind/BUILD +++ b/mediapipe/tasks/python/core/pybind/BUILD @@ -26,9 +26,11 @@ pybind_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/python/pybind:util", "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:task_runner", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index f95cddde8..0de7d24d8 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -14,6 +14,7 @@ #include "mediapipe/tasks/python/core/pybind/task_runner.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/python/pybind/util.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" @@ -21,6 +22,9 @@ #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" #include "tensorflow/lite/core/api/op_resolver.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_shared_data_internal.h" +#endif // MEDIAPIPE_DISABLE_GPU namespace mediapipe { namespace tasks { @@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc"); return absl::OkStatus(); }; } + +#if !MEDIAPIPE_DISABLE_GPU + auto gpu_resources_ = mediapipe::GpuResources::Create(); + if (!gpu_resources_.ok()) { + ABSL_LOG(INFO) << "GPU suport is not available: " + << gpu_resources_.status(); + gpu_resources_ = nullptr; + } + auto task_runner = TaskRunner::Create( + std::move(graph_config), + absl::make_unique(), + std::move(callback), + /* default_executor= */ nullptr, + /* input_side_packes= */ std::nullopt, std::move(*gpu_resources_)); +#else auto task_runner = TaskRunner::Create( std::move(graph_config), absl::make_unique(), std::move(callback)); +#endif // !MEDIAPIPE_DISABLE_GPU + RaisePyErrorIfNotOk(task_runner.status()); return std::move(*task_runner); }, diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index ae3d53d61..c6fae0e6c 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -211,3 +211,20 @@ py_test( "//mediapipe/tasks/python/vision/core:image_processing_options", ], ) + +py_test( + name = "face_stylizer_test", + srcs = ["face_stylizer_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_stylizer", + "//mediapipe/tasks/python/vision/core:image_processing_options", + ], +) diff --git a/mediapipe/tasks/python/test/vision/face_stylizer_test.py b/mediapipe/tasks/python/test/vision/face_stylizer_test.py new file mode 100644 index 000000000..1f6b35db4 --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_stylizer_test.py @@ -0,0 +1,191 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face stylizer.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import rect +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_stylizer +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module + + +_BaseOptions = base_options_module.BaseOptions +_Rect = rect.Rect +_Image = image_module.Image +_FaceStylizer = face_stylizer.FaceStylizer +_FaceStylizerOptions = face_stylizer.FaceStylizerOptions +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_MODEL = 'face_stylizer_color_ink.task' +_LARGE_FACE_IMAGE = 'portrait.jpg' +_MODEL_IMAGE_SIZE = 256 +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceStylizerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _MODEL) + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceStylizer.create_from_model_path(self.model_path) as stylizer: + self.assertIsInstance(stylizer, _FaceStylizer) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceStylizerOptions(base_options=base_options) + with _FaceStylizer.create_from_options(options) as stylizer: + self.assertIsInstance(stylizer, _FaceStylizer) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite' + ) + options = _FaceStylizerOptions(base_options=base_options) + _FaceStylizer.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceStylizerOptions(base_options=base_options) + stylizer = _FaceStylizer.create_from_options(options) + self.assertIsInstance(stylizer, _FaceStylizer) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE), + (ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE), + ) + def test_stylize(self, model_file_type, image_file_name): + # Load the test image. + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, image_file_name) + ) + ) + # Creates stylizer. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceStylizerOptions(base_options=base_options) + stylizer = _FaceStylizer.create_from_options(options) + + # Performs face stylization on the input. + stylized_image = stylizer.stylize(self.test_image) + self.assertIsInstance(stylized_image, _Image) + # Closes the stylizer explicitly when the stylizer is not used in + # a context. + stylizer.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE), + (ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE), + ) + def test_stylize_in_context(self, model_file_type, image_file_name): + # Load the test image. + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, image_file_name) + ) + ) + # Creates stylizer. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceStylizerOptions(base_options=base_options) + with _FaceStylizer.create_from_options(options) as stylizer: + # Performs face stylization on the input. + stylized_image = stylizer.stylize(self.test_image) + self.assertIsInstance(stylized_image, _Image) + self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE) + self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE) + + def test_stylize_succeeds_with_region_of_interest(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceStylizerOptions(base_options=base_options) + with _FaceStylizer.create_from_options(options) as stylizer: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + # Region-of-interest around the face. + roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32) + image_processing_options = _ImageProcessingOptions(roi) + # Performs face stylization on the input. + stylized_image = stylizer.stylize(test_image, image_processing_options) + self.assertIsInstance(stylized_image, _Image) + self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE) + self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE) + + def test_stylize_succeeds_with_no_face_detected(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceStylizerOptions(base_options=base_options) + with _FaceStylizer.create_from_options(options) as stylizer: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + # Region-of-interest that doesn't contain a human face. + roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2) + image_processing_options = _ImageProcessingOptions(roi) + # Performs face stylization on the input. + stylized_image = stylizer.stylize(test_image, image_processing_options) + self.assertIsNone(stylized_image) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 6e10663e0..3f83118b0 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -48,6 +48,7 @@ mediapipe_files(srcs = [ "face_landmark.tflite", "face_landmarker.task", "face_landmarker_v2.task", + "face_stylizer_color_ink.task", "fist.jpg", "fist.png", "gesture_recognizer.task", @@ -183,6 +184,7 @@ filegroup( "face_detection_short_range.tflite", "face_landmarker.task", "face_landmarker_v2.task", + "face_stylizer_color_ink.task", "hair_segmentation.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", diff --git a/mediapipe/tasks/testdata/vision/male_full_height_hands_result_cpu.pbtxt b/mediapipe/tasks/testdata/vision/male_full_height_hands_result_cpu.pbtxt index 199dc6366..e50f777c5 100644 --- a/mediapipe/tasks/testdata/vision/male_full_height_hands_result_cpu.pbtxt +++ b/mediapipe/tasks/testdata/vision/male_full_height_hands_result_cpu.pbtxt @@ -2854,7 +2854,262 @@ auxiliary_landmarks { face_blendshapes { classification { index: 0 - score: 1.6770242e-05 - label: "tongueOut" + score: 8.47715e-07 + label: "_neutral" + } + classification { + index: 1 + score: 0.020850565 + label: "browDownLeft" + } + classification { + index: 2 + score: 0.007629181 + label: "browDownRight" + } + classification { + index: 3 + score: 0.26410568 + label: "browInnerUp" + } + classification { + index: 4 + score: 0.04212071 + label: "browOuterUpLeft" + } + classification { + index: 5 + score: 0.07319052 + label: "browOuterUpRight" + } + classification { + index: 6 + score: 9.39117e-06 + label: "cheekPuff" + } + classification { + index: 7 + score: 1.9243858e-07 + label: "cheekSquintLeft" + } + classification { + index: 8 + score: 4.066475e-08 + label: "cheekSquintRight" + } + classification { + index: 9 + score: 0.46092203 + label: "eyeBlinkLeft" + } + classification { + index: 10 + score: 0.40371567 + label: "eyeBlinkRight" + } + classification { + index: 11 + score: 0.65011656 + label: "eyeLookDownLeft" + } + classification { + index: 12 + score: 0.6423024 + label: "eyeLookDownRight" + } + classification { + index: 13 + score: 0.04721973 + label: "eyeLookInLeft" + } + classification { + index: 14 + score: 0.08176838 + label: "eyeLookInRight" + } + classification { + index: 15 + score: 0.09520102 + label: "eyeLookOutLeft" + } + classification { + index: 16 + score: 0.07271895 + label: "eyeLookOutRight" + } + classification { + index: 17 + score: 0.011193463 + label: "eyeLookUpLeft" + } + classification { + index: 18 + score: 0.007041815 + label: "eyeLookUpRight" + } + classification { + index: 19 + score: 0.27120194 + label: "eyeSquintLeft" + } + classification { + index: 20 + score: 0.21675573 + label: "eyeSquintRight" + } + classification { + index: 21 + score: 0.0018824162 + label: "eyeWideLeft" + } + classification { + index: 22 + score: 0.0011966582 + label: "eyeWideRight" + } + classification { + index: 23 + score: 1.9298719e-05 + label: "jawForward" + } + classification { + index: 24 + score: 9.670858e-06 + label: "jawLeft" + } + classification { + index: 25 + score: 0.000115385694 + label: "jawOpen" + } + classification { + index: 26 + score: 0.00023342477 + label: "jawRight" + } + classification { + index: 27 + score: 2.8894076e-05 + label: "mouthClose" + } + classification { + index: 28 + score: 0.003933548 + label: "mouthDimpleLeft" + } + classification { + index: 29 + score: 0.0051949574 + label: "mouthDimpleRight" + } + classification { + index: 30 + score: 0.00067943585 + label: "mouthFrownLeft" + } + classification { + index: 31 + score: 0.0006520291 + label: "mouthFrownRight" + } + classification { + index: 32 + score: 0.0006695333 + label: "mouthFunnel" + } + classification { + index: 33 + score: 8.578597e-05 + label: "mouthLeft" + } + classification { + index: 34 + score: 2.6707421e-05 + label: "mouthLowerDownLeft" + } + classification { + index: 35 + score: 2.153054e-05 + label: "mouthLowerDownRight" + } + classification { + index: 36 + score: 0.0132145975 + label: "mouthPressLeft" + } + classification { + index: 37 + score: 0.009528495 + label: "mouthPressRight" + } + classification { + index: 38 + score: 0.056963783 + label: "mouthPucker" + } + classification { + index: 39 + score: 0.027331185 + label: "mouthRight" + } + classification { + index: 40 + score: 0.00072388636 + label: "mouthRollLower" + } + classification { + index: 41 + score: 0.00021191382 + label: "mouthRollUpper" + } + classification { + index: 42 + score: 0.23938002 + label: "mouthShrugLower" + } + classification { + index: 43 + score: 0.052946873 + label: "mouthShrugUpper" + } + classification { + index: 44 + score: 0.68681276 + label: "mouthSmileLeft" + } + classification { + index: 45 + score: 0.68557316 + label: "mouthSmileRight" + } + classification { + index: 46 + score: 0.0030625665 + label: "mouthStretchLeft" + } + classification { + index: 47 + score: 0.003999545 + label: "mouthStretchRight" + } + classification { + index: 48 + score: 0.013184475 + label: "mouthUpperUpLeft" + } + classification { + index: 49 + score: 0.017995607 + label: "mouthUpperUpRight" + } + classification { + index: 50 + score: 2.0452394e-06 + label: "noseSneerLeft" + } + classification { + index: 51 + score: 3.7912793e-07 + label: "noseSneerRight" } } diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index dfbbb9f91..31bad937d 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -31,27 +31,57 @@ mediapipe_ts_library( mediapipe_ts_library( name = "drawing_utils", - srcs = ["drawing_utils.ts"], + srcs = [ + "drawing_utils.ts", + "drawing_utils_category_mask.ts", + ], deps = [ + ":image", + ":image_shader_context", + ":mask", ":types", "//mediapipe/tasks/web/components/containers:bounding_box", "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) mediapipe_ts_library( - name = "image", - srcs = [ - "image.ts", - "image_shader_context.ts", + name = "drawing_utils_test_lib", + testonly = True, + srcs = ["drawing_utils.test.ts"], + deps = [ + ":drawing_utils", + ":image", + ":image_shader_context", + ":mask", ], ) +jasmine_node_test( + name = "drawing_utils_test", + deps = [":drawing_utils_test_lib"], +) + +mediapipe_ts_library( + name = "image", + srcs = ["image.ts"], + deps = ["image_shader_context"], +) + +mediapipe_ts_library( + name = "image_shader_context", + srcs = ["image_shader_context.ts"], +) + mediapipe_ts_library( name = "image_test_lib", testonly = True, srcs = ["image.test.ts"], - deps = [":image"], + deps = [ + ":image", + ":image_shader_context", + ], ) jasmine_node_test( @@ -64,6 +94,7 @@ mediapipe_ts_library( srcs = ["mask.ts"], deps = [ ":image", + ":image_shader_context", "//mediapipe/web/graph_runner:platform_utils", ], ) @@ -74,6 +105,7 @@ mediapipe_ts_library( srcs = ["mask.test.ts"], deps = [ ":image", + ":image_shader_context", ":mask", ], ) @@ -89,6 +121,7 @@ mediapipe_ts_library( deps = [ ":image", ":image_processing_options", + ":image_shader_context", ":mask", ":vision_task_options", "//mediapipe/framework/formats:rect_jspb_proto", diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.test.ts b/mediapipe/tasks/web/vision/core/drawing_utils.test.ts new file mode 100644 index 000000000..b5ba8e9a4 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/drawing_utils.test.ts @@ -0,0 +1,103 @@ +/** + * Copyright 2023 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {DrawingUtils} from './drawing_utils'; +import {MPImageShaderContext} from './image_shader_context'; +import {MPMask} from './mask'; + +const WIDTH = 2; +const HEIGHT = 2; + +const skip = typeof document === 'undefined'; +if (skip) { + console.log('These tests must be run in a browser.'); +} + +(skip ? xdescribe : describe)('DrawingUtils', () => { + let shaderContext = new MPImageShaderContext(); + let canvas2D: HTMLCanvasElement; + let context2D: CanvasRenderingContext2D; + let drawingUtils2D: DrawingUtils; + let canvasWebGL: HTMLCanvasElement; + let contextWebGL: WebGL2RenderingContext; + let drawingUtilsWebGL: DrawingUtils; + + beforeEach(() => { + shaderContext = new MPImageShaderContext(); + + canvasWebGL = document.createElement('canvas'); + canvasWebGL.width = WIDTH; + canvasWebGL.height = HEIGHT; + contextWebGL = canvasWebGL.getContext('webgl2')!; + drawingUtilsWebGL = new DrawingUtils(contextWebGL); + + canvas2D = document.createElement('canvas'); + canvas2D.width = WIDTH; + canvas2D.height = HEIGHT; + context2D = canvas2D.getContext('2d')!; + drawingUtils2D = new DrawingUtils(context2D, contextWebGL); + }); + + afterEach(() => { + shaderContext.close(); + drawingUtils2D.close(); + drawingUtilsWebGL.close(); + }); + + describe('drawCategoryMask() ', () => { + const colors = [ + [0, 0, 0, 255], + [0, 255, 0, 255], + [0, 0, 255, 255], + [255, 255, 255, 255], + ]; + const expectedResult = new Uint8Array( + [0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255], + ); + + it('on 2D canvas', () => { + const categoryMask = new MPMask( + [new Uint8Array([0, 1, 2, 3])], + /* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH, + HEIGHT); + + drawingUtils2D.drawCategoryMask(categoryMask, colors); + + const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data; + expect(actualResult) + .toEqual(new Uint8ClampedArray(expectedResult.buffer)); + }); + + it('on WebGL canvas', () => { + const categoryMask = new MPMask( + [new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped + /* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH, + HEIGHT); + + drawingUtilsWebGL.drawCategoryMask(categoryMask, colors); + + const actualResult = new Uint8Array(WIDTH * WIDTH * 4); + contextWebGL.readPixels( + 0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE, + actualResult); + expect(actualResult).toEqual(expectedResult); + }); + }); + + // TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox +}); diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.ts b/mediapipe/tasks/web/vision/core/drawing_utils.ts index c1e84fa11..796d7dcb6 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils.ts @@ -16,7 +16,11 @@ import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {CategoryMaskShaderContext, CategoryToColorMap, RGBAColor} from '../../../../tasks/web/vision/core/drawing_utils_category_mask'; +import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; +import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {Connection} from '../../../../tasks/web/vision/core/types'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; /** * A user-defined callback to take input data and map it to a custom output @@ -24,6 +28,9 @@ import {Connection} from '../../../../tasks/web/vision/core/types'; */ export type Callback = (input: I) => O; +// Used in public API +export {ImageSource}; + /** Data that a user can use to specialize drawing options. */ export declare interface LandmarkData { index?: number; @@ -31,6 +38,32 @@ export declare interface LandmarkData { to?: NormalizedLandmark; } +/** A color map with 22 classes. Used in our demos. */ +export const DEFAULT_CATEGORY_TO_COLOR_MAP = [ + [0, 0, 0, 0], // class 0 is BG = transparent + [255, 0, 0, 255], // class 1 is red + [0, 255, 0, 255], // class 2 is light green + [0, 0, 255, 255], // class 3 is blue + [255, 255, 0, 255], // class 4 is yellow + [255, 0, 255, 255], // class 5 is light purple / magenta + [0, 255, 255, 255], // class 6 is light blue / aqua + [128, 128, 128, 255], // class 7 is gray + [255, 100, 0, 255], // class 8 is dark orange + [128, 0, 255, 255], // class 9 is dark purple + [0, 150, 0, 255], // class 10 is green + [255, 255, 255, 255], // class 11 is white + [255, 105, 180, 255], // class 12 is pink + [255, 150, 0, 255], // class 13 is orange + [255, 250, 224, 255], // class 14 is light yellow + [148, 0, 211, 255], // class 15 is dark violet + [0, 100, 0, 255], // class 16 is dark green + [0, 0, 128, 255], // class 17 is navy blue + [165, 42, 42, 255], // class 18 is brown + [64, 224, 208, 255], // class 19 is turquoise + [255, 218, 185, 255], // class 20 is peach + [192, 192, 192, 255], // class 21 is silver +]; + /** * Options for customizing the drawing routines */ @@ -77,14 +110,50 @@ function resolve(value: O|Callback, data: I): O { return value instanceof Function ? value(data) : value; } +export {RGBAColor, CategoryToColorMap}; + /** Helper class to visualize the result of a MediaPipe Vision task. */ export class DrawingUtils { + private categoryMaskShaderContext?: CategoryMaskShaderContext; + private convertToWebGLTextureShaderContext?: MPImageShaderContext; + private readonly context2d?: CanvasRenderingContext2D| + OffscreenCanvasRenderingContext2D; + private readonly contextWebGL?: WebGL2RenderingContext; + /** * Creates a new DrawingUtils class. * - * @param ctx The canvas to render onto. + * @param gpuContext The WebGL canvas rendering context to render into. If + * your Task is using a GPU delegate, the context must be obtained from + * its canvas (provided via `setOptions({ canvas: .. })`). */ - constructor(private readonly ctx: CanvasRenderingContext2D) {} + constructor(gpuContext: WebGL2RenderingContext); + /** + * Creates a new DrawingUtils class. + * + * @param cpuContext The 2D canvas rendering context to render into. If + * you are rendering GPU data you must also provide `gpuContext` to allow + * for data conversion. + * @param gpuContext A WebGL canvas that is used for GPU rendering and for + * converting GPU to CPU data. If your Task is using a GPU delegate, the + * context must be obtained from its canvas (provided via + * `setOptions({ canvas: .. })`). + */ + constructor( + cpuContext: CanvasRenderingContext2D|OffscreenCanvasRenderingContext2D, + gpuContext?: WebGL2RenderingContext); + constructor( + cpuOrGpuGontext: CanvasRenderingContext2D| + OffscreenCanvasRenderingContext2D|WebGL2RenderingContext, + gpuContext?: WebGL2RenderingContext) { + if (cpuOrGpuGontext instanceof CanvasRenderingContext2D || + cpuOrGpuGontext instanceof OffscreenCanvasRenderingContext2D) { + this.context2d = cpuOrGpuGontext; + this.contextWebGL = gpuContext; + } else { + this.contextWebGL = cpuOrGpuGontext; + } + } /** * Restricts a number between two endpoints (order doesn't matter). @@ -120,9 +189,36 @@ export class DrawingUtils { return DrawingUtils.clamp(out, y0, y1); } + private getCanvasRenderingContext(): CanvasRenderingContext2D + |OffscreenCanvasRenderingContext2D { + if (!this.context2d) { + throw new Error( + 'CPU rendering requested but CanvasRenderingContext2D not provided.'); + } + return this.context2d; + } + + private getWebGLRenderingContext(): WebGL2RenderingContext { + if (!this.contextWebGL) { + throw new Error( + 'GPU rendering requested but WebGL2RenderingContext not provided.'); + } + return this.contextWebGL; + } + + private getCategoryMaskShaderContext(): CategoryMaskShaderContext { + if (!this.categoryMaskShaderContext) { + this.categoryMaskShaderContext = new CategoryMaskShaderContext(); + } + return this.categoryMaskShaderContext; + } + /** * Draws circles onto the provided landmarks. * + * This method can only be used when `DrawingUtils` is initialized with a + * `CanvasRenderingContext2D`. + * * @export * @param landmarks The landmarks to draw. * @param style The style to visualize the landmarks. @@ -132,7 +228,7 @@ export class DrawingUtils { if (!landmarks) { return; } - const ctx = this.ctx; + const ctx = this.getCanvasRenderingContext(); const options = addDefaultOptions(style); ctx.save(); const canvas = ctx.canvas; @@ -159,6 +255,9 @@ export class DrawingUtils { /** * Draws lines between landmarks (given a connection graph). * + * This method can only be used when `DrawingUtils` is initialized with a + * `CanvasRenderingContext2D`. + * * @export * @param landmarks The landmarks to draw. * @param connections The connections array that contains the start and the @@ -171,7 +270,7 @@ export class DrawingUtils { if (!landmarks || !connections) { return; } - const ctx = this.ctx; + const ctx = this.getCanvasRenderingContext(); const options = addDefaultOptions(style); ctx.save(); const canvas = ctx.canvas; @@ -195,12 +294,15 @@ export class DrawingUtils { /** * Draws a bounding box. * + * This method can only be used when `DrawingUtils` is initialized with a + * `CanvasRenderingContext2D`. + * * @export * @param boundingBox The bounding box to draw. * @param style The style to visualize the boundin box. */ drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void { - const ctx = this.ctx; + const ctx = this.getCanvasRenderingContext(); const options = addDefaultOptions(style); ctx.save(); ctx.beginPath(); @@ -218,6 +320,118 @@ export class DrawingUtils { ctx.fill(); ctx.restore(); } + + /** Draws a category mask on a CanvasRenderingContext2D. */ + private drawCategoryMask2D( + mask: MPMask, background: RGBAColor|ImageSource, + categoryToColorMap: Map|RGBAColor[]): void { + // Use the WebGL renderer to draw result on our internal canvas. + const gl = this.getWebGLRenderingContext(); + this.runWithWebGLTexture(mask, texture => { + this.drawCategoryMaskWebGL(texture, background, categoryToColorMap); + // Draw the result on the user canvas. + const ctx = this.getCanvasRenderingContext(); + ctx.drawImage(gl.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height); + }); + } + + /** Draws a category mask on a WebGL2RenderingContext2D. */ + private drawCategoryMaskWebGL( + categoryTexture: WebGLTexture, background: RGBAColor|ImageSource, + categoryToColorMap: Map|RGBAColor[]): void { + const shaderContext = this.getCategoryMaskShaderContext(); + const gl = this.getWebGLRenderingContext(); + const backgroundImage = Array.isArray(background) ? + new ImageData(new Uint8ClampedArray(background), 1, 1) : + background; + + shaderContext.run(gl, /* flipTexturesVertically= */ true, () => { + shaderContext.bindAndUploadTextures( + categoryTexture, backgroundImage, categoryToColorMap); + gl.clearColor(0, 0, 0, 0); + gl.clear(gl.COLOR_BUFFER_BIT); + gl.drawArrays(gl.TRIANGLE_FAN, 0, 4); + shaderContext.unbindTextures(); + }); + } + + /** + * Draws a category mask using the provided category-to-color mapping. + * + * @export + * @param mask A category mask that was returned from a segmentation task. + * @param categoryToColorMap A map that maps category indices to RGBA + * values. You must specify a map entry for each category. + * @param background A color or image to use as the background. Defaults to + * black. + */ + drawCategoryMask( + mask: MPMask, categoryToColorMap: Map, + background?: RGBAColor|ImageSource): void; + /** + * Draws a category mask using the provided color array. + * + * @export + * @param mask A category mask that was returned from a segmentation task. + * @param categoryToColorMap An array that maps indices to RGBA values. The + * array's indices must correspond to the category indices of the model + * and an entry must be provided for each category. + * @param background A color or image to use as the background. Defaults to + * black. + */ + drawCategoryMask( + mask: MPMask, categoryToColorMap: RGBAColor[], + background?: RGBAColor|ImageSource): void; + drawCategoryMask( + mask: MPMask, categoryToColorMap: CategoryToColorMap, + background: RGBAColor|ImageSource = [0, 0, 0, 255]): void { + if (this.context2d) { + this.drawCategoryMask2D(mask, background, categoryToColorMap); + } else { + this.drawCategoryMaskWebGL( + mask.getAsWebGLTexture(), background, categoryToColorMap); + } + } + + /** + * Converts the given mask to a WebGLTexture and runs the callback. Cleans + * up any new resources after the callback finished executing. + */ + private runWithWebGLTexture( + mask: MPMask, callback: (texture: WebGLTexture) => void): void { + if (!mask.hasWebGLTexture()) { + // Re-create the MPMask but use our the WebGL canvas so we can draw the + // texture directly. + const data = mask.hasFloat32Array() ? mask.getAsFloat32Array() : + mask.getAsUint8Array(); + this.convertToWebGLTextureShaderContext = + this.convertToWebGLTextureShaderContext ?? new MPImageShaderContext(); + const gl = this.getWebGLRenderingContext(); + + const convertedMask = new MPMask( + [data], + /* ownsWebGlTexture= */ false, + gl.canvas, + this.convertToWebGLTextureShaderContext, + mask.width, + mask.height, + ); + callback(convertedMask.getAsWebGLTexture()); + convertedMask.close(); + } else { + callback(mask.getAsWebGLTexture()); + } + } + /** + * Frees all WebGL resources held by this class. + * @export + */ + close(): void { + this.categoryMaskShaderContext?.close(); + this.categoryMaskShaderContext = undefined; + this.convertToWebGLTextureShaderContext?.close(); + this.convertToWebGLTextureShaderContext = undefined; + } } diff --git a/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts b/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts new file mode 100644 index 000000000..d7706075f --- /dev/null +++ b/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts @@ -0,0 +1,189 @@ +/** + * Copyright 2023 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +/** + * A fragment shader that maps categories to colors based on a background + * texture, a mask texture and a 256x1 "color mapping texture" that contains one + * color for each pixel. + */ +const FRAGMENT_SHADER = ` + precision mediump float; + uniform sampler2D backgroundTexture; + uniform sampler2D maskTexture; + uniform sampler2D colorMappingTexture; + varying vec2 vTex; + void main() { + vec4 backgroundColor = texture2D(backgroundTexture, vTex); + float category = texture2D(maskTexture, vTex).r; + vec4 categoryColor = texture2D(colorMappingTexture, vec2(category, 0.0)); + gl_FragColor = mix(backgroundColor, categoryColor, categoryColor.a); + } + `; + +/** + * A four channel color with values for red, green, blue and alpha + * respectively. + */ +export type RGBAColor = [number, number, number, number]|number[]; + +/** + * A category to color mapping that uses either a map or an array to assign + * category indexes to RGBA colors. + */ +export type CategoryToColorMap = Map|RGBAColor[]; + + +/** Checks CategoryToColorMap maps for deep equality. */ +function isEqualColorMap( + a: CategoryToColorMap, b: CategoryToColorMap): boolean { + if (a !== b) { + return false; + } + + const aEntries = a.entries(); + const bEntries = b.entries(); + for (const [aKey, aValue] of aEntries) { + const bNext = bEntries.next(); + if (bNext.done) { + return false; + } + + const [bKey, bValue] = bNext.value; + if (aKey !== bKey) { + return false; + } + + if (aValue[0] !== bValue[0] || aValue[1] !== bValue[1] || + aValue[2] !== bValue[2] || aValue[3] !== bValue[3]) { + return false; + } + } + return !!bEntries.next().done; +} + + +/** A drawing util class for category masks. */ +export class CategoryMaskShaderContext extends MPImageShaderContext { + backgroundTexture?: WebGLTexture; + colorMappingTexture?: WebGLTexture; + colorMappingTextureUniform?: WebGLUniformLocation; + backgroundTextureUniform?: WebGLUniformLocation; + maskTextureUniform?: WebGLUniformLocation; + currentColorMap?: CategoryToColorMap; + + bindAndUploadTextures( + categoryMask: WebGLTexture, background: ImageSource, + colorMap: Map|number[][]) { + const gl = this.gl!; + + // TODO: We should avoid uploading textures from CPU to GPU + // if the textures haven't changed. This can lead to drastic performance + // slowdowns (~50ms per frame). Users can reduce the penalty by passing a + // canvas object instead of ImageData/HTMLImageElement. + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background); + + // Bind color mapping texture if changed. + if (!this.currentColorMap || + !isEqualColorMap(this.currentColorMap, colorMap)) { + this.currentColorMap = colorMap; + + const pixels = new Array(256 * 4).fill(0); + colorMap.forEach((rgba, index) => { + if (rgba.length !== 4) { + throw new Error( + `Color at index ${index} is not a four-channel value.`); + } + pixels[index * 4] = rgba[0]; + pixels[index * 4 + 1] = rgba[1]; + pixels[index * 4 + 2] = rgba[2]; + pixels[index * 4 + 3] = rgba[3]; + }); + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE, + new Uint8Array(pixels)); + } else { + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!); + } + + // Bind category mask + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, categoryMask); + } + + unbindTextures() { + const gl = this.gl!; + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, null); + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_2D, null); + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, null); + } + + protected override getFragmentShader(): string { + return FRAGMENT_SHADER; + } + + protected override setupTextures(): void { + const gl = this.gl!; + gl.activeTexture(gl.TEXTURE0); + this.backgroundTexture = this.createTexture(gl, gl.LINEAR); + // Use `gl.NEAREST` to prevent interpolating values in our category to + // color map. + this.colorMappingTexture = this.createTexture(gl, gl.NEAREST); + } + + protected override setupShaders(): void { + super.setupShaders(); + const gl = this.gl!; + this.backgroundTextureUniform = assertNotNull( + gl.getUniformLocation(this.program!, 'backgroundTexture'), + 'Uniform location'); + this.colorMappingTextureUniform = assertNotNull( + gl.getUniformLocation(this.program!, 'colorMappingTexture'), + 'Uniform location'); + this.maskTextureUniform = assertNotNull( + gl.getUniformLocation(this.program!, 'maskTexture'), + 'Uniform location'); + } + + protected override configureUniforms(): void { + super.configureUniforms(); + const gl = this.gl!; + gl.uniform1i(this.backgroundTextureUniform!, 0); + gl.uniform1i(this.colorMappingTextureUniform!, 1); + gl.uniform1i(this.maskTextureUniform!, 2); + } + + override close(): void { + if (this.backgroundTexture) { + this.gl!.deleteTexture(this.backgroundTexture); + } + if (this.colorMappingTexture) { + this.gl!.deleteTexture(this.colorMappingTexture); + } + super.close(); + } +} diff --git a/mediapipe/tasks/web/vision/core/image_shader_context.ts b/mediapipe/tasks/web/vision/core/image_shader_context.ts index eb17d001a..3dec9da95 100644 --- a/mediapipe/tasks/web/vision/core/image_shader_context.ts +++ b/mediapipe/tasks/web/vision/core/image_shader_context.ts @@ -27,9 +27,9 @@ const FRAGMENT_SHADER = ` precision mediump float; varying vec2 vTex; uniform sampler2D inputTexture; - void main() { - gl_FragColor = texture2D(inputTexture, vTex); - } + void main() { + gl_FragColor = texture2D(inputTexture, vTex); + } `; /** Helper to assert that `value` is not null. */ @@ -73,9 +73,9 @@ class MPImageShaderBuffers { * For internal use only. */ export class MPImageShaderContext { - private gl?: WebGL2RenderingContext; + protected gl?: WebGL2RenderingContext; private framebuffer?: WebGLFramebuffer; - private program?: WebGLProgram; + protected program?: WebGLProgram; private vertexShader?: WebGLShader; private fragmentShader?: WebGLShader; private aVertex?: GLint; @@ -94,6 +94,14 @@ export class MPImageShaderContext { */ private shaderBuffersFlipVertically?: MPImageShaderBuffers; + protected getFragmentShader(): string { + return FRAGMENT_SHADER; + } + + protected getVertexShader(): string { + return VERTEX_SHADER; + } + private compileShader(source: string, type: number): WebGLShader { const gl = this.gl!; const shader = @@ -108,14 +116,15 @@ export class MPImageShaderContext { return shader; } - private setupShaders(): void { + protected setupShaders(): void { const gl = this.gl!; this.program = assertNotNull(gl.createProgram()!, 'Failed to create WebGL program'); - this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER); + this.vertexShader = + this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER); this.fragmentShader = - this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER); + this.compileShader(this.getFragmentShader(), gl.FRAGMENT_SHADER); gl.linkProgram(this.program); const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS); @@ -128,6 +137,10 @@ export class MPImageShaderContext { this.aTex = gl.getAttribLocation(this.program, 'aTex'); } + protected setupTextures(): void {} + + protected configureUniforms(): void {} + private createBuffers(flipVertically: boolean): MPImageShaderBuffers { const gl = this.gl!; const vertexArrayObject = @@ -193,17 +206,44 @@ export class MPImageShaderContext { if (!this.program) { this.setupShaders(); + this.setupTextures(); } const shaderBuffers = this.getShaderBuffers(flipVertically); gl.useProgram(this.program!); shaderBuffers.bind(); + this.configureUniforms(); const result = callback(); shaderBuffers.unbind(); return result; } + /** + * Creates and configures a texture. + * + * @param gl The rendering context. + * @param filter The setting to use for `gl.TEXTURE_MIN_FILTER` and + * `gl.TEXTURE_MAG_FILTER`. Defaults to `gl.LINEAR`. + * @param wrapping The setting to use for `gl.TEXTURE_WRAP_S` and + * `gl.TEXTURE_WRAP_T`. Defaults to `gl.CLAMP_TO_EDGE`. + */ + createTexture(gl: WebGL2RenderingContext, filter?: GLenum, wrapping?: GLenum): + WebGLTexture { + this.maybeInitGL(gl); + const texture = + assertNotNull(gl.createTexture(), 'Failed to create texture'); + gl.bindTexture(gl.TEXTURE_2D, texture); + gl.texParameteri( + gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE); + gl.texParameteri( + gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, wrapping ?? gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter ?? gl.LINEAR); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter ?? gl.LINEAR); + gl.bindTexture(gl.TEXTURE_2D, null); + return texture; + } + /** * Binds a framebuffer to the canvas. If the framebuffer does not yet exist, * creates it first. Binds the provided texture to the framebuffer. diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts index ebb3be16a..3ee981bab 100644 --- a/mediapipe/tasks/web/vision/core/render_utils.ts +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -16,24 +16,6 @@ * limitations under the License. */ -// Pre-baked color table for a maximum of 12 classes. -const CM_ALPHA = 128; -const COLOR_MAP: Array<[number, number, number, number]> = [ - [0, 0, 0, CM_ALPHA], // class 0 is BG = transparent - [255, 0, 0, CM_ALPHA], // class 1 is red - [0, 255, 0, CM_ALPHA], // class 2 is light green - [0, 0, 255, CM_ALPHA], // class 3 is blue - [255, 255, 0, CM_ALPHA], // class 4 is yellow - [255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta - [0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua - [128, 128, 128, CM_ALPHA], // class 7 is gray - [255, 128, 0, CM_ALPHA], // class 8 is orange - [128, 0, 255, CM_ALPHA], // class 9 is dark purple - [0, 128, 0, CM_ALPHA], // class 10 is dark green - [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? -]; - - /** Helper function to draw a confidence mask */ export function drawConfidenceMask( ctx: CanvasRenderingContext2D, image: Float32Array, width: number, @@ -47,23 +29,3 @@ export function drawConfidenceMask( } ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0); } - -/** - * Helper function to draw a category mask. For GPU, we only have F32Arrays - * for now. - */ -export function drawCategoryMask( - ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array, - width: number, height: number): void { - const rgbaArray = new Uint8ClampedArray(width * height * 4); - const isFloatArray = image instanceof Float32Array; - for (let i = 0; i < image.length; i++) { - const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; - const color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - rgbaArray[4 * i] = color[0]; - rgbaArray[4 * i + 1] = color[1]; - rgbaArray[4 * i + 2] = color[2]; - rgbaArray[4 * i + 3] = color[3]; - } - ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0); -} diff --git a/mediapipe/web/graph_runner/platform_utils.ts b/mediapipe/web/graph_runner/platform_utils.ts index d86e002de..a9a62a884 100644 --- a/mediapipe/web/graph_runner/platform_utils.ts +++ b/mediapipe/web/graph_runner/platform_utils.ts @@ -32,6 +32,5 @@ export function isIOS() { // tslint:disable-next-line:deprecation ].includes(navigator.platform) // iPad on iOS 13 detection - || (navigator.userAgent.includes('Mac') && - (typeof document !== undefined && 'ontouchend' in document)); + || (navigator.userAgent.includes('Mac') && 'ontouchend' in self.document); } diff --git a/setup.py b/setup.py index b5b75b73c..aa6004b7e 100644 --- a/setup.py +++ b/setup.py @@ -272,13 +272,14 @@ class BuildModules(build_ext.build_ext): self._download_external_file(external_file) binary_graphs = [ - 'face_detection/face_detection_short_range_cpu', - 'face_detection/face_detection_full_range_cpu', - 'face_landmark/face_landmark_front_cpu', - 'hand_landmark/hand_landmark_tracking_cpu', - 'holistic_landmark/holistic_landmark_cpu', 'objectron/objectron_cpu', - 'pose_landmark/pose_landmark_cpu', - 'selfie_segmentation/selfie_segmentation_cpu' + 'face_detection/face_detection_short_range_cpu.binarypb', + 'face_detection/face_detection_full_range_cpu.binarypb', + 'face_landmark/face_landmark_front_cpu.binarypb', + 'hand_landmark/hand_landmark_tracking_cpu.binarypb', + 'holistic_landmark/holistic_landmark_cpu.binarypb', + 'objectron/objectron_cpu.binarypb', + 'pose_landmark/pose_landmark_cpu.binarypb', + 'selfie_segmentation/selfie_segmentation_cpu.binarypb' ] for elem in binary_graphs: binary_graph = os.path.join('mediapipe/modules/', elem) @@ -312,7 +313,7 @@ class BuildModules(build_ext.build_ext): bazel_command.append('--define=OPENCV=source') _invoke_shell_command(bazel_command) - _copy_to_build_lib_dir(self.build_lib, binary_graph_target + '.binarypb') + _copy_to_build_lib_dir(self.build_lib, binary_graph_target) class GenerateMetadataSchema(build_ext.build_ext): diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 750421b6e..41abd6270 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -424,6 +424,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylization_dummy.tflite?generation=1678323589048063"], ) + http_file( + name = "com_google_mediapipe_face_stylizer_color_ink_task", + sha256 = "887a490b74046ecb2b1d092cc0173a961b4ed3640aaadeafa852b1122ca23b2a", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_stylizer_color_ink.task?generation=1697732437695259"], + ) + http_file( name = "com_google_mediapipe_face_stylizer_json", sha256 = "ad89860d5daba6a1c4163a576428713fc3ddab76d6bbaf06d675164423ae159f", @@ -550,6 +556,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/hand_roi_refinement_generated_graph.pbtxt?generation=1695159196033618"], ) + http_file( + name = "com_google_mediapipe_holistic_hand_tracking_left_hand_graph_pbtxt", + sha256 = "c964589b448471c0cd9e0f68c243e232e6f8a4c0959b41a3cd1cbb14e9efa6b1", + urls = ["https://storage.googleapis.com/mediapipe-assets/holistic_hand_tracking_left_hand_graph.pbtxt?generation=1697732440362430"], + ) + + http_file( + name = "com_google_mediapipe_holistic_pose_tracking_graph_pbtxt", + sha256 = "1d36d014d38c09fea73042471d5d1a616f3cc9f22c8ca625deabc38efd63f6aa", + urls = ["https://storage.googleapis.com/mediapipe-assets/holistic_pose_tracking_graph.pbtxt?generation=1697732442566093"], + ) + http_file( name = "com_google_mediapipe_image_tensor_meta_json", sha256 = "aad86fde3defb379c82ff7ee48e50493a58529cdc0623cf0d7bf135c3577060e", diff --git a/third_party/org_tensorflow_objc_build_fixes.diff b/third_party/org_tensorflow_objc_build_fixes.diff new file mode 100644 index 000000000..db7b827a9 --- /dev/null +++ b/third_party/org_tensorflow_objc_build_fixes.diff @@ -0,0 +1,86 @@ +diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD +index 875c2a4f3da..e513db47388 100644 +--- a/tensorflow/lite/delegates/gpu/BUILD ++++ b/tensorflow/lite/delegates/gpu/BUILD +@@ -70,14 +70,17 @@ cc_library( + }) + tflite_extra_gles_deps(), + ) + +-objc_library( ++cc_library( + name = "metal_delegate", +- srcs = ["metal_delegate.mm"], ++ srcs = ["metal_delegate.cc"], + hdrs = ["metal_delegate.h"], +- copts = ["-std=c++17"], ++ copts = [ ++ "-ObjC++", ++ "-std=c++17", ++ "-fobjc-arc", ++ ], ++ linkopts = ["-framework Metal"], + features = ["-layering_check"], +- module_name = "TensorFlowLiteCMetal", +- sdk_frameworks = ["Metal"], + deps = [ + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", +@@ -98,14 +101,20 @@ objc_library( + "//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor", + "@com_google_absl//absl/types:span", + ], ++ alwayslink = 1, + ) + +-objc_library( ++cc_library( + name = "metal_delegate_internal", + hdrs = ["metal_delegate_internal.h"], +- copts = ["-std=c++17"], +- sdk_frameworks = ["Metal"], ++ copts = [ ++ "-ObjC++", ++ "-std=c++17", ++ "-fobjc-arc", ++ ], ++ linkopts = ["-framework Metal"], + deps = ["//tensorflow/lite/delegates/gpu:metal_delegate"], ++ alwayslink = 1, + ) + + # build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --linkopt -s --strip always :libtensorflowlite_gpu_gl.so +diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD +index 8571ff7f041..82e6bb91d2d 100644 +--- a/tensorflow/lite/delegates/gpu/metal/BUILD ++++ b/tensorflow/lite/delegates/gpu/metal/BUILD +@@ -137,15 +137,16 @@ objc_library( + ], + ) + +-objc_library( ++cc_library( + name = "inference_context", + srcs = ["inference_context.cc"], + hdrs = ["inference_context.h"], + copts = DEFAULT_COPTS + [ + "-ObjC++", ++ "-fobjc-arc", + ], + features = ["-layering_check"], +- sdk_frameworks = ["Metal"], ++ linkopts = ["-framework Metal"], + deps = [ + ":compute_task", + ":inference_context_cc_fbs", +@@ -171,6 +172,7 @@ objc_library( + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], ++ alwayslink = 1, + ) + + flatbuffer_cc_library( +diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.cc +similarity index 100% +rename from tensorflow/lite/delegates/gpu/metal_delegate.mm +rename to tensorflow/lite/delegates/gpu/metal_delegate.cc diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index f6256b91b..58c0570e9 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,72 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "bb10ba65b0135f3d22c380bd87712f6a859ecdebdf1e3243407bc2f3ac5ccf71", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1696624044559284"], + sha256 = "8722e1047a54dcd08206d018a4bc348dd820f479cb10218c5cbcd411dd9e1c0c", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1698954798232640"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "01079a05e1ce4963e3a78cd5fee8f33be9fda10f1c6c450d00cf71251d817e0c", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1696624047350232"], + sha256 = "bcd230238dbabdf09eab58dbbe7e36deacf7e3fc57c2d67af679188d37731883", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1698954800502145"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", - sha256 = "c9484189261601052357359edd4a8575b0d51d0ce12cf3343d12a933307303c6", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1696624049287593"], + sha256 = "a5d80eefde268611ed385b90fab9defc37df50a124a15282961dbaa30b62c14d", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1698954802474619"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", - sha256 = "c00c3455b9ca1f477879383dde7a5853e1bac06cad8c48088c53c6bd1bcd3890", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1696624051752560"], + sha256 = "2fc431cc62330332c0c1e730d44b933a79e4572be0dc5c5a82635bd5dc330b94", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1698954804758365"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "b5a8b98ac3927c28b462f8a0f9523fb7b0b6ac720a0009ba2c4f211516cc5a5e", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1696624053716124"], + sha256 = "b100d299cb06c0fd7cf40099653e8d4a3ac953937402a5d7c3a3a02fa59d8105", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1698954806886809"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "a5f28b7aa458f7e34c78cb90aac0fc3a228a5d0f3ae675fb7a5e2dca4299363c", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1696624056400203"], + sha256 = "ae1b8f9684b9afa989b1144f25a2ae1bda809c811367475567c823d65d4fef0a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1698954809121561"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", - sha256 = "0eb6417d62f860161b504a72c88659a6a866b36e477da600630bebaca01bf7c6", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1696624058379631"], + sha256 = "d8db720214acfa1b758099daeb07c02e04b7221805523e9b6926a1f11ec00183", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1698954811167986"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", - sha256 = "af9e62a8e63ea38e3f984f2937cd856cb6a599cdda169bb1c29169a1a08f60f9", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1696624060772680"], + sha256 = "44b8e5be980e6fe79fa9a8b02551ef50e1d74682dd8f3e6cf92435cf43e8ef91", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1698954813498288"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "addc64f89585eaa7322f649158d57ac70fd0e9ae0d37f4aeb4016b04d0e18d2a", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1696624062874405"], + sha256 = "f5ba7b1d0adad63c581a80113567913a7106b20f8d26982f82c56998c7d44465", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1698954815469471"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "357a65e76313ceb65ffec453c47275abfc387b7b6e77823b6c4c016ab7a40cf5", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1696624065934602"], + sha256 = "d502a753b40626a36734806599bf0e765cf3a611653d980b39a5474998f1d6fe", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1698954817976682"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", - sha256 = "c41e7c4b00d2ab701d6e805917a26da7b563737fd12671f2f3496c036c94d633", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1696624067936821"], + sha256 = "731786df74b19150eecc8fe69ddf16040bbbba8cf2d22c964ef38ecef25d1e1f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1698954819912485"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", - sha256 = "620986d18baa69934579bbd701d72532cedc2a6b052932cb19e302bd748afcba", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1696624070556427"], + sha256 = "2c16ecc52398857c5ce45d58c98fe16e795b6a6eda6a2a8aa00f519a4bd15f2a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1698954822497945"], )