Changed labels to be a property
This commit is contained in:
		
							parent
							
								
									d621df8046
								
							
						
					
					
						commit
						1cb404bea1
					
				|  | @ -247,14 +247,16 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
|           ) | ||||
|       ) | ||||
| 
 | ||||
|   def test_get_labels_succeeds(self): | ||||
|   def test_labels_succeeds(self): | ||||
|     expected_labels = _EXPECTED_LABELS | ||||
|     base_options = _BaseOptions(model_asset_path=self.model_path) | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) | ||||
|         base_options=base_options, output_category_mask=True, | ||||
|         output_confidence_masks=False | ||||
|     ) | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       # Performs image segmentation on the input. | ||||
|       actual_labels = segmenter.get_labels() | ||||
|       actual_labels = segmenter.labels | ||||
|       self.assertListEqual(actual_labels, expected_labels) | ||||
| 
 | ||||
|   def test_missing_result_callback(self): | ||||
|  |  | |||
|  | @ -151,7 +151,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|       Exception if there is an error during finding TensorsToSegmentationCalculator. | ||||
|     :return: | ||||
|     """ | ||||
|     self.labels = [] | ||||
|     self._labels = [] | ||||
|     graph_config = self._runner.get_graph_config() | ||||
|     found_tensors_to_segmentation = False | ||||
| 
 | ||||
|  | @ -170,7 +170,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|           for i in range(len(options.label_items)): | ||||
|             if i not in options.label_items: | ||||
|               raise Exception(f"The labelmap has no expected key: {i}.") | ||||
|             self.labels.append(options.label_items[i].name) | ||||
|             self._labels.append(options.label_items[i].name) | ||||
| 
 | ||||
|   @classmethod | ||||
|   def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': | ||||
|  | @ -271,19 +271,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|         packets_callback if options.result_callback else None, | ||||
|     ) | ||||
| 
 | ||||
|   def get_labels(self): | ||||
|     """ Get the category label list of the ImageSegmenter can recognize. | ||||
| 
 | ||||
|     For CATEGORY_MASK type, the index in the category mask corresponds to the | ||||
|     category in the label list. | ||||
|     For CONFIDENCE_MASK type, the output mask list at index corresponds to the | ||||
|     category in the label list. | ||||
| 
 | ||||
|     If there is no label map provided in the model file, empty label list is | ||||
|     returned. | ||||
|     """ | ||||
|     return self.labels | ||||
| 
 | ||||
|   def segment( | ||||
|       self, | ||||
|       image: image_module.Image, | ||||
|  | @ -427,3 +414,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|             normalized_rect.to_pb2() | ||||
|         ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), | ||||
|     }) | ||||
| 
 | ||||
|   @property | ||||
|   def labels(self) -> List[str]: | ||||
|     """ Get the category label list of the ImageSegmenter can recognize. | ||||
| 
 | ||||
|     For CATEGORY_MASK type, the index in the category mask corresponds to the | ||||
|     category in the label list. | ||||
|     For CONFIDENCE_MASK type, the output mask list at index corresponds to the | ||||
|     category in the label list. | ||||
| 
 | ||||
|     If there is no label map provided in the model file, empty label list is | ||||
|     returned. | ||||
|     """ | ||||
|     return self._labels | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user