Merge branch 'master' into ios-task

This commit is contained in:
Prianka Liz Kariat 2023-01-04 13:40:17 +05:30
parent 7e0fec7c28
commit 7ce21038bb
67 changed files with 593 additions and 457 deletions

View File

@ -15,4 +15,5 @@
# A list of assignees # A list of assignees
assignees: assignees:
- sureshdagooglecom - kuaashish
- ayushgdev

View File

@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) {
} }
// If options.convert_signature_to_tags() is set, will convert letters to // If options.convert_signature_to_tags() is set, will convert letters to
// uppercase and replace /'s and -'s with _'s. This enables the standard // uppercase and replace /, -, . and :'s with _'s. This enables the standard
// SavedModel classification, regression, and prediction signatures to be used // SavedModel classification, regression, and prediction signatures to be used
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common // as uppercase INPUTS and OUTPUTS tags for streams and supports other common
// patterns. // patterns.
@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag(
output.resize(name.length()); output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(), std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); }); [](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll(output, {{"/", "_"}}); output = absl::StrReplaceAll(
output = absl::StrReplaceAll(output, {{"-", "_"}}); output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
output = absl::StrReplaceAll(output, {{".", "_"}});
LOG(INFO) << "Renamed TAG from: " << name << " to " << output; LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output; return output;
} else { } else {

View File

@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions {
// The name of the generic signature to load into the mapping from tags to // The name of the generic signature to load into the mapping from tags to
// tensor names. // tensor names.
optional string signature_name = 2 [default = "serving_default"]; optional string signature_name = 2 [default = "serving_default"];
// Whether to convert the signature keys to uppercase as well as switch /'s // Whether to convert the signature keys to uppercase as well as switch
// and -'s to _'s, which enables common signatures to be used as Tags. // /, -, .and :'s to _'s, which enables common signatures to be used as Tags.
optional bool convert_signature_to_tags = 3 [default = true]; optional bool convert_signature_to_tags = 3 [default = true];
// If true, saved_model_path can have multiple exported models in // If true, saved_model_path can have multiple exported models in
// subdirectories saved_model_path/%08d and the alphabetically last (i.e., // subdirectories saved_model_path/%08d and the alphabetically last (i.e.,

View File

@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) {
} }
// If options.convert_signature_to_tags() is set, will convert letters to // If options.convert_signature_to_tags() is set, will convert letters to
// uppercase and replace /'s and -'s with _'s. This enables the standard // uppercase and replace /, -, and .'s with _'s. This enables the standard
// SavedModel classification, regression, and prediction signatures to be used // SavedModel classification, regression, and prediction signatures to be used
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common // as uppercase INPUTS and OUTPUTS tags for streams and supports other common
// patterns. // patterns.
@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag(
output.resize(name.length()); output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(), std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); }); [](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll(output, {{"/", "_"}}); output = absl::StrReplaceAll(
output = absl::StrReplaceAll(output, {{"-", "_"}}); output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
output = absl::StrReplaceAll(output, {{".", "_"}});
LOG(INFO) << "Renamed TAG from: " << name << " to " << output; LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output; return output;
} else { } else {

View File

@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions {
// The name of the generic signature to load into the mapping from tags to // The name of the generic signature to load into the mapping from tags to
// tensor names. // tensor names.
optional string signature_name = 2 [default = "serving_default"]; optional string signature_name = 2 [default = "serving_default"];
// Whether to convert the signature keys to uppercase as well as switch /'s // Whether to convert the signature keys to uppercase, as well as switch /'s
// and -'s to _'s, which enables common signatures to be used as Tags. // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags.
optional bool convert_signature_to_tags = 3 [default = true]; optional bool convert_signature_to_tags = 3 [default = true];
// If true, saved_model_path can have multiple exported models in // If true, saved_model_path can have multiple exported models in
// subdirectories saved_model_path/%08d and the alphabetically last (i.e., // subdirectories saved_model_path/%08d and the alphabetically last (i.e.,

View File

@ -30,6 +30,10 @@ proto_library(
java_lite_proto_library( java_lite_proto_library(
name = "autoflip_messages_java_proto_lite", name = "autoflip_messages_java_proto_lite",
visibility = [
"//java/com/google/android/apps/photos:__subpackages__",
"//javatests/com/google/android/apps/photos:__subpackages__",
],
deps = [ deps = [
":autoflip_messages_proto", ":autoflip_messages_proto",
], ],

View File

@ -398,7 +398,7 @@ template <class Calc = internal::Generic>
class Node; class Node;
#if __cplusplus >= 201703L #if __cplusplus >= 201703L
// Deduction guide to silence -Wctad-maybe-unsupported. // Deduction guide to silence -Wctad-maybe-unsupported.
explicit Node()->Node<internal::Generic>; explicit Node() -> Node<internal::Generic>;
#endif // C++17 #endif // C++17
template <> template <>

View File

@ -181,7 +181,7 @@ template <typename T = internal::Generic>
class Packet; class Packet;
#if __cplusplus >= 201703L #if __cplusplus >= 201703L
// Deduction guide to silence -Wctad-maybe-unsupported. // Deduction guide to silence -Wctad-maybe-unsupported.
explicit Packet()->Packet<internal::Generic>; explicit Packet() -> Packet<internal::Generic>;
#endif // C++17 #endif // C++17
template <> template <>

View File

@ -455,7 +455,7 @@ cc_library(
], ],
}), }),
deps = [ deps = [
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"//mediapipe/framework:port", "//mediapipe/framework:port",

View File

@ -24,7 +24,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
@ -434,8 +434,9 @@ class Tensor {
mutable bool use_ahwb_ = false; mutable bool use_ahwb_ = false;
mutable uint64_t ahwb_tracking_key_ = 0; mutable uint64_t ahwb_tracking_key_ = 0;
// TODO: Tracks all unique tensors. Can grow to a large number. LRU // TODO: Tracks all unique tensors. Can grow to a large number. LRU
// can be more predicted. // (Least Recently Used) can be more predicted.
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_; // The value contains the size alignment parameter.
static inline absl::flat_hash_map<uint64_t, int> ahwb_usage_track_;
// Expects the target SSBO to be already bound. // Expects the target SSBO to be already bound.
bool AllocateAhwbMapToSsbo() const; bool AllocateAhwbMapToSsbo() const;
bool InsertAhwbToSsboFence() const; bool InsertAhwbToSsboFence() const;

View File

@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
// Mark current tracking key as Ahwb-use. // Mark current tracking key as Ahwb-use.
ahwb_usage_track_.insert(ahwb_tracking_key_); if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_);
it != ahwb_usage_track_.end()) {
size_alignment = it->second;
} else if (ahwb_tracking_key_ != 0) {
ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment});
}
use_ahwb_ = true; use_ahwb_ = true;
if (__builtin_available(android 26, *)) { if (__builtin_available(android 26, *)) {
@ -458,7 +463,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
} }
} }
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); // Keep flag value if it was set previously.
use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_);
} }
#else // MEDIAPIPE_TENSOR_USE_AHWB #else // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
}; };
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{
// Request Ahwb first to get Ahwb storage allocated internally.
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetWritingFinishedFD(-1, [](bool) { return true; });
}
RunInGlContext([&tensor] { RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name(); auto ssbo_name = ssbo_view.name();
@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
} }
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
{
// Request Ahwb first to get Ahwb storage allocated internally.
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
RunInGlContext([&tensor] { RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name(); auto ssbo_name = ssbo_view.name();
@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
// Request the CPU view to get the memory to be allocated. // Request the CPU view to get the memory to be allocated.
// Request Ahwb view then to transform the storage into Ahwb. // Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{ {
@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
// Request the GPU view to get the ssbo allocated internally. // Request the GPU view to get the ssbo allocated internally.
// Request Ahwb view then to transform the storage into Ahwb. // Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] { RunInGlContext([&tensor] {

View File

@ -1,34 +1,28 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_test_base.h"
#include "testing/base/public/gmock.h" #include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h" #include "testing/base/public/gunit.h"
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#if !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
class TensorAhwbTest : public mediapipe::GpuTestBase { TEST(TensorAhwbTest, TestCpuThenAHWB) {
public:
};
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{ {
auto ptr = tensor.GetCpuWriteView().buffer<float>(); auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr); EXPECT_NE(ptr, nullptr);
} }
{ {
auto ahwb = tensor.GetAHardwareBufferReadView().handle(); auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(ahwb, nullptr); EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
} }
} }
TEST_F(TensorAhwbTest, TestAHWBThenCpu) { TEST(TensorAhwbTest, TestAHWBThenCpu) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{ {
auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(ahwb, nullptr); EXPECT_NE(view.handle(), nullptr);
view.SetWritingFinishedFD(-1, [](bool) { return true; });
} }
{ {
auto ptr = tensor.GetCpuReadView().buffer<float>(); auto ptr = tensor.GetCpuReadView().buffer<float>();
@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
} }
} }
TEST_F(TensorAhwbTest, TestCpuThenGl) { TEST(TensorAhwbTest, TestAhwbAlignment) {
RunInGlContext([] { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5});
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); {
auto view = tensor.GetAHardwareBufferWriteView(16);
EXPECT_NE(view.handle(), nullptr);
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc;
AHardwareBuffer_describe(view.handle(), &desc);
// sizeof(float) * 5 = 20, the closest aligned to 16 size is 32.
EXPECT_EQ(desc.width, 32);
}
view.SetWritingFinishedFD(-1, [](bool) { return true; });
}
}
// Tensor::GetCpuView uses source location mechanism that gives source file name
// and line from where the method is called. The function is intended just to
// have two calls providing the same source file name and line.
auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); }
// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved
// for the first time then the source location is attached to the tensor. If the
// Ahwb view is requested then from the tensor then the previously recorded Cpu
// view request source location is marked for using Ahwb storage.
// When a Cpu view with the same source location (but for the newly allocated
// tensor) is requested and the location is marked to use Ahwb storage then the
// Ahwb storage is allocated for the CpuView.
TEST(TensorAhwbTest, TestTrackingAhwb) {
// Create first tensor and request Cpu and then Ahwb view to mark the source
// location for Ahwb storage.
{
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
{ {
auto ptr = tensor.GetCpuWriteView().buffer<float>(); auto view = GetCpuView(tensor);
EXPECT_NE(ptr, nullptr); EXPECT_NE(view.buffer<float>(), nullptr);
} }
{ {
auto ssbo = tensor.GetOpenGlBufferReadView().name(); // Align size of the Ahwb by multiple of 16.
EXPECT_GT(ssbo, 0); auto view = tensor.GetAHardwareBufferWriteView(16);
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
} }
}); }
{
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
{
// The second tensor uses the same Cpu view source location so Ahwb
// storage is allocated internally.
auto view = GetCpuView(tensor);
EXPECT_NE(view.buffer<float>(), nullptr);
}
{
// Check the Ahwb size to be aligned to multiple of 16. The alignment is
// stored by previous requesting of the Ahwb view.
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc;
AHardwareBuffer_describe(view.handle(), &desc);
// sizeof(float) * 9 = 36. The closest aligned size is 48.
EXPECT_EQ(desc.width, 48);
}
view.SetReadingFinishedFunc([](bool) { return true; });
}
}
} }
} // namespace mediapipe } // namespace mediapipe
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -194,6 +194,7 @@ void GraphProfiler::Initialize(
"Calculator \"$0\" has already been added.", node_name); "Calculator \"$0\" has already been added.", node_name);
} }
profile_builder_ = std::make_unique<GraphProfileBuilder>(this); profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
graph_id_ = ++next_instance_id_;
is_initialized_ = true; is_initialized_ = true;
} }

View File

@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
return validated_graph_; return validated_graph_;
} }
// Gets a numerical identifier for this GraphProfiler object.
uint64_t GetGraphId() { return graph_id_; }
private: private:
// This can be used to add packet info for the input streams to the graph. // This can be used to add packet info for the input streams to the graph.
// It treats the stream defined by |stream_name| as a stream produced by a // It treats the stream defined by |stream_name| as a stream produced by a
@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
class GraphProfileBuilder; class GraphProfileBuilder;
std::unique_ptr<GraphProfileBuilder> profile_builder_; std::unique_ptr<GraphProfileBuilder> profile_builder_;
// The globally incrementing identifier for all graphs in a process.
static inline std::atomic_int next_instance_id_ = 0;
// A unique identifier for this object. Only unique within a process.
uint64_t graph_id_;
// For testing. // For testing.
friend GraphProfilerTestPeer; friend GraphProfilerTestPeer;
}; };

View File

@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) {
"Cannot initialize .* multiple times."); "Cannot initialize .* multiple times.");
} }
// Tests that graph identifiers are not reused, even after destruction.
TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) {
auto raw_graph_config = R"(
profiler_config {
enable_profiler: true
}
input_stream: "input_stream"
node {
calculator: "DummyTestCalculator"
input_stream: "input_stream"
})";
const int n_iterations = 100;
absl::flat_hash_set<int> seen_ids;
for (int i = 0; i < n_iterations; ++i) {
std::shared_ptr<ProfilingContext> profiler =
std::make_shared<ProfilingContext>();
auto graph_config = CreateGraphConfig(raw_graph_config);
mediapipe::ValidatedGraphConfig validated_graph;
QCHECK_OK(validated_graph.Initialize(graph_config));
profiler->Initialize(validated_graph);
int id = profiler->GetGraphId();
ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id)));
seen_ids.insert(id);
}
}
// Tests that Pause(), Resume(), and Reset() works. // Tests that Pause(), Resume(), and Reset() works.
TEST_F(GraphProfilerTestPeer, PauseResumeReset) { TEST_F(GraphProfilerTestPeer, PauseResumeReset) {
InitializeProfilerWithGraphConfig(R"( InitializeProfilerWithGraphConfig(R"(

View File

@ -74,42 +74,51 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView(
static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
const GlTextureView& view) { const GlTextureView& view) {
CHECK(pixel_buffer); CHECK(pixel_buffer);
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); auto ctx = GlContext::GetCurrent().get();
CHECK(err == kCVReturnSuccess) if (!ctx) ctx = view.gl_context();
<< "CVPixelBufferLockBaseAddress failed: " << err; ctx->Run([pixel_buffer, &view, ctx] {
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); CHECK(err == kCVReturnSuccess)
uint8_t* pixel_ptr = << "CVPixelBufferLockBaseAddress failed: " << err;
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer)); OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
if (pixel_format == kCVPixelFormatType_32BGRA) { size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
// TODO: restore previous framebuffer? Move this to helper so we uint8_t* pixel_ptr =
// can use BindFramebuffer? static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
glViewport(0, 0, view.width(), view.height()); if (pixel_format == kCVPixelFormatType_32BGRA) {
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx));
view.name(), 0); glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
view.target(), view.name(), 0);
size_t contiguous_bytes_per_row = view.width() * 4; size_t contiguous_bytes_per_row = view.width() * 4;
if (bytes_per_row == contiguous_bytes_per_row) { if (bytes_per_row == contiguous_bytes_per_row) {
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
pixel_ptr); GL_UNSIGNED_BYTE, pixel_ptr);
} else { } else {
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row * // TODO: use GL_PACK settings for row length. We can expect
view.height()); // GLES 3.0 on iOS now.
uint8_t* temp_ptr = contiguous_buffer.data(); std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, view.height());
temp_ptr); uint8_t* temp_ptr = contiguous_buffer.data();
for (int i = 0; i < view.height(); ++i) { glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); GL_UNSIGNED_BYTE, temp_ptr);
temp_ptr += contiguous_bytes_per_row; for (int i = 0; i < view.height(); ++i) {
pixel_ptr += bytes_per_row; memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
temp_ptr += contiguous_bytes_per_row;
pixel_ptr += bytes_per_row;
}
} }
// TODO: restore previous framebuffer?
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
view.target(), 0, 0);
glBindFramebuffer(GL_FRAMEBUFFER, 0);
} else {
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
} }
} else { err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
LOG(ERROR) << "unsupported pixel format: " << pixel_format; CHECK(err == kCVReturnSuccess)
} << "CVPixelBufferUnlockBaseAddress failed: " << err;
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); });
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
} }
#endif // TARGET_IPHONE_SIMULATOR #endif // TARGET_IPHONE_SIMULATOR

View File

@ -71,9 +71,12 @@ class TextClassifierTest(tf.test.TestCase):
self.assertTrue(os.path.exists(output_metadata_file)) self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
filecmp.clear_cache()
self.assertTrue( self.assertTrue(
filecmp.cmp(output_metadata_file, filecmp.cmp(
self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
shallow=False))
def test_create_and_train_bert(self): def test_create_and_train_bert(self):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()

View File

@ -135,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue(os.path.exists(output_metadata_file)) self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) filecmp.clear_cache()
self.assertTrue(
filecmp.cmp(
output_metadata_file, expected_metadata_file, shallow=False))
def test_continual_training_by_loading_checkpoint(self): def test_continual_training_by_loading_checkpoint(self):
mock_stdout = io.StringIO() mock_stdout = io.StringIO()

View File

@ -230,16 +230,17 @@ if ([wrapper.delegate
} }
- (absl::Status)performStart { - (absl::Status)performStart {
absl::Status status = _graph->Initialize(_config); absl::Status status;
if (!status.ok()) {
return status;
}
for (const auto& service_packet : _servicePackets) { for (const auto& service_packet : _servicePackets) {
status = _graph->SetServicePacket(*service_packet.first, service_packet.second); status = _graph->SetServicePacket(*service_packet.first, service_packet.second);
if (!status.ok()) { if (!status.ok()) {
return status; return status;
} }
} }
status = _graph->Initialize(_config);
if (!status.ok()) {
return status;
}
status = _graph->StartRun(_inputSidePackets, _streamHeaders); status = _graph->StartRun(_inputSidePackets, _streamHeaders);
if (!status.ok()) { if (!status.ok()) {
return status; return status;

View File

@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
auto custom_gestures_classifier_options_proto = auto custom_gestures_classifier_options_proto =
std::make_unique<components::processors::proto::ClassifierOptions>( std::make_unique<components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto( components::processors::ConvertClassifierOptionsToProto(
&(options->canned_gestures_classifier_options))); &(options->custom_gestures_classifier_options)));
hand_gesture_recognizer_graph_options hand_gesture_recognizer_graph_options
->mutable_custom_gesture_classifier_graph_options() ->mutable_custom_gesture_classifier_graph_options()
->mutable_classifier_options() ->mutable_classifier_options()
->Swap(canned_gestures_classifier_options_proto.get()); ->Swap(custom_gestures_classifier_options_proto.get());
return options_proto; return options_proto;
} }

View File

@ -38,4 +38,3 @@ objc_library(
"-std=c++17", "-std=c++17",
], ],
) )

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@ -56,6 +57,7 @@ extern NSString *const MPPTasksErrorDomain;
* @param status absl::Status. * @param status absl::Status.
* @param error Pointer to the memory location where the created error should be saved. If `nil`, * @param error Pointer to the memory location where the created error should be saved. If `nil`,
* no error will be saved. * no error will be saved.
* @return YES when there is no error, NO otherwise.
*/ */
+ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; + (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error;

View File

@ -20,7 +20,6 @@
#include "absl/status/status.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/cord.h" // from @com_google_absl #include "absl/strings/cord.h" // from @com_google_absl
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
/** Error domain of MediaPipe task library errors. */ /** Error domain of MediaPipe task library errors. */
@ -96,8 +95,8 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
// appropriate MPPTasksErrorCode in default cases. Note: // appropriate MPPTasksErrorCode in default cases. Note:
// The mapping to absl::Status::code() is done to generate a more specific error code than // The mapping to absl::Status::code() is done to generate a more specific error code than
// MPPTasksErrorCodeError in cases when the payload can't be mapped to // MPPTasksErrorCodeError in cases when the payload can't be mapped to
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn
// without modification by Mediapipe cc library methods. // returned without modification by Mediapipe cc library methods.
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
switch (status.code()) { switch (status.code()) {
case absl::StatusCode::kInternal: case absl::StatusCode::kInternal:

View File

@ -13,13 +13,14 @@
// limitations under the License. // limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include <string> #include <string>
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@interface NSString (Helpers) @interface NSString (Helpers)
@property(readonly) std::string cppString; @property(readonly, nonatomic) std::string cppString;
+ (NSString *)stringWithCppString:(std::string)text; + (NSString *)stringWithCppString:(std::string)text;

View File

@ -21,4 +21,3 @@ objc_library(
srcs = ["sources/MPPClassifierOptions.m"], srcs = ["sources/MPPClassifierOptions.m"],
hdrs = ["sources/MPPClassifierOptions.h"], hdrs = ["sources/MPPClassifierOptions.h"],
) )

View File

@ -22,29 +22,34 @@ NS_ASSUME_NONNULL_BEGIN
NS_SWIFT_NAME(ClassifierOptions) NS_SWIFT_NAME(ClassifierOptions)
@interface MPPClassifierOptions : NSObject <NSCopying> @interface MPPClassifierOptions : NSObject <NSCopying>
/** The locale to use for display names specified through the TFLite Model /**
* The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English. * Metadata, if any. Defaults to English.
*/ */
@property(nonatomic, copy) NSString *displayNamesLocale; @property(nonatomic, copy) NSString *displayNamesLocale;
/** The maximum number of top-scored classification results to return. If < 0, /**
* The maximum number of top-scored classification results to return. If < 0,
* all available results will be returned. If 0, an invalid argument error is * all available results will be returned. If 0, an invalid argument error is
* returned. * returned.
*/ */
@property(nonatomic) NSInteger maxResults; @property(nonatomic) NSInteger maxResults;
/** Score threshold to override the one provided in the model metadata (if any). /**
* Score threshold to override the one provided in the model metadata (if any).
* Results below this value are rejected. * Results below this value are rejected.
*/ */
@property(nonatomic) float scoreThreshold; @property(nonatomic) float scoreThreshold;
/** The allowlist of category names. If non-empty, detection results whose /**
* The allowlist of category names. If non-empty, detection results whose
* category name is not in this set will be filtered out. Duplicate or unknown * category name is not in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryDenylist. * category names are ignored. Mutually exclusive with categoryDenylist.
*/ */
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist; @property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
/** The denylist of category names. If non-empty, detection results whose /**
* The denylist of category names. If non-empty, detection results whose
* category name is in this set will be filtered out. Duplicate or unknown * category name is in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryAllowlist. * category names are ignored. Mutually exclusive with categoryAllowlist.
*/ */

View File

@ -19,8 +19,8 @@
- (instancetype)init { - (instancetype)init {
self = [super init]; self = [super init];
if (self) { if (self) {
self.maxResults = -1; _maxResults = -1;
self.scoreThreshold = 0; _scoreThreshold = 0;
} }
return self; return self;
} }

View File

@ -21,9 +21,8 @@ objc_library(
srcs = ["sources/MPPClassifierOptions+Helpers.mm"], srcs = ["sources/MPPClassifierOptions+Helpers.mm"],
hdrs = ["sources/MPPClassifierOptions+Helpers.h"], hdrs = ["sources/MPPClassifierOptions+Helpers.h"],
deps = [ deps = [
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
] ],
) )

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN

View File

@ -29,7 +29,6 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
} }
classifierOptionsProto->set_max_results((int)self.maxResults); classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold); classifierOptionsProto->set_score_threshold(self.scoreThreshold);
for (NSString *category in self.categoryAllowlist) { for (NSString *category in self.categoryAllowlist) {

View File

@ -54,14 +54,14 @@ objc_library(
"-std=c++17", "-std=c++17",
], ],
deps = [ deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
":MPPTaskOptions", ":MPPTaskOptions",
":MPPTaskOptionsProtocol", ":MPPTaskOptionsProtocol",
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/common:MPPCommon",
], ],
) )
@ -83,9 +83,13 @@ objc_library(
name = "MPPTaskRunner", name = "MPPTaskRunner",
srcs = ["sources/MPPTaskRunner.mm"], srcs = ["sources/MPPTaskRunner.mm"],
hdrs = ["sources/MPPTaskRunner.h"], hdrs = ["sources/MPPTaskRunner.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [ deps = [
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
], ],
) )

View File

@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
@ -59,7 +61,7 @@ NS_ASSUME_NONNULL_BEGIN
/** /**
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.
*/ */
- (mediapipe::CalculatorGraphConfig)generateGraphConfig; - (::mediapipe::CalculatorGraphConfig)generateGraphConfig;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -24,7 +24,6 @@
namespace { namespace {
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
using Node = ::mediapipe::CalculatorGraphConfig::Node; using Node = ::mediapipe::CalculatorGraphConfig::Node;
using ::mediapipe::CalculatorOptions;
using ::mediapipe::FlowLimiterCalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions;
using ::mediapipe::InputStreamInfo; using ::mediapipe::InputStreamInfo;
} // namespace } // namespace

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@ -25,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
/** /**
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
*/ */
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; - (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
@end @end

View File

@ -22,7 +22,6 @@ NS_ASSUME_NONNULL_BEGIN
/** /**
* This class is used to create and call appropriate methods on the C++ Task Runner. * This class is used to create and call appropriate methods on the C++ Task Runner.
*/ */
@interface MPPTaskRunner : NSObject @interface MPPTaskRunner : NSObject
/** /**
@ -35,11 +34,10 @@ NS_ASSUME_NONNULL_BEGIN
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error NS_DESIGNATED_INITIALIZER; error:(NSError **)error NS_DESIGNATED_INITIALIZER;
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>) - (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:
process:(const mediapipe::tasks::core::PacketMap &)packetMap (const mediapipe::tasks::core::PacketMap &)packetMap;
error:(NSError **)error;
- (void)close; - (absl::Status)close;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -17,7 +17,6 @@
namespace { namespace {
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace } // namespace
@ -49,8 +48,8 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
return _cppTaskRunner->Process(packetMap); return _cppTaskRunner->Process(packetMap);
} }
- (void)close { - (absl::Status)close {
_cppTaskRunner->Close(); return _cppTaskRunner->Close();
} }
@end @end

View File

@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
* *
* @param options The options for the audio classifier. * @param options The options for the audio classifier.
*/ */
override async setOptions(options: AudioClassifierOptions): Promise<void> { override setOptions(options: AudioClassifierOptions): Promise<void> {
await super.setOptions(options);
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM);

View File

@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
audioClassifier = new AudioClassifierFake(); audioClassifier = new AudioClassifierFake();
await audioClassifier.setOptions({}); // Initialize graph await audioClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
* *
* @param options The options for the audio embedder. * @param options The options for the audio embedder.
*/ */
override async setOptions(options: AudioEmbedderOptions): Promise<void> { override setOptions(options: AudioEmbedderOptions): Promise<void> {
await super.setOptions(options);
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM);

View File

@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
audioEmbedder = new AudioEmbedderFake(); audioEmbedder = new AudioEmbedderFake();
await audioEmbedder.setOptions({}); // Initialize graph await audioEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', () => { it('initializes graph', () => {

View File

@ -103,29 +103,3 @@ jasmine_node_test(
name = "embedder_options_test", name = "embedder_options_test",
deps = [":embedder_options_test_lib"], deps = [":embedder_options_test_lib"],
) )
mediapipe_ts_library(
name = "base_options",
srcs = [
"base_options.ts",
],
deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core",
],
)
mediapipe_ts_library(
name = "base_options_test_lib",
testonly = True,
srcs = ["base_options.test.ts"],
deps = [":base_options"],
)
jasmine_node_test(
name = "base_options_test",
deps = [":base_options_test_lib"],
)

View File

@ -1,127 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
// Placeholder for internal dependency on trusted resource URL builder
import {convertBaseOptionsToProto} from './base_options';
describe('convertBaseOptionsToProto()', () => {
const mockBytes = new Uint8Array([0, 1, 2, 3]);
const mockBytesResult = {
modelAsset: {
fileContent: Buffer.from(mockBytes).toString('base64'),
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
},
};
let fetchSpy: jasmine.Spy;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
} as unknown as Response;
});
global.fetch = fetchSpy;
});
it('verifies that at least one model asset option is provided', async () => {
await expectAsync(convertBaseOptionsToProto({}))
.toBeRejectedWithError(
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
});
it('verifies that no more than one model asset option is provided', async () => {
await expectAsync(convertBaseOptionsToProto({
modelAssetPath: `foo`,
modelAssetBuffer: new Uint8Array([])
}))
.toBeRejectedWithError(
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
});
it('downloads model', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetPath: `foo`,
});
expect(fetchSpy).toHaveBeenCalled();
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('does not download model when bytes are provided', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
});
expect(fetchSpy).not.toHaveBeenCalled();
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('can enable CPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'CPU',
});
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('can enable GPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
});
expect(baseOptionsProto.toObject()).toEqual({
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
});
it('can reset delegate', async () => {
let baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
});
// Clear backend
baseOptionsProto =
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
});

View File

@ -1,80 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided
*/
export async function convertBaseOptionsToProto(
updatedOptions: BaseOptions,
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
const result =
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
}
if (!externalFile.hasFileContent()) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
}
/** Configues the `acceleration` option. */
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
const acceleration = proto.getAcceleration() ?? new Acceleration();
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
proto.setAcceleration(acceleration);
}

View File

@ -18,8 +18,10 @@ mediapipe_ts_library(
srcs = ["task_runner.ts"], srcs = ["task_runner.ts"],
deps = [ deps = [
":core", ":core",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
@ -53,6 +55,7 @@ mediapipe_ts_library(
"task_runner_test.ts", "task_runner_test.ts",
], ],
deps = [ deps = [
":core",
":task_runner", ":task_runner",
":task_runner_test_utils", ":task_runner_test_utils",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",

View File

@ -14,9 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
@ -91,14 +93,52 @@ export abstract class TaskRunner {
this.graphRunner.registerModelResourcesGraphService(); this.graphRunner.registerModelResourcesGraphService();
} }
/** Configures the shared options of a MediaPipe Task. */ /** Configures the task with custom options. */
async setOptions(options: TaskRunnerOptions): Promise<void> { abstract setOptions(options: TaskRunnerOptions): Promise<void>;
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto( /**
options.baseOptions, this.baseOptions); * Applies the current set of options, including any base options that have
* not been processed by the task implementation. The options are applied
* synchronously unless a `modelAssetPath` is provided. This ensures that
* for most use cases options are applied directly and immediately affect
* the next inference.
*/
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
const baseOptions: BaseOptions = options.baseOptions || {};
// Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => response.arrayBuffer())
.then(buffer => {
this.setExternalFile(new Uint8Array(buffer));
this.refreshGraph();
});
} else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer);
this.refreshGraph();
return Promise.resolve();
} }
} }
/** Appliest the current options to the MediaPipe graph. */
protected abstract refreshGraph(): void;
/** /**
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
* over the video stream. Will replace the previously running MediaPipe graph, * over the video stream. Will replace the previously running MediaPipe graph,
@ -140,6 +180,27 @@ export abstract class TaskRunner {
} }
this.processingErrors = []; this.processingErrors = [];
} }
/** Configures the `externalFile` option */
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
if (modelAssetBuffer) {
externalFile.setFileContent(modelAssetBuffer);
}
this.baseOptions.setModelAsset(externalFile);
}
/** Configures the `acceleration` option. */
private setAcceleration(options: BaseOptions) {
const acceleration =
this.baseOptions.getAcceleration() ?? new Acceleration();
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
this.baseOptions.setAcceleration(acceleration);
}
} }

View File

@ -15,18 +15,22 @@
*/ */
import 'jasmine'; import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
import {ErrorListener} from '../../../web/graph_runner/graph_runner'; import {ErrorListener} from '../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource URL builder
import {GraphRunnerImageLib} from './task_runner'; import {GraphRunnerImageLib} from './task_runner';
import {TaskRunnerOptions} from './task_runner_options.d';
class TaskRunnerFake extends TaskRunner { class TaskRunnerFake extends TaskRunner {
protected baseOptions = new BaseOptionsProto();
private errorListener: ErrorListener|undefined; private errorListener: ErrorListener|undefined;
private errors: string[] = []; private errors: string[] = [];
baseOptions = new BaseOptionsProto();
static createFake(): TaskRunnerFake { static createFake(): TaskRunnerFake {
const wasmModule = createSpyWasmModule(); const wasmModule = createSpyWasmModule();
return new TaskRunnerFake(wasmModule); return new TaskRunnerFake(wasmModule);
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
super.finishProcessing(); super.finishProcessing();
} }
override refreshGraph(): void {}
override setGraph(graphData: Uint8Array, isBinary: boolean): void { override setGraph(graphData: Uint8Array, isBinary: boolean): void {
super.setGraph(graphData, isBinary); super.setGraph(graphData, isBinary);
} }
setOptions(options: TaskRunnerOptions): Promise<void> {
return this.applyOptions(options);
}
private throwErrors(): void { private throwErrors(): void {
expect(this.errorListener).toBeDefined(); expect(this.errorListener).toBeDefined();
for (const error of this.errors) { for (const error of this.errors) {
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
} }
describe('TaskRunner', () => { describe('TaskRunner', () => {
const mockBytes = new Uint8Array([0, 1, 2, 3]);
const mockBytesResult = {
modelAsset: {
fileContent: Buffer.from(mockBytes).toString('base64'),
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
},
};
let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
} as unknown as Response;
});
global.fetch = fetchSpy;
taskRunner = TaskRunnerFake.createFake();
});
it('handles errors during graph update', () => { it('handles errors during graph update', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.enqueueError('Test error'); taskRunner.enqueueError('Test error');
expect(() => { expect(() => {
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
}); });
it('handles errors during graph execution', () => { it('handles errors during graph execution', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
taskRunner.enqueueError('Test error'); taskRunner.enqueueError('Test error');
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
}); });
it('can handle multiple errors', () => { it('can handle multiple errors', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 1');
taskRunner.enqueueError('Test error 2'); taskRunner.enqueueError('Test error 2');
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
}).toThrowError(/Test error 1, Test error 2/); }).toThrowError(/Test error 1, Test error 2/);
}); });
it('verifies that at least one model asset option is provided', () => {
expect(() => {
taskRunner.setOptions({});
})
.toThrowError(
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
});
it('verifies that no more than one model asset option is provided', () => {
expect(() => {
taskRunner.setOptions({
baseOptions: {
modelAssetPath: `foo`,
modelAssetBuffer: new Uint8Array([])
}
});
})
.toThrowError(
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
});
it('doesn\'t require model once it is configured', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
expect(() => {
taskRunner.setOptions({});
}).not.toThrowError();
});
it('downloads model', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetPath: `foo`}});
expect(fetchSpy).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('does not download model when bytes are provided', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
expect(fetchSpy).not.toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('changes model synchronously when bytes are provided', () => {
const resolvedPromise = taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
// Check that the change has been applied even though we do not await the
// above Promise
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
return resolvedPromise;
});
it('can enable CPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'CPU',
}
});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('can enable GPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
expect(taskRunner.baseOptions.toObject()).toEqual({
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
});
it('can reset delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
// Clear backend
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
}); });

View File

@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule {
* Sets up our equality testing to use a custom float equality checking function * Sets up our equality testing to use a custom float equality checking function
* to avoid incorrect test results due to minor floating point inaccuracies. * to avoid incorrect test results due to minor floating point inaccuracies.
*/ */
export function addJasmineCustomFloatEqualityTester() { export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) {
jasmine.addCustomEqualityTester((a, b) => { // Custom float equality jasmine.addCustomEqualityTester((a, b) => { // Custom float equality
if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) {
return Math.abs(a - b) < 5e-8; return Math.abs(a - b) < tolerance;
} }
return; return;
}); });

View File

@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
* *
* @param options The options for the text classifier. * @param options The options for the text classifier.
*/ */
override async setOptions(options: TextClassifierOptions): Promise<void> { override setOptions(options: TextClassifierOptions): Promise<void> {
await super.setOptions(options);
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
protected override get baseOptions(): BaseOptionsProto { protected override get baseOptions(): BaseOptionsProto {
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);

View File

@ -56,7 +56,8 @@ describe('TextClassifier', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
textClassifier = new TextClassifierFake(); textClassifier = new TextClassifierFake();
await textClassifier.setOptions({}); // Initialize graph await textClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
* *
* @param options The options for the text embedder. * @param options The options for the text embedder.
*/ */
override async setOptions(options: TextEmbedderOptions): Promise<void> { override setOptions(options: TextEmbedderOptions): Promise<void> {
await super.setOptions(options);
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
protected override get baseOptions(): BaseOptionsProto { protected override get baseOptions(): BaseOptionsProto {
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM);

View File

@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
textEmbedder = new TextEmbedderFake(); textEmbedder = new TextEmbedderFake();
await textEmbedder.setOptions({}); // Initialize graph await textEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -29,6 +29,7 @@ mediapipe_ts_library(
testonly = True, testonly = True,
srcs = ["vision_task_runner.test.ts"], srcs = ["vision_task_runner.test.ts"],
deps = [ deps = [
":vision_task_options",
":vision_task_runner", ":vision_task_runner",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",

View File

@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {VisionTaskOptions} from './vision_task_options';
import {VisionTaskRunner} from './vision_task_runner'; import {VisionTaskRunner} from './vision_task_runner';
class VisionTaskRunnerFake extends VisionTaskRunner<void> { class VisionTaskRunnerFake extends VisionTaskRunner<void> {
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
protected override process(): void {} protected override process(): void {}
protected override refreshGraph(): void {}
override setOptions(options: VisionTaskOptions): Promise<void> {
return this.applyOptions(options);
}
override processImageData(image: ImageSource): void { override processImageData(image: ImageSource): void {
super.processImageData(image); super.processImageData(image);
} }
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
} }
describe('VisionTaskRunner', () => { describe('VisionTaskRunner', () => {
const streamMode = {
modelAsset: undefined,
useStreamMode: true,
acceleration: undefined,
};
const imageMode = {
modelAsset: undefined,
useStreamMode: false,
acceleration: undefined,
};
let visionTaskRunner: VisionTaskRunnerFake; let visionTaskRunner: VisionTaskRunnerFake;
beforeEach(() => { beforeEach(async () => {
visionTaskRunner = new VisionTaskRunnerFake(); visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('can enable image mode', async () => { it('can enable image mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'image'}); await visionTaskRunner.setOptions({runningMode: 'image'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
}); });
it('can enable video mode', async () => { it('can enable video mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'video'}); await visionTaskRunner.setOptions({runningMode: 'video'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: true}));
}); });
it('can clear running mode', async () => { it('can clear running mode', async () => {
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
// Clear running mode // Clear running mode
await visionTaskRunner.setOptions({runningMode: undefined}); await visionTaskRunner.setOptions({runningMode: undefined});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
}); });
it('cannot process images with video mode', async () => { it('cannot process images with video mode', async () => {

View File

@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
/** Base class for all MediaPipe Vision Tasks. */ /** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner<T> extends TaskRunner { export abstract class VisionTaskRunner<T> extends TaskRunner {
/** Configures the shared options of a vision task. */ /** Configures the shared options of a vision task. */
override async setOptions(options: VisionTaskOptions): Promise<void> { override applyOptions(options: VisionTaskOptions): Promise<void> {
await super.setOptions(options);
if ('runningMode' in options) { if ('runningMode' in options) {
const useStreamMode = const useStreamMode =
!!options.runningMode && options.runningMode !== 'image'; !!options.runningMode && options.runningMode !== 'image';
this.baseOptions.setUseStreamMode(useStreamMode); this.baseOptions.setUseStreamMode(useStreamMode);
} }
return super.applyOptions(options);
} }
/** Sends an image packet to the graph and awaits results. */ /** Sends an image packet to the graph and awaits results. */

View File

@ -169,9 +169,7 @@ export class GestureRecognizer extends
* *
* @param options The options for the gesture recognizer. * @param options The options for the gesture recognizer.
*/ */
override async setOptions(options: GestureRecognizerOptions): Promise<void> { override setOptions(options: GestureRecognizerOptions): Promise<void> {
await super.setOptions(options);
if ('numHands' in options) { if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands( this.handDetectorGraphOptions.setNumHands(
options.numHands ?? DEFAULT_NUM_HANDS); options.numHands ?? DEFAULT_NUM_HANDS);
@ -221,7 +219,7 @@ export class GestureRecognizer extends
?.clearClassifierOptions(); ?.clearClassifierOptions();
} }
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -265,12 +263,22 @@ export class GestureRecognizer extends
NORM_RECT_STREAM, timestamp); NORM_RECT_STREAM, timestamp);
this.finishProcessing(); this.finishProcessing();
return { if (this.gestures.length === 0) {
gestures: this.gestures, // If no gestures are detected in the image, just return an empty list
landmarks: this.landmarks, return {
worldLandmarks: this.worldLandmarks, gestures: [],
handednesses: this.handednesses landmarks: [],
}; worldLandmarks: [],
handednesses: [],
};
} else {
return {
gestures: this.gestures,
landmarks: this.landmarks,
worldLandmarks: this.worldLandmarks,
handednesses: this.handednesses
};
}
} }
/** Sets the default values for the graph. */ /** Sets the default values for the graph. */
@ -285,15 +293,19 @@ export class GestureRecognizer extends
} }
/** Converts the proto data to a Category[][] structure. */ /** Converts the proto data to a Category[][] structure. */
private toJsCategories(data: Uint8Array[]): Category[][] { private toJsCategories(data: Uint8Array[], populateIndex = true):
Category[][] {
const result: Category[][] = []; const result: Category[][] = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const inputList = ClassificationList.deserializeBinary(binaryProto); const inputList = ClassificationList.deserializeBinary(binaryProto);
const outputList: Category[] = []; const outputList: Category[] = [];
for (const classification of inputList.getClassificationList()) { for (const classification of inputList.getClassificationList()) {
const index = populateIndex && classification.hasIndex() ?
classification.getIndex()! :
DEFAULT_CATEGORY_INDEX;
outputList.push({ outputList.push({
score: classification.getScore() ?? 0, score: classification.getScore() ?? 0,
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, index,
categoryName: classification.getLabel() ?? '', categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '', displayName: classification.getDisplayName() ?? '',
}); });
@ -342,7 +354,7 @@ export class GestureRecognizer extends
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM);
@ -377,7 +389,10 @@ export class GestureRecognizer extends
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
HAND_GESTURES_STREAM, binaryProto => { HAND_GESTURES_STREAM, binaryProto => {
this.gestures.push(...this.toJsCategories(binaryProto)); // Gesture index is not used, because the final gesture result comes
// from multiple classifiers.
this.gestures.push(
...this.toJsCategories(binaryProto, /* populateIndex= */ false));
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => { HANDEDNESS_STREAM, binaryProto => {

View File

@ -17,6 +17,8 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
export {Category, Landmark, NormalizedLandmark};
/** /**
* Represents the gesture recognition results generated by `GestureRecognizer`. * Represents the gesture recognition results generated by `GestureRecognizer`.
*/ */
@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult {
/** Handedness of detected hands. */ /** Handedness of detected hands. */
handednesses: Category[][]; handednesses: Category[][];
/** Recognized hand gestures of detected hands */ /**
* Recognized hand gestures of detected hands. Note that the index of the
* gesture is always -1, because the raw indices from multiple gesture
* classifiers cannot consolidate to a meaningful index.
*/
gestures: Category[][]; gestures: Category[][];
} }

View File

@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
gestureRecognizer = new GestureRecognizerFake(); gestureRecognizer = new GestureRecognizerFake();
await gestureRecognizer.setOptions({}); // Initialize graph await gestureRecognizer.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {
@ -271,7 +272,7 @@ describe('GestureRecognizer', () => {
expect(gestures).toEqual({ expect(gestures).toEqual({
'gestures': [[{ 'gestures': [[{
'score': 0.2, 'score': 0.2,
'index': 2, 'index': -1,
'categoryName': 'gesture_label', 'categoryName': 'gesture_label',
'displayName': 'gesture_display_name' 'displayName': 'gesture_display_name'
}]], }]],
@ -304,4 +305,25 @@ describe('GestureRecognizer', () => {
// gestures. // gestures.
expect(gestures2).toEqual(gestures1); expect(gestures2).toEqual(gestures1);
}); });
it('returns empty results when no gestures are detected', async () => {
// Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(gestureRecognizer);
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
gestureRecognizer.listeners.get('hand_gestures')!([]);
});
// Invoke the gesture recognizer
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
expect(gestures).toEqual({
'gestures': [],
'landmarks': [],
'worldLandmarks': [],
'handednesses': []
});
});
}); });

View File

@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
* *
* @param options The options for the hand landmarker. * @param options The options for the hand landmarker.
*/ */
override async setOptions(options: HandLandmarkerOptions): Promise<void> { override setOptions(options: HandLandmarkerOptions): Promise<void> {
await super.setOptions(options);
// Configure hand detector options. // Configure hand detector options.
if ('numHands' in options) { if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands( this.handDetectorGraphOptions.setNumHands(
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
} }
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM);

View File

@ -17,6 +17,8 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
export {Landmark, NormalizedLandmark, Category};
/** /**
* Represents the hand landmarks deection results generated by `HandLandmarker`. * Represents the hand landmarks deection results generated by `HandLandmarker`.
*/ */

View File

@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
handLandmarker = new HandLandmarkerFake(); handLandmarker = new HandLandmarkerFake();
await handLandmarker.setOptions({}); // Initialize graph await handLandmarker.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
* *
* @param options The options for the image classifier. * @param options The options for the image classifier.
*/ */
override async setOptions(options: ImageClassifierOptions): Promise<void> { override setOptions(options: ImageClassifierOptions): Promise<void> {
await super.setOptions(options);
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);

View File

@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
imageClassifier = new ImageClassifierFake(); imageClassifier = new ImageClassifierFake();
await imageClassifier.setOptions({}); // Initialize graph await imageClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
* *
* @param options The options for the image embedder. * @param options The options for the image embedder.
*/ */
override async setOptions(options: ImageEmbedderOptions): Promise<void> { override setOptions(options: ImageEmbedderOptions): Promise<void> {
await super.setOptions(options);
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM);

View File

@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
imageEmbedder = new ImageEmbedderFake(); imageEmbedder = new ImageEmbedderFake();
await imageEmbedder.setOptions({}); // Initialize graph await imageEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
* *
* @param options The options for the object detector. * @param options The options for the object detector.
*/ */
override async setOptions(options: ObjectDetectorOptions): Promise<void> { override setOptions(options: ObjectDetectorOptions): Promise<void> {
await super.setOptions(options);
// Note that we have to support both JSPB and ProtobufJS, hence we // Note that we have to support both JSPB and ProtobufJS, hence we
// have to expliclity clear the values instead of setting them to // have to expliclity clear the values instead of setting them to
// `undefined`. // `undefined`.
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
this.options.clearCategoryDenylistList(); this.options.clearCategoryDenylistList();
} }
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(DETECTIONS_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM);

View File

@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
objectDetector = new ObjectDetectorFake(); objectDetector = new ObjectDetectorFake();
await objectDetector.setOptions({}); // Initialize graph await objectDetector.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {