Add GetInputImageTensorSpecs into BaseVisionTaskApi for tasks api users to get input image tensor specifications.
PiperOrigin-RevId: 514650593
This commit is contained in:
parent
2f2a74da6a
commit
dbd6d72696
|
@ -38,10 +38,12 @@ cc_library(
|
|||
":image_processing_options",
|
||||
":running_mode",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -25,12 +25,14 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -44,6 +46,42 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
|
|||
explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
|
||||
RunningMode running_mode)
|
||||
: BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
|
||||
virtual ~BaseVisionTaskApi() {}
|
||||
|
||||
virtual absl::StatusOr<ImageTensorSpecs> GetInputImageTensorSpecs() {
|
||||
ImageTensorSpecs image_tensor_specs;
|
||||
bool found_image_to_tensor_calculator = false;
|
||||
for (auto& node : runner_->GetGraphConfig().node()) {
|
||||
if (node.calculator() == "ImageToTensorCalculator") {
|
||||
if (!found_image_to_tensor_calculator) {
|
||||
found_image_to_tensor_calculator = true;
|
||||
} else {
|
||||
return absl::Status(CreateStatusWithPayload(
|
||||
absl::StatusCode::kFailedPrecondition,
|
||||
absl::StrCat(
|
||||
"The graph has more than one ImageToTensorCalculator.")));
|
||||
}
|
||||
mediapipe::ImageToTensorCalculatorOptions options =
|
||||
node.options().GetExtension(
|
||||
mediapipe::ImageToTensorCalculatorOptions::ext);
|
||||
image_tensor_specs.image_width = options.output_tensor_width();
|
||||
image_tensor_specs.image_height = options.output_tensor_height();
|
||||
image_tensor_specs.color_space =
|
||||
tflite::ColorSpaceType::ColorSpaceType_RGB;
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
image_tensor_specs.tensor_type = tflite::TensorType_UINT8;
|
||||
} else if (options.has_output_tensor_float_range()) {
|
||||
image_tensor_specs.tensor_type = tflite::TensorType_FLOAT32;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found_image_to_tensor_calculator) {
|
||||
return absl::Status(CreateStatusWithPayload(
|
||||
absl::StatusCode::kNotFound,
|
||||
absl::StrCat("The graph doesn't contain ImageToTensorCalculator.")));
|
||||
}
|
||||
return image_tensor_specs;
|
||||
}
|
||||
|
||||
protected:
|
||||
// A synchronous method to process single image inputs.
|
||||
|
|
|
@ -130,7 +130,7 @@ struct ObjectDetectorOptions {
|
|||
//
|
||||
// [1]:
|
||||
// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||
class ObjectDetector : public tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -298,6 +299,36 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInLiveStreamMode) {
|
|||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, InputTensorSpecsForMobileSsdModel) {
|
||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
|
||||
object_detector->GetInputImageTensorSpecs());
|
||||
EXPECT_EQ(image_tensor_specs.image_width, 300);
|
||||
EXPECT_EQ(image_tensor_specs.image_height, 300);
|
||||
EXPECT_EQ(image_tensor_specs.color_space,
|
||||
tflite::ColorSpaceType::ColorSpaceType_RGB);
|
||||
EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
|
||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kEfficientDetWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
|
||||
object_detector->GetInputImageTensorSpecs());
|
||||
EXPECT_EQ(image_tensor_specs.image_width, 320);
|
||||
EXPECT_EQ(image_tensor_specs.image_height, 320);
|
||||
EXPECT_EQ(image_tensor_specs.color_space,
|
||||
tflite::ColorSpaceType::ColorSpaceType_RGB);
|
||||
EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
|
||||
}
|
||||
|
||||
// TODO: Add NumThreadsTest back after having an
|
||||
// "acceleration configuration" field in the ObjectDetectorOptions.
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
|||
// Build ImageTensorSpec from model resources. The tflite model must contain
|
||||
// single subgraph with single input tensor.
|
||||
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
||||
const core::ModelResources& model_resources);
|
||||
const tasks::core::ModelResources& model_resources);
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
|
|
Loading…
Reference in New Issue
Block a user