Add GetInputImageTensorSpecs into BaseVisionTaskApi for tasks api users to get input image tensor specifications.
PiperOrigin-RevId: 514650593
This commit is contained in:
		
							parent
							
								
									2f2a74da6a
								
							
						
					
					
						commit
						dbd6d72696
					
				| 
						 | 
					@ -38,10 +38,12 @@ cc_library(
 | 
				
			||||||
        ":image_processing_options",
 | 
					        ":image_processing_options",
 | 
				
			||||||
        ":running_mode",
 | 
					        ":running_mode",
 | 
				
			||||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
					        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
					        "//mediapipe/framework/formats:rect_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers:rect",
 | 
					        "//mediapipe/tasks/cc/components/containers:rect",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
					        "//mediapipe/tasks/cc/core:base_task_api",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
					        "//mediapipe/tasks/cc/core:task_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,12 +25,14 @@ limitations under the License.
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
#include "absl/strings/str_cat.h"
 | 
					#include "absl/strings/str_cat.h"
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
					#include "mediapipe/framework/formats/rect.pb.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
					#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
 | 
					#include "mediapipe/tasks/cc/core/base_task_api.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
 | 
					#include "mediapipe/tasks/cc/core/task_runner.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
					#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
					#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
namespace tasks {
 | 
					namespace tasks {
 | 
				
			||||||
| 
						 | 
					@ -44,6 +46,42 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
 | 
				
			||||||
  explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
 | 
					  explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
 | 
				
			||||||
                             RunningMode running_mode)
 | 
					                             RunningMode running_mode)
 | 
				
			||||||
      : BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
 | 
					      : BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
 | 
				
			||||||
 | 
					  virtual ~BaseVisionTaskApi() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual absl::StatusOr<ImageTensorSpecs> GetInputImageTensorSpecs() {
 | 
				
			||||||
 | 
					    ImageTensorSpecs image_tensor_specs;
 | 
				
			||||||
 | 
					    bool found_image_to_tensor_calculator = false;
 | 
				
			||||||
 | 
					    for (auto& node : runner_->GetGraphConfig().node()) {
 | 
				
			||||||
 | 
					      if (node.calculator() == "ImageToTensorCalculator") {
 | 
				
			||||||
 | 
					        if (!found_image_to_tensor_calculator) {
 | 
				
			||||||
 | 
					          found_image_to_tensor_calculator = true;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          return absl::Status(CreateStatusWithPayload(
 | 
				
			||||||
 | 
					              absl::StatusCode::kFailedPrecondition,
 | 
				
			||||||
 | 
					              absl::StrCat(
 | 
				
			||||||
 | 
					                  "The graph has more than one ImageToTensorCalculator.")));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        mediapipe::ImageToTensorCalculatorOptions options =
 | 
				
			||||||
 | 
					            node.options().GetExtension(
 | 
				
			||||||
 | 
					                mediapipe::ImageToTensorCalculatorOptions::ext);
 | 
				
			||||||
 | 
					        image_tensor_specs.image_width = options.output_tensor_width();
 | 
				
			||||||
 | 
					        image_tensor_specs.image_height = options.output_tensor_height();
 | 
				
			||||||
 | 
					        image_tensor_specs.color_space =
 | 
				
			||||||
 | 
					            tflite::ColorSpaceType::ColorSpaceType_RGB;
 | 
				
			||||||
 | 
					        if (options.has_output_tensor_uint_range()) {
 | 
				
			||||||
 | 
					          image_tensor_specs.tensor_type = tflite::TensorType_UINT8;
 | 
				
			||||||
 | 
					        } else if (options.has_output_tensor_float_range()) {
 | 
				
			||||||
 | 
					          image_tensor_specs.tensor_type = tflite::TensorType_FLOAT32;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (!found_image_to_tensor_calculator) {
 | 
				
			||||||
 | 
					      return absl::Status(CreateStatusWithPayload(
 | 
				
			||||||
 | 
					          absl::StatusCode::kNotFound,
 | 
				
			||||||
 | 
					          absl::StrCat("The graph doesn't contain ImageToTensorCalculator.")));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return image_tensor_specs;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  // A synchronous method to process single image inputs.
 | 
					  // A synchronous method to process single image inputs.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -130,7 +130,7 @@ struct ObjectDetectorOptions {
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// [1]:
 | 
					// [1]:
 | 
				
			||||||
// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
 | 
					// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
 | 
				
			||||||
class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
 | 
					class ObjectDetector : public tasks::vision::core::BaseVisionTaskApi {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  using BaseVisionTaskApi::BaseVisionTaskApi;
 | 
					  using BaseVisionTaskApi::BaseVisionTaskApi;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,6 +39,7 @@ limitations under the License.
 | 
				
			||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
					#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
					#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
					#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
 | 
					#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
 | 
				
			||||||
#include "tensorflow/lite/c/common.h"
 | 
					#include "tensorflow/lite/c/common.h"
 | 
				
			||||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
					#include "tensorflow/lite/core/api/op_resolver.h"
 | 
				
			||||||
| 
						 | 
					@ -298,6 +299,36 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInLiveStreamMode) {
 | 
				
			||||||
                  MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
 | 
					                  MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(CreateFromOptionsTest, InputTensorSpecsForMobileSsdModel) {
 | 
				
			||||||
 | 
					  auto options = std::make_unique<ObjectDetectorOptions>();
 | 
				
			||||||
 | 
					  options->base_options.model_asset_path =
 | 
				
			||||||
 | 
					      JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
 | 
				
			||||||
 | 
					  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
 | 
				
			||||||
 | 
					                          ObjectDetector::Create(std::move(options)));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
 | 
				
			||||||
 | 
					                          object_detector->GetInputImageTensorSpecs());
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.image_width, 300);
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.image_height, 300);
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.color_space,
 | 
				
			||||||
 | 
					            tflite::ColorSpaceType::ColorSpaceType_RGB);
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
 | 
				
			||||||
 | 
					  auto options = std::make_unique<ObjectDetectorOptions>();
 | 
				
			||||||
 | 
					  options->base_options.model_asset_path =
 | 
				
			||||||
 | 
					      JoinPath("./", kTestDataDirectory, kEfficientDetWithMetadata);
 | 
				
			||||||
 | 
					  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
 | 
				
			||||||
 | 
					                          ObjectDetector::Create(std::move(options)));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
 | 
				
			||||||
 | 
					                          object_detector->GetInputImageTensorSpecs());
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.image_width, 320);
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.image_height, 320);
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.color_space,
 | 
				
			||||||
 | 
					            tflite::ColorSpaceType::ColorSpaceType_RGB);
 | 
				
			||||||
 | 
					  EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO: Add NumThreadsTest back after having an
 | 
					// TODO: Add NumThreadsTest back after having an
 | 
				
			||||||
// "acceleration configuration" field in the ObjectDetectorOptions.
 | 
					// "acceleration configuration" field in the ObjectDetectorOptions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,7 +94,7 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
 | 
				
			||||||
// Build ImageTensorSpec from model resources. The tflite model must contain
 | 
					// Build ImageTensorSpec from model resources. The tflite model must contain
 | 
				
			||||||
// single subgraph with single input tensor.
 | 
					// single subgraph with single input tensor.
 | 
				
			||||||
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
 | 
					absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
 | 
				
			||||||
    const core::ModelResources& model_resources);
 | 
					    const tasks::core::ModelResources& model_resources);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace vision
 | 
					}  // namespace vision
 | 
				
			||||||
}  // namespace tasks
 | 
					}  // namespace tasks
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user