Internal change
PiperOrigin-RevId: 485553891
This commit is contained in:
parent
475e6b4fd5
commit
aab5f84aae
|
@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
|||
constexpr char kScoresTag[] = "SCORES";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
|
||||
// Struct holding the different output streams produced by the graph.
|
||||
struct ClassificationPostprocessingOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||
};
|
||||
|
||||
// Performs sanity checks on provided ClassifierOptions.
|
||||
absl::Status SanityCheckClassifierOptions(
|
||||
|
@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of timestamps that a single ClassificationResult should
|
||||
// aggregate. This is mostly useful for classifiers working on time series,
|
||||
// e.g. audio or video classification.
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the CLASSIFICATIONS output is used for results.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||
|
@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
auto output_streams,
|
||||
BuildClassificationPostprocessing(
|
||||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
classification_result_out >>
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.timestamped_classifications >>
|
||||
graph[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds an on-device classification postprocessing graph into the provided
|
||||
// builder::Graph instance. The classification postprocessing graph takes
|
||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
||||
// stream containing the output classification results (ClassificationResult).
|
||||
// tensors (std::vector<mediapipe::Tensor>) and optional timestamps
|
||||
// (std::vector<Timestamp>) as input and returns two output streams:
|
||||
// - classification results aggregated by classifier head as a
|
||||
// ClassificationResult proto, used when no timestamps are passed in
|
||||
// the graph,
|
||||
// - classification results aggregated by timestamp then by classifier head
|
||||
// as a std::vector<ClassificationResult>, used when timestamps are passed
|
||||
// in the graph.
|
||||
//
|
||||
// options: the on-device ClassificationPostprocessingGraphOptions.
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||
// timestamps that a single ClassificationResult should aggregate.
|
||||
// timestamps that should be used to aggregate classification results.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>>
|
||||
absl::StatusOr<ClassificationPostprocessingOutputStreams>
|
||||
BuildClassificationPostprocessing(
|
||||
const proto::ClassificationPostprocessingGraphOptions& options,
|
||||
Source<std::vector<Tensor>> tensors_in,
|
||||
|
@ -505,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
timestamps_in >> result_aggregation.In(kTimestampsTag);
|
||||
|
||||
// Connects output.
|
||||
return result_aggregation[Output<ClassificationResult>(
|
||||
kClassificationResultTag)];
|
||||
ClassificationPostprocessingOutputStreams output_streams{
|
||||
/*classification_result=*/result_aggregation
|
||||
[Output<ClassificationResult>(kClassificationResultTag)],
|
||||
/*classifications=*/
|
||||
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
|
||||
/*timestamped_classifications=*/
|
||||
result_aggregation[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)]};
|
||||
return output_streams;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -45,12 +45,22 @@ namespace processors {
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of timestamps that a single ClassificationResult should
|
||||
// aggregate. This is mostly useful for classifiers working on time series,
|
||||
// e.g. audio or video classification.
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the CLASSIFICATIONS output is used for results.
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
|
|
|
@ -38,6 +38,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
|
@ -64,6 +65,7 @@ using ::mediapipe::file::JoinPath;
|
|||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::proto::Approximately;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||
|
@ -86,6 +88,11 @@ constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
|||
constexpr char kTimestampsName[] = "timestamps";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationResultName[] = "classification_result";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kClassificationsName[] = "classifications";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
constexpr char kTimestampedClassificationsName[] =
|
||||
"timestamped_classifications";
|
||||
|
||||
// Helper function to get ModelResources.
|
||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||
|
@ -413,6 +420,316 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
|||
}
|
||||
|
||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
&postprocessing
|
||||
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
postprocessing.In(kTimestampsTag);
|
||||
postprocessing.Out(kTimestampedClassificationsTag)
|
||||
.SetName(kTimestampedClassificationsName) >>
|
||||
graph[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)];
|
||||
} else {
|
||||
postprocessing.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedClassificationsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kClassificationsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddTensor(
|
||||
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||
tensors_->emplace_back(element_type,
|
||||
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||
quantization_parameters);
|
||||
auto view = tensors_->back().GetCpuWriteView();
|
||||
T* buffer = view.buffer<T>();
|
||||
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||
}
|
||||
|
||||
absl::Status Run(
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||
int timestamp = 0) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||
// Reset tensors for future calls.
|
||||
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||
absl::make_unique<std::vector<Tensor>>();
|
||||
};
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
options.set_score_threshold(0.5);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 18;
|
||||
tensor[2] = 16;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<ClassificationResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(results,
|
||||
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
classifications {
|
||||
head_index: 0
|
||||
classification_list {
|
||||
classification { index: 1 score: 0.8 }
|
||||
classification { index: 2 score: 0.6 }
|
||||
}
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 12;
|
||||
tensor[2] = 14;
|
||||
tensor[3] = 16;
|
||||
tensor[4] = 18;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<ClassificationResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
classification_list {
|
||||
classification { index: 4 score: 0.8 label: "tiger shark" }
|
||||
classification { index: 3 score: 0.6 label: "great white shark" }
|
||||
classification { index: 2 score: 0.4 label: "goldfish" }
|
||||
}
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 12;
|
||||
tensor[2] = 14;
|
||||
tensor[3] = 16;
|
||||
tensor[4] = 18;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<ClassificationResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
classification_list {
|
||||
classification { index: 4 score: 0.6899744811 label: "tiger shark" }
|
||||
classification {
|
||||
index: 3
|
||||
score: 0.6456563062
|
||||
label: "great white shark"
|
||||
}
|
||||
classification { index: 2 score: 0.5986876601 label: "goldfish" }
|
||||
}
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
|
||||
tensor_0[1] = 0.2;
|
||||
tensor_0[2] = 0.4;
|
||||
tensor_0[3] = 0.6;
|
||||
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
|
||||
tensor_1[1] = 0.2;
|
||||
tensor_1[2] = 0.4;
|
||||
tensor_1[3] = 0.6;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<ClassificationResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "yamnet_classification"
|
||||
classification_list {
|
||||
classification { index: 3 score: 0.6 label: "Narration, monologue" }
|
||||
classification { index: 2 score: 0.4 label: "Conversation" }
|
||||
}
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bird_classification"
|
||||
classification_list {
|
||||
classification { index: 3 score: 0.6 label: "Azara\'s Spinetail" }
|
||||
classification { index: 2 score: 0.4 label: "House Sparrow" }
|
||||
}
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||
/*connect_timestamps=*/true));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
|
||||
tensor_0[1] = 12;
|
||||
tensor_0[2] = 14;
|
||||
tensor_0[3] = 16;
|
||||
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
|
||||
tensor_1[5] = 12;
|
||||
tensor_1[6] = 14;
|
||||
tensor_1[7] = 16;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run(
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||
/*timestamp=*/1000));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<std::vector<ClassificationResult>>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results,
|
||||
Pointwise(
|
||||
EqualsProto(),
|
||||
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
classification_list {
|
||||
classification {
|
||||
index: 3
|
||||
score: 0.6
|
||||
label: "great white shark"
|
||||
}
|
||||
classification { index: 2 score: 0.4 label: "goldfish" }
|
||||
}
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 1
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
classification_list {
|
||||
classification { index: 7 score: 0.6 label: "stingray" }
|
||||
classification { index: 6 score: 0.4 label: "electric ray" }
|
||||
}
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
// TODO: remove these tests once migration is over.
|
||||
class LegacyPostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||
|
@ -496,7 +813,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
|||
absl::make_unique<std::vector<Tensor>>();
|
||||
};
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
|
@ -525,7 +842,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
|||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
|
@ -568,7 +885,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
|||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
|
@ -614,7 +931,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
|||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
|
@ -674,7 +991,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
|||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
|
|
Loading…
Reference in New Issue
Block a user