Merge branch 'master' into ios-task

This commit is contained in:
Prianka Liz Kariat 2023-01-04 13:40:17 +05:30
parent 7e0fec7c28
commit 7ce21038bb
67 changed files with 593 additions and 457 deletions

View File

@ -15,4 +15,5 @@
# A list of assignees
assignees:
- sureshdagooglecom
- kuaashish
- ayushgdev

View File

@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) {
}
// If options.convert_signature_to_tags() is set, will convert letters to
// uppercase and replace /'s and -'s with _'s. This enables the standard
// uppercase and replace /, -, . and :'s with _'s. This enables the standard
// SavedModel classification, regression, and prediction signatures to be used
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
// patterns.
@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag(
output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll(output, {{"/", "_"}});
output = absl::StrReplaceAll(output, {{"-", "_"}});
output = absl::StrReplaceAll(output, {{".", "_"}});
output = absl::StrReplaceAll(
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output;
} else {

View File

@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions {
// The name of the generic signature to load into the mapping from tags to
// tensor names.
optional string signature_name = 2 [default = "serving_default"];
// Whether to convert the signature keys to uppercase as well as switch /'s
// and -'s to _'s, which enables common signatures to be used as Tags.
// Whether to convert the signature keys to uppercase as well as switch
// /, -, .and :'s to _'s, which enables common signatures to be used as Tags.
optional bool convert_signature_to_tags = 3 [default = true];
// If true, saved_model_path can have multiple exported models in
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,

View File

@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) {
}
// If options.convert_signature_to_tags() is set, will convert letters to
// uppercase and replace /'s and -'s with _'s. This enables the standard
// uppercase and replace /, -, and .'s with _'s. This enables the standard
// SavedModel classification, regression, and prediction signatures to be used
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
// patterns.
@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag(
output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll(output, {{"/", "_"}});
output = absl::StrReplaceAll(output, {{"-", "_"}});
output = absl::StrReplaceAll(output, {{".", "_"}});
output = absl::StrReplaceAll(
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output;
} else {

View File

@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions {
// The name of the generic signature to load into the mapping from tags to
// tensor names.
optional string signature_name = 2 [default = "serving_default"];
// Whether to convert the signature keys to uppercase as well as switch /'s
// and -'s to _'s, which enables common signatures to be used as Tags.
// Whether to convert the signature keys to uppercase, as well as switch /'s
// -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags.
optional bool convert_signature_to_tags = 3 [default = true];
// If true, saved_model_path can have multiple exported models in
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,

View File

@ -30,6 +30,10 @@ proto_library(
java_lite_proto_library(
name = "autoflip_messages_java_proto_lite",
visibility = [
"//java/com/google/android/apps/photos:__subpackages__",
"//javatests/com/google/android/apps/photos:__subpackages__",
],
deps = [
":autoflip_messages_proto",
],

View File

@ -455,7 +455,7 @@ cc_library(
],
}),
deps = [
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:port",

View File

@ -24,7 +24,7 @@
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/port.h"
@ -434,8 +434,9 @@ class Tensor {
mutable bool use_ahwb_ = false;
mutable uint64_t ahwb_tracking_key_ = 0;
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
// can be more predicted.
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_;
// (Least Recently Used) can be more predicted.
// The value contains the size alignment parameter.
static inline absl::flat_hash_map<uint64_t, int> ahwb_usage_track_;
// Expects the target SSBO to be already bound.
bool AllocateAhwbMapToSsbo() const;
bool InsertAhwbToSsboFence() const;

View File

@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
// Mark current tracking key as Ahwb-use.
ahwb_usage_track_.insert(ahwb_tracking_key_);
if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_);
it != ahwb_usage_track_.end()) {
size_alignment = it->second;
} else if (ahwb_tracking_key_ != 0) {
ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment});
}
use_ahwb_ = true;
if (__builtin_available(android 26, *)) {
@ -458,7 +463,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
}
}
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
// Keep flag value if it was set previously.
use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_);
}
#else // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
};
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{
// Request Ahwb first to get Ahwb storage allocated internally.
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetWritingFinishedFD(-1, [](bool) { return true; });
}
RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name();
@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
}
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
{
// Request Ahwb first to get Ahwb storage allocated internally.
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name();
@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
// Request the CPU view to get the memory to be allocated.
// Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{
@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
// Request the GPU view to get the ssbo allocated internally.
// Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] {

View File

@ -1,34 +1,28 @@
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_test_base.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#if !MEDIAPIPE_DISABLE_GPU
namespace mediapipe {
class TensorAhwbTest : public mediapipe::GpuTestBase {
public:
};
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
TEST(TensorAhwbTest, TestCpuThenAHWB) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{
auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr);
}
{
auto ahwb = tensor.GetAHardwareBufferReadView().handle();
EXPECT_NE(ahwb, nullptr);
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
}
TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
TEST(TensorAhwbTest, TestAHWBThenCpu) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{
auto ahwb = tensor.GetAHardwareBufferWriteView().handle();
EXPECT_NE(ahwb, nullptr);
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetWritingFinishedFD(-1, [](bool) { return true; });
}
{
auto ptr = tensor.GetCpuReadView().buffer<float>();
@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
}
}
TEST_F(TensorAhwbTest, TestCpuThenGl) {
RunInGlContext([] {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
TEST(TensorAhwbTest, TestAhwbAlignment) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5});
{
auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr);
auto view = tensor.GetAHardwareBufferWriteView(16);
EXPECT_NE(view.handle(), nullptr);
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc;
AHardwareBuffer_describe(view.handle(), &desc);
// sizeof(float) * 5 = 20, the closest aligned to 16 size is 32.
EXPECT_EQ(desc.width, 32);
}
view.SetWritingFinishedFD(-1, [](bool) { return true; });
}
}
// Tensor::GetCpuView uses source location mechanism that gives source file name
// and line from where the method is called. The function is intended just to
// have two calls providing the same source file name and line.
auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); }
// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved
// for the first time then the source location is attached to the tensor. If the
// Ahwb view is requested then from the tensor then the previously recorded Cpu
// view request source location is marked for using Ahwb storage.
// When a Cpu view with the same source location (but for the newly allocated
// tensor) is requested and the location is marked to use Ahwb storage then the
// Ahwb storage is allocated for the CpuView.
TEST(TensorAhwbTest, TestTrackingAhwb) {
// Create first tensor and request Cpu and then Ahwb view to mark the source
// location for Ahwb storage.
{
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
{
auto view = GetCpuView(tensor);
EXPECT_NE(view.buffer<float>(), nullptr);
}
{
auto ssbo = tensor.GetOpenGlBufferReadView().name();
EXPECT_GT(ssbo, 0);
// Align size of the Ahwb by multiple of 16.
auto view = tensor.GetAHardwareBufferWriteView(16);
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
}
{
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
{
// The second tensor uses the same Cpu view source location so Ahwb
// storage is allocated internally.
auto view = GetCpuView(tensor);
EXPECT_NE(view.buffer<float>(), nullptr);
}
{
// Check the Ahwb size to be aligned to multiple of 16. The alignment is
// stored by previous requesting of the Ahwb view.
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc;
AHardwareBuffer_describe(view.handle(), &desc);
// sizeof(float) * 9 = 36. The closest aligned size is 48.
EXPECT_EQ(desc.width, 48);
}
view.SetReadingFinishedFunc([](bool) { return true; });
}
}
});
}
} // namespace mediapipe
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -194,6 +194,7 @@ void GraphProfiler::Initialize(
"Calculator \"$0\" has already been added.", node_name);
}
profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
graph_id_ = ++next_instance_id_;
is_initialized_ = true;
}

