Rename embedding postprocessor "configure" method for consistency with classification postprocessor.
PiperOrigin-RevId: 489518257
This commit is contained in:
parent
ac212c1507
commit
e2052a6a51
|
@ -158,9 +158,11 @@ class AudioEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// inference results.
|
// inference results.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||||
model_resources, task_options.embedder_options(),
|
model_resources, task_options.embedder_options(),
|
||||||
&postprocessing.GetOptions<components::processors::proto::
|
&postprocessing
|
||||||
|
.GetOptions<components::processors::proto::
|
||||||
EmbeddingPostprocessingGraphOptions>()));
|
EmbeddingPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
// Time aggregation is only needed for performing audio embedding on
|
// Time aggregation is only needed for performing audio embedding on
|
||||||
|
|
|
@ -150,7 +150,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessingGraph(
|
||||||
const ModelResources& model_resources,
|
const ModelResources& model_resources,
|
||||||
const proto::EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
proto::EmbeddingPostprocessingGraphOptions* options) {
|
proto::EmbeddingPostprocessingGraphOptions* options) {
|
||||||
|
@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
// timestamp aggregation is required.
|
// timestamp aggregation is required.
|
||||||
//
|
//
|
||||||
// The recommended way of using this graph is through the GraphBuilder API using
|
// The recommended way of using this graph is through the GraphBuilder API using
|
||||||
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for
|
||||||
// details.
|
// more details.
|
||||||
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||||
|
|
|
@ -58,7 +58,7 @@ namespace processors {
|
||||||
// The embedding result aggregated by timestamp, then by head. Must be
|
// The embedding result aggregated by timestamp, then by head. Must be
|
||||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
// timestamp aggregation is required.
|
// timestamp aggregation is required.
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessingGraph(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const proto::EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
proto::EmbeddingPostprocessingGraphOptions* options);
|
proto::EmbeddingPostprocessingGraphOptions* options);
|
||||||
|
|
|
@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
options_in.set_l2_normalize(true);
|
options_in.set_l2_normalize(true);
|
||||||
|
|
||||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
|
||||||
&options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
|
@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
||||||
options_in.set_quantize(true);
|
options_in.set_quantize(true);
|
||||||
|
|
||||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
|
||||||
&options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
|
@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
options_in.set_l2_normalize(true);
|
options_in.set_l2_normalize(true);
|
||||||
|
|
||||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources,
|
||||||
&options_out));
|
options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
|
@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.processors."
|
"mediapipe.tasks.components.processors."
|
||||||
"EmbeddingPostprocessingGraph");
|
"EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
|
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
|
||||||
*model_resources, options,
|
*model_resources, options,
|
||||||
&postprocessing
|
&postprocessing
|
||||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||||
|
|
|
@ -128,9 +128,11 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// inference results.
|
// inference results.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||||
model_resources, task_options.embedder_options(),
|
model_resources, task_options.embedder_options(),
|
||||||
&postprocessing.GetOptions<components::processors::proto::
|
&postprocessing
|
||||||
|
.GetOptions<components::processors::proto::
|
||||||
EmbeddingPostprocessingGraphOptions>()));
|
EmbeddingPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
|
|
|
@ -151,9 +151,11 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// inference results.
|
// inference results.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||||
model_resources, task_options.embedder_options(),
|
model_resources, task_options.embedder_options(),
|
||||||
&postprocessing.GetOptions<components::processors::proto::
|
&postprocessing
|
||||||
|
.GetOptions<components::processors::proto::
|
||||||
EmbeddingPostprocessingGraphOptions>()));
|
EmbeddingPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user