Rename embedding postprocessor "configure" method for consistency with classification postprocessor.

PiperOrigin-RevId: 489518257
This commit is contained in:
MediaPipe Team 2022-11-18 11:11:22 -08:00 committed by Copybara-Service
parent ac212c1507
commit e2052a6a51
6 changed files with 29 additions and 23 deletions

View File

@ -158,10 +158,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph {
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
MP_RETURN_IF_ERROR(
components::processors::ConfigureEmbeddingPostprocessingGraph(
model_resources, task_options.embedder_options(),
&postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio embedding on
// audio files. Disables timestamp aggregation by not connecting the

View File

@ -150,7 +150,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
} // namespace
absl::Status ConfigureEmbeddingPostprocessing(
absl::Status ConfigureEmbeddingPostprocessingGraph(
const ModelResources& model_resources,
const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options) {
@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing(
// timestamp aggregation is required.
//
// The recommended way of using this graph is through the GraphBuilder API using
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
// details.
// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for
// more details.
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(

View File

@ -58,7 +58,7 @@ namespace processors {
// The embedding result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
absl::Status ConfigureEmbeddingPostprocessing(
absl::Status ConfigureEmbeddingPostprocessingGraph(
const tasks::core::ModelResources& model_resources,
const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options);

View File

@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
options_in.set_l2_normalize(true);
proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out));
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
options_in, &options_out));
EXPECT_THAT(
options_out,
@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
options_in.set_quantize(true);
proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out));
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
options_in, &options_out));
EXPECT_THAT(
options_out,
@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
options_in.set_l2_normalize(true);
proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out));
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
options_in, &options_out));
EXPECT_THAT(
options_out,
@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
*model_resources, options,
&postprocessing
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));

View File

@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
MP_RETURN_IF_ERROR(
components::processors::ConfigureEmbeddingPostprocessingGraph(
model_resources, task_options.embedder_options(),
&postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding result.

View File

@ -151,10 +151,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
MP_RETURN_IF_ERROR(
components::processors::ConfigureEmbeddingPostprocessingGraph(
model_resources, task_options.embedder_options(),
&postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding results.