View File

@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
return validated_graph_;
}
// Gets a numerical identifier for this GraphProfiler object.
uint64_t GetGraphId() { return graph_id_; }
private:
// This can be used to add packet info for the input streams to the graph.
// It treats the stream defined by |stream_name| as a stream produced by a
@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
class GraphProfileBuilder;
std::unique_ptr<GraphProfileBuilder> profile_builder_;
// The globally incrementing identifier for all graphs in a process.
static inline std::atomic_int next_instance_id_ = 0;
// A unique identifier for this object. Only unique within a process.
uint64_t graph_id_;
// For testing.
friend GraphProfilerTestPeer;
};

View File

@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) {
"Cannot initialize .* multiple times.");
}
// Tests that graph identifiers are not reused, even after destruction.
TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) {
auto raw_graph_config = R"(
profiler_config {
enable_profiler: true
}
input_stream: "input_stream"
node {
calculator: "DummyTestCalculator"
input_stream: "input_stream"
})";
const int n_iterations = 100;
absl::flat_hash_set<int> seen_ids;
for (int i = 0; i < n_iterations; ++i) {
std::shared_ptr<ProfilingContext> profiler =
std::make_shared<ProfilingContext>();
auto graph_config = CreateGraphConfig(raw_graph_config);
mediapipe::ValidatedGraphConfig validated_graph;
QCHECK_OK(validated_graph.Initialize(graph_config));
profiler->Initialize(validated_graph);
int id = profiler->GetGraphId();
ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id)));
seen_ids.insert(id);
}
}
// Tests that Pause(), Resume(), and Reset() works.
TEST_F(GraphProfilerTestPeer, PauseResumeReset) {
InitializeProfilerWithGraphConfig(R"(

View File

@ -74,6 +74,9 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView(
static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
const GlTextureView& view) {
CHECK(pixel_buffer);
auto ctx = GlContext::GetCurrent().get();
if (!ctx) ctx = view.gl_context();
ctx->Run([pixel_buffer, &view, ctx] {
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferLockBaseAddress failed: " << err;
@ -82,34 +85,40 @@ static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
uint8_t* pixel_ptr =
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
if (pixel_format == kCVPixelFormatType_32BGRA) {
// TODO: restore previous framebuffer? Move this to helper so we
// can use BindFramebuffer?
glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx));
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
view.target(), view.name(), 0);
size_t contiguous_bytes_per_row = view.width() * 4;
if (bytes_per_row == contiguous_bytes_per_row) {
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
pixel_ptr);
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
GL_UNSIGNED_BYTE, pixel_ptr);
} else {
// TODO: use GL_PACK settings for row length. We can expect
// GLES 3.0 on iOS now.
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
view.height());
uint8_t* temp_ptr = contiguous_buffer.data();
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
temp_ptr);
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
GL_UNSIGNED_BYTE, temp_ptr);
for (int i = 0; i < view.height(); ++i) {
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
temp_ptr += contiguous_bytes_per_row;
pixel_ptr += bytes_per_row;
}
}
// TODO: restore previous framebuffer?
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
view.target(), 0, 0);
glBindFramebuffer(GL_FRAMEBUFFER, 0);
} else {
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
}
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
});
}
#endif // TARGET_IPHONE_SIMULATOR

