Internal change
PiperOrigin-RevId: 485553891
This commit is contained in:
parent
475e6b4fd5
commit
aab5f84aae
|
@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kScoresTag[] = "SCORES";
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
|
|
||||||
|
// Struct holding the different output streams produced by the graph.
|
||||||
|
struct ClassificationPostprocessingOutputStreams {
|
||||||
|
Source<ClassificationResult> classification_result;
|
||||||
|
Source<ClassificationResult> classifications;
|
||||||
|
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||||
|
};
|
||||||
|
|
||||||
// Performs sanity checks on provided ClassifierOptions.
|
// Performs sanity checks on provided ClassifierOptions.
|
||||||
absl::Status SanityCheckClassifierOptions(
|
absl::Status SanityCheckClassifierOptions(
|
||||||
|
@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator.
|
// The output tensors of an InferenceCalculator.
|
||||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
// The collection of timestamps that a single ClassificationResult should
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
// aggregate. This is mostly useful for classifiers working on time series,
|
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||||
// e.g. audio or video classification.
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
|
// required the CLASSIFICATIONS output is used for results.
|
||||||
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
// The output aggregated classification results.
|
// The classification results aggregated by head. Must be connected if the
|
||||||
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
|
// aggregation is not required.
|
||||||
|
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||||
|
// The classification result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
|
// // TODO: remove output once migration is over.
|
||||||
|
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||||
|
// The aggregated classification result.
|
||||||
//
|
//
|
||||||
// The recommended way of using this graph is through the GraphBuilder API
|
// The recommended way of using this graph is through the GraphBuilder API
|
||||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||||
|
@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
mediapipe::SubgraphContext* sc) override {
|
mediapipe::SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto classification_result_out,
|
auto output_streams,
|
||||||
BuildClassificationPostprocessing(
|
BuildClassificationPostprocessing(
|
||||||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||||
classification_result_out >>
|
output_streams.classification_result >>
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||||
|
output_streams.classifications >>
|
||||||
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
|
output_streams.timestamped_classifications >>
|
||||||
|
graph[Output<std::vector<ClassificationResult>>(
|
||||||
|
kTimestampedClassificationsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Adds an on-device classification postprocessing graph into the provided
|
// Adds an on-device classification postprocessing graph into the provided
|
||||||
// builder::Graph instance. The classification postprocessing graph takes
|
// builder::Graph instance. The classification postprocessing graph takes
|
||||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
// tensors (std::vector<mediapipe::Tensor>) and optional timestamps
|
||||||
// stream containing the output classification results (ClassificationResult).
|
// (std::vector<Timestamp>) as input and returns two output streams:
|
||||||
|
// - classification results aggregated by classifier head as a
|
||||||
|
// ClassificationResult proto, used when no timestamps are passed in
|
||||||
|
// the graph,
|
||||||
|
// - classification results aggregated by timestamp then by classifier head
|
||||||
|
// as a std::vector<ClassificationResult>, used when timestamps are passed
|
||||||
|
// in the graph.
|
||||||
//
|
//
|
||||||
// options: the on-device ClassificationPostprocessingGraphOptions.
|
// options: the on-device ClassificationPostprocessingGraphOptions.
|
||||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||||
// timestamps that a single ClassificationResult should aggregate.
|
// timestamps that should be used to aggregate classification results.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<ClassificationResult>>
|
absl::StatusOr<ClassificationPostprocessingOutputStreams>
|
||||||
BuildClassificationPostprocessing(
|
BuildClassificationPostprocessing(
|
||||||
const proto::ClassificationPostprocessingGraphOptions& options,
|
const proto::ClassificationPostprocessingGraphOptions& options,
|
||||||
Source<std::vector<Tensor>> tensors_in,
|
Source<std::vector<Tensor>> tensors_in,
|
||||||
|
@ -505,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
timestamps_in >> result_aggregation.In(kTimestampsTag);
|
timestamps_in >> result_aggregation.In(kTimestampsTag);
|
||||||
|
|
||||||
// Connects output.
|
// Connects output.
|
||||||
return result_aggregation[Output<ClassificationResult>(
|
ClassificationPostprocessingOutputStreams output_streams{
|
||||||
kClassificationResultTag)];
|
/*classification_result=*/result_aggregation
|
||||||
|
[Output<ClassificationResult>(kClassificationResultTag)],
|
||||||
|
/*classifications=*/
|
||||||
|
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
|
||||||
|
/*timestamped_classifications=*/
|
||||||
|
result_aggregation[Output<std::vector<ClassificationResult>>(
|
||||||
|
kTimestampedClassificationsTag)]};
|
||||||
|
return output_streams;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -45,12 +45,22 @@ namespace processors {
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator.
|
// The output tensors of an InferenceCalculator.
|
||||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
// The collection of timestamps that a single ClassificationResult should
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
// aggregate. This is mostly useful for classifiers working on time series,
|
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||||
// e.g. audio or video classification.
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
|
// required the CLASSIFICATIONS output is used for results.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
// The output aggregated classification results.
|
// The classification results aggregated by head. Must be connected if the
|
||||||
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
|
// aggregation is not required.
|
||||||
|
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||||
|
// The classification result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
|
// // TODO: remove output once migration is over.
|
||||||
|
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||||
|
// The aggregated classification result.
|
||||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const proto::ClassifierOptions& classifier_options,
|
const proto::ClassifierOptions& classifier_options,
|
||||||
|
|
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/output_stream_poller.h"
|
#include "mediapipe/framework/output_stream_poller.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
|
@ -64,6 +65,7 @@ using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::Pointwise;
|
||||||
using ::testing::proto::Approximately;
|
using ::testing::proto::Approximately;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||||
|
@ -86,6 +88,11 @@ constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
constexpr char kTimestampsName[] = "timestamps";
|
constexpr char kTimestampsName[] = "timestamps";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
constexpr char kClassificationResultName[] = "classification_result";
|
constexpr char kClassificationResultName[] = "classification_result";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kClassificationsName[] = "classifications";
|
||||||
|
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
|
constexpr char kTimestampedClassificationsName[] =
|
||||||
|
"timestamped_classifications";
|
||||||
|
|
||||||
// Helper function to get ModelResources.
|
// Helper function to get ModelResources.
|
||||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
|
@ -413,6 +420,316 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
}
|
}
|
||||||
|
|
||||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
|
protected:
|
||||||
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
|
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||||
|
bool connect_timestamps = false) {
|
||||||
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
|
CreateModelResourcesForModel(model_name));
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
auto& postprocessing = graph.AddNode(
|
||||||
|
"mediapipe.tasks.components.processors."
|
||||||
|
"ClassificationPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||||
|
*model_resources, options,
|
||||||
|
&postprocessing
|
||||||
|
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
||||||
|
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||||
|
postprocessing.In(kTensorsTag);
|
||||||
|
if (connect_timestamps) {
|
||||||
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||||
|
kTimestampsName) >>
|
||||||
|
postprocessing.In(kTimestampsTag);
|
||||||
|
postprocessing.Out(kTimestampedClassificationsTag)
|
||||||
|
.SetName(kTimestampedClassificationsName) >>
|
||||||
|
graph[Output<std::vector<ClassificationResult>>(
|
||||||
|
kTimestampedClassificationsTag)];
|
||||||
|
} else {
|
||||||
|
postprocessing.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||||
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kTimestampedClassificationsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kClassificationsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AddTensor(
|
||||||
|
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||||
|
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||||
|
tensors_->emplace_back(element_type,
|
||||||
|
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||||
|
quantization_parameters);
|
||||||
|
auto view = tensors_->back().GetCpuWriteView();
|
||||||
|
T* buffer = view.buffer<T>();
|
||||||
|
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Run(
|
||||||
|
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||||
|
int timestamp = 0) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||||
|
// Reset tensors for future calls.
|
||||||
|
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||||
|
if (aggregation_timestamps.has_value()) {
|
||||||
|
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||||
|
for (const auto& timestamp : *aggregation_timestamps) {
|
||||||
|
packet->emplace_back(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||||
|
|
||||||
|
Packet packet;
|
||||||
|
if (!poller.Next(&packet)) {
|
||||||
|
return absl::InternalError("Unable to get output packet");
|
||||||
|
}
|
||||||
|
auto result = packet.Get<T>();
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraph calculator_graph_;
|
||||||
|
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||||
|
absl::make_unique<std::vector<Tensor>>();
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||||
|
// Build graph.
|
||||||
|
proto::ClassifierOptions options;
|
||||||
|
options.set_max_results(3);
|
||||||
|
options.set_score_threshold(0.5);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller,
|
||||||
|
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||||
|
tensor[1] = 18;
|
||||||
|
tensor[2] = 16;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||||
|
/*quantization_parameters=*/{0.1, 10});
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<ClassificationResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_THAT(results,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
classification_list {
|
||||||
|
classification { index: 1 score: 0.8 }
|
||||||
|
classification { index: 2 score: 0.6 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||||
|
// Build graph.
|
||||||
|
proto::ClassifierOptions options;
|
||||||
|
options.set_max_results(3);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||||
|
tensor[1] = 12;
|
||||||
|
tensor[2] = 14;
|
||||||
|
tensor[3] = 16;
|
||||||
|
tensor[4] = 18;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||||
|
/*quantization_parameters=*/{0.1, 10});
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<ClassificationResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_THAT(
|
||||||
|
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "probability"
|
||||||
|
classification_list {
|
||||||
|
classification { index: 4 score: 0.8 label: "tiger shark" }
|
||||||
|
classification { index: 3 score: 0.6 label: "great white shark" }
|
||||||
|
classification { index: 2 score: 0.4 label: "goldfish" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
|
// Build graph.
|
||||||
|
proto::ClassifierOptions options;
|
||||||
|
options.set_max_results(3);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller,
|
||||||
|
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||||
|
tensor[1] = 12;
|
||||||
|
tensor[2] = 14;
|
||||||
|
tensor[3] = 16;
|
||||||
|
tensor[4] = 18;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||||
|
/*quantization_parameters=*/{0.1, 10});
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<ClassificationResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_THAT(
|
||||||
|
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "probability"
|
||||||
|
classification_list {
|
||||||
|
classification { index: 4 score: 0.6899744811 label: "tiger shark" }
|
||||||
|
classification {
|
||||||
|
index: 3
|
||||||
|
score: 0.6456563062
|
||||||
|
label: "great white shark"
|
||||||
|
}
|
||||||
|
classification { index: 2 score: 0.5986876601 label: "goldfish" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||||
|
// Build graph.
|
||||||
|
proto::ClassifierOptions options;
|
||||||
|
options.set_max_results(2);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller,
|
||||||
|
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
|
||||||
|
tensor_0[1] = 0.2;
|
||||||
|
tensor_0[2] = 0.4;
|
||||||
|
tensor_0[3] = 0.6;
|
||||||
|
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
|
||||||
|
tensor_1[1] = 0.2;
|
||||||
|
tensor_1[2] = 0.4;
|
||||||
|
tensor_1[3] = 0.6;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||||
|
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<ClassificationResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_THAT(
|
||||||
|
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "yamnet_classification"
|
||||||
|
classification_list {
|
||||||
|
classification { index: 3 score: 0.6 label: "Narration, monologue" }
|
||||||
|
classification { index: 2 score: 0.4 label: "Conversation" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
classifications {
|
||||||
|
head_index: 1
|
||||||
|
head_name: "bird_classification"
|
||||||
|
classification_list {
|
||||||
|
classification { index: 3 score: 0.6 label: "Azara\'s Spinetail" }
|
||||||
|
classification { index: 2 score: 0.4 label: "House Sparrow" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
|
// Build graph.
|
||||||
|
proto::ClassifierOptions options;
|
||||||
|
options.set_max_results(2);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||||
|
/*connect_timestamps=*/true));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
|
||||||
|
tensor_0[1] = 12;
|
||||||
|
tensor_0[2] = 14;
|
||||||
|
tensor_0[3] = 16;
|
||||||
|
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
|
||||||
|
tensor_1[5] = 12;
|
||||||
|
tensor_1[6] = 14;
|
||||||
|
tensor_1[7] = 16;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
|
||||||
|
/*quantization_parameters=*/{0.1, 10});
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
|
||||||
|
/*quantization_parameters=*/{0.1, 10});
|
||||||
|
MP_ASSERT_OK(Run(
|
||||||
|
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||||
|
/*timestamp=*/1000));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<std::vector<ClassificationResult>>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_THAT(
|
||||||
|
results,
|
||||||
|
Pointwise(
|
||||||
|
EqualsProto(),
|
||||||
|
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
timestamp_ms: 0
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "probability"
|
||||||
|
classification_list {
|
||||||
|
classification {
|
||||||
|
index: 3
|
||||||
|
score: 0.6
|
||||||
|
label: "great white shark"
|
||||||
|
}
|
||||||
|
classification { index: 2 score: 0.4 label: "goldfish" }
|
||||||
|
}
|
||||||
|
})pb"),
|
||||||
|
ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
timestamp_ms: 1
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "probability"
|
||||||
|
classification_list {
|
||||||
|
classification { index: 7 score: 0.6 label: "stingray" }
|
||||||
|
classification { index: 6 score: 0.4 label: "electric ray" }
|
||||||
|
}
|
||||||
|
})pb")}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove these tests once migration is over.
|
||||||
|
class LegacyPostprocessingTest : public tflite_shims::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||||
|
@ -496,7 +813,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
absl::make_unique<std::vector<Tensor>>();
|
absl::make_unique<std::vector<Tensor>>();
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
proto::ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(3);
|
options.set_max_results(3);
|
||||||
|
@ -525,7 +842,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||||
})pb"));
|
})pb"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
proto::ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(3);
|
options.set_max_results(3);
|
||||||
|
@ -568,7 +885,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||||
})pb"));
|
})pb"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
proto::ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(3);
|
options.set_max_results(3);
|
||||||
|
@ -614,7 +931,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
})pb"));
|
})pb"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
proto::ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(2);
|
options.set_max_results(2);
|
||||||
|
@ -674,7 +991,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||||
})pb"));
|
})pb"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
proto::ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(2);
|
options.set_max_results(2);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user