Add GetInputImageTensorSpecs into BaseVisionTaskApi for tasks api users to get input image tensor specifications.
PiperOrigin-RevId: 514650593
This commit is contained in:
parent
2f2a74da6a
commit
dbd6d72696
|
@ -38,10 +38,12 @@ cc_library(
|
||||||
":image_processing_options",
|
":image_processing_options",
|
||||||
":running_mode",
|
":running_mode",
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:rect",
|
"//mediapipe/tasks/cc/components/containers:rect",
|
||||||
"//mediapipe/tasks/cc/core:base_task_api",
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -25,12 +25,14 @@ limitations under the License.
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -44,6 +46,42 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
|
||||||
explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
|
explicit BaseVisionTaskApi(std::unique_ptr<tasks::core::TaskRunner> runner,
|
||||||
RunningMode running_mode)
|
RunningMode running_mode)
|
||||||
: BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
|
: BaseTaskApi(std::move(runner)), running_mode_(running_mode) {}
|
||||||
|
virtual ~BaseVisionTaskApi() {}
|
||||||
|
|
||||||
|
virtual absl::StatusOr<ImageTensorSpecs> GetInputImageTensorSpecs() {
|
||||||
|
ImageTensorSpecs image_tensor_specs;
|
||||||
|
bool found_image_to_tensor_calculator = false;
|
||||||
|
for (auto& node : runner_->GetGraphConfig().node()) {
|
||||||
|
if (node.calculator() == "ImageToTensorCalculator") {
|
||||||
|
if (!found_image_to_tensor_calculator) {
|
||||||
|
found_image_to_tensor_calculator = true;
|
||||||
|
} else {
|
||||||
|
return absl::Status(CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kFailedPrecondition,
|
||||||
|
absl::StrCat(
|
||||||
|
"The graph has more than one ImageToTensorCalculator.")));
|
||||||
|
}
|
||||||
|
mediapipe::ImageToTensorCalculatorOptions options =
|
||||||
|
node.options().GetExtension(
|
||||||
|
mediapipe::ImageToTensorCalculatorOptions::ext);
|
||||||
|
image_tensor_specs.image_width = options.output_tensor_width();
|
||||||
|
image_tensor_specs.image_height = options.output_tensor_height();
|
||||||
|
image_tensor_specs.color_space =
|
||||||
|
tflite::ColorSpaceType::ColorSpaceType_RGB;
|
||||||
|
if (options.has_output_tensor_uint_range()) {
|
||||||
|
image_tensor_specs.tensor_type = tflite::TensorType_UINT8;
|
||||||
|
} else if (options.has_output_tensor_float_range()) {
|
||||||
|
image_tensor_specs.tensor_type = tflite::TensorType_FLOAT32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found_image_to_tensor_calculator) {
|
||||||
|
return absl::Status(CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kNotFound,
|
||||||
|
absl::StrCat("The graph doesn't contain ImageToTensorCalculator.")));
|
||||||
|
}
|
||||||
|
return image_tensor_specs;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// A synchronous method to process single image inputs.
|
// A synchronous method to process single image inputs.
|
||||||
|
|
|
@ -130,7 +130,7 @@ struct ObjectDetectorOptions {
|
||||||
//
|
//
|
||||||
// [1]:
|
// [1]:
|
||||||
// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||||
class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
class ObjectDetector : public tasks::vision::core::BaseVisionTaskApi {
|
||||||
public:
|
public:
|
||||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -298,6 +299,36 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInLiveStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, InputTensorSpecsForMobileSsdModel) {
|
||||||
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
|
ObjectDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
|
||||||
|
object_detector->GetInputImageTensorSpecs());
|
||||||
|
EXPECT_EQ(image_tensor_specs.image_width, 300);
|
||||||
|
EXPECT_EQ(image_tensor_specs.image_height, 300);
|
||||||
|
EXPECT_EQ(image_tensor_specs.color_space,
|
||||||
|
tflite::ColorSpaceType::ColorSpaceType_RGB);
|
||||||
|
EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
|
||||||
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kEfficientDetWithMetadata);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
|
ObjectDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto image_tensor_specs,
|
||||||
|
object_detector->GetInputImageTensorSpecs());
|
||||||
|
EXPECT_EQ(image_tensor_specs.image_width, 320);
|
||||||
|
EXPECT_EQ(image_tensor_specs.image_height, 320);
|
||||||
|
EXPECT_EQ(image_tensor_specs.color_space,
|
||||||
|
tflite::ColorSpaceType::ColorSpaceType_RGB);
|
||||||
|
EXPECT_EQ(image_tensor_specs.tensor_type, tflite::TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add NumThreadsTest back after having an
|
// TODO: Add NumThreadsTest back after having an
|
||||||
// "acceleration configuration" field in the ObjectDetectorOptions.
|
// "acceleration configuration" field in the ObjectDetectorOptions.
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
||||||
// Build ImageTensorSpec from model resources. The tflite model must contain
|
// Build ImageTensorSpec from model resources. The tflite model must contain
|
||||||
// single subgraph with single input tensor.
|
// single subgraph with single input tensor.
|
||||||
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
||||||
const core::ModelResources& model_resources);
|
const tasks::core::ModelResources& model_resources);
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
Loading…
Reference in New Issue
Block a user