Add GetInputImageTensorSpecs into BaseVisionTaskApi for tasks api users to get input image tensor specifications.

PiperOrigin-RevId: 514650593
This commit is contained in:
Jiuqiang Tang 2023-03-07 00:41:41 -08:00 committed by Copybara-Service
parent 2f2a74da6a
commit dbd6d72696
5 changed files with 73 additions and 2 deletions

View File

@ -38,10 +38,12 @@ cc_library(
":image_processing_options", ":image_processing_options",
":running_mode", ":running_mode",
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:rect", "//mediapipe/tasks/cc/components/containers:rect",
"//mediapipe/tasks/cc/core:base_task_api", "//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -25,12 +25,14 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -44,6 +46,42 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner, explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
RunningMode running_mode) RunningMode running_mode)
: BaseTaskApi(std::move(runner)), running_mode_(running_mode) {} : BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
virtual ~BaseVisionTaskApi() {}
virtual absl::StatusOr<ImageTensorSpecs> GetInputImageTensorSpecs() {
ImageTensorSpecs image_tensor_specs;
bool found_image_to_tensor_calculator = false;
for (auto& node : runner_->GetGraphConfig().node()) {
if (node.calculator() == "ImageToTensorCalculator") {
if (!found_image_to_tensor_calculator) {
found_image_to_tensor_calculator = true;
} else {
return absl::Status(CreateStatusWithPayload(
absl::StatusCode::kFailedPrecondition,
absl::StrCat(
"The graph has more than one ImageToTensorCalculator.")));
}
mediapipe::ImageToTensorCalculatorOptions options =
node.options().GetExtension(
mediapipe::ImageToTensorCalculatorOptions::ext);
image_tensor_specs.image_width = options.output_tensor_width();
image_tensor_specs.image_height = options.output_tensor_height();
image_tensor_specs.color_space =
tflite::ColorSpaceType::ColorSpaceType_RGB;
if (options.has_output_tensor_uint_range()) {
image_tensor_specs.tensor_type = tflite::TensorType_UINT8;
} else if (options.has_output_tensor_float_range()) {
image_tensor_specs.tensor_type = tflite::TensorType_FLOAT32;
}
}
}
if (!found_image_to_tensor_calculator) {
return absl::Status(CreateStatusWithPayload(
absl::StatusCode::kNotFound,
absl::StrCat("The graph doesn't contain ImageToTensorCalculator.")));
}
return image_tensor_specs;
}
protected: protected:
// A synchronous method to process single image inputs. // A synchronous method to process single image inputs.

View File

@ -130,7 +130,7 @@ struct ObjectDetectorOptions {
// //
// [1]: // [1]:
// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 // https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { class ObjectDetector : public tasks::vision::core::BaseVisionTaskApi {
public: public:
using BaseVisionTaskApi::BaseVisionTaskApi; using BaseVisionTaskApi::BaseVisionTaskApi;

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -298,6 +299,36 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInLiveStreamMode) {
MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
} }
TEST_F(CreateFromOptionsTest, InputTensorSpecsForMobileSsdModel) {
auto options = std::make_unique<ObjectDetectorOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
object_detector->GetInputImageTensorSpecs());
EXPECT_EQ(image_tensor_specs.image_width, 300);
EXPECT_EQ(image_tensor_specs.image_height, 300);
EXPECT_EQ(image_tensor_specs.color_space,
tflite::ColorSpaceType::ColorSpaceType_RGB);
EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
}
TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
auto options = std::make_unique<ObjectDetectorOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kEfficientDetWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
object_detector->GetInputImageTensorSpecs());
EXPECT_EQ(image_tensor_specs.image_width, 320);
EXPECT_EQ(image_tensor_specs.image_height, 320);
EXPECT_EQ(image_tensor_specs.color_space,
tflite::ColorSpaceType::ColorSpaceType_RGB);
EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
}
// TODO: Add NumThreadsTest back after having an // TODO: Add NumThreadsTest back after having an
// "acceleration configuration" field in the ObjectDetectorOptions. // "acceleration configuration" field in the ObjectDetectorOptions.

View File

@ -94,7 +94,7 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
// Build ImageTensorSpec from model resources. The tflite model must contain // Build ImageTensorSpec from model resources. The tflite model must contain
// single subgraph with single input tensor. // single subgraph with single input tensor.
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs( absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const core::ModelResources& model_resources); const tasks::core::ModelResources& model_resources);
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks