Update base audio/vision tasks api to suuport proto3 graph options.
PiperOrigin-RevId: 538661975
This commit is contained in:
		
							parent
							
								
									a7cd7b9a32
								
							
						
					
					
						commit
						943445fba8
					
				| 
						 | 
				
			
			@ -43,6 +43,7 @@ cc_library(
 | 
			
		|||
        ":base_audio_task_api",
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "mediapipe/framework/calculator.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
| 
						 | 
				
			
			@ -60,13 +61,8 @@ class AudioTaskApiFactory {
 | 
			
		|||
            "Task graph config should only contain one task subgraph node.",
 | 
			
		||||
            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
      } else {
 | 
			
		||||
        if (!node.options().HasExtension(Options::ext)) {
 | 
			
		||||
          return CreateStatusWithPayload(
 | 
			
		||||
              absl::StatusCode::kInvalidArgument,
 | 
			
		||||
              absl::StrCat(node.calculator(),
 | 
			
		||||
                           " is missing the required task options field."),
 | 
			
		||||
              MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
        }
 | 
			
		||||
        MP_RETURN_IF_ERROR(
 | 
			
		||||
            tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
 | 
			
		||||
        found_task_subgraph = true;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,7 +81,6 @@ class TaskApiFactory {
 | 
			
		|||
    return std::make_unique<T>(std::move(runner));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  template <typename Options>
 | 
			
		||||
  static absl::Status CheckHasValidOptions(
 | 
			
		||||
      const CalculatorGraphConfig::Node& node) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,6 +43,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:rect",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +59,7 @@ cc_library(
 | 
			
		|||
        ":base_vision_task_api",
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,6 +26,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "mediapipe/framework/calculator.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -60,13 +61,8 @@ class VisionTaskApiFactory {
 | 
			
		|||
            "Task graph config should only contain one task subgraph node.",
 | 
			
		||||
            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
      } else {
 | 
			
		||||
        if (!node.options().HasExtension(Options::ext)) {
 | 
			
		||||
          return CreateStatusWithPayload(
 | 
			
		||||
              absl::StatusCode::kInvalidArgument,
 | 
			
		||||
              absl::StrCat(node.calculator(),
 | 
			
		||||
                           " is missing the required task options field."),
 | 
			
		||||
              MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
        }
 | 
			
		||||
        MP_RETURN_IF_ERROR(
 | 
			
		||||
            tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
 | 
			
		||||
        found_task_subgraph = true;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user