Merge branch 'google:master' into image-segmenter-python-impl

This commit is contained in:
Kinar R 2022-10-18 15:13:44 +05:30 committed by GitHub
commit 36ac0689d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
349 changed files with 21692 additions and 1716 deletions

View File

@ -53,7 +53,7 @@ RUN pip3 install wheel
RUN pip3 install future RUN pip3 install future
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1 RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
RUN pip3 install six==1.14.0 RUN pip3 install six==1.14.0
RUN pip3 install tensorflow==2.2.0 RUN pip3 install tensorflow
RUN pip3 install tf_slim RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python

View File

@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
} }
``` ```
## Graph Options
It is possible to specify a "graph options" protobuf for a MediaPipe graph
similar to the [`Calculator Options`](calculators.md#calculator-options)
protobuf specified for a MediaPipe calculator. These "graph options" can be
specified where a graph is invoked, and used to populate calculator options and
subgraph options within the graph.
In a CalculatorGraphConfig, graph options can be specified for a subgraph
exactly like calculator options, as shown below:
```
node {
calculator: "FlowLimiterCalculator"
input_stream: "image"
output_stream: "throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
max_in_flight: 1
}
}
}
node {
calculator: "FaceDetectionSubgraph"
input_stream: "IMAGE:throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
tensor_width: 192
tensor_height: 192
}
}
}
```
In a CalculatorGraphConfig, graph options can be accepted and used to populate
calculator options, as shown below:
```
graph_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
}
node: {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:multi_backend_image"
node_options: {
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
keep_aspect_ratio: true
border_mode: BORDER_ZERO
}
}
option_value: "output_tensor_width:options/tensor_width"
option_value: "output_tensor_height:options/tensor_height"
}
node {
calculator: "InferenceCalculator"
node_options: {
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
}
option_value: "delegate:options/delegate"
option_value: "model_path:options/model_path"
}
```
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
values in the calculator options `ImageToTensorCalculatorOptions` and some field
values in the subgraph options `InferenceCalculatorOptions`. The field values
are defined using the `option_value:` syntax.
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
`option_value:` together define the option values for a calculator such as
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
constant values using the text protobuf syntax. Each `option_value:` field
defines the value for one protobuf field using information from the enclosing
graph, specifically from field values of the graph options of the enclosing
graph. In the example above, the `option_value:`
`"output_tensor_width:options/tensor_width"` defines the field
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
`FaceDetectionOptions.tensor_width`.
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
field and the RHS identifies a graph option field. More specifically, the LHS
and RHS each consists of a series of protobuf field names identifying nested
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
syntax. Nested messages that are referenced in the LHS or RHS must already be
defined in the enclosing protobuf in order to be traversed using
`option_value:`.
## Cycles ## Cycles
<!-- TODO: add discussion of PreviousLoopbackCalculator --> <!-- TODO: add discussion of PreviousLoopbackCalculator -->

View File

@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow
```bash ```bash
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt
``` ```
This will open up your webcam as long as it is connected and on. Any errors This will open up your webcam as long as it is connected and on. Any errors

View File

@ -209,11 +209,18 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "rotation_mode_proto",
srcs = ["rotation_mode.proto"],
visibility = ["//visibility:public"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "image_transformation_calculator_proto", name = "image_transformation_calculator_proto",
srcs = ["image_transformation_calculator.proto"], srcs = ["image_transformation_calculator.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":rotation_mode_proto",
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto", "//mediapipe/gpu:scale_mode_proto",
@ -238,6 +245,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":rotation_mode_cc_proto",
":image_transformation_calculator_cc_proto", ":image_transformation_calculator_cc_proto",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h" #include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/calculators/image/rotation_mode.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"

View File

@ -16,20 +16,10 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/calculators/image/rotation_mode.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto"; import "mediapipe/gpu/scale_mode.proto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}
message ImageTransformationCalculatorOptions { message ImageTransformationCalculatorOptions {
extend CalculatorOptions { extend CalculatorOptions {
optional ImageTransformationCalculatorOptions ext = 251952830; optional ImageTransformationCalculatorOptions ext = 251952830;

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}

View File

@ -161,6 +161,98 @@ cc_test(
], ],
) )
mediapipe_proto_library(
name = "bert_preprocessor_calculator_proto",
srcs = ["bert_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bert_preprocessor_calculator",
srcs = ["bert_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":bert_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "bert_preprocessor_calculator_test",
srcs = ["bert_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":bert_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library(
name = "regex_preprocessor_calculator_proto",
srcs = ["regex_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "regex_preprocessor_calculator",
srcs = ["regex_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":regex_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],
@ -320,6 +412,8 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
], ],
) )

View File

@ -0,0 +1,251 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kTokenizerProcessUnitIndex = 0;
constexpr absl::string_view kInputIdsTensorName = "ids";
constexpr absl::string_view kInputMasksTensorName = "mask";
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
constexpr absl::string_view kClassifierToken = "[CLS]";
constexpr absl::string_view kSeparatorToken = "[SEP]";
// Preprocesses input text into three int32 input tensors for a BERT model using
// a tokenizer.
// The associated BERT model is expected to contain input tensors with names:
//
// Tensor | Metadata Name
// ---------------- | --------------
// IDs | "ids"
// Segment IDs | "segment_ids"
// Mask | "mask"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have names corresponding to the above
// metadata names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// This calculator is currently configured for the TextClassifier Task but it
// will eventually be generalized for other Text Tasks.
// TODO: Handle preprocessing for other Text Tasks too.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the BERT model. Used to determine the order of
// the three input Tensors for the BERT model and to extract the metadata to
// construct the tokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the BERT model:
// (1): the token ids of the tokenized input string. A classifier token
// ("[CLS]") will be prepended to the input tokens and a separator
// token ("[SEP]") will be appended to the input tokens.
// (2): the segment ids, which are all 0 for now but will have different
// values to distinguish between different sentences in the input
// text for other Text tasks.
// (3): the input mask ids, which are 1 at each of the input token indices
// and 0 elsewhere.
// The Tensors will have size equal to the max sequence length for the BERT
// model.
//
// Example:
// node {
// calculator: "BertPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
// bert_max_seq_len: 128
// }
// }
// }
class BertPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
// The max sequence length accepted by the BERT model.
int bert_max_seq_len_ = 2;
// Indices of the three input tensors for the BERT model. They should form the
// set {0, 1, 2}.
int input_ids_tensor_index_ = 0;
int segment_ids_tensor_index_ = 1;
int input_masks_tensor_index_ = 2;
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
// Processes the `input_tokens` to generate the three input tensors for the
// BERT model.
std::vector<Tensor> GenerateInputTensors(
const std::vector<std::string>& input_tokens);
};
absl::Status BertPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
RET_CHECK_GE(options.bert_max_seq_len(), 2)
<< "bert_max_seq_len must be at least 2";
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::ProcessUnit* tokenizer_metadata =
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
tokenizer_metadata, metadata_extractor));
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputIdsTensorName);
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kSegmentIdsTensorName);
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputMasksTensorName);
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
segment_ids_tensor_index_,
input_masks_tensor_index_};
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
input_ids_tensor_index_, segment_ids_tensor_index_,
input_masks_tensor_index_));
}
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
bert_max_seq_len_ = options.bert_max_seq_len();
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
kTensorsOut(cc).Send(
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
return absl::OkStatus();
}
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
absl::string_view input_text) {
std::string processed_input = std::string(input_text);
absl::AsciiStrToLower(&processed_input);
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(processed_input);
// Offset by 2 to account for [CLS] and [SEP]
int input_tokens_size =
std::min(bert_max_seq_len_,
static_cast<int>(tokenizer_result.subwords.size()) + 2);
std::vector<std::string> input_tokens;
input_tokens.reserve(input_tokens_size);
input_tokens.push_back(std::string(kClassifierToken));
for (int i = 0; i < input_tokens_size - 2; ++i) {
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
}
input_tokens.push_back(std::string(kSeparatorToken));
return input_tokens;
}
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
const std::vector<std::string>& input_tokens) {
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < input_tokens.size(); ++i) {
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_masks[i] = 1;
}
// |<--------bert_max_seq_len_--------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
}
std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_ids.data(), input_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[segment_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[input_masks_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_masks.data(), input_masks.size() * sizeof(int32_t));
return input_tensors;
}
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -13,22 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ syntax = "proto2";
#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
#include "tensorflow/lite/kernels/register.h" package mediapipe;
namespace mediapipe { import "mediapipe/framework/calculator.proto";
namespace tasks {
namespace vision {
class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver {
public:
HandDetectorOpResolver();
HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete;
};
} // namespace vision message BertPreprocessorCalculatorOptions {
} // namespace tasks extend mediapipe.CalculatorOptions {
} // namespace mediapipe optional BertPreprocessorCalculatorOptions ext = 462509271;
}
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ // The maximum input sequence length for the calculator's BERT model.
optional int32 bert_max_seq_len = 1;
}

View File

@ -0,0 +1,154 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kBertMaxSeqLen = 128;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
absl::string_view text, absl::string_view model_path) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "BertPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
options {
[mediapipe.BertPreprocessorCalculatorOptions.ext] {
bert_max_seq_len: $0
}
}
}
)",
kBertMaxSeqLen));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForBert) {
return absl::InvalidArgumentError(
absl::Substitute("tensor_vec has size $0, expected $1",
tensor_vec.size(), kNumInputTensorsForBert));
}
std::vector<std::vector<int>> results;
for (int i = 0; i < kNumInputTensorsForBert; i++) {
const Tensor& tensor = tensor_vec[i];
if (tensor.element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32");
}
auto* buffer = tensor.GetCpuReadView().buffer<int>();
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen);
results.push_back(buffer_view);
}
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return results;
}
TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) {
std::vector<std::vector<int>> expected_result = {
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}};
// segment_ids
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
// input_masks
expected_result.push_back(std::vector(expected_result[0].size(), 1));
expected_result[2].resize(kBertMaxSeqLen);
// padding input_ids
expected_result[0].resize(kBertMaxSeqLen);
MP_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<int>> processed_tensor_values,
RunBertPreprocessorCalculator(
"it's a charming and often affecting journey", kTestModelPath));
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
TEST(BertPreprocessorCalculatorTest, LongInput) {
std::stringstream long_input;
long_input
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kBertMaxSeqLen; ++i) {
long_input << " long";
}
long_input << " movie review";
std::vector<std::vector<int>> expected_result = {
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023,
2003, 1037}};
// "long" id
expected_result[0].resize(kBertMaxSeqLen - 1, 2146);
// "[SEP]" id
expected_result[0].push_back(102);
// segment_ids
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
// input_masks
expected_result.push_back(std::vector(kBertMaxSeqLen, 1));
MP_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<int>> processed_tensor_values,
RunBertPreprocessorCalculator(long_input.str(), kTestModelPath));
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
} // namespace
} // namespace mediapipe

View File

@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
} }
ASSIGN_OR_RETURN(auto image, GetInputImage(cc)); ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
const Size size{image->width(), image->height()};
RotatedRect roi = GetRoi(size.width, size.height, norm_rect); RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
options_.output_tensor_height(), options_.output_tensor_height(),
options_.keep_aspect_ratio(), &roi)); options_.keep_aspect_ratio(), &roi));
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
} }
if (kOutMatrix(cc).IsConnected()) { if (kOutMatrix(cc).IsConnected()) {
std::array<float, 16> matrix; std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height, GetRotatedSubRectToRectTransformMatrix(
/*flip_horizontaly=*/false, roi, image->width(), image->height(),
&matrix); /*flip_horizontaly=*/false, &matrix);
kOutMatrix(cc).Send(std::move(matrix)); kOutMatrix(cc).Send(std::move(matrix));
} }
// Lazy initialization of the GPU or CPU converter. // Lazy initialization of the GPU or CPU converter.
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get())); MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
ASSIGN_OR_RETURN(Tensor tensor, Tensor::ElementType output_tensor_type =
(image->UsesGpu() ? gpu_converter_ : cpu_converter_) GetOutputTensorType(image->UsesGpu());
->Convert(*image, roi, {output_width_, output_height_}, Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
range_min_, range_max_)); GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_,
/*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>(); auto result = std::make_unique<std::vector<Tensor>>();
result->push_back(std::move(tensor)); result->push_back(std::move(tensor));
@ -292,7 +295,8 @@ class ImageToTensorCalculator : public Node {
} }
} }
Tensor::ElementType GetOutputTensorType() { Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (!uses_gpu) {
if (is_float_output_) { if (is_float_output_) {
return Tensor::ElementType::kFloat32; return Tensor::ElementType::kFloat32;
} }
@ -302,6 +306,21 @@ class ImageToTensorCalculator : public Node {
return Tensor::ElementType::kUInt8; return Tensor::ElementType::kUInt8;
} }
} }
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage( absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
CalculatorContext* cc) { CalculatorContext* cc) {
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
#if !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
cpu_converter_, cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType())); CreateOpenCvConverter(cc, GetBorderMode(),
GetOutputTensorType(/*uses_gpu=*/false)));
#else #else
LOG(FATAL) << "Cannot create image to tensor opencv converter since " LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined."; "MEDIAPIPE_DISABLE_OPENCV is defined.";

View File