View File

@ -71,9 +71,12 @@ class TextClassifierTest(tf.test.TestCase):
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
filecmp.clear_cache()
self.assertTrue(
filecmp.cmp(output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
filecmp.cmp(
output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
shallow=False))
def test_create_and_train_bert(self):
train_data, validation_data = self._get_data()

View File

@ -135,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
filecmp.clear_cache()
self.assertTrue(
filecmp.cmp(
output_metadata_file, expected_metadata_file, shallow=False))
def test_continual_training_by_loading_checkpoint(self):
mock_stdout = io.StringIO()

View File

@ -230,16 +230,17 @@ if ([wrapper.delegate
}
- (absl::Status)performStart {
absl::Status status = _graph->Initialize(_config);
if (!status.ok()) {
return status;
}
absl::Status status;
for (const auto& service_packet : _servicePackets) {
status = _graph->SetServicePacket(*service_packet.first, service_packet.second);
if (!status.ok()) {
return status;
}
}
status = _graph->Initialize(_config);
if (!status.ok()) {
return status;
}
status = _graph->StartRun(_inputSidePackets, _streamHeaders);
if (!status.ok()) {
return status;

View File

@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
auto custom_gestures_classifier_options_proto =
std::make_unique<components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->canned_gestures_classifier_options)));
&(options->custom_gestures_classifier_options)));
hand_gesture_recognizer_graph_options
->mutable_custom_gesture_classifier_graph_options()
->mutable_classifier_options()
->Swap(canned_gestures_classifier_options_proto.get());
->Swap(custom_gestures_classifier_options_proto.get());
return options_proto;
}

View File

@ -38,4 +38,3 @@ objc_library(
"-std=c++17",
],
)

View File

