Rename embedding postprocessor "configure" method for consistency with classification postprocessor.

PiperOrigin-RevId: 489518257
This commit is contained in:
MediaPipe Team 2022-11-18 11:11:22 -08:00 committed by Copybara-Service
parent ac212c1507
commit e2052a6a51
6 changed files with 29 additions and 23 deletions

View File

@ -158,10 +158,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph {
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(
model_resources, task_options.embedder_options(), components::processors::ConfigureEmbeddingPostprocessingGraph(
&postprocessing.GetOptions<components::processors::proto:: model_resources, task_options.embedder_options(),
EmbeddingPostprocessingGraphOptions>())); &postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio embedding on // Time aggregation is only needed for performing audio embedding on
// audio files. Disables timestamp aggregation by not connecting the // audio files. Disables timestamp aggregation by not connecting the

View File

@ -150,7 +150,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
} // namespace } // namespace
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessingGraph(
const ModelResources& model_resources, const ModelResources& model_resources,
const proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options) { proto::EmbeddingPostprocessingGraphOptions* options) {
@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing(
// timestamp aggregation is required. // timestamp aggregation is required.
// //
// The recommended way of using this graph is through the GraphBuilder API using // The recommended way of using this graph is through the GraphBuilder API using
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more // the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for
// details. // more details.
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
public: public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(

View File

@ -58,7 +58,7 @@ namespace processors {
// The embedding result aggregated by timestamp, then by head. Must be // The embedding result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that // connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required. // timestamp aggregation is required.
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessingGraph(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options); proto::EmbeddingPostprocessingGraphOptions* options);

View File

@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
proto::EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
&options_out)); options_in, &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
options_in.set_quantize(true); options_in.set_quantize(true);
proto::EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
&options_out)); options_in, &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
proto::EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
&options_out)); options_in, &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors." "mediapipe.tasks.components.processors."
"EmbeddingPostprocessingGraph"); "EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
*model_resources, options, *model_resources, options,
&postprocessing &postprocessing
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>())); .GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));

View File

@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(
model_resources, task_options.embedder_options(), components::processors::ConfigureEmbeddingPostprocessingGraph(
&postprocessing.GetOptions<components::processors::proto:: model_resources, task_options.embedder_options(),
EmbeddingPostprocessingGraphOptions>())); &postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding result. // Outputs the embedding result.

View File

@ -151,10 +151,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(
model_resources, task_options.embedder_options(), components::processors::ConfigureEmbeddingPostprocessingGraph(
&postprocessing.GetOptions<components::processors::proto:: model_resources, task_options.embedder_options(),
EmbeddingPostprocessingGraphOptions>())); &postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding results. // Outputs the embedding results.