@ -42,13 +42,16 @@ class ImageToTensorConverter {
// @image contains image to extract from. // @image contains image to extract from.
// @roi describes region of interest within the image to extract (absolute // @roi describes region of interest within the image to extract (absolute
// values). // values).
// @output_dims dimensions of output tensor.
// @range_min/max describes output tensor range image pixels should converted // @range_min/max describes output tensor range image pixels should converted
// to. // to.
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, // @tensor_buffer_offset an inteter representing the offset of the tensor
const RotatedRect& roi, // buffer the result should be written to.
const Size& output_dims, // @output_tensor a tensor with pre-defined shape. The "Convert" is
float range_min, float range_max) = 0; // responsible of populating the content into the output tensor.
virtual absl::Status Convert(const mediapipe::Image& input,
const RotatedRect& roi, float range_min,
float range_max, int tensor_buffer_offset,
Tensor& output_tensor) = 0;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -264,10 +264,10 @@ class GlProcessor : public ImageToTensorConverter {
}); });
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -275,46 +275,46 @@ class GlProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ", "Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3; MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
Tensor tensor(Tensor::ElementType::kFloat32, [this, &output_tensor, &input, &roi, &output_shape, range_min,
{1, output_dims.height, output_dims.width, kNumChannels}); range_max, tensor_buffer_offset]() -> absl::Status {
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi,
&output_dims, range_min,
range_max]() -> absl::Status {
constexpr int kRgbaNumChannels = 4; constexpr int kRgbaNumChannels = 4;
auto source_texture = gl_helper_.CreateSourceTexture(input); auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture( tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA, GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
source_texture.width() * source_texture.height() * kRgbaNumChannels * source_texture.width() * source_texture.height() *
sizeof(uint8_t), kRgbaNumChannels * sizeof(uint8_t),
/*layer=*/0, /*layer=*/0,
/*owned=*/false); /*owned=*/false);
constexpr float kInputImageRangeMin = 0.0f; constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f; constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(auto transform,
auto transform, GetValueRangeTransformation(kInputImageRangeMin,
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, kInputImageRangeMax,
range_min, range_max)); range_min, range_max));
auto buffer_view = tensor.GetOpenGlBufferWriteView(); const int output_size = output_tensor.bytes() / output_shape.dims[0];
auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
buffer_view.name(), tensor.bytes(), buffer_view.name(), output_size,
/*offset=*/0, /*offset=*/tensor_buffer_offset,
/*has_ownership=*/false); /*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer( MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture, input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi, tflite::gpu::HW(source_texture.height(), source_texture.width()),
roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset, /*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width), tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_queue_.get(), &output)); command_queue_.get(), &output));
return absl::OkStatus(); return absl::OkStatus();
})); }));
return tensor; return absl::OkStatus();
} }
~GlProcessor() override { ~GlProcessor() override {
@ -326,6 +326,17 @@ class GlProcessor : public ImageToTensorConverter {
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_; std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
std::unique_ptr<SubRectExtractorGl> extractor_; std::unique_ptr<SubRectExtractorGl> extractor_;
mediapipe::GlCalculatorHelper gl_helper_; mediapipe::GlCalculatorHelper gl_helper_;

View File

@ -168,10 +168,10 @@ class GlProcessor : public ImageToTensorConverter {
}); });
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -179,15 +179,15 @@ class GlProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ", "Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }
// TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3; MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
Tensor tensor( [this, &output_tensor, &input, &roi, &output_shape, range_min,
Tensor::ElementType::kFloat32, range_max]() -> absl::Status {
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
MP_RETURN_IF_ERROR(
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
range_min, range_max]() -> absl::Status {
auto input_texture = gl_helper_.CreateSourceTexture(input); auto input_texture = gl_helper_.CreateSourceTexture(input);
constexpr float kInputImageRangeMin = 0.0f; constexpr float kInputImageRangeMin = 0.0f;
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin, GetValueRangeTransformation(kInputImageRangeMin,
kInputImageRangeMax, kInputImageRangeMax,
range_min, range_max)); range_min, range_max));
auto tensor_view = tensor.GetOpenGlTexture2dWriteView(); auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi, MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
/*flip_horizontaly=*/false, /*flip_horizontaly=*/false,
transform.scale, transform.offset, transform.scale, transform.offset,
output_dims, &tensor_view)); output_shape, &tensor_view));
return absl::OkStatus(); return absl::OkStatus();
})); }));
return tensor; return absl::OkStatus();
} }
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
const RotatedRect& sub_rect, const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta, bool flip_horizontaly, float alpha, float beta,
const Size& output_dims, const Tensor::Shape& output_shape,
Tensor::OpenGlTexture2dView* output) { Tensor::OpenGlTexture2dView* output) {
const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2];
std::array<float, 16> transform_mat; std::array<float, 16> transform_mat;
glDisable(GL_DEPTH_TEST); glDisable(GL_DEPTH_TEST);
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
glViewport(0, 0, output_dims.width, output_dims.height); glViewport(0, 0, output_width, output_height);
glActiveTexture(GL_TEXTURE0); glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, output->name()); glBindTexture(GL_TEXTURE_2D, output->name());
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
mediapipe::GlCalculatorHelper gl_helper_; mediapipe::GlCalculatorHelper gl_helper_;
bool use_custom_zero_border_ = false; bool use_custom_zero_border_ = false;
BorderMode border_mode_ = BorderMode::kReplicate; BorderMode border_mode_ = BorderMode::kReplicate;

View File

@ -262,7 +262,6 @@ class SubRectExtractorMetal {
RET_CHECK(pipeline_state != nil); RET_CHECK(pipeline_state != nil);
std::string output_type_def; std::string output_type_def;
MTLPixelFormat pixel_format;
switch (output_format) { switch (output_format) {
case OutputFormat::kF16C4: case OutputFormat::kF16C4:
output_type_def = R"( output_type_def = R"(
@ -348,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -359,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ", "Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@autoreleasepool { @autoreleasepool {
id<MTLTexture> texture = id<MTLTexture> texture =
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()]; [metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
constexpr int kNumChannels = 4;
Tensor tensor(Tensor::ElementType::kFloat32,
Tensor::Shape{1, output_dims.height, output_dims.width,
kNumChannels});
constexpr float kInputImageRangeMin = 0.0f; constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f; constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
@ -377,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
range_min, range_max)); range_min, range_max));
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer]; id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer); const auto& buffer_view =
output_tensor.GetMtlBufferWriteView(command_buffer);
MP_RETURN_IF_ERROR(extractor_->Execute( MP_RETURN_IF_ERROR(extractor_->Execute(
texture, roi, texture, roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset, /*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width), tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_buffer, buffer_view.buffer())); command_buffer, buffer_view.buffer()));
[command_buffer commit]; [command_buffer commit];
return tensor; return absl::OkStatus();
} }
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 4)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
MPPMetalHelper* metal_helper_ = nil; MPPMetalHelper* metal_helper_ = nil;
std::unique_ptr<SubRectExtractorMetal> extractor_; std::unique_ptr<SubRectExtractorMetal> extractor_;
}; };

View File

@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
} }
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.image_format() != mediapipe::ImageFormat::SRGB && if (input.image_format() != mediapipe::ImageFormat::SRGB &&
input.image_format() != mediapipe::ImageFormat::SRGBA) { input.image_format() != mediapipe::ImageFormat::SRGBA) {
return InvalidArgumentError( return InvalidArgumentError(
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
static_cast<uint32_t>(input.image_format()))); static_cast<uint32_t>(input.image_format())));
} }
auto src = mediapipe::formats::MatView(&input); // TODO: Remove the check once tensor_buffer_offset > 0 is
// supported.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3; const int output_height = output_shape.dims[1];
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height, const int output_width = output_shape.dims[2];
output_dims.width, kNumChannels}); const int output_channels = output_shape.dims[3];
auto buffer_view = tensor.GetCpuWriteView(); auto buffer_view = output_tensor.GetCpuWriteView();
cv::Mat dst; cv::Mat dst;
switch (tensor_type_) { switch (tensor_type_) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<int8>()); buffer_view.buffer<int8>());
break; break;
case Tensor::ElementType::kFloat32: case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<float>()); buffer_view.buffer<float>());
break; break;
case Tensor::ElementType::kUInt8: case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<uint8>()); buffer_view.buffer<uint8>());
break; break;
default: default:
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
cv::Mat src_points; cv::Mat src_points;
cv::boxPoints(rotated_rect, src_points); cv::boxPoints(rotated_rect, src_points);
const float dst_width = output_dims.width; const float dst_width = output_width;
const float dst_height = output_dims.height; const float dst_height = output_height;
/* clang-format off */ /* clang-format off */
float dst_corners[8] = {0.0f, dst_height, float dst_corners[8] = {0.0f, dst_height,
0.0f, 0.0f, 0.0f, 0.0f,
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
dst_width, dst_height}; dst_width, dst_height};
/* clang-format on */ /* clang-format on */
auto src = mediapipe::formats::MatView(&input);
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners); cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
cv::Mat projection_matrix = cv::Mat projection_matrix =
cv::getPerspectiveTransform(src_points, dst_points); cv::getPerspectiveTransform(src_points, dst_points);
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
/*flags=*/cv::INTER_LINEAR, /*flags=*/cv::INTER_LINEAR,
/*borderMode=*/border_mode_); /*borderMode=*/border_mode_);
if (transformed.channels() > kNumChannels) { if (transformed.channels() > output_channels) {
cv::Mat proper_channels_mat; cv::Mat proper_channels_mat;
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB); cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
transformed = proper_channels_mat; transformed = proper_channels_mat;
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
range_min, range_max)); range_min, range_max));
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
return tensor; return absl::OkStatus();
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
enum cv::BorderTypes border_mode_; enum cv::BorderTypes border_mode_;
Tensor::ElementType tensor_type_; Tensor::ElementType tensor_type_;
int mat_type_; int mat_type_;

View File

@ -224,9 +224,6 @@ absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
void InferenceCalculatorMetalImpl::AddDelegate( void InferenceCalculatorMetalImpl::AddDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>();
// Configure and create the delegate. // Configure and create the delegate.
TFLGpuDelegateOptions options; TFLGpuDelegateOptions options;
// `enable_quantization` enables the run of sparse models i.e. the models with // `enable_quantization` enables the run of sparse models i.e. the models with

View File

@ -21,8 +21,10 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe { namespace mediapipe {
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
} }
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
tflite::DynamicBuffer dynamic_buffer;
dynamic_buffer.AddString(input_tensor_buffer,
input_tensor.shape().num_elements());
dynamic_buffer.WriteToTensorAsVector(
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
}
template <typename T> template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
int output_tensor_index, int output_tensor_index,
@ -87,12 +102,12 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
break; break;
} }
case TfLiteType::kTfLiteUInt8: { case TfLiteType::kTfLiteUInt8: {
CopyTensorBufferToInterpreter<uint8>(input_tensors[i], CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteInt8: { case TfLiteType::kTfLiteInt8: {
CopyTensorBufferToInterpreter<int8>(input_tensors[i], CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteString: {
CopyTensorBufferToInterpreter<char>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteBool:
// No current use-case for copying MediaPipe Tensors with bool type to
// TfLiteTensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type)); absl::StrCat("Unsupported input tensor type:", input_tensor_type));
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i, CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
&output_tensors.back()); &output_tensors.back());
break; break;
case TfLiteType::kTfLiteBool:
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
Tensor::QuantizationParameters{1.0f, 0});
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteString:
// No current use-case for copying TfLiteTensors with string type to
// MediaPipe Tensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported output tensor type:", absl::StrCat("Unsupported output tensor type:",

View File

@ -0,0 +1,174 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
// Preprocesses input text into one int32 input tensor for a text model using
// a RegexTokenizer.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the text model. Used to extract the metadata
// to construct the RegexTokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor which is the text model's input tensor.
// Depending on the tokenizer metadata, the tensor may start with
// the id of the tokenizer's <START> token. The following tensor values will
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
// have the id of the <UNKNOWN> token. The tensor will be padded with the
// <PAD> token id to have size equal to the max sequence length for the text
// model.
//
// Example:
// node {
// calculator: "RegexPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
// max_seq_len: 256
// }
// }
// }
class RegexPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
// The max sequence length accepted by the text model.
int max_seq_len_ = 0;
};
absl::Status RegexPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::TensorMetadata* tensor_metadata =
metadata_extractor->GetInputTensorMetadata(0);
if (tensor_metadata == nullptr) {
return absl::InvalidArgumentError("No tensor metadata found");
}
ASSIGN_OR_RETURN(
const auto* tokenizer_metadata,
metadata_extractor->FindFirstProcessUnit(
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
if (tokenizer_metadata == nullptr) {
return absl::InvalidArgumentError("No tokenizer metadata found");
}
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
regex_tokenizer_options, metadata_extractor));
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
max_seq_len_ = options.max_seq_len();
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(kTextIn(cc).Get());
int unknown_token_id = 0;
tokenizer_->GetUnknownToken(&unknown_token_id);
int pad_token_id = 0;
tokenizer_->GetPadToken(&pad_token_id);
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
int start_token_id = 0;
int input_token_index = 0;
if (tokenizer_->GetStartToken(&start_token_id)) {
input_tokens[0] = start_token_id;
input_token_index = 1;
}
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
(input_token_index < max_seq_len_);
++i, ++input_token_index) {
const std::string& token = tokenizer_result.subwords[i];
int token_id = 0;
if (tokenizer_->LookupId(token, &token_id)) {
input_tokens[input_token_index] = token_id;
} else {
input_tokens[input_token_index] = unknown_token_id;
}
}
// |<-------sentence_length-------->|
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
// not found in the tokenizer vocab.
std::vector<Tensor> result;
result.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
kTensorsOut(cc).Send(std::move(result));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message RegexPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional RegexPreprocessorCalculatorOptions ext = 463716697;
}
// The maximum input sequence length for the calculator's text model.
optional int32 max_seq_len = 1;
}

View File

@ -296,7 +296,6 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) {
output_tensors->emplace_back(Tensor::ElementType::kFloat32, output_tensors->emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{1, height, width, channels}); Tensor::Shape{1, height, width, channels});
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer]; id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TensorConverterCalculatorConvert"; command_buffer.label = @"TensorConverterCalculatorConvert";
id<MTLComputeCommandEncoder> compute_encoder = id<MTLComputeCommandEncoder> compute_encoder =

