mediapipe/mediapipe/tasks/cc/components/classification_postprocessing_test.cc
MediaPipe Team f8af41b1eb Internal change
PiperOrigin-RevId: 477538515
2022-09-28 21:32:36 +00:00

735 lines
30 KiB
C++

/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/graph_runner.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/util/label_map.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::ClassifierOptions;
using ::mediapipe::tasks::core::ModelResources;
using ::testing::HasSubstr;
using ::testing::proto::Approximately;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
constexpr char kQuantizedImageClassifierWithMetadata[] =
"vision/mobilenet_v1_0.25_224_quant.tflite";
constexpr char kQuantizedImageClassifierWithDummyScoreCalibration[] =
"vision/mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
constexpr char kQuantizedImageClassifierWithoutMetadata[] =
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] =
"audio/two_heads.tflite";
constexpr char kTestModelResourcesTag[] = "test_model_resources";
constexpr int kMobileNetNumClasses = 1001;
constexpr int kTwoHeadsNumClasses[] = {521, 5};
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTensorsName[] = "tensors";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationResultName[] = "classification_result";
// Helper function to get ModelResources.
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
absl::string_view model_name) {
auto external_file = std::make_unique<core::proto::ExternalFile>();
external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name));
return ModelResources::Create(kTestModelResourcesTag,
std::move(external_file));
}
class ConfigureTest : public tflite_shims::testing::Test {};
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
options_in.set_max_results(0);
ClassificationPostprocessingOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
}
TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
options_in.add_category_allowlist("foo");
options_in.add_category_denylist("bar");
ClassificationPostprocessingOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
}
TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
options_in.add_category_allowlist("foo");
ClassificationPostprocessingOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
status.message(),
HasSubstr("requires labels to be present in the TFLite Model Metadata"));
}
TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
}
classification_aggregation_options {}
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
options_in.set_max_results(3);
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: 3
sort_by_descending_score: true
}
classification_aggregation_options {}
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
options_in.set_score_threshold(0.5);
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: 0.5
top_k: -1
sort_by_descending_score: true
}
classification_aggregation_options {}
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
// Check label map size and two first elements.
EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(),
kMobileNetNumClasses);
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(0),
EqualsProto(R"pb(name: "background")pb"));
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(1),
EqualsProto(R"pb(name: "tench")pb"));
// Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
}
classification_aggregation_options {
head_names: "probability"
}
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithAllowlist) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
options_in.add_category_allowlist("tench");
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
// Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
allow_classes: 1
}
classification_aggregation_options {
head_names: "probability"
}
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithDenylist) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
options_in.add_category_denylist("background");
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
// Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
ignore_classes: 0
}
classification_aggregation_options {
head_names: "probability"
}
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(
kQuantizedImageClassifierWithDummyScoreCalibration));
ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
// Check label map size and two first elements.
EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(),
kMobileNetNumClasses);
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(0),
EqualsProto(R"pb(name: "background")pb"));
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(1),
EqualsProto(R"pb(name: "tench")pb"));
// Clear label map.
options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items();
// Check sigmoids size and first element.
EXPECT_EQ(options_out.score_calibration_options_size(), 1);
auto score_calibration_options =
options_out.score_calibration_options().at(0);
EXPECT_EQ(score_calibration_options.sigmoids_size(), kMobileNetNumClasses);
EXPECT_THAT(score_calibration_options.sigmoids(0),
EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb"));
options_out.mutable_score_calibration_options()->at(0).clear_sigmoids();
// Compare the rest of the options.
EXPECT_THAT(
options_out,
Approximately(EqualsProto(
R"pb(score_calibration_options {
key: 0
value { score_transformation: IDENTITY default_score: 0.5 }
}
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
}
classification_aggregation_options { head_names: "probability" }
has_quantized_outputs: true
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
// Check label maps sizes and first two elements.
EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(),
kTwoHeadsNumClasses[0]);
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(0),
EqualsProto(R"pb(name: "Speech")pb"));
EXPECT_THAT(
options_out.tensors_to_classifications_options(0).label_items().at(1),
EqualsProto(R"pb(name: "Child speech, kid speaking")pb"));
EXPECT_EQ(
options_out.tensors_to_classifications_options(1).label_items_size(),
kTwoHeadsNumClasses[1]);
EXPECT_THAT(
options_out.tensors_to_classifications_options(1).label_items().at(0),
EqualsProto(R"pb(name: "Red Crossbill")pb"));
EXPECT_THAT(
options_out.tensors_to_classifications_options(1).label_items().at(1),
EqualsProto(R"pb(name: "White-breasted Wood-Wren")pb"));
// Clear label maps and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0)
->clear_label_items();
options_out.mutable_tensors_to_classifications_options(1)
->clear_label_items();
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
}
tensors_to_classifications_options {
min_score_threshold: -3.4028235e+38
top_k: -1
sort_by_descending_score: true
}
classification_aggregation_options {
head_names: "yamnet_classification"
head_names: "bird_classification"
}
has_quantized_outputs: false
)pb")));
}
class PostprocessingTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const ClassifierOptions& options,
bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
*model_resources, options,
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
postprocessing.In(kTimestampsTag);
}
postprocessing.Out(kClassificationResultTag)
.SetName(kClassificationResultName) >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationResultName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
template <typename T>
void AddTensor(
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
const Tensor::QuantizationParameters& quantization_parameters = {}) {
tensors_->emplace_back(element_type,
Tensor::Shape{1, static_cast<int>(tensor.size())},
quantization_parameters);
auto view = tensors_->back().GetCpuWriteView();
T* buffer = view.buffer<T>();
std::copy(tensor.begin(), tensor.end(), buffer);
}
absl::Status Run(
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
int timestamp = 0) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
// Reset tensors for future calls.
tensors_ = absl::make_unique<std::vector<Tensor>>();
if (aggregation_timestamps.has_value()) {
auto packet = absl::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
absl::StatusOr<ClassificationResult> GetClassificationResult(
OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<ClassificationResult>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
std::unique_ptr<std::vector<Tensor>> tensors_ =
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
// Build graph.
ClassifierOptions options;
options.set_max_results(3);
options.set_score_threshold(0.5);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 18;
tensor[2] = 16;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(results, EqualsProto(R"pb(classifications {
entries {
categories { index: 1 score: 0.8 }
categories { index: 2 score: 0.6 }
timestamp_ms: 0
}
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Build graph.
ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(
results,
EqualsProto(
R"pb(classifications {
entries {
categories {
index: 4
score: 0.8
category_name: "tiger shark"
}
categories {
index: 3
score: 0.6
category_name: "great white shark"
}
categories { index: 2 score: 0.4 category_name: "goldfish" }
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph.
ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(results, EqualsProto(
R"pb(classifications {
entries {
categories {
index: 4
score: 0.6899744811
category_name: "tiger shark"
}
categories {
index: 3
score: 0.6456563062
category_name: "great white shark"
}
categories {
index: 2
score: 0.5986876601
category_name: "goldfish"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph.
ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
// Build input tensors.
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
tensor_0[1] = 0.2;
tensor_0[2] = 0.4;
tensor_0[3] = 0.6;
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
tensor_1[1] = 0.2;
tensor_1[2] = 0.4;
tensor_1[3] = 0.6;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
EXPECT_THAT(results, EqualsProto(
R"pb(classifications {
entries {
categories {
index: 3
score: 0.6
category_name: "Narration, monologue"
}
categories {
index: 2
score: 0.4
category_name: "Conversation"
}
timestamp_ms: 0
}
head_index: 0
head_name: "yamnet_classification"
}
classifications {
entries {
categories {
index: 3
score: 0.6
category_name: "Azara\'s Spinetail"
}
categories {
index: 2
score: 0.4
category_name: "House Sparrow"
}
timestamp_ms: 0
}
head_index: 1
head_name: "bird_classification"
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
/*connect_timestamps=*/true));
// Build input tensors.
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
tensor_0[1] = 12;
tensor_0[2] = 14;
tensor_0[3] = 16;
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
tensor_1[5] = 12;
tensor_1[6] = 14;
tensor_1[7] = 16;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run(
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
/*timestamp=*/1000));
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(
results,
EqualsProto(
R"pb(classifications {
entries {
categories {
index: 3
score: 0.6
category_name: "great white shark"
}
categories { index: 2 score: 0.4 category_name: "goldfish" }
timestamp_ms: 0
}
entries {
categories { index: 7 score: 0.6 category_name: "stingray" }
categories {
index: 6
score: 0.4
category_name: "electric ray"
}
timestamp_ms: 1
}
head_index: 0
head_name: "probability"
})pb"));
}
} // namespace
} // namespace components
} // namespace tasks
} // namespace mediapipe