Rename embedding postprocessor "configure" method for consistency with classification postprocessor.
PiperOrigin-RevId: 489518257
This commit is contained in:
parent
ac212c1507
commit
e2052a6a51
|
@ -158,10 +158,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph {
|
|||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
// Time aggregation is only needed for performing audio embedding on
|
||||
// audio files. Disables timestamp aggregation by not connecting the
|
||||
|
|
|
@ -150,7 +150,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
|
|||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureEmbeddingPostprocessing(
|
||||
absl::Status ConfigureEmbeddingPostprocessingGraph(
|
||||
const ModelResources& model_resources,
|
||||
const proto::EmbedderOptions& embedder_options,
|
||||
proto::EmbeddingPostprocessingGraphOptions* options) {
|
||||
|
@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing(
|
|||
// timestamp aggregation is required.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API using
|
||||
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
||||
// details.
|
||||
// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for
|
||||
// more details.
|
||||
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
|
|
|
@ -58,7 +58,7 @@ namespace processors {
|
|||
// The embedding result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
absl::Status ConfigureEmbeddingPostprocessing(
|
||||
absl::Status ConfigureEmbeddingPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::EmbedderOptions& embedder_options,
|
||||
proto::EmbeddingPostprocessingGraphOptions* options);
|
||||
|
|
|
@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
|||
options_in.set_l2_normalize(true);
|
||||
|
||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||
&options_out));
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
|
||||
options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
|
@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
|||
options_in.set_quantize(true);
|
||||
|
||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||
&options_out));
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
|
||||
options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
|
@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
|||
options_in.set_l2_normalize(true);
|
||||
|
||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||
&options_out));
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
|
||||
options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
|
@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
|||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
|
||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
&postprocessing
|
||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||
|
|
|
@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
|
|||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the embedding result.
|
||||
|
|
|
@ -151,10 +151,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
|||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the embedding results.
|
||||
|
|
Loading…
Reference in New Issue
Block a user