mediapipe/mediapipe/python/solution_base_test.py
MediaPipe Team cc6a2f7af6 Project import generated by Copybara.
GitOrigin-RevId: 73d686c40057684f8bfaca285368bf1813f9fc26
2022-03-21 12:12:39 -07:00

397 lines
15 KiB
Python

# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for mediapipe.python.solution_base."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
from mediapipe.framework import calculator_pb2
from mediapipe.framework.formats import detection_pb2
from mediapipe.python import solution_base
from mediapipe.python.solution_base import PacketDataType
CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG = """
input_stream: 'image_in'
output_stream: 'image_out'
node {
name: 'ImageTransformation'
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:image_out'
options: {
[mediapipe.ImageTransformationCalculatorOptions.ext] {
output_width: 10
output_height: 10
}
}
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 10
output_height: 10
}
}
}
"""
class SolutionBaseTest(parameterized.TestCase):
def test_invalid_initialization_arguments(self):
with self.assertRaisesRegex(
ValueError,
'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
):
solution_base.SolutionBase()
with self.assertRaisesRegex(
ValueError,
'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
):
solution_base.SolutionBase(
graph_config=calculator_pb2.CalculatorGraphConfig(),
binary_graph_path='/tmp/no_such.binarypb')
@parameterized.named_parameters(('no_graph_input_output_stream', """
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
""", RuntimeError, 'does not have a corresponding output stream.'),
('calcualtor_io_mismatch', """
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'in2'
output_stream: 'out'
}
""", ValueError, 'must use matching tags and indexes.'),
('unkown_registered_stream_type_name', """
input_stream: 'in'
output_stream: 'out'
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
""", RuntimeError, 'Unable to find the type for stream \"in\".'))
def test_invalid_config(self, text_config, error_type, error_message):
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
with self.assertRaisesRegex(error_type, error_message):
solution_base.SolutionBase(graph_config=config_proto)
def test_valid_input_data_type_proto(self):
text_config = """
input_stream: 'input_detections'
output_stream: 'output_detections'
node {
calculator: 'DetectionUniqueIdCalculator'
input_stream: 'DETECTION_LIST:input_detections'
output_stream: 'DETECTION_LIST:output_detections'
}
"""
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
with solution_base.SolutionBase(graph_config=config_proto) as solution:
input_detections = detection_pb2.DetectionList()
detection_1 = input_detections.detection.add()
text_format.Parse('score: 0.5', detection_1)
detection_2 = input_detections.detection.add()
text_format.Parse('score: 0.8', detection_2)
results = solution.process({'input_detections': input_detections})
self.assertTrue(hasattr(results, 'output_detections'))
self.assertLen(results.output_detections.detection, 2)
expected_detection_1 = detection_pb2.Detection()
text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1)
expected_detection_2 = detection_pb2.Detection()
text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2)
self.assertEqual(results.output_detections.detection[0],
expected_detection_1)
self.assertEqual(results.output_detections.detection[1],
expected_detection_2)
def test_invalid_input_data_type_proto_vector(self):
text_config = """
input_stream: 'input_detections'
output_stream: 'output_detections'
node {
calculator: 'DetectionUniqueIdCalculator'
input_stream: 'DETECTIONS:input_detections'
output_stream: 'DETECTIONS:output_detections'
}
"""
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
with solution_base.SolutionBase(graph_config=config_proto) as solution:
detection = detection_pb2.Detection()
text_format.Parse('score: 0.5', detection)
with self.assertRaisesRegex(
NotImplementedError,
'SolutionBase can only process non-audio and non-proto-list data. '
+ 'PROTO_LIST type is not supported.'
):
solution.process({'input_detections': detection})
def test_invalid_input_image_data(self):
text_config = """
input_stream: 'image_in'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:transformed_image_in'
output_stream: 'IMAGE:image_out'
}
"""
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
with solution_base.SolutionBase(graph_config=config_proto) as solution:
with self.assertRaisesRegex(
ValueError, 'Input image must contain three channel rgb data.'):
solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
@parameterized.named_parameters(('graph_without_side_packets', """
input_stream: 'image_in'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:transformed_image_in'
output_stream: 'IMAGE:image_out'
}
""", None), ('graph_with_side_packets', """
input_stream: 'image_in'
input_side_packet: 'allow_signal'
input_side_packet: 'rotation_degrees'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'GateCalculator'
input_stream: 'transformed_image_in'
input_side_packet: 'ALLOW:allow_signal'
output_stream: 'image_out_to_transform'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_out_to_transform'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:image_out'
}""", {
'allow_signal': True,
'rotation_degrees': 0
}))
def test_solution_process(self, text_config, side_inputs):
self._process_and_verify(
config_proto=text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig()),
side_inputs=side_inputs)
def test_invalid_calculator_options(self):
text_config = """
input_stream: 'image_in'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:transformed_image_in'
}
node {
name: 'SignalGate'
calculator: 'GateCalculator'
input_stream: 'transformed_image_in'
input_side_packet: 'ALLOW:allow_signal'
output_stream: 'image_out_to_transform'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_out_to_transform'
output_stream: 'IMAGE:image_out'
}
"""
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
with self.assertRaisesRegex(
ValueError,
'Modifying the calculator options of SignalGate is not supported.'):
solution_base.SolutionBase(
graph_config=config_proto,
calculator_params={'SignalGate.invalid_field': 'I am invalid'})
def test_calculator_has_both_options_and_node_options(self):
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
calculator_pb2.CalculatorGraphConfig())
with self.assertRaisesRegex(ValueError,
'has both options and node_options fields.'):
solution_base.SolutionBase(
graph_config=config_proto,
calculator_params={
'ImageTransformation.output_width': 0,
'ImageTransformation.output_height': 0
})
def test_modifying_calculator_proto2_options(self):
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
calculator_pb2.CalculatorGraphConfig())
# To test proto2 options only, remove the proto3 node_options field from the
# graph config.
self.assertEqual('ImageTransformation', config_proto.node[0].name)
config_proto.node[0].ClearField('node_options')
self._process_and_verify(
config_proto=config_proto,
calculator_params={
'ImageTransformation.output_width': 0,
'ImageTransformation.output_height': 0
})
def test_modifying_calculator_proto3_node_options(self):
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
calculator_pb2.CalculatorGraphConfig())
# To test proto3 node options only, remove the proto2 options field from the
# graph config.
self.assertEqual('ImageTransformation', config_proto.node[0].name)
config_proto.node[0].ClearField('options')
self._process_and_verify(
config_proto=config_proto,
calculator_params={
'ImageTransformation.output_width': 0,
'ImageTransformation.output_height': 0
})
def test_adding_calculator_options(self):
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
calculator_pb2.CalculatorGraphConfig())
# To test a calculator with no options field, remove both proto2 options and
# proto3 node_options fields from the graph config.
self.assertEqual('ImageTransformation', config_proto.node[0].name)
config_proto.node[0].ClearField('options')
config_proto.node[0].ClearField('node_options')
self._process_and_verify(
config_proto=config_proto,
calculator_params={
'ImageTransformation.output_width': 0,
'ImageTransformation.output_height': 0
})
@parameterized.named_parameters(('graph_without_side_packets', """
input_stream: 'image_in'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:transformed_image_in'
output_stream: 'IMAGE:image_out'
}
""", None), ('graph_with_side_packets', """
input_stream: 'image_in'
input_side_packet: 'allow_signal'
input_side_packet: 'rotation_degrees'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'GateCalculator'
input_stream: 'transformed_image_in'
input_side_packet: 'ALLOW:allow_signal'
output_stream: 'image_out_to_transform'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_out_to_transform'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:image_out'
}""", {
'allow_signal': True,
'rotation_degrees': 0
}))
def test_solution_reset(self, text_config, side_inputs):
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
with solution_base.SolutionBase(
graph_config=config_proto, side_inputs=side_inputs) as solution:
for _ in range(20):
outputs = solution.process(input_image)
self.assertTrue(np.array_equal(input_image, outputs.image_out))
solution.reset()
def test_solution_stream_type_hints(self):
text_config = """
input_stream: 'union_type_image_in'
output_stream: 'image_type_out'
node {
calculator: 'ToImageCalculator'
input_stream: 'IMAGE:union_type_image_in'
output_stream: 'IMAGE:image_type_out'
}
"""
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
with solution_base.SolutionBase(
graph_config=config_proto,
stream_type_hints={'union_type_image_in': PacketDataType.IMAGE
}) as solution:
for _ in range(20):
outputs = solution.process(input_image)
self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
with solution_base.SolutionBase(
graph_config=config_proto,
stream_type_hints={'union_type_image_in': PacketDataType.IMAGE_FRAME
}) as solution2:
for _ in range(20):
outputs = solution2.process(input_image)
self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
def _process_and_verify(self,
config_proto,
side_inputs=None,
calculator_params=None):
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
with solution_base.SolutionBase(
graph_config=config_proto,
side_inputs=side_inputs,
calculator_params=calculator_params) as solution:
outputs = solution.process(input_image)
outputs2 = solution.process({'image_in': input_image})
self.assertTrue(np.array_equal(input_image, outputs.image_out))
self.assertTrue(np.array_equal(input_image, outputs2.image_out))
if __name__ == '__main__':
absltest.main()