View File

@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
Dequantize<int8>(input_tensor, &output_tensors->back()); Dequantize<int8>(input_tensor, &output_tensors->back());
break; break;
case Tensor::ElementType::kBool:
Dequantize<bool>(input_tensor, &output_tensors->back());
break;
default: default:
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"Unsupported input tensor type: ", input_tensor.element_type())); "Unsupported input tensor type: ", input_tensor.element_type()));

View File

@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
ValidateResult(GetOutput(), {-1.007874, 0, 1}); ValidateResult(GetOutput(), {-1.007874, 0, 1});
} }
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
std::vector<bool> tensor = {true, false, true};
PushTensor(Tensor::ElementType::kBool, tensor,
Tensor::QuantizationParameters{1.0f, 0});
MP_ASSERT_OK(runner_.Run());
ValidateResult(GetOutput(), {1, 0, 1});
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -532,7 +532,6 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
detection_classes.data(), detection_classes.data(),
output_detections)); output_detections));
#elif MEDIAPIPE_METAL_ENABLED #elif MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
if (!anchors_init_) { if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);

View File

@ -0,0 +1,167 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <array>
#include <cstring>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
constexpr absl::string_view kResponseContextMetadataName = "res_context";
constexpr absl::string_view kResponseTextMetadataName = "res_text";
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
// Preprocesses input text into three kTfLiteString input tensors for a
// Universal Sentence Encoder (USE) model.
//
// The associated USE model is expected to contain input tensors with metadata
// names:
//
// Tensor | Metadata Name
// ---------------- | ------------------
// Query text | "inp_text"
// Response context | "res_context"
// Response text | "res_text"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have metadata names corresponding to the
// above names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// Inputs:
// TEXT - std::string
// The text to be embedded.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the USE model. Used to determine the order of
// the three input Tensors for the USE model.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the USE model. The tensors
// fit a question-answering setting and store a query text, a response
// context, and a response text. This calculator will just be preprocessing
// a single input text that will be stored in the response text tensor. The
// query text and response context tensors will store empty strings.
//
// Example:
// node {
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// }
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Indices of the three input tensors for the USE model. They should form the
// set {0, 1, 2}.
int query_text_tensor_index_ = 0;
int response_context_tensor_index_ = 1;
int response_text_tensor_index_ = 2;
// Tensor shapes for the model's input tensors.
// The query text and response context tensors will only hold the empty
// string, so their tensors will have shape [0], but the Universal Sentence
// Encoder model's input signature requires them to be present. The response
// text tensor will store the embedding text and have shape
// [embedding_text_len].
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
};
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
query_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kQueryTextMetadataName);
response_context_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseContextMetadataName);
response_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseTextMetadataName);
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
{query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_});
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_));
}
return absl::OkStatus();
}
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
CalculatorContext* cc) {
absl::string_view text = kTextIn(cc).Get();
const int text_len = static_cast<int>(text.length());
tensor_shapes_[response_text_tensor_index_] = text_len;
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
}
std::memcpy(
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_context_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_text_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
text.data(), text_len * sizeof(char));
kTensorsOut(cc).Send(std::move(input_tensors));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,111 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::IsOkAndHolds;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/"
"universal_sentence_encoder_qa_with_metadata.tflite";
absl::StatusOr<std::vector<std::string>>
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer =
tasks::core::LoadBinaryContent(kTestModelPath.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(),
kNumInputTensorsForUniversalSentenceEncoder));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
}
std::vector<std::string> results;
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
results.push_back(
{tensor_vec[i].GetCpuReadView().buffer<char>(),
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
}
return results;
}
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
ASSERT_THAT(
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
}
} // namespace
} // namespace mediapipe

View File

@ -331,6 +331,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -499,7 +499,6 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
gpu_data_out_ = absl::make_unique<GPUData>(); gpu_data_out_ = absl::make_unique<GPUData>();
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_; gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
const bool include_alpha = (max_num_channels_ == 4); const bool include_alpha = (max_num_channels_ == 4);
const bool single_channel = (max_num_channels_ == 1);
if (!(format == mediapipe::ImageFormat::GRAY8 || if (!(format == mediapipe::ImageFormat::GRAY8 ||
format == mediapipe::ImageFormat::SRGB || format == mediapipe::ImageFormat::SRGB ||
format == mediapipe::ImageFormat::SRGBA)) format == mediapipe::ImageFormat::SRGBA))
@ -509,6 +508,7 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
const bool single_channel = (max_num_channels_ == 1);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> absl::Status { [this, &include_alpha, &input, &single_channel]() -> absl::Status {
// Device memory. // Device memory.

View File

@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -81,6 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
} }
if (cc->InputSidePackets().HasTag("MODEL_FD")) { if (cc->InputSidePackets().HasTag("MODEL_FD")) {
#ifdef ABSL_HAVE_MMAP
model_packet = cc->InputSidePackets().Tag("MODEL_FD"); model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd = const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>(); model_packet.Get<std::tuple<int, size_t, size_t>>();
@ -89,6 +91,10 @@ class TfLiteModelCalculator : public CalculatorBase {
tflite::DefaultErrorReporter()); tflite::DefaultErrorReporter());
model = tflite::FlatBufferModel::BuildFromAllocation( model = tflite::FlatBufferModel::BuildFromAllocation(
std::move(model_allocation), tflite::DefaultErrorReporter()); std::move(model_allocation), tflite::DefaultErrorReporter());
#elif
return absl::FailedPreconditionError(
"Loading by file descriptor is not supported on this platform.")
#endif
} }
RET_CHECK(model) << "Failed to load TfLite model from blob."; RET_CHECK(model) << "Failed to load TfLite model from blob.";

View File

