# Copyright 2020 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for mediapipe.python._framework_bindings.calculator_graph.""" from absl.testing import absltest from google.protobuf import text_format from mediapipe.framework import calculator_pb2 from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import calculator_graph from mediapipe.python._framework_bindings import validated_graph_config CalculatorGraph = calculator_graph.CalculatorGraph ValidatedGraphConfig = validated_graph_config.ValidatedGraphConfig class GraphTest(absltest.TestCase): def test_invalid_binary_graph_file(self): with self.assertRaisesRegex( FileNotFoundError, '(No such file or directory|The path does not exist)'): CalculatorGraph(binary_graph_path='/tmp/abc.binarypb') def test_invalid_node_config(self): text_config = """ node { calculator: 'PassThroughCalculator' input_stream: 'in' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) with self.assertRaisesRegex( ValueError, 'Input and output streams to PassThroughCalculator must use matching tags and indexes.' ): CalculatorGraph(graph_config=config_proto) def test_invalid_calculator_type(self): text_config = """ node { calculator: 'SomeUnknownCalculator' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) with self.assertRaisesRegex( RuntimeError, 'Unable to find Calculator \"SomeUnknownCalculator\"'): CalculatorGraph(graph_config=config_proto) def test_graph_initialized_with_proto_config(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) graph = CalculatorGraph(graph_config=config_proto) hello_world_packet = packet_creator.create_string('hello world') out = [] graph = CalculatorGraph(graph_config=config_proto) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet, timestamp=0) graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet.at(1)) graph.close() self.assertEqual( graph.graph_input_stream_add_mode, calculator_graph.GraphInputStreamAddMode.WAIT_TILL_NOT_FULL) self.assertEqual(graph.max_queue_size, 1) self.assertFalse(graph.has_error()) self.assertLen(out, 2) self.assertEqual(out[0].timestamp, 0) self.assertEqual(out[1].timestamp, 1) self.assertEqual(packet_getter.get_str(out[0]), 'hello world') self.assertEqual(packet_getter.get_str(out[1]), 'hello world') def test_graph_initialized_with_text_config(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ hello_world_packet = packet_creator.create_string('hello world') out = [] graph = CalculatorGraph(graph_config=text_config) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet.at(0)) graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet, timestamp=1) graph.close() self.assertEqual( graph.graph_input_stream_add_mode, calculator_graph.GraphInputStreamAddMode.WAIT_TILL_NOT_FULL) self.assertEqual(graph.max_queue_size, 1) self.assertFalse(graph.has_error()) self.assertLen(out, 2) self.assertEqual(out[0].timestamp, 0) self.assertEqual(out[1].timestamp, 1) self.assertEqual(packet_getter.get_str(out[0]), 'hello world') self.assertEqual(packet_getter.get_str(out[1]), 'hello world') def test_graph_validation_and_initialization(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ hello_world_packet = packet_creator.create_string('hello world') out = [] validated_graph = ValidatedGraphConfig() self.assertFalse(validated_graph.initialized()) validated_graph.initialize(graph_config=text_config) self.assertTrue(validated_graph.initialized()) graph = CalculatorGraph(validated_graph_config=validated_graph) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet.at(0)) graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet, timestamp=1) graph.close() self.assertEqual( graph.graph_input_stream_add_mode, calculator_graph.GraphInputStreamAddMode.WAIT_TILL_NOT_FULL) self.assertEqual(graph.max_queue_size, 1) self.assertFalse(graph.has_error()) self.assertLen(out, 2) self.assertEqual(out[0].timestamp, 0) self.assertEqual(out[1].timestamp, 1) self.assertEqual(packet_getter.get_str(out[0]), 'hello world') self.assertEqual(packet_getter.get_str(out[1]), 'hello world') def test_insert_packets_with_same_timestamp(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) hello_world_packet = packet_creator.create_string('hello world') out = [] graph = CalculatorGraph(graph_config=config_proto) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet.at(0)) graph.wait_until_idle() graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet.at(0)) with self.assertRaisesRegex( ValueError, 'Current minimum expected timestamp is 1 but received 0.'): graph.wait_until_idle() def test_side_packet_graph(self): text_config = """ node { calculator: 'StringToUint64Calculator' input_side_packet: "string" output_side_packet: "number" } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) graph = CalculatorGraph(graph_config=config_proto) graph.start_run( input_side_packets={'string': packet_creator.create_string('42')}) graph.wait_until_done() self.assertFalse(graph.has_error()) self.assertEqual( packet_getter.get_uint(graph.get_output_side_packet('number')), 42) def test_sequence_input(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ hello_world_packet = packet_creator.create_string('hello world') out = [] graph = CalculatorGraph(graph_config=text_config) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() sequence_size = 1000 for i in range(sequence_size): graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet, timestamp=i) graph.wait_until_idle() self.assertLen(out, sequence_size) for i in range(sequence_size): self.assertEqual(out[i].timestamp, i) self.assertEqual(packet_getter.get_str(out[i]), 'hello world') if __name__ == '__main__': absltest.main()