mediapipe/mediapipe/tasks/cc/metadata/metadata_extractor.h
MediaPipe Team 4dc4b19ddb Project import generated by Copybara.
GitOrigin-RevId: 1e13be30e2c6838d4a2ff768a39c414bc80534bb
2022-09-06 21:46:17 +00:00

164 lines
7.3 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_METADATA_METADATA_EXTRACTOR_H_
#define MEDIAPIPE_TASKS_CC_METADATA_METADATA_EXTRACTOR_H_
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe {
namespace tasks {
namespace metadata {
// Extracts and provides easy access to the TFLite ModelMetadata [1] and
// corresponding associated files packed into a TFLite FlatBuffer, if any.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
class ModelMetadataExtractor {
public:
// Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer
// and returns a pointer to the new object. Ownership is transferred to the
// caller. Returns an error if the creation failed, which may happen e.g. if
// the provided buffer is not a valid TFLite FlatBuffer.
//
// Warning: Does not take ownership of the provided buffer, which must outlive
// this object.
//
// It is recommended to obtain and manage the buffer through an
// ExternalFileHandler[1], which is optimized through mmap(2) to avoid having
// to load the entire buffer in memory when provided by path or file
// descriptor.
//
// [1]:
// mediapipe/tasks/cc/core/external_file_handler.h
static absl::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char* buffer_data, size_t buffer_size);
// Returns the pointer to the *first* ProcessUnit with the provided type, or
// nullptr if none can be found. An error is returned if multiple
// ProcessUnit-s with the provided type are found.
static absl::StatusOr<const tflite::ProcessUnit*> FindFirstProcessUnit(
const tflite::TensorMetadata& tensor_metadata,
tflite::ProcessUnitOptions type);
// Returns the name of the *first* associated file with the provided type and
// (optional) locale in the provided TensorMetadata, or an empty string if no
// such associated file can be found (which is not necessarily an error: some
// models have no associated files at all) or its `name` field is unspecified.
// Note: see `GetAssociatedFile` to read the actual file contents.
static std::string FindFirstAssociatedFileName(
const tflite::TensorMetadata& tensor_metadata,
tflite::AssociatedFileType type,
absl::string_view locale = absl::string_view());
// Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no
// metadata was present in the Model FlatBuffer provided at creation time.
const tflite::ModelMetadata* GetModelMetadata() const {
return model_metadata_;
}
// Gets the contents of the associated file with the provided name packed into
// the model metadata. An error is returned if there is no such associated
// file.
absl::StatusOr<absl::string_view> GetAssociatedFile(
const std::string& filename) const;
// Gets the model version from the model metadata. An error is returned if
// either the metadata does not exist or no model version is present in it.
absl::StatusOr<std::string> GetModelVersion() const;
// Note: all methods below retrieves metadata of the *first* subgraph as
// default.
// Gets the metadata for input tensors.
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const;
// Gets the metadata for the input tensor specified by the given index, or
// nullptr in case there is no metadata or the index is out of range.
const tflite::TensorMetadata* GetInputTensorMetadata(int index) const;
// Gets the count of input tensors with metadata in the metadata FlatBuffer.
// In particular, 0 is returned when there is no metadata.
int GetInputTensorCount() const;
// Gets the metadata for output tensors.
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetOutputTensorMetadata() const;
// Gets the metadata for the output tensor specified by the given index, or
// nullptr in case there is no metadata or the index is out of range.
const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const;
// Gets the count of output tensors with metadata in the metadata FlatBuffer.
// In particular, 0 is returned when there is no metadata.
int GetOutputTensorCount() const;
// Gets the input process units from SubgraphMetadata.input_process_units,
// could be nullptr.
const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const;
// Gets the input process unit specified by the given index, or nullptr in
// case there is no input process unit or the index is out of range.
const tflite::ProcessUnit* GetInputProcessUnit(int index) const;
// Gets the count of input process units. In particular, 0 is returned when
// there is no input process units.
int GetInputProcessUnitsCount() const;
// Gets the output process units from SubgraphMetadata.output_process_units,
// could be nullptr.
const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const;
// Gets the output process unit specified by the given index, or nullptr in
// case there is no output process unit or the index is out of range.
const tflite::ProcessUnit* GetOutputProcessUnit(int index) const;
// Gets the count of output process units. In particular, 0 is returned when
// there is no output process units.
int GetOutputProcessUnitsCount() const;
private:
static constexpr int kDefaultSubgraphIndex = 0;
// Private default constructor, called from CreateFromModel().
ModelMetadataExtractor() = default;
// Initializes the ModelMetadataExtractor from the provided Model FlatBuffer.
absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size);
// Extracts and stores in associated_files_ the associated files (if present)
// packed into the model FlatBuffer data.
absl::Status ExtractAssociatedFiles(const char* buffer_data,
size_t buffer_size);
// Pointer to the TFLite Model object from which to read the ModelMetadata.
const tflite::Model* model_{nullptr};
// Pointer to the extracted ModelMetadata, if any.
const tflite::ModelMetadata* model_metadata_{nullptr};
// The files associated with the ModelMetadata, as a map with the filename
// (corresponding to a basename, e.g. "labels.txt") as key and a pointer to
// the file contents as value.
absl::flat_hash_map<std::string, absl::string_view> associated_files_;
};
} // namespace metadata
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_METADATA_METADATA_EXTRACTOR_H_