@ -143,9 +143,7 @@ mediapipe_proto_library(
cc_library( cc_library(
name = "packet_frequency_calculator", name = "packet_frequency_calculator",
srcs = ["packet_frequency_calculator.cc"], srcs = ["packet_frequency_calculator.cc"],
visibility = [ visibility = ["//visibility:public"],
"//visibility:public",
],
deps = [ deps = [
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
"//mediapipe/calculators/util:packet_frequency_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto",
@ -190,9 +188,7 @@ cc_test(
cc_library( cc_library(
name = "packet_latency_calculator", name = "packet_latency_calculator",
srcs = ["packet_latency_calculator.cc"], srcs = ["packet_latency_calculator.cc"],
visibility = [ visibility = ["//visibility:public"],
"//visibility:public",
],
deps = [ deps = [
"//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:latency_cc_proto",
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto",

View File

@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
text->set_left(label_left_px_); text->set_left(label_left_px_);
text->set_baseline(label_baseline_px + i * label_height_px_); text->set_baseline(label_baseline_px + i * label_height_px_);
text->set_font_face(options_.font_face()); text->set_font_face(options_.font_face());
if (options_.outline_thickness() > 0) {
text->set_outline_thickness(options_.outline_thickness());
if (options_.outline_color_size() > 0) {
*(text->mutable_outline_color()) =
options_.outline_color(i % options_.outline_color_size());
} else {
text->mutable_outline_color()->set_r(0);
text->mutable_outline_color()->set_g(0);
text->mutable_outline_color()->set_b(0);
}
}
} }
cc->Outputs() cc->Outputs()
.Tag(kRenderDataTag) .Tag(kRenderDataTag)

View File

@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
// Thickness for drawing the label(s). // Thickness for drawing the label(s).
optional double thickness = 2 [default = 2]; optional double thickness = 2 [default = 2];
// Color of outline around each character, if any. One per label, as with
// color attribute.
repeated Color outline_color = 12;
// Thickness of outline around each character.
optional double outline_thickness = 11;
// The font height in absolute pixels. // The font height in absolute pixels.
optional int32 font_height_px = 3 [default = 50]; optional int32 font_height_px = 3 [default = 50];

View File

@ -18,7 +18,7 @@ import android.content.ClipDescription;
import android.content.Context; import android.content.Context;
import android.net.Uri; import android.net.Uri;
import android.os.Bundle; import android.os.Bundle;
import androidx.appcompat.widget.AppCompatEditText; import android.support.v7.widget.AppCompatEditText;
import android.util.AttributeSet; import android.util.AttributeSet;
import android.util.Log; import android.util.Log;
import android.view.inputmethod.EditorInfo; import android.view.inputmethod.EditorInfo;

View File

@ -1685,10 +1685,3 @@ cc_test(
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -14,15 +14,10 @@ cc_library(
name = "builder", name = "builder",
hdrs = ["builder.h"], hdrs = ["builder.h"],
deps = [ deps = [
":const_str",
":contract",
":node",
":packet",
":port", ":port",
"//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -5,12 +5,7 @@
#include <type_traits> #include <type_traits>
#include "absl/container/btree_map.h" #include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
@ -112,6 +107,17 @@ class MultiPort : public Single {
std::vector<std::unique_ptr<Base>>& vec_; std::vector<std::unique_ptr<Base>>& vec_;
}; };
namespace internal_builder {
template <typename T, typename U>
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>>;
} // namespace internal_builder
template <bool IsSide, typename T = internal::Generic>
class SourceImpl;
// These classes wrap references to the underlying source/destination // These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API. // endpoints, adding type information and the user-visible API.
template <bool IsSide, typename T = internal::Generic> template <bool IsSide, typename T = internal::Generic>
@ -122,16 +128,21 @@ class DestinationImpl {
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec) explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {} : DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
explicit DestinationImpl(DestinationBase* base) : base_(*base) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
template <typename U,
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
DestinationImpl<IsSide, U> Cast() {
return DestinationImpl<IsSide, U>(&base_);
}
private:
DestinationBase& base_; DestinationBase& base_;
template <bool Source_IsSide, typename Source_T>
friend class SourceImpl;
}; };
template <bool IsSide, typename T> template <bool IsSide, typename T>
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
public:
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
};
template <bool IsSide, typename T = internal::Generic>
class SourceImpl { class SourceImpl {
public: public:
using Base = SourceBase; using Base = SourceBase;
@ -171,12 +182,8 @@ class SourceImpl {
return AddTarget(dest); return AddTarget(dest);
} }
template <typename U> template <typename U,
struct AllowCast std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>> {};
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
SourceImpl<IsSide, U> Cast() { SourceImpl<IsSide, U> Cast() {
return SourceImpl<IsSide, U>(base_); return SourceImpl<IsSide, U>(base_);
} }
@ -186,12 +193,6 @@ class SourceImpl {
SourceBase* base_; SourceBase* base_;
}; };
template <bool IsSide, typename T>
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
public:
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
};
// A source and a destination correspond to an output/input stream on a node, // A source and a destination correspond to an output/input stream on a node,
// and a side source and side destination correspond to an output/input side // and a side source and side destination correspond to an output/input side
// packet. // packet.
@ -201,20 +202,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
template <typename T = internal::Generic> template <typename T = internal::Generic>
using Source = SourceImpl<false, T>; using Source = SourceImpl<false, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSource = MultiSourceImpl<false, T>; using MultiSource = MultiPort<Source<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using SideSource = SourceImpl<true, T>; using SideSource = SourceImpl<true, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSideSource = MultiSourceImpl<true, T>; using MultiSideSource = MultiPort<SideSource<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using Destination = DestinationImpl<false, T>; using Destination = DestinationImpl<false, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using SideDestination = DestinationImpl<true, T>; using SideDestination = DestinationImpl<true, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiDestination = MultiDestinationImpl<false, T>; using MultiDestination = MultiPort<Destination<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSideDestination = MultiDestinationImpl<true, T>; using MultiSideDestination = MultiPort<SideDestination<T>>;
class NodeBase { class NodeBase {
public: public:
@ -439,8 +440,9 @@ class Graph {
// Creates a node of a specific type. Should be used for pure interfaces, // Creates a node of a specific type. Should be used for pure interfaces,
// which do not have a built-in type string. // which do not have a built-in type string.
template <class Calc> template <class Calc>
Node<Calc>& AddNode(const std::string& type) { Node<Calc>& AddNode(absl::string_view type) {
auto node = std::make_unique<Node<Calc>>(type); auto node =
std::make_unique<Node<Calc>>(std::string(type.data(), type.size()));
auto node_p = node.get(); auto node_p = node.get();
nodes_.emplace_back(std::move(node)); nodes_.emplace_back(std::move(node));
return *node_p; return *node_p;
@ -448,16 +450,18 @@ class Graph {
// Creates a generic node, with no compile-time checking of inputs and // Creates a generic node, with no compile-time checking of inputs and
// outputs. This can be used for calculators whose contract is not visible. // outputs. This can be used for calculators whose contract is not visible.
GenericNode& AddNode(const std::string& type) { GenericNode& AddNode(absl::string_view type) {
auto node = std::make_unique<GenericNode>(type); auto node =
std::make_unique<GenericNode>(std::string(type.data(), type.size()));
auto node_p = node.get(); auto node_p = node.get();
nodes_.emplace_back(std::move(node)); nodes_.emplace_back(std::move(node));
return *node_p; return *node_p;
} }
// For legacy PacketGenerators. // For legacy PacketGenerators.
PacketGenerator& AddPacketGenerator(const std::string& type) { PacketGenerator& AddPacketGenerator(absl::string_view type) {
auto node = std::make_unique<PacketGenerator>(type); auto node = std::make_unique<PacketGenerator>(
std::string(type.data(), type.size()));
auto node_p = node.get(); auto node_p = node.get();
packet_gens_.emplace_back(std::move(node)); packet_gens_.emplace_back(std::move(node));
return *node_p; return *node_p;

View File

@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>(); node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output"); any_type_output.SetName("any_type_output");
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node { node {
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
output_stream: "ANY_OUTPUT:any_type_output" output_stream: "ANY_OUTPUT:any_type_output"
} }
input_stream: "GRAPH_ANY_INPUT:__stream_0" input_stream: "GRAPH_ANY_INPUT:__stream_0"
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
)pb"); )pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }

View File

@ -334,13 +334,6 @@ mediapipe_register_type(
deps = [":landmark_cc_proto"], deps = [":landmark_cc_proto"],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)
cc_library( cc_library(
name = "image", name = "image",
srcs = ["image.cc"], srcs = ["image.cc"],
@ -469,6 +462,10 @@ cc_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}), }),
defines = select({
"//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"],
"//conditions:default": [],
}),
linkopts = select({ linkopts = select({
"//mediapipe:ios": [ "//mediapipe:ios": [
"-framework CoreVideo", "-framework CoreVideo",

View File

@ -33,10 +33,3 @@ mediapipe_proto_library(
srcs = ["rasterization.proto"], srcs = ["rasterization.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -31,11 +31,12 @@
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
#import <Metal/Metal.h> #import <Metal/Metal.h>
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#ifndef MEDIAPIPE_NO_JNI
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
#define MEDIAPIPE_TENSOR_USE_AHWB 1 #define MEDIAPIPE_TENSOR_USE_AHWB 1
#endif // __ANDROID_API__ >= 26 || #endif // __ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
#endif // MEDIAPIPE_NO_JNI
#ifdef MEDIAPIPE_TENSOR_USE_AHWB #ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <android/hardware_buffer.h> #include <android/hardware_buffer.h>
@ -43,7 +44,6 @@
#include "third_party/GL/gl/include/EGL/egl.h" #include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h" #include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_context.h"
@ -97,8 +97,8 @@ class Tensor {
kUInt8, kUInt8,
kInt8, kInt8,
kInt32, kInt32,
// TODO: Update the inference runner to handle kTfLiteString. kChar,
kChar kBool
}; };
struct Shape { struct Shape {
Shape() = default; Shape() = default;
@ -330,6 +330,8 @@ class Tensor {
return sizeof(int32_t); return sizeof(int32_t);
case ElementType::kChar: case ElementType::kChar:
return sizeof(char); return sizeof(char);
case ElementType::kBool:
return sizeof(bool);
} }
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) { if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method. // EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
glFinish(); gl_context_->Run([]() { glFinish(); });
} else if (valid_ & kValidAHardwareBuffer) { } else if (valid_ & kValidAHardwareBuffer) {
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the " CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
"completion function to be set"; "completion function to be set";

View File

@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
} }
TEST(Cpu, TestMemoryAllocation) { TEST(Cpu, TestMemoryAllocation) {

View File

@ -64,7 +64,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
std::vector<TraceEventType> basic_types = { std::vector<TraceEventType> basic_types = {
{TraceEvent::UNKNOWN, "An uninitialized trace-event."}, {TraceEvent::UNKNOWN, "An uninitialized trace-event."},
{TraceEvent::OPEN, "A call to Calculator::Open.", true, true}, {TraceEvent::OPEN, "A call to Calculator::Open.", true, true},
{TraceEvent::PROCESS, "A call to Calculator::Open.", true, true}, {TraceEvent::PROCESS, "A call to Calculator::Process.", true, true},
{TraceEvent::CLOSE, "A call to Calculator::Close.", true, true}, {TraceEvent::CLOSE, "A call to Calculator::Close.", true, true},
{TraceEvent::NOT_READY, "A calculator cannot process packets yet."}, {TraceEvent::NOT_READY, "A calculator cannot process packets yet."},

View File

@ -150,7 +150,7 @@ cc_library(
name = "executor_util", name = "executor_util",
srcs = ["executor_util.cc"], srcs = ["executor_util.cc"],
hdrs = ["executor_util.h"], hdrs = ["executor_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto",

View File

@ -1050,7 +1050,7 @@ objc_library(
alwayslink = 1, alwayslink = 1,
) )
MIN_IOS_VERSION = "9.0" # For thread_local. MIN_IOS_VERSION = "11.0"
test_suite( test_suite(
name = "ios", name = "ios",

View File

@ -111,7 +111,8 @@ typedef CVOpenGLESTextureCacheRef CVTextureCacheType;
- (CVMetalTextureCacheRef)mtlTextureCache { - (CVMetalTextureCacheRef)mtlTextureCache {
@synchronized(self) { @synchronized(self) {
if (!_mtlTextureCache) { if (!_mtlTextureCache) {
CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); CVReturn __unused err =
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err);
// TODO: register and flush metal caches too. // TODO: register and flush metal caches too.
} }

View File

@ -38,10 +38,6 @@ static pthread_key_t egl_release_thread_key;
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
static void EglThreadExitCallback(void* key_value) { static void EglThreadExitCallback(void* key_value) {
#if defined(__ANDROID__)
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
EGL_NO_CONTEXT);
#else
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
// parameter for eglMakeCurrent. This behavior is not portable to all EGL // parameter for eglMakeCurrent. This behavior is not portable to all EGL
// implementations, and should be considered as an undocumented vendor // implementations, and should be considered as an undocumented vendor
@ -49,7 +45,6 @@ static void EglThreadExitCallback(void* key_value) {
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT); EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif
eglReleaseThread(); eglReleaseThread();
} }

View File

@ -144,14 +144,23 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
context](std::shared_ptr<GlSyncPoint> sync_token) { context](std::shared_ptr<GlSyncPoint> sync_token) {
CHECK_NE(name_, 0); CHECK_NE(name_, 0);
GLuint name_to_delete = name_; GLuint name_to_delete = name_;
context->RunWithoutWaiting([name_to_delete, sync_token]() { context->RunWithoutWaiting([name_to_delete]() {
if (sync_token) { // Note that we do not wait for consumers to be done before deleting the
// TODO: maybe we do not actually have to wait for the // texture. Based on a reading of the GLES 3.0 spec, appendix D:
// consumer sync here. Check docs. // - when a texture is deleted, it is _not_ automatically unbound from
sync_token->WaitOnGpu(); // bind points in other contexts;
} else { // - when a texture is deleted, its name becomes immediately invalid, but
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback"; // the actual object is not deleted until it is no longer in use, i.e.
} // attached to a container object or bound to a context;
// - deleting an object is not an operation that changes its contents;
// - within each context, commands are executed sequentially, so it seems
// like an unbind that follows a command that reads a texture should not
// take effect until the GPU has actually finished executing the
// previous commands.
// The final point is the least explicit in the docs, but it is implied by
// normal single-context behavior. E.g. if you do bind, delete, render,
// unbind, the object is not deleted until the unbind, and it waits for
// the render to finish.
DLOG_IF(ERROR, !glIsTexture(name_to_delete)) DLOG_IF(ERROR, !glIsTexture(name_to_delete))
<< "Deleting invalid texture id: " << name_to_delete; << "Deleting invalid texture id: " << name_to_delete;
glDeleteTextures(1, &name_to_delete); glDeleteTextures(1, &name_to_delete);
@ -185,7 +194,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
<< "Updated existing texture which had not been marked for reuse!"; << "Updated existing texture which had not been marked for reuse!";
CHECK(prod_token); CHECK(prod_token);
producer_sync_ = std::move(prod_token); producer_sync_ = std::move(prod_token);
producer_context_ = producer_sync_->GetContext(); const auto& synced_context = producer_sync_->GetContext();
if (synced_context) {
producer_context_ = synced_context;
}
} }
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const { void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {

View File

@ -34,6 +34,7 @@ android_library(
android_library( android_library(
name = "android_framework_no_mff", name = "android_framework_no_mff",
proguard_specs = [":proguard.pgcfg"], proguard_specs = [":proguard.pgcfg"],
visibility = ["//visibility:public"],
exports = [ exports = [
":android_framework_no_proguard", ":android_framework_no_proguard",
], ],

View File

@ -30,3 +30,10 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -89,10 +89,6 @@ def mediapipe_aar(
calculators = calculators, calculators = calculators,
) )
_mediapipe_proto(
name = name + "_proto",
)
native.genrule( native.genrule(
name = name + "_aar_manifest_generator", name = name + "_aar_manifest_generator",
outs = ["AndroidManifest.xml"], outs = ["AndroidManifest.xml"],
@ -115,19 +111,10 @@ EOF
"//mediapipe/java/com/google/mediapipe/components:java_src", "//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src", "//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src", "//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", ] + mediapipe_java_proto_srcs() +
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
] +
select({ select({
"//conditions:default": [], "//conditions:default": [],
"enable_stats_logging": [ "enable_stats_logging": mediapipe_logging_java_proto_srcs(),
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
],
}), }),
manifest = "AndroidManifest.xml", manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
@ -177,93 +164,9 @@ EOF
assets_dir = assets_dir, assets_dir = assets_dir,
) )
_aar_with_jni(name, name + "_android_lib") mediapipe_build_aar_with_jni(
name = name,
def _mediapipe_proto(name): android_library = name + "_android_lib",
"""Generates MediaPipe java proto libraries.
Args:
name: the name of the target.
"""
_proto_java_src_generator(
name = "mediapipe_log_extension_proto",
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "mediapipe_logging_enums_proto",
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "calculator_proto",
proto_src = "mediapipe/framework/calculator.proto",
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
srcs = ["//mediapipe/framework:protos_src"],
)
_proto_java_src_generator(
name = "landmark_proto",
proto_src = "mediapipe/framework/formats/landmark.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
srcs = ["//mediapipe/framework/formats:protos_src"],
)
_proto_java_src_generator(
name = "rasterization_proto",
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
)
_proto_java_src_generator(
name = "location_data_proto",
proto_src = "mediapipe/framework/formats/location_data.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "detection_proto",
proto_src = "mediapipe/framework/formats/detection.proto",
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "classification_proto",
proto_src = "mediapipe/framework/formats/classification.proto",
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
],
)
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
native.genrule(
name = name + "_proto_java_src_generator",
srcs = srcs + [
"@com_google_protobuf//:lite_well_known_protos",
],
outs = [java_lite_out],
cmd = "$(location @com_google_protobuf//:protoc) " +
"--proto_path=. --proto_path=$(GENDIR) " +
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
tools = [
"@com_google_protobuf//:protoc",
],
) )
def _mediapipe_jni(name, gen_libmediapipe, calculators = []): def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
@ -303,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
alwayslink = 1, alwayslink = 1,
) )
def _aar_with_jni(name, android_library): def mediapipe_build_aar_with_jni(name, android_library):
"""Builds MediaPipe AAR with jni.
Args:
name: The bazel target name.
android_library: the android library that contains jni.
"""
# Generates dummy AndroidManifest.xml for dummy apk usage # Generates dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app target below) # (dummy apk is generated by <name>_dummy_app target below)
native.genrule( native.genrule(
@ -314,7 +224,7 @@ cat > $(OUTS) <<EOF
<manifest <manifest
xmlns:android="http://schemas.android.com/apk/res/android" xmlns:android="http://schemas.android.com/apk/res/android"
package="dummy.package.for.so"> package="dummy.package.for.so">
<uses-sdk android:minSdkVersion="21"/> <uses-sdk android:minSdkVersion="24"/>
</manifest> </manifest>
EOF EOF
""", """,
@ -341,7 +251,128 @@ chmod +w $(location :{}.aar)
origdir=$$PWD origdir=$$PWD
cd $$(mktemp -d) cd $$(mktemp -d)
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*" unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
find lib -name *_dummy_app.so -delete
cp -r lib jni cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name), """.format(android_library, name, name, name, name),
) )
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
"""Extracts the generated MediaPipe java proto source code from the target.
Args:
target: The java proto lite target to be built and extracted.
src_out: The output java proto src code path.
name: The optional bazel target name.
Returns:
The output java proto src code path.
"""
if not name:
name = target.split(":")[-1] + "_proto_java_src_extractor"
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
native.genrule(
name = name + "_proto_java_src_extractor",
srcs = [target],
outs = [src_out],
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
src_out + " $$(dirname $(location " + src_out + "))",
)
return src_out
def mediapipe_java_proto_srcs(name = ""):
"""Extracts the generated MediaPipe framework java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:stream_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StreamHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_factory_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketFactoryProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_generator_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:status_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StatusHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:mediapipe_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:detection_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
))
return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""):
"""Extracts the generated logging-related MediaPipe java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe logging-related java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
))
return proto_src_list

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,26 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "data_util",
srcs = ["data_util.py"],
srcs_version = "PY3",
)
py_test(
name = "data_util_test",
srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":data_util"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
srcs_version = "PY3",
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dataset",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classification_dataset",
srcs = ["classification_dataset.py"],
srcs_version = "PY3",
deps = [":dataset"],
)
py_test(
name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":classification_dataset"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,47 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
super().__init__(dataset, size)
self.index_to_label = index_to_label
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label)
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: float, demonstrates the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self.index_to_label)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
class ClassificationDataLoaderTest(tf.test.TestCase):
def test_split(self):
class MagicClassificationDataLoader(
classification_dataset.ClassificationDataset):
def __init__(self, dataset, size, index_to_label, value):
super(MagicClassificationDataLoader,
self).__init__(dataset, size, index_to_label)
self.value = value
def split(self, fraction):
return self._split(fraction, self.index_to_label, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_to_label = (False, True)
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
magic_value)
# Train/Test data split.
fraction = .25
train_data, test_data = data.split(fraction)
# `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader)
self.assertIsInstance(test_data, MagicClassificationDataLoader)
# Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
self.assertLen(train_data, fraction * len(ds))
self.assertLen(test_data, len(ds) - len(train_data))
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utility library."""
import cv2
import numpy as np
import tensorflow as tf
def load_image(path: str) -> np.ndarray:
"""Loads an image as an RGB numpy array.
Args:
path: input image file absolute path.
Returns:
An RGB image in numpy.ndarray.
"""
tf.compat.v1.logging.info('Loading RGB image %s', path)
# TODO Replace the OpenCV image load and conversion library by
# MediaPipe image utility library once it is ready.
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import flags
import tensorflow as tf
from mediapipe.model_maker.python.core.data import data_util
_WORKSPACE = "mediapipe"
_TEST_DATA_DIR = os.path.join(
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
FLAGS = flags.FLAGS
class DataUtilTest(tf.test.TestCase):
def test_load_rgb_image(self):
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
image_data = data_util.load_image(image_path)
self.assertEqual(image_data.shape, (5184, 3456, 3))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,164 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common dataset for model training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
class Dataset(object):
"""A generic dataset class for loading model training and evaluation dataset.
For each ML task, such as image classification, text classification etc., a
subclass can be derived from this class to provide task-specific data loading
utilities.
"""
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
"""Initializes Dataset class.
To build dataset from raw data, consider using the task specific utilities,
e.g. from_folder().
Args:
tf_dataset: A tf.data.Dataset object that contains a potentially large set
of elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means the ground truth of the raw input data, e.g. the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = tf_dataset
self._size = size
@property
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
"""
return self._size
def gen_tf_dataset(self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:
batch_size: An integer, the returned dataset will be batched by this size.
is_training: A boolean, when True, the returned dataset will be optionally
shuffled and repeated as an endless dataset.
shuffle: A boolean, when True, the returned dataset will be shuffled to
create randomness during model training.
preprocess: A function taking three arguments in order, feature, label and
boolean is_training.
drop_remainder: boolean, whether the finaly batch drops remainder.
Returns:
A TF dataset ready to be consumed by Keras model.
"""
dataset = self._dataset
if preprocess:
preprocess = functools.partial(preprocess, is_training=is_training)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if is_training:
if shuffle:
# Shuffle size should be bigger than the batch_size. Otherwise it's only
# shuffling within the batch, which equals to not having shuffle.
buffer_size = 3 * batch_size
# But since we are doing shuffle before repeat, it doesn't make sense to
# shuffle more than total available entries.
# TODO: Investigate if shuffling before / after repeat
# dataset can get a better performance?
# Shuffle after repeat will give a more randomized dataset and mix the
# epoch boundary: https://www.tensorflow.org/guide/data
if self._size:
buffer_size = min(self._size, buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# TODO: Consider converting dataset to distributed dataset
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
if self._size is not None:
return self._size
else:
return len(self._dataset)
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction)
def _split(self: _DatasetT, fraction: float,
*args) -> Tuple[_DatasetT, _DatasetT]:
"""Implementation for `split` method and returns sub-class instances.
Child DataLoader classes, if requires additional constructor arguments,
should implement their own `split` method by calling `_split` with all
arguments to the constructor.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
*args: additional arguments passed to the sub-class constructor.
Returns:
The splitted two sub datasets.
"""
assert (fraction > 0 and fraction < 1)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
return trainset, testset

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import test_util
class DatasetTest(tf.test.TestCase):
def test_split(self):
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, 4)
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
self.assertIsInstance(train_data, ds.Dataset)
self.assertIsInstance(test_data, ds.Dataset)
for i, elem in enumerate(train_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
def test_len(self):
size = 4
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, size)
self.assertLen(data, size)
def test_gen_tf_dataset(self):
input_dim = 8
data = test_util.create_dataset(
data_size=2, input_shape=[input_dim], num_classes=2)
dataset = data.gen_tf_dataset()
self.assertLen(dataset, 2)
for (feature, label) in dataset:
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
dataset2 = data.gen_tf_dataset(batch_size=2)
self.assertLen(dataset2, 1)
for (feature, label) in dataset2:
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
self.assertEqual(dataset3.cardinality(), 1)
for (feature, label) in dataset3.take(10):
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
mediapipe_files(srcs = ["test.jpg"])
filegroup(
name = "testdata",
srcs = ["test.jpg"],
)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training models. Shared across tasks."""
import dataclasses
import tempfile
from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass
class BaseHParams:
"""Hyperparameters used for training models.
A common set of hyperparameters shared by the training jobs of all model
maker tasks.
Attributes:
learning_rate: The learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to
use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy
documentation for more details:
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
tpu: The Cloud TPU to use for training. This should be either the name used
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
"""
# Parameters for train configuration
learning_rate: float
batch_size: int
epochs: int
steps_per_epoch: Optional[int] = None
# Dataset-related parameters
shuffle: bool = False
# Parameters for model / checkpoint files
export_dir: str = tempfile.mkdtemp()
# Parameters for hardware acceleration
distribution_strategy: str = 'off'
num_gpus: int = -1 # default value of -1 means use all available GPUs
tpu: str = ''

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "custom_model",
srcs = ["custom_model.py"],
srcs_version = "PY3",
deps = [
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)
py_test(
name = "custom_model_test",
srcs = ["custom_model_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":custom_model",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classifier",
srcs = ["classifier.py"],
srcs_version = "PY3",
deps = [
":custom_model",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_test(
name = "classifier_test",
srcs = ["classifier_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":classifier",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,77 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Any, List
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
full_train: bool):
"""Initilizes a classifier with its specifications.
Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
"""
super(Classifier, self).__init__(model_spec, shuffle)
self._index_to_label = index_to_label
self._full_train = full_train
self._num_classes = len(index_to_label)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.
Args:
data: Evaluation dataset
batch_size: Number of samples per evaluation step.
Returns:
The loss value and accuracy.
"""
ds = data.gen_tf_dataset(
batch_size, is_training=False, preprocess=self._preprocess)
return self._model.evaluate(ds)
def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'):
"""Exports classification labels into a label file.
Args:
export_dir: The directory to save exported files.
label_filename: File name to save labels model. The full export path is
{export_dir}/{label_filename}.
"""
if not tf.io.gfile.exists(export_dir):
tf.io.gfile.makedirs(export_dir)
label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_to_label))

View File

@ -0,0 +1,58 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import test_util
class MockClassifier(classifier.Classifier):
"""A mock class with implementation of abstract methods for testing."""
def train(self, train_data, validation_data=None, **kwargs):
pass
def evaluate(self, data, **kwargs):
pass
class ClassifierTest(tf.test.TestCase):
def setUp(self):
super(ClassifierTest, self).setUp()
index_to_label = ['cat', 'dog']
self.model = MockClassifier(
model_spec=None,
index_to_label=index_to_label,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath))
self.assertGreater(os.path.getsize(filepath), 0)
def test_export_labels(self):
export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_labels(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'labels.txt'))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,85 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface to define a custom model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import os
from typing import Any, Callable, Optional
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
class CustomModel(abc.ABC):
"""The abstract base class that represents a custom TensorFlow model."""
def __init__(self, model_spec: Any, shuffle: bool):
"""Initializes a custom model with model specs and other parameters.
Args:
model_spec: Specification for the model.
shuffle: Whether the training data need be shuffled.
"""
self._model_spec = model_spec
self._shuffle = shuffle
self._preprocess = None
self._model = None
@abc.abstractmethod
def evaluate(self, data: dataset.Dataset, **kwargs):
"""Evaluates the model with the provided data."""
return
def summary(self):
"""Prints a summary of the model."""
self._model.summary()
def export_tflite(
self,
export_dir: str,
tflite_filename: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None,
preprocess: Optional[Callable[..., bool]] = None):
"""Converts the model to requested formats.
Args:
export_dir: The directory to save exported files.
tflite_filename: File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training.
"""
if not tf.io.gfile.exists(export_dir):
tf.io.gfile.makedirs(export_dir)
tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite(
self._model,
tflite_filepath,
quantization_config,
preprocess=preprocess)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -0,0 +1,56 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.tasks import custom_model
from mediapipe.model_maker.python.core.utils import test_util
class MockCustomModel(custom_model.CustomModel):
"""A mock class with implementation of abstract methods for testing."""
def train(self, train_data, validation_data=None, **kwargs):
pass
def evaluate(self, data, **kwargs):
pass
class CustomModelTest(tf.test.TestCase):
def setUp(self):
super(CustomModelTest, self).setUp()
self.model = MockCustomModel(model_spec=None, shuffle=False)
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath))
self.assertGreater(os.path.getsize(filepath), 0)
def test_export_tflite(self):
export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_tflite(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,100 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [
":model_util",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
srcs_version = "PY3",
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":image_preprocessing"],
)
py_library(
name = "model_util",
srcs = ["model_util.py"],
srcs_version = "PY3",
deps = [
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_test(
name = "model_util_test",
srcs = ["model_util_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":model_util",
":quantization",
":test_util",
],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY3",
)
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":loss_functions"],
)
py_library(
name = "quantization",
srcs = ["quantization.py"],
srcs_version = "PY3",
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
)
py_test(
name = "quantization_test",
srcs = ["quantization_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":quantization",
":test_util",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,228 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ImageNet preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
IMAGE_SIZE = 224
CROP_PADDING = 32
class Preprocessor(object):
"""Preprocessor for image classification."""
def __init__(self,
input_shape,
num_classes,
mean_rgb,
stddev_rgb,
use_augmentation=False):
self.input_shape = input_shape
self.num_classes = num_classes
self.mean_rgb = mean_rgb
self.stddev_rgb = stddev_rgb
self.use_augmentation = use_augmentation
def __call__(self, image, label, is_training=True):
if self.use_augmentation:
return self._preprocess_with_augmentation(image, label, is_training)
return self._preprocess_without_augmentation(image, label)
def _preprocess_with_augmentation(self, image, label, is_training):
"""Image preprocessing method with data augmentation."""
image_size = self.input_shape[0]
if is_training:
image = preprocess_for_train(image, image_size)
else:
image = preprocess_for_eval(image, image_size)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
# TODO: Changes to preprocess to support batch input.
def _preprocess_without_augmentation(self, image, label):
"""Image preprocessing method without data augmentation."""
image = tf.cast(image, tf.float32)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
image = tf.compat.v1.image.resize(image, self.input_shape)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
def _distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
of the image must contain at least this fraction of any bounding box
supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image must
contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
Returns:
A cropped image `Tensor`
"""
with tf.name_scope('distorted_bounding_box_crop'):
shape = tf.shape(image)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
image = tf.image.crop_to_bounding_box(image, offset_y, offset_x,
target_height, target_width)
return image
def _at_least_x_are_equal(a, b, x):
"""At least `x` of `a` and `b` `Tensors` are equal."""
match = tf.equal(a, b)
match = tf.cast(match, tf.int32)
return tf.greater_equal(tf.reduce_sum(match), x)
def _resize_image(image, image_size, method=None):
if method is not None:
tf.compat.v1.logging.info('Use customized resize method {}'.format(method))
return tf.compat.v1.image.resize([image], [image_size, image_size],
method)[0]
tf.compat.v1.logging.info('Use default resize_bicubic.')
return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
def _decode_and_random_crop(original_image, image_size, resize_method=None):
"""Makes a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = _distorted_bounding_box_crop(
original_image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(3. / 4, 4. / 3.),
area_range=(0.08, 1.0),
max_attempts=10)
original_shape = tf.shape(original_image)
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
image = tf.cond(bad,
lambda: _decode_and_center_crop(original_image, image_size),
lambda: _resize_image(image, image_size, resize_method))
return image
def _decode_and_center_crop(image, image_size, resize_method=None):
"""Crops to center of image with padding then scales image_size."""
shape = tf.shape(image)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
((image_size / (image_size + CROP_PADDING)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
padded_center_crop_size,
padded_center_crop_size)
image = _resize_image(image, image_size, resize_method)
return image
def _flip(image):
"""Random horizontal image flip."""
image = tf.image.random_flip_left_right(image)
return image
def preprocess_for_train(
image: tf.Tensor,
image_size: int = IMAGE_SIZE,
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
image_size: image size.
resize_method: resize method. If none, use bicubic.
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_random_crop(image, image_size, resize_method)
image = _flip(image)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image
def preprocess_for_eval(
image: tf.Tensor,
image_size: int = IMAGE_SIZE,
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
image_size: image size.
resize_method: if None, use bicubic.
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_center_crop(image, image_size, resize_method)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image

View File

@ -0,0 +1,85 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import image_preprocessing
def _get_preprocessed_image(preprocessor, is_training=False):
image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3])
label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
image_tensor, _ = preprocessor(image_placeholder, label_placeholder,
is_training)
with tf.compat.v1.Session() as sess:
input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3])
image = sess.run(
image_tensor,
feed_dict={
image_placeholder: input_image,
label_placeholder: [0]
})
return image
class PreprocessorTest(tf.test.TestCase):
def test_preprocess_without_augmentation(self):
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
num_classes=2,
mean_rgb=[0.0],
stddev_rgb=[255.0],
use_augmentation=False)
actual_image = np.array([[[0., 0.00392157, 0.00784314],
[0.14117648, 0.14509805, 0.14901961]],
[[0.37647063, 0.3803922, 0.38431376],
[0.5176471, 0.52156866, 0.5254902]]])
image = _get_preprocessed_image(preprocessor)
self.assertTrue(np.allclose(image, actual_image, atol=1e-05))
def test_preprocess_with_augmentation(self):
image_preprocessing.CROP_PADDING = 1
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
num_classes=2,
mean_rgb=[0.0],
stddev_rgb=[255.0],
use_augmentation=True)
# Tests validation image.
actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216],
[0.26666668, 0.27058825, 0.27450982]],
[[0.42352945, 0.427451, 0.43137258],
[0.5176471, 0.52156866, 0.5254902]]])
image = _get_preprocessed_image(preprocessor, is_training=False)
self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05))
# Tests training image.
image1 = _get_preprocessed_image(preprocessor, is_training=True)
image2 = _get_preprocessed_image(preprocessor, is_training=True)
self.assertFalse(np.allclose(image1, image2, atol=1e-05))
self.assertEqual(image1.shape, (2, 2, 3))
self.assertEqual(image2.shape, (2, 2, 3))
if __name__ == '__main__':
tf.compat.v1.disable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,105 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss function utility library."""
from typing import Optional, Sequence
import tensorflow as tf
class FocalLoss(tf.keras.losses.Loss):
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
This class computes the focal loss between labels and prediction. Focal loss
is a weighted loss function that modulates the standard cross-entropy loss
based on how well the neural network performs on a specific example of a
class. The labels should be provided in a `one_hot` vector representation.
There should be `#classes` floating point values per prediction.
The loss is reduced across all samples using 'sum_over_batch_size' reduction
(see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction).
Example usage:
>>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> gamma = 2
>>> focal_loss = FocalLoss(gamma)
>>> focal_loss(y_true, y_pred).numpy()
0.9326
>>> # Calling with 'sample_weight'.
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.6528
Usage with the `compile()` API:
```python
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
```
"""
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
"""Constructor.
Args:
gamma: Focal loss gamma, as described in class docs.
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super(tf.keras.losses.Loss, self).__init__()
# Used for clipping min/max values of probability values in y_pred to avoid
# NaNs and Infs in computation.
self._epsilon = 1e-7
# This is a tunable "focusing parameter"; should be >= 0.
# When gamma = 0, the loss returned is the standard categorical
# cross-entropy loss.
self._gamma = gamma
self._class_weight = class_weight
# tf.keras.losses.Loss class implementation requires a Reduction specified
# in self.reduction. To use this reduction, we should use tensorflow's
# compute_weighted_loss function however it is only compatible with v1 of
# Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long
# So even though it is specified here, we don't use self.reduction in the
# loss function call.
self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
def __call__(self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor:
if self._class_weight:
class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32)
label = tf.argmax(y_true, axis=1)
loss_weight = tf.gather(class_weight, label)
else:
loss_weight = tf.ones(tf.shape(y_true)[0])
y_true = tf.cast(y_true, y_pred.dtype)
y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon)
batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype)
if sample_weight is None:
sample_weight = tf.constant(1.0)
weight_shape = sample_weight.shape
weight_rank = weight_shape.ndims
y_pred_rank = y_pred.shape.ndims
if y_pred_rank - weight_rank == 1:
sample_weight = tf.expand_dims(sample_weight, [-1])
elif weight_rank != 0:
raise ValueError(f'Unexpected sample_weights, should be either a scalar'
f'or a vector of batch_size:{batch_size.numpy()}')
ce = -tf.math.log(y_pred)
modulating_factor = tf.math.pow(1 - y_pred, self._gamma)
losses = y_true * modulating_factor * ce * sample_weight
losses = losses * loss_weight[:, tf.newaxis]
# By default, this function uses "sum_over_batch_size" reduction for the
# loss per batch.
return tf.reduce_sum(losses) / batch_size

View File

@ -0,0 +1,103 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name='no_sample_weight', sample_weight=None),
dict(
testcase_name='with_sample_weight',
sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])))
def test_focal_loss_gamma_0_is_cross_entropy(
self, sample_weight: Optional[tf.Tensor]):
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
0]])
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
tf_cce = tf.keras.losses.CategoricalCrossentropy(
from_logits=False,
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
focal_loss = loss_functions.FocalLoss(gamma=0)
self.assertAllClose(
tf_cce(y_true, y_pred, sample_weight=sample_weight),
focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4)
def test_focal_loss_with_sample_weight(self):
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
0]])
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
focal_loss = loss_functions.FocalLoss(gamma=0)
sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])
self.assertGreater(
focal_loss(y_true=y_true, y_pred=y_pred),
focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight))
@parameterized.named_parameters(
dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])),
dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])),
dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])),
dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])),
dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])),
)
def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor):
y_true = tf.constant([[1, 0]])
focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0)
loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred)
focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5)
loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred)
focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1)
loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred)
focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2)
loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred)
focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5)
loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred)
self.assertGreater(loss_gamma_0, loss_gamma_0p5)
self.assertGreater(loss_gamma_0p5, loss_gamma_1)
self.assertGreater(loss_gamma_1, loss_gamma_2)
self.assertGreater(loss_gamma_2, loss_gamma_5)
@parameterized.named_parameters(
dict(testcase_name='index_0', true_class=0),
dict(testcase_name='index_1', true_class=1),
dict(testcase_name='index_2', true_class=2),
)
def test_focal_loss_class_weight_is_applied(self, true_class: int):
class_weight = [1.0, 3.0, 10.0]
y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0
y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :]
expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class]
loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight)
loss = loss_fn(y_true, y_pred)
self.assertNear(loss, expected_loss, 1e-4)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,272 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for keras models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
# Dependency imports
import numpy as np
import tensorflow as tf
# resources dependency
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000
def load_keras_model(model_path: str,
compile_on_load: bool = False) -> tf.keras.Model:
"""Loads a tensorflow Keras model from file and returns the Keras model.
Args:
model_path: Relative path to a directory containing model data, such as
<parent_path>/saved_model/.
compile_on_load: Whether the model should be compiled while loading. If
False, the model returned has to be compiled with the appropriate loss
function and custom metrics before running for inference on a test
dataset.
Returns:
A tensorflow Keras model.
"""
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
# with the `model_path` which defines the relative path under mediapipe/, it
# yields to the aboslution path of the model files directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, model_path)
return tf.keras.models.load_model(
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int:
"""Gets the estimated training steps per epoch.
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
2. Else if we can get the length of training data successfully, returns
`train_data_length // batch_size`.
Args:
steps_per_epoch: int, training steps per epoch.
batch_size: int, batch size.
train_data: training data.
Returns:
Estimated training steps per epoch.
Raises:
ValueError: if both steps_per_epoch and train_data are not set.
"""
if steps_per_epoch is not None:
# steps_per_epoch is set by users manually.
return steps_per_epoch
else:
if train_data is None:
raise ValueError('Input train_data cannot be None.')
# Gets the steps by the length of the training data.
return len(train_data) // batch_size
def export_tflite(
model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None):
"""Converts the model to tflite format and saves it.
Args:
model: model to be converted to tflite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf')
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
if quantization_config:
converter = quantization_config.set_converter_with_quantization(
converter, preprocess=preprocess)
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
f.write(tflite_model)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(self,
initial_learning_rate: float,
decay_schedule_fn: Callable[[Any], Any],
warmup_steps: int,
name: Optional[str] = None):
"""Initializes a new instance of the `WarmUp` class.
Args:
initial_learning_rate: learning rate after the warmup.
decay_schedule_fn: A function maps step to learning rate. Will be applied
for values of step larger than 'warmup_steps'.
warmup_steps: Number of steps to do warmup for.
name: TF namescope under which to perform the learning rate calculation.
"""
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements linear warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self) -> Dict[Text, Any]:
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'name': self.name
}
class LiteRunner(object):
"""A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str):
"""Initializes Lite runner with tflite model file.
Args:
tflite_filepath: File path to the TFLite model.
"""
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def run(
self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]]
) -> Union[List[tf.Tensor], tf.Tensor]:
"""Runs inference with the TFLite model.
Args:
input_tensors: List / Dict of the input tensors of the TFLite model. The
order should be the same as the keras model if it's a list. It also
accepts tensor directly if the model has only 1 input.
Returns:
List of the output tensors for multi-output models, otherwise just
the output tensor. The order should be the same as the keras model.
"""
if not isinstance(input_tensors, list) and not isinstance(
input_tensors, dict):
input_tensors = [input_tensors]
interpreter = self.interpreter
# Reshape inputs
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(
input_tensors=input_tensors,
input_details=self.input_details,
index=i)
interpreter.resize_tensor_input(
input_index=input_detail['index'], tensor_size=input_tensor.shape)
interpreter.allocate_tensors()
# Feed input to the interpreter
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(
input_tensors=input_tensors,
input_details=self.input_details,
index=i)
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Quantize the input
scale, zero_point = input_detail['quantization']
input_tensor = input_tensor / scale + zero_point
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
interpreter.set_tensor(input_detail['index'], input_tensor)
interpreter.invoke()
output_tensors = []
for output_detail in self.output_details:
output_tensor = interpreter.get_tensor(output_detail['index'])
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Dequantize the output
scale, zero_point = output_detail['quantization']
output_tensor = output_tensor.astype(np.float32)
output_tensor = (output_tensor - zero_point) * scale
output_tensors.append(output_tensor)
if len(output_tensors) == 1:
return output_tensors[0]
return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model."""
lite_runner = LiteRunner(tflite_filepath)
return lite_runner
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
tf.Tensor]],
input_details: Dict[str, Any], index: int) -> tf.Tensor:
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
if isinstance(input_tensors, dict):
# Gets the mapped input tensor.
input_detail = input_details
for input_tensor_name, input_tensor in input_tensors.items():
if input_tensor_name in input_detail['name']:
return input_tensor
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
'input detail %s' % str(input_detail))
else:
return input_tensors[index]

View File

@ -0,0 +1,148 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
def test_load_model(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
model.save(saved_model_path)
loaded_model = model_util.load_keras_model(saved_model_path)
input_tensors = test_util.create_random_sample(size=[1, input_dim])
model_output = model.predict_on_batch(input_tensors)
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all())
@parameterized.named_parameters(
dict(
testcase_name='input_only_steps_per_epoch',
steps_per_epoch=1000,
batch_size=None,
train_data=None,
expected_steps_per_epoch=1000),
dict(
testcase_name='input_steps_per_epoch_and_batch_size',
steps_per_epoch=1000,
batch_size=32,
train_data=None,
expected_steps_per_epoch=1000),
dict(
testcase_name='input_steps_per_epoch_batch_size_and_train_data',
steps_per_epoch=1000,
batch_size=32,
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=1000),
dict(
testcase_name='input_batch_size_and_train_data',
steps_per_epoch=None,
batch_size=2,
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=2))
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
expected_steps_per_epoch):
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=steps_per_epoch,
batch_size=batch_size,
train_data=train_data)
self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch)
def test_get_steps_per_epoch_raise_value_error(self):
with self.assertRaises(ValueError):
model_util.get_steps_per_epoch(
steps_per_epoch=None, batch_size=16, train_data=None)
def test_warmup(self):
init_lr = 0.1
warmup_steps = 1000
num_decay_steps = 100
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=num_decay_steps)
warmup_object = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=1000,
name='test')
self.assertEqual(
warmup_object.get_config(), {
'initial_learning_rate': init_lr,
'decay_schedule_fn': learning_rate_fn,
'warmup_steps': warmup_steps,
'name': 'test'
})
def test_export_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim)
@parameterized.named_parameters(
dict(
testcase_name='dynamic_quantize',
config=quantization.QuantizationConfig.for_dynamic(),
model_size=1288),
dict(
testcase_name='int8_quantize',
config=quantization.QuantizationConfig.for_int8(
representative_data=test_util.create_dataset(
data_size=10, input_shape=[16], num_classes=3)),
model_size=1832),
dict(
testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(),
model_size=1468))
def test_export_tflite_quantized(self, config, model_size):
input_dim = 16
num_classes = 2
max_input_value = 5
model = test_util.build_model([input_dim], num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config)
self._test_tflite(
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
random_input = test_util.create_random_sample(
size=[1, input_dim], high=max_input_value)
random_input = tf.convert_to_tensor(random_input)
self.assertTrue(
test_util.is_same_output(
tflite_model_file, keras_model, random_input, atol=atol))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,213 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Libraries for post-training quantization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, List, Optional, Union
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
DEFAULT_QUANTIZATION_STEPS = 500
def _get_representative_dataset_generator(dataset: tf.data.Dataset,
num_steps: int) -> Callable[[], Any]:
"""Gets a representative dataset generator for post-training quantization.
The generator is to provide a small dataset to calibrate or estimate the
range, i.e, (min, max) of all floating-point arrays in the model for
quantization. Usually, this is a small subset of a few hundred samples
randomly chosen, in no particular order, from the training or evaluation
dataset. See tf.lite.RepresentativeDataset for more details.
Args:
dataset: Input dataset for extracting representative sub dataset.
num_steps: The number of quantization steps which also reflects the size of
the representative dataset.
Returns:
A representative dataset generator.
"""
def representative_dataset_gen():
"""Generates representative dataset for quantization."""
for data, _ in dataset.take(num_steps):
yield [data]
return representative_dataset_gen
class QuantizationConfig(object):
"""Configuration for post-training quantization.
Refer to
https://www.tensorflow.org/lite/performance/post_training_quantization
for different post-training quantization options.
"""
def __init__(
self,
optimizations: Optional[Union[tf.lite.Optimize,
List[tf.lite.Optimize]]] = None,
representative_data: Optional[ds.Dataset] = None,
quantization_steps: Optional[int] = None,
inference_input_type: Optional[tf.dtypes.DType] = None,
inference_output_type: Optional[tf.dtypes.DType] = None,
supported_ops: Optional[Union[tf.lite.OpsSet,
List[tf.lite.OpsSet]]] = None,
supported_types: Optional[Union[tf.dtypes.DType,
List[tf.dtypes.DType]]] = None,
experimental_new_quantizer: bool = False,
):
"""Constructs QuantizationConfig.
Args:
optimizations: A list of optimizations to apply when converting the model.
If not set, use `[Optimize.DEFAULT]` by default.
representative_data: A representative ds.Dataset for post-training
quantization.
quantization_steps: Number of post-training quantization calibration steps
to run (default to DEFAULT_QUANTIZATION_STEPS).
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays. Defaults to None. If set, must be
be `{tf.float32, tf.uint8, tf.int8}`.
inference_output_type: Target data type of real-number output arrays.
Allows for a different type for output arrays. Defaults to None. If set,
must be `{tf.float32, tf.uint8, tf.int8}`.
supported_ops: Set of OpsSet options supported by the device. Used to Set
converter.target_spec.supported_ops.
supported_types: List of types for constant values on the target device.
Supported values are types exported by lite.constants. Frequently, an
optimization choice is driven by the most compact (i.e. smallest) type
in this list (default [constants.FLOAT]).
experimental_new_quantizer: Whether to enable experimental new quantizer.
Raises:
ValueError: if inference_input_type or inference_output_type are set but
not in {tf.float32, tf.uint8, tf.int8}.
"""
if inference_input_type is not None and inference_input_type not in {
tf.float32, tf.uint8, tf.int8
}:
raise ValueError('Unsupported inference_input_type %s' %
inference_input_type)
if inference_output_type is not None and inference_output_type not in {
tf.float32, tf.uint8, tf.int8
}:
raise ValueError('Unsupported inference_output_type %s' %
inference_output_type)
if optimizations is None:
optimizations = [tf.lite.Optimize.DEFAULT]
if not isinstance(optimizations, list):
optimizations = [optimizations]
self.optimizations = optimizations
self.representative_data = representative_data
if self.representative_data is not None and quantization_steps is None:
quantization_steps = DEFAULT_QUANTIZATION_STEPS
self.quantization_steps = quantization_steps
self.inference_input_type = inference_input_type
self.inference_output_type = inference_output_type
if supported_ops is not None and not isinstance(supported_ops, list):
supported_ops = [supported_ops]
self.supported_ops = supported_ops
if supported_types is not None and not isinstance(supported_types, list):
supported_types = [supported_types]
self.supported_types = supported_types
self.experimental_new_quantizer = experimental_new_quantizer
@classmethod
def for_dynamic(cls) -> 'QuantizationConfig':
"""Creates configuration for dynamic range quantization."""
return QuantizationConfig()
@classmethod
def for_int8(
cls,
representative_data: ds.Dataset,
quantization_steps: int = DEFAULT_QUANTIZATION_STEPS,
inference_input_type: tf.dtypes.DType = tf.uint8,
inference_output_type: tf.dtypes.DType = tf.uint8,
supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8
) -> 'QuantizationConfig':
"""Creates configuration for full integer quantization.
Args:
representative_data: Representative data used for post-training
quantization.
quantization_steps: Number of post-training quantization calibration steps
to run.
inference_input_type: Target data type of real-number input arrays.
inference_output_type: Target data type of real-number output arrays.
supported_ops: Set of `tf.lite.OpsSet` options, where each option
represents a set of operators supported by the target device.
Returns:
QuantizationConfig.
"""
return QuantizationConfig(
representative_data=representative_data,
quantization_steps=quantization_steps,
inference_input_type=inference_input_type,
inference_output_type=inference_output_type,
supported_ops=supported_ops)
@classmethod
def for_float16(cls) -> 'QuantizationConfig':
"""Creates configuration for float16 quantization."""
return QuantizationConfig(supported_types=[tf.float16])
def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter,
**kwargs: Any) -> tf.lite.TFLiteConverter:
"""Sets input TFLite converter with quantization configurations.
Args:
converter: input tf.lite.TFLiteConverter.
**kwargs: arguments used by ds.Dataset.gen_tf_dataset.
Returns:
tf.lite.TFLiteConverter with quantization configurations.
"""
converter.optimizations = self.optimizations
if self.representative_data is not None:
tf_ds = self.representative_data.gen_tf_dataset(
batch_size=1, is_training=False, **kwargs)
converter.representative_dataset = tf.lite.RepresentativeDataset(
_get_representative_dataset_generator(tf_ds, self.quantization_steps))
if self.inference_input_type:
converter.inference_input_type = self.inference_input_type
if self.inference_output_type:
converter.inference_output_type = self.inference_output_type
if self.supported_ops:
converter.target_spec.supported_ops = self.supported_ops
if self.supported_types:
converter.target_spec.supported_types = self.supported_types
if self.experimental_new_quantizer is not None:
converter.experimental_new_quantizer = self.experimental_new_quantizer
return converter

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
def test_create_dynamic_quantization_config(self):
config = quantization.QuantizationConfig.for_dynamic()
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertIsNone(config.representative_data)
self.assertIsNone(config.inference_input_type)
self.assertIsNone(config.inference_output_type)
self.assertIsNone(config.supported_ops)
self.assertIsNone(config.supported_types)
self.assertFalse(config.experimental_new_quantizer)
def test_create_int8_quantization_config(self):
representative_data = test_util.create_dataset(
data_size=10, input_shape=[4], num_classes=3)
config = quantization.QuantizationConfig.for_int8(
representative_data=representative_data)
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertEqual(config.inference_input_type, tf.uint8)
self.assertEqual(config.inference_output_type, tf.uint8)
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
self.assertFalse(config.experimental_new_quantizer)
def test_set_converter_with_quantization_from_int8_config(self):
representative_data = test_util.create_dataset(
data_size=10, input_shape=[4], num_classes=3)
config = quantization.QuantizationConfig.for_int8(
representative_data=representative_data)
model = test_util.build_model(input_shape=[4], num_classes=3)
saved_model_dir = self.get_temp_dir()
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertEqual(config.inference_input_type, tf.uint8)
self.assertEqual(config.inference_output_type, tf.uint8)
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)
def test_create_float16_quantization_config(self):
config = quantization.QuantizationConfig.for_float16()
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertIsNone(config.representative_data)
self.assertIsNone(config.inference_input_type)
self.assertIsNone(config.inference_output_type)
self.assertIsNone(config.supported_ops)
self.assertEqual(config.supported_types, [tf.float16])
self.assertFalse(config.experimental_new_quantizer)
def test_set_converter_with_quantization_from_float16_config(self):
config = quantization.QuantizationConfig.for_float16()
model = test_util.build_model(input_shape=[4], num_classes=3)
saved_model_dir = self.get_temp_dir()
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.supported_types, [tf.float16])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# The input and output are expected to be set to float32 by default.
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
@parameterized.named_parameters(
dict(
testcase_name='invalid_inference_input_type',
inference_input_type=tf.uint8,
inference_output_type=tf.int64),
dict(
testcase_name='invalid_inference_output_type',
inference_input_type=tf.int64,
inference_output_type=tf.float32))
def test_create_quantization_config_failure(self, inference_input_type,
inference_output_type):
with self.assertRaises(ValueError):
_ = quantization.QuantizationConfig(
inference_input_type=inference_input_type,
inference_output_type=inference_output_type)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,94 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test utilities for model maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import List, Union
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import model_util
def create_dataset(data_size: int,
input_shape: List[int],
num_classes: int,
max_input_value: int = 1000) -> ds.Dataset:
"""Creates and returns a simple `Dataset` object for test."""
features = tf.random.uniform(
shape=[data_size] + input_shape,
minval=0,
maxval=max_input_value,
dtype=tf.float32)
labels = tf.random.uniform(
shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32)
tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = ds.Dataset(tf_dataset, data_size)
return dataset
def create_random_sample(size: Union[int, List[int]],
low: float = 0,
high: float = 1) -> np.ndarray:
"""Creates and returns a random sample with floating point values.
Args:
size: Size of the output multi-dimensional array.
low: Lower boundary of the output values.
high: Higher boundary of the output values.
Returns:
1D array if the size is scalar. Otherwise, N-D array whose dimension equals
input size.
"""
np.random.seed(0)
return np.random.uniform(low=low, high=high, size=size).astype(np.float32)
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
"""Builds a simple Keras model for test."""
inputs = tf.keras.layers.Input(shape=input_shape)
if len(input_shape) == 3: # Image inputs.
outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs)
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs)
elif len(input_shape) == 1: # Text inputs.
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)
else:
raise ValueError("Model inputs should be 2D tensor or 4D tensor.")
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def is_same_output(tflite_file: str,
keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file)
lite_output = lite_runner.run(input_tensors)
# Gets output from keras model.
keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol)

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,111 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python library rule.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "image_classifier_import",
srcs = ["__init__.py"],
deps = [
":dataset",
":hyperparameters",
":image_classifier",
":model_spec",
],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
)
py_test(
name = "model_spec_test",
srcs = ["model_spec_test.py"],
deps = [":model_spec"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
deps = [":dataset"],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
)
py_library(
name = "train_image_classifier_lib",
srcs = ["train_image_classifier_lib.py"],
deps = [
":hyperparameters",
"//mediapipe/model_maker/python/core/utils:model_util",
],
)
py_library(
name = "image_classifier",
srcs = ["image_classifier.py"],
deps = [
":hyperparameters",
":model_spec",
":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)
py_library(
name = "image_classifier_test_lib",
testonly = 1,
srcs = ["image_classifier_test.py"],
deps = [":image_classifier_import"],
)
py_test(
name = "image_classifier_test",
srcs = ["image_classifier_test.py"],
shard_count = 2,
tags = ["requires-net:external"],
deps = [
":image_classifier_test_lib",
],
)
py_binary(
name = "image_classifier_demo",
srcs = ["image_classifier_demo.py"],
deps = [
":image_classifier_import",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)

View File

@ -0,0 +1,25 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Image Classifier."""
from mediapipe.model_maker.python.vision.image_classifier import dataset
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import model_spec
ImageClassifier = image_classifier.ImageClassifier
HParams = hyperparameters.HParams
Dataset = dataset.Dataset
ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels

View File

@ -0,0 +1,139 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classifier dataset library."""
import os
import random
from typing import List, Optional, Tuple
import tensorflow as tf
import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset
def _load_image(path: str) -> tf.Tensor:
"""Loads image."""
image_raw = tf.io.read_file(path)
image_tensor = tf.cond(
tf.image.is_jpeg(image_raw),
lambda: tf.image.decode_jpeg(image_raw, channels=3),
lambda: tf.image.decode_png(image_raw, channels=3))
return image_tensor
def _create_data(
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
label_names: List[str]
) -> Optional[classification_dataset.ClassificationDataset]:
"""Creates a Dataset object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return Dataset(data, size, label_names)
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for image classifier."""
@classmethod
def from_folder(
cls,
dirname: str,
shuffle: bool = True) -> classification_dataset.ClassificationDataset:
"""Loads images and labels from the given directory.
Assume the image data of the same label are in the same subdirectory.
Args:
dirname: Name of the directory containing the data files.
shuffle: boolean, if shuffle, random shuffle data.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Image size is zero')
if shuffle:
# Random shuffle data.
random.shuffle(all_image_paths)
label_names = sorted(
name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names)
label_to_index = dict(
(name, index) for index, name in enumerate(label_names))
all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
autotune = tf.data.AUTOTUNE
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
# Loads label.
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64))
# Creates a dataset if (image, label) pairs.
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
tf.compat.v1.logging.info(
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(image_label_ds, all_image_size, label_names)
@classmethod
def load_tf_dataset(
cls, name: str
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset]]:
"""Loads data from tensorflow_datasets.
Args:
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
documentation of tfds.load for more details.
Returns:
A tuple of Datasets for the train/validation/test.
Raises:
ValueError: if the input tf dataset does not have train/validation/test
labels.
"""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = _create_data('train', data, info, label_names)
validation_data = _create_data('validation', data, info, label_names)
test_data = _create_data('test', data, info, label_names)
return train_data, validation_data, test_data

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision.image_classifier import dataset
def _fill_image(rgb, image_size):
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8),
shape=(image_size, image_size, 3))
def _write_filled_jpeg_file(path, rgb, image_size):
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
'channels_last', 'jpeg')
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir')
if os.path.exists(self.image_path):
return
os.mkdir(self.image_path)
for class_name in ('daisy', 'tulips'):
class_subdir = os.path.join(self.image_path, class_name)
os.mkdir(class_subdir)
_write_filled_jpeg_file(
os.path.join(class_subdir, '0.jpeg'),
[random.uniform(0, 255) for _ in range(3)], 224)
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path)
self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0:
raw_image_tensor = dataset._load_image(
os.path.join(self.image_path, 'daisy', '0.jpeg'))
else:
raw_image_tensor = dataset._load_image(
os.path.join(self.image_path, 'tulips', '0.jpeg'))
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
def test_from_tfds(self):
# TODO: Remove this once tfds download error is fixed.
self.skipTest('Temporarily skip the unittest due to tfds download error.')
train_data, validation_data, test_data = (
dataset.Dataset.from_tfds('beans'))
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,74 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training image classification models."""
import dataclasses
import tempfile
from typing import Optional
# TODO: Expose other hyperparameters, e.g. data augmentation
# hyperparameters if requested.
@dataclasses.dataclass
class HParams:
"""The hyperparameters for training image classifiers.
The hyperparameters include:
# Parameters about training data.
do_fine_tuning: If true, the base module is trained together with the
classification layer on top.
shuffle: A boolean controlling if shuffle the dataset. Default to false.
# Parameters about training configuration
train_epochs: Training will do this many iterations over the dataset.
batch_size: Each training step samples a batch of this many images.
learning_rate: The learning rate to use for gradient descent training.
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
l1_regularizer: A regularizer that applies a L1 regularization penalty.
l2_regularizer: A regularizer that applies a L2 regularization penalty.
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
more details.
do_data_augmentation: A boolean controlling whether the training dataset is
augmented by randomly distorting input images, including random cropping,
flipping, etc. See utils.image_preprocessing documentation for details.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
decay_samples: Number of training samples used to calculate the decay steps
and create the training optimizer.
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
# Parameters about the saved checkpoint
model_dir: The location of model checkpoint files and exported model files.
"""
# Parameters about training data
do_fine_tuning: bool = False
shuffle: bool = False
# Parameters about training configuration
train_epochs: int = 5
batch_size: int = 32
learning_rate: float = 0.005
dropout_rate: float = 0.2
l1_regularizer: float = 0.0
l2_regularizer: float = 0.0001
label_smoothing: float = 0.1
do_data_augmentation: bool = True
steps_per_epoch: Optional[int] = None
decay_samples: int = 10000 * 256
warmup_epochs: int = 2
# Parameters about the saved checkpoint
model_dir: str = tempfile.mkdtemp()

View File

@ -0,0 +1,172 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train image classifier model."""
from typing import Any, List, Optional
import tensorflow as tf
import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import image_preprocessing
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
hparams: hp.HParams):
"""Initializes ImageClassifier class.
Args:
model_spec: Specification for the model.
index_to_label: A list that maps from index to label class name.
hparams: The hyperparameters for training image classifier.
"""
super(ImageClassifier, self).__init__(
model_spec=model_spec,
index_to_label=index_to_label,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams
self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=self._num_classes,
mean_rgb=self._model_spec.mean_rgb,
stddev_rgb=self._model_spec.stddev_rgb,
use_augmentation=hparams.do_data_augmentation)
self._history = None # Training history returned from `keras_model.fit`.
@classmethod
def create(
cls,
model_spec: ms.SupportedModels,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
hparams: Optional[hp.HParams] = None,
) -> 'ImageClassifier':
"""Creates and trains an image classifier.
Loads data and trains the model based on data for image classification.
Args:
model_spec: Specification for the model.
train_data: Training data.
validation_data: Validation data.
hparams: Hyperparameters for training image classifier.
Returns:
An instance based on ImageClassifier.
"""
if hparams is None:
hparams = hp.HParams()
spec = ms.SupportedModels.get(model_spec)
image_classifier = cls(
model_spec=spec,
index_to_label=train_data.index_to_label,
hparams=hparams)
image_classifier._create_model()
tf.compat.v1.logging.info('Training the models...')
image_classifier._train(
train_data=train_data, validation_data=validation_data)
return image_classifier
def _train(self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data.
The training results are recorded by a self._history object returned by
tf.keras.Model.fit().
Args:
train_data: Training data.
validation_data: Validation data.
"""
tf.compat.v1.logging.info('Training the models...')
hparams = self._hparams
if len(train_data) < hparams.batch_size:
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
'than batch_size (%d). To solve this problem, set '
'the batch_size smaller or increase the size of the '
'train_data.' % (len(train_data), hparams.batch_size))
train_dataset = train_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=True,
shuffle=self._shuffle,
preprocess=self._preprocess)
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=hparams.steps_per_epoch,
batch_size=hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=False,
preprocess=self._preprocess)
# Train the model.
self._history = train_image_classifier_lib.train_model(
model=self._model,
hparams=hparams,
train_ds=train_dataset,
validation_ds=validation_dataset)
def _create_model(self):
"""Creates the classifier model from TFHub pretrained models."""
module_layer = hub.KerasLayer(
handle=self._model_spec.uri, trainable=self._hparams.do_fine_tuning)
image_size = self._model_spec.input_image_shape
self._model = tf.keras.Sequential([
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
tf.keras.layers.Dense(
units=self._num_classes,
activation='softmax',
kernel_regularizer=tf.keras.regularizers.l1_l2(
l1=self._hparams.l1_regularizer,
l2=self._hparams.l2_regularizer))
])
print(self._model.summary())
def export_model(
self,
model_name: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts the model to the requested formats and exports to a file.
Args:
model_name: File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization.
"""
super().export_tflite(
self._hparams.model_dir,
model_name,
quantization_config,
preprocess=self._preprocess)

View File

@ -0,0 +1,106 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demo for making an image classifier model by MediaPipe Model Maker."""
import os
# Dependency imports
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision import image_classifier
FLAGS = flags.FLAGS
def define_flags() -> None:
"""Define flags for the image classifier model maker demo."""
flags.DEFINE_string('export_dir', None,
'The directory to save exported files.')
flags.DEFINE_string(
'input_data_dir', None,
"""The directory with input training data. If the training data is not
specified, the pipeline will download a default training dataset.""")
flags.DEFINE_enum_class('spec',
image_classifier.SupportedModels.EFFICIENTNET_LITE0,
image_classifier.SupportedModels,
'The image classifier to run.')
flags.DEFINE_enum('quantization', None, ['dynamic', 'int8', 'float16'],
'The quantization method to use when exporting the model.')
flags.mark_flag_as_required('export_dir')
def download_demo_data() -> str:
"""Downloads demo data, and returns directory path."""
data_dir = tf.keras.utils.get_file(
fname='flower_photos.tgz',
origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
return os.path.join(os.path.dirname(data_dir), 'flower_photos') # folder name
def run(data_dir: str, export_dir: str,
model_spec: image_classifier.SupportedModels,
quantization_option: str) -> None:
"""Runs demo."""
data = image_classifier.Dataset.from_folder(data_dir)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
model = image_classifier.ImageClassifier.create(
model_spec=model_spec,
train_data=train_data,
validation_data=validation_data,
hparams=image_classifier.HParams(model_dir=export_dir))
_, acc = model.evaluate(test_data)
print('Test accuracy: %f' % acc)
if quantization_option is None:
quantization_config = None
elif quantization_option == 'dynamic':
quantization_config = quantization.QuantizationConfig.for_dynamic()
elif quantization_option == 'int8':
quantization_config = quantization.QuantizationConfig.for_int8(train_data)
elif quantization_option == 'float16':
quantization_config = quantization.QuantizationConfig.for_float16()
else:
raise ValueError(f'Quantization: {quantization} is not recognized')
model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir)
def main(_) -> None:
logging.set_verbosity(logging.INFO)
if FLAGS.input_data_dir is None:
data_dir = download_demo_data()
else:
data_dir = FLAGS.input_data_dir
export_dir = os.path.expanduser(FLAGS.export_dir)
run(data_dir=data_dir,
export_dir=export_dir,
model_spec=FLAGS.spec,
quantization_option=FLAGS.quantization)
if __name__ == '__main__':
define_flags()
app.run(main)

View File

@ -0,0 +1,122 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier
def _fill_image(rgb, image_size):
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8),
shape=(image_size, image_size, 3))
class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
IMAGE_SIZE = 24
IMAGES_PER_CLASS = 2
CMY_NAMES_AND_RGB_VALUES = (('cyan', (0, 255, 255)),
('magenta', (255, 0, 255)), ('yellow', (255, 255,
0)))
def _gen(self):
for i, (_, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES):
for _ in range(self.IMAGES_PER_CLASS):
yield (_fill_image(rgb, self.IMAGE_SIZE), i)
def _gen_cmy_data(self):
ds = tf.data.Dataset.from_generator(
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
['cyan', 'magenta', 'yellow'])
return data
def setUp(self):
super(ImageClassifierTest, self).setUp()
all_data = self._gen_cmy_data()
# Splits data, 90% data for training, 10% for testing
self.train_data, self.test_data = all_data.split(0.9)
@parameterized.named_parameters(
dict(
testcase_name='mobilenet_v2',
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='resnet_50',
model_spec=image_classifier.SupportedModels.RESNET_50,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite0',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite1',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite3',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
)
def test_create_and_train_model(self,
model_spec: image_classifier.SupportedModels,
hparams: image_classifier.HParams):
model = image_classifier.ImageClassifier.create(
model_spec=model_spec,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model)
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create(
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model)
def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data)
self.assertGreaterEqual(accuracy, threshold)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,104 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classifier model specification."""
import enum
import functools
from typing import List, Optional
class ModelSpec(object):
"""Specification of image classifier model."""
mean_rgb = [0.0]
stddev_rgb = [255.0]
def __init__(self,
uri: str,
input_image_shape: Optional[List[int]] = None,
name: str = ''):
"""Initializes a new instance of the `ImageModelSpec` class.
Args:
uri: str, URI to the pretrained model.
input_image_shape: list of int, input image shape. Default: [224, 224].
name: str, model spec name.
"""
self.uri = uri
self.name = name
if input_image_shape is None:
input_image_shape = [224, 224]
self.input_image_shape = input_image_shape
mobilenet_v2_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
name='mobilenet_v2')
resnet_50_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
name='resnet_50')
efficientnet_lite0_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
name='efficientnet_lite0')
efficientnet_lite1_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
input_image_shape=[240, 240],
name='efficientnet_lite1')
efficientnet_lite2_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
input_image_shape=[260, 260],
name='efficientnet_lite2')
efficientnet_lite3_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
input_image_shape=[280, 280],
name='efficientnet_lite3')
efficientnet_lite4_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
input_image_shape=[300, 300],
name='efficientnet_lite4')
# TODO: Document the exposed models.
@enum.unique
class SupportedModels(enum.Enum):
"""Image classifier model supported by model maker."""
MOBILENET_V2 = mobilenet_v2_spec
RESNET_50 = resnet_50_spec
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
@classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
"""Gets model spec from the input enum and initializes it."""
if spec not in cls:
raise TypeError('Unsupported image classifier spec: {}'.format(spec))
return spec.value()

View File

@ -0,0 +1,93 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, List
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='mobilenet_v2_spec_test',
model_spec=ms.mobilenet_v2_spec,
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
expected_name='mobilenet_v2',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='resnet_50_spec_test',
model_spec=ms.resnet_50_spec,
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
expected_name='resnet_50',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='efficientnet_lite0_spec_test',
model_spec=ms.efficientnet_lite0_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
expected_name='efficientnet_lite0',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='efficientnet_lite1_spec_test',
model_spec=ms.efficientnet_lite1_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
expected_name='efficientnet_lite1',
expected_input_image_shape=[240, 240]),
dict(
testcase_name='efficientnet_lite2_spec_test',
model_spec=ms.efficientnet_lite2_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
expected_name='efficientnet_lite2',
expected_input_image_shape=[260, 260]),
dict(
testcase_name='efficientnet_lite3_spec_test',
model_spec=ms.efficientnet_lite3_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
expected_name='efficientnet_lite3',
expected_input_image_shape=[280, 280]),
dict(
testcase_name='efficientnet_lite4_spec_test',
model_spec=ms.efficientnet_lite4_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
expected_name='efficientnet_lite4',
expected_input_image_shape=[300, 300]),
)
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
expected_uri: str, expected_name: str,
expected_input_image_shape: List[int]):
model_spec_obj = model_spec()
self.assertIsInstance(model_spec_obj, ms.ModelSpec)
self.assertEqual(model_spec_obj.uri, expected_uri)
self.assertEqual(model_spec_obj.name, expected_name)
self.assertEqual(model_spec_obj.input_image_shape,
expected_input_image_shape)
def test_create_spec(self):
custom_model_spec = ms.ModelSpec(
uri='https://custom_model',
input_image_shape=[128, 128],
name='custom_model')
self.assertEqual(custom_model_spec.uri, 'https://custom_model')
self.assertEqual(custom_model_spec.name, 'custom_model')
self.assertEqual(custom_model_spec.input_image_shape, [128, 128])
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,103 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to train model."""
import os
from typing import List
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
def _create_optimizer(init_lr: float, decay_steps: int,
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule.
Uses Keras CosineDecay schedule for the learning rate by default.
Args:
init_lr: Initial learning rate.
decay_steps: Number of steps to decay over.
warmup_steps: Number of steps to do warmup for.
Returns:
A tf.keras.optimizers.Optimizer for model training.
"""
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
if warmup_steps:
learning_rate_fn = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=warmup_steps)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
return optimizer
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback]
def train_model(model: tf.keras.Model, hparams: hp.HParams,
train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
"""Trains model with the input data and hyperparameters.
Args:
model: Input tf.keras.Model.
hparams: Hyperparameters for training image classifier.
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
validation_ds: tf.data.Dataset, validation data to be fed in
tf.keras.Model.fit().
Returns:
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
"""
# Learning rate is linear to batch size.
learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps.
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps)
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
optimizer = _create_optimizer(
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = _get_default_callbacks(hparams.model_dir)
# Train the model.
return model.fit(
x=train_ds,
epochs=hparams.train_epochs,
steps_per_epoch=hparams.steps_per_epoch,
validation_data=validation_ds,
callbacks=callbacks)

View File

@ -0,0 +1,6 @@
absl-py
numpy
opencv-contrib-python
tensorflow
tensorflow-datasets
tensorflow-hub

View File

@ -244,7 +244,7 @@
if ([_session canAddOutput:_depthDataOutput]) { if ([_session canAddOutput:_depthDataOutput]) {
[_session addOutput:_depthDataOutput]; [_session addOutput:_depthDataOutput];
AVCaptureConnection* connection = AVCaptureConnection* __unused connection =
[_depthDataOutput connectionWithMediaType:AVMediaTypeDepthData]; [_depthDataOutput connectionWithMediaType:AVMediaTypeDepthData];
// Set this when we have a handler. // Set this when we have a handler.
@ -327,7 +327,6 @@
if (depthData.depthDataType != kCVPixelFormatType_DepthFloat32) { if (depthData.depthDataType != kCVPixelFormatType_DepthFloat32) {
depthData = [depthData depthDataByConvertingToDepthDataType:kCVPixelFormatType_DepthFloat32]; depthData = [depthData depthDataByConvertingToDepthDataType:kCVPixelFormatType_DepthFloat32];
} }
CVPixelBufferRef depthBuffer = depthData.depthDataMap;
[self.delegate processDepthData:depthData timestamp:timestamp fromSource:self]; [self.delegate processDepthData:depthData timestamp:timestamp fromSource:self];
} }

Some files were not shown because too many files have changed in this diff Show More