Merge branch 'google:master' into c-image-embedder-api
This commit is contained in:
commit
3b122a1e61
|
@ -513,6 +513,9 @@ http_archive(
|
|||
"@//third_party:org_tensorflow_system_python.diff",
|
||||
# Diff is generated with a script, don't update it manually.
|
||||
"@//third_party:org_tensorflow_custom_ops.diff",
|
||||
# Works around Bazel issue with objc_library.
|
||||
# See https://github.com/bazelbuild/bazel/issues/19912
|
||||
"@//third_party:org_tensorflow_objc_build_fixes.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 56;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
8566B55D2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h in Headers */ = {isa = PBXBuildFile; fileRef = 8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */; settings = {ATTRIBUTES = (Public, ); }; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = MediaPipeTasksDocGen.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MediaPipeTasksDocGen.h; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
8566B5562ABABF9A00AAB22A /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
8566B54F2ABABF9A00AAB22A = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
8566B55B2ABABF9A00AAB22A /* MediaPipeTasksDocGen */,
|
||||
8566B55A2ABABF9A00AAB22A /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
8566B55A2ABABF9A00AAB22A /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
8566B55B2ABABF9A00AAB22A /* MediaPipeTasksDocGen */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
8566B55C2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h */,
|
||||
);
|
||||
path = MediaPipeTasksDocGen;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXHeadersBuildPhase section */
|
||||
8566B5542ABABF9A00AAB22A /* Headers */ = {
|
||||
isa = PBXHeadersBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
8566B55D2ABABF9A00AAB22A /* MediaPipeTasksDocGen.h in Headers */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXHeadersBuildPhase section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
8566B5582ABABF9A00AAB22A /* MediaPipeTasksDocGen */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = 8566B5602ABABF9A00AAB22A /* Build configuration list for PBXNativeTarget "MediaPipeTasksDocGen" */;
|
||||
buildPhases = (
|
||||
8566B5542ABABF9A00AAB22A /* Headers */,
|
||||
8566B5552ABABF9A00AAB22A /* Sources */,
|
||||
8566B5562ABABF9A00AAB22A /* Frameworks */,
|
||||
8566B5572ABABF9A00AAB22A /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
name = MediaPipeTasksDocGen;
|
||||
productName = MediaPipeTasksDocGen;
|
||||
productReference = 8566B5592ABABF9A00AAB22A /* MediaPipeTasksDocGen.framework */;
|
||||
productType = "com.apple.product-type.framework";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
8566B5502ABABF9A00AAB22A /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastUpgradeCheck = 1430;
|
||||
TargetAttributes = {
|
||||
8566B5582ABABF9A00AAB22A = {
|
||||
CreatedOnToolsVersion = 14.3.1;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = 8566B5532ABABF9A00AAB22A /* Build configuration list for PBXProject "MediaPipeTasksDocGen" */;
|
||||
compatibilityVersion = "Xcode 14.0";
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = 8566B54F2ABABF9A00AAB22A;
|
||||
productRefGroup = 8566B55A2ABABF9A00AAB22A /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
8566B5582ABABF9A00AAB22A /* MediaPipeTasksDocGen */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
8566B5572ABABF9A00AAB22A /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
8566B5552ABABF9A00AAB22A /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
8566B55E2ABABF9A00AAB22A /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 16.4;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
VERSIONING_SYSTEM = "apple-generic";
|
||||
VERSION_INFO_PREFIX = "";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
8566B55F2ABABF9A00AAB22A /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu11;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 16.4;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-O";
|
||||
VALIDATE_PRODUCT = YES;
|
||||
VERSIONING_SYSTEM = "apple-generic";
|
||||
VERSION_INFO_PREFIX = "";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
8566B5612ABABF9A00AAB22A /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEFINES_MODULE = YES;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
ENABLE_MODULE_VERIFIER = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright = "";
|
||||
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
"@loader_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu11 gnu++20";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.google.mediapipe.MediaPipeTasksDocGen;
|
||||
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
|
||||
SKIP_INSTALL = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
8566B5622ABABF9A00AAB22A /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEFINES_MODULE = YES;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
ENABLE_MODULE_VERIFIER = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_NSHumanReadableCopyright = "";
|
||||
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
"@loader_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++";
|
||||
MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu11 gnu++20";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.google.mediapipe.MediaPipeTasksDocGen;
|
||||
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
|
||||
SKIP_INSTALL = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
8566B5532ABABF9A00AAB22A /* Build configuration list for PBXProject "MediaPipeTasksDocGen" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
8566B55E2ABABF9A00AAB22A /* Debug */,
|
||||
8566B55F2ABABF9A00AAB22A /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
8566B5602ABABF9A00AAB22A /* Build configuration list for PBXNativeTarget "MediaPipeTasksDocGen" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
8566B5612ABABF9A00AAB22A /* Debug */,
|
||||
8566B5622ABABF9A00AAB22A /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
};
|
||||
rootObject = 8566B5502ABABF9A00AAB22A /* Project object */;
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>IDEDidComputeMac32BitWarning</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
Binary file not shown.
|
@ -0,0 +1,14 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>SchemeUserState</key>
|
||||
<dict>
|
||||
<key>MediaPipeTasksDocGen.xcscheme_^#shared#^_</key>
|
||||
<dict>
|
||||
<key>orderHint</key>
|
||||
<integer>0</integer>
|
||||
</dict>
|
||||
</dict>
|
||||
</dict>
|
||||
</plist>
|
|
@ -0,0 +1,17 @@
|
|||
//
|
||||
// MediaPipeTasksDocGen.h
|
||||
// MediaPipeTasksDocGen
|
||||
//
|
||||
// Created by Mark McDonald on 20/9/2023.
|
||||
//
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
//! Project version number for MediaPipeTasksDocGen.
|
||||
FOUNDATION_EXPORT double MediaPipeTasksDocGenVersionNumber;
|
||||
|
||||
//! Project version string for MediaPipeTasksDocGen.
|
||||
FOUNDATION_EXPORT const unsigned char MediaPipeTasksDocGenVersionString[];
|
||||
|
||||
// In this header, you should import all the public headers of your framework using statements like
|
||||
// #import <MediaPipeTasksDocGen/PublicHeader.h>
|
11
docs/MediaPipeTasksDocGen/Podfile
Normal file
11
docs/MediaPipeTasksDocGen/Podfile
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Uncomment the next line to define a global platform for your project
|
||||
platform :ios, '15.0'
|
||||
|
||||
target 'MediaPipeTasksDocGen' do
|
||||
# Comment the next line if you don't want to use dynamic frameworks
|
||||
use_frameworks!
|
||||
|
||||
# Pods for MediaPipeTasksDocGen
|
||||
pod 'MediaPipeTasksText'
|
||||
pod 'MediaPipeTasksVision'
|
||||
end
|
9
docs/MediaPipeTasksDocGen/README.md
Normal file
9
docs/MediaPipeTasksDocGen/README.md
Normal file
|
@ -0,0 +1,9 @@
|
|||
# MediaPipeTasksDocGen
|
||||
|
||||
This empty project is used to generate reference documentation for the
|
||||
ObjectiveC and Swift libraries.
|
||||
|
||||
Docs are generated using [Jazzy](https://github.com/realm/jazzy) and published
|
||||
to [the developer site](https://developers.google.com/mediapipe/solutions/).
|
||||
|
||||
To bump the API version used, edit [`Podfile`](./Podfile).
|
|
@ -727,6 +727,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -742,6 +743,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_util",
|
||||
"//mediapipe/util:packet_test_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -32,6 +33,7 @@ namespace {
|
|||
constexpr char kTagAtPreStream[] = "AT_PRESTREAM";
|
||||
constexpr char kTagAtPostStream[] = "AT_POSTSTREAM";
|
||||
constexpr char kTagAtZero[] = "AT_ZERO";
|
||||
constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK";
|
||||
constexpr char kTagAtTick[] = "AT_TICK";
|
||||
constexpr char kTagTick[] = "TICK";
|
||||
constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP";
|
||||
|
@ -43,6 +45,7 @@ static std::map<std::string, Timestamp>* kTimestampMap = []() {
|
|||
res->emplace(kTagAtPostStream, Timestamp::PostStream());
|
||||
res->emplace(kTagAtZero, Timestamp(0));
|
||||
res->emplace(kTagAtTick, Timestamp::Unset());
|
||||
res->emplace(kTagAtFirstTick, Timestamp::Unset());
|
||||
res->emplace(kTagAtTimestamp, Timestamp::Unset());
|
||||
return res;
|
||||
}();
|
||||
|
@ -59,8 +62,8 @@ std::string GetOutputTag(const CC& cc) {
|
|||
// timestamp, depending on the tag used to define output stream(s). (One tag can
|
||||
// be used only.)
|
||||
//
|
||||
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP
|
||||
// and corresponding timestamps are Timestamp::PreStream(),
|
||||
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK,
|
||||
// AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(),
|
||||
// Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK
|
||||
// input, and timestamp received from a side input.
|
||||
//
|
||||
|
@ -96,6 +99,7 @@ class SidePacketToStreamCalculator : public CalculatorBase {
|
|||
|
||||
private:
|
||||
bool is_tick_processing_ = false;
|
||||
bool close_on_first_tick_ = false;
|
||||
std::string output_tag_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
|
||||
|
@ -103,13 +107,16 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator);
|
|||
absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
|
||||
const auto& tags = cc->Outputs().GetTags();
|
||||
RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1)
|
||||
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
|
||||
"AT_TIMESTAMP tags is allowed and required to specify output "
|
||||
"stream(s).";
|
||||
RET_CHECK(
|
||||
(cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) ||
|
||||
(!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick)))
|
||||
<< "Either both of TICK and AT_TICK should be used or none of them.";
|
||||
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, "
|
||||
"AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to "
|
||||
"specify output stream(s).";
|
||||
const bool has_tick_output =
|
||||
cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick);
|
||||
const bool has_tick_input = cc->Inputs().HasTag(kTagTick);
|
||||
RET_CHECK((has_tick_output && has_tick_input) ||
|
||||
(!has_tick_output && !has_tick_input))
|
||||
<< "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
|
||||
"should be used or none of them.";
|
||||
RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) &&
|
||||
cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) ||
|
||||
(!cc->Outputs().HasTag(kTagAtTimestamp) &&
|
||||
|
@ -148,11 +155,17 @@ absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) {
|
|||
// timestamp bound update.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
}
|
||||
if (output_tag_ == kTagAtFirstTick) {
|
||||
close_on_first_tick_ = true;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
|
||||
if (is_tick_processing_) {
|
||||
if (cc->Outputs().Get(output_tag_, 0).IsClosed()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// TICK input is guaranteed to be non-empty, as it's the only input stream
|
||||
// for this calculator.
|
||||
const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp();
|
||||
|
@ -160,6 +173,9 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
|
|||
cc->Outputs()
|
||||
.Get(output_tag_, i)
|
||||
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
|
||||
if (close_on_first_tick_) {
|
||||
cc->Outputs().Get(output_tag_, i).Close();
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -170,6 +186,7 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
|
||||
if (!cc->Outputs().HasTag(kTagAtTick) &&
|
||||
!cc->Outputs().HasTag(kTagAtFirstTick) &&
|
||||
!cc->Outputs().HasTag(kTagAtTimestamp)) {
|
||||
const auto& timestamp = kTimestampMap->at(output_tag_);
|
||||
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
|
||||
|
|
|
@ -27,13 +27,17 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/options_util.h"
|
||||
#include "mediapipe/util/packet_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using testing::HasSubstr;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::IsEmpty;
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTick) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -52,10 +56,35 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
|
|||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr(
|
||||
"Either both of TICK and AT_TICK should be used or none of them."));
|
||||
"Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
|
||||
"should be used or none of them."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
|
||||
TEST(SidePacketToStreamCalculator,
|
||||
WrongConfigWithMissingTickForFirstTickProcessing) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "tick"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "packet"
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "AT_FIRST_TICK:packet"
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraph graph;
|
||||
auto status = graph.Initialize(graph_config);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr(
|
||||
"Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
|
||||
"should be used or none of them."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTimestampSideInput) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -76,7 +105,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
|
|||
"or none of them."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithNonExistentTag) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -92,14 +121,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
|
|||
CalculatorGraph graph;
|
||||
auto status = graph.Initialize(graph_config);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
|
||||
"AT_TIMESTAMP tags is allowed and required to specify output "
|
||||
"stream(s)."));
|
||||
EXPECT_THAT(status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
|
||||
"AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
|
||||
"allowed and required to specify output stream(s)."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithMixedTags) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -117,14 +145,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
|
|||
CalculatorGraph graph;
|
||||
auto status = graph.Initialize(graph_config);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
|
||||
"AT_TIMESTAMP tags is allowed and required to specify output "
|
||||
"stream(s)."));
|
||||
EXPECT_THAT(status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
|
||||
"AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
|
||||
"allowed and required to specify output stream(s)."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughSidePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -146,7 +173,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
|
|||
"Same number of input side packets and output streams is required."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughOutputStreams) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -248,7 +275,50 @@ TEST(SidePacketToStreamCalculator, AtTick) {
|
|||
tick_and_verify(/*at_timestamp=*/1025);
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
|
||||
TEST(SidePacketToStreamCalculator, AtFirstTick) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "tick"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "packet"
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_stream: "TICK:tick"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "AT_FIRST_TICK:packet"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("packet", &graph_config, &output_packets);
|
||||
CalculatorGraph graph;
|
||||
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
const int expected_value = 20;
|
||||
const Timestamp kTestTimestamp(1234);
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"side_packet", MakePacket<int>(expected_value)}}));
|
||||
|
||||
auto insert_tick = [&graph](Timestamp at_timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
};
|
||||
|
||||
insert_tick(kTestTimestamp);
|
||||
|
||||
EXPECT_THAT(output_packets,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<int>(
|
||||
Eq(kTestTimestamp), Eq(expected_value))));
|
||||
|
||||
output_packets.clear();
|
||||
|
||||
// Should not result in an additional output.
|
||||
insert_tick(kTestTimestamp + 1);
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTickWithMultipleSidePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -302,6 +372,62 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
|
|||
tick_and_verify(/*at_timestamp=*/1025);
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtFirstTickWithMultipleSidePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "tick"
|
||||
input_side_packet: "side_packet0"
|
||||
input_side_packet: "side_packet1"
|
||||
output_stream: "packet0"
|
||||
output_stream: "packet1"
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_stream: "TICK:tick"
|
||||
input_side_packet: "side_packet0"
|
||||
input_side_packet: "side_packet1"
|
||||
output_stream: "AT_FIRST_TICK:0:packet0"
|
||||
output_stream: "AT_FIRST_TICK:1:packet1"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets0;
|
||||
tool::AddVectorSink("packet0", &graph_config, &output_packets0);
|
||||
std::vector<Packet> output_packets1;
|
||||
tool::AddVectorSink("packet1", &graph_config, &output_packets1);
|
||||
CalculatorGraph graph;
|
||||
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
const int expected_value0 = 20;
|
||||
const int expected_value1 = 128;
|
||||
const Timestamp kTestTimestamp(1234);
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"side_packet0", MakePacket<int>(expected_value0)},
|
||||
{"side_packet1", MakePacket<int>(expected_value1)}}));
|
||||
|
||||
auto insert_tick = [&graph](Timestamp at_timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
};
|
||||
|
||||
insert_tick(kTestTimestamp);
|
||||
|
||||
EXPECT_THAT(output_packets0,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<int>(
|
||||
Eq(kTestTimestamp), Eq(expected_value0))));
|
||||
EXPECT_THAT(output_packets1,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<int>(
|
||||
Eq(kTestTimestamp), Eq(expected_value1))));
|
||||
|
||||
output_packets0.clear();
|
||||
output_packets1.clear();
|
||||
|
||||
// Should not result in an additional output.
|
||||
insert_tick(kTestTimestamp + 1);
|
||||
EXPECT_THAT(output_packets0, IsEmpty());
|
||||
EXPECT_THAT(output_packets1, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTimestamp) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
|
@ -334,7 +460,7 @@ TEST(SidePacketToStreamCalculator, AtTimestamp) {
|
|||
EXPECT_EQ(expected_value, output_packets.back().Get<int>());
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) {
|
||||
TEST(SidePacketToStreamCalculator, AtTimestampWithMultipleOutputs) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
|
|
@ -65,7 +65,7 @@ class ImageCloneCalculator : public Node {
|
|||
}
|
||||
#else
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
|
||||
cc, /*requesst_gpu_as_optional=*/true));
|
||||
cc, /*request_gpu_as_optional=*/true));
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -118,7 +118,7 @@ absl::Status SegmentationSmoothingCalculator::GetContract(
|
|||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
|
||||
cc, /*requesst_gpu_as_optional=*/true));
|
||||
cc, /*request_gpu_as_optional=*/true));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -206,7 +206,7 @@ class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl<InterfaceT> {
|
|||
if constexpr (std::is_same_v<InterfaceT, WarpAffineCalculatorGpu> ||
|
||||
std::is_same_v<InterfaceT, WarpAffineCalculator>) {
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
|
||||
cc, /*requesst_gpu_as_optional=*/true));
|
||||
cc, /*request_gpu_as_optional=*/true));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -1480,7 +1480,6 @@ cc_test(
|
|||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -109,7 +109,7 @@ bool IsValidFftSize(int size) {
|
|||
// Non-streaming mode: when "stream_mode" is set to false in the calculator
|
||||
// options, the calculators treats the packets in the input audio stream as
|
||||
// a batch of unrelated audio buffers. In each Process() call, the input
|
||||
// buffer will be frist resampled, and framed as fixed-sized, possibly
|
||||
// buffer will be first resampled, and framed as fixed-sized, possibly
|
||||
// overlapping tensors. The last tensor produced by a Process() invocation
|
||||
// will be zero-padding if the remaining samples are insufficient. As the
|
||||
// calculator treats the input packets as unrelated, all samples will be
|
||||
|
@ -159,7 +159,7 @@ class AudioToTensorCalculator : public Node {
|
|||
public:
|
||||
static constexpr Input<Matrix> kAudioIn{"AUDIO"};
|
||||
// TODO: Removes this optional input stream when the "AUDIO" stream
|
||||
// uses the new mediapipe audio data containers that carry audio metatdata,
|
||||
// uses the new mediapipe audio data containers that carry audio metadata,
|
||||
// such as sample rate.
|
||||
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
|
|
@ -37,7 +37,7 @@ message AudioToTensorCalculatorOptions {
|
|||
// will be converted into tensors.
|
||||
optional double target_sample_rate = 4;
|
||||
|
||||
// Whether to treat the input audio stream as a continous stream or a batch
|
||||
// Whether to treat the input audio stream as a continuous stream or a batch
|
||||
// of unrelated audio buffers.
|
||||
optional bool stream_mode = 5 [default = true];
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ namespace api2 {
|
|||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor populated with an extrated RGB image.
|
||||
// Vector containing a single Tensor populated with an extracted RGB image.
|
||||
// MATRIX - std::array<float, 16> @Optional
|
||||
// An std::array<float, 16> representing a 4x4 row-major-order matrix that
|
||||
// maps a point on the input image to a point on the output tensor, and
|
||||
|
@ -212,7 +212,7 @@ class ImageToTensorCalculator : public Node {
|
|||
std::array<float, 16> matrix;
|
||||
GetRotatedSubRectToRectTransformMatrix(
|
||||
roi, image->width(), image->height(),
|
||||
/*flip_horizontaly=*/false, &matrix);
|
||||
/*flip_horizontally=*/false, &matrix);
|
||||
kOutMatrix(cc).Send(std::move(matrix));
|
||||
}
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
|
|||
} else if (image_channels == 1) {
|
||||
return ImageFormat::GRAY8;
|
||||
}
|
||||
ABSL_CHECK(false) << "Unsupported input image channles: " << image_channels;
|
||||
ABSL_CHECK(false) << "Unsupported input image channels: " << image_channels;
|
||||
}
|
||||
|
||||
Packet MakeImageFramePacket(cv::Mat input) {
|
||||
|
|
|
@ -57,7 +57,7 @@ class SubRectExtractorGl {
|
|||
absl::Status ExtractSubRectToBuffer(
|
||||
const tflite::gpu::gl::GlTexture& texture,
|
||||
const tflite::gpu::HW& texture_size, const RotatedRect& sub_rect,
|
||||
bool flip_horizontaly, float alpha, float beta,
|
||||
bool flip_horizontally, float alpha, float beta,
|
||||
const tflite::gpu::HW& destination_size,
|
||||
tflite::gpu::gl::CommandQueue* command_queue,
|
||||
tflite::gpu::gl::GlBuffer* destination);
|
||||
|
@ -154,13 +154,13 @@ void main() {
|
|||
absl::Status SubRectExtractorGl::ExtractSubRectToBuffer(
|
||||
const tflite::gpu::gl::GlTexture& texture,
|
||||
const tflite::gpu::HW& texture_size, const RotatedRect& texture_sub_rect,
|
||||
bool flip_horizontaly, float alpha, float beta,
|
||||
bool flip_horizontally, float alpha, float beta,
|
||||
const tflite::gpu::HW& destination_size,
|
||||
tflite::gpu::gl::CommandQueue* command_queue,
|
||||
tflite::gpu::gl::GlBuffer* destination) {
|
||||
std::array<float, 16> transform_mat;
|
||||
GetRotatedSubRectToRectTransformMatrix(texture_sub_rect, texture_size.w,
|
||||
texture_size.h, flip_horizontaly,
|
||||
texture_size.h, flip_horizontally,
|
||||
&transform_mat);
|
||||
MP_RETURN_IF_ERROR(texture.BindAsSampler2D(0));
|
||||
|
||||
|
@ -308,7 +308,7 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
input_texture,
|
||||
tflite::gpu::HW(source_texture.height(), source_texture.width()),
|
||||
roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
/*flip_horizontally=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||
command_queue_.get(), &output));
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
range_min, range_max));
|
||||
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
|
||||
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
|
||||
/*flip_horizontaly=*/false,
|
||||
/*flip_horizontally=*/false,
|
||||
transform.scale, transform.offset,
|
||||
output_shape, &tensor_view));
|
||||
return absl::OkStatus();
|
||||
|
@ -210,7 +210,7 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
|
||||
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
|
||||
const RotatedRect& sub_rect,
|
||||
bool flip_horizontaly, float alpha, float beta,
|
||||
bool flip_horizontally, float alpha, float beta,
|
||||
const Tensor::Shape& output_shape,
|
||||
Tensor::OpenGlTexture2dView* output) {
|
||||
const int output_height = output_shape.dims[1];
|
||||
|
@ -263,13 +263,13 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
ABSL_LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread.";
|
||||
if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
|
||||
GetTransposedRotatedSubRectToRectTransformMatrix(
|
||||
sub_rect, texture.width(), texture.height(), flip_horizontaly,
|
||||
sub_rect, texture.width(), texture.height(), flip_horizontally,
|
||||
&transform_mat);
|
||||
glUniformMatrix4fv(matrix_id_, 1, GL_FALSE, transform_mat.data());
|
||||
} else {
|
||||
GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(),
|
||||
texture.height(), flip_horizontaly,
|
||||
&transform_mat);
|
||||
texture.height(),
|
||||
flip_horizontally, &transform_mat);
|
||||
glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data());
|
||||
}
|
||||
|
||||
|
|
|
@ -179,13 +179,13 @@ class SubRectExtractorMetal {
|
|||
}
|
||||
|
||||
absl::Status Execute(id<MTLTexture> input_texture,
|
||||
const RotatedRect& sub_rect, bool flip_horizontaly,
|
||||
const RotatedRect& sub_rect, bool flip_horizontally,
|
||||
float alpha, float beta,
|
||||
const tflite::gpu::HW& destination_size,
|
||||
id<MTLCommandBuffer> command_buffer,
|
||||
id<MTLBuffer> destination) {
|
||||
auto output_texture = MTLTextureWithBuffer(destination_size, destination);
|
||||
return InternalExecute(input_texture, sub_rect, flip_horizontaly, alpha,
|
||||
return InternalExecute(input_texture, sub_rect, flip_horizontally, alpha,
|
||||
beta, destination_size, command_buffer,
|
||||
output_texture);
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ class SubRectExtractorMetal {
|
|||
|
||||
absl::Status InternalExecute(id<MTLTexture> input_texture,
|
||||
const RotatedRect& sub_rect,
|
||||
bool flip_horizontaly, float alpha, float beta,
|
||||
bool flip_horizontally, float alpha, float beta,
|
||||
const tflite::gpu::HW& destination_size,
|
||||
id<MTLCommandBuffer> command_buffer,
|
||||
id<MTLTexture> output_texture) {
|
||||
|
@ -223,7 +223,7 @@ class SubRectExtractorMetal {
|
|||
std::array<float, 16> transform_mat;
|
||||
GetRotatedSubRectToRectTransformMatrix(sub_rect, input_texture.width,
|
||||
input_texture.height,
|
||||
flip_horizontaly, &transform_mat);
|
||||
flip_horizontally, &transform_mat);
|
||||
id<MTLBuffer> transform_mat_buffer =
|
||||
[device_ newBufferWithBytes:&transform_mat
|
||||
length:sizeof(transform_mat)
|
||||
|
@ -383,7 +383,7 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
MtlBufferView::GetWriteView(output_tensor, command_buffer);
|
||||
MP_RETURN_IF_ERROR(extractor_->Execute(
|
||||
texture, roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
/*flip_horizontally=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||
command_buffer, buffer_view.buffer()));
|
||||
[command_buffer commit];
|
||||
|
|
|
@ -92,7 +92,7 @@ absl::StatusOr<ValueTransformation> GetValueRangeTransformation(
|
|||
|
||||
void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
|
||||
int rect_width, int rect_height,
|
||||
bool flip_horizontaly,
|
||||
bool flip_horizontally,
|
||||
std::array<float, 16>* matrix_ptr) {
|
||||
std::array<float, 16>& matrix = *matrix_ptr;
|
||||
// The resulting matrix is multiplication of below commented out matrices:
|
||||
|
@ -118,7 +118,7 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
|
|||
// {0.0f, 0.0f, a, 0.0f}
|
||||
// {0.0f, 0.0f, 0.0f, 1.0f}
|
||||
|
||||
const float flip = flip_horizontaly ? -1 : 1;
|
||||
const float flip = flip_horizontally ? -1 : 1;
|
||||
// Matrix for optional horizontal flip around middle of output image.
|
||||
// { fl , 0.0f, 0.0f, 0.0f}
|
||||
// { 0.0f, 1.0f, 0.0f, 0.0f}
|
||||
|
@ -177,13 +177,13 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
|
|||
|
||||
void GetTransposedRotatedSubRectToRectTransformMatrix(
|
||||
const RotatedRect& sub_rect, int rect_width, int rect_height,
|
||||
bool flip_horizontaly, std::array<float, 16>* matrix_ptr) {
|
||||
bool flip_horizontally, std::array<float, 16>* matrix_ptr) {
|
||||
std::array<float, 16>& matrix = *matrix_ptr;
|
||||
// See comments in GetRotatedSubRectToRectTransformMatrix for detailed
|
||||
// calculations.
|
||||
const float a = sub_rect.width;
|
||||
const float b = sub_rect.height;
|
||||
const float flip = flip_horizontaly ? -1 : 1;
|
||||
const float flip = flip_horizontally ? -1 : 1;
|
||||
const float c = std::cos(sub_rect.rotation);
|
||||
const float d = std::sin(sub_rect.rotation);
|
||||
const float e = sub_rect.center_x;
|
||||
|
|
|
@ -74,7 +74,7 @@ absl::StatusOr<std::array<float, 4>> PadRoi(int input_tensor_width,
|
|||
// Represents a transformation of value which involves scaling and offsetting.
|
||||
// To apply transformation:
|
||||
// ValueTransformation transform = ...
|
||||
// float transformed_value = transform.scale * value + transfrom.offset;
|
||||
// float transformed_value = transform.scale * value + transform.offset;
|
||||
struct ValueTransformation {
|
||||
float scale;
|
||||
float offset;
|
||||
|
@ -99,11 +99,11 @@ absl::StatusOr<ValueTransformation> GetValueRangeTransformation(
|
|||
// @sub_rect - rotated sub rect in absolute coordinates
|
||||
// @rect_width - rect width
|
||||
// @rect_height - rect height
|
||||
// @flip_horizontaly - we need to flip the output buffer.
|
||||
// @flip_horizontally - we need to flip the output buffer.
|
||||
// @matrix - 4x4 matrix (array of 16 elements) to populate
|
||||
void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
|
||||
int rect_width, int rect_height,
|
||||
bool flip_horizontaly,
|
||||
bool flip_horizontally,
|
||||
std::array<float, 16>* matrix);
|
||||
|
||||
// Returns the transpose of the matrix found with
|
||||
|
@ -118,11 +118,11 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect,
|
|||
// @sub_rect - rotated sub rect in absolute coordinates
|
||||
// @rect_width - rect width
|
||||
// @rect_height - rect height
|
||||
// @flip_horizontaly - we need to flip the output buffer.
|
||||
// @flip_horizontally - we need to flip the output buffer.
|
||||
// @matrix - 4x4 matrix (array of 16 elements) to populate
|
||||
void GetTransposedRotatedSubRectToRectTransformMatrix(
|
||||
const RotatedRect& sub_rect, int rect_width, int rect_height,
|
||||
bool flip_horizontaly, std::array<float, 16>* matrix);
|
||||
bool flip_horizontally, std::array<float, 16>* matrix);
|
||||
|
||||
// Validates the output dimensions set in the option proto. The input option
|
||||
// proto is expected to have to following fields:
|
||||
|
|
|
@ -32,7 +32,7 @@ message TensorConverterCalculatorOptions {
|
|||
// Custom settings to override the internal scaling factors `div` and `sub`.
|
||||
// Both values must be set to non-negative values. Will only take effect on
|
||||
// CPU AND when |use_custom_normalization| is set to true. When these custom
|
||||
// values take effect, the |zero_center| setting above will be overriden, and
|
||||
// values take effect, the |zero_center| setting above will be overridden, and
|
||||
// the normalized_value will be calculated as:
|
||||
// normalized_value = input / custom_div - custom_sub.
|
||||
optional bool use_custom_normalization = 6 [default = false];
|
||||
|
|
|
@ -34,7 +34,7 @@ message TensorsToClassificationCalculatorOptions {
|
|||
repeated Entry entries = 1;
|
||||
}
|
||||
|
||||
// Score threshold for perserving the class.
|
||||
// Score threshold for preserving the class.
|
||||
optional float min_score_threshold = 1;
|
||||
// Number of highest scoring labels to output. If top_k is not positive then
|
||||
// all labels are used.
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
|
@ -147,7 +146,7 @@ BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) {
|
|||
// TENSORS - Vector of Tensors of type kFloat32. The vector of tensors can have
|
||||
// 2 or 3 tensors. First tensor is the predicted raw boxes/keypoints.
|
||||
// The size of the values must be (num_boxes * num_predicted_values).
|
||||
// Second tensor is the score tensor. The size of the valuse must be
|
||||
// Second tensor is the score tensor. The size of the values must be
|
||||
// (num_boxes * num_classes). It's optional to pass in a third tensor
|
||||
// for anchors (e.g. for SSD models) depend on the outputs of the
|
||||
// detection model. The size of anchor tensor must be (num_boxes *
|
||||
|
@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
const int* detection_classes,
|
||||
std::vector<Detection>* output_detections);
|
||||
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
||||
float box_xmax, float score, int class_id,
|
||||
float box_xmax, absl::Span<const float> scores,
|
||||
absl::Span<const int> class_ids,
|
||||
bool flip_vertically);
|
||||
bool IsClassIndexAllowed(int class_index);
|
||||
|
||||
|
@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
int num_boxes_ = 0;
|
||||
int num_coords_ = 0;
|
||||
int max_results_ = -1;
|
||||
int classes_per_detection_ = 1;
|
||||
BoxFormat box_output_format_ =
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||
|
||||
|
@ -267,7 +268,7 @@ absl::Status TensorsToDetectionsCalculator::UpdateContract(
|
|||
if (CanUseGpu()) {
|
||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
|
||||
cc, /*requesst_gpu_as_optional=*/true));
|
||||
cc, /*request_gpu_as_optional=*/true));
|
||||
#elif MEDIAPIPE_METAL_ENABLED
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
|
@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
auto num_boxes_view = num_boxes_tensor->GetCpuReadView();
|
||||
auto num_boxes = num_boxes_view.buffer<float>();
|
||||
num_boxes_ = num_boxes[0];
|
||||
// The detection model with Detection_PostProcess op may output duplicate
|
||||
// boxes with different classes, in the following format:
|
||||
// num_boxes_tensor = [num_boxes]
|
||||
// detection_classes_tensor = [box_1_class_1, box_1_class_2, ...]
|
||||
// detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ]
|
||||
// detection_boxes_tensor = [box_1, box1, ... ]
|
||||
// Each box repeats classes_per_detection_ times.
|
||||
// Note Detection_PostProcess op is only supported in CPU.
|
||||
RET_CHECK_EQ(max_detections % num_boxes_, 0);
|
||||
classes_per_detection_ = max_detections / num_boxes_;
|
||||
|
||||
auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView();
|
||||
auto detection_boxes = detection_boxes_view.buffer<float>();
|
||||
|
@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
|
||||
auto detection_classes_view = detection_classes_tensor->GetCpuReadView();
|
||||
auto detection_classes_ptr = detection_classes_view.buffer<float>();
|
||||
std::vector<int> detection_classes(num_boxes_);
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
std::vector<int> detection_classes(num_boxes_ * classes_per_detection_);
|
||||
for (int i = 0; i < detection_classes.size(); ++i) {
|
||||
detection_classes[i] = static_cast<int>(detection_classes_ptr[i]);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
|
||||
|
@ -863,24 +874,25 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
|
|||
absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
||||
const float* detection_boxes, const float* detection_scores,
|
||||
const int* detection_classes, std::vector<Detection>* output_detections) {
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
for (int i = 0; i < num_boxes_ * classes_per_detection_;
|
||||
i += classes_per_detection_) {
|
||||
if (max_results_ > 0 && output_detections->size() == max_results_) {
|
||||
break;
|
||||
}
|
||||
if (options_.has_min_score_thresh() &&
|
||||
detection_scores[i] < options_.min_score_thresh()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsClassIndexAllowed(detection_classes[i])) {
|
||||
continue;
|
||||
}
|
||||
const int box_offset = i * num_coords_;
|
||||
Detection detection = ConvertToDetection(
|
||||
/*box_ymin=*/detection_boxes[box_offset + box_indices_[0]],
|
||||
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
|
||||
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
|
||||
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
|
||||
detection_scores[i], detection_classes[i], options_.flip_vertically());
|
||||
absl::MakeConstSpan(detection_scores + i, classes_per_detection_),
|
||||
absl::MakeConstSpan(detection_classes + i, classes_per_detection_),
|
||||
options_.flip_vertically());
|
||||
// if all the scores and classes are filtered out, we skip the empty
|
||||
// detection.
|
||||
if (detection.score().empty()) {
|
||||
continue;
|
||||
}
|
||||
const auto& bbox = detection.location_data().relative_bounding_box();
|
||||
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
|
||||
std::isnan(bbox.height())) {
|
||||
|
@ -910,11 +922,21 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
|||
}
|
||||
|
||||
Detection TensorsToDetectionsCalculator::ConvertToDetection(
|
||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
|
||||
int class_id, bool flip_vertically) {
|
||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax,
|
||||
absl::Span<const float> scores, absl::Span<const int> class_ids,
|
||||
bool flip_vertically) {
|
||||
Detection detection;
|
||||
detection.add_score(score);
|
||||
detection.add_label_id(class_id);
|
||||
for (int i = 0; i < scores.size(); ++i) {
|
||||
if (!IsClassIndexAllowed(class_ids[i])) {
|
||||
continue;
|
||||
}
|
||||
if (options_.has_min_score_thresh() &&
|
||||
scores[i] < options_.min_score_thresh()) {
|
||||
continue;
|
||||
}
|
||||
detection.add_score(scores[i]);
|
||||
detection.add_label_id(class_ids[i]);
|
||||
}
|
||||
|
||||
LocationData* location_data = detection.mutable_location_data();
|
||||
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
||||
|
|
|
@ -75,7 +75,7 @@ message TensorsToDetectionsCalculatorOptions {
|
|||
// representation has a bottom-left origin (e.g., in OpenGL).
|
||||
optional bool flip_vertically = 18 [default = false];
|
||||
|
||||
// Score threshold for perserving decoded detections.
|
||||
// Score threshold for preserving decoded detections.
|
||||
optional float min_score_thresh = 19;
|
||||
|
||||
// The maximum number of the detection results to return. If < 0, all
|
||||
|
|
|
@ -124,7 +124,7 @@ absl::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) {
|
|||
kFlipVertically(cc).IsConnected())) {
|
||||
RET_CHECK(options_.has_input_image_height() &&
|
||||
options_.has_input_image_width())
|
||||
<< "Must provide input width/height for using flipping when outputing "
|
||||
<< "Must provide input width/height for using flipping when outputting "
|
||||
"landmarks in absolute coordinates.";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -208,7 +208,7 @@ absl::Status TensorsToSegmentationCalculator::GetContract(
|
|||
if (CanUseGpu()) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(
|
||||
cc, /*requesst_gpu_as_optional=*/true));
|
||||
cc, /*request_gpu_as_optional=*/true));
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
|
|
@ -60,24 +60,22 @@ struct FormattingTestCase {
|
|||
std::vector<float> inputs;
|
||||
std::vector<float> expected_outputs;
|
||||
Options::Activation activation;
|
||||
int rows;
|
||||
int cols;
|
||||
int channels;
|
||||
int rows = 1;
|
||||
int cols = 1;
|
||||
int rows_new = 1;
|
||||
int cols_new = 1;
|
||||
int channels = 1;
|
||||
double max_abs_diff = 1e-7;
|
||||
};
|
||||
|
||||
using TensorsToSegmentationCalculatorTest = TestWithParam<FormattingTestCase>;
|
||||
|
||||
// Currently only useable for tests with no output resize.
|
||||
TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
|
||||
const FormattingTestCase& test_case = GetParam();
|
||||
std::vector<float> inputs = test_case.inputs;
|
||||
std::vector<float> expected_outputs = test_case.expected_outputs;
|
||||
Options::Activation activation = test_case.activation;
|
||||
int rows = test_case.rows;
|
||||
int cols = test_case.cols;
|
||||
int channels = test_case.channels;
|
||||
const auto& [test_name, inputs, expected_outputs, activation, rows, cols,
|
||||
rows_new, cols_new, channels, max_abs_diff] = GetParam();
|
||||
|
||||
std::string string_config = absl::Substitute(
|
||||
auto graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
|
||||
R"pb(
|
||||
input_stream: "tensors"
|
||||
input_stream: "size"
|
||||
|
@ -93,9 +91,7 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
|
|||
}
|
||||
}
|
||||
)pb",
|
||||
ActivationTypeToString(activation));
|
||||
auto graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(string_config);
|
||||
ActivationTypeToString(activation)));
|
||||
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("image_as_mask", &graph_config, &output_packets);
|
||||
|
@ -119,28 +115,34 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
|
|||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensors", mediapipe::Adopt(tensors.release()).At(Timestamp(0))));
|
||||
}
|
||||
|
||||
// The output size is defined as pair(new_width, new_height).
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"size",
|
||||
mediapipe::Adopt(new std::pair<int, int>(rows, cols)).At(Timestamp(0))));
|
||||
"size", mediapipe::Adopt(new std::pair<int, int>(cols_new, rows_new))
|
||||
.At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
ASSERT_THAT(output_packets, SizeIs(1));
|
||||
const Image& image_as_mask = output_packets[0].Get<Image>();
|
||||
EXPECT_FALSE(image_as_mask.UsesGpu());
|
||||
|
||||
std::shared_ptr<cv::Mat> result_mat = formats::MatView(&image_as_mask);
|
||||
EXPECT_EQ(result_mat->rows, rows);
|
||||
EXPECT_EQ(result_mat->cols, cols);
|
||||
EXPECT_EQ(result_mat->channels(), channels);
|
||||
EXPECT_EQ(result_mat->rows, rows_new);
|
||||
EXPECT_EQ(result_mat->cols, cols_new);
|
||||
EXPECT_EQ(result_mat->channels(), 1);
|
||||
|
||||
// Compare the real result with the expected result.
|
||||
cv::Mat expected_result = cv::Mat(
|
||||
rows, cols, CV_32FC1, const_cast<float*>(expected_outputs.data()));
|
||||
cv::Mat expected_result =
|
||||
cv::Mat(rows_new, cols_new, CV_32FC1,
|
||||
const_cast<float*>(expected_outputs.data()));
|
||||
cv::Mat diff;
|
||||
cv::absdiff(*result_mat, expected_result, diff);
|
||||
double max_val;
|
||||
cv::minMaxLoc(diff, nullptr, &max_val);
|
||||
// Expects the maximum absolute pixel-by-pixel difference is less than 1e-5.
|
||||
// This delta is for passthorugh accuracy only.
|
||||
EXPECT_LE(max_val, 1e-5);
|
||||
|
||||
// The max allowable diff between output and expected output varies between
|
||||
// tests.
|
||||
EXPECT_LE(max_val, max_abs_diff);
|
||||
|
||||
MP_ASSERT_OK(graph.CloseInputStream("tensors"));
|
||||
MP_ASSERT_OK(graph.CloseInputStream("size"));
|
||||
|
@ -150,17 +152,96 @@ TEST_P(TensorsToSegmentationCalculatorTest, ParameterizedTests) {
|
|||
INSTANTIATE_TEST_SUITE_P(
|
||||
TensorsToSegmentationCalculatorTests, TensorsToSegmentationCalculatorTest,
|
||||
testing::ValuesIn<FormattingTestCase>({
|
||||
{/*test_name=*/"NoActivationAndNoOutputResize",
|
||||
/*inputs=*/
|
||||
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
|
||||
14.0, 15.0, 16.0},
|
||||
/*expected_outputs=*/
|
||||
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
|
||||
14.0, 15.0, 16.0},
|
||||
/*activation=*/Options::NONE,
|
||||
/*rows=*/4,
|
||||
/*cols=*/4,
|
||||
/*channels=*/1},
|
||||
{.test_name = "NoActivationAndNoOutputResize",
|
||||
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 13.0, 14.0, 15.0, 16.0},
|
||||
.expected_outputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0,
|
||||
11.0, 12.0, 13.0, 14.0, 15.0, 16.0},
|
||||
.activation = Options::NONE,
|
||||
.rows = 4,
|
||||
.cols = 4,
|
||||
.rows_new = 4,
|
||||
.cols_new = 4,
|
||||
.channels = 1,
|
||||
.max_abs_diff = 1e-7},
|
||||
{.test_name = "OutputResizeOnly",
|
||||
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 13.0, 14.0, 15.0, 16.0},
|
||||
.expected_outputs = {1, 1.5, 2.166667, 2.833333, 3.5, 4,
|
||||
3.8, 4.3, 4.966667, 5.633333, 6.3, 6.8,
|
||||
7, 7.5, 8.166667, 8.833333, 9.5, 10,
|
||||
10.2, 10.7, 11.366667, 12.033333, 12.7, 13.2,
|
||||
13, 13.5, 14.166667, 14.833333, 15.5, 16},
|
||||
.activation = Options::NONE,
|
||||
.rows = 4,
|
||||
.cols = 4,
|
||||
.rows_new = 5,
|
||||
.cols_new = 6,
|
||||
.channels = 1,
|
||||
.max_abs_diff = 1e-6},
|
||||
{.test_name = "SigmoidActivationWithNoOutputResize",
|
||||
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 13.0, 14.0, 15.0, 16.0},
|
||||
.expected_outputs = {0.731059, 0.880797, 0.952574, 0.982014, 0.993307,
|
||||
0.997527, 0.999089, 0.999665, 0.999877, 0.999955,
|
||||
0.999983, 0.999994, 0.999998, 0.999999, 1.0, 1.0},
|
||||
.activation = Options::SIGMOID,
|
||||
.rows = 4,
|
||||
.cols = 4,
|
||||
.rows_new = 4,
|
||||
.cols_new = 4,
|
||||
.channels = 1,
|
||||
.max_abs_diff = 1e-6},
|
||||
{.test_name = "SigmoidActivationWithOutputResize",
|
||||
.inputs = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 13.0, 14.0, 15.0, 16.0},
|
||||
.expected_outputs = {0.731059, 0.805928, 0.89276, 0.940611, 0.967294,
|
||||
0.982014, 0.914633, 0.93857, 0.966279, 0.981363,
|
||||
0.989752, 0.994369, 0.996592, 0.997666, 0.998873,
|
||||
0.999404, 0.999683, 0.999829, 0.999913, 0.99994,
|
||||
0.999971, 0.999985, 0.999992, 0.999996, 0.999998,
|
||||
0.999998, 0.999999, 1.0, 1.0, 1.0},
|
||||
.activation = Options::SIGMOID,
|
||||
.rows = 4,
|
||||
.cols = 4,
|
||||
.rows_new = 5,
|
||||
.cols_new = 6,
|
||||
.channels = 1,
|
||||
.max_abs_diff = 1e-6},
|
||||
{.test_name = "SoftmaxActivationWithNoOutputResize",
|
||||
.inputs = {1.0, 2.0, 4.0, 2.0, 3.0, 5.0, 6.0, 1.5,
|
||||
7.0, 10.0, 11.0, 4.0, 12.0, 15.0, 16.0, 18.5,
|
||||
19.0, 20.0, 22.0, 23.0, 24.5, 23.4, 25.6, 28.3,
|
||||
29.2, 30.0, 24.6, 29.2, 30.0, 24.9, 31.2, 30.3},
|
||||
.expected_outputs = {0.731059, 0.119203, 0.880797, 0.0109869, 0.952574,
|
||||
0.000911051, 0.952574, 0.924142, 0.731059,
|
||||
0.731059, 0.24974, 0.937027, 0.689974, 0.990048,
|
||||
0.0060598, 0.28905},
|
||||
.activation = Options::SOFTMAX,
|
||||
.rows = 4,
|
||||
.cols = 4,
|
||||
.rows_new = 4,
|
||||
.cols_new = 4,
|
||||
.channels = 2,
|
||||
.max_abs_diff = 1e-6},
|
||||
{.test_name = "SoftmaxActivationWithOutputResize",
|
||||
.inputs = {1.0, 2.0, 4.0, 2.0, 3.0, 5.0, 6.0, 1.5,
|
||||
7.0, 10.0, 11.0, 4.0, 12.0, 15.0, 16.0, 18.5,
|
||||
19.0, 20.0, 22.0, 23.0, 24.5, 23.4, 25.6, 28.3,
|
||||
29.2, 30.0, 24.6, 29.2, 30.0, 24.9, 31.2, 30.3},
|
||||
.expected_outputs = {0.731059, 0.425131, 0.246135, 0.753865, 0.445892,
|
||||
0.0109869, 0.886119, 0.461259, 0.185506, 0.781934,
|
||||
0.790618, 0.650195, 0.841816, 0.603901, 0.40518,
|
||||
0.561962, 0.765871, 0.930584, 0.718733, 0.763744,
|
||||
0.703402, 0.281989, 0.459635, 0.742634, 0.689974,
|
||||
0.840011, 0.82605, 0.170058, 0.147555, 0.28905},
|
||||
.activation = Options::SOFTMAX,
|
||||
.rows = 4,
|
||||
.cols = 4,
|
||||
.rows_new = 5,
|
||||
.cols_new = 6,
|
||||
.channels = 2,
|
||||
.max_abs_diff = 1e-6},
|
||||
}),
|
||||
[](const testing::TestParamInfo<
|
||||
TensorsToSegmentationCalculatorTest::ParamType>& info) {
|
||||
|
|
|
@ -79,7 +79,7 @@ namespace mpms = mediapipe::mediasequence;
|
|||
// and label and label_id are optional but at least one of them should be set.
|
||||
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
|
||||
// prefixed versions of each stream, which allows for multiple image streams to
|
||||
// be included. However, the default names are suppored by more tools.
|
||||
// be included. However, the default names are supported by more tools.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
|
|
|
@ -67,8 +67,8 @@ absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
|
|||
// -- 1-D or 2-D Tensor
|
||||
// Output:
|
||||
// -- Matrix with the same values as the Tensor
|
||||
// If input tensor is 1 dimensional, the ouput Matrix is of (1xn) shape.
|
||||
// If input tensor is 2 dimensional (batched), the ouput Matrix is (mxn) shape.
|
||||
// If input tensor is 1 dimensional, the output Matrix is of (1xn) shape.
|
||||
// If input tensor is 2 dimensional (batched), the output Matrix is (mxn) shape.
|
||||
//
|
||||
// Example Config
|
||||
// node: {
|
||||
|
|
|
@ -111,8 +111,8 @@ class InferenceState {
|
|||
// input_side_packet.
|
||||
//
|
||||
// The input and output streams are TensorFlow tensors labeled by tags. The tags
|
||||
// for the streams are matched to feeds and fetchs in a TensorFlow session using
|
||||
// a named_signature.generic_signature in the ModelManifest. The
|
||||
// for the streams are matched to feeds and fetches in a TensorFlow session
|
||||
// using a named_signature.generic_signature in the ModelManifest. The
|
||||
// generic_signature is used as key-value pairs between the MediaPipe tag and
|
||||
// the TensorFlow tensor. The signature_name in the options proto determines
|
||||
// which named_signature is used. The keys in the generic_signature must be
|
||||
|
@ -128,7 +128,7 @@ class InferenceState {
|
|||
// addition. Once batch_size inputs have been provided, the batch will be run
|
||||
// and the output tensors sent out on the output streams with timestamps
|
||||
// corresponding to the input stream packets. Setting the batch_size to 1
|
||||
// completely disables batching, but is indepdent of add_batch_dim_to_tensors.
|
||||
// completely disables batching, but is independent of add_batch_dim_to_tensors.
|
||||
//
|
||||
// The TensorFlowInferenceCalculator also support feeding states recurrently for
|
||||
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the
|
||||
|
|
|
@ -42,7 +42,7 @@ message TensorFlowInferenceCalculatorOptions {
|
|||
// If the 0th dimension is the batch dimension, then the tensors are
|
||||
// concatenated on that dimension. If the 0th is a data dimension, then a 0th
|
||||
// dimension is added before concatenating. If added, the extra dimension is
|
||||
// removed before outputing the tensor. Examples of each case: If you want
|
||||
// removed before outputting the tensor. Examples of each case: If you want
|
||||
// to batch spectra of audio over time for an LSTM, a time-frequency
|
||||
// representation has a 0th dimension as the batch dimension. If you want to
|
||||
// batch frames of video that are [width, height, channels], the batch
|
||||
|
|
Binary file not shown.
|
@ -1,6 +1,7 @@
|
|||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
|
||||
networkTimeout=10000
|
||||
validateDistributionUrl=true
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
|
29
mediapipe/examples/android/solutions/gradlew
vendored
29
mediapipe/examples/android/solutions/gradlew
vendored
|
@ -83,10 +83,8 @@ done
|
|||
# This is normally unused
|
||||
# shellcheck disable=SC2034
|
||||
APP_BASE_NAME=${0##*/}
|
||||
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
|
||||
APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD=maximum
|
||||
|
@ -133,10 +131,13 @@ location of your Java installation."
|
|||
fi
|
||||
else
|
||||
JAVACMD=java
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
if ! command -v java >/dev/null 2>&1
|
||||
then
|
||||
die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
|
@ -144,7 +145,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
|||
case $MAX_FD in #(
|
||||
max*)
|
||||
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
|
||||
# shellcheck disable=SC3045
|
||||
# shellcheck disable=SC2039,SC3045
|
||||
MAX_FD=$( ulimit -H -n ) ||
|
||||
warn "Could not query maximum file descriptor limit"
|
||||
esac
|
||||
|
@ -152,7 +153,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
|||
'' | soft) :;; #(
|
||||
*)
|
||||
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
|
||||
# shellcheck disable=SC3045
|
||||
# shellcheck disable=SC2039,SC3045
|
||||
ulimit -n "$MAX_FD" ||
|
||||
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||
esac
|
||||
|
@ -197,11 +198,15 @@ if "$cygwin" || "$msys" ; then
|
|||
done
|
||||
fi
|
||||
|
||||
# Collect all arguments for the java command;
|
||||
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
|
||||
# shell script including quotes and variable substitutions, so put them in
|
||||
# double quotes to make sure that they get re-expanded; and
|
||||
# * put everything else in single quotes, so that it's not re-expanded.
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Collect all arguments for the java command:
|
||||
# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
|
||||
# and any embedded shellness will be escaped.
|
||||
# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
|
||||
# treated as '${Hostname}' itself on the command line.
|
||||
|
||||
set -- \
|
||||
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||
|
|
|
@ -56,7 +56,7 @@ absl::Status RunMPPGraph() {
|
|||
for (const std::string& kv_pair : kv_pairs) {
|
||||
std::vector<std::string> name_and_value = absl::StrSplit(kv_pair, '=');
|
||||
RET_CHECK(name_and_value.size() == 2);
|
||||
RET_CHECK(!mediapipe::ContainsKey(input_side_packets, name_and_value[0]));
|
||||
RET_CHECK(!input_side_packets.contains(name_and_value[0]));
|
||||
std::string input_side_packet_contents;
|
||||
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
|
||||
name_and_value[1], &input_side_packet_contents));
|
||||
|
|
|
@ -616,6 +616,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/functional:bind_front",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/functional/bind_front.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
|
|
|
@ -134,7 +134,7 @@ absl::Status ParseTagAndName(absl::string_view tag_and_name, std::string* tag,
|
|||
RET_CHECK(name);
|
||||
absl::Status tag_status = absl::OkStatus();
|
||||
absl::Status name_status = absl::UnknownError("");
|
||||
int name_index = 0;
|
||||
int name_index = -1;
|
||||
std::vector<std::string> v = absl::StrSplit(tag_and_name, ':');
|
||||
if (v.size() == 1) {
|
||||
name_status = ValidateName(v[0]);
|
||||
|
@ -143,7 +143,7 @@ absl::Status ParseTagAndName(absl::string_view tag_and_name, std::string* tag,
|
|||
tag_status = ValidateTag(v[0]);
|
||||
name_status = ValidateName(v[1]);
|
||||
name_index = 1;
|
||||
}
|
||||
} // else omitted, name_index == -1, triggering error.
|
||||
if (name_index == -1 || tag_status != absl::OkStatus() ||
|
||||
name_status != absl::OkStatus()) {
|
||||
tag->clear();
|
||||
|
|
|
@ -516,6 +516,7 @@ cc_library(
|
|||
":gpu_buffer_storage",
|
||||
":image_frame_view",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
@ -526,12 +527,14 @@ mediapipe_proto_library(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
cc_library(
|
||||
name = "pixel_buffer_pool_util",
|
||||
srcs = ["pixel_buffer_pool_util.mm"],
|
||||
srcs = ["pixel_buffer_pool_util.cc"],
|
||||
hdrs = ["pixel_buffer_pool_util.h"],
|
||||
copts = [
|
||||
"-x objective-c++",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
@ -542,13 +545,14 @@ objc_library(
|
|||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
cc_library(
|
||||
name = "metal_shared_resources",
|
||||
srcs = ["metal_shared_resources.mm"],
|
||||
srcs = ["metal_shared_resources.cc"],
|
||||
hdrs = ["metal_shared_resources.h"],
|
||||
copts = [
|
||||
"-x objective-c++",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
features = ["-layering_check"],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -557,15 +561,17 @@ objc_library(
|
|||
"@google_toolbox_for_mac//:GTM_Defines",
|
||||
] + [
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
objc_library(
|
||||
cc_library(
|
||||
name = "MPPMetalUtil",
|
||||
srcs = ["MPPMetalUtil.mm"],
|
||||
srcs = ["MPPMetalUtil.cc"],
|
||||
hdrs = ["MPPMetalUtil.h"],
|
||||
copts = [
|
||||
"-x objective-c++",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
@ -575,6 +581,7 @@ objc_library(
|
|||
"@com_google_absl//absl/time",
|
||||
"@google_toolbox_for_mac//:GTM_Defines",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
|
@ -857,12 +864,14 @@ cc_library(
|
|||
}),
|
||||
)
|
||||
|
||||
objc_library(
|
||||
cc_library(
|
||||
name = "MPPMetalHelper",
|
||||
srcs = ["MPPMetalHelper.mm"],
|
||||
srcs = ["MPPMetalHelper.cc"],
|
||||
hdrs = ["MPPMetalHelper.h"],
|
||||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc",
|
||||
],
|
||||
features = ["-layering_check"],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -1215,9 +1224,13 @@ mediapipe_cc_test(
|
|||
],
|
||||
requires_full_emulation = True,
|
||||
deps = [
|
||||
":gl_texture_buffer",
|
||||
":gl_texture_util",
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage_ahwb",
|
||||
":gpu_test_base",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -14,15 +14,14 @@
|
|||
|
||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||
|
||||
#import "GTMDefines.h"
|
||||
#include "absl/log/absl_check.h"
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#import "mediapipe/gpu/gpu_buffer.h"
|
||||
#import "mediapipe/gpu/gpu_service.h"
|
||||
#import "mediapipe/gpu/graph_support.h"
|
||||
#import "mediapipe/gpu/metal_shared_resources.h"
|
||||
#import "GTMDefines.h"
|
||||
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
@interface MPPMetalHelper () {
|
||||
mediapipe::GpuResources* _gpuResources;
|
||||
|
@ -31,7 +30,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
// Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport.
|
||||
// Using a C++ class so it can be declared as a friend of
|
||||
// LegacyCalculatorSupport.
|
||||
class MetalHelperLegacySupport {
|
||||
public:
|
||||
static CalculatorContract* GetCalculatorContract() {
|
||||
|
@ -61,7 +61,8 @@ class MetalHelperLegacySupport {
|
|||
|
||||
- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc {
|
||||
if (!cc) return nil;
|
||||
return [self initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()];
|
||||
return [self
|
||||
initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()];
|
||||
}
|
||||
|
||||
+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc {
|
||||
|
@ -77,7 +78,8 @@ class MetalHelperLegacySupport {
|
|||
}
|
||||
|
||||
// Legacy support.
|
||||
- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets {
|
||||
- (instancetype)initWithSidePackets:
|
||||
(const mediapipe::PacketSet&)inputSidePackets {
|
||||
auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContext();
|
||||
if (cc) {
|
||||
ABSL_CHECK_EQ(&inputSidePackets, &cc->InputSidePackets());
|
||||
|
@ -85,16 +87,19 @@ class MetalHelperLegacySupport {
|
|||
}
|
||||
|
||||
// TODO: remove when we can.
|
||||
ABSL_LOG(WARNING) << "CalculatorContext not available. If this calculator uses "
|
||||
ABSL_LOG(WARNING)
|
||||
<< "CalculatorContext not available. If this calculator uses "
|
||||
"CalculatorBase, call initWithCalculatorContext instead.";
|
||||
mediapipe::GpuSharedData* gpu_shared =
|
||||
inputSidePackets.Tag(mediapipe::kGpuSharedTagName).Get<mediapipe::GpuSharedData*>();
|
||||
inputSidePackets.Tag(mediapipe::kGpuSharedTagName)
|
||||
.Get<mediapipe::GpuSharedData*>();
|
||||
|
||||
return [self initWithGpuResources:gpu_shared->gpu_resources.get()];
|
||||
}
|
||||
|
||||
// Legacy support.
|
||||
+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets {
|
||||
+ (absl::Status)setupInputSidePackets:
|
||||
(mediapipe::PacketTypeSet*)inputSidePackets {
|
||||
auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract();
|
||||
if (cc) {
|
||||
ABSL_CHECK_EQ(inputSidePackets, &cc->InputSidePackets());
|
||||
|
@ -102,11 +107,11 @@ class MetalHelperLegacySupport {
|
|||
}
|
||||
|
||||
// TODO: remove when we can.
|
||||
ABSL_LOG(WARNING) << "CalculatorContract not available. If you're calling this "
|
||||
ABSL_LOG(WARNING)
|
||||
<< "CalculatorContract not available. If you're calling this "
|
||||
"from a GetContract method, call updateContract instead.";
|
||||
auto id = inputSidePackets->GetId(mediapipe::kGpuSharedTagName, 0);
|
||||
RET_CHECK(id.IsValid())
|
||||
<< "A " << mediapipe::kGpuSharedTagName
|
||||
RET_CHECK(id.IsValid()) << "A " << mediapipe::kGpuSharedTagName
|
||||
<< " input side packet is required here.";
|
||||
inputSidePackets->Get(id).Set<mediapipe::GpuSharedData*>();
|
||||
return absl::OkStatus();
|
||||
|
@ -125,10 +130,12 @@ class MetalHelperLegacySupport {
|
|||
}
|
||||
|
||||
- (id<MTLCommandBuffer>)commandBuffer {
|
||||
return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer];
|
||||
return
|
||||
[_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer];
|
||||
}
|
||||
|
||||
- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer
|
||||
- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:
|
||||
(const mediapipe::GpuBuffer&)gpuBuffer
|
||||
plane:(size_t)plane {
|
||||
CVPixelBufferRef pixel_buffer = mediapipe::GetCVPixelBufferRef(gpuBuffer);
|
||||
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
|
||||
|
@ -178,40 +185,47 @@ class MetalHelperLegacySupport {
|
|||
CVMetalTextureRef texture;
|
||||
CVReturn err = CVMetalTextureCacheCreateTextureFromImage(
|
||||
NULL, _gpuResources->metal_shared().resources().mtlTextureCache,
|
||||
mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane,
|
||||
&texture);
|
||||
mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width,
|
||||
height, plane, &texture);
|
||||
ABSL_CHECK_EQ(err, kCVReturnSuccess);
|
||||
return texture;
|
||||
}
|
||||
|
||||
- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer {
|
||||
- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:
|
||||
(const mediapipe::GpuBuffer&)gpuBuffer {
|
||||
return [self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:0];
|
||||
}
|
||||
|
||||
- (id<MTLTexture>)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer {
|
||||
- (id<MTLTexture>)metalTextureWithGpuBuffer:
|
||||
(const mediapipe::GpuBuffer&)gpuBuffer {
|
||||
return [self metalTextureWithGpuBuffer:gpuBuffer plane:0];
|
||||
}
|
||||
|
||||
- (id<MTLTexture>)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer
|
||||
- (id<MTLTexture>)metalTextureWithGpuBuffer:
|
||||
(const mediapipe::GpuBuffer&)gpuBuffer
|
||||
plane:(size_t)plane {
|
||||
CFHolder<CVMetalTextureRef> cvTexture;
|
||||
cvTexture.adopt([self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:plane]);
|
||||
return CVMetalTextureGetTexture(*cvTexture);
|
||||
}
|
||||
|
||||
- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height {
|
||||
- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width
|
||||
height:(int)height {
|
||||
return _gpuResources->gpu_buffer_pool().GetBuffer(width, height);
|
||||
}
|
||||
|
||||
- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width
|
||||
height:(int)height
|
||||
format:(mediapipe::GpuBufferFormat)format {
|
||||
format:(mediapipe::GpuBufferFormat)
|
||||
format {
|
||||
return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format);
|
||||
}
|
||||
|
||||
- (id<MTLLibrary>)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error {
|
||||
- (id<MTLLibrary>)newLibraryWithResourceName:(NSString*)name
|
||||
error:(NSError* _Nullable*)error {
|
||||
return [_gpuResources->metal_shared().resources().mtlDevice
|
||||
newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name
|
||||
newLibraryWithFile:[[NSBundle bundleForClass:[self class]]
|
||||
pathForResource:name
|
||||
ofType:@"metallib"]
|
||||
error:error];
|
||||
}
|
|
@ -69,10 +69,10 @@
|
|||
while (!bufferCompleted) {
|
||||
auto duration = absl::Now() - start_time;
|
||||
// If the spin-lock takes more than 5 ms then go to blocking wait:
|
||||
// - it frees the CPU core for another threads: increase the performance/decrease power
|
||||
// consumption.
|
||||
// - if a driver thread that notifies that the GPU buffer is completed has lower priority then
|
||||
// the CPU core is allocated for the thread.
|
||||
// - it frees the CPU core for another threads: increase the
|
||||
// performance/decrease power consumption.
|
||||
// - if a driver thread that notifies that the GPU buffer is completed has
|
||||
// lower priority then the CPU core is allocated for the thread.
|
||||
if (duration >= absl::Milliseconds(5)) {
|
||||
[commandBuffer waitUntilCompleted];
|
||||
break;
|
|
@ -57,8 +57,8 @@ void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) {
|
|||
|
||||
// static
|
||||
absl::Status GlCalculatorHelper::UpdateContract(CalculatorContract* cc,
|
||||
bool requesst_gpu_as_optional) {
|
||||
if (requesst_gpu_as_optional) {
|
||||
bool request_gpu_as_optional) {
|
||||
if (request_gpu_as_optional) {
|
||||
cc->UseService(kGpuService).Optional();
|
||||
} else {
|
||||
cc->UseService(kGpuService);
|
||||
|
|
|
@ -68,7 +68,7 @@ class GlCalculatorHelper {
|
|||
// This method can be called from GetContract to set up the needed GPU
|
||||
// resources.
|
||||
static absl::Status UpdateContract(CalculatorContract* cc,
|
||||
bool requesst_gpu_as_optional = false);
|
||||
bool request_gpu_as_optional = false);
|
||||
|
||||
// This method can be called from FillExpectations to set the correct types
|
||||
// for the shared GL input side packet(s).
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/log/absl_check.h"
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
|
@ -131,6 +133,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
|||
SymbolAvailable(&glTexStorage2D)) {
|
||||
ABSL_CHECK(data == nullptr) << "unimplemented";
|
||||
glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_);
|
||||
} else if (info.immutable) {
|
||||
ABSL_CHECK(SymbolAvailable(&glTexStorage2D) &&
|
||||
context->GetGlVersion() != GlVersion::kGLES2)
|
||||
<< "Immutable GpuBuffer format requested is not supported in this "
|
||||
<< "GlContext. Format was " << static_cast<uint32_t>(format_);
|
||||
ABSL_CHECK(data == nullptr) << "unimplemented";
|
||||
glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_);
|
||||
} else {
|
||||
glTexImage2D(target_, 0 /* level */, info.gl_internal_format, width_,
|
||||
height_, 0 /* border */, info.gl_format, info.gl_type, data);
|
||||
|
|
|
@ -35,6 +35,10 @@ namespace mediapipe {
|
|||
#endif // GL_HALF_FLOAT_OES
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
#ifndef GL_RGBA8
|
||||
#define GL_RGBA8 0x8058
|
||||
#endif // GL_RGBA8
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#ifdef GL_ES_VERSION_2_0
|
||||
static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) {
|
||||
|
@ -163,6 +167,14 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
|
|||
{
|
||||
{GL_RGBA32F, GL_RGBA, GL_FLOAT, 1},
|
||||
}},
|
||||
{GpuBufferFormat::kImmutableRGBAFloat128,
|
||||
{
|
||||
{GL_RGBA32F, GL_RGBA, GL_FLOAT, 1, true /* immutable */},
|
||||
}},
|
||||
{GpuBufferFormat::kImmutableRGBA32,
|
||||
{
|
||||
{GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, 1, true /* immutable */},
|
||||
}},
|
||||
}};
|
||||
|
||||
static const auto* gles2_format_info = ([] {
|
||||
|
@ -206,6 +218,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
|
|||
|
||||
ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
|
||||
switch (format) {
|
||||
case GpuBufferFormat::kImmutableRGBA32:
|
||||
case GpuBufferFormat::kBGRA32:
|
||||
// TODO: verify we are handling order of channels correctly.
|
||||
return ImageFormat::SRGBA;
|
||||
|
@ -221,10 +234,11 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
|
|||
return ImageFormat::SRGB;
|
||||
case GpuBufferFormat::kTwoComponentFloat32:
|
||||
return ImageFormat::VEC32F2;
|
||||
case GpuBufferFormat::kImmutableRGBAFloat128:
|
||||
case GpuBufferFormat::kRGBAFloat128:
|
||||
return ImageFormat::VEC32F4;
|
||||
case GpuBufferFormat::kRGBA32:
|
||||
// TODO: this likely maps to ImageFormat::SRGBA
|
||||
return ImageFormat::SRGBA;
|
||||
case GpuBufferFormat::kGrayHalf16:
|
||||
case GpuBufferFormat::kOneComponent8Alpha:
|
||||
case GpuBufferFormat::kOneComponent8Red:
|
||||
|
|
|
@ -53,6 +53,10 @@ enum class GpuBufferFormat : uint32_t {
|
|||
kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible.
|
||||
kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'),
|
||||
kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'),
|
||||
// Immutable version of kRGBA32
|
||||
kImmutableRGBA32 = MEDIAPIPE_FOURCC('4', 'C', 'I', '8'),
|
||||
// Immutable version of kRGBAFloat128
|
||||
kImmutableRGBAFloat128 = MEDIAPIPE_FOURCC('4', 'C', 'I', 'f'),
|
||||
// 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling.
|
||||
kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'),
|
||||
// 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling.
|
||||
|
@ -78,6 +82,9 @@ struct GlTextureInfo {
|
|||
// For multiplane buffers, this represents how many times smaller than
|
||||
// the nominal image size a plane is.
|
||||
int downscale;
|
||||
// For GLES3.1+ compute shaders, users may explicitly request immutable
|
||||
// textures.
|
||||
bool immutable = false;
|
||||
};
|
||||
|
||||
const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
|
||||
|
@ -121,6 +128,8 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
|
|||
return kCVPixelFormatType_64RGBAHalf;
|
||||
case GpuBufferFormat::kRGBAFloat128:
|
||||
return kCVPixelFormatType_128RGBAFloat;
|
||||
case GpuBufferFormat::kImmutableRGBA32:
|
||||
case GpuBufferFormat::kImmutableRGBAFloat128:
|
||||
case GpuBufferFormat::kNV12:
|
||||
case GpuBufferFormat::kNV21:
|
||||
case GpuBufferFormat::kI420:
|
||||
|
|
|
@ -151,7 +151,7 @@ static std::shared_ptr<GpuBufferStorageCvPixelBuffer> ConvertFromImageFrame(
|
|||
std::shared_ptr<GpuBufferStorageImageFrame> frame) {
|
||||
auto status_or_buffer =
|
||||
CreateCVPixelBufferForImageFrame(frame->image_frame());
|
||||
ABSL_CHECK(status_or_buffer.ok());
|
||||
ABSL_CHECK_OK(status_or_buffer);
|
||||
return std::make_shared<GpuBufferStorageCvPixelBuffer>(
|
||||
std::move(status_or_buffer).value());
|
||||
}
|
||||
|
|
|
@ -50,9 +50,10 @@
|
|||
- (CVMetalTextureCacheRef)mtlTextureCache {
|
||||
@synchronized(self) {
|
||||
if (!_mtlTextureCache) {
|
||||
CVReturn __unused err =
|
||||
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err,
|
||||
CVReturn __unused err = CVMetalTextureCacheCreate(
|
||||
NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
NSAssert(err == kCVReturnSuccess,
|
||||
@"Error at CVMetalTextureCacheCreate %d ; device %@", err,
|
||||
self.mtlDevice);
|
||||
// TODO: register and flush metal caches too.
|
||||
}
|
|
@ -24,23 +24,27 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
CVPixelBufferPoolRef CreateCVPixelBufferPool(
|
||||
int width, int height, OSType pixelFormat, int keepCount,
|
||||
CVPixelBufferPoolRef CreateCVPixelBufferPool(int width, int height,
|
||||
OSType pixelFormat, int keepCount,
|
||||
CFTimeInterval maxAge) {
|
||||
CVPixelBufferPoolRef pool = NULL;
|
||||
|
||||
NSMutableDictionary *sourcePixelBufferOptions =
|
||||
[(__bridge NSDictionary*)GetCVPixelBufferAttributesForGlCompatibility() mutableCopy];
|
||||
[(__bridge NSDictionary *)GetCVPixelBufferAttributesForGlCompatibility()
|
||||
mutableCopy];
|
||||
[sourcePixelBufferOptions addEntriesFromDictionary:@{
|
||||
(id)kCVPixelBufferPixelFormatTypeKey : @(pixelFormat),
|
||||
(id)kCVPixelBufferWidthKey : @(width),
|
||||
(id)kCVPixelBufferHeightKey : @(height),
|
||||
}];
|
||||
|
||||
NSMutableDictionary *pixelBufferPoolOptions = [[NSMutableDictionary alloc] init];
|
||||
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] = @(keepCount);
|
||||
NSMutableDictionary *pixelBufferPoolOptions =
|
||||
[[NSMutableDictionary alloc] init];
|
||||
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] =
|
||||
@(keepCount);
|
||||
if (maxAge > 0) {
|
||||
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] = @(maxAge);
|
||||
pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] =
|
||||
@(maxAge);
|
||||
}
|
||||
|
||||
CVPixelBufferPoolCreate(
|
||||
|
@ -50,8 +54,9 @@ CVPixelBufferPoolRef CreateCVPixelBufferPool(
|
|||
return pool;
|
||||
}
|
||||
|
||||
OSStatus PreallocateCVPixelBufferPoolBuffers(
|
||||
CVPixelBufferPoolRef pool, int count, CFDictionaryRef auxAttributes) {
|
||||
OSStatus PreallocateCVPixelBufferPoolBuffers(CVPixelBufferPoolRef pool,
|
||||
int count,
|
||||
CFDictionaryRef auxAttributes) {
|
||||
CVReturn err = kCVReturnSuccess;
|
||||
NSMutableArray *pixelBuffers = [[NSMutableArray alloc] init];
|
||||
for (int i = 0; i < count && err == kCVReturnSuccess; i++) {
|
||||
|
@ -68,30 +73,37 @@ OSStatus PreallocateCVPixelBufferPoolBuffers(
|
|||
return err;
|
||||
}
|
||||
|
||||
CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold(int allocationThreshold) {
|
||||
CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold(
|
||||
int allocationThreshold) {
|
||||
if (allocationThreshold > 0) {
|
||||
return (CFDictionaryRef)CFBridgingRetain(
|
||||
@{(id)kCVPixelBufferPoolAllocationThresholdKey: @(allocationThreshold)});
|
||||
return (CFDictionaryRef)CFBridgingRetain(@{
|
||||
(id)kCVPixelBufferPoolAllocationThresholdKey : @(allocationThreshold)
|
||||
});
|
||||
} else {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
CVReturn CreateCVPixelBufferWithPool(
|
||||
CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes,
|
||||
CVTextureCacheType textureCache, CVPixelBufferRef* outBuffer) {
|
||||
return CreateCVPixelBufferWithPool(pool, auxAttributes, [textureCache](){
|
||||
CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool,
|
||||
CFDictionaryRef auxAttributes,
|
||||
CVTextureCacheType textureCache,
|
||||
CVPixelBufferRef *outBuffer) {
|
||||
return CreateCVPixelBufferWithPool(
|
||||
pool, auxAttributes,
|
||||
[textureCache]() {
|
||||
#if TARGET_OS_OSX
|
||||
CVOpenGLTextureCacheFlush(textureCache, 0);
|
||||
#else
|
||||
CVOpenGLESTextureCacheFlush(textureCache, 0);
|
||||
#endif // TARGET_OS_OSX
|
||||
}, outBuffer);
|
||||
},
|
||||
outBuffer);
|
||||
}
|
||||
|
||||
CVReturn CreateCVPixelBufferWithPool(
|
||||
CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes,
|
||||
std::function<void(void)> flush, CVPixelBufferRef* outBuffer) {
|
||||
CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool,
|
||||
CFDictionaryRef auxAttributes,
|
||||
std::function<void(void)> flush,
|
||||
CVPixelBufferRef *outBuffer) {
|
||||
CVReturn err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes(
|
||||
kCFAllocatorDefault, pool, auxAttributes, outBuffer);
|
||||
if (err == kCVReturnWouldExceedAllocationThreshold) {
|
||||
|
@ -103,10 +115,12 @@ CVReturn CreateCVPixelBufferWithPool(
|
|||
kCFAllocatorDefault, pool, auxAttributes, outBuffer);
|
||||
}
|
||||
if (err == kCVReturnWouldExceedAllocationThreshold) {
|
||||
// TODO: allow the application to set the threshold. For now, disable it by
|
||||
// default, since the threshold we are using is arbitrary and some graphs routinely cross it.
|
||||
// TODO: allow the application to set the threshold. For now, disable it
|
||||
// by default, since the threshold we are using is arbitrary and some
|
||||
// graphs routinely cross it.
|
||||
#ifdef ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK
|
||||
NSLog(@"Using more buffers than expected! This is a debug-only warning, "
|
||||
NSLog(
|
||||
@"Using more buffers than expected! This is a debug-only warning, "
|
||||
"you can ignore it if your app works fine otherwise.");
|
||||
#ifdef DEBUG
|
||||
NSLog(@"Pool status: %@", ((__bridge NSObject *)pool).description);
|
|
@ -52,9 +52,9 @@ objc_library(
|
|||
)
|
||||
|
||||
MEDIAPIPE_IOS_SRCS = [
|
||||
"MPPGraph.mm",
|
||||
"MPPTimestampConverter.mm",
|
||||
"NSError+util_status.mm",
|
||||
"MPPGraph.cc",
|
||||
"MPPTimestampConverter.cc",
|
||||
"NSError+util_status.cc",
|
||||
]
|
||||
|
||||
MEDIAPIPE_IOS_HDRS = [
|
||||
|
@ -63,11 +63,13 @@ MEDIAPIPE_IOS_HDRS = [
|
|||
"NSError+util_status.h",
|
||||
]
|
||||
|
||||
objc_library(
|
||||
cc_library(
|
||||
name = "mediapipe_framework_ios",
|
||||
srcs = MEDIAPIPE_IOS_SRCS,
|
||||
hdrs = MEDIAPIPE_IOS_HDRS,
|
||||
copts = [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
|
@ -99,6 +101,7 @@ objc_library(
|
|||
"@com_google_absl//absl/synchronization",
|
||||
"@google_toolbox_for_mac//:GTM_Defines",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
objc_library(
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <atomic>
|
||||
|
||||
#import "GTMDefines.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
@ -26,22 +27,22 @@
|
|||
#include "mediapipe/framework/graph_service.h"
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#import "mediapipe/objc/NSError+util_status.h"
|
||||
#include "mediapipe/objc/util.h"
|
||||
|
||||
#import "mediapipe/objc/NSError+util_status.h"
|
||||
#import "GTMDefines.h"
|
||||
|
||||
@implementation MPPGraph {
|
||||
// Graph is wrapped in a unique_ptr because it was generating 39+KB of unnecessary ObjC runtime
|
||||
// information. See https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de
|
||||
// for details.
|
||||
// Graph is wrapped in a unique_ptr because it was generating 39+KB of
|
||||
// unnecessary ObjC runtime information. See
|
||||
// https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de for
|
||||
// details.
|
||||
std::unique_ptr<mediapipe::CalculatorGraph> _graph;
|
||||
/// Input side packets that will be added to the graph when it is started.
|
||||
std::map<std::string, mediapipe::Packet> _inputSidePackets;
|
||||
/// Packet headers that will be added to the graph when it is started.
|
||||
std::map<std::string, mediapipe::Packet> _streamHeaders;
|
||||
/// Service packets to be added to the graph when it is started.
|
||||
std::map<const mediapipe::GraphServiceBase*, mediapipe::Packet> _servicePackets;
|
||||
std::map<const mediapipe::GraphServiceBase*, mediapipe::Packet>
|
||||
_servicePackets;
|
||||
|
||||
/// Number of frames currently being processed by the graph.
|
||||
std::atomic<int32_t> _framesInFlight;
|
||||
|
@ -56,7 +57,8 @@
|
|||
BOOL _started;
|
||||
}
|
||||
|
||||
- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config {
|
||||
- (instancetype)initWithGraphConfig:
|
||||
(const mediapipe::CalculatorGraphConfig&)config {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
// Turn on Cocoa multithreading, since MediaPipe uses threads.
|
||||
|
@ -76,34 +78,41 @@
|
|||
return _graph->GetGraphInputStreamAddMode();
|
||||
}
|
||||
|
||||
- (void)setPacketAddMode:(mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode {
|
||||
- (void)setPacketAddMode:
|
||||
(mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode {
|
||||
_graph->SetGraphInputStreamAddMode(mode);
|
||||
}
|
||||
|
||||
- (void)addFrameOutputStream:(const std::string&)outputStreamName
|
||||
outputPacketType:(MPPPacketType)packetType {
|
||||
std::string callbackInputName;
|
||||
mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config, &callbackInputName,
|
||||
mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config,
|
||||
&callbackInputName,
|
||||
/*use_std_function=*/true);
|
||||
// No matter what ownership qualifiers are put on the pointer, NewPermanentCallback will
|
||||
// still end up with a strong pointer to MPPGraph*. That is why we use void* instead.
|
||||
// No matter what ownership qualifiers are put on the pointer,
|
||||
// NewPermanentCallback will still end up with a strong pointer to MPPGraph*.
|
||||
// That is why we use void* instead.
|
||||
void* wrapperVoid = (__bridge void*)self;
|
||||
_inputSidePackets[callbackInputName] =
|
||||
mediapipe::MakePacket<std::function<void(const mediapipe::Packet&)>>(
|
||||
[wrapperVoid, outputStreamName, packetType](const mediapipe::Packet& packet) {
|
||||
CallFrameDelegate(wrapperVoid, outputStreamName, packetType, packet);
|
||||
[wrapperVoid, outputStreamName,
|
||||
packetType](const mediapipe::Packet& packet) {
|
||||
CallFrameDelegate(wrapperVoid, outputStreamName, packetType,
|
||||
packet);
|
||||
});
|
||||
}
|
||||
|
||||
- (NSString *)description {
|
||||
return [NSString stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self,
|
||||
- (NSString*)description {
|
||||
return [NSString
|
||||
stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self,
|
||||
_framesInFlight.load(std::memory_order_relaxed)];
|
||||
}
|
||||
|
||||
/// This is the function that gets called by the CallbackCalculator that
|
||||
/// receives the graph's output.
|
||||
void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
|
||||
MPPPacketType packetType, const mediapipe::Packet& packet) {
|
||||
MPPPacketType packetType,
|
||||
const mediapipe::Packet& packet) {
|
||||
MPPGraph* wrapper = (__bridge MPPGraph*)wrapperVoid;
|
||||
@autoreleasepool {
|
||||
if (packetType == MPPPacketTypeRaw) {
|
||||
|
@ -118,13 +127,16 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
|
|||
if (format == mediapipe::ImageFormat::SRGBA ||
|
||||
format == mediapipe::ImageFormat::GRAY8) {
|
||||
CVPixelBufferRef pixelBuffer;
|
||||
// If kCVPixelFormatType_32RGBA does not work, it returns kCVReturnInvalidPixelFormat.
|
||||
// If kCVPixelFormatType_32RGBA does not work, it returns
|
||||
// kCVReturnInvalidPixelFormat.
|
||||
CVReturn error = CVPixelBufferCreate(
|
||||
NULL, frame.Width(), frame.Height(), kCVPixelFormatType_32BGRA,
|
||||
GetCVPixelBufferAttributesForGlCompatibility(), &pixelBuffer);
|
||||
_GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferCreate failed: %d", error);
|
||||
_GTMDevAssert(error == kCVReturnSuccess,
|
||||
@"CVPixelBufferCreate failed: %d", error);
|
||||
error = CVPixelBufferLockBaseAddress(pixelBuffer, 0);
|
||||
_GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferLockBaseAddress failed: %d", error);
|
||||
_GTMDevAssert(error == kCVReturnSuccess,
|
||||
@"CVPixelBufferLockBaseAddress failed: %d", error);
|
||||
|
||||
vImage_Buffer vDestination = vImageForCVPixelBuffer(pixelBuffer);
|
||||
// Note: we have to throw away const here, but we should not overwrite
|
||||
|
@ -133,26 +145,31 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
|
|||
if (format == mediapipe::ImageFormat::SRGBA) {
|
||||
// Swap R and B channels.
|
||||
const uint8_t permuteMap[4] = {2, 1, 0, 3};
|
||||
vImage_Error __unused vError =
|
||||
vImagePermuteChannels_ARGB8888(&vSource, &vDestination, permuteMap, kvImageNoFlags);
|
||||
_GTMDevAssert(vError == kvImageNoError, @"vImagePermuteChannels failed: %zd", vError);
|
||||
vImage_Error __unused vError = vImagePermuteChannels_ARGB8888(
|
||||
&vSource, &vDestination, permuteMap, kvImageNoFlags);
|
||||
_GTMDevAssert(vError == kvImageNoError,
|
||||
@"vImagePermuteChannels failed: %zd", vError);
|
||||
} else {
|
||||
// Convert grayscale back to BGRA
|
||||
vImage_Error __unused vError = vImageGrayToBGRA(&vSource, &vDestination);
|
||||
_GTMDevAssert(vError == kvImageNoError, @"vImageGrayToBGRA failed: %zd", vError);
|
||||
vImage_Error __unused vError =
|
||||
vImageGrayToBGRA(&vSource, &vDestination);
|
||||
_GTMDevAssert(vError == kvImageNoError,
|
||||
@"vImageGrayToBGRA failed: %zd", vError);
|
||||
}
|
||||
|
||||
error = CVPixelBufferUnlockBaseAddress(pixelBuffer, 0);
|
||||
_GTMDevAssert(error == kCVReturnSuccess,
|
||||
@"CVPixelBufferUnlockBaseAddress failed: %d", error);
|
||||
|
||||
if ([wrapper.delegate respondsToSelector:@selector
|
||||
if ([wrapper.delegate
|
||||
respondsToSelector:@selector
|
||||
(mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) {
|
||||
[wrapper.delegate mediapipeGraph:wrapper
|
||||
didOutputPixelBuffer:pixelBuffer
|
||||
fromStream:streamName
|
||||
timestamp:packet.Timestamp()];
|
||||
} else if ([wrapper.delegate respondsToSelector:@selector
|
||||
} else if ([wrapper.delegate
|
||||
respondsToSelector:@selector
|
||||
(mediapipeGraph:didOutputPixelBuffer:fromStream:)]) {
|
||||
[wrapper.delegate mediapipeGraph:wrapper
|
||||
didOutputPixelBuffer:pixelBuffer
|
||||
|
@ -168,10 +185,11 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName,
|
|||
wrapper->_framesInFlight--;
|
||||
CVPixelBufferRef pixelBuffer;
|
||||
if (packetType == MPPPacketTypePixelBuffer)
|
||||
pixelBuffer = mediapipe::GetCVPixelBufferRef(packet.Get<mediapipe::GpuBuffer>());
|
||||
pixelBuffer =
|
||||
mediapipe::GetCVPixelBufferRef(packet.Get<mediapipe::GpuBuffer>());
|
||||
else
|
||||
pixelBuffer = packet.Get<mediapipe::Image>().GetCVPixelBufferRef();
|
||||
if ([wrapper.delegate
|
||||
if ([wrapper.delegate
|
||||
respondsToSelector:@selector
|
||||
(mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) {
|
||||
[wrapper.delegate mediapipeGraph:wrapper
|
||||
|
@ -192,13 +210,15 @@ if ([wrapper.delegate
|
|||
}
|
||||
}
|
||||
|
||||
- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName {
|
||||
- (void)setHeaderPacket:(const mediapipe::Packet&)packet
|
||||
forStream:(const std::string&)streamName {
|
||||
_GTMDevAssert(!_started, @"%@ must be called before the graph is started",
|
||||
NSStringFromSelector(_cmd));
|
||||
_streamHeaders[streamName] = packet;
|
||||
}
|
||||
|
||||
- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name {
|
||||
- (void)setSidePacket:(const mediapipe::Packet&)packet
|
||||
named:(const std::string&)name {
|
||||
_GTMDevAssert(!_started, @"%@ must be called before the graph is started",
|
||||
NSStringFromSelector(_cmd));
|
||||
_inputSidePackets[name] = packet;
|
||||
|
@ -211,7 +231,8 @@ if ([wrapper.delegate
|
|||
_servicePackets[&service] = std::move(packet);
|
||||
}
|
||||
|
||||
- (void)addSidePackets:(const std::map<std::string, mediapipe::Packet>&)extraSidePackets {
|
||||
- (void)addSidePackets:
|
||||
(const std::map<std::string, mediapipe::Packet>&)extraSidePackets {
|
||||
_GTMDevAssert(!_started, @"%@ must be called before the graph is started",
|
||||
NSStringFromSelector(_cmd));
|
||||
_inputSidePackets.insert(extraSidePackets.begin(), extraSidePackets.end());
|
||||
|
@ -232,7 +253,8 @@ if ([wrapper.delegate
|
|||
- (absl::Status)performStart {
|
||||
absl::Status status;
|
||||
for (const auto& service_packet : _servicePackets) {
|
||||
status = _graph->SetServicePacket(*service_packet.first, service_packet.second);
|
||||
status =
|
||||
_graph->SetServicePacket(*service_packet.first, service_packet.second);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -269,10 +291,11 @@ if ([wrapper.delegate
|
|||
}
|
||||
|
||||
- (BOOL)waitUntilDoneWithError:(NSError**)error {
|
||||
// Since this method blocks with no timeout, it should not be called in the main thread in
|
||||
// an app. However, it's fine to allow that in a test.
|
||||
// Since this method blocks with no timeout, it should not be called in the
|
||||
// main thread in an app. However, it's fine to allow that in a test.
|
||||
// TODO: is this too heavy-handed? Maybe a warning would be fine.
|
||||
_GTMDevAssert(![NSThread isMainThread] || (NSClassFromString(@"XCTest")),
|
||||
_GTMDevAssert(
|
||||
![NSThread isMainThread] || (NSClassFromString(@"XCTest")),
|
||||
@"waitUntilDoneWithError: should not be called on the main thread");
|
||||
absl::Status status = _graph->WaitUntilDone();
|
||||
_started = NO;
|
||||
|
@ -289,7 +312,8 @@ if ([wrapper.delegate
|
|||
- (BOOL)movePacket:(mediapipe::Packet&&)packet
|
||||
intoStream:(const std::string&)streamName
|
||||
error:(NSError**)error {
|
||||
absl::Status status = _graph->AddPacketToInputStream(streamName, std::move(packet));
|
||||
absl::Status status =
|
||||
_graph->AddPacketToInputStream(streamName, std::move(packet));
|
||||
if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status];
|
||||
return status.ok();
|
||||
}
|
||||
|
@ -305,7 +329,8 @@ if ([wrapper.delegate
|
|||
- (BOOL)setMaxQueueSize:(int)maxQueueSize
|
||||
forStream:(const std::string&)streamName
|
||||
error:(NSError**)error {
|
||||
absl::Status status = _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize);
|
||||
absl::Status status =
|
||||
_graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize);
|
||||
if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status];
|
||||
return status.ok();
|
||||
}
|
||||
|
@ -313,7 +338,8 @@ if ([wrapper.delegate
|
|||
- (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||
packetType:(MPPPacketType)packetType {
|
||||
mediapipe::Packet packet;
|
||||
if (packetType == MPPPacketTypeImageFrame || packetType == MPPPacketTypeImageFrameBGRANoSwap) {
|
||||
if (packetType == MPPPacketTypeImageFrame ||
|
||||
packetType == MPPPacketTypeImageFrameBGRANoSwap) {
|
||||
auto frame = CreateImageFrameForCVPixelBuffer(
|
||||
imageBuffer, /* canOverwrite = */ false,
|
||||
/* bgrAsRgb = */ packetType == MPPPacketTypeImageFrameBGRANoSwap);
|
||||
|
@ -328,7 +354,8 @@ if ([wrapper.delegate
|
|||
packet = mediapipe::MakePacket<mediapipe::Image>(imageBuffer);
|
||||
#else
|
||||
// CPU
|
||||
auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer, /* canOverwrite = */ false,
|
||||
auto frame = CreateImageFrameForCVPixelBuffer(imageBuffer,
|
||||
/* canOverwrite = */ false,
|
||||
/* bgrAsRgb = */ false);
|
||||
packet = mediapipe::MakePacket<mediapipe::Image>(std::move(frame));
|
||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
|
@ -339,7 +366,8 @@ if ([wrapper.delegate
|
|||
}
|
||||
|
||||
- (mediapipe::Packet)imagePacketWithPixelBuffer:(CVPixelBufferRef)pixelBuffer {
|
||||
return [self packetWithPixelBuffer:(pixelBuffer) packetType:(MPPPacketTypeImage)];
|
||||
return [self packetWithPixelBuffer:(pixelBuffer)
|
||||
packetType:(MPPPacketTypeImage)];
|
||||
}
|
||||
|
||||
- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer
|
||||
|
@ -367,13 +395,16 @@ if ([wrapper.delegate
|
|||
allowOverwrite:(BOOL)allowOverwrite
|
||||
error:(NSError**)error {
|
||||
if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO;
|
||||
mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType];
|
||||
mediapipe::Packet packet =
|
||||
[self packetWithPixelBuffer:imageBuffer packetType:packetType];
|
||||
BOOL success;
|
||||
if (allowOverwrite) {
|
||||
packet = std::move(packet).At(timestamp);
|
||||
success = [self movePacket:std::move(packet) intoStream:inputName error:error];
|
||||
success =
|
||||
[self movePacket:std::move(packet) intoStream:inputName error:error];
|
||||
} else {
|
||||
success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
|
||||
success =
|
||||
[self sendPacket:packet.At(timestamp) intoStream:inputName error:error];
|
||||
}
|
||||
if (success) _framesInFlight++;
|
||||
return success;
|
||||
|
@ -407,22 +438,24 @@ if ([wrapper.delegate
|
|||
}
|
||||
|
||||
- (void)debugPrintGlInfo {
|
||||
std::shared_ptr<mediapipe::GpuResources> gpu_resources = _graph->GetGpuResources();
|
||||
std::shared_ptr<mediapipe::GpuResources> gpu_resources =
|
||||
_graph->GetGpuResources();
|
||||
if (!gpu_resources) {
|
||||
NSLog(@"GPU not set up.");
|
||||
return;
|
||||
}
|
||||
|
||||
NSString* extensionString;
|
||||
(void)gpu_resources->gl_context()->Run([&extensionString]{
|
||||
extensionString = [NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)];
|
||||
(void)gpu_resources->gl_context()->Run([&extensionString] {
|
||||
extensionString =
|
||||
[NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)];
|
||||
return absl::OkStatus();
|
||||
});
|
||||
|
||||
NSArray* extensions = [extensionString componentsSeparatedByCharactersInSet:
|
||||
[NSCharacterSet whitespaceCharacterSet]];
|
||||
for (NSString* oneExtension in extensions)
|
||||
NSLog(@"%@", oneExtension);
|
||||
NSArray* extensions = [extensionString
|
||||
componentsSeparatedByCharactersInSet:[NSCharacterSet
|
||||
whitespaceCharacterSet]];
|
||||
for (NSString* oneExtension in extensions) NSLog(@"%@", oneExtension);
|
||||
}
|
||||
|
||||
@end
|
|
@ -20,8 +20,7 @@
|
|||
mediapipe::TimestampDiff _timestampOffset;
|
||||
}
|
||||
|
||||
- (instancetype)init
|
||||
{
|
||||
- (instancetype)init {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
[self reset];
|
||||
|
@ -36,11 +35,14 @@
|
|||
}
|
||||
|
||||
- (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime {
|
||||
Float64 sampleSeconds = CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0;
|
||||
const int64 sampleUsec = sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond;
|
||||
Float64 sampleSeconds =
|
||||
CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0;
|
||||
const int64 sampleUsec =
|
||||
sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond;
|
||||
_mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset;
|
||||
if (_mediapipeTimestamp <= _lastTimestamp) {
|
||||
_timestampOffset = _timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp;
|
||||
_timestampOffset =
|
||||
_timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp;
|
||||
_mediapipeTimestamp = _lastTimestamp + 1;
|
||||
}
|
||||
_lastTimestamp = _mediapipeTimestamp;
|
72
mediapipe/objc/NSError+util_status.cc
Normal file
72
mediapipe/objc/NSError+util_status.cc
Normal file
|
@ -0,0 +1,72 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/objc/NSError+util_status.h"
|
||||
|
||||
@implementation GUSUtilStatusWrapper
|
||||
|
||||
+ (instancetype)wrapStatus:(const absl::Status &)status {
|
||||
return [[self alloc] initWithStatus:status];
|
||||
}
|
||||
|
||||
- (instancetype)initWithStatus:(const absl::Status &)status {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_status = status;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSString *)description {
|
||||
return [NSString stringWithFormat:@"<%@: %p; status = %s>", [self class],
|
||||
self, _status.message().data()];
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation NSError (GUSGoogleUtilStatus)
|
||||
|
||||
NSString *const kGUSGoogleUtilStatusErrorDomain =
|
||||
@"GoogleUtilStatusErrorDomain";
|
||||
NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey";
|
||||
|
||||
+ (NSError *)gus_errorWithStatus:(const absl::Status &)status {
|
||||
NSDictionary *userInfo = @{
|
||||
NSLocalizedDescriptionKey : @(status.message().data()),
|
||||
kGUSGoogleUtilStatusErrorKey : [GUSUtilStatusWrapper wrapStatus:status],
|
||||
};
|
||||
NSError *error =
|
||||
[NSError errorWithDomain:kGUSGoogleUtilStatusErrorDomain
|
||||
code:static_cast<NSInteger>(status.code())
|
||||
userInfo:userInfo];
|
||||
return error;
|
||||
}
|
||||
|
||||
- (absl::Status)gus_status {
|
||||
NSString *domain = self.domain;
|
||||
if ([domain isEqual:kGUSGoogleUtilStatusErrorDomain]) {
|
||||
GUSUtilStatusWrapper *wrapper = self.userInfo[kGUSGoogleUtilStatusErrorKey];
|
||||
if (wrapper) return wrapper.status;
|
||||
#if 0
|
||||
// Unfortunately, util/task/posixerrorspace.h is not in portable status yet.
|
||||
// TODO: fix that.
|
||||
} else if ([domain isEqual:NSPOSIXErrorDomain]) {
|
||||
return ::util::PosixErrorToStatus(self.code, self.localizedDescription.UTF8String);
|
||||
#endif
|
||||
}
|
||||
return absl::Status(absl::StatusCode::kUnknown,
|
||||
self.localizedDescription.UTF8String);
|
||||
}
|
||||
|
||||
@end
|
|
@ -207,8 +207,12 @@ class ImageTest(absltest.TestCase):
|
|||
loaded_image = Image.create_from_file(image_path)
|
||||
self.assertEqual(loaded_image.width, 720)
|
||||
self.assertEqual(loaded_image.height, 382)
|
||||
self.assertEqual(loaded_image.channels, 3)
|
||||
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
|
||||
# On Mac w/ GPU support, images use 4 channels (SRGBA). Otherwise, all
|
||||
# images use 3 channels (SRGB).
|
||||
self.assertIn(loaded_image.channels, [3, 4])
|
||||
self.assertIn(
|
||||
loaded_image.image_format, [ImageFormat.SRGB, ImageFormat.SRGBA]
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -51,10 +51,10 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
|
||||
```python
|
||||
import cv2
|
||||
cv_mat = cv2.imread(input_file)[:, :, ::-1]
|
||||
rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat)
|
||||
cv_mat = cv2.imread(input_file)
|
||||
rgb_frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=cv_mat)
|
||||
gray_frame = mp.Image(
|
||||
image_format=ImageFormat.GRAY,
|
||||
image_format=mp.ImageFormat.GRAY8,
|
||||
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
|
||||
|
||||
from PIL import Image
|
||||
|
@ -244,12 +244,26 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
image.def_static(
|
||||
"create_from_file",
|
||||
[](const std::string& file_name) {
|
||||
unsigned char* image_data = nullptr;
|
||||
int width;
|
||||
int height;
|
||||
int channels;
|
||||
auto* image_data =
|
||||
stbi_load(file_name.c_str(), &width, &height, &channels,
|
||||
|
||||
#if TARGET_OS_OSX && !MEDIAPIPE_DISABLE_GPU
|
||||
// Our ObjC layer does not support 3-channel images, so we read the
|
||||
// number of channels first and request RGBA if needed.
|
||||
if (stbi_info(file_name.c_str(), &width, &height, &channels)) {
|
||||
if (channels == 3) {
|
||||
channels = 4;
|
||||
}
|
||||
int unused;
|
||||
image_data =
|
||||
stbi_load(file_name.c_str(), &width, &height, &unused, channels);
|
||||
}
|
||||
#else
|
||||
image_data = stbi_load(file_name.c_str(), &width, &height, &channels,
|
||||
/*desired_channels=*/0);
|
||||
#endif // TARGET_OS_OSX && !MEDIAPIPE_DISABLE_GPU
|
||||
if (image_data == nullptr) {
|
||||
throw RaisePyError(PyExc_RuntimeError,
|
||||
absl::StrFormat("Image decoding failed (%s): %s",
|
||||
|
@ -263,11 +277,13 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
ImageFormat::GRAY8, width, height, width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
#if !TARGET_OS_OSX || MEDIAPIPE_DISABLE_GPU
|
||||
case 3:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGB, width, height, 3 * width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
#endif // !TARGET_OS_OSX || MEDIAPIPE_DISABLE_GPU
|
||||
case 4:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGBA, width, height, 4 * width, image_data,
|
||||
|
|
|
@ -81,8 +81,10 @@ void ImageFrameSubmodule(pybind11::module* module) {
|
|||
become immutable after creation.
|
||||
|
||||
Creation examples:
|
||||
|
||||
```python
|
||||
import cv2
|
||||
cv_mat = cv2.imread(input_file)[:, :, ::-1]
|
||||
cv_mat = cv2.imread(input_file)
|
||||
rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat)
|
||||
gray_frame = mp.ImageFrame(
|
||||
image_format=ImageFormat.GRAY,
|
||||
|
@ -92,6 +94,7 @@ void ImageFrameSubmodule(pybind11::module* module) {
|
|||
pil_img = Image.new('RGB', (60, 30), color = 'red')
|
||||
image_frame = mp.ImageFrame(
|
||||
image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
|
||||
```
|
||||
|
||||
The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling
|
||||
`ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the
|
||||
|
|
|
@ -30,13 +30,12 @@ cc_library(
|
|||
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
|
@ -26,6 +28,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
|
||||
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
|
@ -41,7 +44,10 @@ using ::mediapipe::tasks::c::components::processors::
|
|||
CppConvertToClassifierOptions;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
|
||||
using ::mediapipe::tasks::vision::core::RunningMode;
|
||||
using ::mediapipe::tasks::vision::image_classifier::ImageClassifier;
|
||||
typedef ::mediapipe::tasks::vision::image_classifier::ImageClassifierResult
|
||||
CppImageClassifierResult;
|
||||
|
||||
int CppProcessError(absl::Status status, char** error_msg) {
|
||||
if (error_msg) {
|
||||
|
@ -60,6 +66,53 @@ ImageClassifier* CppImageClassifierCreate(const ImageClassifierOptions& options,
|
|||
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
|
||||
CppConvertToClassifierOptions(options.classifier_options,
|
||||
&cpp_options->classifier_options);
|
||||
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
|
||||
|
||||
// Enable callback for processing live stream data when the running mode is
|
||||
// set to RunningMode::LIVE_STREAM.
|
||||
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
|
||||
if (options.result_callback == nullptr) {
|
||||
const absl::Status status = absl::InvalidArgumentError(
|
||||
"Provided null pointer to callback function.");
|
||||
ABSL_LOG(ERROR) << "Failed to create ImageClassifier: " << status;
|
||||
CppProcessError(status, error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ImageClassifierOptions::result_callback_fn result_callback =
|
||||
options.result_callback;
|
||||
cpp_options->result_callback =
|
||||
[result_callback](absl::StatusOr<CppImageClassifierResult> cpp_result,
|
||||
const Image& image, int64_t timestamp) {
|
||||
char* error_msg = nullptr;
|
||||
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status();
|
||||
CppProcessError(cpp_result.status(), &error_msg);
|
||||
result_callback(nullptr, MpImage(), timestamp, error_msg);
|
||||
free(error_msg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Result is valid for the lifetime of the callback function.
|
||||
ImageClassifierResult result;
|
||||
CppConvertToClassificationResult(*cpp_result, &result);
|
||||
|
||||
const auto& image_frame = image.GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<::ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
result_callback(&result, mp_image, timestamp,
|
||||
/* error_msg= */ nullptr);
|
||||
|
||||
CppCloseClassificationResult(&result);
|
||||
};
|
||||
}
|
||||
|
||||
auto classifier = ImageClassifier::Create(std::move(cpp_options));
|
||||
if (!classifier.ok()) {
|
||||
|
@ -75,8 +128,8 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image,
|
|||
ImageClassifierResult* result,
|
||||
char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("gpu buffer not supported yet");
|
||||
const absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet.");
|
||||
|
||||
ABSL_LOG(ERROR) << "Classification failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
|
@ -102,6 +155,68 @@ int CppImageClassifierClassify(void* classifier, const MpImage* image,
|
|||
return 0;
|
||||
}
|
||||
|
||||
int CppImageClassifierClassifyForVideo(void* classifier, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageClassifierResult* result,
|
||||
char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet");
|
||||
|
||||
ABSL_LOG(ERROR) << "Classification failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_classifier = static_cast<ImageClassifier*>(classifier);
|
||||
auto cpp_result = cpp_classifier->ClassifyForVideo(*img, timestamp_ms);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
CppConvertToClassificationResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CppImageClassifierClassifyAsync(void* classifier, const MpImage* image,
|
||||
int64_t timestamp_ms, char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet");
|
||||
|
||||
ABSL_LOG(ERROR) << "Classification failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_classifier = static_cast<ImageClassifier*>(classifier);
|
||||
auto cpp_result = cpp_classifier->ClassifyAsync(*img, timestamp_ms);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Data preparation for the image classification failed: "
|
||||
<< cpp_result;
|
||||
return CppProcessError(cpp_result, error_msg);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CppImageClassifierCloseResult(ImageClassifierResult* result) {
|
||||
CppCloseClassificationResult(result);
|
||||
}
|
||||
|
@ -134,6 +249,22 @@ int image_classifier_classify_image(void* classifier, const MpImage* image,
|
|||
CppImageClassifierClassify(classifier, image, result, error_msg);
|
||||
}
|
||||
|
||||
int image_classifier_classify_for_video(void* classifier, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageClassifierResult* result,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_classifier::
|
||||
CppImageClassifierClassifyForVideo(classifier, image, timestamp_ms,
|
||||
result, error_msg);
|
||||
}
|
||||
|
||||
int image_classifier_classify_async(void* classifier, const MpImage* image,
|
||||
int64_t timestamp_ms, char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::image_classifier::
|
||||
CppImageClassifierClassifyAsync(classifier, image, timestamp_ms,
|
||||
error_msg);
|
||||
}
|
||||
|
||||
void image_classifier_close_result(ImageClassifierResult* result) {
|
||||
mediapipe::tasks::c::vision::image_classifier::CppImageClassifierCloseResult(
|
||||
result);
|
||||
|
|
|
@ -92,9 +92,16 @@ struct ImageClassifierOptions {
|
|||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
typedef void (*result_callback_fn)(ImageClassifierResult*, const MpImage*,
|
||||
int64_t);
|
||||
// to RunningMode::LIVE_STREAM. Arguments of the callback function include:
|
||||
// the pointer to classification result, the image that result was obtained
|
||||
// on, the timestamp relevant to classification results and pointer to error
|
||||
// message in case of any failure. The validity of the passed arguments is
|
||||
// true for the lifetime of the callback function.
|
||||
//
|
||||
// A caller is responsible for closing image classifier result.
|
||||
typedef void (*result_callback_fn)(ImageClassifierResult* result,
|
||||
const MpImage image, int64_t timestamp_ms,
|
||||
char* error_msg);
|
||||
result_callback_fn result_callback;
|
||||
};
|
||||
|
||||
|
@ -110,13 +117,22 @@ MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options,
|
|||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// allocated for the error message.
|
||||
//
|
||||
// TODO: Add API for video and live stream processing.
|
||||
MP_EXPORT int image_classifier_classify_image(void* classifier,
|
||||
const MpImage* image,
|
||||
ImageClassifierResult* result,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
MP_EXPORT int image_classifier_classify_for_video(void* classifier,
|
||||
const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ImageClassifierResult* result,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
MP_EXPORT int image_classifier_classify_async(void* classifier,
|
||||
const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
// Frees the memory allocated inside a ImageClassifierResult result.
|
||||
// Does not free the result pointer itself.
|
||||
MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result);
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
|
@ -36,12 +37,13 @@ using testing::HasSubstr;
|
|||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kModelName[] = "mobilenet_v2_1.0_224.tflite";
|
||||
constexpr float kPrecision = 1e-4;
|
||||
constexpr int kIterations = 100;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
TEST(ImageClassifierTest, SmokeTest) {
|
||||
TEST(ImageClassifierTest, ImageModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
|
@ -63,14 +65,13 @@ TEST(ImageClassifierTest, SmokeTest) {
|
|||
void* classifier = image_classifier_create(&options);
|
||||
EXPECT_NE(classifier, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<ImageFormat>(
|
||||
image->GetImageFrameSharedPtr()->Format()),
|
||||
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
|
||||
.width = image->GetImageFrameSharedPtr()->Width(),
|
||||
.height = image->GetImageFrameSharedPtr()->Height()}};
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
ImageClassifierResult result;
|
||||
image_classifier_classify_image(classifier, &mp_image, &result);
|
||||
|
@ -84,6 +85,120 @@ TEST(ImageClassifierTest, SmokeTest) {
|
|||
image_classifier_close(classifier);
|
||||
}
|
||||
|
||||
TEST(ImageClassifierTest, VideoModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ImageClassifierOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::VIDEO,
|
||||
/* classifier_options= */
|
||||
{/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ 3,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0},
|
||||
/* result_callback= */ nullptr,
|
||||
};
|
||||
|
||||
void* classifier = image_classifier_create(&options);
|
||||
EXPECT_NE(classifier, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
ImageClassifierResult result;
|
||||
image_classifier_classify_for_video(classifier, &mp_image, i, &result);
|
||||
EXPECT_EQ(result.classifications_count, 1);
|
||||
EXPECT_EQ(result.classifications[0].categories_count, 3);
|
||||
EXPECT_EQ(
|
||||
std::string{result.classifications[0].categories[0].category_name},
|
||||
"cheeseburger");
|
||||
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f,
|
||||
kPrecision);
|
||||
image_classifier_close_result(&result);
|
||||
}
|
||||
image_classifier_close(classifier);
|
||||
}
|
||||
|
||||
// A structure to support LiveStreamModeTest below. This structure holds a
|
||||
// static method `Fn` for a callback function of C API. A `static` qualifier
|
||||
// allows to take an address of the method to follow API style. Another static
|
||||
// struct member is `last_timestamp` that is used to verify that current
|
||||
// timestamp is greater than the previous one.
|
||||
struct LiveStreamModeCallback {
|
||||
static int64_t last_timestamp;
|
||||
static void Fn(ImageClassifierResult* classifier_result, const MpImage image,
|
||||
int64_t timestamp, char* error_msg) {
|
||||
ASSERT_NE(classifier_result, nullptr);
|
||||
ASSERT_EQ(error_msg, nullptr);
|
||||
EXPECT_EQ(
|
||||
std::string{
|
||||
classifier_result->classifications[0].categories[0].category_name},
|
||||
"cheeseburger");
|
||||
EXPECT_NEAR(classifier_result->classifications[0].categories[0].score,
|
||||
0.7939f, kPrecision);
|
||||
EXPECT_GT(image.image_frame.width, 0);
|
||||
EXPECT_GT(image.image_frame.height, 0);
|
||||
EXPECT_GT(timestamp, last_timestamp);
|
||||
last_timestamp++;
|
||||
}
|
||||
};
|
||||
int64_t LiveStreamModeCallback::last_timestamp = -1;
|
||||
|
||||
TEST(ImageClassifierTest, LiveStreamModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
|
||||
ImageClassifierOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::LIVE_STREAM,
|
||||
/* classifier_options= */
|
||||
{/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ 3,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0},
|
||||
/* result_callback= */ LiveStreamModeCallback::Fn,
|
||||
};
|
||||
|
||||
void* classifier = image_classifier_create(&options);
|
||||
EXPECT_NE(classifier, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
EXPECT_GE(image_classifier_classify_async(classifier, &mp_image, i), 0);
|
||||
}
|
||||
image_classifier_close(classifier);
|
||||
|
||||
// Due to the flow limiter, the total of outputs might be smaller than the
|
||||
// number of iterations.
|
||||
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
|
||||
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
|
||||
}
|
||||
|
||||
TEST(ImageClassifierTest, InvalidArgumentHandling) {
|
||||
// It is an error to set neither the asset buffer nor the path.
|
||||
ImageClassifierOptions options = {
|
||||
|
@ -124,7 +239,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) {
|
|||
ImageClassifierResult result;
|
||||
char* error_msg;
|
||||
image_classifier_classify_image(classifier, &mp_image, &result, &error_msg);
|
||||
EXPECT_THAT(error_msg, HasSubstr("gpu buffer not supported yet"));
|
||||
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet"));
|
||||
free(error_msg);
|
||||
image_classifier_close(classifier);
|
||||
}
|
||||
|
|
|
@ -98,3 +98,9 @@ mediapipe_proto_library(
|
|||
name = "transformer_params_proto",
|
||||
srcs = ["transformer_params.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "llm_params_proto",
|
||||
srcs = ["llm_params.proto"],
|
||||
deps = [":transformer_params_proto"],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/components/processors/proto/transformer_params.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||
option java_outer_classname = "LLMParametersProto";
|
||||
|
||||
// Parameters for Large Language Models (LLM).
|
||||
message LLMParameters {
|
||||
TransformerParameters transformer_parameters = 1;
|
||||
|
||||
// Size of vocabulary.
|
||||
int32 vocab_size = 2;
|
||||
|
||||
// Whether or not to disable KV cache, which is also referred as state
|
||||
// somewhere else.
|
||||
bool disable_kv_cache = 3;
|
||||
|
||||
// Id of the start token.
|
||||
int32 start_token_id = 4;
|
||||
|
||||
// Token to determine the end of output stream.
|
||||
string stop_token = 5;
|
||||
}
|
|
@ -44,6 +44,21 @@ message TransformerParameters {
|
|||
// Number of stacked transformers, `N` in the paper.
|
||||
int32 num_stacks = 7;
|
||||
|
||||
// Whether to use Multi-Query-Attention (MQA).
|
||||
bool use_mqa = 8;
|
||||
// Deprecated: bool use_mqa. Use num_kv_heads below.
|
||||
reserved 8;
|
||||
|
||||
// Number of kv heads. 0 means Multi-Head-Attention (MHA), key and value have
|
||||
// same number of heads as query; 1 means Multi-Query-Attention (MQA), key and
|
||||
// value have one head; otherwise, this specifies the number of heads for key
|
||||
// and value, and Grouped-Query-Attention (GQA) will be used. See
|
||||
// https://arxiv.org/pdf/2305.13245.pdf for details.
|
||||
int32 num_kv_heads = 9;
|
||||
|
||||
// Different types of attention mask type.
|
||||
enum AttentionMaskType {
|
||||
UNSPECIFIED = 0;
|
||||
CAUSAL = 1;
|
||||
PREFIX = 2;
|
||||
}
|
||||
AttentionMaskType attention_mask_type = 10;
|
||||
}
|
||||
|
|
|
@ -264,6 +264,7 @@ cc_library_with_tflite(
|
|||
"//mediapipe/framework:executor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:name_util",
|
||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
|
|
@ -39,6 +39,10 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace core {
|
||||
|
@ -88,16 +92,34 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
|
|||
} // namespace
|
||||
|
||||
/* static */
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||
PacketsCallback packets_callback,
|
||||
std::shared_ptr<Executor> default_executor,
|
||||
std::optional<PacketMap> input_side_packets,
|
||||
std::shared_ptr<::mediapipe::GpuResources> resources) {
|
||||
#else
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||
PacketsCallback packets_callback,
|
||||
std::shared_ptr<Executor> default_executor,
|
||||
std::optional<PacketMap> input_side_packets) {
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
||||
MP_RETURN_IF_ERROR(task_runner->Initialize(
|
||||
std::move(config), std::move(op_resolver), std::move(default_executor),
|
||||
std::move(input_side_packets)));
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
if (resources) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
task_runner->graph_.SetGpuResources(std::move(resources)));
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
MP_RETURN_IF_ERROR(task_runner->Start());
|
||||
return task_runner;
|
||||
}
|
||||
|
|
|
@ -42,6 +42,11 @@ limitations under the License.
|
|||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
class GpuResources;
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace tasks {
|
||||
namespace core {
|
||||
|
||||
|
@ -72,12 +77,22 @@ class TaskRunner {
|
|||
// asynchronous method, Send(), to provide the input packets. If the packets
|
||||
// callback is absent, clients must use the synchronous method, Process(), to
|
||||
// provide the input packets and receive the output packets.
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||
PacketsCallback packets_callback = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt,
|
||||
std::shared_ptr<::mediapipe::GpuResources> resources = nullptr);
|
||||
#else
|
||||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||
PacketsCallback packets_callback = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// TaskRunner is neither copyable nor movable.
|
||||
TaskRunner(const TaskRunner&) = delete;
|
||||
|
|
|
@ -57,6 +57,7 @@ CALCULATORS_AND_GRAPHS = [
|
|||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
]
|
||||
|
||||
|
@ -83,6 +84,7 @@ strip_api_include_path_prefix(
|
|||
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
|
||||
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
|
||||
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
|
||||
"//mediapipe/tasks/ios/vision/core:sources/MPPMask.h",
|
||||
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetector.h",
|
||||
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorOptions.h",
|
||||
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorResult.h",
|
||||
|
@ -98,6 +100,9 @@ strip_api_include_path_prefix(
|
|||
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifier.h",
|
||||
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierOptions.h",
|
||||
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenter.h",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenterOptions.h",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter:sources/MPPImageSegmenterResult.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
|
||||
|
@ -178,6 +183,7 @@ apple_static_xcframework(
|
|||
":MPPTaskOptions.h",
|
||||
":MPPTaskResult.h",
|
||||
":MPPImage.h",
|
||||
":MPPMask.h",
|
||||
":MPPRunningMode.h",
|
||||
":MPPFaceDetector.h",
|
||||
":MPPFaceDetectorOptions.h",
|
||||
|
@ -188,6 +194,9 @@ apple_static_xcframework(
|
|||
":MPPImageClassifier.h",
|
||||
":MPPImageClassifierOptions.h",
|
||||
":MPPImageClassifierResult.h",
|
||||
":MPPImageSegmenter.h",
|
||||
":MPPImageSegmenterOptions.h",
|
||||
":MPPImageSegmenterResult.h",
|
||||
":MPPHandLandmarker.h",
|
||||
":MPPHandLandmarkerOptions.h",
|
||||
":MPPHandLandmarkerResult.h",
|
||||
|
@ -204,6 +213,7 @@ apple_static_xcframework(
|
|||
"//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer",
|
||||
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker",
|
||||
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenter",
|
||||
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,6 +14,15 @@
|
|||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
/**
|
||||
* The delegate to run MediaPipe. If the delegate is not set, the default
|
||||
* delegate CPU is used.
|
||||
*/
|
||||
typedef NS_ENUM(NSUInteger, MPPDelegate) {
|
||||
MPPDelegateCPU,
|
||||
MPPDelegateGPU,
|
||||
} NS_SWIFT_NAME(Delegate);
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
|
@ -26,6 +35,9 @@ NS_SWIFT_NAME(BaseOptions)
|
|||
/** The path to the model asset to open and mmap in memory. */
|
||||
@property(nonatomic, copy) NSString *modelAssetPath;
|
||||
|
||||
/** Overrides the default backend to use for the provided model. */
|
||||
@property(nonatomic) MPPDelegate delegate;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
self = [super init];
|
||||
if (self) {
|
||||
self.modelAssetPath = [[NSString alloc] init];
|
||||
self.delegate = MPPDelegateCPU;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
@ -28,6 +29,7 @@
|
|||
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
|
||||
|
||||
baseOptions.modelAssetPath = self.modelAssetPath;
|
||||
baseOptions.delegate = self.delegate;
|
||||
|
||||
return baseOptions;
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ objc_library(
|
|||
srcs = ["sources/MPPBaseOptions+Helpers.mm"],
|
||||
hdrs = ["sources/MPPBaseOptions+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
|
||||
using InferenceCalculatorOptionsProto = ::mediapipe::InferenceCalculatorOptions;
|
||||
}
|
||||
|
||||
@implementation MPPBaseOptions (Helpers)
|
||||
|
@ -33,6 +35,11 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
|
|||
if (self.modelAssetPath) {
|
||||
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
|
||||
}
|
||||
|
||||
if (self.delegate == MPPDelegateGPU) {
|
||||
baseOptionsProto->mutable_acceleration()->mutable_gpu()->MergeFrom(
|
||||
InferenceCalculatorOptionsProto::Delegate::Gpu());
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -31,3 +31,28 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPLanguageDetector",
|
||||
srcs = ["sources/MPPLanguageDetector.mm"],
|
||||
hdrs = ["sources/MPPLanguageDetector.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
"-x objective-c++",
|
||||
],
|
||||
module_name = "MPPLanguageDetector",
|
||||
deps = [
|
||||
":MPPLanguageDetectorOptions",
|
||||
":MPPLanguageDetectorResult",
|
||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
||||
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
|
||||
"//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorResultHelpers",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.h"
|
||||
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* @brief Predicts the language of an input text.
|
||||
*
|
||||
* This API expects a TFLite model with [TFLite Model
|
||||
* Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory
|
||||
* (described below) input tensor, output tensor, and the language codes in an AssociatedFile.
|
||||
*
|
||||
* Metadata is required for models with int32 input tensors because it contains the input
|
||||
* process unit for the model's Tokenizer. No metadata is required for models with string
|
||||
* input tensors.
|
||||
*
|
||||
* Input tensor
|
||||
* - One input tensor (`kTfLiteString`) of shape `[1]` containing the input string.
|
||||
*
|
||||
* Output tensor
|
||||
* - One output tensor (`kTfLiteFloat32`) of shape `[1 x N]` where `N` is the number of languages.
|
||||
*/
|
||||
NS_SWIFT_NAME(LanguageDetector)
|
||||
@interface MPPLanguageDetector : NSObject
|
||||
|
||||
/**
|
||||
* Creates a new instance of `LanguageDetector` from an absolute path to a TensorFlow Lite
|
||||
* model file stored locally on the device and the default `LanguageDetectorOptions`.
|
||||
*
|
||||
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||
* @param error An optional error parameter populated when there is an error in initializing the
|
||||
* language detector.
|
||||
*
|
||||
* @return A new instance of `LanguageDetector` with the given model path. `nil` if there is an
|
||||
* error in initializing the language detector.
|
||||
*/
|
||||
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Creates a new instance of `LanguageDetector` from the given `LanguageDetectorOptions`.
|
||||
*
|
||||
* @param options The options of type `LanguageDetectorOptions` to use for configuring the
|
||||
* `LanguageDetector`.
|
||||
* @param error An optional error parameter populated when there is an error in initializing the
|
||||
* language detector.
|
||||
*
|
||||
* @return A new instance of `LanguageDetector` with the given options. `nil` if there is an
|
||||
* error in initializing the language detector.
|
||||
*/
|
||||
- (nullable instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options
|
||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
* Predicts the language of the input text.
|
||||
*
|
||||
* @param text The `NSString` for which language is to be predicted.
|
||||
* @param error An optional error parameter populated when there is an error in performing
|
||||
* language prediction on the input text.
|
||||
*
|
||||
* @return A `LanguageDetectorResult` object that contains a list of language predictions.
|
||||
*/
|
||||
- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text
|
||||
error:(NSError **)error NS_SWIFT_NAME(detect(text:));
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
|
||||
#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
} // namespace
|
||||
|
||||
static NSString *const kClassificationsStreamName = @"classifications_out";
|
||||
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
|
||||
static NSString *const kTextInStreamName = @"text_in";
|
||||
static NSString *const kTextTag = @"TEXT";
|
||||
static NSString *const kTaskGraphName =
|
||||
@"mediapipe.tasks.text.language_detector.LanguageDetectorGraph";
|
||||
|
||||
@interface MPPLanguageDetector () {
|
||||
/** iOS Text Task Runner */
|
||||
MPPTextTaskRunner *_textTaskRunner;
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation MPPLanguageDetector
|
||||
|
||||
- (instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options error:(NSError **)error {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||
initWithTaskGraphName:kTaskGraphName
|
||||
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
|
||||
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
|
||||
kClassificationsStreamName] ]
|
||||
taskOptions:options
|
||||
enableFlowLimiting:NO
|
||||
error:error];
|
||||
|
||||
if (!taskInfo) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
_textTaskRunner =
|
||||
[[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig]
|
||||
error:error];
|
||||
|
||||
if (!_textTaskRunner) {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||
MPPLanguageDetectorOptions *options = [[MPPLanguageDetectorOptions alloc] init];
|
||||
|
||||
options.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return [self initWithOptions:options error:error];
|
||||
}
|
||||
|
||||
- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text error:(NSError **)error {
|
||||
Packet packet = [MPPTextPacketCreator createWithText:text];
|
||||
|
||||
std::map<std::string, Packet> packetMap = {{kTextInStreamName.cppString, packet}};
|
||||
std::optional<PacketMap> outputPacketMap = [_textTaskRunner processPacketMap:packetMap
|
||||
error:error];
|
||||
|
||||
if (!outputPacketMap.has_value()) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
return
|
||||
[MPPLanguageDetectorResult languageDetectorResultWithClassificationsPacket:
|
||||
outputPacketMap.value()[kClassificationsStreamName.cppString]];
|
||||
}
|
||||
|
||||
@end
|
|
@ -30,3 +30,15 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorOptions",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPLanguageDetectorResultHelpers",
|
||||
srcs = ["sources/MPPLanguageDetectorResult+Helpers.mm"],
|
||||
hdrs = ["sources/MPPLanguageDetectorResult+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetectorResult",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h"
|
||||
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface MPPLanguageDetectorResult (Helpers)
|
||||
|
||||
+ (MPPLanguageDetectorResult *)languageDetectorResultWithClassificationsPacket:
|
||||
(const mediapipe::Packet &)packet;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
namespace {
|
||||
using ClassificationResultProto =
|
||||
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
} // namespace
|
||||
|
||||
#define int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
@implementation MPPLanguageDetectorResult (Helpers)
|
||||
|
||||
+ (MPPLanguageDetectorResult *)languageDetectorResultWithClassificationsPacket:
|
||||
(const mediapipe::Packet &)packet {
|
||||
MPPClassificationResult *classificationResult = [MPPClassificationResult
|
||||
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
|
||||
|
||||
return [MPPLanguageDetectorResult
|
||||
languageDetectorResultWithClassificationResult:classificationResult
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
}
|
||||
|
||||
+ (MPPLanguageDetectorResult *)
|
||||
languageDetectorResultWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
NSMutableArray<MPPLanguagePrediction *> *languagePredictions =
|
||||
[NSMutableArray arrayWithCapacity:classificationResult.classifications.count];
|
||||
|
||||
if (classificationResult.classifications.count > 0) {
|
||||
for (MPPCategory *category in classificationResult.classifications[0].categories) {
|
||||
MPPLanguagePrediction *languagePrediction =
|
||||
[[MPPLanguagePrediction alloc] initWithLanguageCode:category.categoryName
|
||||
probability:category.score];
|
||||
[languagePredictions addObject:languagePrediction];
|
||||
}
|
||||
}
|
||||
|
||||
return [[MPPLanguageDetectorResult alloc] initWithLanguagePredictions:languagePredictions
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
@end
|
|
@ -37,7 +37,7 @@ vImage_Buffer allocatedVImageBuffer(vImagePixelCount width, vImagePixelCount hei
|
|||
}
|
||||
|
||||
static void FreeDataProviderReleaseCallback(void *buffer, const void *data, size_t size) {
|
||||
delete (vImage_Buffer *)buffer;
|
||||
delete[] (vImage_Buffer *)buffer;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -47,7 +47,7 @@ using ::mediapipe::Packet;
|
|||
->PixelData()
|
||||
width:confidenceMask.width()
|
||||
height:confidenceMask.height()
|
||||
shouldCopy:shouldCopyMaskPacketData ? YES : NO]];
|
||||
shouldCopy:shouldCopyMaskPacketData]];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ using ::mediapipe::Packet;
|
|||
initWithUInt8Data:(UInt8 *)cppCategoryMask.GetImageFrameSharedPtr().get()->PixelData()
|
||||
width:cppCategoryMask.width()
|
||||
height:cppCategoryMask.height()
|
||||
shouldCopy:shouldCopyMaskPacketData ? YES : NO];
|
||||
shouldCopy:shouldCopyMaskPacketData];
|
||||
}
|
||||
|
||||
if (qualityScoresPacket.ValidateAsType<std::vector<float>>().ok()) {
|
||||
|
|
|
@ -37,3 +37,23 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/vision/core:MPPRunningMode",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPPoseLandmarksConnections",
|
||||
hdrs = ["sources/MPPPoseLandmarksConnections.h"],
|
||||
module_name = "MPPPoseLandmarksConnections",
|
||||
deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPPoseLandmarker",
|
||||
hdrs = ["sources/MPPPoseLandmarker.h"],
|
||||
module_name = "MPPPoseLandmarker",
|
||||
deps = [
|
||||
":MPPPoseLandmarkerOptions",
|
||||
":MPPPoseLandmarkerResult",
|
||||
":MPPPoseLandmarksConnections",
|
||||
"//mediapipe/tasks/ios/components/containers:MPPConnection",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
|
||||
#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerOptions.h"
|
||||
#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* @brief Performs pose landmarks detection on images.
|
||||
*
|
||||
* This API expects a pre-trained pose landmarks model asset bundle.
|
||||
*/
|
||||
NS_SWIFT_NAME(PoseLandmarker)
|
||||
@interface MPPPoseLandmarker : NSObject
|
||||
|
||||
/** The array of connections between all the landmarks in the detected pose. */
|
||||
@property(class, nonatomic, readonly) NSArray<MPPConnection *> *poseLandmarks;
|
||||
|
||||
/**
|
||||
* Creates a new instance of `PoseLandmarker` from an absolute path to a model asset bundle stored
|
||||
* locally on the device and the default `PoseLandmarkerOptions`.
|
||||
*
|
||||
* @param modelPath An absolute path to a model asset bundle stored locally on the device.
|
||||
* @param error An optional error parameter populated when there is an error in initializing the
|
||||
* pose landmarker.
|
||||
*
|
||||
* @return A new instance of `PoseLandmarker` with the given model path. `nil` if there is an error
|
||||
* in initializing the pose landmarker.
|
||||
*/
|
||||
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
|
||||
|
||||
/**
|
||||
* Creates a new instance of `PoseLandmarker` from the given `PoseLandmarkerOptions`.
|
||||
*
|
||||
* @param options The options of type `PoseLandmarkerOptions` to use for configuring the
|
||||
* `PoseLandmarker`.
|
||||
* @param error An optional error parameter populated when there is an error in initializing the
|
||||
* pose landmarker.
|
||||
*
|
||||
* @return A new instance of `PoseLandmarker` with the given options. `nil` if there is an error in
|
||||
* initializing the pose landmarker.
|
||||
*/
|
||||
- (nullable instancetype)initWithOptions:(MPPPoseLandmarkerOptions *)options
|
||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on the provided `MPImage` using the whole image as region of
|
||||
* interest. Rotation will be applied according to the `orientation` property of the provided
|
||||
* `MPImage`. Only use this method when the `PoseLandmarker` is created with running mode `.image`.
|
||||
*
|
||||
* This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a
|
||||
* source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use
|
||||
* `kCVPixelFormatType_32BGRA` as its pixel format.
|
||||
*
|
||||
*
|
||||
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
|
||||
* channel.
|
||||
*
|
||||
* @param image The `MPImage` on which pose landmarks detection is to be performed.
|
||||
* @param error An optional error parameter populated when there is an error in performing pose
|
||||
* landmark detection on the input image.
|
||||
*
|
||||
* @return An `PoseLandmarkerResult` object that contains the pose landmarks detection
|
||||
* results.
|
||||
*/
|
||||
- (nullable MPPPoseLandmarkerResult *)detectImage:(MPPImage *)image
|
||||
error:(NSError **)error NS_SWIFT_NAME(detect(image:));
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on the provided video frame of type `MPImage` using the whole
|
||||
* image as region of interest. Rotation will be applied according to the `orientation` property of
|
||||
* the provided `MPImage`. Only use this method when the `PoseLandmarker` is created with running
|
||||
* mode `.video`.
|
||||
*
|
||||
* It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must
|
||||
* be monotonically increasing.
|
||||
*
|
||||
* This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a
|
||||
* source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use
|
||||
* `kCVPixelFormatType_32BGRA` as its pixel format.
|
||||
*
|
||||
*
|
||||
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
|
||||
* channel.
|
||||
*
|
||||
* @param image The `MPImage` on which pose landmarks detection is to be performed.
|
||||
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||
* timestamps must be monotonically increasing.
|
||||
* @param error An optional error parameter populated when there is an error in performing pose
|
||||
* landmark detection on the input video frame.
|
||||
*
|
||||
* @return An `PoseLandmarkerResult` object that contains the pose landmarks detection
|
||||
* results.
|
||||
*/
|
||||
- (nullable MPPPoseLandmarkerResult *)detectVideoFrame:(MPPImage *)image
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
|
||||
|
||||
/**
|
||||
* Sends live stream image data of type `MPImage` to perform pose landmarks detection using the
|
||||
* whole image as region of interest. Rotation will be applied according to the `orientation`
|
||||
* property of the provided `MPImage`. Only use this method when the `PoseLandmarker` is created
|
||||
* with running mode`.liveStream`.
|
||||
*
|
||||
* The object which needs to be continuously notified of the available results of pose landmark
|
||||
* detection must confirm to `PoseLandmarkerLiveStreamDelegate` protocol and implement the
|
||||
* `poseLandmarker(_:didFinishDetectionWithResult:timestampInMilliseconds:error:)` delegate method.
|
||||
*
|
||||
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
|
||||
* to the pose landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* This method supports performing pose landmarks detection on RGBA images. If your `MPImage` has a
|
||||
* source type of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use
|
||||
* `kCVPixelFormatType_32BGRA` as its pixel format.
|
||||
*
|
||||
* If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an
|
||||
* Alpha channel.
|
||||
*
|
||||
* If this method is used for performing pose landmarks detection on live camera frames using
|
||||
* `AVFoundation`, ensure that you request `AVCaptureVideoDataOutput` to output frames in
|
||||
* `kCMPixelFormat_32BGRA` using its `videoSettings` property.
|
||||
*
|
||||
* @param image A live stream image data of type `MPImage` on which pose landmarks detection is to
|
||||
* be performed.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||
* image is sent to the pose landmarker. The input timestamps must be monotonically increasing.
|
||||
* @param error An optional error parameter populated when there is an error in performing pose
|
||||
* landmark detection on the input live stream image data.
|
||||
*
|
||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||
*/
|
||||
- (BOOL)detectAsyncImage:(MPPImage *)image
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -46,7 +46,7 @@ NS_SWIFT_NAME(PoseLandmarkerResult)
|
|||
*/
|
||||
- (instancetype)initWithLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)landmarks
|
||||
worldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)worldLandmarks
|
||||
segmentationMasks:(NSArray<MPPMask *> *)segmentationMasks
|
||||
segmentationMasks:(nullable NSArray<MPPMask *> *)segmentationMasks
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_UNAVAILABLE;
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
NSArray<MPPConnection *> *const MPPPoseLandmarksConnections = @[
|
||||
[[MPPConnection alloc] initWithStart:0 end:1], [[MPPConnection alloc] initWithStart:1 end:2],
|
||||
[[MPPConnection alloc] initWithStart:2 end:3], [[MPPConnection alloc] initWithStart:3 end:7],
|
||||
[[MPPConnection alloc] initWithStart:0 end:4], [[MPPConnection alloc] initWithStart:4 end:5],
|
||||
[[MPPConnection alloc] initWithStart:5 end:6], [[MPPConnection alloc] initWithStart:6 end:8],
|
||||
[[MPPConnection alloc] initWithStart:9 end:10], [[MPPConnection alloc] initWithStart:11 end:12],
|
||||
[[MPPConnection alloc] initWithStart:11 end:13], [[MPPConnection alloc] initWithStart:13 end:15],
|
||||
[[MPPConnection alloc] initWithStart:15 end:17], [[MPPConnection alloc] initWithStart:15 end:19],
|
||||
[[MPPConnection alloc] initWithStart:15 end:21], [[MPPConnection alloc] initWithStart:17 end:19],
|
||||
[[MPPConnection alloc] initWithStart:12 end:14], [[MPPConnection alloc] initWithStart:14 end:16],
|
||||
[[MPPConnection alloc] initWithStart:16 end:18], [[MPPConnection alloc] initWithStart:16 end:20],
|
||||
[[MPPConnection alloc] initWithStart:16 end:22], [[MPPConnection alloc] initWithStart:18 end:20],
|
||||
[[MPPConnection alloc] initWithStart:11 end:23], [[MPPConnection alloc] initWithStart:12 end:24],
|
||||
[[MPPConnection alloc] initWithStart:23 end:24], [[MPPConnection alloc] initWithStart:23 end:25],
|
||||
[[MPPConnection alloc] initWithStart:26 end:28], [[MPPConnection alloc] initWithStart:27 end:29],
|
||||
[[MPPConnection alloc] initWithStart:28 end:30], [[MPPConnection alloc] initWithStart:29 end:31],
|
||||
[[MPPConnection alloc] initWithStart:30 end:32], [[MPPConnection alloc] initWithStart:27 end:31],
|
||||
[[MPPConnection alloc] initWithStart:28 end:32]
|
||||
];
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -36,3 +36,21 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerOptions",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPPoseLandmarkerResultHelpers",
|
||||
srcs = ["sources/MPPPoseLandmarkerResult+Helpers.mm"],
|
||||
hdrs = ["sources/MPPPoseLandmarkerResult+Helpers.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
"-x objective-c++",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPLandmarkHelpers",
|
||||
"//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarkerResult",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/pose_landmarker/sources/MPPPoseLandmarkerResult.h"
|
||||
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface MPPPoseLandmarkerResult (Helpers)
|
||||
|
||||
/**
|
||||
* Creates an `MPPPoseLandmarkerResult` from landmarks, world landmarks and segmentation mask
|
||||
* packets.
|
||||
*
|
||||
* @param landmarksPacket A MediaPipe packet wrapping a `std::vector<NormalizedlandmarkListProto>`.
|
||||
* @param worldLandmarksPacket A MediaPipe packet wrapping a `std::vector<LandmarkListProto>`.
|
||||
* @param segmentationMasksPacket a MediaPipe packet wrapping a `std::vector<Image>`.
|
||||
*
|
||||
* @return An `MPPPoseLandmarkerResult` object that contains the hand landmark detection
|
||||
* results.
|
||||
*/
|
||||
+ (MPPPoseLandmarkerResult *)
|
||||
poseLandmarkerResultWithLandmarksPacket:(const mediapipe::Packet &)landmarksPacket
|
||||
worldLandmarksPacket:(const mediapipe::Packet &)worldLandmarksPacket
|
||||
segmentationMasksPacket:(const mediapipe::Packet *)segmentationMasksPacket;
|
||||
|
||||
/**
|
||||
* Creates an `MPPPoseLandmarkerResult` from landmarks, world landmarks and segmentation mask
|
||||
* images.
|
||||
*
|
||||
* @param landmarksProto A vector of protos of type `std::vector<NormalizedlandmarkListProto>`.
|
||||
* @param worldLandmarksProto A vector of protos of type `std::vector<LandmarkListProto>`.
|
||||
* @param segmentationMasks A vector of type `std::vector<Image>`.
|
||||
* @param timestampInMilliSeconds The timestamp of the Packet that contained the result.
|
||||
*
|
||||
* @return An `MPPPoseLandmarkerResult` object that contains the pose landmark detection
|
||||
* results.
|
||||
*/
|
||||
+ (MPPPoseLandmarkerResult *)
|
||||
poseLandmarkerResultWithLandmarksProto:
|
||||
(const std::vector<::mediapipe::NormalizedLandmarkList> &)landmarksProto
|
||||
worldLandmarksProto:
|
||||
(const std::vector<::mediapipe::LandmarkList> &)worldLandmarksProto
|
||||
segmentationMasks:(const std::vector<mediapipe::Image> *)segmentationMasks
|
||||
timestampInMilliSeconds:(NSInteger)timestampInMilliseconds;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,124 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/pose_landmarker/utils/sources/MPPPoseLandmarkerResult+Helpers.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPLandmark+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using LandmarkListProto = ::mediapipe::LandmarkList;
|
||||
using NormalizedLandmarkListProto = ::mediapipe::NormalizedLandmarkList;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::Packet;
|
||||
|
||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPPoseLandmarkerResult (Helpers)
|
||||
|
||||
+ (MPPPoseLandmarkerResult *)emptyPoseLandmarkerResultWithTimestampInMilliseconds:
|
||||
(NSInteger)timestampInMilliseconds {
|
||||
return [[MPPPoseLandmarkerResult alloc] initWithLandmarks:@[]
|
||||
worldLandmarks:@[]
|
||||
segmentationMasks:@[]
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
+ (MPPPoseLandmarkerResult *)
|
||||
poseLandmarkerResultWithLandmarksProto:
|
||||
(const std::vector<NormalizedLandmarkListProto> &)landmarksProto
|
||||
worldLandmarksProto:
|
||||
(const std::vector<LandmarkListProto> &)worldLandmarksProto
|
||||
segmentationMasks:(const std::vector<Image> *)segmentationMasks
|
||||
timestampInMilliSeconds:(NSInteger)timestampInMilliseconds {
|
||||
NSMutableArray<NSMutableArray<MPPNormalizedLandmark *> *> *multiplePoseLandmarks =
|
||||
[NSMutableArray arrayWithCapacity:(NSUInteger)landmarksProto.size()];
|
||||
|
||||
for (const auto &landmarkListProto : landmarksProto) {
|
||||
NSMutableArray<MPPNormalizedLandmark *> *landmarks =
|
||||
[NSMutableArray arrayWithCapacity:(NSUInteger)landmarkListProto.landmark().size()];
|
||||
for (const auto &normalizedLandmarkProto : landmarkListProto.landmark()) {
|
||||
MPPNormalizedLandmark *normalizedLandmark =
|
||||
[MPPNormalizedLandmark normalizedLandmarkWithProto:normalizedLandmarkProto];
|
||||
[landmarks addObject:normalizedLandmark];
|
||||
}
|
||||
[multiplePoseLandmarks addObject:landmarks];
|
||||
}
|
||||
|
||||
NSMutableArray<NSMutableArray<MPPLandmark *> *> *multiplePoseWorldLandmarks =
|
||||
[NSMutableArray arrayWithCapacity:(NSUInteger)worldLandmarksProto.size()];
|
||||
|
||||
for (const auto &worldLandmarkListProto : worldLandmarksProto) {
|
||||
NSMutableArray<MPPLandmark *> *worldLandmarks =
|
||||
[NSMutableArray arrayWithCapacity:(NSUInteger)worldLandmarkListProto.landmark().size()];
|
||||
for (const auto &landmarkProto : worldLandmarkListProto.landmark()) {
|
||||
MPPLandmark *landmark = [MPPLandmark landmarkWithProto:landmarkProto];
|
||||
[worldLandmarks addObject:landmark];
|
||||
}
|
||||
[multiplePoseWorldLandmarks addObject:worldLandmarks];
|
||||
}
|
||||
|
||||
NSMutableArray<MPPMask *> *confidenceMasks =
|
||||
[NSMutableArray arrayWithCapacity:(NSUInteger)segmentationMasks->size()];
|
||||
|
||||
for (const auto &segmentationMask : *segmentationMasks) {
|
||||
[confidenceMasks addObject:[[MPPMask alloc] initWithFloat32Data:(float *)segmentationMask
|
||||
.GetImageFrameSharedPtr()
|
||||
.get()
|
||||
->PixelData()
|
||||
width:segmentationMask.width()
|
||||
height:segmentationMask.height()
|
||||
/** Always deep copy */
|
||||
shouldCopy:YES]];
|
||||
}
|
||||
|
||||
MPPPoseLandmarkerResult *poseLandmarkerResult =
|
||||
[[MPPPoseLandmarkerResult alloc] initWithLandmarks:multiplePoseLandmarks
|
||||
worldLandmarks:multiplePoseWorldLandmarks
|
||||
segmentationMasks:confidenceMasks
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
return poseLandmarkerResult;
|
||||
}
|
||||
|
||||
+ (MPPPoseLandmarkerResult *)
|
||||
poseLandmarkerResultWithLandmarksPacket:(const Packet &)landmarksPacket
|
||||
worldLandmarksPacket:(const Packet &)worldLandmarksPacket
|
||||
segmentationMasksPacket:(const Packet *)segmentationMasksPacket {
|
||||
NSInteger timestampInMilliseconds =
|
||||
(NSInteger)(landmarksPacket.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
|
||||
if (landmarksPacket.IsEmpty()) {
|
||||
return [MPPPoseLandmarkerResult
|
||||
emptyPoseLandmarkerResultWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
if (!landmarksPacket.ValidateAsType<std::vector<NormalizedLandmarkListProto>>().ok() ||
|
||||
!worldLandmarksPacket.ValidateAsType<std::vector<LandmarkListProto>>().ok()) {
|
||||
return [MPPPoseLandmarkerResult
|
||||
emptyPoseLandmarkerResultWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
const std::vector<Image> *segmentationMasks =
|
||||
segmentationMasksPacket ? &(segmentationMasksPacket->Get<std::vector<Image>>()) : nullptr;
|
||||
|
||||
return [MPPPoseLandmarkerResult
|
||||
poseLandmarkerResultWithLandmarksProto:landmarksPacket
|
||||
.Get<std::vector<NormalizedLandmarkListProto>>()
|
||||
worldLandmarksProto:worldLandmarksPacket
|
||||
.Get<std::vector<LandmarkListProto>>()
|
||||
segmentationMasks:segmentationMasks
|
||||
timestampInMilliSeconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
@end
|
|
@ -70,7 +70,7 @@ class BaseOptions:
|
|||
platform_name = platform.system()
|
||||
|
||||
if self.delegate == BaseOptions.Delegate.GPU:
|
||||
if platform_name == 'Linux':
|
||||
if platform_name in ['Linux', 'Darwin']:
|
||||
acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -26,9 +26,11 @@ pybind_library(
|
|||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||
"//mediapipe/python/pybind:util",
|
||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
|
||||
],
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mediapipe/tasks/python/core/pybind/task_runner.h"
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/python/pybind/util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
@ -21,6 +22,9 @@
|
|||
#include "pybind11/stl.h"
|
||||
#include "pybind11_protobuf/native_proto_caster.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc");
|
|||
return absl::OkStatus();
|
||||
};
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
auto gpu_resources_ = mediapipe::GpuResources::Create();
|
||||
if (!gpu_resources_.ok()) {
|
||||
ABSL_LOG(INFO) << "GPU suport is not available: "
|
||||
<< gpu_resources_.status();
|
||||
gpu_resources_ = nullptr;
|
||||
}
|
||||
auto task_runner = TaskRunner::Create(
|
||||
std::move(graph_config),
|
||||
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||
std::move(callback),
|
||||
/* default_executor= */ nullptr,
|
||||
/* input_side_packes= */ std::nullopt, std::move(*gpu_resources_));
|
||||
#else
|
||||
auto task_runner = TaskRunner::Create(
|
||||
std::move(graph_config),
|
||||
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||
std::move(callback));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
RaisePyErrorIfNotOk(task_runner.status());
|
||||
return std::move(*task_runner);
|
||||
},
|
||||
|
|
|
@ -211,3 +211,20 @@ py_test(
|
|||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "face_stylizer_test",
|
||||
srcs = ["face_stylizer_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:face_stylizer",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
],
|
||||
)
|
||||
|
|
191
mediapipe/tasks/python/test/vision/face_stylizer_test.py
Normal file
191
mediapipe/tasks/python/test/vision/face_stylizer_test.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face stylizer."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_stylizer
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect.Rect
|
||||
_Image = image_module.Image
|
||||
_FaceStylizer = face_stylizer.FaceStylizer
|
||||
_FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL = 'face_stylizer_color_ink.task'
|
||||
_LARGE_FACE_IMAGE = 'portrait.jpg'
|
||||
_MODEL_IMAGE_SIZE = 256
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceStylizerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceStylizer.create_from_model_path(self.model_path) as stylizer:
|
||||
self.assertIsInstance(stylizer, _FaceStylizer)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
self.assertIsInstance(stylizer, _FaceStylizer)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
_FaceStylizer.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
stylizer = _FaceStylizer.create_from_options(options)
|
||||
self.assertIsInstance(stylizer, _FaceStylizer)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_stylize(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates stylizer.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
stylizer = _FaceStylizer.create_from_options(options)
|
||||
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(self.test_image)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
# Closes the stylizer explicitly when the stylizer is not used in
|
||||
# a context.
|
||||
stylizer.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_stylize_in_context(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates stylizer.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(self.test_image)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_stylize_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the face.
|
||||
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_stylize_succeeds_with_no_face_detected(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest that doesn't contain a human face.
|
||||
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
||||
self.assertIsNone(stylized_image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
2
mediapipe/tasks/testdata/vision/BUILD
vendored
2
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -48,6 +48,7 @@ mediapipe_files(srcs = [
|
|||
"face_landmark.tflite",
|
||||
"face_landmarker.task",
|
||||
"face_landmarker_v2.task",
|
||||
"face_stylizer_color_ink.task",
|
||||
"fist.jpg",
|
||||
"fist.png",
|
||||
"gesture_recognizer.task",
|
||||
|
@ -183,6 +184,7 @@ filegroup(
|
|||
"face_detection_short_range.tflite",
|
||||
"face_landmarker.task",
|
||||
"face_landmarker_v2.task",
|
||||
"face_stylizer_color_ink.task",
|
||||
"hair_segmentation.tflite",
|
||||
"hand_landmark_full.tflite",
|
||||
"hand_landmark_lite.tflite",
|
||||
|
|
|
@ -2854,7 +2854,262 @@ auxiliary_landmarks {
|
|||
face_blendshapes {
|
||||
classification {
|
||||
index: 0
|
||||
score: 1.6770242e-05
|
||||
label: "tongueOut"
|
||||
score: 8.47715e-07
|
||||
label: "_neutral"
|
||||
}
|
||||
classification {
|
||||
index: 1
|
||||
score: 0.020850565
|
||||
label: "browDownLeft"
|
||||
}
|
||||
classification {
|
||||
index: 2
|
||||
score: 0.007629181
|
||||
label: "browDownRight"
|
||||
}
|
||||
classification {
|
||||
index: 3
|
||||
score: 0.26410568
|
||||
label: "browInnerUp"
|
||||
}
|
||||
classification {
|
||||
index: 4
|
||||
score: 0.04212071
|
||||
label: "browOuterUpLeft"
|
||||
}
|
||||
classification {
|
||||
index: 5
|
||||
score: 0.07319052
|
||||
label: "browOuterUpRight"
|
||||
}
|
||||
classification {
|
||||
index: 6
|
||||
score: 9.39117e-06
|
||||
label: "cheekPuff"
|
||||
}
|
||||
classification {
|
||||
index: 7
|
||||
score: 1.9243858e-07
|
||||
label: "cheekSquintLeft"
|
||||
}
|
||||
classification {
|
||||
index: 8
|
||||
score: 4.066475e-08
|
||||
label: "cheekSquintRight"
|
||||
}
|
||||
classification {
|
||||
index: 9
|
||||
score: 0.46092203
|
||||
label: "eyeBlinkLeft"
|
||||
}
|
||||
classification {
|
||||
index: 10
|
||||
score: 0.40371567
|
||||
label: "eyeBlinkRight"
|
||||
}
|
||||
classification {
|
||||
index: 11
|
||||
score: 0.65011656
|
||||
label: "eyeLookDownLeft"
|
||||
}
|
||||
classification {
|
||||
index: 12
|
||||
score: 0.6423024
|
||||
label: "eyeLookDownRight"
|
||||
}
|
||||
classification {
|
||||
index: 13
|
||||
score: 0.04721973
|
||||
label: "eyeLookInLeft"
|
||||
}
|
||||
classification {
|
||||
index: 14
|
||||
score: 0.08176838
|
||||
label: "eyeLookInRight"
|
||||
}
|
||||
classification {
|
||||
index: 15
|
||||
score: 0.09520102
|
||||
label: "eyeLookOutLeft"
|
||||
}
|
||||
classification {
|
||||
index: 16
|
||||
score: 0.07271895
|
||||
label: "eyeLookOutRight"
|
||||
}
|
||||
classification {
|
||||
index: 17
|
||||
score: 0.011193463
|
||||
label: "eyeLookUpLeft"
|
||||
}
|
||||
classification {
|
||||
index: 18
|
||||
score: 0.007041815
|
||||
label: "eyeLookUpRight"
|
||||
}
|
||||
classification {
|
||||
index: 19
|
||||
score: 0.27120194
|
||||
label: "eyeSquintLeft"
|
||||
}
|
||||
classification {
|
||||
index: 20
|
||||
score: 0.21675573
|
||||
label: "eyeSquintRight"
|
||||
}
|
||||
classification {
|
||||
index: 21
|
||||
score: 0.0018824162
|
||||
label: "eyeWideLeft"
|
||||
}
|
||||
classification {
|
||||
index: 22
|
||||
score: 0.0011966582
|
||||
label: "eyeWideRight"
|
||||
}
|
||||
classification {
|
||||
index: 23
|
||||
score: 1.9298719e-05
|
||||
label: "jawForward"
|
||||
}
|
||||
classification {
|
||||
index: 24
|
||||
score: 9.670858e-06
|
||||
label: "jawLeft"
|
||||
}
|
||||
classification {
|
||||
index: 25
|
||||
score: 0.000115385694
|
||||
label: "jawOpen"
|
||||
}
|
||||
classification {
|
||||
index: 26
|
||||
score: 0.00023342477
|
||||
label: "jawRight"
|
||||
}
|
||||
classification {
|
||||
index: 27
|
||||
score: 2.8894076e-05
|
||||
label: "mouthClose"
|
||||
}
|
||||
classification {
|
||||
index: 28
|
||||
score: 0.003933548
|
||||
label: "mouthDimpleLeft"
|
||||
}
|
||||
classification {
|
||||
index: 29
|
||||
score: 0.0051949574
|
||||
label: "mouthDimpleRight"
|
||||
}
|
||||
classification {
|
||||
index: 30
|
||||
score: 0.00067943585
|
||||
label: "mouthFrownLeft"
|
||||
}
|
||||
classification {
|
||||
index: 31
|
||||
score: 0.0006520291
|
||||
label: "mouthFrownRight"
|
||||
}
|
||||
classification {
|
||||
index: 32
|
||||
score: 0.0006695333
|
||||
label: "mouthFunnel"
|
||||
}
|
||||
classification {
|
||||
index: 33
|
||||
score: 8.578597e-05
|
||||
label: "mouthLeft"
|
||||
}
|
||||
classification {
|
||||
index: 34
|
||||
score: 2.6707421e-05
|
||||
label: "mouthLowerDownLeft"
|
||||
}
|
||||
classification {
|
||||
index: 35
|
||||
score: 2.153054e-05
|
||||
label: "mouthLowerDownRight"
|
||||
}
|
||||
classification {
|
||||
index: 36
|
||||
score: 0.0132145975
|
||||
label: "mouthPressLeft"
|
||||
}
|
||||
classification {
|
||||
index: 37
|
||||
score: 0.009528495
|
||||
label: "mouthPressRight"
|
||||
}
|
||||
classification {
|
||||
index: 38
|
||||
score: 0.056963783
|
||||
label: "mouthPucker"
|
||||
}
|
||||
classification {
|
||||
index: 39
|
||||
score: 0.027331185
|
||||
label: "mouthRight"
|
||||
}
|
||||
classification {
|
||||
index: 40
|
||||
score: 0.00072388636
|
||||
label: "mouthRollLower"
|
||||
}
|
||||
classification {
|
||||
index: 41
|
||||
score: 0.00021191382
|
||||
label: "mouthRollUpper"
|
||||
}
|
||||
classification {
|
||||
index: 42
|
||||
score: 0.23938002
|
||||
label: "mouthShrugLower"
|
||||
}
|
||||
classification {
|
||||
index: 43
|
||||
score: 0.052946873
|
||||
label: "mouthShrugUpper"
|
||||
}
|
||||
classification {
|
||||
index: 44
|
||||
score: 0.68681276
|
||||
label: "mouthSmileLeft"
|
||||
}
|
||||
classification {
|
||||
index: 45
|
||||
score: 0.68557316
|
||||
label: "mouthSmileRight"
|
||||
}
|
||||
classification {
|
||||
index: 46
|
||||
score: 0.0030625665
|
||||
label: "mouthStretchLeft"
|
||||
}
|
||||
classification {
|
||||
index: 47
|
||||
score: 0.003999545
|
||||
label: "mouthStretchRight"
|
||||
}
|
||||
classification {
|
||||
index: 48
|
||||
score: 0.013184475
|
||||
label: "mouthUpperUpLeft"
|
||||
}
|
||||
classification {
|
||||
index: 49
|
||||
score: 0.017995607
|
||||
label: "mouthUpperUpRight"
|
||||
}
|
||||
classification {
|
||||
index: 50
|
||||
score: 2.0452394e-06
|
||||
label: "noseSneerLeft"
|
||||
}
|
||||
classification {
|
||||
index: 51
|
||||
score: 3.7912793e-07
|
||||
label: "noseSneerRight"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,27 +31,57 @@ mediapipe_ts_library(
|
|||
|
||||
mediapipe_ts_library(
|
||||
name = "drawing_utils",
|
||||
srcs = ["drawing_utils.ts"],
|
||||
srcs = [
|
||||
"drawing_utils.ts",
|
||||
"drawing_utils_category_mask.ts",
|
||||
],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
":types",
|
||||
"//mediapipe/tasks/web/components/containers:bounding_box",
|
||||
"//mediapipe/tasks/web/components/containers:landmark",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image",
|
||||
srcs = [
|
||||
"image.ts",
|
||||
"image_shader_context.ts",
|
||||
name = "drawing_utils_test_lib",
|
||||
testonly = True,
|
||||
srcs = ["drawing_utils.test.ts"],
|
||||
deps = [
|
||||
":drawing_utils",
|
||||
":image",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "drawing_utils_test",
|
||||
deps = [":drawing_utils_test_lib"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image",
|
||||
srcs = ["image.ts"],
|
||||
deps = ["image_shader_context"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image_shader_context",
|
||||
srcs = ["image_shader_context.ts"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image_test_lib",
|
||||
testonly = True,
|
||||
srcs = ["image.test.ts"],
|
||||
deps = [":image"],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
|
@ -64,6 +94,7 @@ mediapipe_ts_library(
|
|||
srcs = ["mask.ts"],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
"//mediapipe/web/graph_runner:platform_utils",
|
||||
],
|
||||
)
|
||||
|
@ -74,6 +105,7 @@ mediapipe_ts_library(
|
|||
srcs = ["mask.test.ts"],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
],
|
||||
)
|
||||
|
@ -89,6 +121,7 @@ mediapipe_ts_library(
|
|||
deps = [
|
||||
":image",
|
||||
":image_processing_options",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
":vision_task_options",
|
||||
"//mediapipe/framework/formats:rect_jspb_proto",
|
||||
|
|
103
mediapipe/tasks/web/vision/core/drawing_utils.test.ts
Normal file
103
mediapipe/tasks/web/vision/core/drawing_utils.test.ts
Normal file
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
import {DrawingUtils} from './drawing_utils';
|
||||
import {MPImageShaderContext} from './image_shader_context';
|
||||
import {MPMask} from './mask';
|
||||
|
||||
const WIDTH = 2;
|
||||
const HEIGHT = 2;
|
||||
|
||||
const skip = typeof document === 'undefined';
|
||||
if (skip) {
|
||||
console.log('These tests must be run in a browser.');
|
||||
}
|
||||
|
||||
(skip ? xdescribe : describe)('DrawingUtils', () => {
|
||||
let shaderContext = new MPImageShaderContext();
|
||||
let canvas2D: HTMLCanvasElement;
|
||||
let context2D: CanvasRenderingContext2D;
|
||||
let drawingUtils2D: DrawingUtils;
|
||||
let canvasWebGL: HTMLCanvasElement;
|
||||
let contextWebGL: WebGL2RenderingContext;
|
||||
let drawingUtilsWebGL: DrawingUtils;
|
||||
|
||||
beforeEach(() => {
|
||||
shaderContext = new MPImageShaderContext();
|
||||
|
||||
canvasWebGL = document.createElement('canvas');
|
||||
canvasWebGL.width = WIDTH;
|
||||
canvasWebGL.height = HEIGHT;
|
||||
contextWebGL = canvasWebGL.getContext('webgl2')!;
|
||||
drawingUtilsWebGL = new DrawingUtils(contextWebGL);
|
||||
|
||||
canvas2D = document.createElement('canvas');
|
||||
canvas2D.width = WIDTH;
|
||||
canvas2D.height = HEIGHT;
|
||||
context2D = canvas2D.getContext('2d')!;
|
||||
drawingUtils2D = new DrawingUtils(context2D, contextWebGL);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
shaderContext.close();
|
||||
drawingUtils2D.close();
|
||||
drawingUtilsWebGL.close();
|
||||
});
|
||||
|
||||
describe('drawCategoryMask() ', () => {
|
||||
const colors = [
|
||||
[0, 0, 0, 255],
|
||||
[0, 255, 0, 255],
|
||||
[0, 0, 255, 255],
|
||||
[255, 255, 255, 255],
|
||||
];
|
||||
const expectedResult = new Uint8Array(
|
||||
[0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255],
|
||||
);
|
||||
|
||||
it('on 2D canvas', () => {
|
||||
const categoryMask = new MPMask(
|
||||
[new Uint8Array([0, 1, 2, 3])],
|
||||
/* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH,
|
||||
HEIGHT);
|
||||
|
||||
drawingUtils2D.drawCategoryMask(categoryMask, colors);
|
||||
|
||||
const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data;
|
||||
expect(actualResult)
|
||||
.toEqual(new Uint8ClampedArray(expectedResult.buffer));
|
||||
});
|
||||
|
||||
it('on WebGL canvas', () => {
|
||||
const categoryMask = new MPMask(
|
||||
[new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped
|
||||
/* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH,
|
||||
HEIGHT);
|
||||
|
||||
drawingUtilsWebGL.drawCategoryMask(categoryMask, colors);
|
||||
|
||||
const actualResult = new Uint8Array(WIDTH * WIDTH * 4);
|
||||
contextWebGL.readPixels(
|
||||
0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE,
|
||||
actualResult);
|
||||
expect(actualResult).toEqual(expectedResult);
|
||||
});
|
||||
});
|
||||
|
||||
// TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox
|
||||
});
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user