@ -13,6 +13,7 @@
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/tasks/cc/common.h"
NS_ASSUME_NONNULL_BEGIN
@ -56,6 +57,7 @@ extern NSString *const MPPTasksErrorDomain;
* @param status absl::Status.
* @param error Pointer to the memory location where the created error should be saved. If `nil`,
* no error will be saved.
* @return YES when there is no error, NO otherwise.
*/
+ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error;

View File

@ -20,7 +20,6 @@
#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/cord.h" // from @com_google_absl
#include "mediapipe/tasks/cc/common.h"
/** Error domain of MediaPipe task library errors. */
@ -96,8 +95,8 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
// appropriate MPPTasksErrorCode in default cases. Note:
// The mapping to absl::Status::code() is done to generate a more specific error code than
// MPPTasksErrorCodeError in cases when the payload can't be mapped to
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned
// without modification by Mediapipe cc library methods.
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn
// returned without modification by Mediapipe cc library methods.
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
switch (status.code()) {
case absl::StatusCode::kInternal:

View File

@ -13,13 +13,14 @@
// limitations under the License.
#import <Foundation/Foundation.h>
#include <string>
NS_ASSUME_NONNULL_BEGIN
@interface NSString (Helpers)
@property(readonly) std::string cppString;
@property(readonly, nonatomic) std::string cppString;
+ (NSString *)stringWithCppString:(std::string)text;

View File

@ -21,4 +21,3 @@ objc_library(
srcs = ["sources/MPPClassifierOptions.m"],
hdrs = ["sources/MPPClassifierOptions.h"],
)

View File

@ -22,29 +22,34 @@ NS_ASSUME_NONNULL_BEGIN
NS_SWIFT_NAME(ClassifierOptions)
@interface MPPClassifierOptions : NSObject <NSCopying>
/** The locale to use for display names specified through the TFLite Model
/**
* The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English.
*/
@property(nonatomic, copy) NSString *displayNamesLocale;
/** The maximum number of top-scored classification results to return. If < 0,
/**
* The maximum number of top-scored classification results to return. If < 0,
* all available results will be returned. If 0, an invalid argument error is
* returned.
*/
@property(nonatomic) NSInteger maxResults;
/** Score threshold to override the one provided in the model metadata (if any).
/**
* Score threshold to override the one provided in the model metadata (if any).
* Results below this value are rejected.
*/
@property(nonatomic) float scoreThreshold;
/** The allowlist of category names. If non-empty, detection results whose
/**
* The allowlist of category names. If non-empty, detection results whose
* category name is not in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryDenylist.
*/
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
/** The denylist of category names. If non-empty, detection results whose
/**
* The denylist of category names. If non-empty, detection results whose
* category name is in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryAllowlist.
*/

View File

@ -19,8 +19,8 @@
- (instancetype)init {
self = [super init];
if (self) {
self.maxResults = -1;
self.scoreThreshold = 0;
_maxResults = -1;
_scoreThreshold = 0;
}
return self;
}

View File

@ -22,8 +22,7 @@ objc_library(
hdrs = ["sources/MPPClassifierOptions+Helpers.h"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
]
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
],
)

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
NS_ASSUME_NONNULL_BEGIN

View File

@ -29,7 +29,6 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
}
classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
for (NSString *category in self.categoryAllowlist) {

View File

@ -54,14 +54,14 @@ objc_library(
"-std=c++17",
],
deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
":MPPTaskOptions",
":MPPTaskOptionsProtocol",
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/common:MPPCommon",
],
)
@ -83,9 +83,13 @@ objc_library(
name = "MPPTaskRunner",
srcs = ["sources/MPPTaskRunner.mm"],
hdrs = ["sources/MPPTaskRunner.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
],
)

View File

@ -13,7 +13,9 @@
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
@ -59,7 +61,7 @@ NS_ASSUME_NONNULL_BEGIN
/**
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.
*/
- (mediapipe::CalculatorGraphConfig)generateGraphConfig;
- (::mediapipe::CalculatorGraphConfig)generateGraphConfig;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -24,7 +24,6 @@
namespace {
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
using Node = ::mediapipe::CalculatorGraphConfig::Node;
using ::mediapipe::CalculatorOptions;
using ::mediapipe::FlowLimiterCalculatorOptions;
using ::mediapipe::InputStreamInfo;
} // namespace

