Add GetInputImageTensorSpecs into BaseVisionTaskApi for tasks api users to get input image tensor specifications.
PiperOrigin-RevId: 514650593
This commit is contained in:
		
							parent
							
								
									2f2a74da6a
								
							
						
					
					
						commit
						dbd6d72696
					
				| 
						 | 
				
			
			@ -38,10 +38,12 @@ cc_library(
 | 
			
		|||
        ":image_processing_options",
 | 
			
		||||
        ":running_mode",
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:rect",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,12 +25,14 @@ limitations under the License.
 | 
			
		|||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_runner.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
| 
						 | 
				
			
			@ -44,6 +46,42 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
 | 
			
		|||
  explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
 | 
			
		||||
                             RunningMode running_mode)
 | 
			
		||||
      : BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
 | 
			
		||||
  virtual ~BaseVisionTaskApi() {}
 | 
			
		||||
 | 
			
		||||
  virtual absl::StatusOr<ImageTensorSpecs> GetInputImageTensorSpecs() {
 | 
			
		||||
    ImageTensorSpecs image_tensor_specs;
 | 
			
		||||
    bool found_image_to_tensor_calculator = false;
 | 
			
		||||
    for (auto& node : runner_->GetGraphConfig().node()) {
 | 
			
		||||
      if (node.calculator() == "ImageToTensorCalculator") {
 | 
			
		||||
        if (!found_image_to_tensor_calculator) {
 | 
			
		||||
          found_image_to_tensor_calculator = true;
 | 
			
		||||
        } else {
 | 
			
		||||
          return absl::Status(CreateStatusWithPayload(
 | 
			
		||||
              absl::StatusCode::kFailedPrecondition,
 | 
			
		||||
              absl::StrCat(
 | 
			
		||||
                  "The graph has more than one ImageToTensorCalculator.")));
 | 
			
		||||
        }
 | 
			
		||||
        mediapipe::ImageToTensorCalculatorOptions options =
 | 
			
		||||
            node.options().GetExtension(
 | 
			
		||||
                mediapipe::ImageToTensorCalculatorOptions::ext);
 | 
			
		||||
        image_tensor_specs.image_width = options.output_tensor_width();
 | 
			
		||||
        image_tensor_specs.image_height = options.output_tensor_height();
 | 
			
		||||
        image_tensor_specs.color_space =
 | 
			
		||||
            tflite::ColorSpaceType::ColorSpaceType_RGB;
 | 
			
		||||
        if (options.has_output_tensor_uint_range()) {
 | 
			
		||||
          image_tensor_specs.tensor_type = tflite::TensorType_UINT8;
 | 
			
		||||
        } else if (options.has_output_tensor_float_range()) {
 | 
			
		||||
          image_tensor_specs.tensor_type = tflite::TensorType_FLOAT32;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (!found_image_to_tensor_calculator) {
 | 
			
		||||
      return absl::Status(CreateStatusWithPayload(
 | 
			
		||||
          absl::StatusCode::kNotFound,
 | 
			
		||||
          absl::StrCat("The graph doesn't contain ImageToTensorCalculator.")));
 | 
			
		||||
    }
 | 
			
		||||
    return image_tensor_specs;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // A synchronous method to process single image inputs.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -130,7 +130,7 @@ struct ObjectDetectorOptions {
 | 
			
		|||
//
 | 
			
		||||
// [1]:
 | 
			
		||||
// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
 | 
			
		||||
class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
 | 
			
		||||
class ObjectDetector : public tasks::vision::core::BaseVisionTaskApi {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseVisionTaskApi::BaseVisionTaskApi;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,6 +39,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
 | 
			
		||||
#include "tensorflow/lite/c/common.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -298,6 +299,36 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInLiveStreamMode) {
 | 
			
		|||
                  MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CreateFromOptionsTest, InputTensorSpecsForMobileSsdModel) {
 | 
			
		||||
  auto options = std::make_unique<ObjectDetectorOptions>();
 | 
			
		||||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
 | 
			
		||||
                          ObjectDetector::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
 | 
			
		||||
                          object_detector->GetInputImageTensorSpecs());
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.image_width, 300);
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.image_height, 300);
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.color_space,
 | 
			
		||||
            tflite::ColorSpaceType::ColorSpaceType_RGB);
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
 | 
			
		||||
  auto options = std::make_unique<ObjectDetectorOptions>();
 | 
			
		||||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kEfficientDetWithMetadata);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
 | 
			
		||||
                          ObjectDetector::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
 | 
			
		||||
                          object_detector->GetInputImageTensorSpecs());
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.image_width, 320);
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.image_height, 320);
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.color_space,
 | 
			
		||||
            tflite::ColorSpaceType::ColorSpaceType_RGB);
 | 
			
		||||
  EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: Add NumThreadsTest back after having an
 | 
			
		||||
// "acceleration configuration" field in the ObjectDetectorOptions.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -94,7 +94,7 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
 | 
			
		|||
// Build ImageTensorSpec from model resources. The tflite model must contain
 | 
			
		||||
// single subgraph with single input tensor.
 | 
			
		||||
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
 | 
			
		||||
    const core::ModelResources& model_resources);
 | 
			
		||||
    const tasks::core::ModelResources& model_resources);
 | 
			
		||||
 | 
			
		||||
}  // namespace vision
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user