# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for mediapipe.python.solution_base.""" from absl.testing import absltest from absl.testing import parameterized import numpy as np from google.protobuf import text_format from mediapipe.framework import calculator_pb2 from mediapipe.framework.formats import detection_pb2 from mediapipe.python import solution_base from mediapipe.python.solution_base import PacketDataType CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG = """ input_stream: 'image_in' output_stream: 'image_out' node { name: 'ImageTransformation' calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:image_out' options: { [mediapipe.ImageTransformationCalculatorOptions.ext] { output_width: 10 output_height: 10 } } node_options: { [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { output_width: 10 output_height: 10 } } } """ class SolutionBaseTest(parameterized.TestCase): def test_invalid_initialization_arguments(self): with self.assertRaisesRegex( ValueError, 'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.' ): solution_base.SolutionBase() with self.assertRaisesRegex( ValueError, 'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.' ): solution_base.SolutionBase( graph_config=calculator_pb2.CalculatorGraphConfig(), binary_graph_path='/tmp/no_such.binarypb') @parameterized.named_parameters(('no_graph_input_output_stream', """ node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """, RuntimeError, 'does not have a corresponding output stream.'), ('calcualtor_io_mismatch', """ node { calculator: 'PassThroughCalculator' input_stream: 'in' input_stream: 'in2' output_stream: 'out' } """, ValueError, 'must use matching tags and indexes.'), ('unkown_registered_stream_type_name', """ input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """, RuntimeError, 'Unable to find the type for stream \"in\".')) def test_invalid_config(self, text_config, error_type, error_message): config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex(error_type, error_message): solution_base.SolutionBase(graph_config=config_proto) def test_valid_input_data_type_proto(self): text_config = """ input_stream: 'input_detections' output_stream: 'output_detections' node { calculator: 'DetectionUniqueIdCalculator' input_stream: 'DETECTION_LIST:input_detections' output_stream: 'DETECTION_LIST:output_detections' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: input_detections = detection_pb2.DetectionList() detection_1 = input_detections.detection.add() text_format.Parse('score: 0.5', detection_1) detection_2 = input_detections.detection.add() text_format.Parse('score: 0.8', detection_2) results = solution.process({'input_detections': input_detections}) self.assertTrue(hasattr(results, 'output_detections')) self.assertLen(results.output_detections.detection, 2) expected_detection_1 = detection_pb2.Detection() text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1) expected_detection_2 = detection_pb2.Detection() text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2) self.assertEqual(results.output_detections.detection[0], expected_detection_1) self.assertEqual(results.output_detections.detection[1], expected_detection_2) def test_invalid_input_data_type_proto_vector(self): text_config = """ input_stream: 'input_detections' output_stream: 'output_detections' node { calculator: 'DetectionUniqueIdCalculator' input_stream: 'DETECTIONS:input_detections' output_stream: 'DETECTIONS:output_detections' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: detection = detection_pb2.Detection() text_format.Parse('score: 0.5', detection) with self.assertRaisesRegex( NotImplementedError, 'SolutionBase can only process non-audio and non-proto-list data. ' + 'PROTO_LIST type is not supported.' ): solution.process({'input_detections': detection}) def test_invalid_input_image_data(self): text_config = """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:transformed_image_in' output_stream: 'IMAGE:image_out' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: with self.assertRaisesRegex( ValueError, 'Input image must contain three channel rgb data.'): solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) @parameterized.named_parameters(('graph_without_side_packets', """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:transformed_image_in' output_stream: 'IMAGE:image_out' } """, None), ('graph_with_side_packets', """ input_stream: 'image_in' input_side_packet: 'allow_signal' input_side_packet: 'rotation_degrees' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' input_side_packet: 'ROTATION_DEGREES:rotation_degrees' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'GateCalculator' input_stream: 'transformed_image_in' input_side_packet: 'ALLOW:allow_signal' output_stream: 'image_out_to_transform' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_out_to_transform' input_side_packet: 'ROTATION_DEGREES:rotation_degrees' output_stream: 'IMAGE:image_out' }""", { 'allow_signal': True, 'rotation_degrees': 0 })) def test_solution_process(self, text_config, side_inputs): self._process_and_verify( config_proto=text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()), side_inputs=side_inputs) def test_invalid_calculator_options(self): text_config = """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { name: 'SignalGate' calculator: 'GateCalculator' input_stream: 'transformed_image_in' input_side_packet: 'ALLOW:allow_signal' output_stream: 'image_out_to_transform' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_out_to_transform' output_stream: 'IMAGE:image_out' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex( ValueError, 'Modifying the calculator options of SignalGate is not supported.'): solution_base.SolutionBase( graph_config=config_proto, calculator_params={'SignalGate.invalid_field': 'I am invalid'}) def test_calculator_has_both_options_and_node_options(self): config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex(ValueError, 'has both options and node_options fields.'): solution_base.SolutionBase( graph_config=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 }) def test_modifying_calculator_proto2_options(self): config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) # To test proto2 options only, remove the proto3 node_options field from the # graph config. self.assertEqual('ImageTransformation', config_proto.node[0].name) config_proto.node[0].ClearField('node_options') self._process_and_verify( config_proto=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 }) def test_modifying_calculator_proto3_node_options(self): config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) # To test proto3 node options only, remove the proto2 options field from the # graph config. self.assertEqual('ImageTransformation', config_proto.node[0].name) config_proto.node[0].ClearField('options') self._process_and_verify( config_proto=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 }) def test_adding_calculator_options(self): config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) # To test a calculator with no options field, remove both proto2 options and # proto3 node_options fields from the graph config. self.assertEqual('ImageTransformation', config_proto.node[0].name) config_proto.node[0].ClearField('options') config_proto.node[0].ClearField('node_options') self._process_and_verify( config_proto=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 }) @parameterized.named_parameters(('graph_without_side_packets', """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:transformed_image_in' output_stream: 'IMAGE:image_out' } """, None), ('graph_with_side_packets', """ input_stream: 'image_in' input_side_packet: 'allow_signal' input_side_packet: 'rotation_degrees' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' input_side_packet: 'ROTATION_DEGREES:rotation_degrees' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'GateCalculator' input_stream: 'transformed_image_in' input_side_packet: 'ALLOW:allow_signal' output_stream: 'image_out_to_transform' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_out_to_transform' input_side_packet: 'ROTATION_DEGREES:rotation_degrees' output_stream: 'IMAGE:image_out' }""", { 'allow_signal': True, 'rotation_degrees': 0 })) def test_solution_reset(self, text_config, side_inputs): config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) with solution_base.SolutionBase( graph_config=config_proto, side_inputs=side_inputs) as solution: for _ in range(20): outputs = solution.process(input_image) self.assertTrue(np.array_equal(input_image, outputs.image_out)) solution.reset() def test_solution_stream_type_hints(self): text_config = """ input_stream: 'union_type_image_in' output_stream: 'image_type_out' node { calculator: 'ToImageCalculator' input_stream: 'IMAGE:union_type_image_in' output_stream: 'IMAGE:image_type_out' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) with solution_base.SolutionBase( graph_config=config_proto, stream_type_hints={'union_type_image_in': PacketDataType.IMAGE }) as solution: for _ in range(20): outputs = solution.process(input_image) self.assertTrue(np.array_equal(input_image, outputs.image_type_out)) with solution_base.SolutionBase( graph_config=config_proto, stream_type_hints={'union_type_image_in': PacketDataType.IMAGE_FRAME }) as solution2: for _ in range(20): outputs = solution2.process(input_image) self.assertTrue(np.array_equal(input_image, outputs.image_type_out)) def _process_and_verify(self, config_proto, side_inputs=None, calculator_params=None): input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) with solution_base.SolutionBase( graph_config=config_proto, side_inputs=side_inputs, calculator_params=calculator_params) as solution: outputs = solution.process(input_image) outputs2 = solution.process({'image_in': input_image}) self.assertTrue(np.array_equal(input_image, outputs.image_out)) self.assertTrue(np.array_equal(input_image, outputs2.image_out)) if __name__ == '__main__': absltest.main()