Internal change
PiperOrigin-RevId: 524345939
This commit is contained in:
parent
257fa01b68
commit
54d208aa5c
|
@ -401,8 +401,8 @@ cc_library_with_tflite(
|
||||||
hdrs = ["inference_calculator.h"],
|
hdrs = ["inference_calculator.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"//mediapipe/util/tflite:tflite_model_loader",
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_cc_proto",
|
":inference_calculator_cc_proto",
|
||||||
|
@ -506,7 +506,7 @@ cc_library_with_tflite(
|
||||||
name = "tflite_delegate_ptr",
|
name = "tflite_delegate_ptr",
|
||||||
hdrs = ["tflite_delegate_ptr.h"],
|
hdrs = ["tflite_delegate_ptr.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -517,8 +517,8 @@ cc_library_with_tflite(
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":tflite_delegate_ptr",
|
":tflite_delegate_ptr",
|
||||||
"//mediapipe/util/tflite:tflite_model_loader",
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_runner",
|
":inference_runner",
|
||||||
|
@ -546,8 +546,8 @@ cc_library(
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
|
|
|
@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
|
||||||
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
||||||
}
|
}
|
||||||
return PacketAdopting<tflite::OpResolver>(
|
return PacketAdopting<tflite::OpResolver>(
|
||||||
std::make_unique<tflite_shims::ops::builtin::
|
std::make_unique<
|
||||||
BuiltinOpResolverWithoutDefaultDelegates>());
|
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
|
||||||
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
||||||
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
||||||
// migration.
|
// migration.
|
||||||
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
|
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
||||||
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||||
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
||||||
"OP_RESOLVER"};
|
"OP_RESOLVER"};
|
||||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#endif // ANDROID
|
#endif // ANDROID
|
||||||
|
|
|
@ -22,9 +22,9 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
|
#include "tensorflow/lite/interpreter_builder.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||||
|
@ -33,8 +33,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using Interpreter = ::tflite_shims::Interpreter;
|
using Interpreter = ::tflite::Interpreter;
|
||||||
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
|
using InterpreterBuilder = ::tflite::InterpreterBuilder;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
#include "mediapipe/calculators/tensor/tflite_delegate_ptr.h"
|
#include "mediapipe/calculators/tensor/tflite_delegate_ptr.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -157,7 +157,7 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
|
@ -270,7 +270,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassifyTest : public tflite_shims::testing::Test {};
|
class ClassifyTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ClassifyTest, Succeeds) {
|
TEST_F(ClassifyTest, Succeeds) {
|
||||||
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
||||||
|
@ -467,7 +467,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
|
class ClassifyAsyncTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ClassifyAsyncTest, Succeeds) {
|
TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
constexpr int kSampleRateHz = 48000;
|
constexpr int kSampleRateHz = 48000;
|
||||||
|
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -66,7 +66,7 @@ Matrix GetAudioData(absl::string_view filename) {
|
||||||
return matrix_mapping.matrix();
|
return matrix_mapping.matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
auto audio_embedder =
|
auto audio_embedder =
|
||||||
|
@ -124,7 +124,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class EmbedTest : public tflite_shims::testing::Test {};
|
class EmbedTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
||||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
@ -187,7 +187,7 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
|
||||||
MP_EXPECT_OK(audio_embedder->Close());
|
MP_EXPECT_OK(audio_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class EmbedAsyncTest : public tflite_shims::testing::Test {
|
class EmbedAsyncTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
||||||
int sample_rate_hz,
|
int sample_rate_hz,
|
||||||
|
|
|
@ -47,7 +47,7 @@ cc_test_with_tflite(
|
||||||
data = ["//mediapipe/tasks/testdata/audio:test_models"],
|
data = ["//mediapipe/tasks/testdata/audio:test_models"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":audio_tensor_specs",
|
":audio_tensor_specs",
|
||||||
|
|
|
@ -34,7 +34,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -52,7 +52,7 @@ constexpr char kModelWithMetadata[] =
|
||||||
"yamnet_audio_classifier_with_metadata.tflite";
|
"yamnet_audio_classifier_with_metadata.tflite";
|
||||||
constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite";
|
constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite";
|
||||||
|
|
||||||
class AudioTensorSpecsTest : public tflite_shims::testing::Test {};
|
class AudioTensorSpecsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(AudioTensorSpecsTest,
|
TEST_F(AudioTensorSpecsTest,
|
||||||
BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) {
|
BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) {
|
||||||
|
|
|
@ -63,7 +63,7 @@ cc_test(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,6 +232,6 @@ cc_test(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -66,8 +66,7 @@ ClassificationList MakeClassificationList(int class_index) {
|
||||||
class_index));
|
class_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassificationAggregationCalculatorTest
|
class ClassificationAggregationCalculatorTest : public tflite::testing::Test {
|
||||||
: public tflite_shims::testing::Test {
|
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
bool connect_timestamps = false) {
|
bool connect_timestamps = false) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -52,7 +52,7 @@ constexpr char kTimestampsName[] = "timestamps_in";
|
||||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||||
|
|
||||||
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
class EmbeddingAggregationCalculatorTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
|
|
@ -49,7 +49,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -101,7 +101,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
std::move(external_file));
|
std::move(external_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
class ConfigureTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -417,7 +417,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
)pb")));
|
)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
class PostprocessingTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||||
|
|
|
@ -39,7 +39,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -86,7 +86,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
std::move(external_file));
|
std::move(external_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
class ConfigureTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -153,7 +153,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
has_quantized_outputs: false)pb")));
|
has_quantized_outputs: false)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
class PostprocessingTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -125,7 +125,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
||||||
return TaskRunner::Create(graph.GetConfig());
|
return TaskRunner::Create(graph.GetConfig());
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
class ConfigureTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -128,9 +128,9 @@ cc_library_with_tflite(
|
||||||
srcs = ["model_resources.cc"],
|
srcs = ["model_resources.cc"],
|
||||||
hdrs = ["model_resources.h"],
|
hdrs = ["model_resources.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:verifier",
|
"@org_tensorflow//tensorflow/lite/tools:verifier",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":external_file_handler",
|
":external_file_handler",
|
||||||
|
@ -159,9 +159,9 @@ cc_test_with_tflite(
|
||||||
],
|
],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":model_resources",
|
":model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":utils",
|
":utils",
|
||||||
|
@ -186,7 +186,7 @@ cc_library_with_tflite(
|
||||||
hdrs = ["model_resources_cache.h"],
|
hdrs = ["model_resources_cache.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":model_resources",
|
":model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":model_asset_bundle_resources",
|
":model_asset_bundle_resources",
|
||||||
|
@ -233,7 +233,7 @@ cc_test_with_tflite(
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
":model_resources_calculator",
|
":model_resources_calculator",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
@ -284,7 +284,7 @@ cc_test_with_tflite(
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
|
|
|
@ -37,8 +37,8 @@ limitations under the License.
|
||||||
#include "mediapipe/util/tflite/error_reporter.h"
|
#include "mediapipe/util/tflite/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
#include "tensorflow/lite/model_builder.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
|
#include "tensorflow/lite/tools/verifier.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -52,7 +52,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
|
||||||
bool ModelResources::Verifier::Verify(const char* data, int length,
|
bool ModelResources::Verifier::Verify(const char* data, int length,
|
||||||
tflite::ErrorReporter* reporter) {
|
tflite::ErrorReporter* reporter) {
|
||||||
return tflite_shims::Verify(data, length, reporter);
|
return tflite::Verify(data, length, reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelResources::ModelResources(const std::string& tag,
|
ModelResources::ModelResources(const std::string& tag,
|
||||||
|
@ -124,7 +124,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
|
||||||
// and that it uses only operators that are supported by the OpResolver
|
// and that it uses only operators that are supported by the OpResolver
|
||||||
// that was passed to the ModelResources constructor, and then builds
|
// that was passed to the ModelResources constructor, and then builds
|
||||||
// the model from the buffer.
|
// the model from the buffer.
|
||||||
auto model = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
|
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||||
buffer_data, buffer_size, &verifier_, &error_reporter_);
|
buffer_data, buffer_size, &verifier_, &error_reporter_);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
static constexpr char kInvalidFlatbufferMessage[] =
|
static constexpr char kInvalidFlatbufferMessage[] =
|
||||||
|
@ -151,8 +151,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
|
||||||
}
|
}
|
||||||
|
|
||||||
model_packet_ = MakePacket<ModelPtr>(
|
model_packet_ = MakePacket<ModelPtr>(
|
||||||
model.release(),
|
model.release(), [](tflite::FlatBufferModel* model) { delete model; });
|
||||||
[](tflite_shims::FlatBufferModel* model) { delete model; });
|
|
||||||
ASSIGN_OR_RETURN(auto model_metadata_extractor,
|
ASSIGN_OR_RETURN(auto model_metadata_extractor,
|
||||||
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
|
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
|
||||||
buffer_data, buffer_size));
|
buffer_data, buffer_size));
|
||||||
|
|
|
@ -32,10 +32,10 @@ limitations under the License.
|
||||||
#include "mediapipe/util/tflite/error_reporter.h"
|
#include "mediapipe/util/tflite/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
#include "tensorflow/lite/model_builder.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
|
#include "tensorflow/lite/tools/verifier.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -51,8 +51,8 @@ class ModelResources {
|
||||||
public:
|
public:
|
||||||
// Represents a TfLite model as a FlatBuffer.
|
// Represents a TfLite model as a FlatBuffer.
|
||||||
using ModelPtr =
|
using ModelPtr =
|
||||||
std::unique_ptr<tflite_shims::FlatBufferModel,
|
std::unique_ptr<tflite::FlatBufferModel,
|
||||||
std::function<void(tflite_shims::FlatBufferModel*)>>;
|
std::function<void(tflite::FlatBufferModel*)>>;
|
||||||
|
|
||||||
// Takes the ownership of the provided ExternalFile proto and creates
|
// Takes the ownership of the provided ExternalFile proto and creates
|
||||||
// ModelResources from the proto and an op resolver object. A non-empty tag
|
// ModelResources from the proto and an op resolver object. A non-empty tag
|
||||||
|
@ -61,7 +61,7 @@ class ModelResources {
|
||||||
static absl::StatusOr<std::unique_ptr<ModelResources>> Create(
|
static absl::StatusOr<std::unique_ptr<ModelResources>> Create(
|
||||||
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
|
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver =
|
std::unique_ptr<tflite::OpResolver> op_resolver =
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
|
|
||||||
// Takes the ownership of the provided ExternalFile proto and creates
|
// Takes the ownership of the provided ExternalFile proto and creates
|
||||||
// ModelResources from the proto and an op resolver mediapipe packet. A
|
// ModelResources from the proto and an op resolver mediapipe packet. A
|
||||||
|
|
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -124,7 +124,7 @@ void RunGraphWithGraphService(std::unique_ptr<ModelResources> model_resources,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class ModelResourcesCalculatorTest : public tflite_shims::testing::Test {};
|
class ModelResourcesCalculatorTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) {
|
TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) {
|
||||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
|
|
@ -38,9 +38,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -116,7 +116,7 @@ void CheckModelResourcesPackets(const ModelResources* model_resources) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class ModelResourcesTest : public tflite_shims::testing::Test {};
|
class ModelResourcesTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ModelResourcesTest, CreateFromBinaryContent) {
|
TEST_F(ModelResourcesTest, CreateFromBinaryContent) {
|
||||||
auto model_file = std::make_unique<proto::ExternalFile>();
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
@ -211,7 +211,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsFromFile) {
|
||||||
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
||||||
tflite::MutableOpResolver resolver;
|
tflite::MutableOpResolver resolver;
|
||||||
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||||
::tflite_shims::ops::builtin::Register_ADD());
|
::tflite::ops::builtin::Register_ADD());
|
||||||
resolver.AddCustom(kCustomOpName,
|
resolver.AddCustom(kCustomOpName,
|
||||||
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsPacket) {
|
||||||
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
||||||
tflite::MutableOpResolver resolver;
|
tflite::MutableOpResolver resolver;
|
||||||
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||||
::tflite_shims::ops::builtin::Register_ADD());
|
::tflite::ops::builtin::Register_ADD());
|
||||||
resolver.AddCustom(kCustomOpName,
|
resolver.AddCustom(kCustomOpName,
|
||||||
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -112,7 +112,7 @@ CalculatorGraphConfig GetModelSidePacketsToStreamPacketsGraphConfig(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class TaskRunnerTest : public tflite_shims::testing::Test {};
|
class TaskRunnerTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) {
|
TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) {
|
||||||
CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
|
|
@ -89,7 +89,7 @@ cc_test(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:cord",
|
"@com_google_absl//absl/strings:cord",
|
||||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::text_classifier {
|
namespace mediapipe::tasks::text::text_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -87,7 +87,7 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
class TextClassifierTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||||
auto options = std::make_unique<TextClassifierOptions>();
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
|
|
@ -91,6 +91,6 @@ cc_test(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::text_embedder {
|
namespace mediapipe::tasks::text::text_embedder {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -49,7 +49,7 @@ using ::mediapipe::file::JoinPath;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
class EmbedderTest : public tflite_shims::testing::Test {};
|
class EmbedderTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
||||||
auto text_embedder =
|
auto text_embedder =
|
||||||
|
|
|
@ -81,6 +81,6 @@ cc_test(
|
||||||
"@com_google_absl//absl/flags:flag",
|
"@com_google_absl//absl/flags:flag",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::utils {
|
namespace mediapipe::tasks::text::utils {
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class TextModelUtilsTest : public tflite_shims::testing::Test {};
|
class TextModelUtilsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
|
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -105,7 +105,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
||||||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
class FaceBlendshapesTest : public tflite_shims::testing::Test {};
|
class FaceBlendshapesTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(FaceBlendshapesTest, SmokeTest) {
|
TEST_F(FaceBlendshapesTest, SmokeTest) {
|
||||||
// Prepare graph inputs.
|
// Prepare graph inputs.
|
||||||
|
|
|
@ -43,7 +43,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -137,7 +137,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
||||||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
class HandLandmarkerTest : public tflite_shims::testing::Test {};
|
class HandLandmarkerTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(HandLandmarkerTest, Succeeds) {
|
TEST_F(HandLandmarkerTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -41,7 +41,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
|
|
@ -146,7 +146,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner(
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create a Multi Hand Landmark TaskRunner.
|
// Helper function to create a Multi Hand Landmark TaskRunner.
|
||||||
|
@ -188,7 +188,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner(
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
||||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -148,7 +148,7 @@ class MobileNetQuantizedOpResolverMissingOps
|
||||||
const MobileNetQuantizedOpResolverMissingOps& r) = delete;
|
const MobileNetQuantizedOpResolverMissingOps& r) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateTest : public tflite_shims::testing::Test {};
|
class CreateTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ImageClassifierOptions>();
|
auto options = std::make_unique<ImageClassifierOptions>();
|
||||||
|
@ -265,7 +265,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -605,7 +605,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
||||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -707,7 +707,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK(image_classifier->Close());
|
MP_ASSERT_OK(image_classifier->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -30,9 +30,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -103,7 +103,7 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
delete;
|
delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateTest : public tflite_shims::testing::Test {};
|
class CreateTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||||
|
@ -181,7 +181,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -410,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -494,7 +494,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK(image_embedder->Close());
|
MP_ASSERT_OK(image_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -180,7 +180,7 @@ class DeepLabOpResolver : public ::tflite::MutableOpResolver {
|
||||||
DeepLabOpResolver(const DeepLabOpResolver& r) = delete;
|
DeepLabOpResolver(const DeepLabOpResolver& r) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
public:
|
public:
|
||||||
|
@ -268,7 +268,7 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -521,7 +521,7 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -581,7 +581,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
#include "testing/base/public/gmock.h"
|
#include "testing/base/public/gmock.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -124,7 +124,7 @@ MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
|
||||||
similarity_threshold;
|
similarity_threshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
public:
|
public:
|
||||||
|
@ -261,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||||
info) { return info.param.test_name; });
|
info) { return info.param.test_name; });
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
// TODO: fix this unit test after image segmenter handled post
|
// TODO: fix this unit test after image segmenter handled post
|
||||||
// processing correctly with rotated image.
|
// processing correctly with rotated image.
|
||||||
|
|
|
@ -43,9 +43,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -159,7 +159,7 @@ class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver {
|
||||||
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
|
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
@ -332,7 +332,7 @@ TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
|
||||||
// TODO: Add NumThreadsTest back after having an
|
// TODO: Add NumThreadsTest back after having an
|
||||||
// "acceleration configuration" field in the ObjectDetectorOptions.
|
// "acceleration configuration" field in the ObjectDetectorOptions.
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
@ -618,7 +618,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
@ -673,7 +673,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
|
|
@ -143,7 +143,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSinglePoseTaskRunner(
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create a Multi Pose Landmark TaskRunner.
|
// Helper function to create a Multi Pose Landmark TaskRunner.
|
||||||
|
@ -189,7 +189,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiPoseTaskRunner(
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
||||||
|
|
|
@ -50,7 +50,7 @@ cc_test_with_tflite(
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":image_tensor_specs",
|
":image_tensor_specs",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
|
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -69,7 +69,7 @@ constexpr char kMobileNetMetadata[] =
|
||||||
constexpr char kMobileNetQuantizedPartialMetadata[] =
|
constexpr char kMobileNetQuantizedPartialMetadata[] =
|
||||||
"mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite";
|
"mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite";
|
||||||
|
|
||||||
class ImageTensorSpecsTest : public tflite_shims::testing::Test {};
|
class ImageTensorSpecsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageTensorSpecsTest, BuildInputImageTensorSpecsWorks) {
|
TEST_F(ImageTensorSpecsTest, BuildInputImageTensorSpecsWorks) {
|
||||||
auto model_file = std::make_unique<core::proto::ExternalFile>();
|
auto model_file = std::make_unique<core::proto::ExternalFile>();
|
||||||
|
|
|
@ -28,7 +28,7 @@ cc_library_with_tflite(
|
||||||
],
|
],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
|
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using ::mediapipe::tasks::core::kModelResourcesCacheService;
|
using ::mediapipe::tasks::core::kModelResourcesCacheService;
|
||||||
|
|
|
@ -125,7 +125,7 @@ cc_library_with_tflite(
|
||||||
srcs = ["tflite_model_loader.cc"],
|
srcs = ["tflite_model_loader.cc"],
|
||||||
hdrs = ["tflite_model_loader.h"],
|
hdrs = ["tflite_model_loader.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
using FlatBufferModel = ::tflite_shims::FlatBufferModel;
|
using FlatBufferModel = ::tflite::FlatBufferModel;
|
||||||
|
|
||||||
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
||||||
const std::string& path) {
|
const std::string& path) {
|
||||||
|
|
|
@ -22,13 +22,13 @@
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
// Represents a TfLite model as a FlatBuffer.
|
// Represents a TfLite model as a FlatBuffer.
|
||||||
using TfLiteModelPtr =
|
using TfLiteModelPtr =
|
||||||
std::unique_ptr<tflite_shims::FlatBufferModel,
|
std::unique_ptr<tflite::FlatBufferModel,
|
||||||
std::function<void(tflite_shims::FlatBufferModel*)>>;
|
std::function<void(tflite::FlatBufferModel*)>>;
|
||||||
|
|
||||||
class TfLiteModelLoader {
|
class TfLiteModelLoader {
|
||||||
public:
|
public:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user