Switch MediaPipe Tasks Python and Java base layer to use MediaPipeBuiltinOpResolver by default.

PiperOrigin-RevId: 477927852
This commit is contained in:
Jiuqiang Tang 2022-09-30 08:32:42 +00:00 committed by Sebastian Schmidt
parent 3816951b8c
commit af2ad1abbe
5 changed files with 8 additions and 4 deletions

View File

@ -28,10 +28,10 @@ cc_library_with_tflite(
], ],
tflite_deps = [ tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_resources_cache",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
], ],
deps = [ deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
] + select({ ] + select({
"//conditions:default": ["//third_party/java/jdk:jni"], "//conditions:default": ["//third_party/java/jdk:jni"],
"//mediapipe:android": [], "//mediapipe:android": [],

View File

@ -34,10 +34,10 @@ cc_library_with_tflite(
}), }),
tflite_deps = [ tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_resources_cache",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
], ],
deps = [ deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": [], "//mediapipe:android": [],

View File

@ -17,11 +17,13 @@
#include <utility> #include <utility>
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h" #include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace { namespace {
using ::mediapipe::tasks::core::kModelResourcesCacheService; using ::mediapipe::tasks::core::kModelResourcesCacheService;
using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver;
using ::mediapipe::tasks::core::ModelResourcesCache; using ::mediapipe::tasks::core::ModelResourcesCache;
using HandleType = std::shared_ptr<ModelResourcesCache>*; using HandleType = std::shared_ptr<ModelResourcesCache>*;
} // namespace } // namespace
@ -29,7 +31,7 @@ using HandleType = std::shared_ptr<ModelResourcesCache>*;
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD( JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) { nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) {
auto ptr = std::make_shared<ModelResourcesCache>( auto ptr = std::make_shared<ModelResourcesCache>(
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); absl::make_unique<MediaPipeBuiltinOpResolver>());
HandleType handle = new std::shared_ptr<ModelResourcesCache>(std::move(ptr)); HandleType handle = new std::shared_ptr<ModelResourcesCache>(std::move(ptr));
return reinterpret_cast<jlong>(handle); return reinterpret_cast<jlong>(handle);
} }

View File

@ -27,6 +27,7 @@ pybind_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/python/pybind:util", "//mediapipe/python/pybind:util",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",

View File

@ -16,6 +16,7 @@
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/python/pybind/util.h" #include "mediapipe/python/pybind/util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h" #include "pybind11_protobuf/native_proto_caster.h"
@ -75,7 +76,7 @@ mode) or not (synchronous mode).)doc");
} }
auto task_runner = TaskRunner::Create( auto task_runner = TaskRunner::Create(
std::move(graph_config), std::move(graph_config),
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback)); std::move(callback));
RaisePyErrorIfNotOk(task_runner.status()); RaisePyErrorIfNotOk(task_runner.status());
return std::move(*task_runner); return std::move(*task_runner);