Merge branch 'google:master' into image-segmenter-python-impl
This commit is contained in:
commit
36ac0689d7
|
@ -53,7 +53,7 @@ RUN pip3 install wheel
|
|||
RUN pip3 install future
|
||||
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
|
||||
RUN pip3 install six==1.14.0
|
||||
RUN pip3 install tensorflow==2.2.0
|
||||
RUN pip3 install tensorflow
|
||||
RUN pip3 install tf_slim
|
||||
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
|
|
@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
|
|||
}
|
||||
```
|
||||
|
||||
## Graph Options
|
||||
|
||||
It is possible to specify a "graph options" protobuf for a MediaPipe graph
|
||||
similar to the [`Calculator Options`](calculators.md#calculator-options)
|
||||
protobuf specified for a MediaPipe calculator. These "graph options" can be
|
||||
specified where a graph is invoked, and used to populate calculator options and
|
||||
subgraph options within the graph.
|
||||
|
||||
In a CalculatorGraphConfig, graph options can be specified for a subgraph
|
||||
exactly like calculator options, as shown below:
|
||||
|
||||
```
|
||||
node {
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "image"
|
||||
output_stream: "throttled_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
|
||||
max_in_flight: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "FaceDetectionSubgraph"
|
||||
input_stream: "IMAGE:throttled_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
|
||||
tensor_width: 192
|
||||
tensor_height: 192
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In a CalculatorGraphConfig, graph options can be accepted and used to populate
|
||||
calculator options, as shown below:
|
||||
|
||||
```
|
||||
graph_options: {
|
||||
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
|
||||
}
|
||||
|
||||
node: {
|
||||
calculator: "ImageToTensorCalculator"
|
||||
input_stream: "IMAGE:multi_backend_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
|
||||
keep_aspect_ratio: true
|
||||
border_mode: BORDER_ZERO
|
||||
}
|
||||
}
|
||||
option_value: "output_tensor_width:options/tensor_width"
|
||||
option_value: "output_tensor_height:options/tensor_height"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "InferenceCalculator"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
|
||||
}
|
||||
option_value: "delegate:options/delegate"
|
||||
option_value: "model_path:options/model_path"
|
||||
}
|
||||
```
|
||||
|
||||
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
|
||||
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
|
||||
values in the calculator options `ImageToTensorCalculatorOptions` and some field
|
||||
values in the subgraph options `InferenceCalculatorOptions`. The field values
|
||||
are defined using the `option_value:` syntax.
|
||||
|
||||
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
|
||||
`option_value:` together define the option values for a calculator such as
|
||||
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
|
||||
constant values using the text protobuf syntax. Each `option_value:` field
|
||||
defines the value for one protobuf field using information from the enclosing
|
||||
graph, specifically from field values of the graph options of the enclosing
|
||||
graph. In the example above, the `option_value:`
|
||||
`"output_tensor_width:options/tensor_width"` defines the field
|
||||
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
|
||||
`FaceDetectionOptions.tensor_width`.
|
||||
|
||||
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
|
||||
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
|
||||
field and the RHS identifies a graph option field. More specifically, the LHS
|
||||
and RHS each consists of a series of protobuf field names identifying nested
|
||||
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
|
||||
syntax. Nested messages that are referenced in the LHS or RHS must already be
|
||||
defined in the enclosing protobuf in order to be traversed using
|
||||
`option_value:`.
|
||||
|
||||
## Cycles
|
||||
|
||||
<!-- TODO: add discussion of PreviousLoopbackCalculator -->
|
||||
|
|
|
@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow
|
|||
|
||||
```bash
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt
|
||||
```
|
||||
|
||||
This will open up your webcam as long as it is connected and on. Any errors
|
||||
|
|
|
@ -209,11 +209,18 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "rotation_mode_proto",
|
||||
srcs = ["rotation_mode.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "image_transformation_calculator_proto",
|
||||
srcs = ["image_transformation_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rotation_mode_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/gpu:scale_mode_proto",
|
||||
|
@ -238,6 +245,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rotation_mode_cc_proto",
|
||||
":image_transformation_calculator_cc_proto",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/rotation_mode.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
|
|
|
@ -16,20 +16,10 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/calculators/image/rotation_mode.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/gpu/scale_mode.proto";
|
||||
|
||||
// Counterclockwise rotation.
|
||||
message RotationMode {
|
||||
enum Mode {
|
||||
UNKNOWN = 0;
|
||||
ROTATION_0 = 1;
|
||||
ROTATION_90 = 2;
|
||||
ROTATION_180 = 3;
|
||||
ROTATION_270 = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message ImageTransformationCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ImageTransformationCalculatorOptions ext = 251952830;
|
||||
|
|
28
mediapipe/calculators/image/rotation_mode.proto
Normal file
28
mediapipe/calculators/image/rotation_mode.proto
Normal file
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
// Counterclockwise rotation.
|
||||
message RotationMode {
|
||||
enum Mode {
|
||||
UNKNOWN = 0;
|
||||
ROTATION_0 = 1;
|
||||
ROTATION_90 = 2;
|
||||
ROTATION_180 = 3;
|
||||
ROTATION_270 = 4;
|
||||
}
|
||||
}
|
|
@ -161,6 +161,98 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "bert_preprocessor_calculator_proto",
|
||||
srcs = ["bert_preprocessor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bert_preprocessor_calculator",
|
||||
srcs = ["bert_preprocessor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":bert_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "bert_preprocessor_calculator_test",
|
||||
srcs = ["bert_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":bert_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "regex_preprocessor_calculator_proto",
|
||||
srcs = ["regex_preprocessor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "regex_preprocessor_calculator",
|
||||
srcs = ["regex_preprocessor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "inference_calculator_proto",
|
||||
srcs = ["inference_calculator.proto"],
|
||||
|
@ -320,6 +412,8 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
|
251
mediapipe/calculators/tensor/bert_preprocessor_calculator.cc
Normal file
251
mediapipe/calculators/tensor/bert_preprocessor_calculator.cc
Normal file
|
@ -0,0 +1,251 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kTokenizerProcessUnitIndex = 0;
|
||||
constexpr absl::string_view kInputIdsTensorName = "ids";
|
||||
constexpr absl::string_view kInputMasksTensorName = "mask";
|
||||
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
|
||||
constexpr absl::string_view kClassifierToken = "[CLS]";
|
||||
constexpr absl::string_view kSeparatorToken = "[SEP]";
|
||||
|
||||
// Preprocesses input text into three int32 input tensors for a BERT model using
|
||||
// a tokenizer.
|
||||
// The associated BERT model is expected to contain input tensors with names:
|
||||
//
|
||||
// Tensor | Metadata Name
|
||||
// ---------------- | --------------
|
||||
// IDs | "ids"
|
||||
// Segment IDs | "segment_ids"
|
||||
// Mask | "mask"
|
||||
//
|
||||
// This calculator will return an error if the model does not have three input
|
||||
// tensors or if the tensors do not have names corresponding to the above
|
||||
// metadata names in some order. Additional details regarding these input
|
||||
// tensors are given in the Calculator "Outputs" section below.
|
||||
//
|
||||
// This calculator is currently configured for the TextClassifier Task but it
|
||||
// will eventually be generalized for other Text Tasks.
|
||||
// TODO: Handle preprocessing for other Text Tasks too.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The input text.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the BERT model. Used to determine the order of
|
||||
// the three input Tensors for the BERT model and to extract the metadata to
|
||||
// construct the tokenizer.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing the three input Tensors for the BERT model:
|
||||
// (1): the token ids of the tokenized input string. A classifier token
|
||||
// ("[CLS]") will be prepended to the input tokens and a separator
|
||||
// token ("[SEP]") will be appended to the input tokens.
|
||||
// (2): the segment ids, which are all 0 for now but will have different
|
||||
// values to distinguish between different sentences in the input
|
||||
// text for other Text tasks.
|
||||
// (3): the input mask ids, which are 1 at each of the input token indices
|
||||
// and 0 elsewhere.
|
||||
// The Tensors will have size equal to the max sequence length for the BERT
|
||||
// model.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "BertPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// options {
|
||||
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
|
||||
// bert_max_seq_len: 128
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class BertPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
|
||||
// The max sequence length accepted by the BERT model.
|
||||
int bert_max_seq_len_ = 2;
|
||||
// Indices of the three input tensors for the BERT model. They should form the
|
||||
// set {0, 1, 2}.
|
||||
int input_ids_tensor_index_ = 0;
|
||||
int segment_ids_tensor_index_ = 1;
|
||||
int input_masks_tensor_index_ = 2;
|
||||
|
||||
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
|
||||
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
|
||||
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
|
||||
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
|
||||
// Processes the `input_tokens` to generate the three input tensors for the
|
||||
// BERT model.
|
||||
std::vector<Tensor> GenerateInputTensors(
|
||||
const std::vector<std::string>& input_tokens);
|
||||
};
|
||||
|
||||
absl::Status BertPreprocessorCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
|
||||
RET_CHECK_GE(options.bert_max_seq_len(), 2)
|
||||
<< "bert_max_seq_len must be at least 2";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
const tflite::ProcessUnit* tokenizer_metadata =
|
||||
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
|
||||
ASSIGN_OR_RETURN(tokenizer_,
|
||||
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
|
||||
tokenizer_metadata, metadata_extractor));
|
||||
|
||||
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kInputIdsTensorName);
|
||||
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kSegmentIdsTensorName);
|
||||
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kInputMasksTensorName);
|
||||
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
|
||||
segment_ids_tensor_index_,
|
||||
input_masks_tensor_index_};
|
||||
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||
input_ids_tensor_index_, segment_ids_tensor_index_,
|
||||
input_masks_tensor_index_));
|
||||
}
|
||||
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||
bert_max_seq_len_ = options.bert_max_seq_len();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||
kTensorsOut(cc).Send(
|
||||
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
|
||||
absl::string_view input_text) {
|
||||
std::string processed_input = std::string(input_text);
|
||||
absl::AsciiStrToLower(&processed_input);
|
||||
|
||||
tasks::text::tokenizers::TokenizerResult tokenizer_result =
|
||||
tokenizer_->Tokenize(processed_input);
|
||||
|
||||
// Offset by 2 to account for [CLS] and [SEP]
|
||||
int input_tokens_size =
|
||||
std::min(bert_max_seq_len_,
|
||||
static_cast<int>(tokenizer_result.subwords.size()) + 2);
|
||||
std::vector<std::string> input_tokens;
|
||||
input_tokens.reserve(input_tokens_size);
|
||||
input_tokens.push_back(std::string(kClassifierToken));
|
||||
for (int i = 0; i < input_tokens_size - 2; ++i) {
|
||||
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
|
||||
}
|
||||
input_tokens.push_back(std::string(kSeparatorToken));
|
||||
return input_tokens;
|
||||
}
|
||||
|
||||
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
||||
const std::vector<std::string>& input_tokens) {
|
||||
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
|
||||
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
|
||||
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
|
||||
// Convert tokens back into ids and set mask
|
||||
for (int i = 0; i < input_tokens.size(); ++i) {
|
||||
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
|
||||
input_masks[i] = 1;
|
||||
}
|
||||
// |<--------bert_max_seq_len_--------->|
|
||||
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
|
||||
// segment_ids 0 0 0... 0 0 0 0... 0
|
||||
// input_masks 1 1 1... 1 1 0 0... 0
|
||||
|
||||
std::vector<Tensor> input_tensors;
|
||||
input_tensors.reserve(kNumInputTensorsForBert);
|
||||
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
||||
input_tensors.push_back(
|
||||
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
|
||||
}
|
||||
std::memcpy(input_tensors[input_ids_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<int32_t>(),
|
||||
input_ids.data(), input_ids.size() * sizeof(int32_t));
|
||||
std::memcpy(input_tensors[segment_ids_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<int32_t>(),
|
||||
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
|
||||
std::memcpy(input_tensors[input_masks_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<int32_t>(),
|
||||
input_masks.data(), input_masks.size() * sizeof(int32_t));
|
||||
return input_tensors;
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -13,22 +13,17 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
syntax = "proto2";
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
package mediapipe;
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver {
|
||||
public:
|
||||
HandDetectorOpResolver();
|
||||
HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete;
|
||||
};
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
message BertPreprocessorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional BertPreprocessorCalculatorOptions ext = 462509271;
|
||||
}
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
// The maximum input sequence length for the calculator's BERT model.
|
||||
optional int32 bert_max_seq_len = 1;
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kBertMaxSeqLen = 128;
|
||||
constexpr absl::string_view kTestModelPath =
|
||||
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
|
||||
absl::string_view text, absl::string_view model_path) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "BertPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
options {
|
||||
[mediapipe.BertPreprocessorCalculatorOptions.ext] {
|
||||
bert_max_seq_len: $0
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
kBertMaxSeqLen));
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data());
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != kNumInputTensorsForBert) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::Substitute("tensor_vec has size $0, expected $1",
|
||||
tensor_vec.size(), kNumInputTensorsForBert));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> results;
|
||||
for (int i = 0; i < kNumInputTensorsForBert; i++) {
|
||||
const Tensor& tensor = tensor_vec[i];
|
||||
if (tensor.element_type() != Tensor::ElementType::kInt32) {
|
||||
return absl::InvalidArgumentError("Expected tensor element type kInt32");
|
||||
}
|
||||
auto* buffer = tensor.GetCpuReadView().buffer<int>();
|
||||
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen);
|
||||
results.push_back(buffer_view);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
|
||||
return results;
|
||||
}
|
||||
|
||||
TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) {
|
||||
std::vector<std::vector<int>> expected_result = {
|
||||
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}};
|
||||
// segment_ids
|
||||
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
|
||||
// input_masks
|
||||
expected_result.push_back(std::vector(expected_result[0].size(), 1));
|
||||
expected_result[2].resize(kBertMaxSeqLen);
|
||||
// padding input_ids
|
||||
expected_result[0].resize(kBertMaxSeqLen);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<std::vector<int>> processed_tensor_values,
|
||||
RunBertPreprocessorCalculator(
|
||||
"it's a charming and often affecting journey", kTestModelPath));
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(BertPreprocessorCalculatorTest, LongInput) {
|
||||
std::stringstream long_input;
|
||||
long_input
|
||||
<< "it's a charming and often affecting journey and this is a long";
|
||||
for (int i = 0; i < kBertMaxSeqLen; ++i) {
|
||||
long_input << " long";
|
||||
}
|
||||
long_input << " movie review";
|
||||
std::vector<std::vector<int>> expected_result = {
|
||||
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023,
|
||||
2003, 1037}};
|
||||
// "long" id
|
||||
expected_result[0].resize(kBertMaxSeqLen - 1, 2146);
|
||||
// "[SEP]" id
|
||||
expected_result[0].push_back(102);
|
||||
// segment_ids
|
||||
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
|
||||
// input_masks
|
||||
expected_result.push_back(std::vector(kBertMaxSeqLen, 1));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<std::vector<int>> processed_tensor_values,
|
||||
RunBertPreprocessorCalculator(long_input.str(), kTestModelPath));
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
|
||||
const Size size{image->width(), image->height()};
|
||||
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
|
||||
|
||||
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
|
||||
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
||||
options_.output_tensor_height(),
|
||||
options_.keep_aspect_ratio(), &roi));
|
||||
|
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
if (kOutMatrix(cc).IsConnected()) {
|
||||
std::array<float, 16> matrix;
|
||||
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
|
||||
/*flip_horizontaly=*/false,
|
||||
&matrix);
|
||||
GetRotatedSubRectToRectTransformMatrix(
|
||||
roi, image->width(), image->height(),
|
||||
/*flip_horizontaly=*/false, &matrix);
|
||||
kOutMatrix(cc).Send(std::move(matrix));
|
||||
}
|
||||
|
||||
// Lazy initialization of the GPU or CPU converter.
|
||||
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
|
||||
|
||||
ASSIGN_OR_RETURN(Tensor tensor,
|
||||
(image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||
->Convert(*image, roi, {output_width_, output_height_},
|
||||
range_min_, range_max_));
|
||||
Tensor::ElementType output_tensor_type =
|
||||
GetOutputTensorType(image->UsesGpu());
|
||||
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
|
||||
GetNumOutputChannels(*image)});
|
||||
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||
->Convert(*image, roi, range_min_, range_max_,
|
||||
/*tensor_buffer_offset=*/0, tensor));
|
||||
|
||||
auto result = std::make_unique<std::vector<Tensor>>();
|
||||
result->push_back(std::move(tensor));
|
||||
|
@ -292,7 +295,8 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType() {
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
|
||||
if (!uses_gpu) {
|
||||
if (is_float_output_) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
@ -302,6 +306,21 @@ class ImageToTensorCalculator : public Node {
|
|||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
// Always use float32 when GPU is enabled.
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
||||
int GetNumOutputChannels(const Image& image) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
if (image.UsesGpu()) {
|
||||
return 4;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
CalculatorContext* cc) {
|
||||
|
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
|
|||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
ASSIGN_OR_RETURN(
|
||||
cpu_converter_,
|
||||
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
|
||||
CreateOpenCvConverter(cc, GetBorderMode(),
|
||||
GetOutputTensorType(/*uses_gpu=*/false)));
|
||||
#else
|
||||
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
||||
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
||||
|
|
|
@ -42,13 +42,16 @@ class ImageToTensorConverter {
|
|||
// @image contains image to extract from.
|
||||
// @roi describes region of interest within the image to extract (absolute
|
||||
// values).
|
||||
// @output_dims dimensions of output tensor.
|
||||
// @range_min/max describes output tensor range image pixels should converted
|
||||
// to.
|
||||
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims,
|
||||
float range_min, float range_max) = 0;
|
||||
// @tensor_buffer_offset an inteter representing the offset of the tensor
|
||||
// buffer the result should be written to.
|
||||
// @output_tensor a tensor with pre-defined shape. The "Convert" is
|
||||
// responsible of populating the content into the output tensor.
|
||||
virtual absl::Status Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi, float range_min,
|
||||
float range_max, int tensor_buffer_offset,
|
||||
Tensor& output_tensor) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -264,10 +264,10 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
|
@ -275,46 +275,46 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
constexpr int kNumChannels = 3;
|
||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
||||
{1, output_dims.height, output_dims.width, kNumChannels});
|
||||
|
||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi,
|
||||
&output_dims, range_min,
|
||||
range_max]() -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||
range_max, tensor_buffer_offset]() -> absl::Status {
|
||||
constexpr int kRgbaNumChannels = 4;
|
||||
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
||||
tflite::gpu::gl::GlTexture input_texture(
|
||||
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
|
||||
source_texture.width() * source_texture.height() * kRgbaNumChannels *
|
||||
sizeof(uint8_t),
|
||||
source_texture.width() * source_texture.height() *
|
||||
kRgbaNumChannels * sizeof(uint8_t),
|
||||
/*layer=*/0,
|
||||
/*owned=*/false);
|
||||
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
constexpr float kInputImageRangeMax = 1.0f;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto transform,
|
||||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
||||
ASSIGN_OR_RETURN(auto transform,
|
||||
GetValueRangeTransformation(kInputImageRangeMin,
|
||||
kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
|
||||
auto buffer_view = tensor.GetOpenGlBufferWriteView();
|
||||
const int output_size = output_tensor.bytes() / output_shape.dims[0];
|
||||
auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
|
||||
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
|
||||
buffer_view.name(), tensor.bytes(),
|
||||
/*offset=*/0,
|
||||
buffer_view.name(), output_size,
|
||||
/*offset=*/tensor_buffer_offset,
|
||||
/*has_ownership=*/false);
|
||||
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
|
||||
input_texture,
|
||||
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi,
|
||||
tflite::gpu::HW(source_texture.height(), source_texture.width()),
|
||||
roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_dims.height, output_dims.width),
|
||||
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||
command_queue_.get(), &output));
|
||||
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
~GlProcessor() override {
|
||||
|
@ -326,6 +326,17 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
|
||||
std::unique_ptr<SubRectExtractorGl> extractor_;
|
||||
mediapipe::GlCalculatorHelper gl_helper_;
|
||||
|
|
|
@ -168,10 +168,10 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
|
@ -179,15 +179,15 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
// TODO: support tensor_buffer_offset > 0 scenario.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
constexpr int kNumChannels = 3;
|
||||
Tensor tensor(
|
||||
Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
|
||||
|
||||
MP_RETURN_IF_ERROR(
|
||||
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
|
||||
range_min, range_max]() -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||
range_max]() -> absl::Status {
|
||||
auto input_texture = gl_helper_.CreateSourceTexture(input);
|
||||
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
|
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
GetValueRangeTransformation(kInputImageRangeMin,
|
||||
kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
auto tensor_view = tensor.GetOpenGlTexture2dWriteView();
|
||||
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
|
||||
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
|
||||
/*flip_horizontaly=*/false,
|
||||
transform.scale, transform.offset,
|
||||
output_dims, &tensor_view));
|
||||
output_shape, &tensor_view));
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
|
||||
const RotatedRect& sub_rect,
|
||||
bool flip_horizontaly, float alpha, float beta,
|
||||
const Size& output_dims,
|
||||
const Tensor::Shape& output_shape,
|
||||
Tensor::OpenGlTexture2dView* output) {
|
||||
const int output_height = output_shape.dims[1];
|
||||
const int output_width = output_shape.dims[2];
|
||||
std::array<float, 16> transform_mat;
|
||||
|
||||
glDisable(GL_DEPTH_TEST);
|
||||
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
|
||||
glViewport(0, 0, output_dims.width, output_dims.height);
|
||||
glViewport(0, 0, output_width, output_height);
|
||||
|
||||
glActiveTexture(GL_TEXTURE0);
|
||||
glBindTexture(GL_TEXTURE_2D, output->name());
|
||||
|
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::GlCalculatorHelper gl_helper_;
|
||||
bool use_custom_zero_border_ = false;
|
||||
BorderMode border_mode_ = BorderMode::kReplicate;
|
||||
|
|
|
@ -262,7 +262,6 @@ class SubRectExtractorMetal {
|
|||
RET_CHECK(pipeline_state != nil);
|
||||
|
||||
std::string output_type_def;
|
||||
MTLPixelFormat pixel_format;
|
||||
switch (output_format) {
|
||||
case OutputFormat::kF16C4:
|
||||
output_type_def = R"(
|
||||
|
@ -348,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
|
@ -359,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
@autoreleasepool {
|
||||
id<MTLTexture> texture =
|
||||
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
|
||||
|
||||
constexpr int kNumChannels = 4;
|
||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, output_dims.height, output_dims.width,
|
||||
kNumChannels});
|
||||
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
constexpr float kInputImageRangeMax = 1.0f;
|
||||
ASSIGN_OR_RETURN(
|
||||
|
@ -377,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
range_min, range_max));
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
|
||||
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer);
|
||||
const auto& buffer_view =
|
||||
output_tensor.GetMtlBufferWriteView(command_buffer);
|
||||
MP_RETURN_IF_ERROR(extractor_->Execute(
|
||||
texture, roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_dims.height, output_dims.width),
|
||||
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||
command_buffer, buffer_view.buffer()));
|
||||
[command_buffer commit];
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 4)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MPPMetalHelper* metal_helper_ = nil;
|
||||
std::unique_ptr<SubRectExtractorMetal> extractor_;
|
||||
};
|
||||
|
|
|
@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
|
||||
input.image_format() != mediapipe::ImageFormat::SRGBA) {
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.image_format())));
|
||||
}
|
||||
auto src = mediapipe::formats::MatView(&input);
|
||||
// TODO: Remove the check once tensor_buffer_offset > 0 is
|
||||
// supported.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
constexpr int kNumChannels = 3;
|
||||
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height,
|
||||
output_dims.width, kNumChannels});
|
||||
auto buffer_view = tensor.GetCpuWriteView();
|
||||
const int output_height = output_shape.dims[1];
|
||||
const int output_width = output_shape.dims[2];
|
||||
const int output_channels = output_shape.dims[3];
|
||||
auto buffer_view = output_tensor.GetCpuWriteView();
|
||||
cv::Mat dst;
|
||||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
buffer_view.buffer<int8>());
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
buffer_view.buffer<float>());
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
buffer_view.buffer<uint8>());
|
||||
break;
|
||||
default:
|
||||
|
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
cv::Mat src_points;
|
||||
cv::boxPoints(rotated_rect, src_points);
|
||||
|
||||
const float dst_width = output_dims.width;
|
||||
const float dst_height = output_dims.height;
|
||||
const float dst_width = output_width;
|
||||
const float dst_height = output_height;
|
||||
/* clang-format off */
|
||||
float dst_corners[8] = {0.0f, dst_height,
|
||||
0.0f, 0.0f,
|
||||
|
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
dst_width, dst_height};
|
||||
/* clang-format on */
|
||||
|
||||
auto src = mediapipe::formats::MatView(&input);
|
||||
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
|
||||
cv::Mat projection_matrix =
|
||||
cv::getPerspectiveTransform(src_points, dst_points);
|
||||
|
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
/*flags=*/cv::INTER_LINEAR,
|
||||
/*borderMode=*/border_mode_);
|
||||
|
||||
if (transformed.channels() > kNumChannels) {
|
||||
if (transformed.channels() > output_channels) {
|
||||
cv::Mat proper_channels_mat;
|
||||
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
|
||||
transformed = proper_channels_mat;
|
||||
|
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
enum cv::BorderTypes border_mode_;
|
||||
Tensor::ElementType tensor_type_;
|
||||
int mat_type_;
|
||||
|
|
|
@ -224,9 +224,6 @@ absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
|
|||
|
||||
void InferenceCalculatorMetalImpl::AddDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
const auto& calculator_opts =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
|
||||
// Configure and create the delegate.
|
||||
TFLGpuDelegateOptions options;
|
||||
// `enable_quantization` enables the run of sparse models i.e. the models with
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
|||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
const char* input_tensor_buffer =
|
||||
input_tensor.GetCpuReadView().buffer<char>();
|
||||
tflite::DynamicBuffer dynamic_buffer;
|
||||
dynamic_buffer.AddString(input_tensor_buffer,
|
||||
input_tensor.shape().num_elements());
|
||||
dynamic_buffer.WriteToTensorAsVector(
|
||||
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
|
@ -87,12 +102,12 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
|
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteString: {
|
||||
CopyTensorBufferToInterpreter<char>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteBool:
|
||||
// No current use-case for copying MediaPipe Tensors with bool type to
|
||||
// TfLiteTensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||
|
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteBool:
|
||||
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteString:
|
||||
// No current use-case for copying TfLiteTensors with string type to
|
||||
// MediaPipe Tensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported output tensor type:",
|
||||
|
|
174
mediapipe/calculators/tensor/regex_preprocessor_calculator.cc
Normal file
174
mediapipe/calculators/tensor/regex_preprocessor_calculator.cc
Normal file
|
@ -0,0 +1,174 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
// Preprocesses input text into one int32 input tensor for a text model using
|
||||
// a RegexTokenizer.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The input text.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the text model. Used to extract the metadata
|
||||
// to construct the RegexTokenizer.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor which is the text model's input tensor.
|
||||
// Depending on the tokenizer metadata, the tensor may start with
|
||||
// the id of the tokenizer's <START> token. The following tensor values will
|
||||
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
|
||||
// have the id of the <UNKNOWN> token. The tensor will be padded with the
|
||||
// <PAD> token id to have size equal to the max sequence length for the text
|
||||
// model.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "RegexPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// options {
|
||||
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
|
||||
// max_seq_len: 256
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class RegexPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
|
||||
// The max sequence length accepted by the text model.
|
||||
int max_seq_len_ = 0;
|
||||
};
|
||||
|
||||
absl::Status RegexPreprocessorCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
|
||||
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
|
||||
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
const tflite::TensorMetadata* tensor_metadata =
|
||||
metadata_extractor->GetInputTensorMetadata(0);
|
||||
if (tensor_metadata == nullptr) {
|
||||
return absl::InvalidArgumentError("No tensor metadata found");
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* tokenizer_metadata,
|
||||
metadata_extractor->FindFirstProcessUnit(
|
||||
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
|
||||
if (tokenizer_metadata == nullptr) {
|
||||
return absl::InvalidArgumentError("No tokenizer metadata found");
|
||||
}
|
||||
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
|
||||
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
|
||||
ASSIGN_OR_RETURN(tokenizer_,
|
||||
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
|
||||
regex_tokenizer_options, metadata_extractor));
|
||||
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
|
||||
max_seq_len_ = options.max_seq_len();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||
tasks::text::tokenizers::TokenizerResult tokenizer_result =
|
||||
tokenizer_->Tokenize(kTextIn(cc).Get());
|
||||
|
||||
int unknown_token_id = 0;
|
||||
tokenizer_->GetUnknownToken(&unknown_token_id);
|
||||
int pad_token_id = 0;
|
||||
tokenizer_->GetPadToken(&pad_token_id);
|
||||
|
||||
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
|
||||
int start_token_id = 0;
|
||||
int input_token_index = 0;
|
||||
if (tokenizer_->GetStartToken(&start_token_id)) {
|
||||
input_tokens[0] = start_token_id;
|
||||
input_token_index = 1;
|
||||
}
|
||||
|
||||
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
|
||||
(input_token_index < max_seq_len_);
|
||||
++i, ++input_token_index) {
|
||||
const std::string& token = tokenizer_result.subwords[i];
|
||||
int token_id = 0;
|
||||
if (tokenizer_->LookupId(token, &token_id)) {
|
||||
input_tokens[input_token_index] = token_id;
|
||||
} else {
|
||||
input_tokens[input_token_index] = unknown_token_id;
|
||||
}
|
||||
}
|
||||
|
||||
// |<-------sentence_length-------->|
|
||||
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
|
||||
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
|
||||
// not found in the tokenizer vocab.
|
||||
std::vector<Tensor> result;
|
||||
result.push_back(
|
||||
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
|
||||
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
||||
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
||||
kTensorsOut(cc).Send(std::move(result));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message RegexPreprocessorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional RegexPreprocessorCalculatorOptions ext = 463716697;
|
||||
}
|
||||
|
||||
// The maximum input sequence length for the calculator's text model.
|
||||
optional int32 max_seq_len = 1;
|
||||
}
|
|
@ -296,7 +296,6 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) {
|
|||
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, height, width, channels});
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TensorConverterCalculatorConvert";
|
||||
id<MTLComputeCommandEncoder> compute_encoder =
|
||||
|
|
|
@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
|
|||
case Tensor::ElementType::kInt8:
|
||||
Dequantize<int8>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
case Tensor::ElementType::kBool:
|
||||
Dequantize<bool>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Unsupported input tensor type: ", input_tensor.element_type()));
|
||||
|
|
|
@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
|
|||
ValidateResult(GetOutput(), {-1.007874, 0, 1});
|
||||
}
|
||||
|
||||
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
|
||||
std::vector<bool> tensor = {true, false, true};
|
||||
PushTensor(Tensor::ElementType::kBool, tensor,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
ValidateResult(GetOutput(), {1, 0, 1});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -532,7 +532,6 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
detection_classes.data(),
|
||||
output_detections));
|
||||
#elif MEDIAPIPE_METAL_ENABLED
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
if (!anchors_init_) {
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
|
||||
constexpr absl::string_view kResponseContextMetadataName = "res_context";
|
||||
constexpr absl::string_view kResponseTextMetadataName = "res_text";
|
||||
|
||||
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||
|
||||
// Preprocesses input text into three kTfLiteString input tensors for a
|
||||
// Universal Sentence Encoder (USE) model.
|
||||
//
|
||||
// The associated USE model is expected to contain input tensors with metadata
|
||||
// names:
|
||||
//
|
||||
// Tensor | Metadata Name
|
||||
// ---------------- | ------------------
|
||||
// Query text | "inp_text"
|
||||
// Response context | "res_context"
|
||||
// Response text | "res_text"
|
||||
//
|
||||
// This calculator will return an error if the model does not have three input
|
||||
// tensors or if the tensors do not have metadata names corresponding to the
|
||||
// above names in some order. Additional details regarding these input
|
||||
// tensors are given in the Calculator "Outputs" section below.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The text to be embedded.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the USE model. Used to determine the order of
|
||||
// the three input Tensors for the USE model.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing the three input Tensors for the USE model. The tensors
|
||||
// fit a question-answering setting and store a query text, a response
|
||||
// context, and a response text. This calculator will just be preprocessing
|
||||
// a single input text that will be stored in the response text tensor. The
|
||||
// query text and response context tensors will store empty strings.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// }
|
||||
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Indices of the three input tensors for the USE model. They should form the
|
||||
// set {0, 1, 2}.
|
||||
int query_text_tensor_index_ = 0;
|
||||
int response_context_tensor_index_ = 1;
|
||||
int response_text_tensor_index_ = 2;
|
||||
|
||||
// Tensor shapes for the model's input tensors.
|
||||
// The query text and response context tensors will only hold the empty
|
||||
// string, so their tensors will have shape [0], but the Universal Sentence
|
||||
// Encoder model's input signature requires them to be present. The response
|
||||
// text tensor will store the embedding text and have shape
|
||||
// [embedding_text_len].
|
||||
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
|
||||
};
|
||||
|
||||
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||
query_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kQueryTextMetadataName);
|
||||
response_context_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kResponseContextMetadataName);
|
||||
response_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kResponseTextMetadataName);
|
||||
|
||||
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
|
||||
{query_text_tensor_index_, response_context_tensor_index_,
|
||||
response_text_tensor_index_});
|
||||
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||
query_text_tensor_index_, response_context_tensor_index_,
|
||||
response_text_tensor_index_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
absl::string_view text = kTextIn(cc).Get();
|
||||
const int text_len = static_cast<int>(text.length());
|
||||
tensor_shapes_[response_text_tensor_index_] = text_len;
|
||||
|
||||
std::vector<Tensor> input_tensors;
|
||||
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
input_tensors.push_back(
|
||||
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
|
||||
"", 0);
|
||||
std::memcpy(input_tensors[response_context_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<char>(),
|
||||
"", 0);
|
||||
std::memcpy(input_tensors[response_text_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<char>(),
|
||||
text.data(), text_len * sizeof(char));
|
||||
kTensorsOut(cc).Send(std::move(input_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/options_map.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::IsOkAndHolds;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||
|
||||
constexpr absl::string_view kTestModelPath =
|
||||
"mediapipe/tasks/testdata/text/"
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<std::string>>
|
||||
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer =
|
||||
tasks::core::LoadBinaryContent(kTestModelPath.data());
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor_vec has size $0, expected $1", tensor_vec.size(),
|
||||
kNumInputTensorsForUniversalSentenceEncoder));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||
Tensor::ElementType::kChar));
|
||||
}
|
||||
std::vector<std::string> results;
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
results.push_back(
|
||||
{tensor_vec[i].GetCpuReadView().buffer<char>(),
|
||||
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
|
||||
ASSERT_THAT(
|
||||
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
|
||||
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -331,6 +331,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
|
|
@ -499,7 +499,6 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
|||
gpu_data_out_ = absl::make_unique<GPUData>();
|
||||
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
|
||||
const bool include_alpha = (max_num_channels_ == 4);
|
||||
const bool single_channel = (max_num_channels_ == 1);
|
||||
if (!(format == mediapipe::ImageFormat::GRAY8 ||
|
||||
format == mediapipe::ImageFormat::SRGB ||
|
||||
format == mediapipe::ImageFormat::SRGBA))
|
||||
|
@ -509,6 +508,7 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
|||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
const bool single_channel = (max_num_channels_ == 1);
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &include_alpha, &input, &single_channel]() -> absl::Status {
|
||||
// Device memory.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -81,6 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
||||
#ifdef ABSL_HAVE_MMAP
|
||||
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
||||
const auto& model_fd =
|
||||
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
||||
|
@ -89,6 +91,10 @@ class TfLiteModelCalculator : public CalculatorBase {
|
|||
tflite::DefaultErrorReporter());
|
||||
model = tflite::FlatBufferModel::BuildFromAllocation(
|
||||
std::move(model_allocation), tflite::DefaultErrorReporter());
|
||||
#elif
|
||||
return absl::FailedPreconditionError(
|
||||
"Loading by file descriptor is not supported on this platform.")
|
||||
#endif
|
||||
}
|
||||
|
||||
RET_CHECK(model) << "Failed to load TfLite model from blob.";
|
||||
|
|
|
@ -143,9 +143,7 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "packet_frequency_calculator",
|
||||
srcs = ["packet_frequency_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:packet_frequency_cc_proto",
|
||||
|
@ -190,9 +188,7 @@ cc_test(
|
|||
cc_library(
|
||||
name = "packet_latency_calculator",
|
||||
srcs = ["packet_latency_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:latency_cc_proto",
|
||||
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto",
|
||||
|
|
|
@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
text->set_left(label_left_px_);
|
||||
text->set_baseline(label_baseline_px + i * label_height_px_);
|
||||
text->set_font_face(options_.font_face());
|
||||
if (options_.outline_thickness() > 0) {
|
||||
text->set_outline_thickness(options_.outline_thickness());
|
||||
if (options_.outline_color_size() > 0) {
|
||||
*(text->mutable_outline_color()) =
|
||||
options_.outline_color(i % options_.outline_color_size());
|
||||
} else {
|
||||
text->mutable_outline_color()->set_r(0);
|
||||
text->mutable_outline_color()->set_g(0);
|
||||
text->mutable_outline_color()->set_b(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kRenderDataTag)
|
||||
|
|
|
@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
|
|||
// Thickness for drawing the label(s).
|
||||
optional double thickness = 2 [default = 2];
|
||||
|
||||
// Color of outline around each character, if any. One per label, as with
|
||||
// color attribute.
|
||||
repeated Color outline_color = 12;
|
||||
|
||||
// Thickness of outline around each character.
|
||||
optional double outline_thickness = 11;
|
||||
|
||||
// The font height in absolute pixels.
|
||||
optional int32 font_height_px = 3 [default = 50];
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import android.content.ClipDescription;
|
|||
import android.content.Context;
|
||||
import android.net.Uri;
|
||||
import android.os.Bundle;
|
||||
import androidx.appcompat.widget.AppCompatEditText;
|
||||
import android.support.v7.widget.AppCompatEditText;
|
||||
import android.util.AttributeSet;
|
||||
import android.util.Log;
|
||||
import android.view.inputmethod.EditorInfo;
|
||||
|
|
|
@ -1685,10 +1685,3 @@ cc_test(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -14,15 +14,10 @@ cc_library(
|
|||
name = "builder",
|
||||
hdrs = ["builder.h"],
|
||||
deps = [
|
||||
":const_str",
|
||||
":contract",
|
||||
":node",
|
||||
":packet",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -5,12 +5,7 @@
|
|||
#include <type_traits>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
#include "mediapipe/framework/api2/contract.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
|
@ -112,6 +107,17 @@ class MultiPort : public Single {
|
|||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
||||
namespace internal_builder {
|
||||
|
||||
template <typename T, typename U>
|
||||
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>>;
|
||||
|
||||
} // namespace internal_builder
|
||||
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
class SourceImpl;
|
||||
|
||||
// These classes wrap references to the underlying source/destination
|
||||
// endpoints, adding type information and the user-visible API.
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
|
@ -122,16 +128,21 @@ class DestinationImpl {
|
|||
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
||||
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
DestinationImpl<IsSide, U> Cast() {
|
||||
return DestinationImpl<IsSide, U>(&base_);
|
||||
}
|
||||
|
||||
private:
|
||||
DestinationBase& base_;
|
||||
|
||||
template <bool Source_IsSide, typename Source_T>
|
||||
friend class SourceImpl;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
class SourceImpl {
|
||||
public:
|
||||
using Base = SourceBase;
|
||||
|
@ -171,12 +182,8 @@ class SourceImpl {
|
|||
return AddTarget(dest);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
struct AllowCast
|
||||
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>> {};
|
||||
|
||||
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
SourceImpl<IsSide, U> Cast() {
|
||||
return SourceImpl<IsSide, U>(base_);
|
||||
}
|
||||
|
@ -186,12 +193,6 @@ class SourceImpl {
|
|||
SourceBase* base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
// A source and a destination correspond to an output/input stream on a node,
|
||||
// and a side source and side destination correspond to an output/input side
|
||||
// packet.
|
||||
|
@ -201,20 +202,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
|
|||
template <typename T = internal::Generic>
|
||||
using Source = SourceImpl<false, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSource = MultiSourceImpl<false, T>;
|
||||
using MultiSource = MultiPort<Source<T>>;
|
||||
template <typename T = internal::Generic>
|
||||
using SideSource = SourceImpl<true, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSideSource = MultiSourceImpl<true, T>;
|
||||
using MultiSideSource = MultiPort<SideSource<T>>;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
using Destination = DestinationImpl<false, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using SideDestination = DestinationImpl<true, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiDestination = MultiDestinationImpl<false, T>;
|
||||
using MultiDestination = MultiPort<Destination<T>>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSideDestination = MultiDestinationImpl<true, T>;
|
||||
using MultiSideDestination = MultiPort<SideDestination<T>>;
|
||||
|
||||
class NodeBase {
|
||||
public:
|
||||
|
@ -439,8 +440,9 @@ class Graph {
|
|||
// Creates a node of a specific type. Should be used for pure interfaces,
|
||||
// which do not have a built-in type string.
|
||||
template <class Calc>
|
||||
Node<Calc>& AddNode(const std::string& type) {
|
||||
auto node = std::make_unique<Node<Calc>>(type);
|
||||
Node<Calc>& AddNode(absl::string_view type) {
|
||||
auto node =
|
||||
std::make_unique<Node<Calc>>(std::string(type.data(), type.size()));
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
|
@ -448,16 +450,18 @@ class Graph {
|
|||
|
||||
// Creates a generic node, with no compile-time checking of inputs and
|
||||
// outputs. This can be used for calculators whose contract is not visible.
|
||||
GenericNode& AddNode(const std::string& type) {
|
||||
auto node = std::make_unique<GenericNode>(type);
|
||||
GenericNode& AddNode(absl::string_view type) {
|
||||
auto node =
|
||||
std::make_unique<GenericNode>(std::string(type.data(), type.size()));
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
}
|
||||
|
||||
// For legacy PacketGenerators.
|
||||
PacketGenerator& AddPacketGenerator(const std::string& type) {
|
||||
auto node = std::make_unique<PacketGenerator>(type);
|
||||
PacketGenerator& AddPacketGenerator(absl::string_view type) {
|
||||
auto node = std::make_unique<PacketGenerator>(
|
||||
std::string(type.data(), type.size()));
|
||||
auto node_p = node.get();
|
||||
packet_gens_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
|
|
|
@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||
any_type_output.SetName("any_type_output");
|
||||
|
||||
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
|
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
output_stream: "ANY_OUTPUT:any_type_output"
|
||||
}
|
||||
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
||||
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
|
|
@ -334,13 +334,6 @@ mediapipe_register_type(
|
|||
deps = [":landmark_cc_proto"],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image",
|
||||
srcs = ["image.cc"],
|
||||
|
@ -469,6 +462,10 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
defines = select({
|
||||
"//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = select({
|
||||
"//mediapipe:ios": [
|
||||
"-framework CoreVideo",
|
||||
|
|
|
@ -33,10 +33,3 @@ mediapipe_proto_library(
|
|||
srcs = ["rasterization.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -31,11 +31,12 @@
|
|||
#if MEDIAPIPE_METAL_ENABLED
|
||||
#import <Metal/Metal.h>
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
#ifndef MEDIAPIPE_NO_JNI
|
||||
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
|
||||
#define MEDIAPIPE_TENSOR_USE_AHWB 1
|
||||
#endif // __ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
|
||||
#endif // MEDIAPIPE_NO_JNI
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#include <android/hardware_buffer.h>
|
||||
|
@ -43,7 +44,6 @@
|
|||
#include "third_party/GL/gl/include/EGL/egl.h"
|
||||
#include "third_party/GL/gl/include/EGL/eglext.h"
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
|
@ -97,8 +97,8 @@ class Tensor {
|
|||
kUInt8,
|
||||
kInt8,
|
||||
kInt32,
|
||||
// TODO: Update the inference runner to handle kTfLiteString.
|
||||
kChar
|
||||
kChar,
|
||||
kBool
|
||||
};
|
||||
struct Shape {
|
||||
Shape() = default;
|
||||
|
@ -330,6 +330,8 @@ class Tensor {
|
|||
return sizeof(int32_t);
|
||||
case ElementType::kChar:
|
||||
return sizeof(char);
|
||||
case ElementType::kBool:
|
||||
return sizeof(bool);
|
||||
}
|
||||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
|
|
@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
|
|||
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
|
||||
// EGLSync is failed. Use another synchronization method.
|
||||
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
|
||||
glFinish();
|
||||
gl_context_->Run([]() { glFinish(); });
|
||||
} else if (valid_ & kValidAHardwareBuffer) {
|
||||
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
|
||||
"completion function to be set";
|
||||
|
|
|
@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
|
|||
|
||||
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||
|
||||
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
|
||||
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||
}
|
||||
|
||||
TEST(Cpu, TestMemoryAllocation) {
|
||||
|
|
|
@ -64,7 +64,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
|
|||
std::vector<TraceEventType> basic_types = {
|
||||
{TraceEvent::UNKNOWN, "An uninitialized trace-event."},
|
||||
{TraceEvent::OPEN, "A call to Calculator::Open.", true, true},
|
||||
{TraceEvent::PROCESS, "A call to Calculator::Open.", true, true},
|
||||
{TraceEvent::PROCESS, "A call to Calculator::Process.", true, true},
|
||||
{TraceEvent::CLOSE, "A call to Calculator::Close.", true, true},
|
||||
|
||||
{TraceEvent::NOT_READY, "A calculator cannot process packets yet."},
|
||||
|
|
|
@ -150,7 +150,7 @@ cc_library(
|
|||
name = "executor_util",
|
||||
srcs = ["executor_util.cc"],
|
||||
hdrs = ["executor_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
|
|
|
@ -1050,7 +1050,7 @@ objc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
MIN_IOS_VERSION = "9.0" # For thread_local.
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
test_suite(
|
||||
name = "ios",
|
||||
|
|
|
@ -111,7 +111,8 @@ typedef CVOpenGLESTextureCacheRef CVTextureCacheType;
|
|||
- (CVMetalTextureCacheRef)mtlTextureCache {
|
||||
@synchronized(self) {
|
||||
if (!_mtlTextureCache) {
|
||||
CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
CVReturn __unused err =
|
||||
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err);
|
||||
// TODO: register and flush metal caches too.
|
||||
}
|
||||
|
|
|
@ -38,10 +38,6 @@ static pthread_key_t egl_release_thread_key;
|
|||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
||||
|
||||
static void EglThreadExitCallback(void* key_value) {
|
||||
#if defined(__ANDROID__)
|
||||
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||
EGL_NO_CONTEXT);
|
||||
#else
|
||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
||||
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
||||
// implementations, and should be considered as an undocumented vendor
|
||||
|
@ -49,7 +45,6 @@ static void EglThreadExitCallback(void* key_value) {
|
|||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||
#endif
|
||||
eglReleaseThread();
|
||||
}
|
||||
|
||||
|
|
|
@ -144,14 +144,23 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
|||
context](std::shared_ptr<GlSyncPoint> sync_token) {
|
||||
CHECK_NE(name_, 0);
|
||||
GLuint name_to_delete = name_;
|
||||
context->RunWithoutWaiting([name_to_delete, sync_token]() {
|
||||
if (sync_token) {
|
||||
// TODO: maybe we do not actually have to wait for the
|
||||
// consumer sync here. Check docs.
|
||||
sync_token->WaitOnGpu();
|
||||
} else {
|
||||
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback";
|
||||
}
|
||||
context->RunWithoutWaiting([name_to_delete]() {
|
||||
// Note that we do not wait for consumers to be done before deleting the
|
||||
// texture. Based on a reading of the GLES 3.0 spec, appendix D:
|
||||
// - when a texture is deleted, it is _not_ automatically unbound from
|
||||
// bind points in other contexts;
|
||||
// - when a texture is deleted, its name becomes immediately invalid, but
|
||||
// the actual object is not deleted until it is no longer in use, i.e.
|
||||
// attached to a container object or bound to a context;
|
||||
// - deleting an object is not an operation that changes its contents;
|
||||
// - within each context, commands are executed sequentially, so it seems
|
||||
// like an unbind that follows a command that reads a texture should not
|
||||
// take effect until the GPU has actually finished executing the
|
||||
// previous commands.
|
||||
// The final point is the least explicit in the docs, but it is implied by
|
||||
// normal single-context behavior. E.g. if you do bind, delete, render,
|
||||
// unbind, the object is not deleted until the unbind, and it waits for
|
||||
// the render to finish.
|
||||
DLOG_IF(ERROR, !glIsTexture(name_to_delete))
|
||||
<< "Deleting invalid texture id: " << name_to_delete;
|
||||
glDeleteTextures(1, &name_to_delete);
|
||||
|
@ -185,7 +194,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
|||
<< "Updated existing texture which had not been marked for reuse!";
|
||||
CHECK(prod_token);
|
||||
producer_sync_ = std::move(prod_token);
|
||||
producer_context_ = producer_sync_->GetContext();
|
||||
const auto& synced_context = producer_sync_->GetContext();
|
||||
if (synced_context) {
|
||||
producer_context_ = synced_context;
|
||||
}
|
||||
}
|
||||
|
||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
||||
|
|
|
@ -34,6 +34,7 @@ android_library(
|
|||
android_library(
|
||||
name = "android_framework_no_mff",
|
||||
proguard_specs = [":proguard.pgcfg"],
|
||||
visibility = ["//visibility:public"],
|
||||
exports = [
|
||||
":android_framework_no_proguard",
|
||||
],
|
||||
|
|
|
@ -30,3 +30,10 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
srcs = glob(["*.java"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -89,10 +89,6 @@ def mediapipe_aar(
|
|||
calculators = calculators,
|
||||
)
|
||||
|
||||
_mediapipe_proto(
|
||||
name = name + "_proto",
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name + "_aar_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
|
@ -115,19 +111,10 @@ EOF
|
|||
"//mediapipe/java/com/google/mediapipe/components:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
|
||||
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
"com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
"com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
"com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
"com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
"com/google/mediapipe/proto/CalculatorProto.java",
|
||||
] +
|
||||
] + mediapipe_java_proto_srcs() +
|
||||
select({
|
||||
"//conditions:default": [],
|
||||
"enable_stats_logging": [
|
||||
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
],
|
||||
"enable_stats_logging": mediapipe_logging_java_proto_srcs(),
|
||||
}),
|
||||
manifest = "AndroidManifest.xml",
|
||||
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
||||
|
@ -177,93 +164,9 @@ EOF
|
|||
assets_dir = assets_dir,
|
||||
)
|
||||
|
||||
_aar_with_jni(name, name + "_android_lib")
|
||||
|
||||
def _mediapipe_proto(name):
|
||||
"""Generates MediaPipe java proto libraries.
|
||||
|
||||
Args:
|
||||
name: the name of the target.
|
||||
"""
|
||||
_proto_java_src_generator(
|
||||
name = "mediapipe_log_extension_proto",
|
||||
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
srcs = ["//mediapipe/util/analytics:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "mediapipe_logging_enums_proto",
|
||||
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
srcs = ["//mediapipe/util/analytics:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "calculator_proto",
|
||||
proto_src = "mediapipe/framework/calculator.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
|
||||
srcs = ["//mediapipe/framework:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "landmark_proto",
|
||||
proto_src = "mediapipe/framework/formats/landmark.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
srcs = ["//mediapipe/framework/formats:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "rasterization_proto",
|
||||
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "location_data_proto",
|
||||
proto_src = "mediapipe/framework/formats/location_data.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
"//mediapipe/framework/formats/annotation:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "detection_proto",
|
||||
proto_src = "mediapipe/framework/formats/detection.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
"//mediapipe/framework/formats/annotation:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "classification_proto",
|
||||
proto_src = "mediapipe/framework/formats/classification.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
|
||||
native.genrule(
|
||||
name = name + "_proto_java_src_generator",
|
||||
srcs = srcs + [
|
||||
"@com_google_protobuf//:lite_well_known_protos",
|
||||
],
|
||||
outs = [java_lite_out],
|
||||
cmd = "$(location @com_google_protobuf//:protoc) " +
|
||||
"--proto_path=. --proto_path=$(GENDIR) " +
|
||||
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
|
||||
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
|
||||
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
|
||||
tools = [
|
||||
"@com_google_protobuf//:protoc",
|
||||
],
|
||||
mediapipe_build_aar_with_jni(
|
||||
name = name,
|
||||
android_library = name + "_android_lib",
|
||||
)
|
||||
|
||||
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
|
||||
|
@ -303,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
def _aar_with_jni(name, android_library):
|
||||
def mediapipe_build_aar_with_jni(name, android_library):
|
||||
"""Builds MediaPipe AAR with jni.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
android_library: the android library that contains jni.
|
||||
"""
|
||||
|
||||
# Generates dummy AndroidManifest.xml for dummy apk usage
|
||||
# (dummy apk is generated by <name>_dummy_app target below)
|
||||
native.genrule(
|
||||
|
@ -314,7 +224,7 @@ cat > $(OUTS) <<EOF
|
|||
<manifest
|
||||
xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="dummy.package.for.so">
|
||||
<uses-sdk android:minSdkVersion="21"/>
|
||||
<uses-sdk android:minSdkVersion="24"/>
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
|
@ -341,7 +251,128 @@ chmod +w $(location :{}.aar)
|
|||
origdir=$$PWD
|
||||
cd $$(mktemp -d)
|
||||
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
|
||||
find lib -name *_dummy_app.so -delete
|
||||
cp -r lib jni
|
||||
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
|
||||
""".format(android_library, name, name, name, name),
|
||||
)
|
||||
|
||||
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
|
||||
"""Extracts the generated MediaPipe java proto source code from the target.
|
||||
|
||||
Args:
|
||||
target: The java proto lite target to be built and extracted.
|
||||
src_out: The output java proto src code path.
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The output java proto src code path.
|
||||
"""
|
||||
|
||||
if not name:
|
||||
name = target.split(":")[-1] + "_proto_java_src_extractor"
|
||||
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
|
||||
native.genrule(
|
||||
name = name + "_proto_java_src_extractor",
|
||||
srcs = [target],
|
||||
outs = [src_out],
|
||||
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
|
||||
src_out + " $$(dirname $(location " + src_out + "))",
|
||||
)
|
||||
return src_out
|
||||
|
||||
def mediapipe_java_proto_srcs(name = ""):
|
||||
"""Extracts the generated MediaPipe framework java proto source code.
|
||||
|
||||
Args:
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The list of the extrated MediaPipe java proto source code.
|
||||
"""
|
||||
|
||||
proto_src_list = []
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:calculator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:stream_handler_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/StreamHandlerProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:packet_factory_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/PacketFactoryProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:packet_generator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:status_handler_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/StatusHandlerProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:mediapipe_options_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:detection_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
||||
def mediapipe_logging_java_proto_srcs(name = ""):
|
||||
"""Extracts the generated logging-related MediaPipe java proto source code.
|
||||
|
||||
Args:
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The list of the extrated MediaPipe logging-related java proto source code.
|
||||
"""
|
||||
|
||||
proto_src_list = []
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
|
22
mediapipe/model_maker/BUILD
Normal file
22
mediapipe/model_maker/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
22
mediapipe/model_maker/python/BUILD
Normal file
22
mediapipe/model_maker/python/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
26
mediapipe/model_maker/python/core/BUILD
Normal file
26
mediapipe/model_maker/python/core/BUILD
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
)
|
13
mediapipe/model_maker/python/core/__init__.py
Normal file
13
mediapipe/model_maker/python/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
68
mediapipe/model_maker/python/core/data/BUILD
Normal file
68
mediapipe/model_maker/python/core/data/BUILD
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "data_util",
|
||||
srcs = ["data_util.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "data_util_test",
|
||||
srcs = ["data_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/data/testdata"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":data_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classification_dataset",
|
||||
srcs = ["classification_dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classification_dataset_test",
|
||||
srcs = ["classification_dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":classification_dataset"],
|
||||
)
|
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common classification dataset library."""
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
|
||||
|
||||
class ClassificationDataset(ds.Dataset):
|
||||
"""DataLoader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
|
||||
super().__init__(dataset, size)
|
||||
self.index_to_label = index_to_label
|
||||
|
||||
@property
|
||||
def num_classes(self: ds._DatasetT) -> int:
|
||||
return len(self.index_to_label)
|
||||
|
||||
def split(self: ds._DatasetT,
|
||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||
"""Splits dataset into two sub-datasets with the given fraction.
|
||||
|
||||
Primarily used for splitting the data set into training and testing sets.
|
||||
|
||||
Args:
|
||||
fraction: float, demonstrates the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction, self.index_to_label)
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
class ClassificationDataLoaderTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
|
||||
class MagicClassificationDataLoader(
|
||||
classification_dataset.ClassificationDataset):
|
||||
|
||||
def __init__(self, dataset, size, index_to_label, value):
|
||||
super(MagicClassificationDataLoader,
|
||||
self).__init__(dataset, size, index_to_label)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction):
|
||||
return self._split(fraction, self.index_to_label, self.value)
|
||||
|
||||
# Some dummy inputs.
|
||||
magic_value = 42
|
||||
num_classes = 2
|
||||
index_to_label = (False, True)
|
||||
|
||||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
|
||||
magic_value)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
train_data, test_data = data.split(fraction)
|
||||
|
||||
# `split` should return instances of child DataLoader.
|
||||
self.assertIsInstance(train_data, MagicClassificationDataLoader)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataLoader)
|
||||
|
||||
# Make sure number of entries are right.
|
||||
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||
self.assertLen(train_data, fraction * len(ds))
|
||||
self.assertLen(test_data, len(ds) - len(train_data))
|
||||
|
||||
# Make sure attributes propagated correctly.
|
||||
self.assertEqual(train_data.num_classes, num_classes)
|
||||
self.assertEqual(test_data.index_to_label, index_to_label)
|
||||
self.assertEqual(train_data.value, magic_value)
|
||||
self.assertEqual(test_data.value, magic_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Data utility library."""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
"""Loads an image as an RGB numpy array.
|
||||
|
||||
Args:
|
||||
path: input image file absolute path.
|
||||
|
||||
Returns:
|
||||
An RGB image in numpy.ndarray.
|
||||
"""
|
||||
tf.compat.v1.logging.info('Loading RGB image %s', path)
|
||||
# TODO Replace the OpenCV image load and conversion library by
|
||||
# MediaPipe image utility library once it is ready.
|
||||
image = cv2.imread(path)
|
||||
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
|
||||
_WORKSPACE = "mediapipe"
|
||||
_TEST_DATA_DIR = os.path.join(
|
||||
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class DataUtilTest(tf.test.TestCase):
|
||||
|
||||
def test_load_rgb_image(self):
|
||||
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
|
||||
image_data = data_util.load_image(image_path)
|
||||
self.assertEqual(image_data.shape, (5184, 3456, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common dataset for model training and evaluation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""A generic dataset class for loading model training and evaluation dataset.
|
||||
|
||||
For each ML task, such as image classification, text classification etc., a
|
||||
subclass can be derived from this class to provide task-specific data loading
|
||||
utilities.
|
||||
"""
|
||||
|
||||
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
|
||||
"""Initializes Dataset class.
|
||||
|
||||
To build dataset from raw data, consider using the task specific utilities,
|
||||
e.g. from_folder().
|
||||
|
||||
Args:
|
||||
tf_dataset: A tf.data.Dataset object that contains a potentially large set
|
||||
of elements, where each element is a pair of (input_data, target). The
|
||||
`input_data` means the raw input data, like an image, a text etc., while
|
||||
the `target` means the ground truth of the raw input data, e.g. the
|
||||
classification label of the image etc.
|
||||
size: The size of the dataset. tf.data.Dataset donesn't support a function
|
||||
to get the length directly since it's lazy-loaded and may be infinite.
|
||||
"""
|
||||
self._dataset = tf_dataset
|
||||
self._size = size
|
||||
|
||||
@property
|
||||
def size(self) -> Optional[int]:
|
||||
"""Returns the size of the dataset.
|
||||
|
||||
Note that this function may return None becuase the exact size of the
|
||||
dataset isn't a necessary parameter to create an instance of this class,
|
||||
and tf.data.Dataset donesn't support a function to get the length directly
|
||||
since it's lazy-loaded and may be infinite.
|
||||
In most cases, however, when an instance of this class is created by helper
|
||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||
and this function can return an int representing the size of the dataset.
|
||||
"""
|
||||
return self._size
|
||||
|
||||
def gen_tf_dataset(self,
|
||||
batch_size: int = 1,
|
||||
is_training: bool = False,
|
||||
shuffle: bool = False,
|
||||
preprocess: Optional[Callable[..., bool]] = None,
|
||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||
|
||||
Args:
|
||||
batch_size: An integer, the returned dataset will be batched by this size.
|
||||
is_training: A boolean, when True, the returned dataset will be optionally
|
||||
shuffled and repeated as an endless dataset.
|
||||
shuffle: A boolean, when True, the returned dataset will be shuffled to
|
||||
create randomness during model training.
|
||||
preprocess: A function taking three arguments in order, feature, label and
|
||||
boolean is_training.
|
||||
drop_remainder: boolean, whether the finaly batch drops remainder.
|
||||
|
||||
Returns:
|
||||
A TF dataset ready to be consumed by Keras model.
|
||||
"""
|
||||
dataset = self._dataset
|
||||
|
||||
if preprocess:
|
||||
preprocess = functools.partial(preprocess, is_training=is_training)
|
||||
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
if is_training:
|
||||
if shuffle:
|
||||
# Shuffle size should be bigger than the batch_size. Otherwise it's only
|
||||
# shuffling within the batch, which equals to not having shuffle.
|
||||
buffer_size = 3 * batch_size
|
||||
# But since we are doing shuffle before repeat, it doesn't make sense to
|
||||
# shuffle more than total available entries.
|
||||
# TODO: Investigate if shuffling before / after repeat
|
||||
# dataset can get a better performance?
|
||||
# Shuffle after repeat will give a more randomized dataset and mix the
|
||||
# epoch boundary: https://www.tensorflow.org/guide/data
|
||||
if self._size:
|
||||
buffer_size = min(self._size, buffer_size)
|
||||
dataset = dataset.shuffle(buffer_size=buffer_size)
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
# TODO: Consider converting dataset to distributed dataset
|
||||
# here.
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of element of the dataset."""
|
||||
if self._size is not None:
|
||||
return self._size
|
||||
else:
|
||||
return len(self._dataset)
|
||||
|
||||
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
"""Splits dataset into two sub-datasets with the given fraction.
|
||||
|
||||
Primarily used for splitting the data set into training and testing sets.
|
||||
|
||||
Args:
|
||||
fraction: A float value defines the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction)
|
||||
|
||||
def _split(self: _DatasetT, fraction: float,
|
||||
*args) -> Tuple[_DatasetT, _DatasetT]:
|
||||
"""Implementation for `split` method and returns sub-class instances.
|
||||
|
||||
Child DataLoader classes, if requires additional constructor arguments,
|
||||
should implement their own `split` method by calling `_split` with all
|
||||
arguments to the constructor.
|
||||
|
||||
Args:
|
||||
fraction: A float value defines the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
*args: additional arguments passed to the sub-class constructor.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
assert (fraction > 0 and fraction < 1)
|
||||
|
||||
dataset = self._dataset
|
||||
|
||||
train_size = int(self._size * fraction)
|
||||
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
||||
|
||||
test_size = self._size - train_size
|
||||
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
||||
|
||||
return trainset, testset
|
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]])
|
||||
data = ds.Dataset(dataset, 4)
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
self.assertIsInstance(train_data, ds.Dataset)
|
||||
self.assertIsInstance(test_data, ds.Dataset)
|
||||
for i, elem in enumerate(train_data.gen_tf_dataset()):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data.gen_tf_dataset()):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
|
||||
def test_len(self):
|
||||
size = 4
|
||||
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]])
|
||||
data = ds.Dataset(dataset, size)
|
||||
self.assertLen(data, size)
|
||||
|
||||
def test_gen_tf_dataset(self):
|
||||
input_dim = 8
|
||||
data = test_util.create_dataset(
|
||||
data_size=2, input_shape=[input_dim], num_classes=2)
|
||||
|
||||
dataset = data.gen_tf_dataset()
|
||||
self.assertLen(dataset, 2)
|
||||
for (feature, label) in dataset:
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
|
||||
|
||||
dataset2 = data.gen_tf_dataset(batch_size=2)
|
||||
self.assertLen(dataset2, 1)
|
||||
for (feature, label) in dataset2:
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||
|
||||
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
|
||||
self.assertEqual(dataset3.cardinality(), 1)
|
||||
for (feature, label) in dataset3.take(10):
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
mediapipe_files(srcs = ["test.jpg"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["test.jpg"],
|
||||
)
|
68
mediapipe/model_maker/python/core/hyperparameters.py
Normal file
68
mediapipe/model_maker/python/core/hyperparameters.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training models. Shared across tasks."""
|
||||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# TODO: Integrate this class into ImageClassifier and other tasks.
|
||||
@dataclasses.dataclass
|
||||
class BaseHParams:
|
||||
"""Hyperparameters used for training models.
|
||||
|
||||
A common set of hyperparameters shared by the training jobs of all model
|
||||
maker tasks.
|
||||
|
||||
Attributes:
|
||||
learning_rate: The learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
steps_per_epoch: An optional integer indicate the number of training steps
|
||||
per epoch. If not set, the training pipeline calculates the default steps
|
||||
per epoch as the training dataset size devided by batch size.
|
||||
shuffle: True if the dataset is shuffled before training.
|
||||
export_dir: The location of the model checkpoint files.
|
||||
distribution_strategy: A string specifying which Distribution Strategy to
|
||||
use. Accepted values are 'off', 'one_device', 'mirrored',
|
||||
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
|
||||
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to
|
||||
use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy
|
||||
documentation for more details:
|
||||
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
|
||||
num_gpus: How many GPUs to use at each worker with the
|
||||
DistributionStrategies API. The default is -1, which means utilize all
|
||||
available GPUs.
|
||||
tpu: The Cloud TPU to use for training. This should be either the name used
|
||||
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
|
||||
"""
|
||||
|
||||
# Parameters for train configuration
|
||||
learning_rate: float
|
||||
batch_size: int
|
||||
epochs: int
|
||||
steps_per_epoch: Optional[int] = None
|
||||
|
||||
# Dataset-related parameters
|
||||
shuffle: bool = False
|
||||
|
||||
# Parameters for model / checkpoint files
|
||||
export_dir: str = tempfile.mkdtemp()
|
||||
|
||||
# Parameters for hardware acceleration
|
||||
distribution_strategy: str = 'off'
|
||||
num_gpus: int = -1 # default value of -1 means use all available GPUs
|
||||
tpu: str = ''
|
64
mediapipe/model_maker/python/core/tasks/BUILD
Normal file
64
mediapipe/model_maker/python/core/tasks/BUILD
Normal file
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "custom_model",
|
||||
srcs = ["custom_model.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "custom_model_test",
|
||||
srcs = ["custom_model_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classifier",
|
||||
srcs = ["classifier.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classifier_test",
|
||||
srcs = ["classifier_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/python/core/tasks/__init__.py
Normal file
13
mediapipe/model_maker/python/core/tasks/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
77
mediapipe/model_maker/python/core/tasks/classifier.py
Normal file
77
mediapipe/model_maker/python/core/tasks/classifier.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Custom classifier."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||
|
||||
|
||||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
|
||||
full_train: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_to_label: A list that map from index to label class name.
|
||||
shuffle: Whether the dataset should be shuffled.
|
||||
full_train: If true, train the model end-to-end including the backbone
|
||||
and the classification layers on top. Otherwise, only train the top
|
||||
classification layers.
|
||||
"""
|
||||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._index_to_label = index_to_label
|
||||
self._full_train = full_train
|
||||
self._num_classes = len(index_to_label)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
||||
Args:
|
||||
data: Evaluation dataset
|
||||
batch_size: Number of samples per evaluation step.
|
||||
|
||||
Returns:
|
||||
The loss value and accuracy.
|
||||
"""
|
||||
ds = data.gen_tf_dataset(
|
||||
batch_size, is_training=False, preprocess=self._preprocess)
|
||||
return self._model.evaluate(ds)
|
||||
|
||||
def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'):
|
||||
"""Exports classification labels into a label file.
|
||||
|
||||
Args:
|
||||
export_dir: The directory to save exported files.
|
||||
label_filename: File name to save labels model. The full export path is
|
||||
{export_dir}/{label_filename}.
|
||||
"""
|
||||
if not tf.io.gfile.exists(export_dir):
|
||||
tf.io.gfile.makedirs(export_dir)
|
||||
|
||||
label_filepath = os.path.join(export_dir, label_filename)
|
||||
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
||||
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
||||
f.write('\n'.join(self._index_to_label))
|
58
mediapipe/model_maker/python/core/tasks/classifier_test.py
Normal file
58
mediapipe/model_maker/python/core/tasks/classifier_test.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class MockClassifier(classifier.Classifier):
|
||||
"""A mock class with implementation of abstract methods for testing."""
|
||||
|
||||
def train(self, train_data, validation_data=None, **kwargs):
|
||||
pass
|
||||
|
||||
def evaluate(self, data, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ClassifierTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ClassifierTest, self).setUp()
|
||||
index_to_label = ['cat', 'dog']
|
||||
self.model = MockClassifier(
|
||||
model_spec=None,
|
||||
index_to_label=index_to_label,
|
||||
shuffle=False,
|
||||
full_train=False)
|
||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
||||
def _check_nonempty_file(self, filepath):
|
||||
self.assertTrue(os.path.isfile(filepath))
|
||||
self.assertGreater(os.path.getsize(filepath), 0)
|
||||
|
||||
def test_export_labels(self):
|
||||
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||
self.model.export_labels(export_dir=export_path)
|
||||
self._check_nonempty_file(os.path.join(export_path, 'labels.txt'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
85
mediapipe/model_maker/python/core/tasks/custom_model.py
Normal file
85
mediapipe/model_maker/python/core/tasks/custom_model.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Interface to define a custom model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
||||
|
||||
class CustomModel(abc.ABC):
|
||||
"""The abstract base class that represents a custom TensorFlow model."""
|
||||
|
||||
def __init__(self, model_spec: Any, shuffle: bool):
|
||||
"""Initializes a custom model with model specs and other parameters.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
shuffle: Whether the training data need be shuffled.
|
||||
"""
|
||||
self._model_spec = model_spec
|
||||
self._shuffle = shuffle
|
||||
self._preprocess = None
|
||||
self._model = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def evaluate(self, data: dataset.Dataset, **kwargs):
|
||||
"""Evaluates the model with the provided data."""
|
||||
return
|
||||
|
||||
def summary(self):
|
||||
"""Prints a summary of the model."""
|
||||
self._model.summary()
|
||||
|
||||
def export_tflite(
|
||||
self,
|
||||
export_dir: str,
|
||||
tflite_filename: str = 'model.tflite',
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
preprocess: Optional[Callable[..., bool]] = None):
|
||||
"""Converts the model to requested formats.
|
||||
|
||||
Args:
|
||||
export_dir: The directory to save exported files.
|
||||
tflite_filename: File name to save tflite model. The full export path is
|
||||
{export_dir}/{tflite_filename}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature,
|
||||
label, and is_training.
|
||||
"""
|
||||
if not tf.io.gfile.exists(export_dir):
|
||||
tf.io.gfile.makedirs(export_dir)
|
||||
|
||||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||
# TODO: Populate metadata to the exported TFLite model.
|
||||
model_util.export_tflite(
|
||||
self._model,
|
||||
tflite_filepath,
|
||||
quantization_config,
|
||||
preprocess=preprocess)
|
||||
tf.compat.v1.logging.info(
|
||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
56
mediapipe/model_maker/python/core/tasks/custom_model_test.py
Normal file
56
mediapipe/model_maker/python/core/tasks/custom_model_test.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class MockCustomModel(custom_model.CustomModel):
|
||||
"""A mock class with implementation of abstract methods for testing."""
|
||||
|
||||
def train(self, train_data, validation_data=None, **kwargs):
|
||||
pass
|
||||
|
||||
def evaluate(self, data, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class CustomModelTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CustomModelTest, self).setUp()
|
||||
self.model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
||||
def _check_nonempty_file(self, filepath):
|
||||
self.assertTrue(os.path.isfile(filepath))
|
||||
self.assertGreater(os.path.getsize(filepath), 0)
|
||||
|
||||
def test_export_tflite(self):
|
||||
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||
self.model.export_tflite(export_dir=export_path)
|
||||
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
100
mediapipe/model_maker/python/core/utils/BUILD
Normal file
100
mediapipe/model_maker/python/core/utils/BUILD
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_preprocessing",
|
||||
srcs = ["image_preprocessing.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_preprocessing_test",
|
||||
srcs = ["image_preprocessing_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_util_test",
|
||||
srcs = ["model_util_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
":quantization",
|
||||
":test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "loss_functions_test",
|
||||
srcs = ["loss_functions_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":loss_functions"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "quantization",
|
||||
srcs = ["quantization.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "quantization_test",
|
||||
srcs = ["quantization_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
":test_util",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/python/core/utils/__init__.py
Normal file
13
mediapipe/model_maker/python/core/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
228
mediapipe/model_maker/python/core/utils/image_preprocessing.py
Normal file
228
mediapipe/model_maker/python/core/utils/image_preprocessing.py
Normal file
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ImageNet preprocessing."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_SIZE = 224
|
||||
CROP_PADDING = 32
|
||||
|
||||
|
||||
class Preprocessor(object):
|
||||
"""Preprocessor for image classification."""
|
||||
|
||||
def __init__(self,
|
||||
input_shape,
|
||||
num_classes,
|
||||
mean_rgb,
|
||||
stddev_rgb,
|
||||
use_augmentation=False):
|
||||
self.input_shape = input_shape
|
||||
self.num_classes = num_classes
|
||||
self.mean_rgb = mean_rgb
|
||||
self.stddev_rgb = stddev_rgb
|
||||
self.use_augmentation = use_augmentation
|
||||
|
||||
def __call__(self, image, label, is_training=True):
|
||||
if self.use_augmentation:
|
||||
return self._preprocess_with_augmentation(image, label, is_training)
|
||||
return self._preprocess_without_augmentation(image, label)
|
||||
|
||||
def _preprocess_with_augmentation(self, image, label, is_training):
|
||||
"""Image preprocessing method with data augmentation."""
|
||||
image_size = self.input_shape[0]
|
||||
if is_training:
|
||||
image = preprocess_for_train(image, image_size)
|
||||
else:
|
||||
image = preprocess_for_eval(image, image_size)
|
||||
|
||||
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
|
||||
label = tf.one_hot(label, depth=self.num_classes)
|
||||
return image, label
|
||||
|
||||
# TODO: Changes to preprocess to support batch input.
|
||||
def _preprocess_without_augmentation(self, image, label):
|
||||
"""Image preprocessing method without data augmentation."""
|
||||
image = tf.cast(image, tf.float32)
|
||||
|
||||
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
|
||||
image = tf.compat.v1.image.resize(image, self.input_shape)
|
||||
label = tf.one_hot(label, depth=self.num_classes)
|
||||
return image, label
|
||||
|
||||
|
||||
def _distorted_bounding_box_crop(image,
|
||||
bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=(0.75, 1.33),
|
||||
area_range=(0.05, 1.0),
|
||||
max_attempts=100):
|
||||
"""Generates cropped_image using one of the bboxes randomly distorted.
|
||||
|
||||
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
||||
|
||||
Args:
|
||||
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||
shape [height, width, channels].
|
||||
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
|
||||
each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
|
||||
xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
|
||||
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
|
||||
of the image must contain at least this fraction of any bounding box
|
||||
supplied.
|
||||
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
||||
image must have an aspect ratio = width / height within this range.
|
||||
area_range: An optional list of `float`s. The cropped area of the image must
|
||||
contain a fraction of the supplied image within in this range.
|
||||
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
||||
region of the image of the specified constraints. After `max_attempts`
|
||||
failures, return the entire image.
|
||||
|
||||
Returns:
|
||||
A cropped image `Tensor`
|
||||
"""
|
||||
with tf.name_scope('distorted_bounding_box_crop'):
|
||||
shape = tf.shape(image)
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
shape,
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=min_object_covered,
|
||||
aspect_ratio_range=aspect_ratio_range,
|
||||
area_range=area_range,
|
||||
max_attempts=max_attempts,
|
||||
use_image_if_no_bounding_boxes=True)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
|
||||
# Crop the image to the specified bounding box.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
image = tf.image.crop_to_bounding_box(image, offset_y, offset_x,
|
||||
target_height, target_width)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _at_least_x_are_equal(a, b, x):
|
||||
"""At least `x` of `a` and `b` `Tensors` are equal."""
|
||||
match = tf.equal(a, b)
|
||||
match = tf.cast(match, tf.int32)
|
||||
return tf.greater_equal(tf.reduce_sum(match), x)
|
||||
|
||||
|
||||
def _resize_image(image, image_size, method=None):
|
||||
if method is not None:
|
||||
tf.compat.v1.logging.info('Use customized resize method {}'.format(method))
|
||||
return tf.compat.v1.image.resize([image], [image_size, image_size],
|
||||
method)[0]
|
||||
tf.compat.v1.logging.info('Use default resize_bicubic.')
|
||||
return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
|
||||
|
||||
|
||||
def _decode_and_random_crop(original_image, image_size, resize_method=None):
|
||||
"""Makes a random crop of image_size."""
|
||||
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
||||
image = _distorted_bounding_box_crop(
|
||||
original_image,
|
||||
bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=(3. / 4, 4. / 3.),
|
||||
area_range=(0.08, 1.0),
|
||||
max_attempts=10)
|
||||
original_shape = tf.shape(original_image)
|
||||
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
|
||||
|
||||
image = tf.cond(bad,
|
||||
lambda: _decode_and_center_crop(original_image, image_size),
|
||||
lambda: _resize_image(image, image_size, resize_method))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _decode_and_center_crop(image, image_size, resize_method=None):
|
||||
"""Crops to center of image with padding then scales image_size."""
|
||||
shape = tf.shape(image)
|
||||
image_height = shape[0]
|
||||
image_width = shape[1]
|
||||
|
||||
padded_center_crop_size = tf.cast(
|
||||
((image_size / (image_size + CROP_PADDING)) *
|
||||
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
|
||||
|
||||
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
||||
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
||||
image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
|
||||
padded_center_crop_size,
|
||||
padded_center_crop_size)
|
||||
image = _resize_image(image, image_size, resize_method)
|
||||
return image
|
||||
|
||||
|
||||
def _flip(image):
|
||||
"""Random horizontal image flip."""
|
||||
image = tf.image.random_flip_left_right(image)
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_for_train(
|
||||
image: tf.Tensor,
|
||||
image_size: int = IMAGE_SIZE,
|
||||
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
|
||||
"""Preprocesses the given image for evaluation.
|
||||
|
||||
Args:
|
||||
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||
shape [height, width, channels].
|
||||
image_size: image size.
|
||||
resize_method: resize method. If none, use bicubic.
|
||||
|
||||
Returns:
|
||||
A preprocessed image `Tensor`.
|
||||
"""
|
||||
image = _decode_and_random_crop(image, image_size, resize_method)
|
||||
image = _flip(image)
|
||||
image = tf.reshape(image, [image_size, image_size, 3])
|
||||
|
||||
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_for_eval(
|
||||
image: tf.Tensor,
|
||||
image_size: int = IMAGE_SIZE,
|
||||
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
|
||||
"""Preprocesses the given image for evaluation.
|
||||
|
||||
Args:
|
||||
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||
shape [height, width, channels].
|
||||
image_size: image size.
|
||||
resize_method: if None, use bicubic.
|
||||
|
||||
Returns:
|
||||
A preprocessed image `Tensor`.
|
||||
"""
|
||||
image = _decode_and_center_crop(image, image_size, resize_method)
|
||||
image = tf.reshape(image, [image_size, image_size, 3])
|
||||
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||
return image
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
|
||||
|
||||
def _get_preprocessed_image(preprocessor, is_training=False):
|
||||
image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3])
|
||||
label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
|
||||
image_tensor, _ = preprocessor(image_placeholder, label_placeholder,
|
||||
is_training)
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3])
|
||||
image = sess.run(
|
||||
image_tensor,
|
||||
feed_dict={
|
||||
image_placeholder: input_image,
|
||||
label_placeholder: [0]
|
||||
})
|
||||
return image
|
||||
|
||||
|
||||
class PreprocessorTest(tf.test.TestCase):
|
||||
|
||||
def test_preprocess_without_augmentation(self):
|
||||
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
|
||||
num_classes=2,
|
||||
mean_rgb=[0.0],
|
||||
stddev_rgb=[255.0],
|
||||
use_augmentation=False)
|
||||
actual_image = np.array([[[0., 0.00392157, 0.00784314],
|
||||
[0.14117648, 0.14509805, 0.14901961]],
|
||||
[[0.37647063, 0.3803922, 0.38431376],
|
||||
[0.5176471, 0.52156866, 0.5254902]]])
|
||||
|
||||
image = _get_preprocessed_image(preprocessor)
|
||||
self.assertTrue(np.allclose(image, actual_image, atol=1e-05))
|
||||
|
||||
def test_preprocess_with_augmentation(self):
|
||||
image_preprocessing.CROP_PADDING = 1
|
||||
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
|
||||
num_classes=2,
|
||||
mean_rgb=[0.0],
|
||||
stddev_rgb=[255.0],
|
||||
use_augmentation=True)
|
||||
# Tests validation image.
|
||||
actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216],
|
||||
[0.26666668, 0.27058825, 0.27450982]],
|
||||
[[0.42352945, 0.427451, 0.43137258],
|
||||
[0.5176471, 0.52156866, 0.5254902]]])
|
||||
|
||||
image = _get_preprocessed_image(preprocessor, is_training=False)
|
||||
self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05))
|
||||
|
||||
# Tests training image.
|
||||
image1 = _get_preprocessed_image(preprocessor, is_training=True)
|
||||
image2 = _get_preprocessed_image(preprocessor, is_training=True)
|
||||
self.assertFalse(np.allclose(image1, image2, atol=1e-05))
|
||||
self.assertEqual(image1.shape, (2, 2, 3))
|
||||
self.assertEqual(image2.shape, (2, 2, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
tf.test.main()
|
105
mediapipe/model_maker/python/core/utils/loss_functions.py
Normal file
105
mediapipe/model_maker/python/core/utils/loss_functions.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Loss function utility library."""
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class FocalLoss(tf.keras.losses.Loss):
|
||||
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
|
||||
|
||||
This class computes the focal loss between labels and prediction. Focal loss
|
||||
is a weighted loss function that modulates the standard cross-entropy loss
|
||||
based on how well the neural network performs on a specific example of a
|
||||
class. The labels should be provided in a `one_hot` vector representation.
|
||||
There should be `#classes` floating point values per prediction.
|
||||
The loss is reduced across all samples using 'sum_over_batch_size' reduction
|
||||
(see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction).
|
||||
|
||||
Example usage:
|
||||
>>> y_true = [[0, 1, 0], [0, 0, 1]]
|
||||
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||
>>> gamma = 2
|
||||
>>> focal_loss = FocalLoss(gamma)
|
||||
>>> focal_loss(y_true, y_pred).numpy()
|
||||
0.9326
|
||||
|
||||
>>> # Calling with 'sample_weight'.
|
||||
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||
0.6528
|
||||
|
||||
Usage with the `compile()` API:
|
||||
```python
|
||||
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
gamma: Focal loss gamma, as described in class docs.
|
||||
class_weight: A weight to apply to the loss, one for each class. The
|
||||
weight is applied for each input where the ground truth label matches.
|
||||
"""
|
||||
super(tf.keras.losses.Loss, self).__init__()
|
||||
# Used for clipping min/max values of probability values in y_pred to avoid
|
||||
# NaNs and Infs in computation.
|
||||
self._epsilon = 1e-7
|
||||
# This is a tunable "focusing parameter"; should be >= 0.
|
||||
# When gamma = 0, the loss returned is the standard categorical
|
||||
# cross-entropy loss.
|
||||
self._gamma = gamma
|
||||
self._class_weight = class_weight
|
||||
# tf.keras.losses.Loss class implementation requires a Reduction specified
|
||||
# in self.reduction. To use this reduction, we should use tensorflow's
|
||||
# compute_weighted_loss function however it is only compatible with v1 of
|
||||
# Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long
|
||||
# So even though it is specified here, we don't use self.reduction in the
|
||||
# loss function call.
|
||||
self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||
|
||||
def __call__(self,
|
||||
y_true: tf.Tensor,
|
||||
y_pred: tf.Tensor,
|
||||
sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor:
|
||||
if self._class_weight:
|
||||
class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32)
|
||||
label = tf.argmax(y_true, axis=1)
|
||||
loss_weight = tf.gather(class_weight, label)
|
||||
else:
|
||||
loss_weight = tf.ones(tf.shape(y_true)[0])
|
||||
y_true = tf.cast(y_true, y_pred.dtype)
|
||||
y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon)
|
||||
batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype)
|
||||
if sample_weight is None:
|
||||
sample_weight = tf.constant(1.0)
|
||||
weight_shape = sample_weight.shape
|
||||
weight_rank = weight_shape.ndims
|
||||
y_pred_rank = y_pred.shape.ndims
|
||||
if y_pred_rank - weight_rank == 1:
|
||||
sample_weight = tf.expand_dims(sample_weight, [-1])
|
||||
elif weight_rank != 0:
|
||||
raise ValueError(f'Unexpected sample_weights, should be either a scalar'
|
||||
f'or a vector of batch_size:{batch_size.numpy()}')
|
||||
ce = -tf.math.log(y_pred)
|
||||
modulating_factor = tf.math.pow(1 - y_pred, self._gamma)
|
||||
losses = y_true * modulating_factor * ce * sample_weight
|
||||
losses = losses * loss_weight[:, tf.newaxis]
|
||||
# By default, this function uses "sum_over_batch_size" reduction for the
|
||||
# loss per batch.
|
||||
return tf.reduce_sum(losses) / batch_size
|
103
mediapipe/model_maker/python/core/utils/loss_functions_test.py
Normal file
103
mediapipe/model_maker/python/core/utils/loss_functions_test.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
|
||||
|
||||
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='no_sample_weight', sample_weight=None),
|
||||
dict(
|
||||
testcase_name='with_sample_weight',
|
||||
sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])))
|
||||
def test_focal_loss_gamma_0_is_cross_entropy(
|
||||
self, sample_weight: Optional[tf.Tensor]):
|
||||
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
|
||||
0]])
|
||||
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
|
||||
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
|
||||
|
||||
tf_cce = tf.keras.losses.CategoricalCrossentropy(
|
||||
from_logits=False,
|
||||
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
|
||||
focal_loss = loss_functions.FocalLoss(gamma=0)
|
||||
self.assertAllClose(
|
||||
tf_cce(y_true, y_pred, sample_weight=sample_weight),
|
||||
focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4)
|
||||
|
||||
def test_focal_loss_with_sample_weight(self):
|
||||
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
|
||||
0]])
|
||||
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
|
||||
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
|
||||
|
||||
focal_loss = loss_functions.FocalLoss(gamma=0)
|
||||
|
||||
sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])
|
||||
|
||||
self.assertGreater(
|
||||
focal_loss(y_true=y_true, y_pred=y_pred),
|
||||
focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])),
|
||||
dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])),
|
||||
dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])),
|
||||
dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])),
|
||||
dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])),
|
||||
)
|
||||
def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor):
|
||||
y_true = tf.constant([[1, 0]])
|
||||
|
||||
focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0)
|
||||
loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred)
|
||||
focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5)
|
||||
loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred)
|
||||
focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1)
|
||||
loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred)
|
||||
focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2)
|
||||
loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred)
|
||||
focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5)
|
||||
loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred)
|
||||
|
||||
self.assertGreater(loss_gamma_0, loss_gamma_0p5)
|
||||
self.assertGreater(loss_gamma_0p5, loss_gamma_1)
|
||||
self.assertGreater(loss_gamma_1, loss_gamma_2)
|
||||
self.assertGreater(loss_gamma_2, loss_gamma_5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='index_0', true_class=0),
|
||||
dict(testcase_name='index_1', true_class=1),
|
||||
dict(testcase_name='index_2', true_class=2),
|
||||
)
|
||||
def test_focal_loss_class_weight_is_applied(self, true_class: int):
|
||||
class_weight = [1.0, 3.0, 10.0]
|
||||
y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0
|
||||
y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :]
|
||||
expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class]
|
||||
|
||||
loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight)
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
self.assertNear(loss, expected_loss, 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
272
mediapipe/model_maker/python/core/utils/model_util.py
Normal file
272
mediapipe/model_maker/python/core/utils/model_util.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for keras models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# resources dependency
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||
ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||
|
||||
|
||||
def load_keras_model(model_path: str,
|
||||
compile_on_load: bool = False) -> tf.keras.Model:
|
||||
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
||||
|
||||
Args:
|
||||
model_path: Relative path to a directory containing model data, such as
|
||||
<parent_path>/saved_model/.
|
||||
compile_on_load: Whether the model should be compiled while loading. If
|
||||
False, the model returned has to be compiled with the appropriate loss
|
||||
function and custom metrics before running for inference on a test
|
||||
dataset.
|
||||
|
||||
Returns:
|
||||
A tensorflow Keras model.
|
||||
"""
|
||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||
# with the `model_path` which defines the relative path under mediapipe/, it
|
||||
# yields to the aboslution path of the model files directory.
|
||||
cwd = os.path.dirname(__file__)
|
||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||
absolute_path = os.path.join(base_dir, model_path)
|
||||
return tf.keras.models.load_model(
|
||||
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
||||
|
||||
|
||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
train_data: Optional[dataset.Dataset] = None) -> int:
|
||||
"""Gets the estimated training steps per epoch.
|
||||
|
||||
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
|
||||
2. Else if we can get the length of training data successfully, returns
|
||||
`train_data_length // batch_size`.
|
||||
|
||||
Args:
|
||||
steps_per_epoch: int, training steps per epoch.
|
||||
batch_size: int, batch size.
|
||||
train_data: training data.
|
||||
|
||||
Returns:
|
||||
Estimated training steps per epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: if both steps_per_epoch and train_data are not set.
|
||||
"""
|
||||
if steps_per_epoch is not None:
|
||||
# steps_per_epoch is set by users manually.
|
||||
return steps_per_epoch
|
||||
else:
|
||||
if train_data is None:
|
||||
raise ValueError('Input train_data cannot be None.')
|
||||
# Gets the steps by the length of the training data.
|
||||
return len(train_data) // batch_size
|
||||
|
||||
|
||||
def export_tflite(
|
||||
model: tf.keras.Model,
|
||||
tflite_filepath: str,
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
supported_ops: Tuple[tf.lite.OpsSet,
|
||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
||||
preprocess: Optional[Callable[..., bool]] = None):
|
||||
"""Converts the model to tflite format and saves it.
|
||||
|
||||
Args:
|
||||
model: model to be converted to tflite.
|
||||
tflite_filepath: File path to save tflite model.
|
||||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature,
|
||||
label, and is_training.
|
||||
"""
|
||||
if tflite_filepath is None:
|
||||
raise ValueError(
|
||||
"TFLite filepath couldn't be None when exporting to tflite.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, 'saved_model')
|
||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||
|
||||
if quantization_config:
|
||||
converter = quantization_config.set_converter_with_quantization(
|
||||
converter, preprocess=preprocess)
|
||||
|
||||
converter.target_spec.supported_ops = supported_ops
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
|
||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
||||
|
||||
def __init__(self,
|
||||
initial_learning_rate: float,
|
||||
decay_schedule_fn: Callable[[Any], Any],
|
||||
warmup_steps: int,
|
||||
name: Optional[str] = None):
|
||||
"""Initializes a new instance of the `WarmUp` class.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: learning rate after the warmup.
|
||||
decay_schedule_fn: A function maps step to learning rate. Will be applied
|
||||
for values of step larger than 'warmup_steps'.
|
||||
warmup_steps: Number of steps to do warmup for.
|
||||
name: TF namescope under which to perform the learning rate calculation.
|
||||
"""
|
||||
super(WarmUp, self).__init__()
|
||||
self.initial_learning_rate = initial_learning_rate
|
||||
self.warmup_steps = warmup_steps
|
||||
self.decay_schedule_fn = decay_schedule_fn
|
||||
self.name = name
|
||||
|
||||
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
|
||||
with tf.name_scope(self.name or 'WarmUp') as name:
|
||||
# Implements linear warmup. i.e., if global_step < warmup_steps, the
|
||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||
global_step_float = tf.cast(step, tf.float32)
|
||||
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
|
||||
warmup_percent_done = global_step_float / warmup_steps_float
|
||||
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
|
||||
return tf.cond(
|
||||
global_step_float < warmup_steps_float,
|
||||
lambda: warmup_learning_rate,
|
||||
lambda: self.decay_schedule_fn(step),
|
||||
name=name)
|
||||
|
||||
def get_config(self) -> Dict[Text, Any]:
|
||||
return {
|
||||
'initial_learning_rate': self.initial_learning_rate,
|
||||
'decay_schedule_fn': self.decay_schedule_fn,
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'name': self.name
|
||||
}
|
||||
|
||||
|
||||
class LiteRunner(object):
|
||||
"""A runner to do inference with the TFLite model."""
|
||||
|
||||
def __init__(self, tflite_filepath: str):
|
||||
"""Initializes Lite runner with tflite model file.
|
||||
|
||||
Args:
|
||||
tflite_filepath: File path to the TFLite model.
|
||||
"""
|
||||
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
|
||||
tflite_model = f.read()
|
||||
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
self.interpreter.allocate_tensors()
|
||||
self.input_details = self.interpreter.get_input_details()
|
||||
self.output_details = self.interpreter.get_output_details()
|
||||
|
||||
def run(
|
||||
self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]]
|
||||
) -> Union[List[tf.Tensor], tf.Tensor]:
|
||||
"""Runs inference with the TFLite model.
|
||||
|
||||
Args:
|
||||
input_tensors: List / Dict of the input tensors of the TFLite model. The
|
||||
order should be the same as the keras model if it's a list. It also
|
||||
accepts tensor directly if the model has only 1 input.
|
||||
|
||||
Returns:
|
||||
List of the output tensors for multi-output models, otherwise just
|
||||
the output tensor. The order should be the same as the keras model.
|
||||
"""
|
||||
|
||||
if not isinstance(input_tensors, list) and not isinstance(
|
||||
input_tensors, dict):
|
||||
input_tensors = [input_tensors]
|
||||
|
||||
interpreter = self.interpreter
|
||||
|
||||
# Reshape inputs
|
||||
for i, input_detail in enumerate(self.input_details):
|
||||
input_tensor = _get_input_tensor(
|
||||
input_tensors=input_tensors,
|
||||
input_details=self.input_details,
|
||||
index=i)
|
||||
interpreter.resize_tensor_input(
|
||||
input_index=input_detail['index'], tensor_size=input_tensor.shape)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
# Feed input to the interpreter
|
||||
for i, input_detail in enumerate(self.input_details):
|
||||
input_tensor = _get_input_tensor(
|
||||
input_tensors=input_tensors,
|
||||
input_details=self.input_details,
|
||||
index=i)
|
||||
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||
# Quantize the input
|
||||
scale, zero_point = input_detail['quantization']
|
||||
input_tensor = input_tensor / scale + zero_point
|
||||
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
|
||||
interpreter.set_tensor(input_detail['index'], input_tensor)
|
||||
|
||||
interpreter.invoke()
|
||||
|
||||
output_tensors = []
|
||||
for output_detail in self.output_details:
|
||||
output_tensor = interpreter.get_tensor(output_detail['index'])
|
||||
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||
# Dequantize the output
|
||||
scale, zero_point = output_detail['quantization']
|
||||
output_tensor = output_tensor.astype(np.float32)
|
||||
output_tensor = (output_tensor - zero_point) * scale
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if len(output_tensors) == 1:
|
||||
return output_tensors[0]
|
||||
return output_tensors
|
||||
|
||||
|
||||
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
|
||||
"""Returns a `LiteRunner` from file path to TFLite model."""
|
||||
lite_runner = LiteRunner(tflite_filepath)
|
||||
return lite_runner
|
||||
|
||||
|
||||
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
|
||||
tf.Tensor]],
|
||||
input_details: Dict[str, Any], index: int) -> tf.Tensor:
|
||||
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
|
||||
if isinstance(input_tensors, dict):
|
||||
# Gets the mapped input tensor.
|
||||
input_detail = input_details
|
||||
for input_tensor_name, input_tensor in input_tensors.items():
|
||||
if input_tensor_name in input_detail['name']:
|
||||
return input_tensor
|
||||
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
|
||||
'input detail %s' % str(input_detail))
|
||||
else:
|
||||
return input_tensors[index]
|
148
mediapipe/model_maker/python/core/utils/model_util_test.py
Normal file
148
mediapipe/model_maker/python/core/utils/model_util_test.py
Normal file
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_load_model(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
model.save(saved_model_path)
|
||||
loaded_model = model_util.load_keras_model(saved_model_path)
|
||||
|
||||
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
||||
model_output = model.predict_on_batch(input_tensors)
|
||||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||
self.assertTrue((model_output == loaded_model_output).all())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='input_only_steps_per_epoch',
|
||||
steps_per_epoch=1000,
|
||||
batch_size=None,
|
||||
train_data=None,
|
||||
expected_steps_per_epoch=1000),
|
||||
dict(
|
||||
testcase_name='input_steps_per_epoch_and_batch_size',
|
||||
steps_per_epoch=1000,
|
||||
batch_size=32,
|
||||
train_data=None,
|
||||
expected_steps_per_epoch=1000),
|
||||
dict(
|
||||
testcase_name='input_steps_per_epoch_batch_size_and_train_data',
|
||||
steps_per_epoch=1000,
|
||||
batch_size=32,
|
||||
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]]),
|
||||
expected_steps_per_epoch=1000),
|
||||
dict(
|
||||
testcase_name='input_batch_size_and_train_data',
|
||||
steps_per_epoch=None,
|
||||
batch_size=2,
|
||||
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]]),
|
||||
expected_steps_per_epoch=2))
|
||||
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
|
||||
expected_steps_per_epoch):
|
||||
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
batch_size=batch_size,
|
||||
train_data=train_data)
|
||||
self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch)
|
||||
|
||||
def test_get_steps_per_epoch_raise_value_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=None, batch_size=16, train_data=None)
|
||||
|
||||
def test_warmup(self):
|
||||
init_lr = 0.1
|
||||
warmup_steps = 1000
|
||||
num_decay_steps = 100
|
||||
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=num_decay_steps)
|
||||
warmup_object = model_util.WarmUp(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate_fn,
|
||||
warmup_steps=1000,
|
||||
name='test')
|
||||
self.assertEqual(
|
||||
warmup_object.get_config(), {
|
||||
'initial_learning_rate': init_lr,
|
||||
'decay_schedule_fn': learning_rate_fn,
|
||||
'warmup_steps': warmup_steps,
|
||||
'name': 'test'
|
||||
})
|
||||
|
||||
def test_export_tflite(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.export_tflite(model, tflite_file)
|
||||
self._test_tflite(model, tflite_file, input_dim)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='dynamic_quantize',
|
||||
config=quantization.QuantizationConfig.for_dynamic(),
|
||||
model_size=1288),
|
||||
dict(
|
||||
testcase_name='int8_quantize',
|
||||
config=quantization.QuantizationConfig.for_int8(
|
||||
representative_data=test_util.create_dataset(
|
||||
data_size=10, input_shape=[16], num_classes=3)),
|
||||
model_size=1832),
|
||||
dict(
|
||||
testcase_name='float16_quantize',
|
||||
config=quantization.QuantizationConfig.for_float16(),
|
||||
model_size=1468))
|
||||
def test_export_tflite_quantized(self, config, model_size):
|
||||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
model = test_util.build_model([input_dim], num_classes)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||
|
||||
model_util.export_tflite(model, tflite_file, config)
|
||||
self._test_tflite(
|
||||
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
def _test_tflite(self,
|
||||
keras_model: tf.keras.Model,
|
||||
tflite_model_file: str,
|
||||
input_dim: int,
|
||||
max_input_value: int = 1000,
|
||||
atol: float = 1e-04):
|
||||
random_input = test_util.create_random_sample(
|
||||
size=[1, input_dim], high=max_input_value)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
self.assertTrue(
|
||||
test_util.is_same_output(
|
||||
tflite_model_file, keras_model, random_input, atol=atol))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
213
mediapipe/model_maker/python/core/utils/quantization.py
Normal file
213
mediapipe/model_maker/python/core/utils/quantization.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Libraries for post-training quantization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
|
||||
DEFAULT_QUANTIZATION_STEPS = 500
|
||||
|
||||
|
||||
def _get_representative_dataset_generator(dataset: tf.data.Dataset,
|
||||
num_steps: int) -> Callable[[], Any]:
|
||||
"""Gets a representative dataset generator for post-training quantization.
|
||||
|
||||
The generator is to provide a small dataset to calibrate or estimate the
|
||||
range, i.e, (min, max) of all floating-point arrays in the model for
|
||||
quantization. Usually, this is a small subset of a few hundred samples
|
||||
randomly chosen, in no particular order, from the training or evaluation
|
||||
dataset. See tf.lite.RepresentativeDataset for more details.
|
||||
|
||||
Args:
|
||||
dataset: Input dataset for extracting representative sub dataset.
|
||||
num_steps: The number of quantization steps which also reflects the size of
|
||||
the representative dataset.
|
||||
|
||||
Returns:
|
||||
A representative dataset generator.
|
||||
"""
|
||||
|
||||
def representative_dataset_gen():
|
||||
"""Generates representative dataset for quantization."""
|
||||
for data, _ in dataset.take(num_steps):
|
||||
yield [data]
|
||||
|
||||
return representative_dataset_gen
|
||||
|
||||
|
||||
class QuantizationConfig(object):
|
||||
"""Configuration for post-training quantization.
|
||||
|
||||
Refer to
|
||||
https://www.tensorflow.org/lite/performance/post_training_quantization
|
||||
for different post-training quantization options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizations: Optional[Union[tf.lite.Optimize,
|
||||
List[tf.lite.Optimize]]] = None,
|
||||
representative_data: Optional[ds.Dataset] = None,
|
||||
quantization_steps: Optional[int] = None,
|
||||
inference_input_type: Optional[tf.dtypes.DType] = None,
|
||||
inference_output_type: Optional[tf.dtypes.DType] = None,
|
||||
supported_ops: Optional[Union[tf.lite.OpsSet,
|
||||
List[tf.lite.OpsSet]]] = None,
|
||||
supported_types: Optional[Union[tf.dtypes.DType,
|
||||
List[tf.dtypes.DType]]] = None,
|
||||
experimental_new_quantizer: bool = False,
|
||||
):
|
||||
"""Constructs QuantizationConfig.
|
||||
|
||||
Args:
|
||||
optimizations: A list of optimizations to apply when converting the model.
|
||||
If not set, use `[Optimize.DEFAULT]` by default.
|
||||
representative_data: A representative ds.Dataset for post-training
|
||||
quantization.
|
||||
quantization_steps: Number of post-training quantization calibration steps
|
||||
to run (default to DEFAULT_QUANTIZATION_STEPS).
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays. Defaults to None. If set, must be
|
||||
be `{tf.float32, tf.uint8, tf.int8}`.
|
||||
inference_output_type: Target data type of real-number output arrays.
|
||||
Allows for a different type for output arrays. Defaults to None. If set,
|
||||
must be `{tf.float32, tf.uint8, tf.int8}`.
|
||||
supported_ops: Set of OpsSet options supported by the device. Used to Set
|
||||
converter.target_spec.supported_ops.
|
||||
supported_types: List of types for constant values on the target device.
|
||||
Supported values are types exported by lite.constants. Frequently, an
|
||||
optimization choice is driven by the most compact (i.e. smallest) type
|
||||
in this list (default [constants.FLOAT]).
|
||||
experimental_new_quantizer: Whether to enable experimental new quantizer.
|
||||
|
||||
Raises:
|
||||
ValueError: if inference_input_type or inference_output_type are set but
|
||||
not in {tf.float32, tf.uint8, tf.int8}.
|
||||
"""
|
||||
if inference_input_type is not None and inference_input_type not in {
|
||||
tf.float32, tf.uint8, tf.int8
|
||||
}:
|
||||
raise ValueError('Unsupported inference_input_type %s' %
|
||||
inference_input_type)
|
||||
if inference_output_type is not None and inference_output_type not in {
|
||||
tf.float32, tf.uint8, tf.int8
|
||||
}:
|
||||
raise ValueError('Unsupported inference_output_type %s' %
|
||||
inference_output_type)
|
||||
|
||||
if optimizations is None:
|
||||
optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
if not isinstance(optimizations, list):
|
||||
optimizations = [optimizations]
|
||||
self.optimizations = optimizations
|
||||
|
||||
self.representative_data = representative_data
|
||||
if self.representative_data is not None and quantization_steps is None:
|
||||
quantization_steps = DEFAULT_QUANTIZATION_STEPS
|
||||
self.quantization_steps = quantization_steps
|
||||
|
||||
self.inference_input_type = inference_input_type
|
||||
self.inference_output_type = inference_output_type
|
||||
|
||||
if supported_ops is not None and not isinstance(supported_ops, list):
|
||||
supported_ops = [supported_ops]
|
||||
self.supported_ops = supported_ops
|
||||
|
||||
if supported_types is not None and not isinstance(supported_types, list):
|
||||
supported_types = [supported_types]
|
||||
self.supported_types = supported_types
|
||||
|
||||
self.experimental_new_quantizer = experimental_new_quantizer
|
||||
|
||||
@classmethod
|
||||
def for_dynamic(cls) -> 'QuantizationConfig':
|
||||
"""Creates configuration for dynamic range quantization."""
|
||||
return QuantizationConfig()
|
||||
|
||||
@classmethod
|
||||
def for_int8(
|
||||
cls,
|
||||
representative_data: ds.Dataset,
|
||||
quantization_steps: int = DEFAULT_QUANTIZATION_STEPS,
|
||||
inference_input_type: tf.dtypes.DType = tf.uint8,
|
||||
inference_output_type: tf.dtypes.DType = tf.uint8,
|
||||
supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||
) -> 'QuantizationConfig':
|
||||
"""Creates configuration for full integer quantization.
|
||||
|
||||
Args:
|
||||
representative_data: Representative data used for post-training
|
||||
quantization.
|
||||
quantization_steps: Number of post-training quantization calibration steps
|
||||
to run.
|
||||
inference_input_type: Target data type of real-number input arrays.
|
||||
inference_output_type: Target data type of real-number output arrays.
|
||||
supported_ops: Set of `tf.lite.OpsSet` options, where each option
|
||||
represents a set of operators supported by the target device.
|
||||
|
||||
Returns:
|
||||
QuantizationConfig.
|
||||
"""
|
||||
return QuantizationConfig(
|
||||
representative_data=representative_data,
|
||||
quantization_steps=quantization_steps,
|
||||
inference_input_type=inference_input_type,
|
||||
inference_output_type=inference_output_type,
|
||||
supported_ops=supported_ops)
|
||||
|
||||
@classmethod
|
||||
def for_float16(cls) -> 'QuantizationConfig':
|
||||
"""Creates configuration for float16 quantization."""
|
||||
return QuantizationConfig(supported_types=[tf.float16])
|
||||
|
||||
def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter,
|
||||
**kwargs: Any) -> tf.lite.TFLiteConverter:
|
||||
"""Sets input TFLite converter with quantization configurations.
|
||||
|
||||
Args:
|
||||
converter: input tf.lite.TFLiteConverter.
|
||||
**kwargs: arguments used by ds.Dataset.gen_tf_dataset.
|
||||
|
||||
Returns:
|
||||
tf.lite.TFLiteConverter with quantization configurations.
|
||||
"""
|
||||
converter.optimizations = self.optimizations
|
||||
|
||||
if self.representative_data is not None:
|
||||
tf_ds = self.representative_data.gen_tf_dataset(
|
||||
batch_size=1, is_training=False, **kwargs)
|
||||
converter.representative_dataset = tf.lite.RepresentativeDataset(
|
||||
_get_representative_dataset_generator(tf_ds, self.quantization_steps))
|
||||
|
||||
if self.inference_input_type:
|
||||
converter.inference_input_type = self.inference_input_type
|
||||
if self.inference_output_type:
|
||||
converter.inference_output_type = self.inference_output_type
|
||||
if self.supported_ops:
|
||||
converter.target_spec.supported_ops = self.supported_ops
|
||||
if self.supported_types:
|
||||
converter.target_spec.supported_types = self.supported_types
|
||||
|
||||
if self.experimental_new_quantizer is not None:
|
||||
converter.experimental_new_quantizer = self.experimental_new_quantizer
|
||||
return converter
|
108
mediapipe/model_maker/python/core/utils/quantization_test.py
Normal file
108
mediapipe/model_maker/python/core/utils/quantization_test.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_create_dynamic_quantization_config(self):
|
||||
config = quantization.QuantizationConfig.for_dynamic()
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertIsNone(config.representative_data)
|
||||
self.assertIsNone(config.inference_input_type)
|
||||
self.assertIsNone(config.inference_output_type)
|
||||
self.assertIsNone(config.supported_ops)
|
||||
self.assertIsNone(config.supported_types)
|
||||
self.assertFalse(config.experimental_new_quantizer)
|
||||
|
||||
def test_create_int8_quantization_config(self):
|
||||
representative_data = test_util.create_dataset(
|
||||
data_size=10, input_shape=[4], num_classes=3)
|
||||
config = quantization.QuantizationConfig.for_int8(
|
||||
representative_data=representative_data)
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertEqual(config.inference_input_type, tf.uint8)
|
||||
self.assertEqual(config.inference_output_type, tf.uint8)
|
||||
self.assertEqual(config.supported_ops,
|
||||
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
|
||||
self.assertFalse(config.experimental_new_quantizer)
|
||||
|
||||
def test_set_converter_with_quantization_from_int8_config(self):
|
||||
representative_data = test_util.create_dataset(
|
||||
data_size=10, input_shape=[4], num_classes=3)
|
||||
config = quantization.QuantizationConfig.for_int8(
|
||||
representative_data=representative_data)
|
||||
model = test_util.build_model(input_shape=[4], num_classes=3)
|
||||
saved_model_dir = self.get_temp_dir()
|
||||
model.save(saved_model_dir)
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
converter = config.set_converter_with_quantization(converter=converter)
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertEqual(config.inference_input_type, tf.uint8)
|
||||
self.assertEqual(config.inference_output_type, tf.uint8)
|
||||
self.assertEqual(config.supported_ops,
|
||||
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
|
||||
tflite_model = converter.convert()
|
||||
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
|
||||
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)
|
||||
|
||||
def test_create_float16_quantization_config(self):
|
||||
config = quantization.QuantizationConfig.for_float16()
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertIsNone(config.representative_data)
|
||||
self.assertIsNone(config.inference_input_type)
|
||||
self.assertIsNone(config.inference_output_type)
|
||||
self.assertIsNone(config.supported_ops)
|
||||
self.assertEqual(config.supported_types, [tf.float16])
|
||||
self.assertFalse(config.experimental_new_quantizer)
|
||||
|
||||
def test_set_converter_with_quantization_from_float16_config(self):
|
||||
config = quantization.QuantizationConfig.for_float16()
|
||||
model = test_util.build_model(input_shape=[4], num_classes=3)
|
||||
saved_model_dir = self.get_temp_dir()
|
||||
model.save(saved_model_dir)
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
converter = config.set_converter_with_quantization(converter=converter)
|
||||
self.assertEqual(config.supported_types, [tf.float16])
|
||||
tflite_model = converter.convert()
|
||||
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
# The input and output are expected to be set to float32 by default.
|
||||
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
|
||||
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='invalid_inference_input_type',
|
||||
inference_input_type=tf.uint8,
|
||||
inference_output_type=tf.int64),
|
||||
dict(
|
||||
testcase_name='invalid_inference_output_type',
|
||||
inference_input_type=tf.int64,
|
||||
inference_output_type=tf.float32))
|
||||
def test_create_quantization_config_failure(self, inference_input_type,
|
||||
inference_output_type):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = quantization.QuantizationConfig(
|
||||
inference_input_type=inference_input_type,
|
||||
inference_output_type=inference_output_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
94
mediapipe/model_maker/python/core/utils/test_util.py
Normal file
94
mediapipe/model_maker/python/core/utils/test_util.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test utilities for model maker."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
||||
|
||||
def create_dataset(data_size: int,
|
||||
input_shape: List[int],
|
||||
num_classes: int,
|
||||
max_input_value: int = 1000) -> ds.Dataset:
|
||||
"""Creates and returns a simple `Dataset` object for test."""
|
||||
features = tf.random.uniform(
|
||||
shape=[data_size] + input_shape,
|
||||
minval=0,
|
||||
maxval=max_input_value,
|
||||
dtype=tf.float32)
|
||||
|
||||
labels = tf.random.uniform(
|
||||
shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32)
|
||||
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
||||
dataset = ds.Dataset(tf_dataset, data_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def create_random_sample(size: Union[int, List[int]],
|
||||
low: float = 0,
|
||||
high: float = 1) -> np.ndarray:
|
||||
"""Creates and returns a random sample with floating point values.
|
||||
|
||||
Args:
|
||||
size: Size of the output multi-dimensional array.
|
||||
low: Lower boundary of the output values.
|
||||
high: Higher boundary of the output values.
|
||||
|
||||
Returns:
|
||||
1D array if the size is scalar. Otherwise, N-D array whose dimension equals
|
||||
input size.
|
||||
"""
|
||||
np.random.seed(0)
|
||||
return np.random.uniform(low=low, high=high, size=size).astype(np.float32)
|
||||
|
||||
|
||||
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
|
||||
"""Builds a simple Keras model for test."""
|
||||
inputs = tf.keras.layers.Input(shape=input_shape)
|
||||
if len(input_shape) == 3: # Image inputs.
|
||||
outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs)
|
||||
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs)
|
||||
elif len(input_shape) == 1: # Text inputs.
|
||||
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)
|
||||
else:
|
||||
raise ValueError("Model inputs should be 2D tensor or 4D tensor.")
|
||||
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
|
||||
|
||||
def is_same_output(tflite_file: str,
|
||||
keras_model: tf.keras.Model,
|
||||
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Returns if the output of TFLite model and keras model are identical."""
|
||||
# Gets output from lite model.
|
||||
lite_runner = model_util.get_lite_runner(tflite_file)
|
||||
lite_output = lite_runner.run(input_tensors)
|
||||
|
||||
# Gets output from keras model.
|
||||
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||
|
||||
return np.allclose(lite_output, keras_output, atol=atol)
|
19
mediapipe/model_maker/python/vision/BUILD
Normal file
19
mediapipe/model_maker/python/vision/BUILD
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
13
mediapipe/model_maker/python/vision/__init__.py
Normal file
13
mediapipe/model_maker/python/vision/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
111
mediapipe/model_maker/python/vision/image_classifier/BUILD
Normal file
111
mediapipe/model_maker/python/vision/image_classifier/BUILD
Normal file
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python library rule.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier_import",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":image_classifier",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_spec",
|
||||
srcs = ["model_spec.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_spec_test",
|
||||
srcs = ["model_spec_test.py"],
|
||||
deps = [":model_spec"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "train_image_classifier_lib",
|
||||
srcs = ["train_image_classifier_lib.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier",
|
||||
srcs = ["image_classifier.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_spec",
|
||||
":train_image_classifier_lib",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["image_classifier_test.py"],
|
||||
deps = [":image_classifier_import"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_classifier_test",
|
||||
srcs = ["image_classifier_test.py"],
|
||||
shard_count = 2,
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":image_classifier_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "image_classifier_demo",
|
||||
srcs = ["image_classifier_demo.py"],
|
||||
deps = [
|
||||
":image_classifier_import",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Model Maker Python Public API For Image Classifier."""
|
||||
|
||||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||
|
||||
ImageClassifier = image_classifier.ImageClassifier
|
||||
HParams = hyperparameters.HParams
|
||||
Dataset = dataset.Dataset
|
||||
ModelSpec = model_spec.ModelSpec
|
||||
SupportedModels = model_spec.SupportedModels
|
139
mediapipe/model_maker/python/vision/image_classifier/dataset.py
Normal file
139
mediapipe/model_maker/python/vision/image_classifier/dataset.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image classifier dataset library."""
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
def _load_image(path: str) -> tf.Tensor:
|
||||
"""Loads image."""
|
||||
image_raw = tf.io.read_file(path)
|
||||
image_tensor = tf.cond(
|
||||
tf.image.is_jpeg(image_raw),
|
||||
lambda: tf.image.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.image.decode_png(image_raw, channels=3))
|
||||
return image_tensor
|
||||
|
||||
|
||||
def _create_data(
|
||||
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
|
||||
label_names: List[str]
|
||||
) -> Optional[classification_dataset.ClassificationDataset]:
|
||||
"""Creates a Dataset object from tfds data."""
|
||||
if name not in data:
|
||||
return None
|
||||
data = data[name]
|
||||
data = data.map(lambda a: (a['image'], a['label']))
|
||||
size = info.splits[name].num_examples
|
||||
return Dataset(data, size, label_names)
|
||||
|
||||
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for image classifier."""
|
||||
|
||||
@classmethod
|
||||
def from_folder(
|
||||
cls,
|
||||
dirname: str,
|
||||
shuffle: bool = True) -> classification_dataset.ClassificationDataset:
|
||||
"""Loads images and labels from the given directory.
|
||||
|
||||
Assume the image data of the same label are in the same subdirectory.
|
||||
|
||||
Args:
|
||||
dirname: Name of the directory containing the data files.
|
||||
shuffle: boolean, if shuffle, random shuffle data.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
data_root = os.path.abspath(dirname)
|
||||
|
||||
# Assumes the image data of the same label are in the same subdirectory,
|
||||
# gets image path and label names.
|
||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||
all_image_size = len(all_image_paths)
|
||||
if all_image_size == 0:
|
||||
raise ValueError('Image size is zero')
|
||||
|
||||
if shuffle:
|
||||
# Random shuffle data.
|
||||
random.shuffle(all_image_paths)
|
||||
|
||||
label_names = sorted(
|
||||
name for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name)))
|
||||
all_label_size = len(label_names)
|
||||
label_to_index = dict(
|
||||
(name, index) for index, name in enumerate(label_names))
|
||||
all_image_labels = [
|
||||
label_to_index[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
autotune = tf.data.AUTOTUNE
|
||||
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
|
||||
|
||||
# Loads label.
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(all_image_labels, tf.int64))
|
||||
|
||||
# Creates a dataset if (image, label) pairs.
|
||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||
|
||||
tf.compat.v1.logging.info(
|
||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(image_label_ds, all_image_size, label_names)
|
||||
|
||||
@classmethod
|
||||
def load_tf_dataset(
|
||||
cls, name: str
|
||||
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset]]:
|
||||
"""Loads data from tensorflow_datasets.
|
||||
|
||||
Args:
|
||||
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
|
||||
documentation of tfds.load for more details.
|
||||
|
||||
Returns:
|
||||
A tuple of Datasets for the train/validation/test.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input tf dataset does not have train/validation/test
|
||||
labels.
|
||||
"""
|
||||
data, info = tfds.load(name, with_info=True)
|
||||
if 'label' not in info.features:
|
||||
raise ValueError('info.features need to contain \'label\' key.')
|
||||
label_names = info.features['label'].names
|
||||
|
||||
train_data = _create_data('train', data, info, label_names)
|
||||
validation_data = _create_data('validation', data, info, label_names)
|
||||
test_data = _create_data('test', data, info, label_names)
|
||||
return train_data, validation_data, test_data
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||
|
||||
|
||||
def _fill_image(rgb, image_size):
|
||||
r, g, b = rgb
|
||||
return np.broadcast_to(
|
||||
np.array([[[r, g, b]]], dtype=np.uint8),
|
||||
shape=(image_size, image_size, 3))
|
||||
|
||||
|
||||
def _write_filled_jpeg_file(path, rgb, image_size):
|
||||
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
|
||||
'channels_last', 'jpeg')
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir')
|
||||
if os.path.exists(self.image_path):
|
||||
return
|
||||
os.mkdir(self.image_path)
|
||||
for class_name in ('daisy', 'tulips'):
|
||||
class_subdir = os.path.join(self.image_path, class_name)
|
||||
os.mkdir(class_subdir)
|
||||
_write_filled_jpeg_file(
|
||||
os.path.join(class_subdir, '0.jpeg'),
|
||||
[random.uniform(0, 255) for _ in range(3)], 224)
|
||||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
for i, elem in enumerate(train_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
self.assertEqual(train_data.num_classes, 2)
|
||||
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
self.assertEqual(test_data.num_classes, 2)
|
||||
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
|
||||
|
||||
def test_from_folder(self):
|
||||
data = dataset.Dataset.from_folder(self.image_path)
|
||||
|
||||
self.assertLen(data, 2)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
|
||||
for image, label in data.gen_tf_dataset():
|
||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||
if label.numpy() == 0:
|
||||
raw_image_tensor = dataset._load_image(
|
||||
os.path.join(self.image_path, 'daisy', '0.jpeg'))
|
||||
else:
|
||||
raw_image_tensor = dataset._load_image(
|
||||
os.path.join(self.image_path, 'tulips', '0.jpeg'))
|
||||
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
|
||||
|
||||
def test_from_tfds(self):
|
||||
# TODO: Remove this once tfds download error is fixed.
|
||||
self.skipTest('Temporarily skip the unittest due to tfds download error.')
|
||||
train_data, validation_data, test_data = (
|
||||
dataset.Dataset.from_tfds('beans'))
|
||||
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(train_data, 1034)
|
||||
self.assertEqual(train_data.num_classes, 3)
|
||||
self.assertEqual(train_data.index_to_label,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(validation_data, 133)
|
||||
self.assertEqual(validation_data.num_classes, 3)
|
||||
self.assertEqual(validation_data.index_to_label,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(test_data, 128)
|
||||
self.assertEqual(test_data.num_classes, 3)
|
||||
self.assertEqual(test_data.index_to_label,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training image classification models."""
|
||||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# TODO: Expose other hyperparameters, e.g. data augmentation
|
||||
# hyperparameters if requested.
|
||||
@dataclasses.dataclass
|
||||
class HParams:
|
||||
"""The hyperparameters for training image classifiers.
|
||||
|
||||
The hyperparameters include:
|
||||
# Parameters about training data.
|
||||
do_fine_tuning: If true, the base module is trained together with the
|
||||
classification layer on top.
|
||||
shuffle: A boolean controlling if shuffle the dataset. Default to false.
|
||||
|
||||
# Parameters about training configuration
|
||||
train_epochs: Training will do this many iterations over the dataset.
|
||||
batch_size: Each training step samples a batch of this many images.
|
||||
learning_rate: The learning rate to use for gradient descent training.
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
l1_regularizer: A regularizer that applies a L1 regularization penalty.
|
||||
l2_regularizer: A regularizer that applies a L2 regularization penalty.
|
||||
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
|
||||
more details.
|
||||
do_data_augmentation: A boolean controlling whether the training dataset is
|
||||
augmented by randomly distorting input images, including random cropping,
|
||||
flipping, etc. See utils.image_preprocessing documentation for details.
|
||||
steps_per_epoch: An optional integer indicate the number of training steps
|
||||
per epoch. If not set, the training pipeline calculates the default steps
|
||||
per epoch as the training dataset size devided by batch size.
|
||||
decay_samples: Number of training samples used to calculate the decay steps
|
||||
and create the training optimizer.
|
||||
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
|
||||
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
|
||||
|
||||
# Parameters about the saved checkpoint
|
||||
model_dir: The location of model checkpoint files and exported model files.
|
||||
"""
|
||||
# Parameters about training data
|
||||
do_fine_tuning: bool = False
|
||||
shuffle: bool = False
|
||||
# Parameters about training configuration
|
||||
train_epochs: int = 5
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 0.005
|
||||
dropout_rate: float = 0.2
|
||||
l1_regularizer: float = 0.0
|
||||
l2_regularizer: float = 0.0001
|
||||
label_smoothing: float = 0.1
|
||||
do_data_augmentation: bool = True
|
||||
steps_per_epoch: Optional[int] = None
|
||||
decay_samples: int = 10000 * 256
|
||||
warmup_epochs: int = 2
|
||||
|
||||
# Parameters about the saved checkpoint
|
||||
model_dir: str = tempfile.mkdtemp()
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""APIs to train image classifier model."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
|
||||
|
||||
class ImageClassifier(classifier.Classifier):
|
||||
"""ImageClassifier for building image classification model."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
|
||||
hparams: hp.HParams):
|
||||
"""Initializes ImageClassifier class.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_to_label: A list that maps from index to label class name.
|
||||
hparams: The hyperparameters for training image classifier.
|
||||
"""
|
||||
super(ImageClassifier, self).__init__(
|
||||
model_spec=model_spec,
|
||||
index_to_label=index_to_label,
|
||||
shuffle=hparams.shuffle,
|
||||
full_train=hparams.do_fine_tuning)
|
||||
self._hparams = hparams
|
||||
self._preprocess = image_preprocessing.Preprocessor(
|
||||
input_shape=self._model_spec.input_image_shape,
|
||||
num_classes=self._num_classes,
|
||||
mean_rgb=self._model_spec.mean_rgb,
|
||||
stddev_rgb=self._model_spec.stddev_rgb,
|
||||
use_augmentation=hparams.do_data_augmentation)
|
||||
self._history = None # Training history returned from `keras_model.fit`.
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_spec: ms.SupportedModels,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
hparams: Optional[hp.HParams] = None,
|
||||
) -> 'ImageClassifier':
|
||||
"""Creates and trains an image classifier.
|
||||
|
||||
Loads data and trains the model based on data for image classification.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
hparams: Hyperparameters for training image classifier.
|
||||
|
||||
Returns:
|
||||
An instance based on ImageClassifier.
|
||||
"""
|
||||
if hparams is None:
|
||||
hparams = hp.HParams()
|
||||
|
||||
spec = ms.SupportedModels.get(model_spec)
|
||||
image_classifier = cls(
|
||||
model_spec=spec,
|
||||
index_to_label=train_data.index_to_label,
|
||||
hparams=hparams)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
image_classifier._train(
|
||||
train_data=train_data, validation_data=validation_data)
|
||||
|
||||
return image_classifier
|
||||
|
||||
def _train(self, train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
||||
The training results are recorded by a self._history object returned by
|
||||
tf.keras.Model.fit().
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
hparams = self._hparams
|
||||
if len(train_data) < hparams.batch_size:
|
||||
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
|
||||
'than batch_size (%d). To solve this problem, set '
|
||||
'the batch_size smaller or increase the size of the '
|
||||
'train_data.' % (len(train_data), hparams.batch_size))
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=self._shuffle,
|
||||
preprocess=self._preprocess)
|
||||
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=hparams.steps_per_epoch,
|
||||
batch_size=hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
|
||||
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=hparams.batch_size,
|
||||
is_training=False,
|
||||
preprocess=self._preprocess)
|
||||
|
||||
# Train the model.
|
||||
self._history = train_image_classifier_lib.train_model(
|
||||
model=self._model,
|
||||
hparams=hparams,
|
||||
train_ds=train_dataset,
|
||||
validation_ds=validation_dataset)
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the classifier model from TFHub pretrained models."""
|
||||
module_layer = hub.KerasLayer(
|
||||
handle=self._model_spec.uri, trainable=self._hparams.do_fine_tuning)
|
||||
|
||||
image_size = self._model_spec.input_image_shape
|
||||
|
||||
self._model = tf.keras.Sequential([
|
||||
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
|
||||
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
|
||||
tf.keras.layers.Dense(
|
||||
units=self._num_classes,
|
||||
activation='softmax',
|
||||
kernel_regularizer=tf.keras.regularizers.l1_l2(
|
||||
l1=self._hparams.l1_regularizer,
|
||||
l2=self._hparams.l2_regularizer))
|
||||
])
|
||||
print(self._model.summary())
|
||||
|
||||
def export_model(
|
||||
self,
|
||||
model_name: str = 'model.tflite',
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||
"""Converts the model to the requested formats and exports to a file.
|
||||
|
||||
Args:
|
||||
model_name: File name to save tflite model. The full export path is
|
||||
{export_dir}/{tflite_filename}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
super().export_tflite(
|
||||
self._hparams.model_dir,
|
||||
model_name,
|
||||
quantization_config,
|
||||
preprocess=self._preprocess)
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Demo for making an image classifier model by MediaPipe Model Maker."""
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def define_flags() -> None:
|
||||
"""Define flags for the image classifier model maker demo."""
|
||||
flags.DEFINE_string('export_dir', None,
|
||||
'The directory to save exported files.')
|
||||
flags.DEFINE_string(
|
||||
'input_data_dir', None,
|
||||
"""The directory with input training data. If the training data is not
|
||||
specified, the pipeline will download a default training dataset.""")
|
||||
flags.DEFINE_enum_class('spec',
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
image_classifier.SupportedModels,
|
||||
'The image classifier to run.')
|
||||
flags.DEFINE_enum('quantization', None, ['dynamic', 'int8', 'float16'],
|
||||
'The quantization method to use when exporting the model.')
|
||||
flags.mark_flag_as_required('export_dir')
|
||||
|
||||
|
||||
def download_demo_data() -> str:
|
||||
"""Downloads demo data, and returns directory path."""
|
||||
data_dir = tf.keras.utils.get_file(
|
||||
fname='flower_photos.tgz',
|
||||
origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
|
||||
extract=True)
|
||||
return os.path.join(os.path.dirname(data_dir), 'flower_photos') # folder name
|
||||
|
||||
|
||||
def run(data_dir: str, export_dir: str,
|
||||
model_spec: image_classifier.SupportedModels,
|
||||
quantization_option: str) -> None:
|
||||
"""Runs demo."""
|
||||
data = image_classifier.Dataset.from_folder(data_dir)
|
||||
train_data, rest_data = data.split(0.8)
|
||||
validation_data, test_data = rest_data.split(0.5)
|
||||
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=model_spec,
|
||||
train_data=train_data,
|
||||
validation_data=validation_data,
|
||||
hparams=image_classifier.HParams(model_dir=export_dir))
|
||||
|
||||
_, acc = model.evaluate(test_data)
|
||||
print('Test accuracy: %f' % acc)
|
||||
|
||||
if quantization_option is None:
|
||||
quantization_config = None
|
||||
elif quantization_option == 'dynamic':
|
||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||
elif quantization_option == 'int8':
|
||||
quantization_config = quantization.QuantizationConfig.for_int8(train_data)
|
||||
elif quantization_option == 'float16':
|
||||
quantization_config = quantization.QuantizationConfig.for_float16()
|
||||
else:
|
||||
raise ValueError(f'Quantization: {quantization} is not recognized')
|
||||
|
||||
model.export_model(quantization_config=quantization_config)
|
||||
model.export_labels(export_dir)
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
if FLAGS.input_data_dir is None:
|
||||
data_dir = download_demo_data()
|
||||
else:
|
||||
data_dir = FLAGS.input_data_dir
|
||||
|
||||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||
run(data_dir=data_dir,
|
||||
export_dir=export_dir,
|
||||
model_spec=FLAGS.spec,
|
||||
quantization_option=FLAGS.quantization)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
define_flags()
|
||||
app.run(main)
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
|
||||
|
||||
def _fill_image(rgb, image_size):
|
||||
r, g, b = rgb
|
||||
return np.broadcast_to(
|
||||
np.array([[[r, g, b]]], dtype=np.uint8),
|
||||
shape=(image_size, image_size, 3))
|
||||
|
||||
|
||||
class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||
IMAGE_SIZE = 24
|
||||
IMAGES_PER_CLASS = 2
|
||||
CMY_NAMES_AND_RGB_VALUES = (('cyan', (0, 255, 255)),
|
||||
('magenta', (255, 0, 255)), ('yellow', (255, 255,
|
||||
0)))
|
||||
|
||||
def _gen(self):
|
||||
for i, (_, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES):
|
||||
for _ in range(self.IMAGES_PER_CLASS):
|
||||
yield (_fill_image(rgb, self.IMAGE_SIZE), i)
|
||||
|
||||
def _gen_cmy_data(self):
|
||||
ds = tf.data.Dataset.from_generator(
|
||||
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
||||
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
||||
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
|
||||
['cyan', 'magenta', 'yellow'])
|
||||
return data
|
||||
|
||||
def setUp(self):
|
||||
super(ImageClassifierTest, self).setUp()
|
||||
all_data = self._gen_cmy_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self.train_data, self.test_data = all_data.split(0.9)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='mobilenet_v2',
|
||||
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='resnet_50',
|
||||
model_spec=image_classifier.SupportedModels.RESNET_50,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite1',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite3',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
)
|
||||
def test_create_and_train_model(self,
|
||||
model_spec: image_classifier.SupportedModels,
|
||||
hparams: image_classifier.HParams):
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=model_spec,
|
||||
train_data=self.train_data,
|
||||
hparams=hparams,
|
||||
validation_data=self.test_data)
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
|
||||
hparams = image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
train_data=self.train_data,
|
||||
hparams=hparams,
|
||||
validation_data=self.test_data)
|
||||
self._test_accuracy(model)
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.0):
|
||||
_, accuracy = model.evaluate(self.test_data)
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image classifier model specification."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class ModelSpec(object):
|
||||
"""Specification of image classifier model."""
|
||||
|
||||
mean_rgb = [0.0]
|
||||
stddev_rgb = [255.0]
|
||||
|
||||
def __init__(self,
|
||||
uri: str,
|
||||
input_image_shape: Optional[List[int]] = None,
|
||||
name: str = ''):
|
||||
"""Initializes a new instance of the `ImageModelSpec` class.
|
||||
|
||||
Args:
|
||||
uri: str, URI to the pretrained model.
|
||||
input_image_shape: list of int, input image shape. Default: [224, 224].
|
||||
name: str, model spec name.
|
||||
"""
|
||||
self.uri = uri
|
||||
self.name = name
|
||||
|
||||
if input_image_shape is None:
|
||||
input_image_shape = [224, 224]
|
||||
self.input_image_shape = input_image_shape
|
||||
|
||||
|
||||
mobilenet_v2_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||
name='mobilenet_v2')
|
||||
|
||||
resnet_50_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
||||
name='resnet_50')
|
||||
|
||||
efficientnet_lite0_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||
name='efficientnet_lite0')
|
||||
|
||||
efficientnet_lite1_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||
input_image_shape=[240, 240],
|
||||
name='efficientnet_lite1')
|
||||
|
||||
efficientnet_lite2_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
||||
input_image_shape=[260, 260],
|
||||
name='efficientnet_lite2')
|
||||
|
||||
efficientnet_lite3_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
||||
input_image_shape=[280, 280],
|
||||
name='efficientnet_lite3')
|
||||
|
||||
efficientnet_lite4_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
||||
input_image_shape=[300, 300],
|
||||
name='efficientnet_lite4')
|
||||
|
||||
|
||||
# TODO: Document the exposed models.
|
||||
@enum.unique
|
||||
class SupportedModels(enum.Enum):
|
||||
"""Image classifier model supported by model maker."""
|
||||
MOBILENET_V2 = mobilenet_v2_spec
|
||||
RESNET_50 = resnet_50_spec
|
||||
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
|
||||
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
|
||||
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
|
||||
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
|
||||
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
|
||||
|
||||
@classmethod
|
||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||
"""Gets model spec from the input enum and initializes it."""
|
||||
if spec not in cls:
|
||||
raise TypeError('Unsupported image classifier spec: {}'.format(spec))
|
||||
|
||||
return spec.value()
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from typing import Callable, List
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
|
||||
|
||||
class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='mobilenet_v2_spec_test',
|
||||
model_spec=ms.mobilenet_v2_spec,
|
||||
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||
expected_name='mobilenet_v2',
|
||||
expected_input_image_shape=[224, 224]),
|
||||
dict(
|
||||
testcase_name='resnet_50_spec_test',
|
||||
model_spec=ms.resnet_50_spec,
|
||||
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
||||
expected_name='resnet_50',
|
||||
expected_input_image_shape=[224, 224]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0_spec_test',
|
||||
model_spec=ms.efficientnet_lite0_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||
expected_name='efficientnet_lite0',
|
||||
expected_input_image_shape=[224, 224]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite1_spec_test',
|
||||
model_spec=ms.efficientnet_lite1_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||
expected_name='efficientnet_lite1',
|
||||
expected_input_image_shape=[240, 240]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2_spec_test',
|
||||
model_spec=ms.efficientnet_lite2_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
||||
expected_name='efficientnet_lite2',
|
||||
expected_input_image_shape=[260, 260]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite3_spec_test',
|
||||
model_spec=ms.efficientnet_lite3_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
||||
expected_name='efficientnet_lite3',
|
||||
expected_input_image_shape=[280, 280]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4_spec_test',
|
||||
model_spec=ms.efficientnet_lite4_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
||||
expected_name='efficientnet_lite4',
|
||||
expected_input_image_shape=[300, 300]),
|
||||
)
|
||||
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
|
||||
expected_uri: str, expected_name: str,
|
||||
expected_input_image_shape: List[int]):
|
||||
model_spec_obj = model_spec()
|
||||
self.assertIsInstance(model_spec_obj, ms.ModelSpec)
|
||||
self.assertEqual(model_spec_obj.uri, expected_uri)
|
||||
self.assertEqual(model_spec_obj.name, expected_name)
|
||||
self.assertEqual(model_spec_obj.input_image_shape,
|
||||
expected_input_image_shape)
|
||||
|
||||
def test_create_spec(self):
|
||||
custom_model_spec = ms.ModelSpec(
|
||||
uri='https://custom_model',
|
||||
input_image_shape=[128, 128],
|
||||
name='custom_model')
|
||||
self.assertEqual(custom_model_spec.uri, 'https://custom_model')
|
||||
self.assertEqual(custom_model_spec.name, 'custom_model')
|
||||
self.assertEqual(custom_model_spec.input_image_shape, [128, 128])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Library to train model."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
|
||||
|
||||
def _create_optimizer(init_lr: float, decay_steps: int,
|
||||
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
|
||||
"""Creates an optimizer with learning rate schedule.
|
||||
|
||||
Uses Keras CosineDecay schedule for the learning rate by default.
|
||||
|
||||
Args:
|
||||
init_lr: Initial learning rate.
|
||||
decay_steps: Number of steps to decay over.
|
||||
warmup_steps: Number of steps to do warmup for.
|
||||
|
||||
Returns:
|
||||
A tf.keras.optimizers.Optimizer for model training.
|
||||
"""
|
||||
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
|
||||
if warmup_steps:
|
||||
learning_rate_fn = model_util.WarmUp(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate_fn,
|
||||
warmup_steps=warmup_steps)
|
||||
optimizer = tf.keras.optimizers.RMSprop(
|
||||
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
|
||||
"""Gets default callbacks."""
|
||||
summary_dir = os.path.join(model_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
# Save checkpoint every 20 epochs.
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
checkpoint_path, save_weights_only=True, period=20)
|
||||
return [summary_callback, checkpoint_callback]
|
||||
|
||||
|
||||
def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||
train_ds: tf.data.Dataset,
|
||||
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
|
||||
"""Trains model with the input data and hyperparameters.
|
||||
|
||||
Args:
|
||||
model: Input tf.keras.Model.
|
||||
hparams: Hyperparameters for training image classifier.
|
||||
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
|
||||
validation_ds: tf.data.Dataset, validation data to be fed in
|
||||
tf.keras.Model.fit().
|
||||
|
||||
Returns:
|
||||
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
|
||||
"""
|
||||
|
||||
# Learning rate is linear to batch size.
|
||||
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
||||
|
||||
# Get decay steps.
|
||||
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
|
||||
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
||||
decay_steps = max(total_training_steps, default_decay_steps)
|
||||
|
||||
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
|
||||
optimizer = _create_optimizer(
|
||||
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
|
||||
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
label_smoothing=hparams.label_smoothing)
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
callbacks = _get_default_callbacks(hparams.model_dir)
|
||||
|
||||
# Train the model.
|
||||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.train_epochs,
|
||||
steps_per_epoch=hparams.steps_per_epoch,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
6
mediapipe/model_maker/requirements.txt
Normal file
6
mediapipe/model_maker/requirements.txt
Normal file
|
@ -0,0 +1,6 @@
|
|||
absl-py
|
||||
numpy
|
||||
opencv-contrib-python
|
||||
tensorflow
|
||||
tensorflow-datasets
|
||||
tensorflow-hub
|
|
@ -244,7 +244,7 @@
|
|||
if ([_session canAddOutput:_depthDataOutput]) {
|
||||
[_session addOutput:_depthDataOutput];
|
||||
|
||||
AVCaptureConnection* connection =
|
||||
AVCaptureConnection* __unused connection =
|
||||
[_depthDataOutput connectionWithMediaType:AVMediaTypeDepthData];
|
||||
|
||||
// Set this when we have a handler.
|
||||
|
@ -327,7 +327,6 @@
|
|||
if (depthData.depthDataType != kCVPixelFormatType_DepthFloat32) {
|
||||
depthData = [depthData depthDataByConvertingToDepthDataType:kCVPixelFormatType_DepthFloat32];
|
||||
}
|
||||
CVPixelBufferRef depthBuffer = depthData.depthDataMap;
|
||||
[self.delegate processDepthData:depthData timestamp:timestamp fromSource:self];
|
||||
}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user