// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MEDIAPIPE_UTIL_TIME_SERIES_TEST_UTIL_H_ #define MEDIAPIPE_UTIL_TIME_SERIES_TEST_UTIL_H_ #include #include #include #include "Eigen/Core" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/util/time_series_util.h" namespace mediapipe { // Base class for testing Calculators that operate on TimeSeries inputs. // Subclasses that do not need a special options proto should inherit from // the subclass BasicTimeSeriesCalculatorTestBase. // // This class handles calculators that accept one or more input streams // specified either by indices or by tags and produce one or more output // streams, again either specified by indices or tags. // The default is to use one input stream and one output stream, specified by // index. To use more streams by index, set num_input_streams_ or // num_output_streams for the number of input or output streams, respectively. // These have to be set before calling InitializeGraph(). To use one or more // streams by tag, set input_stream_tags_ or output_stream_tags_ before calling // InitializeGraph(), for example: // input_stream_tags_ = {"MATRIX", "FRAMES"}; // output_stream_tags_ = {"MATRIX"}; // InitializeGraph(); // These options are exclusive since mediapipe requires calculators to use // either indices or tags, but not both. template class TimeSeriesCalculatorTest : public ::testing::Test { protected: // Sentinal value which can be used to tell methods like // PopulateHeader to ignore certain fields. static constexpr int kUnset = 0; TimeSeriesCalculatorTest() : num_side_packets_(0), num_input_streams_(1), num_output_streams_(1), input_packet_rate_(kUnset), num_input_samples_(kUnset), audio_sample_rate_(kUnset) {} // Makes the input stream names used in CalculatorRunner runner_. // If tags are used, that is, input_stream_tags_ is not empty, it returns // names of the form: // :_, // :_, etc. // For the index format, returns names of the form _0, // _1, etc. std::vector MakeInputStreamNames(const std::string& base_name) { if (!input_stream_tags_.empty()) { return MakeNames(base_name, input_stream_tags_); } else { return MakeNames(base_name, num_input_streams_); } return std::vector(); } // Same as MakeInputStreamNames, but for output streams. std::vector MakeOutputStreamNames(const std::string& base_name) { if (!output_stream_tags_.empty()) { return MakeNames(base_name, output_stream_tags_); } else { return MakeNames(base_name, num_output_streams_); } return std::vector(); } // Makes names used in CalculatorRunner runner_ that use the tag format. Tags // must be capitalized. Returns names of the form // :_, // :_, etc. std::vector MakeNames(const std::string& base_name, const std::vector& tags) { std::vector base_names; std::vector ids; for (const std::string& tag : tags) { const std::string tagged_base_name = absl::StrCat(tag, ":", base_name); base_names.push_back(tagged_base_name); std::string id; id.reserve(tag.size()); for (std::size_t i = 0; i < tag.size(); ++i) { id += std::tolower(tag[i]); } ids.push_back(id); } const std::vector names = MakeNames(base_names, ids); return names; } // Makes names used in CalculatorRunner runner_ that use the index format. // Total is the number of names to create. Returns names of the form // _0, _1, ..., _. std::vector MakeNames(const std::string& base_name, const int total) { std::vector base_names; std::vector ids; for (int i = 0; i < total; ++i) { const std::string id = absl::StrCat(i); ids.push_back(id); base_names.push_back(base_name); } const std::vector names = MakeNames(base_names, ids); return names; } // Makes names used in CalculatorRunner runner_. Returns names of the form // _, _, etc. std::vector MakeNames(const std::vector& base_names, const std::vector& ids) { CHECK_EQ(base_names.size(), ids.size()); std::vector names; for (int i = 0; i < base_names.size(); ++i) { const std::string name_template = R"($0_$1)"; const std::string& base_name = base_names[i]; const std::string& id = ids[i]; const std::string name = absl::Substitute(name_template, base_name, id); names.push_back(name); } return names; } // Makes the CalculatorGraphConfig used to initialize CalculatorRunner // runner_. If no options are needed, pass the empty string for options. CalculatorGraphConfig::Node MakeNodeConfig(const std::string& calculator_name, const int num_side_packets, const CalculatorOptions& options) { CalculatorGraphConfig::Node node_config; node_config.set_calculator(calculator_name); CalculatorOptions* node_config_options = node_config.mutable_options(); *node_config_options = options; const std::string input_stream_base_name = "input_stream"; const std::vector input_stream_names = MakeInputStreamNames(input_stream_base_name); for (const std::string& input_stream_name : input_stream_names) { node_config.add_input_stream(input_stream_name); } const std::string input_side_packet_base_name = "input_side_packet"; const std::vector input_side_packet_names = MakeNames(input_side_packet_base_name, num_side_packets); for (const std::string& input_side_packet_name : input_side_packet_names) { node_config.add_input_side_packet(input_side_packet_name); } const std::string output_stream_base_name = "output_stream"; const std::vector output_stream_names = MakeOutputStreamNames(output_stream_base_name); for (const std::string& output_stream_name : output_stream_names) { node_config.add_output_stream(output_stream_name); } return node_config; } void InitializeGraph(const CalculatorOptions& options) { if (num_external_inputs_ != -1) { LOG(WARNING) << "Use num_side_packets_ instead of num_external_inputs_."; num_side_packets_ = num_external_inputs_; } if (!input_stream_tags_.empty()) { num_input_streams_ = input_stream_tags_.size(); } if (!output_stream_tags_.empty()) { num_output_streams_ = output_stream_tags_.size(); } const CalculatorGraphConfig::Node node_config = MakeNodeConfig(calculator_name_, num_side_packets_, options); runner_.reset(new CalculatorRunner(node_config)); } void InitializeGraph() { CalculatorOptions options; FillOptionsExtension(&options); InitializeGraph(options); } // Provide an alternative to InitializeGraph for tests that want the // options not to be set. void InitializeGraphWithoutOptions() { CalculatorOptions options; // Left empty. InitializeGraph(options); } // This is broken out into a separate function to facilitate the // NoOptions specialization defined below. void FillOptionsExtension(CalculatorOptions* options) { options->MutableExtension(OptionsClass::ext)->MergeFrom(options_); } void PopulateHeader(TimeSeriesHeader* header) { header->set_num_channels(num_input_channels_); header->set_sample_rate(input_sample_rate_); if (num_input_samples_ != kUnset) { header->set_num_samples(num_input_samples_); } if (input_packet_rate_ != kUnset) { header->set_packet_rate(input_packet_rate_); } if (audio_sample_rate_ != kUnset) { header->set_audio_sample_rate(audio_sample_rate_); } } std::unique_ptr CreateInputHeader() { std::unique_ptr header(new TimeSeriesHeader); PopulateHeader(header.get()); return header; } void FillInputHeader(const size_t input_index = 0) { runner_->MutableInputs()->Index(input_index).header = Adopt(CreateInputHeader().release()); } void FillInputHeader(const std::string& input_tag) { runner_->MutableInputs()->Tag(input_tag).header = Adopt(CreateInputHeader().release()); } template std::unique_ptr CreateInputHeaderWithExtension( const TimeSeriesHeaderExtensionClass& extension) { auto header = CreateInputHeader(); time_series_util::SetExtensionInHeader(extension, header.get()); return header; } template void FillInputHeaderWithExtension( const TimeSeriesHeaderExtensionClass& extension, const size_t input_index = 0) { auto header = CreateInputHeaderWithExtension(extension); runner_->MutableInputs()->Index(input_index).header = Adopt(header.release()); } template void FillInputHeaderWithExtension( const TimeSeriesHeaderExtensionClass& extension, const std::string& input_tag) { auto header = CreateInputHeaderWithExtension(extension); runner_->MutableInputs()->Tag(input_tag).header = Adopt(header.release()); } // Takes ownership of payload. template void AppendInputPacket(const T* payload, const Timestamp timestamp, const size_t input_index = 0) { runner_->MutableInputs() ->Index(input_index) .packets.push_back(Adopt(payload).At(timestamp)); } // Overload to allow explicit conversion from int64 to Timestamp template void AppendInputPacket(const T* payload, const int64 timestamp, const size_t input_index = 0) { AppendInputPacket(payload, Timestamp(timestamp), input_index); } template void AppendInputPacket(const T* payload, const Timestamp timestamp, const std::string& input_tag) { runner_->MutableInputs()->Tag(input_tag).packets.push_back( Adopt(payload).At(timestamp)); } template void AppendInputPacket(const T* payload, const int64 timestamp, const std::string& input_tag) { AppendInputPacket(payload, Timestamp(timestamp), input_tag); } absl::Status RunGraph() { return runner_->Run(); } bool HasInputHeader(const size_t input_index = 0) const { return input(input_index) .header.template ValidateAsType() .ok(); } bool HasOutputHeader() const { return output().header.template ValidateAsType().ok(); } template void ExpectOutputHeaderEquals(const Proto& expected, const size_t output_index = 0) const { EXPECT_THAT(output(output_index).header.template Get(), mediapipe::EqualsProto(expected)); } void ExpectOutputHeaderEqualsInputHeader( const size_t input_index = 0, const size_t output_index = 0) const { EXPECT_THAT( output(output_index).header.template Get(), mediapipe::EqualsProto( input(input_index).header.template Get())); } void ExpectOutputHeaderEqualsInputHeader( const std::string& input_tag, const size_t output_index = 0) const { EXPECT_THAT(output(output_index).header.template Get(), mediapipe::EqualsProto( input(input_tag).header.template Get())); } void ExpectOutputHeaderEqualsInputHeader( const size_t input_index, const std::string& output_tag) const { EXPECT_THAT( output(output_tag).header.template Get(), mediapipe::EqualsProto( input(input_index).header.template Get())); } void ExpectOutputHeaderEqualsInputHeader( const std::string& input_tag, const std::string& output_tag) const { EXPECT_THAT(output(output_tag).header.template Get(), mediapipe::EqualsProto( input(input_tag).header.template Get())); } void ExpectApproximatelyEqual(const Matrix& expected, const Matrix& actual) const { const float kPrecision = 1e-6; EXPECT_TRUE(actual.isApprox(expected, kPrecision)) << "Expected: " << expected << ", but got: " << actual; } const CalculatorRunner::StreamContents& input( const size_t input_index = 0) const { return runner_->MutableInputs()->Index(input_index); } const CalculatorRunner::StreamContents& input( const std::string& input_tag) const { return runner_->MutableInputs()->Tag(input_tag); } const CalculatorRunner::StreamContents& output( const size_t output_index = 0) const { return runner_->Outputs().Index(output_index); } const CalculatorRunner::StreamContents& output( const std::string& output_tag) const { return runner_->Outputs().Tag(output_tag); } // Caller takes ownership of the return value. static Matrix* NewRandomMatrix(int num_channels, int num_samples) { // TODO: Fix a consistent lack of random seed setting in tests. auto matrix = new Matrix; matrix->setRandom(num_channels, num_samples); return matrix; } std::string calculator_name_; OptionsClass options_; int num_side_packets_; int num_input_streams_; std::vector input_stream_tags_; int num_output_streams_; std::vector output_stream_tags_; // TODO For backwards compatibility, remove after all clients // are updated. int num_external_inputs_ = -1; int num_input_channels_; double input_sample_rate_; // If this is non-zero, it will be used to set the packet_rate field // of the header proto. double input_packet_rate_; // If this is non-zero, it will be used to set the num_samples field // of the header proto. int num_input_samples_; // If this is non-zero, it will be used to set the audio_sample_rate field // of the header proto. double audio_sample_rate_; std::unique_ptr runner_; }; template class MultiStreamTimeSeriesCalculatorTest : public TimeSeriesCalculatorTest { protected: void FillInputHeader() { std::unique_ptr header( new MultiStreamTimeSeriesHeader); PopulateHeader(header.get()); this->runner_->MutableInputs()->Index(0).header = Adopt(header.release()); } template void FillInputHeaderWithExtension( const TimeSeriesHeaderExtensionClass& extension) { std::unique_ptr header( new MultiStreamTimeSeriesHeader); PopulateHeader(header.get()); time_series_util::SetExtensionInHeader( extension, header->mutable_time_series_header()); this->runner_->MutableInputs()->Index(0).header = Adopt(header.release()); } // Takes ownership of input_vector. void AppendInputPacket(const std::vector* input_vector, const Timestamp timestamp) { this->runner_->MutableInputs()->Index(0).packets.push_back( Adopt(input_vector).At(timestamp)); } // Overload to allow explicit conversion from int64 to Timestamp void AppendInputPacket(const std::vector* input_vector, const int64 timestamp) { AppendInputPacket(input_vector, Timestamp(timestamp)); } template void ExpectOutputHeaderEquals(const StringOrProto& expected) const { EXPECT_THAT( this->output().header.template Get(), mediapipe::EqualsProto(expected)); } void ExpectOutputHeaderEqualsInputHeader() const { ExpectOutputHeaderEquals( this->input().header.template Get()); } int num_input_streams_; private: void PopulateHeader(MultiStreamTimeSeriesHeader* header) { TimeSeriesCalculatorTest::PopulateHeader( header->mutable_time_series_header()); header->set_num_streams(num_input_streams_); } }; struct NoOptions {}; template <> void TimeSeriesCalculatorTest::FillOptionsExtension( CalculatorOptions* options) {} // Base class for testing basic time series calculators, which are calculators // that take no options. class BasicTimeSeriesCalculatorTestBase : public TimeSeriesCalculatorTest { protected: TimeSeriesHeader ParseTextFormat(const std::string& text_format) { TimeSeriesHeader header = ParseTextProtoOrDie(text_format); return header; } void Test(const TimeSeriesHeader& input_header, const std::vector& input_packets, const TimeSeriesHeader& expected_output_header, const std::vector& expected_output_packets) { InitializeGraph(); runner_->MutableInputs()->Index(0).header = Adopt(new TimeSeriesHeader(input_header)); for (int i = 0; i < input_packets.size(); ++i) { const Timestamp timestamp(i * Timestamp::kTimestampUnitsPerSecond); AppendInputPacket(new Matrix(input_packets[i]), timestamp); } MP_ASSERT_OK(RunGraph()); ExpectOutputHeaderEquals(expected_output_header); EXPECT_EQ(input().packets.size(), output().packets.size()); ASSERT_EQ(output().packets.size(), expected_output_packets.size()); for (int i = 0; i < output().packets.size(); ++i) { EXPECT_EQ(input().packets[i].Timestamp(), output().packets[i].Timestamp()); ExpectApproximatelyEqual(expected_output_packets[i], output().packets[i].Get()); } } }; } // namespace mediapipe #endif // MEDIAPIPE_UTIL_TIME_SERIES_TEST_UTIL_H_