View File

@ -13,6 +13,7 @@
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator_options.pb.h"
NS_ASSUME_NONNULL_BEGIN
@ -25,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
/**
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
*/
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;
- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
@end

View File

@ -22,7 +22,6 @@ NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner.
*/
@interface MPPTaskRunner : NSObject
/**
@ -35,11 +34,10 @@ NS_ASSUME_NONNULL_BEGIN
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)
process:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:
(const mediapipe::tasks::core::PacketMap &)packetMap;
- (void)close;
- (absl::Status)close;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -17,7 +17,6 @@
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@ -49,8 +48,8 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
return _cppTaskRunner->Process(packetMap);
}
- (void)close {
_cppTaskRunner->Close();
- (absl::Status)close {
return _cppTaskRunner->Close();
}
@end

View File

@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
*
* @param options The options for the audio classifier.
*/
override async setOptions(options: AudioClassifierOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: AudioClassifierOptions): Promise<void> {
this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions()));
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM);

View File

@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
audioClassifier = new AudioClassifierFake();
await audioClassifier.setOptions({}); // Initialize graph
await audioClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {

View File

@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
*
* @param options The options for the audio embedder.
*/
override async setOptions(options: AudioEmbedderOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: AudioEmbedderOptions): Promise<void> {
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions()));
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM);

View File

@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
audioEmbedder = new AudioEmbedderFake();
await audioEmbedder.setOptions({}); // Initialize graph
await audioEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', () => {

View File

@ -103,29 +103,3 @@ jasmine_node_test(
name = "embedder_options_test",
deps = [":embedder_options_test_lib"],
)
mediapipe_ts_library(
name = "base_options",
srcs = [
"base_options.ts",
],
deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core",
],
)
mediapipe_ts_library(
name = "base_options_test_lib",
testonly = True,
srcs = ["base_options.test.ts"],
deps = [":base_options"],
)
jasmine_node_test(
name = "base_options_test",
deps = [":base_options_test_lib"],
)

View File

@ -1,127 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
// Placeholder for internal dependency on trusted resource URL builder
import {convertBaseOptionsToProto} from './base_options';
describe('convertBaseOptionsToProto()', () => {
const mockBytes = new Uint8Array([0, 1, 2, 3]);
const mockBytesResult = {
modelAsset: {
fileContent: Buffer.from(mockBytes).toString('base64'),
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
},
};
let fetchSpy: jasmine.Spy;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
} as unknown as Response;
});
global.fetch = fetchSpy;
});
it('verifies that at least one model asset option is provided', async () => {
await expectAsync(convertBaseOptionsToProto({}))
.toBeRejectedWithError(
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
});
it('verifies that no more than one model asset option is provided', async () => {
await expectAsync(convertBaseOptionsToProto({
modelAssetPath: `foo`,
modelAssetBuffer: new Uint8Array([])
}))
.toBeRejectedWithError(
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
});
it('downloads model', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetPath: `foo`,
});
expect(fetchSpy).toHaveBeenCalled();
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('does not download model when bytes are provided', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
});
expect(fetchSpy).not.toHaveBeenCalled();
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('can enable CPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'CPU',
});
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('can enable GPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
});
expect(baseOptionsProto.toObject()).toEqual({
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
});
it('can reset delegate', async () => {
let baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
});
// Clear backend
baseOptionsProto =
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
});

View File

@ -1,80 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided
*/
export async function convertBaseOptionsToProto(
updatedOptions: BaseOptions,
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
const result =
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
}
if (!externalFile.hasFileContent()) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
}
/** Configues the `acceleration` option. */
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
const acceleration = proto.getAcceleration() ?? new Acceleration();
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
proto.setAcceleration(acceleration);
}

View File

