Merge branch 'master' into ios-task
This commit is contained in:
parent
7e0fec7c28
commit
7ce21038bb
3
.github/bot_config.yml
vendored
3
.github/bot_config.yml
vendored
|
@ -15,4 +15,5 @@
|
|||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- sureshdagooglecom
|
||||
- kuaashish
|
||||
- ayushgdev
|
||||
|
|
|
@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) {
|
|||
}
|
||||
|
||||
// If options.convert_signature_to_tags() is set, will convert letters to
|
||||
// uppercase and replace /'s and -'s with _'s. This enables the standard
|
||||
// uppercase and replace /, -, . and :'s with _'s. This enables the standard
|
||||
// SavedModel classification, regression, and prediction signatures to be used
|
||||
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
|
||||
// patterns.
|
||||
|
@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag(
|
|||
output.resize(name.length());
|
||||
std::transform(name.begin(), name.end(), output.begin(),
|
||||
[](unsigned char c) { return std::toupper(c); });
|
||||
output = absl::StrReplaceAll(output, {{"/", "_"}});
|
||||
output = absl::StrReplaceAll(output, {{"-", "_"}});
|
||||
output = absl::StrReplaceAll(output, {{".", "_"}});
|
||||
output = absl::StrReplaceAll(
|
||||
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
|
||||
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
|
||||
return output;
|
||||
} else {
|
||||
|
|
|
@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions {
|
|||
// The name of the generic signature to load into the mapping from tags to
|
||||
// tensor names.
|
||||
optional string signature_name = 2 [default = "serving_default"];
|
||||
// Whether to convert the signature keys to uppercase as well as switch /'s
|
||||
// and -'s to _'s, which enables common signatures to be used as Tags.
|
||||
// Whether to convert the signature keys to uppercase as well as switch
|
||||
// /, -, .and :'s to _'s, which enables common signatures to be used as Tags.
|
||||
optional bool convert_signature_to_tags = 3 [default = true];
|
||||
// If true, saved_model_path can have multiple exported models in
|
||||
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
|
||||
|
|
|
@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) {
|
|||
}
|
||||
|
||||
// If options.convert_signature_to_tags() is set, will convert letters to
|
||||
// uppercase and replace /'s and -'s with _'s. This enables the standard
|
||||
// uppercase and replace /, -, and .'s with _'s. This enables the standard
|
||||
// SavedModel classification, regression, and prediction signatures to be used
|
||||
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
|
||||
// patterns.
|
||||
|
@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag(
|
|||
output.resize(name.length());
|
||||
std::transform(name.begin(), name.end(), output.begin(),
|
||||
[](unsigned char c) { return std::toupper(c); });
|
||||
output = absl::StrReplaceAll(output, {{"/", "_"}});
|
||||
output = absl::StrReplaceAll(output, {{"-", "_"}});
|
||||
output = absl::StrReplaceAll(output, {{".", "_"}});
|
||||
output = absl::StrReplaceAll(
|
||||
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
|
||||
LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
|
||||
return output;
|
||||
} else {
|
||||
|
|
|
@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions {
|
|||
// The name of the generic signature to load into the mapping from tags to
|
||||
// tensor names.
|
||||
optional string signature_name = 2 [default = "serving_default"];
|
||||
// Whether to convert the signature keys to uppercase as well as switch /'s
|
||||
// and -'s to _'s, which enables common signatures to be used as Tags.
|
||||
// Whether to convert the signature keys to uppercase, as well as switch /'s
|
||||
// -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags.
|
||||
optional bool convert_signature_to_tags = 3 [default = true];
|
||||
// If true, saved_model_path can have multiple exported models in
|
||||
// subdirectories saved_model_path/%08d and the alphabetically last (i.e.,
|
||||
|
|
|
@ -30,6 +30,10 @@ proto_library(
|
|||
|
||||
java_lite_proto_library(
|
||||
name = "autoflip_messages_java_proto_lite",
|
||||
visibility = [
|
||||
"//java/com/google/android/apps/photos:__subpackages__",
|
||||
"//javatests/com/google/android/apps/photos:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":autoflip_messages_proto",
|
||||
],
|
||||
|
|
|
@ -455,7 +455,7 @@ cc_library(
|
|||
],
|
||||
}),
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//mediapipe/framework:port",
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/formats/tensor_internal.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
|
@ -434,8 +434,9 @@ class Tensor {
|
|||
mutable bool use_ahwb_ = false;
|
||||
mutable uint64_t ahwb_tracking_key_ = 0;
|
||||
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
|
||||
// can be more predicted.
|
||||
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_;
|
||||
// (Least Recently Used) can be more predicted.
|
||||
// The value contains the size alignment parameter.
|
||||
static inline absl::flat_hash_map<uint64_t, int> ahwb_usage_track_;
|
||||
// Expects the target SSBO to be already bound.
|
||||
bool AllocateAhwbMapToSsbo() const;
|
||||
bool InsertAhwbToSsboFence() const;
|
||||
|
|
|
@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
|
|||
|
||||
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
|
||||
// Mark current tracking key as Ahwb-use.
|
||||
ahwb_usage_track_.insert(ahwb_tracking_key_);
|
||||
if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_);
|
||||
it != ahwb_usage_track_.end()) {
|
||||
size_alignment = it->second;
|
||||
} else if (ahwb_tracking_key_ != 0) {
|
||||
ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment});
|
||||
}
|
||||
use_ahwb_ = true;
|
||||
|
||||
if (__builtin_available(android 26, *)) {
|
||||
|
@ -458,7 +463,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
|
|||
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
|
||||
}
|
||||
}
|
||||
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
|
||||
// Keep flag value if it was set previously.
|
||||
use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_);
|
||||
}
|
||||
|
||||
#else // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
|
|
@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
|
|||
};
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
{
|
||||
// Request Ahwb first to get Ahwb storage allocated internally.
|
||||
auto view = tensor.GetAHardwareBufferWriteView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||
}
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
|
@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
|||
}
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
|
||||
{
|
||||
// Request Ahwb first to get Ahwb storage allocated internally.
|
||||
auto view = tensor.GetAHardwareBufferWriteView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
|
@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
|||
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
||||
// Request the CPU view to get the memory to be allocated.
|
||||
// Request Ahwb view then to transform the storage into Ahwb.
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
{
|
||||
|
@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
|||
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
|
||||
// Request the GPU view to get the ssbo allocated internally.
|
||||
// Request Ahwb view then to transform the storage into Ahwb.
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
RunInGlContext([&tensor] {
|
||||
|
|
|
@ -1,34 +1,28 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/gpu/gpu_test_base.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class TensorAhwbTest : public mediapipe::GpuTestBase {
|
||||
public:
|
||||
};
|
||||
|
||||
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
|
||||
TEST(TensorAhwbTest, TestCpuThenAHWB) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
}
|
||||
{
|
||||
auto ahwb = tensor.GetAHardwareBufferReadView().handle();
|
||||
EXPECT_NE(ahwb, nullptr);
|
||||
auto view = tensor.GetAHardwareBufferReadView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
||||
TEST(TensorAhwbTest, TestAHWBThenCpu) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ahwb = tensor.GetAHardwareBufferWriteView().handle();
|
||||
EXPECT_NE(ahwb, nullptr);
|
||||
auto view = tensor.GetAHardwareBufferWriteView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||
}
|
||||
{
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
|
@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbTest, TestCpuThenGl) {
|
||||
RunInGlContext([] {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
TEST(TensorAhwbTest, TestAhwbAlignment) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5});
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
auto view = tensor.GetAHardwareBufferWriteView(16);
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_Desc desc;
|
||||
AHardwareBuffer_describe(view.handle(), &desc);
|
||||
// sizeof(float) * 5 = 20, the closest aligned to 16 size is 32.
|
||||
EXPECT_EQ(desc.width, 32);
|
||||
}
|
||||
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||
}
|
||||
}
|
||||
|
||||
// Tensor::GetCpuView uses source location mechanism that gives source file name
|
||||
// and line from where the method is called. The function is intended just to
|
||||
// have two calls providing the same source file name and line.
|
||||
auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); }
|
||||
|
||||
// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved
|
||||
// for the first time then the source location is attached to the tensor. If the
|
||||
// Ahwb view is requested then from the tensor then the previously recorded Cpu
|
||||
// view request source location is marked for using Ahwb storage.
|
||||
// When a Cpu view with the same source location (but for the newly allocated
|
||||
// tensor) is requested and the location is marked to use Ahwb storage then the
|
||||
// Ahwb storage is allocated for the CpuView.
|
||||
TEST(TensorAhwbTest, TestTrackingAhwb) {
|
||||
// Create first tensor and request Cpu and then Ahwb view to mark the source
|
||||
// location for Ahwb storage.
|
||||
{
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
|
||||
{
|
||||
auto view = GetCpuView(tensor);
|
||||
EXPECT_NE(view.buffer<float>(), nullptr);
|
||||
}
|
||||
{
|
||||
auto ssbo = tensor.GetOpenGlBufferReadView().name();
|
||||
EXPECT_GT(ssbo, 0);
|
||||
// Align size of the Ahwb by multiple of 16.
|
||||
auto view = tensor.GetAHardwareBufferWriteView(16);
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
}
|
||||
{
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9});
|
||||
{
|
||||
// The second tensor uses the same Cpu view source location so Ahwb
|
||||
// storage is allocated internally.
|
||||
auto view = GetCpuView(tensor);
|
||||
EXPECT_NE(view.buffer<float>(), nullptr);
|
||||
}
|
||||
{
|
||||
// Check the Ahwb size to be aligned to multiple of 16. The alignment is
|
||||
// stored by previous requesting of the Ahwb view.
|
||||
auto view = tensor.GetAHardwareBufferReadView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_Desc desc;
|
||||
AHardwareBuffer_describe(view.handle(), &desc);
|
||||
// sizeof(float) * 9 = 36. The closest aligned size is 48.
|
||||
EXPECT_EQ(desc.width, 48);
|
||||
}
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
|
|
@ -194,6 +194,7 @@ void GraphProfiler::Initialize(
|
|||
"Calculator \"$0\" has already been added.", node_name);
|
||||
}
|
||||
profile_builder_ = std::make_unique<GraphProfileBuilder>(this);
|
||||
graph_id_ = ++next_instance_id_;
|
||||
|
||||
is_initialized_ = true;
|
||||
}
|
||||
|
|
|
@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
|
|||
return validated_graph_;
|
||||
}
|
||||
|
||||
// Gets a numerical identifier for this GraphProfiler object.
|
||||
uint64_t GetGraphId() { return graph_id_; }
|
||||
|
||||
private:
|
||||
// This can be used to add packet info for the input streams to the graph.
|
||||
// It treats the stream defined by |stream_name| as a stream produced by a
|
||||
|
@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this<ProfilingContext> {
|
|||
class GraphProfileBuilder;
|
||||
std::unique_ptr<GraphProfileBuilder> profile_builder_;
|
||||
|
||||
// The globally incrementing identifier for all graphs in a process.
|
||||
static inline std::atomic_int next_instance_id_ = 0;
|
||||
|
||||
// A unique identifier for this object. Only unique within a process.
|
||||
uint64_t graph_id_;
|
||||
|
||||
// For testing.
|
||||
friend GraphProfilerTestPeer;
|
||||
};
|
||||
|
|
|
@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) {
|
|||
"Cannot initialize .* multiple times.");
|
||||
}
|
||||
|
||||
// Tests that graph identifiers are not reused, even after destruction.
|
||||
TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) {
|
||||
auto raw_graph_config = R"(
|
||||
profiler_config {
|
||||
enable_profiler: true
|
||||
}
|
||||
input_stream: "input_stream"
|
||||
node {
|
||||
calculator: "DummyTestCalculator"
|
||||
input_stream: "input_stream"
|
||||
})";
|
||||
const int n_iterations = 100;
|
||||
absl::flat_hash_set<int> seen_ids;
|
||||
for (int i = 0; i < n_iterations; ++i) {
|
||||
std::shared_ptr<ProfilingContext> profiler =
|
||||
std::make_shared<ProfilingContext>();
|
||||
auto graph_config = CreateGraphConfig(raw_graph_config);
|
||||
mediapipe::ValidatedGraphConfig validated_graph;
|
||||
QCHECK_OK(validated_graph.Initialize(graph_config));
|
||||
profiler->Initialize(validated_graph);
|
||||
|
||||
int id = profiler->GetGraphId();
|
||||
ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id)));
|
||||
seen_ids.insert(id);
|
||||
}
|
||||
}
|
||||
// Tests that Pause(), Resume(), and Reset() works.
|
||||
TEST_F(GraphProfilerTestPeer, PauseResumeReset) {
|
||||
InitializeProfilerWithGraphConfig(R"(
|
||||
|
|
|
@ -74,6 +74,9 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView(
|
|||
static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
|
||||
const GlTextureView& view) {
|
||||
CHECK(pixel_buffer);
|
||||
auto ctx = GlContext::GetCurrent().get();
|
||||
if (!ctx) ctx = view.gl_context();
|
||||
ctx->Run([pixel_buffer, &view, ctx] {
|
||||
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
|
||||
CHECK(err == kCVReturnSuccess)
|
||||
<< "CVPixelBufferLockBaseAddress failed: " << err;
|
||||
|
@ -82,34 +85,40 @@ static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer,
|
|||
uint8_t* pixel_ptr =
|
||||
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
|
||||
if (pixel_format == kCVPixelFormatType_32BGRA) {
|
||||
// TODO: restore previous framebuffer? Move this to helper so we
|
||||
// can use BindFramebuffer?
|
||||
glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx));
|
||||
glViewport(0, 0, view.width(), view.height());
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
|
||||
view.name(), 0);
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
view.target(), view.name(), 0);
|
||||
|
||||
size_t contiguous_bytes_per_row = view.width() * 4;
|
||||
if (bytes_per_row == contiguous_bytes_per_row) {
|
||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||
pixel_ptr);
|
||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
|
||||
GL_UNSIGNED_BYTE, pixel_ptr);
|
||||
} else {
|
||||
// TODO: use GL_PACK settings for row length. We can expect
|
||||
// GLES 3.0 on iOS now.
|
||||
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
|
||||
view.height());
|
||||
uint8_t* temp_ptr = contiguous_buffer.data();
|
||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
|
||||
temp_ptr);
|
||||
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA,
|
||||
GL_UNSIGNED_BYTE, temp_ptr);
|
||||
for (int i = 0; i < view.height(); ++i) {
|
||||
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
|
||||
temp_ptr += contiguous_bytes_per_row;
|
||||
pixel_ptr += bytes_per_row;
|
||||
}
|
||||
}
|
||||
// TODO: restore previous framebuffer?
|
||||
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
|
||||
view.target(), 0, 0);
|
||||
glBindFramebuffer(GL_FRAMEBUFFER, 0);
|
||||
} else {
|
||||
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
|
||||
}
|
||||
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
|
||||
CHECK(err == kCVReturnSuccess)
|
||||
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
|
||||
});
|
||||
}
|
||||
#endif // TARGET_IPHONE_SIMULATOR
|
||||
|
||||
|
|
|
@ -71,9 +71,12 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
filecmp.clear_cache()
|
||||
self.assertTrue(
|
||||
filecmp.cmp(output_metadata_file,
|
||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
|
||||
filecmp.cmp(
|
||||
output_metadata_file,
|
||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
||||
shallow=False))
|
||||
|
||||
def test_create_and_train_bert(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
|
|
|
@ -135,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||
filecmp.clear_cache()
|
||||
self.assertTrue(
|
||||
filecmp.cmp(
|
||||
output_metadata_file, expected_metadata_file, shallow=False))
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
mock_stdout = io.StringIO()
|
||||
|
|
|
@ -230,16 +230,17 @@ if ([wrapper.delegate
|
|||
}
|
||||
|
||||
- (absl::Status)performStart {
|
||||
absl::Status status = _graph->Initialize(_config);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
absl::Status status;
|
||||
for (const auto& service_packet : _servicePackets) {
|
||||
status = _graph->SetServicePacket(*service_packet.first, service_packet.second);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
status = _graph->Initialize(_config);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = _graph->StartRun(_inputSidePackets, _streamHeaders);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
|
|
|
@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
|
|||
auto custom_gestures_classifier_options_proto =
|
||||
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->canned_gestures_classifier_options)));
|
||||
&(options->custom_gestures_classifier_options)));
|
||||
hand_gesture_recognizer_graph_options
|
||||
->mutable_custom_gesture_classifier_graph_options()
|
||||
->mutable_classifier_options()
|
||||
->Swap(canned_gestures_classifier_options_proto.get());
|
||||
->Swap(custom_gestures_classifier_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,4 +38,3 @@ objc_library(
|
|||
"-std=c++17",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
@ -56,6 +57,7 @@ extern NSString *const MPPTasksErrorDomain;
|
|||
* @param status absl::Status.
|
||||
* @param error Pointer to the memory location where the created error should be saved. If `nil`,
|
||||
* no error will be saved.
|
||||
* @return YES when there is no error, NO otherwise.
|
||||
*/
|
||||
+ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error;
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/strings/cord.h" // from @com_google_absl
|
||||
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
|
||||
/** Error domain of MediaPipe task library errors. */
|
||||
|
@ -96,8 +95,8 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
|
|||
// appropriate MPPTasksErrorCode in default cases. Note:
|
||||
// The mapping to absl::Status::code() is done to generate a more specific error code than
|
||||
// MPPTasksErrorCodeError in cases when the payload can't be mapped to
|
||||
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned
|
||||
// without modification by Mediapipe cc library methods.
|
||||
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn
|
||||
// returned without modification by Mediapipe cc library methods.
|
||||
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
|
||||
switch (status.code()) {
|
||||
case absl::StatusCode::kInternal:
|
||||
|
|
|
@ -13,13 +13,14 @@
|
|||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface NSString (Helpers)
|
||||
|
||||
@property(readonly) std::string cppString;
|
||||
@property(readonly, nonatomic) std::string cppString;
|
||||
|
||||
+ (NSString *)stringWithCppString:(std::string)text;
|
||||
|
||||
|
|
|
@ -21,4 +21,3 @@ objc_library(
|
|||
srcs = ["sources/MPPClassifierOptions.m"],
|
||||
hdrs = ["sources/MPPClassifierOptions.h"],
|
||||
)
|
||||
|
||||
|
|
|
@ -22,29 +22,34 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
NS_SWIFT_NAME(ClassifierOptions)
|
||||
@interface MPPClassifierOptions : NSObject <NSCopying>
|
||||
|
||||
/** The locale to use for display names specified through the TFLite Model
|
||||
/**
|
||||
* The locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any. Defaults to English.
|
||||
*/
|
||||
@property(nonatomic, copy) NSString *displayNamesLocale;
|
||||
|
||||
/** The maximum number of top-scored classification results to return. If < 0,
|
||||
/**
|
||||
* The maximum number of top-scored classification results to return. If < 0,
|
||||
* all available results will be returned. If 0, an invalid argument error is
|
||||
* returned.
|
||||
*/
|
||||
@property(nonatomic) NSInteger maxResults;
|
||||
|
||||
/** Score threshold to override the one provided in the model metadata (if any).
|
||||
/**
|
||||
* Score threshold to override the one provided in the model metadata (if any).
|
||||
* Results below this value are rejected.
|
||||
*/
|
||||
@property(nonatomic) float scoreThreshold;
|
||||
|
||||
/** The allowlist of category names. If non-empty, detection results whose
|
||||
/**
|
||||
* The allowlist of category names. If non-empty, detection results whose
|
||||
* category name is not in this set will be filtered out. Duplicate or unknown
|
||||
* category names are ignored. Mutually exclusive with categoryDenylist.
|
||||
*/
|
||||
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
|
||||
|
||||
/** The denylist of category names. If non-empty, detection results whose
|
||||
/**
|
||||
* The denylist of category names. If non-empty, detection results whose
|
||||
* category name is in this set will be filtered out. Duplicate or unknown
|
||||
* category names are ignored. Mutually exclusive with categoryAllowlist.
|
||||
*/
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
- (instancetype)init {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
self.maxResults = -1;
|
||||
self.scoreThreshold = 0;
|
||||
_maxResults = -1;
|
||||
_scoreThreshold = 0;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
|
|
@ -22,8 +22,7 @@ objc_library(
|
|||
hdrs = ["sources/MPPClassifierOptions+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
]
|
||||
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
|
|
@ -29,7 +29,6 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
|
|||
}
|
||||
|
||||
classifierOptionsProto->set_max_results((int)self.maxResults);
|
||||
|
||||
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
|
||||
|
||||
for (NSString *category in self.categoryAllowlist) {
|
||||
|
|
|
@ -54,14 +54,14 @@ objc_library(
|
|||
"-std=c++17",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
|
||||
":MPPTaskOptions",
|
||||
":MPPTaskOptionsProtocol",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -83,9 +83,13 @@ objc_library(
|
|||
name = "MPPTaskRunner",
|
||||
srcs = ["sources/MPPTaskRunner.mm"],
|
||||
hdrs = ["sources/MPPTaskRunner.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||
|
||||
|
@ -59,7 +61,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
/**
|
||||
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.
|
||||
*/
|
||||
- (mediapipe::CalculatorGraphConfig)generateGraphConfig;
|
||||
- (::mediapipe::CalculatorGraphConfig)generateGraphConfig;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
namespace {
|
||||
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
|
||||
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||
using ::mediapipe::CalculatorOptions;
|
||||
using ::mediapipe::FlowLimiterCalculatorOptions;
|
||||
using ::mediapipe::InputStreamInfo;
|
||||
} // namespace
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/framework/calculator_options.pb.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
@ -25,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
/**
|
||||
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
|
||||
*/
|
||||
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;
|
||||
- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
/**
|
||||
* This class is used to create and call appropriate methods on the C++ Task Runner.
|
||||
*/
|
||||
|
||||
@interface MPPTaskRunner : NSObject
|
||||
|
||||
/**
|
||||
|
@ -35,11 +34,10 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)
|
||||
process:(const mediapipe::tasks::core::PacketMap &)packetMap
|
||||
error:(NSError **)error;
|
||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:
|
||||
(const mediapipe::tasks::core::PacketMap &)packetMap;
|
||||
|
||||
- (void)close;
|
||||
- (absl::Status)close;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
namespace {
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||
} // namespace
|
||||
|
@ -49,8 +48,8 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
|||
return _cppTaskRunner->Process(packetMap);
|
||||
}
|
||||
|
||||
- (void)close {
|
||||
_cppTaskRunner->Close();
|
||||
- (absl::Status)close {
|
||||
return _cppTaskRunner->Close();
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
*
|
||||
* @param options The options for the audio classifier.
|
||||
*/
|
||||
override async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(AUDIO_STREAM);
|
||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||
|
|
|
@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
audioClassifier = new AudioClassifierFake();
|
||||
await audioClassifier.setOptions({}); // Initialize graph
|
||||
await audioClassifier.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
*
|
||||
* @param options The options for the audio embedder.
|
||||
*/
|
||||
override async setOptions(options: AudioEmbedderOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: AudioEmbedderOptions): Promise<void> {
|
||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||
options, this.options.getEmbedderOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(AUDIO_STREAM);
|
||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||
|
|
|
@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
audioEmbedder = new AudioEmbedderFake();
|
||||
await audioEmbedder.setOptions({}); // Initialize graph
|
||||
await audioEmbedder.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', () => {
|
||||
|
|
|
@ -103,29 +103,3 @@ jasmine_node_test(
|
|||
name = "embedder_options_test",
|
||||
deps = [":embedder_options_test_lib"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "base_options",
|
||||
srcs = [
|
||||
"base_options.ts",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "base_options_test_lib",
|
||||
testonly = True,
|
||||
srcs = ["base_options.test.ts"],
|
||||
deps = [":base_options"],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "base_options_test",
|
||||
deps = [":base_options_test_lib"],
|
||||
)
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
// Placeholder for internal dependency on trusted resource URL builder
|
||||
|
||||
import {convertBaseOptionsToProto} from './base_options';
|
||||
|
||||
describe('convertBaseOptionsToProto()', () => {
|
||||
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
||||
const mockBytesResult = {
|
||||
modelAsset: {
|
||||
fileContent: Buffer.from(mockBytes).toString('base64'),
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined,
|
||||
},
|
||||
useStreamMode: false,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: undefined,
|
||||
tflite: {},
|
||||
},
|
||||
};
|
||||
|
||||
let fetchSpy: jasmine.Spy;
|
||||
|
||||
beforeEach(() => {
|
||||
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
||||
expect(url).toEqual('foo');
|
||||
return {
|
||||
arrayBuffer: () => mockBytes.buffer,
|
||||
} as unknown as Response;
|
||||
});
|
||||
global.fetch = fetchSpy;
|
||||
});
|
||||
|
||||
it('verifies that at least one model asset option is provided', async () => {
|
||||
await expectAsync(convertBaseOptionsToProto({}))
|
||||
.toBeRejectedWithError(
|
||||
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
||||
});
|
||||
|
||||
it('verifies that no more than one model asset option is provided', async () => {
|
||||
await expectAsync(convertBaseOptionsToProto({
|
||||
modelAssetPath: `foo`,
|
||||
modelAssetBuffer: new Uint8Array([])
|
||||
}))
|
||||
.toBeRejectedWithError(
|
||||
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
||||
});
|
||||
|
||||
it('downloads model', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetPath: `foo`,
|
||||
});
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalled();
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('does not download model when bytes are provided', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
});
|
||||
|
||||
expect(fetchSpy).not.toHaveBeenCalled();
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('can enable CPU delegate', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'CPU',
|
||||
});
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('can enable GPU delegate', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
});
|
||||
expect(baseOptionsProto.toObject()).toEqual({
|
||||
...mockBytesResult,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: {
|
||||
useAdvancedGpuApi: false,
|
||||
api: 0,
|
||||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
modelToken: undefined,
|
||||
usage: 2,
|
||||
},
|
||||
tflite: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('can reset delegate', async () => {
|
||||
let baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
});
|
||||
// Clear backend
|
||||
baseOptionsProto =
|
||||
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
});
|
|
@ -1,80 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* Converts a BaseOptions API object to its Protobuf representation.
|
||||
* @throws If neither a model assset path or buffer is provided
|
||||
*/
|
||||
export async function convertBaseOptionsToProto(
|
||||
updatedOptions: BaseOptions,
|
||||
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
||||
const result =
|
||||
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
||||
|
||||
await configureExternalFile(updatedOptions, result);
|
||||
configureAcceleration(updatedOptions, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configues the `externalFile` option and validates that a single model is
|
||||
* provided.
|
||||
*/
|
||||
async function configureExternalFile(
|
||||
options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const externalFile = proto.getModelAsset() || new ExternalFile();
|
||||
proto.setModelAsset(externalFile);
|
||||
|
||||
if (options.modelAssetPath || options.modelAssetBuffer) {
|
||||
if (options.modelAssetPath && options.modelAssetBuffer) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
}
|
||||
|
||||
let modelAssetBuffer = options.modelAssetBuffer;
|
||||
if (!modelAssetBuffer) {
|
||||
const response = await fetch(options.modelAssetPath!.toString());
|
||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||
}
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
}
|
||||
|
||||
if (!externalFile.hasFileContent()) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
}
|
||||
|
||||
/** Configues the `acceleration` option. */
|
||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const acceleration = proto.getAcceleration() ?? new Acceleration();
|
||||
if (options.delegate === 'GPU') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
} else {
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
}
|
||||
proto.setAcceleration(acceleration);
|
||||
}
|
|
@ -18,8 +18,10 @@ mediapipe_ts_library(
|
|||
srcs = ["task_runner.ts"],
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
||||
|
@ -53,6 +55,7 @@ mediapipe_ts_library(
|
|||
"task_runner_test.ts",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":task_runner",
|
||||
":task_runner_test_utils",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
|
||||
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
|
||||
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
||||
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||
|
@ -91,14 +93,52 @@ export abstract class TaskRunner {
|
|||
this.graphRunner.registerModelResourcesGraphService();
|
||||
}
|
||||
|
||||
/** Configures the shared options of a MediaPipe Task. */
|
||||
async setOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
this.baseOptions = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.baseOptions);
|
||||
/** Configures the task with custom options. */
|
||||
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
|
||||
|
||||
/**
|
||||
* Applies the current set of options, including any base options that have
|
||||
* not been processed by the task implementation. The options are applied
|
||||
* synchronously unless a `modelAssetPath` is provided. This ensures that
|
||||
* for most use cases options are applied directly and immediately affect
|
||||
* the next inference.
|
||||
*/
|
||||
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
const baseOptions: BaseOptions = options.baseOptions || {};
|
||||
|
||||
// Validate that exactly one model is configured
|
||||
if (options.baseOptions?.modelAssetBuffer &&
|
||||
options.baseOptions?.modelAssetPath) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||
options.baseOptions?.modelAssetBuffer ||
|
||||
options.baseOptions?.modelAssetPath)) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
|
||||
this.setAcceleration(baseOptions);
|
||||
if (baseOptions.modelAssetPath) {
|
||||
// We don't use `await` here since we want to apply most settings
|
||||
// synchronously.
|
||||
return fetch(baseOptions.modelAssetPath.toString())
|
||||
.then(response => response.arrayBuffer())
|
||||
.then(buffer => {
|
||||
this.setExternalFile(new Uint8Array(buffer));
|
||||
this.refreshGraph();
|
||||
});
|
||||
} else {
|
||||
// Apply the setting synchronously.
|
||||
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||
this.refreshGraph();
|
||||
return Promise.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
/** Appliest the current options to the MediaPipe graph. */
|
||||
protected abstract refreshGraph(): void;
|
||||
|
||||
/**
|
||||
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||
|
@ -140,6 +180,27 @@ export abstract class TaskRunner {
|
|||
}
|
||||
this.processingErrors = [];
|
||||
}
|
||||
|
||||
/** Configures the `externalFile` option */
|
||||
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
|
||||
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
||||
if (modelAssetBuffer) {
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
}
|
||||
this.baseOptions.setModelAsset(externalFile);
|
||||
}
|
||||
|
||||
/** Configures the `acceleration` option. */
|
||||
private setAcceleration(options: BaseOptions) {
|
||||
const acceleration =
|
||||
this.baseOptions.getAcceleration() ?? new Acceleration();
|
||||
if (options.delegate === 'GPU') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
} else {
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
}
|
||||
this.baseOptions.setAcceleration(acceleration);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -15,18 +15,22 @@
|
|||
*/
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource URL builder
|
||||
|
||||
import {GraphRunnerImageLib} from './task_runner';
|
||||
import {TaskRunnerOptions} from './task_runner_options.d';
|
||||
|
||||
class TaskRunnerFake extends TaskRunner {
|
||||
protected baseOptions = new BaseOptionsProto();
|
||||
private errorListener: ErrorListener|undefined;
|
||||
private errors: string[] = [];
|
||||
|
||||
baseOptions = new BaseOptionsProto();
|
||||
|
||||
static createFake(): TaskRunnerFake {
|
||||
const wasmModule = createSpyWasmModule();
|
||||
return new TaskRunnerFake(wasmModule);
|
||||
|
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
|
|||
super.finishProcessing();
|
||||
}
|
||||
|
||||
override refreshGraph(): void {}
|
||||
|
||||
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||
super.setGraph(graphData, isBinary);
|
||||
}
|
||||
|
||||
setOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
private throwErrors(): void {
|
||||
expect(this.errorListener).toBeDefined();
|
||||
for (const error of this.errors) {
|
||||
|
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
|
|||
}
|
||||
|
||||
describe('TaskRunner', () => {
|
||||
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
||||
const mockBytesResult = {
|
||||
modelAsset: {
|
||||
fileContent: Buffer.from(mockBytes).toString('base64'),
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined,
|
||||
},
|
||||
useStreamMode: false,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: undefined,
|
||||
tflite: {},
|
||||
},
|
||||
};
|
||||
|
||||
let fetchSpy: jasmine.Spy;
|
||||
let taskRunner: TaskRunnerFake;
|
||||
|
||||
beforeEach(() => {
|
||||
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
||||
expect(url).toEqual('foo');
|
||||
return {
|
||||
arrayBuffer: () => mockBytes.buffer,
|
||||
} as unknown as Response;
|
||||
});
|
||||
global.fetch = fetchSpy;
|
||||
|
||||
taskRunner = TaskRunnerFake.createFake();
|
||||
});
|
||||
|
||||
it('handles errors during graph update', () => {
|
||||
const taskRunner = TaskRunnerFake.createFake();
|
||||
taskRunner.enqueueError('Test error');
|
||||
|
||||
expect(() => {
|
||||
|
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
|
|||
});
|
||||
|
||||
it('handles errors during graph execution', () => {
|
||||
const taskRunner = TaskRunnerFake.createFake();
|
||||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||
|
||||
taskRunner.enqueueError('Test error');
|
||||
|
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
|
|||
});
|
||||
|
||||
it('can handle multiple errors', () => {
|
||||
const taskRunner = TaskRunnerFake.createFake();
|
||||
taskRunner.enqueueError('Test error 1');
|
||||
taskRunner.enqueueError('Test error 2');
|
||||
|
||||
|
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
|
|||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||
}).toThrowError(/Test error 1, Test error 2/);
|
||||
});
|
||||
|
||||
it('verifies that at least one model asset option is provided', () => {
|
||||
expect(() => {
|
||||
taskRunner.setOptions({});
|
||||
})
|
||||
.toThrowError(
|
||||
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
||||
});
|
||||
|
||||
it('verifies that no more than one model asset option is provided', () => {
|
||||
expect(() => {
|
||||
taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetPath: `foo`,
|
||||
modelAssetBuffer: new Uint8Array([])
|
||||
}
|
||||
});
|
||||
})
|
||||
.toThrowError(
|
||||
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
||||
});
|
||||
|
||||
it('doesn\'t require model once it is configured', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||
expect(() => {
|
||||
taskRunner.setOptions({});
|
||||
}).not.toThrowError();
|
||||
});
|
||||
|
||||
it('downloads model', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetPath: `foo`}});
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalled();
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('does not download model when bytes are provided', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||
|
||||
expect(fetchSpy).not.toHaveBeenCalled();
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('changes model synchronously when bytes are provided', () => {
|
||||
const resolvedPromise = taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||
|
||||
// Check that the change has been applied even though we do not await the
|
||||
// above Promise
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
return resolvedPromise;
|
||||
});
|
||||
|
||||
it('can enable CPU delegate', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'CPU',
|
||||
}
|
||||
});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('can enable GPU delegate', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual({
|
||||
...mockBytesResult,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: {
|
||||
useAdvancedGpuApi: false,
|
||||
api: 0,
|
||||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
modelToken: undefined,
|
||||
usage: 2,
|
||||
},
|
||||
tflite: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('can reset delegate', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
// Clear backend
|
||||
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule {
|
|||
* Sets up our equality testing to use a custom float equality checking function
|
||||
* to avoid incorrect test results due to minor floating point inaccuracies.
|
||||
*/
|
||||
export function addJasmineCustomFloatEqualityTester() {
|
||||
export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) {
|
||||
jasmine.addCustomEqualityTester((a, b) => { // Custom float equality
|
||||
if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) {
|
||||
return Math.abs(a - b) < 5e-8;
|
||||
return Math.abs(a - b) < tolerance;
|
||||
}
|
||||
return;
|
||||
});
|
||||
|
|
|
@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
|
|||
*
|
||||
* @param options The options for the text classifier.
|
||||
*/
|
||||
override async setOptions(options: TextClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: TextClassifierOptions): Promise<void> {
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
|
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
|
|
@ -56,7 +56,8 @@ describe('TextClassifier', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
textClassifier = new TextClassifierFake();
|
||||
await textClassifier.setOptions({}); // Initialize graph
|
||||
await textClassifier.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
|
|||
*
|
||||
* @param options The options for the text embedder.
|
||||
*/
|
||||
override async setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||
options, this.options.getEmbedderOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
|
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||
|
|
|
@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
textEmbedder = new TextEmbedderFake();
|
||||
await textEmbedder.setOptions({}); // Initialize graph
|
||||
await textEmbedder.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -29,6 +29,7 @@ mediapipe_ts_library(
|
|||
testonly = True,
|
||||
srcs = ["vision_task_runner.test.ts"],
|
||||
deps = [
|
||||
":vision_task_options",
|
||||
":vision_task_runner",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||
|
|
|
@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
|||
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||
|
||||
import {VisionTaskOptions} from './vision_task_options';
|
||||
import {VisionTaskRunner} from './vision_task_runner';
|
||||
|
||||
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||
|
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
|||
|
||||
protected override process(): void {}
|
||||
|
||||
protected override refreshGraph(): void {}
|
||||
|
||||
override setOptions(options: VisionTaskOptions): Promise<void> {
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
override processImageData(image: ImageSource): void {
|
||||
super.processImageData(image);
|
||||
}
|
||||
|
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
|||
}
|
||||
|
||||
describe('VisionTaskRunner', () => {
|
||||
const streamMode = {
|
||||
modelAsset: undefined,
|
||||
useStreamMode: true,
|
||||
acceleration: undefined,
|
||||
};
|
||||
|
||||
const imageMode = {
|
||||
modelAsset: undefined,
|
||||
useStreamMode: false,
|
||||
acceleration: undefined,
|
||||
};
|
||||
|
||||
let visionTaskRunner: VisionTaskRunnerFake;
|
||||
|
||||
beforeEach(() => {
|
||||
beforeEach(async () => {
|
||||
visionTaskRunner = new VisionTaskRunnerFake();
|
||||
await visionTaskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('can enable image mode', async () => {
|
||||
await visionTaskRunner.setOptions({runningMode: 'image'});
|
||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
||||
expect(visionTaskRunner.baseOptions.toObject())
|
||||
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||
});
|
||||
|
||||
it('can enable video mode', async () => {
|
||||
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
|
||||
expect(visionTaskRunner.baseOptions.toObject())
|
||||
.toEqual(jasmine.objectContaining({useStreamMode: true}));
|
||||
});
|
||||
|
||||
it('can clear running mode', async () => {
|
||||
|
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
|
|||
|
||||
// Clear running mode
|
||||
await visionTaskRunner.setOptions({runningMode: undefined});
|
||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
||||
expect(visionTaskRunner.baseOptions.toObject())
|
||||
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||
});
|
||||
|
||||
it('cannot process images with video mode', async () => {
|
||||
|
|
|
@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
|
|||
/** Base class for all MediaPipe Vision Tasks. */
|
||||
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
||||
/** Configures the shared options of a vision task. */
|
||||
override async setOptions(options: VisionTaskOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override applyOptions(options: VisionTaskOptions): Promise<void> {
|
||||
if ('runningMode' in options) {
|
||||
const useStreamMode =
|
||||
!!options.runningMode && options.runningMode !== 'image';
|
||||
this.baseOptions.setUseStreamMode(useStreamMode);
|
||||
}
|
||||
return super.applyOptions(options);
|
||||
}
|
||||
|
||||
/** Sends an image packet to the graph and awaits results. */
|
||||
|
|
|
@ -169,9 +169,7 @@ export class GestureRecognizer extends
|
|||
*
|
||||
* @param options The options for the gesture recognizer.
|
||||
*/
|
||||
override async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
|
||||
override setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||
if ('numHands' in options) {
|
||||
this.handDetectorGraphOptions.setNumHands(
|
||||
options.numHands ?? DEFAULT_NUM_HANDS);
|
||||
|
@ -221,7 +219,7 @@ export class GestureRecognizer extends
|
|||
?.clearClassifierOptions();
|
||||
}
|
||||
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -265,6 +263,15 @@ export class GestureRecognizer extends
|
|||
NORM_RECT_STREAM, timestamp);
|
||||
this.finishProcessing();
|
||||
|
||||
if (this.gestures.length === 0) {
|
||||
// If no gestures are detected in the image, just return an empty list
|
||||
return {
|
||||
gestures: [],
|
||||
landmarks: [],
|
||||
worldLandmarks: [],
|
||||
handednesses: [],
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
gestures: this.gestures,
|
||||
landmarks: this.landmarks,
|
||||
|
@ -272,6 +279,7 @@ export class GestureRecognizer extends
|
|||
handednesses: this.handednesses
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/** Sets the default values for the graph. */
|
||||
private initDefaults(): void {
|
||||
|
@ -285,15 +293,19 @@ export class GestureRecognizer extends
|
|||
}
|
||||
|
||||
/** Converts the proto data to a Category[][] structure. */
|
||||
private toJsCategories(data: Uint8Array[]): Category[][] {
|
||||
private toJsCategories(data: Uint8Array[], populateIndex = true):
|
||||
Category[][] {
|
||||
const result: Category[][] = [];
|
||||
for (const binaryProto of data) {
|
||||
const inputList = ClassificationList.deserializeBinary(binaryProto);
|
||||
const outputList: Category[] = [];
|
||||
for (const classification of inputList.getClassificationList()) {
|
||||
const index = populateIndex && classification.hasIndex() ?
|
||||
classification.getIndex()! :
|
||||
DEFAULT_CATEGORY_INDEX;
|
||||
outputList.push({
|
||||
score: classification.getScore() ?? 0,
|
||||
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX,
|
||||
index,
|
||||
categoryName: classification.getLabel() ?? '',
|
||||
displayName: classification.getDisplayName() ?? '',
|
||||
});
|
||||
|
@ -342,7 +354,7 @@ export class GestureRecognizer extends
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
|
@ -377,7 +389,10 @@ export class GestureRecognizer extends
|
|||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
HAND_GESTURES_STREAM, binaryProto => {
|
||||
this.gestures.push(...this.toJsCategories(binaryProto));
|
||||
// Gesture index is not used, because the final gesture result comes
|
||||
// from multiple classifiers.
|
||||
this.gestures.push(
|
||||
...this.toJsCategories(binaryProto, /* populateIndex= */ false));
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
HANDEDNESS_STREAM, binaryProto => {
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||
|
||||
export {Category, Landmark, NormalizedLandmark};
|
||||
|
||||
/**
|
||||
* Represents the gesture recognition results generated by `GestureRecognizer`.
|
||||
*/
|
||||
|
@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult {
|
|||
/** Handedness of detected hands. */
|
||||
handednesses: Category[][];
|
||||
|
||||
/** Recognized hand gestures of detected hands */
|
||||
/**
|
||||
* Recognized hand gestures of detected hands. Note that the index of the
|
||||
* gesture is always -1, because the raw indices from multiple gesture
|
||||
* classifiers cannot consolidate to a meaningful index.
|
||||
*/
|
||||
gestures: Category[][];
|
||||
}
|
||||
|
|
|
@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
gestureRecognizer = new GestureRecognizerFake();
|
||||
await gestureRecognizer.setOptions({}); // Initialize graph
|
||||
await gestureRecognizer.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
@ -271,7 +272,7 @@ describe('GestureRecognizer', () => {
|
|||
expect(gestures).toEqual({
|
||||
'gestures': [[{
|
||||
'score': 0.2,
|
||||
'index': 2,
|
||||
'index': -1,
|
||||
'categoryName': 'gesture_label',
|
||||
'displayName': 'gesture_display_name'
|
||||
}]],
|
||||
|
@ -304,4 +305,25 @@ describe('GestureRecognizer', () => {
|
|||
// gestures.
|
||||
expect(gestures2).toEqual(gestures1);
|
||||
});
|
||||
|
||||
it('returns empty results when no gestures are detected', async () => {
|
||||
// Pass the test data to our listener
|
||||
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(gestureRecognizer);
|
||||
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
|
||||
gestureRecognizer.listeners.get('world_hand_landmarks')!
|
||||
(createWorldLandmarks());
|
||||
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
|
||||
gestureRecognizer.listeners.get('hand_gestures')!([]);
|
||||
});
|
||||
|
||||
// Invoke the gesture recognizer
|
||||
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
|
||||
expect(gestures).toEqual({
|
||||
'gestures': [],
|
||||
'landmarks': [],
|
||||
'worldLandmarks': [],
|
||||
'handednesses': []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
*
|
||||
* @param options The options for the hand landmarker.
|
||||
*/
|
||||
override async setOptions(options: HandLandmarkerOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
|
||||
override setOptions(options: HandLandmarkerOptions): Promise<void> {
|
||||
// Configure hand detector options.
|
||||
if ('numHands' in options) {
|
||||
this.handDetectorGraphOptions.setNumHands(
|
||||
|
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||
}
|
||||
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||
|
||||
export {Landmark, NormalizedLandmark, Category};
|
||||
|
||||
/**
|
||||
* Represents the hand landmarks deection results generated by `HandLandmarker`.
|
||||
*/
|
||||
|
|
|
@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
handLandmarker = new HandLandmarkerFake();
|
||||
await handLandmarker.setOptions({}); // Initialize graph
|
||||
await handLandmarker.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
|||
*
|
||||
* @param options The options for the image classifier.
|
||||
*/
|
||||
override async setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
|
|
@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
imageClassifier = new ImageClassifierFake();
|
||||
await imageClassifier.setOptions({}); // Initialize graph
|
||||
await imageClassifier.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
|||
*
|
||||
* @param options The options for the image embedder.
|
||||
*/
|
||||
override async setOptions(options: ImageEmbedderOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: ImageEmbedderOptions): Promise<void> {
|
||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||
options, this.options.getEmbedderOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||
|
|
|
@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
imageEmbedder = new ImageEmbedderFake();
|
||||
await imageEmbedder.setOptions({}); // Initialize graph
|
||||
await imageEmbedder.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
*
|
||||
* @param options The options for the object detector.
|
||||
*/
|
||||
override async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
|
||||
override setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||
// have to expliclity clear the values instead of setting them to
|
||||
// `undefined`.
|
||||
|
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
this.options.clearCategoryDenylistList();
|
||||
}
|
||||
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
||||
|
|
|
@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
objectDetector = new ObjectDetectorFake();
|
||||
await objectDetector.setOptions({}); // Initialize graph
|
||||
await objectDetector.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
Loading…
Reference in New Issue
Block a user