Merge branch 'master' into ios-task
This commit is contained in:
parent
7e0fec7c28
commit
7ce21038bb
3
.github/bot_config.yml
vendored
3
.github/bot_config.yml
vendored
|
@ -15,4 +15,5 @@
|
||||||
|
|
||||||
# A list of assignees
|
# A list of assignees
|
||||||
assignees:
|
assignees:
|
||||||
- sureshdagooglecom
|
- kuaashish
|
||||||
|
- ayushgdev
|
||||||
|
|
|
@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If options.convert_signature_to_tags() is set, will convert letters to
|
// If options.convert_signature_to_tags() is set, will convert letters to
|
||||||
// uppercase and replace /'s and -'s with _'s. This enables the standard
|
// uppercase and replace /, -, . and :'s with _'s. This enables the standard
|
||||||
// SavedModel classification, regression, and prediction signatures to be used
|
// SavedModel classification, regression, and prediction signatures to be used
|
||||||
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
|
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
|
||||||
// patterns.
|
// patterns.
|
||||||
|
@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag(
|
||||||
output.resize(name.length());
|
output.resize(name.length());
|
||||||
std::transform(name.begin(), name.end(), output.begin(),
|
std::transform(name.begin(), name.end(), output.begin(),
|
||||||
[](unsigned char c) { return std::toupper(c); });
|
[](unsigned char c) { return std::toupper(c); });
|
||||||
output = absl::StrReplaceAll(output, {{"/", "_"}});
|
output = absl::StrReplaceAll(
|
||||||
output = absl::StrReplaceAll(output, {{"-", "_"}});
|
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
|
||||||
output = absl::StrReplaceAll(output, {{".", "_"}});
|
|
||||||
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
|
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions {
|
||||||
// The name of the generic signature to load into the mapping from tags to
|
// The name of the generic signature to load into the mapping from tags to
|
||||||
// tensor names.
|
// tensor names.
|
||||||
optional string signature_name = 2 [default = "serving_default"];
|
optional string signature_name = 2 [default = "serving_default"];
|
||||||
// Whether to convert the signature keys to uppercase as well as switch /'s
|
// Whether to convert the signature keys to uppercase as well as switch
|
||||||
// and -'s to _'s, which enables common signatures to be used as Tags.
|
// /, -, .and :'s to _'s, which enables common signatures to be used as Tags.
|
||||||
optional bool convert_signature_to_tags = 3 [default = true];
|
optional bool convert_signature_to_tags = 3 [default = true];
|
||||||
// If true, saved_model_path can have multiple exported models in
|
// If true, saved_model_path can have multiple exported models in
|
||||||
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
|
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
|
||||||
|
|
|
@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If options.convert_signature_to_tags() is set, will convert letters to
|
// If options.convert_signature_to_tags() is set, will convert letters to
|
||||||
// uppercase and replace /'s and -'s with _'s. This enables the standard
|
// uppercase and replace /, -, and .'s with _'s. This enables the standard
|
||||||
// SavedModel classification, regression, and prediction signatures to be used
|
// SavedModel classification, regression, and prediction signatures to be used
|
||||||
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
|
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
|
||||||
// patterns.
|
// patterns.
|
||||||
|
@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag(
|
||||||
output.resize(name.length());
|
output.resize(name.length());
|
||||||
std::transform(name.begin(), name.end(), output.begin(),
|
std::transform(name.begin(), name.end(), output.begin(),
|
||||||
[](unsigned char c) { return std::toupper(c); });
|
[](unsigned char c) { return std::toupper(c); });
|
||||||
output = absl::StrReplaceAll(output, {{"/", "_"}});
|
output = absl::StrReplaceAll(
|
||||||
output = absl::StrReplaceAll(output, {{"-", "_"}});
|
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
|
||||||
output = absl::StrReplaceAll(output, {{".", "_"}});
|
|
||||||
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
|
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions {
|
||||||
// The name of the generic signature to load into the mapping from tags to
|
// The name of the generic signature to load into the mapping from tags to
|
||||||
// tensor names.
|
// tensor names.
|
||||||
optional string signature_name = 2 [default = "serving_default"];
|
optional string signature_name = 2 [default = "serving_default"];
|
||||||
// Whether to convert the signature keys to uppercase as well as switch /'s
|
// Whether to convert the signature keys to uppercase, as well as switch /'s
|
||||||
// and -'s to _'s, which enables common signatures to be used as Tags.
|
// -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags.
|
||||||
optional bool convert_signature_to_tags = 3 [default = true];
|
optional bool convert_signature_to_tags = 3 [default = true];
|
||||||
// If true, saved_model_path can have multiple exported models in
|
// If true, saved_model_path can have multiple exported models in
|
||||||
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
|
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
|
||||||
|
|
|
@ -30,6 +30,10 @@ proto_library(
|
||||||
|
|
||||||
java_lite_proto_library(
|
java_lite_proto_library(
|
||||||
name = "autoflip_messages_java_proto_lite",
|
name = "autoflip_messages_java_proto_lite",
|
||||||
|
visibility = [
|
||||||
|
"//java/com/google/android/apps/photos:__subpackages__",
|
||||||
|
"//javatests/com/google/android/apps/photos:__subpackages__",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":autoflip_messages_proto",
|
":autoflip_messages_proto",
|
||||||
],
|
],
|
||||||
|
|
|
@ -455,7 +455,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"//mediapipe/framework:port",
|
"//mediapipe/framework:port",
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "mediapipe/framework/formats/tensor_internal.h"
|
#include "mediapipe/framework/formats/tensor_internal.h"
|
||||||
#include "mediapipe/framework/port.h"
|
#include "mediapipe/framework/port.h"
|
||||||
|
@ -434,8 +434,9 @@ class Tensor {
|
||||||
mutable bool use_ahwb_ = false;
|
mutable bool use_ahwb_ = false;
|
||||||
mutable uint64_t ahwb_tracking_key_ = 0;
|
mutable uint64_t ahwb_tracking_key_ = 0;
|
||||||
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
|
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
|
||||||
// can be more predicted.
|
// (Least Recently Used) can be more predicted.
|
||||||
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_;
|
// The value contains the size alignment parameter.
|
||||||
|
static inline absl::flat_hash_map<uint64_t, int> ahwb_usage_track_;
|
||||||
// Expects the target SSBO to be already bound.
|
// Expects the target SSBO to be already bound.
|
||||||
bool AllocateAhwbMapToSsbo() const;
|
bool AllocateAhwbMapToSsbo() const;
|
||||||
bool InsertAhwbToSsboFence() const;
|
bool InsertAhwbToSsboFence() const;
|
||||||
|
|
|
@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
|
||||||
|
|
||||||
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
|
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
|
||||||
// Mark current tracking key as Ahwb-use.
|
// Mark current tracking key as Ahwb-use.
|
||||||
ahwb_usage_track_.insert(ahwb_tracking_key_);
|
if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_);
|
||||||
|
it != ahwb_usage_track_.end()) {
|
||||||
|
size_alignment = it->second;
|
||||||
|
} else if (ahwb_tracking_key_ != 0) {
|
||||||
|
ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment});
|
||||||
|
}
|
||||||
use_ahwb_ = true;
|
use_ahwb_ = true;
|
||||||
|
|
||||||
if (__builtin_available(android 26, *)) {
|
if (__builtin_available(android 26, *)) {
|
||||||
|
@ -458,7 +463,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
|
||||||
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
|
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
|
// Keep flag value if it was set previously.
|
||||||
|
use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // MEDIAPIPE_TENSOR_USE_AHWB
|
#else // MEDIAPIPE_TENSOR_USE_AHWB
|
||||||
|
|
|
@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
||||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
|
||||||
constexpr size_t num_elements = 20;
|
constexpr size_t num_elements = 20;
|
||||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||||
|
{
|
||||||
|
// Request Ahwb first to get Ahwb storage allocated internally.
|
||||||
|
auto view = tensor.GetAHardwareBufferWriteView();
|
||||||
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||||
|
}
|
||||||
RunInGlContext([&tensor] {
|
RunInGlContext([&tensor] {
|
||||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||||
auto ssbo_name = ssbo_view.name();
|
auto ssbo_name = ssbo_view.name();
|
||||||
|
@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
||||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
|
||||||
constexpr size_t num_elements = 20;
|
constexpr size_t num_elements = 20;
|
||||||
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
|
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
|
||||||
|
{
|
||||||
|
// Request Ahwb first to get Ahwb storage allocated internally.
|
||||||
|
auto view = tensor.GetAHardwareBufferWriteView();
|
||||||
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||||
|
}
|
||||||
RunInGlContext([&tensor] {
|
RunInGlContext([&tensor] {
|
||||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||||
auto ssbo_name = ssbo_view.name();
|
auto ssbo_name = ssbo_view.name();
|
||||||
|
@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
||||||
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
||||||
// Request the CPU view to get the memory to be allocated.
|
// Request the CPU view to get the memory to be allocated.
|
||||||
// Request Ahwb view then to transform the storage into Ahwb.
|
// Request Ahwb view then to transform the storage into Ahwb.
|
||||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
|
||||||
constexpr size_t num_elements = 20;
|
constexpr size_t num_elements = 20;
|
||||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||||
{
|
{
|
||||||
|
@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
||||||
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
|
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
|
||||||
// Request the GPU view to get the ssbo allocated internally.
|
// Request the GPU view to get the ssbo allocated internally.
|
||||||
// Request Ahwb view then to transform the storage into Ahwb.
|
// Request Ahwb view then to transform the storage into Ahwb.
|
||||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
|
||||||
constexpr size_t num_elements = 20;
|
constexpr size_t num_elements = 20;
|
||||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||||
RunInGlContext([&tensor] {
|
RunInGlContext([&tensor] {
|
||||||
|
|
|
@ -1,34 +1,28 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/gpu/gpu_test_base.h"
|
|
||||||
#include "testing/base/public/gmock.h"
|
#include "testing/base/public/gmock.h"
|
||||||
#include "testing/base/public/gunit.h"
|
#include "testing/base/public/gunit.h"
|
||||||
|
|
||||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
class TensorAhwbTest : public mediapipe::GpuTestBase {
|
TEST(TensorAhwbTest, TestCpuThenAHWB) {
|
||||||
public:
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
|
|
||||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||||
{
|
{
|
||||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||||
EXPECT_NE(ptr, nullptr);
|
EXPECT_NE(ptr, nullptr);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto ahwb = tensor.GetAHardwareBufferReadView().handle();
|
auto view = tensor.GetAHardwareBufferReadView();
|
||||||
EXPECT_NE(ahwb, nullptr);
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
TEST(TensorAhwbTest, TestAHWBThenCpu) {
|
||||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||||
{
|
{
|
||||||
auto ahwb = tensor.GetAHardwareBufferWriteView().handle();
|
auto view = tensor.GetAHardwareBufferWriteView();
|
||||||
EXPECT_NE(ahwb, nullptr);
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||||
|
@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorAhwbTest, TestCpuThenGl) {
|
TEST(TensorAhwbTest, TestAhwbAlignment) {
|
||||||
RunInGlContext([] {
|
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5});
|
||||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
|
||||||
{
|
{
|
||||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
auto view = tensor.GetAHardwareBufferWriteView(16);
|
||||||
EXPECT_NE(ptr, nullptr);
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
if (__builtin_available(android 26, *)) {
|
||||||
|
AHardwareBuffer_Desc desc;
|
||||||
|
AHardwareBuffer_describe(view.handle(), &desc);
|
||||||
|
// sizeof(float) * 5 = 20, the closest aligned to 16 size is 32.
|
||||||
|
EXPECT_EQ(desc.width, 32);
|
||||||
|
}
|
||||||
|
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensor::GetCpuView uses source location mechanism that gives source file name
|
||||||
|
// and line from where the method is called. The function is intended just to
|
||||||
|
// have two calls providing the same source file name and line.
|
||||||
|
auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); }
|
||||||
|
|
||||||
|
// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved
|
||||||
|
// for the first time then the source location is attached to the tensor. If the
|
||||||
|
// Ahwb view is requested then from the tensor then the previously recorded Cpu
|
||||||
|
// view request source location is marked for using Ahwb storage.
|
||||||
|
// When a Cpu view with the same source location (but for the newly allocated
|
||||||
|
// tensor) is requested and the location is marked to use Ahwb storage then the
|
||||||
|
// Ahwb storage is allocated for the CpuView.
|
||||||
|
TEST(TensorAhwbTest, TestTrackingAhwb) {
|
||||||
|
// Create first tensor and request Cpu and then Ahwb view to mark the source
|
||||||
|
// location for Ahwb storage.
|
||||||
|
{
|
||||||
|
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
|
||||||
|
{
|
||||||
|
auto view = GetCpuView(tensor);
|
||||||
|
EXPECT_NE(view.buffer<float>(), nullptr);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto ssbo = tensor.GetOpenGlBufferReadView().name();
|
// Align size of the Ahwb by multiple of 16.
|
||||||
EXPECT_GT(ssbo, 0);
|
auto view = tensor.GetAHardwareBufferWriteView(16);
|
||||||
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
|
||||||
|
{
|
||||||
|
// The second tensor uses the same Cpu view source location so Ahwb
|
||||||
|
// storage is allocated internally.
|
||||||
|
auto view = GetCpuView(tensor);
|
||||||
|
EXPECT_NE(view.buffer<float>(), nullptr);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check the Ahwb size to be aligned to multiple of 16. The alignment is
|
||||||
|
// stored by previous requesting of the Ahwb view.
|
||||||
|
auto view = tensor.GetAHardwareBufferReadView();
|
||||||
|
EXPECT_NE(view.handle(), nullptr);
|
||||||
|
if (__builtin_available(android 26, *)) {
|
||||||
|
AHardwareBuffer_Desc desc;
|
||||||
|
AHardwareBuffer_describe(view.handle(), &desc);
|
||||||
|
// sizeof(float) * 9 = 36. The closest aligned size is 48.
|
||||||
|
EXPECT_EQ(desc.width, 48);
|
||||||
|
}
|
||||||
|
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
|
||||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
|
||||||
|
|
|
@ -194,6 +194,7 @@ void GraphProfiler::Initialize(
|
||||||
"Calculator \"$0\" has already been added.", node_name);
|
"Calculator \"$0\" has already been added.", node_name);
|
||||||
}
|
}
|
||||||
profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
|
profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
|
||||||
|
graph_id_ = ++next_instance_id_;
|
||||||
|
|
||||||
is_initialized_ = true;
|
is_initialized_ = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
|
||||||
return validated_graph_;
|
return validated_graph_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gets a numerical identifier for this GraphProfiler object.
|
||||||
|
uint64_t GetGraphId() { return graph_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// This can be used to add packet info for the input streams to the graph.
|
// This can be used to add packet info for the input streams to the graph.
|
||||||
// It treats the stream defined by |stream_name| as a stream produced by a
|
// It treats the stream defined by |stream_name| as a stream produced by a
|
||||||
|
@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
|
||||||
class GraphProfileBuilder;
|
class GraphProfileBuilder;
|
||||||
std::unique_ptr<GraphProfileBuilder> profile_builder_;
|
std::unique_ptr<GraphProfileBuilder> profile_builder_;
|
||||||
|
|
||||||
|
// The globally incrementing identifier for all graphs in a process.
|
||||||
|
static inline std::atomic_int next_instance_id_ = 0;
|
||||||
|
|
||||||
|
// A unique identifier for this object. Only unique within a process.
|
||||||
|
uint64_t graph_id_;
|
||||||
|
|
||||||
// For testing.
|
// For testing.
|
||||||
friend GraphProfilerTestPeer;
|
friend GraphProfilerTestPeer;
|
||||||
};
|
};
|
||||||
|
|
|
@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) {
|
||||||
"Cannot initialize .* multiple times.");
|
"Cannot initialize .* multiple times.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that graph identifiers are not reused, even after destruction.
|
||||||
|
TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) {
|
||||||
|
auto raw_graph_config = R"(
|
||||||
|
profiler_config {
|
||||||
|
enable_profiler: true
|
||||||
|
}
|
||||||
|
input_stream: "input_stream"
|
||||||
|
node {
|
||||||
|
calculator: "DummyTestCalculator"
|
||||||
|
input_stream: "input_stream"
|
||||||
|
})";
|
||||||
|
const int n_iterations = 100;
|
||||||
|
absl::flat_hash_set<int> seen_ids;
|
||||||
|
for (int i = 0; i < n_iterations; ++i) {
|
||||||
|
std::shared_ptr<ProfilingContext> profiler =
|
||||||
|
std::make_shared<ProfilingContext>();
|
||||||
|
auto graph_config = CreateGraphConfig(raw_graph_config);
|
||||||
|
mediapipe::ValidatedGraphConfig validated_graph;
|
||||||
|
QCHECK_OK(validated_graph.Initialize(graph_config));
|
||||||
|
profiler->Initialize(validated_graph);
|
||||||
|
|
||||||
|
int id = profiler->GetGraphId();
|
||||||
|
ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id)));
|
||||||
|
seen_ids.insert(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
// Tests that Pause(), Resume(), and Reset() works.
|
// Tests that Pause(), Resume(), and Reset() works.
|
||||||
TEST_F(GraphProfilerTestPeer, PauseResumeReset) {
|
TEST_F(GraphProfilerTestPeer, PauseResumeReset) {
|
||||||
InitializeProfilerWithGraphConfig(R"(
|
InitializeProfilerWithGraphConfig(R"(
|
||||||
|
|
|
@ -74,6 +74,9 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView(
|
||||||
static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
|
static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
|
||||||
const GlTextureView& view) {
|
const GlTextureView& view) {
|
||||||
CHECK(pixel_buffer);
|
CHECK(pixel_buffer);
|
||||||
|
auto ctx = GlContext::GetCurrent().get();
|
||||||
|
if (!ctx) ctx = view.gl_context();
|
||||||
|
ctx->Run([pixel_buffer, &view, ctx] {
|
||||||
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
||||||
CHECK(err == kCVReturnSuccess)
|
CHECK(err == kCVReturnSuccess)
|
||||||
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
||||||
|
@ -82,34 +85,40 @@ static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
|
||||||
uint8_t* pixel_ptr =
|
uint8_t* pixel_ptr =
|
||||||
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
||||||
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
||||||
// TODO: restore previous framebuffer? Move this to helper so we
|
glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx));
|
||||||
// can use BindFramebuffer?
|
|
||||||
glViewport(0, 0, view.width(), view.height());
|
glViewport(0, 0, view.width(), view.height());
|
||||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||||
view.name(), 0);
|
view.target(), view.name(), 0);
|
||||||
|
|
||||||
size_t contiguous_bytes_per_row = view.width() * 4;
|
size_t contiguous_bytes_per_row = view.width() * 4;
|
||||||
if (bytes_per_row == contiguous_bytes_per_row) {
|
if (bytes_per_row == contiguous_bytes_per_row) {
|
||||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
|
||||||
pixel_ptr);
|
GL_UNSIGNED_BYTE, pixel_ptr);
|
||||||
} else {
|
} else {
|
||||||
|
// TODO: use GL_PACK settings for row length. We can expect
|
||||||
|
// GLES 3.0 on iOS now.
|
||||||
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
||||||
view.height());
|
view.height());
|
||||||
uint8_t* temp_ptr = contiguous_buffer.data();
|
uint8_t* temp_ptr = contiguous_buffer.data();
|
||||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
|
||||||
temp_ptr);
|
GL_UNSIGNED_BYTE, temp_ptr);
|
||||||
for (int i = 0; i < view.height(); ++i) {
|
for (int i = 0; i < view.height(); ++i) {
|
||||||
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
||||||
temp_ptr += contiguous_bytes_per_row;
|
temp_ptr += contiguous_bytes_per_row;
|
||||||
pixel_ptr += bytes_per_row;
|
pixel_ptr += bytes_per_row;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// TODO: restore previous framebuffer?
|
||||||
|
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||||
|
view.target(), 0, 0);
|
||||||
|
glBindFramebuffer(GL_FRAMEBUFFER, 0);
|
||||||
} else {
|
} else {
|
||||||
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
||||||
}
|
}
|
||||||
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
||||||
CHECK(err == kCVReturnSuccess)
|
CHECK(err == kCVReturnSuccess)
|
||||||
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#endif // TARGET_IPHONE_SIMULATOR
|
#endif // TARGET_IPHONE_SIMULATOR
|
||||||
|
|
||||||
|
|
|
@ -71,9 +71,12 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(output_metadata_file))
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
filecmp.clear_cache()
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
filecmp.cmp(output_metadata_file,
|
filecmp.cmp(
|
||||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
|
output_metadata_file,
|
||||||
|
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
||||||
|
shallow=False))
|
||||||
|
|
||||||
def test_create_and_train_bert(self):
|
def test_create_and_train_bert(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
|
|
|
@ -135,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(output_metadata_file))
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
filecmp.clear_cache()
|
||||||
|
self.assertTrue(
|
||||||
|
filecmp.cmp(
|
||||||
|
output_metadata_file, expected_metadata_file, shallow=False))
|
||||||
|
|
||||||
def test_continual_training_by_loading_checkpoint(self):
|
def test_continual_training_by_loading_checkpoint(self):
|
||||||
mock_stdout = io.StringIO()
|
mock_stdout = io.StringIO()
|
||||||
|
|
|
@ -230,16 +230,17 @@ if ([wrapper.delegate
|
||||||
}
|
}
|
||||||
|
|
||||||
- (absl::Status)performStart {
|
- (absl::Status)performStart {
|
||||||
absl::Status status = _graph->Initialize(_config);
|
absl::Status status;
|
||||||
if (!status.ok()) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
for (const auto& service_packet : _servicePackets) {
|
for (const auto& service_packet : _servicePackets) {
|
||||||
status = _graph->SetServicePacket(*service_packet.first, service_packet.second);
|
status = _graph->SetServicePacket(*service_packet.first, service_packet.second);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
status = _graph->Initialize(_config);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
status = _graph->StartRun(_inputSidePackets, _streamHeaders);
|
status = _graph->StartRun(_inputSidePackets, _streamHeaders);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
|
||||||
auto custom_gestures_classifier_options_proto =
|
auto custom_gestures_classifier_options_proto =
|
||||||
std::make_unique<components::processors::proto::ClassifierOptions>(
|
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||||
components::processors::ConvertClassifierOptionsToProto(
|
components::processors::ConvertClassifierOptionsToProto(
|
||||||
&(options->canned_gestures_classifier_options)));
|
&(options->custom_gestures_classifier_options)));
|
||||||
hand_gesture_recognizer_graph_options
|
hand_gesture_recognizer_graph_options
|
||||||
->mutable_custom_gesture_classifier_graph_options()
|
->mutable_custom_gesture_classifier_graph_options()
|
||||||
->mutable_classifier_options()
|
->mutable_classifier_options()
|
||||||
->Swap(canned_gestures_classifier_options_proto.get());
|
->Swap(custom_gestures_classifier_options_proto.get());
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,4 +38,3 @@ objc_library(
|
||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
@ -56,6 +57,7 @@ extern NSString *const MPPTasksErrorDomain;
|
||||||
* @param status absl::Status.
|
* @param status absl::Status.
|
||||||
* @param error Pointer to the memory location where the created error should be saved. If `nil`,
|
* @param error Pointer to the memory location where the created error should be saved. If `nil`,
|
||||||
* no error will be saved.
|
* no error will be saved.
|
||||||
|
* @return YES when there is no error, NO otherwise.
|
||||||
*/
|
*/
|
||||||
+ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error;
|
+ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error;
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
#include "absl/status/status.h" // from @com_google_absl
|
#include "absl/status/status.h" // from @com_google_absl
|
||||||
#include "absl/strings/cord.h" // from @com_google_absl
|
#include "absl/strings/cord.h" // from @com_google_absl
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
|
||||||
/** Error domain of MediaPipe task library errors. */
|
/** Error domain of MediaPipe task library errors. */
|
||||||
|
@ -96,8 +95,8 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
// appropriate MPPTasksErrorCode in default cases. Note:
|
// appropriate MPPTasksErrorCode in default cases. Note:
|
||||||
// The mapping to absl::Status::code() is done to generate a more specific error code than
|
// The mapping to absl::Status::code() is done to generate a more specific error code than
|
||||||
// MPPTasksErrorCodeError in cases when the payload can't be mapped to
|
// MPPTasksErrorCodeError in cases when the payload can't be mapped to
|
||||||
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned
|
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn
|
||||||
// without modification by Mediapipe cc library methods.
|
// returned without modification by Mediapipe cc library methods.
|
||||||
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
|
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
|
||||||
switch (status.code()) {
|
switch (status.code()) {
|
||||||
case absl::StatusCode::kInternal:
|
case absl::StatusCode::kInternal:
|
||||||
|
|
|
@ -13,13 +13,14 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
@interface NSString (Helpers)
|
@interface NSString (Helpers)
|
||||||
|
|
||||||
@property(readonly) std::string cppString;
|
@property(readonly, nonatomic) std::string cppString;
|
||||||
|
|
||||||
+ (NSString *)stringWithCppString:(std::string)text;
|
+ (NSString *)stringWithCppString:(std::string)text;
|
||||||
|
|
||||||
|
|
|
@ -21,4 +21,3 @@ objc_library(
|
||||||
srcs = ["sources/MPPClassifierOptions.m"],
|
srcs = ["sources/MPPClassifierOptions.m"],
|
||||||
hdrs = ["sources/MPPClassifierOptions.h"],
|
hdrs = ["sources/MPPClassifierOptions.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -22,29 +22,34 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
NS_SWIFT_NAME(ClassifierOptions)
|
NS_SWIFT_NAME(ClassifierOptions)
|
||||||
@interface MPPClassifierOptions : NSObject <NSCopying>
|
@interface MPPClassifierOptions : NSObject <NSCopying>
|
||||||
|
|
||||||
/** The locale to use for display names specified through the TFLite Model
|
/**
|
||||||
|
* The locale to use for display names specified through the TFLite Model
|
||||||
* Metadata, if any. Defaults to English.
|
* Metadata, if any. Defaults to English.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic, copy) NSString *displayNamesLocale;
|
@property(nonatomic, copy) NSString *displayNamesLocale;
|
||||||
|
|
||||||
/** The maximum number of top-scored classification results to return. If < 0,
|
/**
|
||||||
|
* The maximum number of top-scored classification results to return. If < 0,
|
||||||
* all available results will be returned. If 0, an invalid argument error is
|
* all available results will be returned. If 0, an invalid argument error is
|
||||||
* returned.
|
* returned.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic) NSInteger maxResults;
|
@property(nonatomic) NSInteger maxResults;
|
||||||
|
|
||||||
/** Score threshold to override the one provided in the model metadata (if any).
|
/**
|
||||||
|
* Score threshold to override the one provided in the model metadata (if any).
|
||||||
* Results below this value are rejected.
|
* Results below this value are rejected.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic) float scoreThreshold;
|
@property(nonatomic) float scoreThreshold;
|
||||||
|
|
||||||
/** The allowlist of category names. If non-empty, detection results whose
|
/**
|
||||||
|
* The allowlist of category names. If non-empty, detection results whose
|
||||||
* category name is not in this set will be filtered out. Duplicate or unknown
|
* category name is not in this set will be filtered out. Duplicate or unknown
|
||||||
* category names are ignored. Mutually exclusive with categoryDenylist.
|
* category names are ignored. Mutually exclusive with categoryDenylist.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
|
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
|
||||||
|
|
||||||
/** The denylist of category names. If non-empty, detection results whose
|
/**
|
||||||
|
* The denylist of category names. If non-empty, detection results whose
|
||||||
* category name is in this set will be filtered out. Duplicate or unknown
|
* category name is in this set will be filtered out. Duplicate or unknown
|
||||||
* category names are ignored. Mutually exclusive with categoryAllowlist.
|
* category names are ignored. Mutually exclusive with categoryAllowlist.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
- (instancetype)init {
|
- (instancetype)init {
|
||||||
self = [super init];
|
self = [super init];
|
||||||
if (self) {
|
if (self) {
|
||||||
self.maxResults = -1;
|
_maxResults = -1;
|
||||||
self.scoreThreshold = 0;
|
_scoreThreshold = 0;
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,7 @@ objc_library(
|
||||||
hdrs = ["sources/MPPClassifierOptions+Helpers.h"],
|
hdrs = ["sources/MPPClassifierOptions+Helpers.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
|
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
]
|
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
|
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
|
@ -29,7 +29,6 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
|
||||||
}
|
}
|
||||||
|
|
||||||
classifierOptionsProto->set_max_results((int)self.maxResults);
|
classifierOptionsProto->set_max_results((int)self.maxResults);
|
||||||
|
|
||||||
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
|
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
|
||||||
|
|
||||||
for (NSString *category in self.categoryAllowlist) {
|
for (NSString *category in self.categoryAllowlist) {
|
||||||
|
|
|
@ -54,14 +54,14 @@ objc_library(
|
||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_options_cc_proto",
|
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
|
|
||||||
":MPPTaskOptions",
|
":MPPTaskOptions",
|
||||||
":MPPTaskOptionsProtocol",
|
":MPPTaskOptionsProtocol",
|
||||||
|
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,9 +83,13 @@ objc_library(
|
||||||
name = "MPPTaskRunner",
|
name = "MPPTaskRunner",
|
||||||
srcs = ["sources/MPPTaskRunner.mm"],
|
srcs = ["sources/MPPTaskRunner.mm"],
|
||||||
hdrs = ["sources/MPPTaskRunner.h"],
|
hdrs = ["sources/MPPTaskRunner.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,9 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||||
|
|
||||||
|
@ -59,7 +61,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
/**
|
/**
|
||||||
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.
|
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.
|
||||||
*/
|
*/
|
||||||
- (mediapipe::CalculatorGraphConfig)generateGraphConfig;
|
- (::mediapipe::CalculatorGraphConfig)generateGraphConfig;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@
|
||||||
namespace {
|
namespace {
|
||||||
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
|
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
|
||||||
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||||
using ::mediapipe::CalculatorOptions;
|
|
||||||
using ::mediapipe::FlowLimiterCalculatorOptions;
|
using ::mediapipe::FlowLimiterCalculatorOptions;
|
||||||
using ::mediapipe::InputStreamInfo;
|
using ::mediapipe::InputStreamInfo;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_options.pb.h"
|
#include "mediapipe/framework/calculator_options.pb.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
@ -25,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
/**
|
/**
|
||||||
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
|
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
|
||||||
*/
|
*/
|
||||||
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;
|
- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
/**
|
/**
|
||||||
* This class is used to create and call appropriate methods on the C++ Task Runner.
|
* This class is used to create and call appropriate methods on the C++ Task Runner.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@interface MPPTaskRunner : NSObject
|
@interface MPPTaskRunner : NSObject
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -35,11 +34,10 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
||||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)
|
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:
|
||||||
process:(const mediapipe::tasks::core::PacketMap &)packetMap
|
(const mediapipe::tasks::core::PacketMap &)packetMap;
|
||||||
error:(NSError **)error;
|
|
||||||
|
|
||||||
- (void)close;
|
- (absl::Status)close;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
using ::mediapipe::Packet;
|
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -49,8 +48,8 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||||
return _cppTaskRunner->Process(packetMap);
|
return _cppTaskRunner->Process(packetMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)close {
|
- (absl::Status)close {
|
||||||
_cppTaskRunner->Close();
|
return _cppTaskRunner->Close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
*
|
*
|
||||||
* @param options The options for the audio classifier.
|
* @param options The options for the audio classifier.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: AudioClassifierOptions): Promise<void> {
|
override setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(AUDIO_STREAM);
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
|
|
|
@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
audioClassifier = new AudioClassifierFake();
|
audioClassifier = new AudioClassifierFake();
|
||||||
await audioClassifier.setOptions({}); // Initialize graph
|
await audioClassifier.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
*
|
*
|
||||||
* @param options The options for the audio embedder.
|
* @param options The options for the audio embedder.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: AudioEmbedderOptions): Promise<void> {
|
override setOptions(options: AudioEmbedderOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(AUDIO_STREAM);
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
|
|
|
@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
audioEmbedder = new AudioEmbedderFake();
|
audioEmbedder = new AudioEmbedderFake();
|
||||||
await audioEmbedder.setOptions({}); // Initialize graph
|
await audioEmbedder.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', () => {
|
it('initializes graph', () => {
|
||||||
|
|
|
@ -103,29 +103,3 @@ jasmine_node_test(
|
||||||
name = "embedder_options_test",
|
name = "embedder_options_test",
|
||||||
deps = [":embedder_options_test_lib"],
|
deps = [":embedder_options_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
|
||||||
name = "base_options",
|
|
||||||
srcs = [
|
|
||||||
"base_options.ts",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
|
||||||
"//mediapipe/tasks/web/core",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_ts_library(
|
|
||||||
name = "base_options_test_lib",
|
|
||||||
testonly = True,
|
|
||||||
srcs = ["base_options.test.ts"],
|
|
||||||
deps = [":base_options"],
|
|
||||||
)
|
|
||||||
|
|
||||||
jasmine_node_test(
|
|
||||||
name = "base_options_test",
|
|
||||||
deps = [":base_options_test_lib"],
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,127 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import 'jasmine';
|
|
||||||
|
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
|
||||||
// Placeholder for internal dependency on trusted resource URL builder
|
|
||||||
|
|
||||||
import {convertBaseOptionsToProto} from './base_options';
|
|
||||||
|
|
||||||
describe('convertBaseOptionsToProto()', () => {
|
|
||||||
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
|
||||||
const mockBytesResult = {
|
|
||||||
modelAsset: {
|
|
||||||
fileContent: Buffer.from(mockBytes).toString('base64'),
|
|
||||||
fileName: undefined,
|
|
||||||
fileDescriptorMeta: undefined,
|
|
||||||
filePointerMeta: undefined,
|
|
||||||
},
|
|
||||||
useStreamMode: false,
|
|
||||||
acceleration: {
|
|
||||||
xnnpack: undefined,
|
|
||||||
gpu: undefined,
|
|
||||||
tflite: {},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let fetchSpy: jasmine.Spy;
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
|
||||||
expect(url).toEqual('foo');
|
|
||||||
return {
|
|
||||||
arrayBuffer: () => mockBytes.buffer,
|
|
||||||
} as unknown as Response;
|
|
||||||
});
|
|
||||||
global.fetch = fetchSpy;
|
|
||||||
});
|
|
||||||
|
|
||||||
it('verifies that at least one model asset option is provided', async () => {
|
|
||||||
await expectAsync(convertBaseOptionsToProto({}))
|
|
||||||
.toBeRejectedWithError(
|
|
||||||
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('verifies that no more than one model asset option is provided', async () => {
|
|
||||||
await expectAsync(convertBaseOptionsToProto({
|
|
||||||
modelAssetPath: `foo`,
|
|
||||||
modelAssetBuffer: new Uint8Array([])
|
|
||||||
}))
|
|
||||||
.toBeRejectedWithError(
|
|
||||||
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('downloads model', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetPath: `foo`,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(fetchSpy).toHaveBeenCalled();
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('does not download model when bytes are provided', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(fetchSpy).not.toHaveBeenCalled();
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('can enable CPU delegate', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
delegate: 'CPU',
|
|
||||||
});
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('can enable GPU delegate', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
delegate: 'GPU',
|
|
||||||
});
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual({
|
|
||||||
...mockBytesResult,
|
|
||||||
acceleration: {
|
|
||||||
xnnpack: undefined,
|
|
||||||
gpu: {
|
|
||||||
useAdvancedGpuApi: false,
|
|
||||||
api: 0,
|
|
||||||
allowPrecisionLoss: true,
|
|
||||||
cachedKernelPath: undefined,
|
|
||||||
serializedModelDir: undefined,
|
|
||||||
modelToken: undefined,
|
|
||||||
usage: 2,
|
|
||||||
},
|
|
||||||
tflite: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('can reset delegate', async () => {
|
|
||||||
let baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
delegate: 'GPU',
|
|
||||||
});
|
|
||||||
// Clear backend
|
|
||||||
baseOptionsProto =
|
|
||||||
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
});
|
|
|
@ -1,80 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
|
||||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
|
||||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
|
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a BaseOptions API object to its Protobuf representation.
|
|
||||||
* @throws If neither a model assset path or buffer is provided
|
|
||||||
*/
|
|
||||||
export async function convertBaseOptionsToProto(
|
|
||||||
updatedOptions: BaseOptions,
|
|
||||||
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
|
||||||
const result =
|
|
||||||
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
|
||||||
|
|
||||||
await configureExternalFile(updatedOptions, result);
|
|
||||||
configureAcceleration(updatedOptions, result);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configues the `externalFile` option and validates that a single model is
|
|
||||||
* provided.
|
|
||||||
*/
|
|
||||||
async function configureExternalFile(
|
|
||||||
options: BaseOptions, proto: BaseOptionsProto) {
|
|
||||||
const externalFile = proto.getModelAsset() || new ExternalFile();
|
|
||||||
proto.setModelAsset(externalFile);
|
|
||||||
|
|
||||||
if (options.modelAssetPath || options.modelAssetBuffer) {
|
|
||||||
if (options.modelAssetPath && options.modelAssetBuffer) {
|
|
||||||
throw new Error(
|
|
||||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
|
||||||
}
|
|
||||||
|
|
||||||
let modelAssetBuffer = options.modelAssetBuffer;
|
|
||||||
if (!modelAssetBuffer) {
|
|
||||||
const response = await fetch(options.modelAssetPath!.toString());
|
|
||||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
|
||||||
}
|
|
||||||
externalFile.setFileContent(modelAssetBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!externalFile.hasFileContent()) {
|
|
||||||
throw new Error(
|
|
||||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Configues the `acceleration` option. */
|
|
||||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
|
||||||
const acceleration = proto.getAcceleration() ?? new Acceleration();
|
|
||||||
if (options.delegate === 'GPU') {
|
|
||||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
|
||||||
} else {
|
|
||||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
|
||||||
}
|
|
||||||
proto.setAcceleration(acceleration);
|
|
||||||
}
|
|
|
@ -18,8 +18,10 @@ mediapipe_ts_library(
|
||||||
srcs = ["task_runner.ts"],
|
srcs = ["task_runner.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
||||||
|
@ -53,6 +55,7 @@ mediapipe_ts_library(
|
||||||
"task_runner_test.ts",
|
"task_runner_test.ts",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":core",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":task_runner_test_utils",
|
":task_runner_test_utils",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
|
|
@ -14,9 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||||
|
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
|
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
|
||||||
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||||
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||||
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
||||||
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||||
|
@ -91,14 +93,52 @@ export abstract class TaskRunner {
|
||||||
this.graphRunner.registerModelResourcesGraphService();
|
this.graphRunner.registerModelResourcesGraphService();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Configures the shared options of a MediaPipe Task. */
|
/** Configures the task with custom options. */
|
||||||
async setOptions(options: TaskRunnerOptions): Promise<void> {
|
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
|
||||||
if (options.baseOptions) {
|
|
||||||
this.baseOptions = await convertBaseOptionsToProto(
|
/**
|
||||||
options.baseOptions, this.baseOptions);
|
* Applies the current set of options, including any base options that have
|
||||||
|
* not been processed by the task implementation. The options are applied
|
||||||
|
* synchronously unless a `modelAssetPath` is provided. This ensures that
|
||||||
|
* for most use cases options are applied directly and immediately affect
|
||||||
|
* the next inference.
|
||||||
|
*/
|
||||||
|
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
|
||||||
|
const baseOptions: BaseOptions = options.baseOptions || {};
|
||||||
|
|
||||||
|
// Validate that exactly one model is configured
|
||||||
|
if (options.baseOptions?.modelAssetBuffer &&
|
||||||
|
options.baseOptions?.modelAssetPath) {
|
||||||
|
throw new Error(
|
||||||
|
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||||
|
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||||
|
options.baseOptions?.modelAssetBuffer ||
|
||||||
|
options.baseOptions?.modelAssetPath)) {
|
||||||
|
throw new Error(
|
||||||
|
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setAcceleration(baseOptions);
|
||||||
|
if (baseOptions.modelAssetPath) {
|
||||||
|
// We don't use `await` here since we want to apply most settings
|
||||||
|
// synchronously.
|
||||||
|
return fetch(baseOptions.modelAssetPath.toString())
|
||||||
|
.then(response => response.arrayBuffer())
|
||||||
|
.then(buffer => {
|
||||||
|
this.setExternalFile(new Uint8Array(buffer));
|
||||||
|
this.refreshGraph();
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Apply the setting synchronously.
|
||||||
|
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||||
|
this.refreshGraph();
|
||||||
|
return Promise.resolve();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Appliest the current options to the MediaPipe graph. */
|
||||||
|
protected abstract refreshGraph(): void;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||||
* over the video stream. Will replace the previously running MediaPipe graph,
|
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||||
|
@ -140,6 +180,27 @@ export abstract class TaskRunner {
|
||||||
}
|
}
|
||||||
this.processingErrors = [];
|
this.processingErrors = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Configures the `externalFile` option */
|
||||||
|
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
|
||||||
|
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
||||||
|
if (modelAssetBuffer) {
|
||||||
|
externalFile.setFileContent(modelAssetBuffer);
|
||||||
|
}
|
||||||
|
this.baseOptions.setModelAsset(externalFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Configures the `acceleration` option. */
|
||||||
|
private setAcceleration(options: BaseOptions) {
|
||||||
|
const acceleration =
|
||||||
|
this.baseOptions.getAcceleration() ?? new Acceleration();
|
||||||
|
if (options.delegate === 'GPU') {
|
||||||
|
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||||
|
} else {
|
||||||
|
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||||
|
}
|
||||||
|
this.baseOptions.setAcceleration(acceleration);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,18 +15,22 @@
|
||||||
*/
|
*/
|
||||||
import 'jasmine';
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||||
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||||
|
// Placeholder for internal dependency on trusted resource URL builder
|
||||||
|
|
||||||
import {GraphRunnerImageLib} from './task_runner';
|
import {GraphRunnerImageLib} from './task_runner';
|
||||||
|
import {TaskRunnerOptions} from './task_runner_options.d';
|
||||||
|
|
||||||
class TaskRunnerFake extends TaskRunner {
|
class TaskRunnerFake extends TaskRunner {
|
||||||
protected baseOptions = new BaseOptionsProto();
|
|
||||||
private errorListener: ErrorListener|undefined;
|
private errorListener: ErrorListener|undefined;
|
||||||
private errors: string[] = [];
|
private errors: string[] = [];
|
||||||
|
|
||||||
|
baseOptions = new BaseOptionsProto();
|
||||||
|
|
||||||
static createFake(): TaskRunnerFake {
|
static createFake(): TaskRunnerFake {
|
||||||
const wasmModule = createSpyWasmModule();
|
const wasmModule = createSpyWasmModule();
|
||||||
return new TaskRunnerFake(wasmModule);
|
return new TaskRunnerFake(wasmModule);
|
||||||
|
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
|
||||||
super.finishProcessing();
|
super.finishProcessing();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override refreshGraph(): void {}
|
||||||
|
|
||||||
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||||
super.setGraph(graphData, isBinary);
|
super.setGraph(graphData, isBinary);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOptions(options: TaskRunnerOptions): Promise<void> {
|
||||||
|
return this.applyOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
private throwErrors(): void {
|
private throwErrors(): void {
|
||||||
expect(this.errorListener).toBeDefined();
|
expect(this.errorListener).toBeDefined();
|
||||||
for (const error of this.errors) {
|
for (const error of this.errors) {
|
||||||
|
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('TaskRunner', () => {
|
describe('TaskRunner', () => {
|
||||||
|
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
||||||
|
const mockBytesResult = {
|
||||||
|
modelAsset: {
|
||||||
|
fileContent: Buffer.from(mockBytes).toString('base64'),
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined,
|
||||||
|
},
|
||||||
|
useStreamMode: false,
|
||||||
|
acceleration: {
|
||||||
|
xnnpack: undefined,
|
||||||
|
gpu: undefined,
|
||||||
|
tflite: {},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let fetchSpy: jasmine.Spy;
|
||||||
|
let taskRunner: TaskRunnerFake;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
||||||
|
expect(url).toEqual('foo');
|
||||||
|
return {
|
||||||
|
arrayBuffer: () => mockBytes.buffer,
|
||||||
|
} as unknown as Response;
|
||||||
|
});
|
||||||
|
global.fetch = fetchSpy;
|
||||||
|
|
||||||
|
taskRunner = TaskRunnerFake.createFake();
|
||||||
|
});
|
||||||
|
|
||||||
it('handles errors during graph update', () => {
|
it('handles errors during graph update', () => {
|
||||||
const taskRunner = TaskRunnerFake.createFake();
|
|
||||||
taskRunner.enqueueError('Test error');
|
taskRunner.enqueueError('Test error');
|
||||||
|
|
||||||
expect(() => {
|
expect(() => {
|
||||||
|
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('handles errors during graph execution', () => {
|
it('handles errors during graph execution', () => {
|
||||||
const taskRunner = TaskRunnerFake.createFake();
|
|
||||||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||||
|
|
||||||
taskRunner.enqueueError('Test error');
|
taskRunner.enqueueError('Test error');
|
||||||
|
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can handle multiple errors', () => {
|
it('can handle multiple errors', () => {
|
||||||
const taskRunner = TaskRunnerFake.createFake();
|
|
||||||
taskRunner.enqueueError('Test error 1');
|
taskRunner.enqueueError('Test error 1');
|
||||||
taskRunner.enqueueError('Test error 2');
|
taskRunner.enqueueError('Test error 2');
|
||||||
|
|
||||||
|
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
|
||||||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||||
}).toThrowError(/Test error 1, Test error 2/);
|
}).toThrowError(/Test error 1, Test error 2/);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('verifies that at least one model asset option is provided', () => {
|
||||||
|
expect(() => {
|
||||||
|
taskRunner.setOptions({});
|
||||||
|
})
|
||||||
|
.toThrowError(
|
||||||
|
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('verifies that no more than one model asset option is provided', () => {
|
||||||
|
expect(() => {
|
||||||
|
taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetPath: `foo`,
|
||||||
|
modelAssetBuffer: new Uint8Array([])
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.toThrowError(
|
||||||
|
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('doesn\'t require model once it is configured', async () => {
|
||||||
|
await taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||||
|
expect(() => {
|
||||||
|
taskRunner.setOptions({});
|
||||||
|
}).not.toThrowError();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('downloads model', async () => {
|
||||||
|
await taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetPath: `foo`}});
|
||||||
|
|
||||||
|
expect(fetchSpy).toHaveBeenCalled();
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not download model when bytes are provided', async () => {
|
||||||
|
await taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||||
|
|
||||||
|
expect(fetchSpy).not.toHaveBeenCalled();
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('changes model synchronously when bytes are provided', () => {
|
||||||
|
const resolvedPromise = taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||||
|
|
||||||
|
// Check that the change has been applied even though we do not await the
|
||||||
|
// above Promise
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
return resolvedPromise;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can enable CPU delegate', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'CPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can enable GPU delegate', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'GPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual({
|
||||||
|
...mockBytesResult,
|
||||||
|
acceleration: {
|
||||||
|
xnnpack: undefined,
|
||||||
|
gpu: {
|
||||||
|
useAdvancedGpuApi: false,
|
||||||
|
api: 0,
|
||||||
|
allowPrecisionLoss: true,
|
||||||
|
cachedKernelPath: undefined,
|
||||||
|
serializedModelDir: undefined,
|
||||||
|
modelToken: undefined,
|
||||||
|
usage: 2,
|
||||||
|
},
|
||||||
|
tflite: undefined,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can reset delegate', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'GPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Clear backend
|
||||||
|
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule {
|
||||||
* Sets up our equality testing to use a custom float equality checking function
|
* Sets up our equality testing to use a custom float equality checking function
|
||||||
* to avoid incorrect test results due to minor floating point inaccuracies.
|
* to avoid incorrect test results due to minor floating point inaccuracies.
|
||||||
*/
|
*/
|
||||||
export function addJasmineCustomFloatEqualityTester() {
|
export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) {
|
||||||
jasmine.addCustomEqualityTester((a, b) => { // Custom float equality
|
jasmine.addCustomEqualityTester((a, b) => { // Custom float equality
|
||||||
if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) {
|
if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) {
|
||||||
return Math.abs(a - b) < 5e-8;
|
return Math.abs(a - b) < tolerance;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
});
|
});
|
||||||
|
|
|
@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
|
||||||
*
|
*
|
||||||
* @param options The options for the text classifier.
|
* @param options The options for the text classifier.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: TextClassifierOptions): Promise<void> {
|
override setOptions(options: TextClassifierOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
|
@ -56,7 +56,8 @@ describe('TextClassifier', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
textClassifier = new TextClassifierFake();
|
textClassifier = new TextClassifierFake();
|
||||||
await textClassifier.setOptions({}); // Initialize graph
|
await textClassifier.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
|
||||||
*
|
*
|
||||||
* @param options The options for the text embedder.
|
* @param options The options for the text embedder.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: TextEmbedderOptions): Promise<void> {
|
override setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||||
|
|
|
@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
textEmbedder = new TextEmbedderFake();
|
textEmbedder = new TextEmbedderFake();
|
||||||
await textEmbedder.setOptions({}); // Initialize graph
|
await textEmbedder.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -29,6 +29,7 @@ mediapipe_ts_library(
|
||||||
testonly = True,
|
testonly = True,
|
||||||
srcs = ["vision_task_runner.test.ts"],
|
srcs = ["vision_task_runner.test.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":vision_task_options",
|
||||||
":vision_task_runner",
|
":vision_task_runner",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
|
|
@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
||||||
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
|
||||||
|
import {VisionTaskOptions} from './vision_task_options';
|
||||||
import {VisionTaskRunner} from './vision_task_runner';
|
import {VisionTaskRunner} from './vision_task_runner';
|
||||||
|
|
||||||
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
|
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
|
|
||||||
protected override process(): void {}
|
protected override process(): void {}
|
||||||
|
|
||||||
|
protected override refreshGraph(): void {}
|
||||||
|
|
||||||
|
override setOptions(options: VisionTaskOptions): Promise<void> {
|
||||||
|
return this.applyOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
override processImageData(image: ImageSource): void {
|
override processImageData(image: ImageSource): void {
|
||||||
super.processImageData(image);
|
super.processImageData(image);
|
||||||
}
|
}
|
||||||
|
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('VisionTaskRunner', () => {
|
describe('VisionTaskRunner', () => {
|
||||||
const streamMode = {
|
|
||||||
modelAsset: undefined,
|
|
||||||
useStreamMode: true,
|
|
||||||
acceleration: undefined,
|
|
||||||
};
|
|
||||||
|
|
||||||
const imageMode = {
|
|
||||||
modelAsset: undefined,
|
|
||||||
useStreamMode: false,
|
|
||||||
acceleration: undefined,
|
|
||||||
};
|
|
||||||
|
|
||||||
let visionTaskRunner: VisionTaskRunnerFake;
|
let visionTaskRunner: VisionTaskRunnerFake;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(async () => {
|
||||||
visionTaskRunner = new VisionTaskRunnerFake();
|
visionTaskRunner = new VisionTaskRunnerFake();
|
||||||
|
await visionTaskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can enable image mode', async () => {
|
it('can enable image mode', async () => {
|
||||||
await visionTaskRunner.setOptions({runningMode: 'image'});
|
await visionTaskRunner.setOptions({runningMode: 'image'});
|
||||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
expect(visionTaskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can enable video mode', async () => {
|
it('can enable video mode', async () => {
|
||||||
await visionTaskRunner.setOptions({runningMode: 'video'});
|
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
|
expect(visionTaskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(jasmine.objectContaining({useStreamMode: true}));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can clear running mode', async () => {
|
it('can clear running mode', async () => {
|
||||||
|
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
|
||||||
|
|
||||||
// Clear running mode
|
// Clear running mode
|
||||||
await visionTaskRunner.setOptions({runningMode: undefined});
|
await visionTaskRunner.setOptions({runningMode: undefined});
|
||||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
expect(visionTaskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('cannot process images with video mode', async () => {
|
it('cannot process images with video mode', async () => {
|
||||||
|
|
|
@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
|
||||||
/** Base class for all MediaPipe Vision Tasks. */
|
/** Base class for all MediaPipe Vision Tasks. */
|
||||||
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
||||||
/** Configures the shared options of a vision task. */
|
/** Configures the shared options of a vision task. */
|
||||||
override async setOptions(options: VisionTaskOptions): Promise<void> {
|
override applyOptions(options: VisionTaskOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
if ('runningMode' in options) {
|
if ('runningMode' in options) {
|
||||||
const useStreamMode =
|
const useStreamMode =
|
||||||
!!options.runningMode && options.runningMode !== 'image';
|
!!options.runningMode && options.runningMode !== 'image';
|
||||||
this.baseOptions.setUseStreamMode(useStreamMode);
|
this.baseOptions.setUseStreamMode(useStreamMode);
|
||||||
}
|
}
|
||||||
|
return super.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sends an image packet to the graph and awaits results. */
|
/** Sends an image packet to the graph and awaits results. */
|
||||||
|
|
|
@ -169,9 +169,7 @@ export class GestureRecognizer extends
|
||||||
*
|
*
|
||||||
* @param options The options for the gesture recognizer.
|
* @param options The options for the gesture recognizer.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
override setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
|
|
||||||
if ('numHands' in options) {
|
if ('numHands' in options) {
|
||||||
this.handDetectorGraphOptions.setNumHands(
|
this.handDetectorGraphOptions.setNumHands(
|
||||||
options.numHands ?? DEFAULT_NUM_HANDS);
|
options.numHands ?? DEFAULT_NUM_HANDS);
|
||||||
|
@ -221,7 +219,7 @@ export class GestureRecognizer extends
|
||||||
?.clearClassifierOptions();
|
?.clearClassifierOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -265,6 +263,15 @@ export class GestureRecognizer extends
|
||||||
NORM_RECT_STREAM, timestamp);
|
NORM_RECT_STREAM, timestamp);
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
|
|
||||||
|
if (this.gestures.length === 0) {
|
||||||
|
// If no gestures are detected in the image, just return an empty list
|
||||||
|
return {
|
||||||
|
gestures: [],
|
||||||
|
landmarks: [],
|
||||||
|
worldLandmarks: [],
|
||||||
|
handednesses: [],
|
||||||
|
};
|
||||||
|
} else {
|
||||||
return {
|
return {
|
||||||
gestures: this.gestures,
|
gestures: this.gestures,
|
||||||
landmarks: this.landmarks,
|
landmarks: this.landmarks,
|
||||||
|
@ -272,6 +279,7 @@ export class GestureRecognizer extends
|
||||||
handednesses: this.handednesses
|
handednesses: this.handednesses
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Sets the default values for the graph. */
|
/** Sets the default values for the graph. */
|
||||||
private initDefaults(): void {
|
private initDefaults(): void {
|
||||||
|
@ -285,15 +293,19 @@ export class GestureRecognizer extends
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Converts the proto data to a Category[][] structure. */
|
/** Converts the proto data to a Category[][] structure. */
|
||||||
private toJsCategories(data: Uint8Array[]): Category[][] {
|
private toJsCategories(data: Uint8Array[], populateIndex = true):
|
||||||
|
Category[][] {
|
||||||
const result: Category[][] = [];
|
const result: Category[][] = [];
|
||||||
for (const binaryProto of data) {
|
for (const binaryProto of data) {
|
||||||
const inputList = ClassificationList.deserializeBinary(binaryProto);
|
const inputList = ClassificationList.deserializeBinary(binaryProto);
|
||||||
const outputList: Category[] = [];
|
const outputList: Category[] = [];
|
||||||
for (const classification of inputList.getClassificationList()) {
|
for (const classification of inputList.getClassificationList()) {
|
||||||
|
const index = populateIndex && classification.hasIndex() ?
|
||||||
|
classification.getIndex()! :
|
||||||
|
DEFAULT_CATEGORY_INDEX;
|
||||||
outputList.push({
|
outputList.push({
|
||||||
score: classification.getScore() ?? 0,
|
score: classification.getScore() ?? 0,
|
||||||
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX,
|
index,
|
||||||
categoryName: classification.getLabel() ?? '',
|
categoryName: classification.getLabel() ?? '',
|
||||||
displayName: classification.getDisplayName() ?? '',
|
displayName: classification.getDisplayName() ?? '',
|
||||||
});
|
});
|
||||||
|
@ -342,7 +354,7 @@ export class GestureRecognizer extends
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(IMAGE_STREAM);
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
|
@ -377,7 +389,10 @@ export class GestureRecognizer extends
|
||||||
});
|
});
|
||||||
this.graphRunner.attachProtoVectorListener(
|
this.graphRunner.attachProtoVectorListener(
|
||||||
HAND_GESTURES_STREAM, binaryProto => {
|
HAND_GESTURES_STREAM, binaryProto => {
|
||||||
this.gestures.push(...this.toJsCategories(binaryProto));
|
// Gesture index is not used, because the final gesture result comes
|
||||||
|
// from multiple classifiers.
|
||||||
|
this.gestures.push(
|
||||||
|
...this.toJsCategories(binaryProto, /* populateIndex= */ false));
|
||||||
});
|
});
|
||||||
this.graphRunner.attachProtoVectorListener(
|
this.graphRunner.attachProtoVectorListener(
|
||||||
HANDEDNESS_STREAM, binaryProto => {
|
HANDEDNESS_STREAM, binaryProto => {
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||||
|
|
||||||
|
export {Category, Landmark, NormalizedLandmark};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents the gesture recognition results generated by `GestureRecognizer`.
|
* Represents the gesture recognition results generated by `GestureRecognizer`.
|
||||||
*/
|
*/
|
||||||
|
@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult {
|
||||||
/** Handedness of detected hands. */
|
/** Handedness of detected hands. */
|
||||||
handednesses: Category[][];
|
handednesses: Category[][];
|
||||||
|
|
||||||
/** Recognized hand gestures of detected hands */
|
/**
|
||||||
|
* Recognized hand gestures of detected hands. Note that the index of the
|
||||||
|
* gesture is always -1, because the raw indices from multiple gesture
|
||||||
|
* classifiers cannot consolidate to a meaningful index.
|
||||||
|
*/
|
||||||
gestures: Category[][];
|
gestures: Category[][];
|
||||||
}
|
}
|
||||||
|
|
|
@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
gestureRecognizer = new GestureRecognizerFake();
|
gestureRecognizer = new GestureRecognizerFake();
|
||||||
await gestureRecognizer.setOptions({}); // Initialize graph
|
await gestureRecognizer.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
@ -271,7 +272,7 @@ describe('GestureRecognizer', () => {
|
||||||
expect(gestures).toEqual({
|
expect(gestures).toEqual({
|
||||||
'gestures': [[{
|
'gestures': [[{
|
||||||
'score': 0.2,
|
'score': 0.2,
|
||||||
'index': 2,
|
'index': -1,
|
||||||
'categoryName': 'gesture_label',
|
'categoryName': 'gesture_label',
|
||||||
'displayName': 'gesture_display_name'
|
'displayName': 'gesture_display_name'
|
||||||
}]],
|
}]],
|
||||||
|
@ -304,4 +305,25 @@ describe('GestureRecognizer', () => {
|
||||||
// gestures.
|
// gestures.
|
||||||
expect(gestures2).toEqual(gestures1);
|
expect(gestures2).toEqual(gestures1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('returns empty results when no gestures are detected', async () => {
|
||||||
|
// Pass the test data to our listener
|
||||||
|
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(gestureRecognizer);
|
||||||
|
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
|
||||||
|
gestureRecognizer.listeners.get('world_hand_landmarks')!
|
||||||
|
(createWorldLandmarks());
|
||||||
|
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
|
||||||
|
gestureRecognizer.listeners.get('hand_gestures')!([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the gesture recognizer
|
||||||
|
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
|
||||||
|
expect(gestures).toEqual({
|
||||||
|
'gestures': [],
|
||||||
|
'landmarks': [],
|
||||||
|
'worldLandmarks': [],
|
||||||
|
'handednesses': []
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
*
|
*
|
||||||
* @param options The options for the hand landmarker.
|
* @param options The options for the hand landmarker.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: HandLandmarkerOptions): Promise<void> {
|
override setOptions(options: HandLandmarkerOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
|
|
||||||
// Configure hand detector options.
|
// Configure hand detector options.
|
||||||
if ('numHands' in options) {
|
if ('numHands' in options) {
|
||||||
this.handDetectorGraphOptions.setNumHands(
|
this.handDetectorGraphOptions.setNumHands(
|
||||||
|
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||||
}
|
}
|
||||||
|
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(IMAGE_STREAM);
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||||
|
|
||||||
|
export {Landmark, NormalizedLandmark, Category};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents the hand landmarks deection results generated by `HandLandmarker`.
|
* Represents the hand landmarks deection results generated by `HandLandmarker`.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
handLandmarker = new HandLandmarkerFake();
|
handLandmarker = new HandLandmarkerFake();
|
||||||
await handLandmarker.setOptions({}); // Initialize graph
|
await handLandmarker.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
*
|
*
|
||||||
* @param options The options for the image classifier.
|
* @param options The options for the image classifier.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: ImageClassifierOptions): Promise<void> {
|
override setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
|
@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
imageClassifier = new ImageClassifierFake();
|
imageClassifier = new ImageClassifierFake();
|
||||||
await imageClassifier.setOptions({}); // Initialize graph
|
await imageClassifier.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
*
|
*
|
||||||
* @param options The options for the image embedder.
|
* @param options The options for the image embedder.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: ImageEmbedderOptions): Promise<void> {
|
override setOptions(options: ImageEmbedderOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||||
|
|
|
@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
imageEmbedder = new ImageEmbedderFake();
|
imageEmbedder = new ImageEmbedderFake();
|
||||||
await imageEmbedder.setOptions({}); // Initialize graph
|
await imageEmbedder.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
*
|
*
|
||||||
* @param options The options for the object detector.
|
* @param options The options for the object detector.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
override setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
|
|
||||||
// Note that we have to support both JSPB and ProtobufJS, hence we
|
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||||
// have to expliclity clear the values instead of setting them to
|
// have to expliclity clear the values instead of setting them to
|
||||||
// `undefined`.
|
// `undefined`.
|
||||||
|
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
this.options.clearCategoryDenylistList();
|
this.options.clearCategoryDenylistList();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
||||||
|
|
|
@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
objectDetector = new ObjectDetectorFake();
|
objectDetector = new ObjectDetectorFake();
|
||||||
await objectDetector.setOptions({}); // Initialize graph
|
await objectDetector.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user