@ -18,8 +18,10 @@ mediapipe_ts_library(
srcs = ["task_runner.ts"],
deps = [
":core",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
@ -53,6 +55,7 @@ mediapipe_ts_library(
"task_runner_test.ts",
],
deps = [
":core",
":task_runner",
":task_runner_test_utils",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",

View File

@ -14,9 +14,11 @@
* limitations under the License.
*/
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
@ -91,14 +93,52 @@ export abstract class TaskRunner {
this.graphRunner.registerModelResourcesGraphService();
}
/** Configures the shared options of a MediaPipe Task. */
async setOptions(options: TaskRunnerOptions): Promise<void> {
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto(
options.baseOptions, this.baseOptions);
/** Configures the task with custom options. */
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
/**
* Applies the current set of options, including any base options that have
* not been processed by the task implementation. The options are applied
* synchronously unless a `modelAssetPath` is provided. This ensures that
* for most use cases options are applied directly and immediately affect
* the next inference.
*/
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
const baseOptions: BaseOptions = options.baseOptions || {};
// Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => response.arrayBuffer())
.then(buffer => {
this.setExternalFile(new Uint8Array(buffer));
this.refreshGraph();
});
} else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer);
this.refreshGraph();
return Promise.resolve();
}
}
/** Appliest the current options to the MediaPipe graph. */
protected abstract refreshGraph(): void;
/**
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
* over the video stream. Will replace the previously running MediaPipe graph,
@ -140,6 +180,27 @@ export abstract class TaskRunner {
}
this.processingErrors = [];
}
/** Configures the `externalFile` option */
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
if (modelAssetBuffer) {
externalFile.setFileContent(modelAssetBuffer);
}
this.baseOptions.setModelAsset(externalFile);
}
/** Configures the `acceleration` option. */
private setAcceleration(options: BaseOptions) {
const acceleration =
this.baseOptions.getAcceleration() ?? new Acceleration();
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
this.baseOptions.setAcceleration(acceleration);
}
}

View File

@ -15,18 +15,22 @@
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource URL builder
import {GraphRunnerImageLib} from './task_runner';
import {TaskRunnerOptions} from './task_runner_options.d';
class TaskRunnerFake extends TaskRunner {
protected baseOptions = new BaseOptionsProto();
private errorListener: ErrorListener|undefined;
private errors: string[] = [];
baseOptions = new BaseOptionsProto();
static createFake(): TaskRunnerFake {
const wasmModule = createSpyWasmModule();
return new TaskRunnerFake(wasmModule);
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
super.finishProcessing();
}
override refreshGraph(): void {}
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
super.setGraph(graphData, isBinary);
}
setOptions(options: TaskRunnerOptions): Promise<void> {
return this.applyOptions(options);
}
private throwErrors(): void {
expect(this.errorListener).toBeDefined();
for (const error of this.errors) {
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
}
describe('TaskRunner', () => {
const mockBytes = new Uint8Array([0, 1, 2, 3]);
const mockBytesResult = {
modelAsset: {
fileContent: Buffer.from(mockBytes).toString('base64'),
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
},
};
let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
} as unknown as Response;
});
global.fetch = fetchSpy;
taskRunner = TaskRunnerFake.createFake();
});
it('handles errors during graph update', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.enqueueError('Test error');
expect(() => {
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
});
it('handles errors during graph execution', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
taskRunner.enqueueError('Test error');
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
});
it('can handle multiple errors', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.enqueueError('Test error 1');
taskRunner.enqueueError('Test error 2');
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
}).toThrowError(/Test error 1, Test error 2/);
});
it('verifies that at least one model asset option is provided', () => {
expect(() => {
taskRunner.setOptions({});
})
.toThrowError(
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
});
it('verifies that no more than one model asset option is provided', () => {
expect(() => {
taskRunner.setOptions({
baseOptions: {
modelAssetPath: `foo`,
modelAssetBuffer: new Uint8Array([])
}
});
})
.toThrowError(
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
});
it('doesn\'t require model once it is configured', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
expect(() => {
taskRunner.setOptions({});
}).not.toThrowError();
});
it('downloads model', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetPath: `foo`}});
expect(fetchSpy).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('does not download model when bytes are provided', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
expect(fetchSpy).not.toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('changes model synchronously when bytes are provided', () => {
const resolvedPromise = taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
// Check that the change has been applied even though we do not await the
// above Promise
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
return resolvedPromise;
});
it('can enable CPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'CPU',
}
});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('can enable GPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
expect(taskRunner.baseOptions.toObject()).toEqual({
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
});
it('can reset delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
// Clear backend
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
});

View File

@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule {
* Sets up our equality testing to use a custom float equality checking function
* to avoid incorrect test results due to minor floating point inaccuracies.
*/
export function addJasmineCustomFloatEqualityTester() {
export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) {
jasmine.addCustomEqualityTester((a, b) => { // Custom float equality
if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) {
return Math.abs(a - b) < 5e-8;
return Math.abs(a - b) < tolerance;
}
return;
});

View File

@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
*
* @param options The options for the text classifier.
*/
override async setOptions(options: TextClassifierOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: TextClassifierOptions): Promise<void> {
this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions()));
this.refreshGraph();
return this.applyOptions(options);
}
protected override get baseOptions(): BaseOptionsProto {
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);

View File

@ -56,7 +56,8 @@ describe('TextClassifier', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
textClassifier = new TextClassifierFake();
await textClassifier.setOptions({}); // Initialize graph
await textClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {

View File

@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
*
* @param options The options for the text embedder.
*/
override async setOptions(options: TextEmbedderOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: TextEmbedderOptions): Promise<void> {
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions()));
this.refreshGraph();
return this.applyOptions(options);
}
protected override get baseOptions(): BaseOptionsProto {
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM);

View File

@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
textEmbedder = new TextEmbedderFake();
await textEmbedder.setOptions({}); // Initialize graph
await textEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {

View File

@ -29,6 +29,7 @@ mediapipe_ts_library(
testonly = True,
srcs = ["vision_task_runner.test.ts"],
deps = [
":vision_task_options",
":vision_task_runner",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/core:task_runner_test_utils",

View File

@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {VisionTaskOptions} from './vision_task_options';
import {VisionTaskRunner} from './vision_task_runner';
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
protected override process(): void {}
protected override refreshGraph(): void {}
override setOptions(options: VisionTaskOptions): Promise<void> {
return this.applyOptions(options);
}
override processImageData(image: ImageSource): void {
super.processImageData(image);
}
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
}
describe('VisionTaskRunner', () => {
const streamMode = {
modelAsset: undefined,
useStreamMode: true,
acceleration: undefined,
};
const imageMode = {
modelAsset: undefined,
useStreamMode: false,
acceleration: undefined,
};
let visionTaskRunner: VisionTaskRunnerFake;
beforeEach(() => {
beforeEach(async () => {
visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('can enable image mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'image'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
});
it('can enable video mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'video'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: true}));
});
it('can clear running mode', async () => {
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
// Clear running mode
await visionTaskRunner.setOptions({runningMode: undefined});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
});
it('cannot process images with video mode', async () => {

View File

@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
/** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner<T> extends TaskRunner {
/** Configures the shared options of a vision task. */
override async setOptions(options: VisionTaskOptions): Promise<void> {
await super.setOptions(options);
override applyOptions(options: VisionTaskOptions): Promise<void> {
if ('runningMode' in options) {
const useStreamMode =
!!options.runningMode && options.runningMode !== 'image';
this.baseOptions.setUseStreamMode(useStreamMode);
}
return super.applyOptions(options);
}
/** Sends an image packet to the graph and awaits results. */

View File

@ -169,9 +169,7 @@ export class GestureRecognizer extends
*
* @param options The options for the gesture recognizer.
*/
override async setOptions(options: GestureRecognizerOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: GestureRecognizerOptions): Promise<void> {
if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands(
options.numHands ?? DEFAULT_NUM_HANDS);
@ -221,7 +219,7 @@ export class GestureRecognizer extends
?.clearClassifierOptions();
}
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -265,6 +263,15 @@ export class GestureRecognizer extends
NORM_RECT_STREAM, timestamp);
this.finishProcessing();
if (this.gestures.length === 0) {
// If no gestures are detected in the image, just return an empty list
return {
gestures: [],
landmarks: [],
worldLandmarks: [],
handednesses: [],
};
} else {
return {
gestures: this.gestures,
landmarks: this.landmarks,
@ -272,6 +279,7 @@ export class GestureRecognizer extends
handednesses: this.handednesses
};
}
}
/** Sets the default values for the graph. */
private initDefaults(): void {
@ -285,15 +293,19 @@ export class GestureRecognizer extends
}
/** Converts the proto data to a Category[][] structure. */
private toJsCategories(data: Uint8Array[]): Category[][] {
private toJsCategories(data: Uint8Array[], populateIndex = true):
Category[][] {
const result: Category[][] = [];
for (const binaryProto of data) {
const inputList = ClassificationList.deserializeBinary(binaryProto);
const outputList: Category[] = [];
for (const classification of inputList.getClassificationList()) {
const index = populateIndex && classification.hasIndex() ?
classification.getIndex()! :
DEFAULT_CATEGORY_INDEX;
outputList.push({
score: classification.getScore() ?? 0,
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX,
index,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
});
@ -342,7 +354,7 @@ export class GestureRecognizer extends
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM);
@ -377,7 +389,10 @@ export class GestureRecognizer extends
});
this.graphRunner.attachProtoVectorListener(
HAND_GESTURES_STREAM, binaryProto => {
this.gestures.push(...this.toJsCategories(binaryProto));
// Gesture index is not used, because the final gesture result comes
// from multiple classifiers.
this.gestures.push(
...this.toJsCategories(binaryProto, /* populateIndex= */ false));
});
this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => {

View File

@ -17,6 +17,8 @@
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
export {Category, Landmark, NormalizedLandmark};
/**
* Represents the gesture recognition results generated by `GestureRecognizer`.
*/
@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult {
/** Handedness of detected hands. */
handednesses: Category[][];
/** Recognized hand gestures of detected hands */
/**
* Recognized hand gestures of detected hands. Note that the index of the
* gesture is always -1, because the raw indices from multiple gesture
* classifiers cannot consolidate to a meaningful index.
*/
gestures: Category[][];
}

View File

@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
gestureRecognizer = new GestureRecognizerFake();
await gestureRecognizer.setOptions({}); // Initialize graph
await gestureRecognizer.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {
@ -271,7 +272,7 @@ describe('GestureRecognizer', () => {
expect(gestures).toEqual({
'gestures': [[{
'score': 0.2,
'index': 2,
'index': -1,
'categoryName': 'gesture_label',
'displayName': 'gesture_display_name'
}]],
@ -304,4 +305,25 @@ describe('GestureRecognizer', () => {
// gestures.
expect(gestures2).toEqual(gestures1);
});
it('returns empty results when no gestures are detected', async () => {
// Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(gestureRecognizer);
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
gestureRecognizer.listeners.get('hand_gestures')!([]);
});
// Invoke the gesture recognizer
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
expect(gestures).toEqual({
'gestures': [],
'landmarks': [],
'worldLandmarks': [],
'handednesses': []
});
});
});

View File

@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
*
* @param options The options for the hand landmarker.
*/
override async setOptions(options: HandLandmarkerOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: HandLandmarkerOptions): Promise<void> {
// Configure hand detector options.
if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands(
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM);

View File

@ -17,6 +17,8 @@
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
export {Landmark, NormalizedLandmark, Category};
/**
* Represents the hand landmarks deection results generated by `HandLandmarker`.
*/

View File

@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
handLandmarker = new HandLandmarkerFake();
await handLandmarker.setOptions({}); // Initialize graph
await handLandmarker.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {

View File

@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
*
* @param options The options for the image classifier.
*/
override async setOptions(options: ImageClassifierOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: ImageClassifierOptions): Promise<void> {
this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions()));
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);

View File

@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
imageClassifier = new ImageClassifierFake();
await imageClassifier.setOptions({}); // Initialize graph
await imageClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {

View File

@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
*
* @param options The options for the image embedder.
*/
override async setOptions(options: ImageEmbedderOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: ImageEmbedderOptions): Promise<void> {
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions()));
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM);

View File

@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
imageEmbedder = new ImageEmbedderFake();
await imageEmbedder.setOptions({}); // Initialize graph
await imageEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {

View File

@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
*
* @param options The options for the object detector.
*/
override async setOptions(options: ObjectDetectorOptions): Promise<void> {
await super.setOptions(options);
override setOptions(options: ObjectDetectorOptions): Promise<void> {
// Note that we have to support both JSPB and ProtobufJS, hence we
// have to expliclity clear the values instead of setting them to
// `undefined`.
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
this.options.clearCategoryDenylistList();
}
this.refreshGraph();
return this.applyOptions(options);
}
/**
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(DETECTIONS_STREAM);

View File

@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
objectDetector = new ObjectDetectorFake();
await objectDetector.setOptions({}); // Initialize graph
await objectDetector.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {