Merge branch 'master' into nguyencse/facemeshioslib

This commit is contained in:
nguyencse 2023-08-02 18:00:22 +07:00
commit 1001ead358
190 changed files with 4258 additions and 1068 deletions

View File

@ -157,22 +157,22 @@ http_archive(
# 2020-08-21
http_archive(
name = "com_github_glog_glog",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
"https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
],
)
http_archive(
name = "com_github_glog_glog_no_gflags",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
build_file = "@//third_party:glog_no_gflags.BUILD",
urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
"https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
],
patches = [
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff",
"@//third_party:com_github_glog_glog.diff",
],
patch_args = [
"-p1",

View File

@ -68,30 +68,108 @@ config_setting(
visibility = ["//visibility:public"],
)
# Note: this cannot just match "apple_platform_type": "macos" because that option
# defaults to "macos" even when building on Linux!
alias(
# Generic MacOS.
config_setting(
name = "macos",
actual = select({
":macos_i386": ":macos_i386",
":macos_x86_64": ":macos_x86_64",
":macos_arm64": ":macos_arm64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}),
constraint_values = [
"@platforms//os:macos",
],
visibility = ["//visibility:public"],
)
# Note: this also matches on crosstool_top so that it does not produce ambiguous
# selectors when used together with "android".
# MacOS x86 64-bit.
config_setting(
name = "macos_x86_64",
constraint_values = [
"@platforms//os:macos",
"@platforms//cpu:x86_64",
],
visibility = ["//visibility:public"],
)
# MacOS ARM64.
config_setting(
name = "macos_arm64",
constraint_values = [
"@platforms//os:macos",
"@platforms//cpu:arm64",
],
visibility = ["//visibility:public"],
)
# Generic iOS.
config_setting(
name = "ios",
values = {
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
"apple_platform_type": "ios",
},
constraint_values = [
"@platforms//os:ios",
],
visibility = ["//visibility:public"],
)
# iOS device ARM32.
config_setting(
name = "ios_armv7",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm",
],
visibility = ["//visibility:public"],
)
# iOS device ARM64.
config_setting(
name = "ios_arm64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64",
],
visibility = ["//visibility:public"],
)
# iOS device ARM64E.
config_setting(
name = "ios_arm64e",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64e",
],
visibility = ["//visibility:public"],
)
# iOS simulator x86 32-bit.
config_setting(
name = "ios_i386",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:x86_32",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# iOS simulator x86 64-bit.
config_setting(
name = "ios_x86_64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:x86_64",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# iOS simulator ARM64.
config_setting(
name = "ios_sim_arm64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# Generic Apple.
alias(
name = "apple",
actual = select({
@ -102,49 +180,6 @@ alias(
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_i386",
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_x86_64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_x86_64",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_arm64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_arm64",
},
visibility = ["//visibility:public"],
)
[
config_setting(
name = arch,
values = {"cpu": arch},
visibility = ["//visibility:public"],
)
for arch in [
"ios_i386",
"ios_x86_64",
"ios_armv7",
"ios_arm64",
"ios_arm64e",
"ios_sim_arm64",
]
]
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},

View File

@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator);
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
namespace {
std::unique_ptr<audio_dsp::WindowFunction> MakeWindowFun(
const SpectrogramCalculatorOptions::WindowType window_type) {
switch (window_type) {
// The cosine window and square root of Hann are equivalent.
case SpectrogramCalculatorOptions::COSINE:
case SpectrogramCalculatorOptions::SQRT_HANN:
return std::make_unique<audio_dsp::CosineWindow>();
case SpectrogramCalculatorOptions::HANN:
return std::make_unique<audio_dsp::HannWindow>();
case SpectrogramCalculatorOptions::HAMMING:
return std::make_unique<audio_dsp::HammingWindow>();
}
return nullptr;
}
} // namespace
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>();
@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
output_scale_ = spectrogram_options.output_scale();
std::vector<double> window;
switch (spectrogram_options.window_type()) {
case SpectrogramCalculatorOptions::COSINE:
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::SQRT_HANN: {
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
absl::c_transform(window, window.begin(),
[](double x) { return std::sqrt(x); });
break;
}
auto window_fun = MakeWindowFun(spectrogram_options.window_type());
if (window_fun == nullptr) {
return absl::Status(absl::StatusCode::kInvalidArgument,
absl::StrCat("Invalid window type ",
spectrogram_options.window_type()));
}
std::vector<double> window;
window_fun->GetPeriodicSamples(frame_duration_samples_, &window);
// Propagate settings down to the actual Spectrogram object.
spectrogram_generators_.clear();

View File

@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions {
HANN = 0;
HAMMING = 1;
COSINE = 2;
SQRT_HANN = 4;
SQRT_HANN = 4; // Alias of COSINE.
}
optional WindowType window_type = 6 [default = HANN];

View File

@ -381,17 +381,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "clip_detection_vector_size_calculator",
srcs = ["clip_detection_vector_size_calculator.cc"],
deps = [
":clip_vector_size_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
],
alwayslink = 1,
)
cc_test(
name = "clip_vector_size_calculator_test",
srcs = ["clip_vector_size_calculator_test.cc"],

View File

@ -76,4 +76,9 @@ REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator);
// A calculator to process std::vector<mediapipe::Image>.
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
REGISTER_CALCULATOR(BeginLoopImageCalculator);
// A calculator to process std::vector<float>.
typedef BeginLoopCalculator<std::vector<float>> BeginLoopFloatCalculator;
REGISTER_CALCULATOR(BeginLoopFloatCalculator);
} // namespace mediapipe

View File

@ -123,7 +123,10 @@ class PreviousLoopbackCalculator : public Node {
// However, LOOP packet is empty.
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
} else {
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
// Avoids sending leftovers to a stream that's already closed.
if (!kPrevLoop(cc).IsClosed()) {
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
}
}
loop_packets_.pop_front();
main_packet_specs_.pop_front();

View File

@ -135,7 +135,6 @@ cc_library(
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status",
],

View File

@ -112,7 +112,7 @@ class BilateralFilterCalculator : public CalculatorBase {
REGISTER_CALCULATOR(BilateralFilterCalculator);
absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1);
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) {

View File

@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator);
absl::Status SegmentationSmoothingCalculator::GetContract(
CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1);
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
cc->Inputs().Tag(kCurrentMaskTag).Set<Image>();
cc->Inputs().Tag(kPreviousMaskTag).Set<Image>();

View File

@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase {
REGISTER_CALCULATOR(SetAlphaCalculator);
absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1);
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false;

View File

@ -38,7 +38,7 @@ std::string FourCCToString(libyuv::FourCC fourcc) {
buf[0] = (fourcc >> 24) & 0xff;
buf[1] = (fourcc >> 16) & 0xff;
buf[2] = (fourcc >> 8) & 0xff;
buf[3] = (fourcc)&0xff;
buf[3] = (fourcc) & 0xff;
buf[4] = 0;
return std::string(buf);
}

View File

@ -282,18 +282,23 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
if (options.has_volume_gain_db()) {
gain_ = pow(10, options.volume_gain_db() / 20.0);
}
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
!kAudioIn(cc).Header().IsEmpty())
<< "Must either specify the time series header of the \"AUDIO\" stream "
"or have the \"SAMPLE_RATE\" stream connected.";
if (!kAudioIn(cc).Header().IsEmpty()) {
mediapipe::TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
kAudioIn(cc).Header(), &input_header));
if (stream_mode_) {
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
} else {
source_sample_rate_ = input_header.sample_rate();
if (options.has_source_sample_rate()) {
source_sample_rate_ = options.source_sample_rate();
} else {
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
!kAudioIn(cc).Header().IsEmpty())
<< "Must either specify the time series header of the \"AUDIO\" stream "
"or have the \"SAMPLE_RATE\" stream connected.";
if (!kAudioIn(cc).Header().IsEmpty()) {
mediapipe::TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(
mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
kAudioIn(cc).Header(), &input_header));
if (stream_mode_) {
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
} else {
source_sample_rate_ = input_header.sample_rate();
}
}
}
AppendZerosToSampleBuffer(padding_samples_before_);

View File

@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions {
// The volume gain, measured in dB.
// Scale the input audio amplitude by 10^(volume_gain_db/20).
optional double volume_gain_db = 12;
// The source number of samples per second (hertz) of the input audio buffers.
optional double source_sample_rate = 13;
}

View File

@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
gpu_delegate_options);
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool UseSerializedModel() const { return use_serialized_model_; }
private:
bool use_kernel_caching_ = false;
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
}
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
MP_RETURN_IF_ERROR(
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
tflite_gpu_runner_.reset();
return absl::OkStatus();
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
tflite_gpu_runner_->GetOutputShapes()[i].c};
}
if (on_disk_cache_helper_.UseSerializedModel()) {
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
}
MP_RETURN_IF_ERROR(
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
return tflite_gpu_runner_->Build();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
}
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)

View File

@ -256,6 +256,7 @@ class TensorsToDetectionsCalculator : public Node {
bool gpu_inited_ = false;
bool gpu_input_ = false;
bool gpu_has_enough_work_groups_ = true;
bool anchors_init_ = false;
};
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
@ -291,7 +292,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
auto output_detections = absl::make_unique<std::vector<Detection>>();
bool gpu_processing = false;
if (CanUseGpu()) {
if (CanUseGpu() && gpu_has_enough_work_groups_) {
// Use GPU processing only if at least one input tensor is already on GPU
// (to avoid CPU->GPU overhead).
for (const auto& tensor : *kInTensors(cc)) {
@ -321,11 +322,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
RET_CHECK(!has_custom_box_indices_);
}
if (gpu_processing) {
if (!gpu_inited_) {
MP_RETURN_IF_ERROR(GpuInit(cc));
if (gpu_processing && !gpu_inited_) {
auto status = GpuInit(cc);
if (status.ok()) {
gpu_inited_ = true;
} else if (status.code() == absl::StatusCode::kFailedPrecondition) {
// For initialization error because of hardware limitation, fallback to
// CPU processing.
LOG(WARNING) << status.message();
} else {
// For other error, let the error propagates.
return status;
}
}
if (gpu_processing && gpu_inited_) {
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
} else {
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
@ -346,17 +356,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// TODO: Add flexible input tensor size handling.
auto raw_box_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
if (raw_box_tensor->shape().dims.size() == 3) {
// The tensors from CPU inference has dim 3.
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
} else if (raw_box_tensor->shape().dims.size() == 4) {
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
// we allow tensors with 4 dims.
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_);
} else {
return absl::InvalidArgumentError(
"The dimensions of box Tensor must be 3 or 4.");
}
auto raw_score_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
if (raw_score_tensor->shape().dims.size() == 3) {
// The tensors from CPU inference has dim 3.
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
} else if (raw_score_tensor->shape().dims.size() == 4) {
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
// we allow tensors with 4 dims.
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_);
} else {
return absl::InvalidArgumentError(
"The dimensions of score Tensor must be 3 or 4.");
}
auto raw_box_view = raw_box_tensor->GetCpuReadView();
auto raw_boxes = raw_box_view.buffer<float>();
auto raw_scores_view = raw_score_tensor->GetCpuReadView();
@ -1111,8 +1145,13 @@ void main() {
int max_wg_size; // typically <= 1024
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
&max_wg_size); // y-dim
CHECK_LT(num_classes_, max_wg_size)
<< "# classes must be < " << max_wg_size;
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
if (!gpu_has_enough_work_groups_) {
return absl::FailedPreconditionError(absl::StrFormat(
"Hardware limitation: Processing will be done on CPU, because "
"num_classes %d exceeds the max work_group size %d.",
num_classes_, max_wg_size));
}
// TODO support better filtering.
if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(),
@ -1370,7 +1409,13 @@ kernel void scoreKernel(
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
// # filter classes supported is hardware dependent.
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
if (!gpu_has_enough_work_groups_) {
return absl::FailedPreconditionError(absl::StrFormat(
"Hardware limitation: Processing will be done on CPU, because "
"num_classes %d exceeds the max work_group size %d.",
num_classes_, max_wg_size));
}
}
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)

View File

@ -406,8 +406,13 @@ cc_library(
alwayslink = 1,
)
# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed
# Boq conformance test. Weigh your use case to see if this will work for you.
# This dependency removed the following 3 targets because they failed Boq conformance test:
#
# tensorflow_jellyfish_deps
# jfprof_lib
# xprofilez_with_server
#
# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader.
cc_library(
name = "tensorflow_inference_calculator_for_boq",
srcs = ["tensorflow_inference_calculator.cc"],
@ -927,7 +932,6 @@ cc_test(
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/formats:location_opencv",
"//mediapipe/framework/port:gtest_main",

View File

@ -164,8 +164,8 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
}
CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
<< "Neither the output stream nor the output side packet is set to "
"output the sequence example.";
if (cc->Outputs().HasTag(kSequenceExampleTag)) {

View File

@ -23,7 +23,6 @@
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/location_opencv.h"
#include "mediapipe/framework/port/gmock.h"
@ -96,7 +95,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2);
@ -139,7 +139,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2);
@ -378,7 +379,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
Adopt(input_sequence.release());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
auto image_ptr =
@ -410,7 +412,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string);
@ -618,7 +621,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
}
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(width);
@ -767,7 +771,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string);
@ -813,7 +818,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string);
@ -970,7 +976,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2);
@ -1021,7 +1028,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
int height = 2;

View File

@ -172,7 +172,7 @@ class AnnotationOverlayCalculator : public CalculatorBase {
REGISTER_CALCULATOR(AnnotationOverlayCalculator);
absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1);
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false;
@ -189,13 +189,13 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
#if !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kGpuBufferTag)) {
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
CHECK(cc->Outputs().HasTag(kGpuBufferTag));
RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag));
use_gpu = true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kImageFrameTag)) {
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
CHECK(cc->Outputs().HasTag(kImageFrameTag));
RET_CHECK(cc->Outputs().HasTag(kImageFrameTag));
}
// Data streams to render.

View File

@ -322,27 +322,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
options_.presence_threshold(), options_.connection_color(), thickness,
/*normalized=*/false, render_data.get());
}
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const Landmark& landmark = landmarks.landmark(i);
if (options_.render_landmarks()) {
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const Landmark& landmark = landmarks.landmark(i);
if (!IsLandmarkVisibleAndPresent<Landmark>(
landmark, options_.utilize_visibility(),
options_.visibility_threshold(), options_.utilize_presence(),
options_.presence_threshold())) {
continue;
}
if (!IsLandmarkVisibleAndPresent<Landmark>(
landmark, options_.utilize_visibility(),
options_.visibility_threshold(), options_.utilize_presence(),
options_.presence_threshold())) {
continue;
}
auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render,
options_.min_depth_circle_thickness(),
options_.max_depth_circle_thickness());
auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
landmark_data_render,
options_.min_depth_circle_thickness(),
options_.max_depth_circle_thickness());
}
auto* landmark_data = landmark_data_render->mutable_point();
landmark_data->set_normalized(false);
landmark_data->set_x(landmark.x());
landmark_data->set_y(landmark.y());
}
auto* landmark_data = landmark_data_render->mutable_point();
landmark_data->set_normalized(false);
landmark_data->set_x(landmark.x());
landmark_data->set_y(landmark.y());
}
}
@ -368,27 +371,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
options_.presence_threshold(), options_.connection_color(), thickness,
/*normalized=*/true, render_data.get());
}
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const NormalizedLandmark& landmark = landmarks.landmark(i);
if (options_.render_landmarks()) {
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const NormalizedLandmark& landmark = landmarks.landmark(i);
if (!IsLandmarkVisibleAndPresent<NormalizedLandmark>(
landmark, options_.utilize_visibility(),
options_.visibility_threshold(), options_.utilize_presence(),
options_.presence_threshold())) {
continue;
}
if (!IsLandmarkVisibleAndPresent<NormalizedLandmark>(
landmark, options_.utilize_visibility(),
options_.visibility_threshold(), options_.utilize_presence(),
options_.presence_threshold())) {
continue;
}
auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render,
options_.min_depth_circle_thickness(),
options_.max_depth_circle_thickness());
auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
landmark_data_render,
options_.min_depth_circle_thickness(),
options_.max_depth_circle_thickness());
}
auto* landmark_data = landmark_data_render->mutable_point();
landmark_data->set_normalized(true);
landmark_data->set_x(landmark.x());
landmark_data->set_y(landmark.y());
}
auto* landmark_data = landmark_data_render->mutable_point();
landmark_data->set_normalized(true);
landmark_data->set_x(landmark.x());
landmark_data->set_y(landmark.y());
}
}

View File

@ -32,6 +32,10 @@ message LandmarksToRenderDataCalculatorOptions {
// Color of the landmarks.
optional Color landmark_color = 2;
// Whether to render landmarks as points.
optional bool render_landmarks = 14 [default = true];
// Color of the connections.
optional Color connection_color = 3;

View File

@ -124,7 +124,7 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
int center_row = out_lms.landmark(lm_index).y() * hm_height;
// Point is outside of the image let's keep it intact.
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
center_col >= hm_height) {
center_row >= hm_height) {
continue;
}

View File

@ -130,7 +130,6 @@ cc_library(
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util",
],
@ -341,7 +340,6 @@ cc_test(
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag",
],
)
@ -367,7 +365,6 @@ cc_test(
"//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag",
],
)
@ -451,7 +448,6 @@ cc_test(
"//mediapipe/framework/tool:test_util",
"//mediapipe/util/tracking:box_tracker_cc_proto",
"//mediapipe/util/tracking:tracking_cc_proto",
"@com_google_absl//absl/flags:flag",
],
)

View File

@ -1,6 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "facedetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "facedetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "faceeffect",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "facemeshgpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "handdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "handtrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "helloworld",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "holistictrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "iristrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "objectdetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "objectdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "objectdetectiontrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "posetrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
alias(
name = "selfiesegmentationgpu",

View File

@ -44,6 +44,9 @@ bzl_library(
"encode_binary_proto.bzl",
],
visibility = ["//visibility:public"],
deps = [
"@bazel_skylib//lib:paths",
],
)
alias(

View File

@ -64,58 +64,13 @@ class CalculatorBaseFactoryFor<
namespace api2 {
namespace internal {
// Defining a member of this type causes P to be ODR-used, which forces its
// instantiation if it's a static member of a template.
// Previously we depended on the pointer's value to determine whether the size
// of a character array is 0 or 1, forcing it to be instantiated so the
// compiler can determine the object's layout. But using it as a template
// argument is more compact.
template <auto* P>
struct ForceStaticInstantiation {
#ifdef _MSC_VER
// Just having it as the template argument does not count as a use for
// MSVC.
static constexpr bool Use() { return P != nullptr; }
char force_static[Use()];
#endif // _MSC_VER
};
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(
NodeRegistrator, mediapipe::CalculatorBaseRegistry, T::kCalculatorName,
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>)
// Helper template for forcing the definition of a static registration token.
template <typename T>
struct NodeRegistrationStatic {
static NoDestructor<mediapipe::RegistrationToken> registration;
static mediapipe::RegistrationToken Make() {
return mediapipe::CalculatorBaseRegistry::Register(
T::kCalculatorName,
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>,
__FILE__, __LINE__);
}
using RequireStatics = ForceStaticInstantiation<&registration>;
};
// Static members of template classes can be defined in the header.
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
NodeRegistrationStatic<T>::registration(NodeRegistrationStatic<T>::Make());
template <typename T>
struct SubgraphRegistrationImpl {
static NoDestructor<mediapipe::RegistrationToken> registration;
static mediapipe::RegistrationToken Make() {
return mediapipe::SubgraphRegistry::Register(
T::kCalculatorName, absl::make_unique<T>, __FILE__, __LINE__);
}
using RequireStatics = ForceStaticInstantiation<&registration>;
};
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
SubgraphRegistrationImpl<T>::registration(
SubgraphRegistrationImpl<T>::Make());
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(SubgraphRegistrator,
mediapipe::SubgraphRegistry,
T::kCalculatorName, absl::make_unique<T>)
} // namespace internal
@ -128,14 +83,7 @@ template <class Impl = void>
class RegisteredNode;
template <class Impl>
class RegisteredNode : public Node {
private:
// The member below triggers instantiation of the registration static.
// Note that the constructor of calculator subclasses is only invoked through
// the registration token, and so we cannot simply use the static in the
// constructor.
typename internal::NodeRegistrationStatic<Impl>::RequireStatics register_;
};
class RegisteredNode : public Node, private internal::NodeRegistrator<Impl> {};
// No-op version for backwards compatibility.
template <>
@ -217,31 +165,27 @@ class NodeImpl : public RegisteredNode<Impl>, public Intf {
// TODO: verify that the subgraph config fully implements the
// declared interface.
template <class Intf, class Impl>
class SubgraphImpl : public Subgraph, public Intf {
private:
typename internal::SubgraphRegistrationImpl<Impl>::RequireStatics register_;
};
class SubgraphImpl : public Subgraph,
public Intf,
private internal::SubgraphRegistrator<Impl> {};
// This macro is used to register a calculator that does not use automatic
// registration. Deprecated.
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(calculator_registration, \
__LINE__)(mediapipe::CalculatorBaseRegistry::Register( \
Impl::kCalculatorName, \
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>, \
__FILE__, __LINE__))
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
mediapipe::CalculatorBaseRegistry, calculator_registration, \
Impl::kCalculatorName, \
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>)
// This macro is used to register a non-split-contract calculator. Deprecated.
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
// This macro is used to define a subgraph that does not use automatic
// registration. Deprecated.
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(subgraph_registration, \
__LINE__)(mediapipe::SubgraphRegistry::Register( \
Impl::kCalculatorName, absl::make_unique<Impl>, __FILE__, __LINE__))
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
mediapipe::SubgraphRegistry, subgraph_registration, \
Impl::kCalculatorName, absl::make_unique<Impl>)
} // namespace api2
} // namespace mediapipe

View File

@ -183,8 +183,7 @@ TEST(CalculatorTest, CreateByNameWhitelisted) {
CalculatorBaseRegistry::Register(
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
absl::make_unique<internal::CalculatorBaseFactoryFor<
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>,
__FILE__, __LINE__);
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>);
// A whitelisted calculator can be found in its own namespace.
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //

View File

@ -16,7 +16,6 @@
#define MEDIAPIPE_DEPS_REGISTRATION_H_
#include <algorithm>
#include <cstdint>
#include <functional>
#include <string>
#include <tuple>
@ -145,6 +144,23 @@ template <typename T>
struct WrapStatusOr<absl::StatusOr<T>> {
using type = absl::StatusOr<T>;
};
// Defining a member of this type causes P to be ODR-used, which forces its
// instantiation if it's a static member of a template.
// Previously we depended on the pointer's value to determine whether the size
// of a character array is 0 or 1, forcing it to be instantiated so the
// compiler can determine the object's layout. But using it as a template
// argument is more compact.
template <auto* P>
struct ForceStaticInstantiation {
#ifdef _MSC_VER
// Just having it as the template argument does not count as a use for
// MSVC.
static constexpr bool Use() { return P != nullptr; }
char force_static[Use()];
#endif // _MSC_VER
};
} // namespace registration_internal
class NamespaceAllowlist {
@ -162,8 +178,7 @@ class FunctionRegistry {
FunctionRegistry(const FunctionRegistry&) = delete;
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
RegistrationToken Register(absl::string_view name, Function func,
std::string filename, uint64_t line)
RegistrationToken Register(absl::string_view name, Function func)
ABSL_LOCKS_EXCLUDED(lock_) {
std::string normalized_name = GetNormalizedName(name);
absl::WriterMutexLock lock(&lock_);
@ -173,21 +188,10 @@ class FunctionRegistry {
}
if (functions_.insert(std::make_pair(normalized_name, std::move(func)))
.second) {
#ifndef NDEBUG
locations_.emplace(normalized_name,
std::make_pair(std::move(filename), line));
#endif
return RegistrationToken(
[this, normalized_name]() { Unregister(normalized_name); });
}
#ifndef NDEBUG
LOG(FATAL) << "Function with name " << name << " already registered."
<< " First registration at "
<< locations_.at(normalized_name).first << ":"
<< locations_.at(normalized_name).second;
#else
LOG(FATAL) << "Function with name " << name << " already registered.";
#endif
return RegistrationToken([]() {});
}
@ -316,11 +320,6 @@ class FunctionRegistry {
private:
mutable absl::Mutex lock_;
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
#ifndef NDEBUG
// Stores filename and line number for useful debug log.
absl::flat_hash_map<std::string, std::pair<std::string, uint32_t>> locations_
ABSL_GUARDED_BY(lock_);
#endif
// For names included in NamespaceAllowlist, strips the namespace.
std::string GetAdjustedName(absl::string_view name) {
@ -351,10 +350,8 @@ class GlobalFactoryRegistry {
public:
static RegistrationToken Register(absl::string_view name,
typename Functions::Function func,
std::string filename, uint64_t line) {
return functions()->Register(name, std::move(func), std::move(filename),
line);
typename Functions::Function func) {
return functions()->Register(name, std::move(func));
}
// Invokes the specified factory function and returns the result.
@ -414,12 +411,77 @@ class GlobalFactoryRegistry {
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \
static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \
new mediapipe::RegistrationToken( \
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__))
RegistryType::Register(#name, __VA_ARGS__))
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \
name, ...) \
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
new mediapipe::RegistrationToken( \
RegistryType::Register(name, __VA_ARGS__))
// TODO: migrate to the above.
#define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
new mediapipe::RegistrationToken( \
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__))
RegistryType::Register(#name, __VA_ARGS__))
// Defines a utility registrator class which can be used to automatically
// register factory functions.
//
// Example:
// === Defining a registry ================================================
//
// class Component {};
//
// using ComponentRegistry = GlobalFactoryRegistry<std::unique_ptr<Component>>;
//
// === Defining a registrator =============================================
//
// MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(ComponentRegistrator,
// ComponentRegistry, T::kName,
// absl::make_unique<T>);
//
// === Defining and registering a new component. ==========================
//
// class MyComponent : public Component,
// private ComponentRegistrator<MyComponent> {
// public:
// static constexpr char kName[] = "MyComponent";
// ...
// };
//
// NOTE:
// - MyComponent is automatically registered in ComponentRegistry by
// "MyComponent" name.
// - Every component is require to provide its name (T::kName here.)
#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \
name, ...) \
template <typename T> \
struct Internal##RegistratorName { \
static NoDestructor<mediapipe::RegistrationToken> registration; \
\
static mediapipe::RegistrationToken Make() { \
return RegistryType::Register(name, __VA_ARGS__); \
} \
\
using RequireStatics = \
registration_internal::ForceStaticInstantiation<&registration>; \
}; \
/* Static members of template classes can be defined in the header. */ \
template <typename T> \
NoDestructor<mediapipe::RegistrationToken> \
Internal##RegistratorName<T>::registration( \
Internal##RegistratorName<T>::Make()); \
\
template <typename T> \
class RegistratorName { \
private: \
/* The member below triggers instantiation of the registration static. */ \
/* Note that the constructor of calculator subclasses is only invoked */ \
/* through the registration token, and so we cannot simply use the */ \
/* static in theconstructor. */ \
typename Internal##RegistratorName<T>::RequireStatics register_; \
};
} // namespace mediapipe

View File

@ -37,29 +37,33 @@ Args:
output: The desired name of the output file. Optional.
"""
load("@bazel_skylib//lib:paths.bzl", "paths")
PROTOC = "@com_google_protobuf//:protoc"
def _canonicalize_proto_path_oss(all_protos, genfile_path):
"""For the protos from external repository, canonicalize the proto path and the file name.
def _canonicalize_proto_path_oss(f):
if not f.root.path:
return struct(
proto_path = ".",
file_name = f.short_path,
)
Returns:
Proto path list and proto source file list.
"""
proto_paths = []
proto_file_names = []
for s in all_protos.to_list():
if s.path.startswith(genfile_path):
repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/")
# `f.path` looks like "<genfiles>/external/<repo>/(_virtual_imports/<library>/)?<file_name>"
repo_name, _, file_name = f.path[len(paths.join(f.root.path, "external") + "/"):].partition("/")
if file_name.startswith("_virtual_imports/"):
# This is a virtual import; move "_virtual_imports/<library>" from `repo_name` to `file_name`.
repo_name = paths.join(repo_name, *file_name.split("/", 2)[:2])
file_name = file_name.split("/", 2)[-1]
return struct(
proto_path = paths.join(f.root.path, "external", repo_name),
file_name = file_name,
)
# handle virtual imports
if file_name.startswith("_virtual_imports"):
repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2])
file_name = file_name.split("/", 2)[-1]
proto_paths.append(genfile_path + "/external/" + repo_name)
proto_file_names.append(file_name)
else:
proto_file_names.append(s.path)
return ([" --proto_path=" + path for path in proto_paths], proto_file_names)
def _map_root_path(f):
return _canonicalize_proto_path_oss(f).proto_path
def _map_short_path(f):
return _canonicalize_proto_path_oss(f).file_name
def _get_proto_provider(dep):
"""Get the provider for protocol buffers from a dependnecy.
@ -90,24 +94,35 @@ def _encode_binary_proto_impl(ctx):
sibling = textpb,
)
path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path)
args = ctx.actions.args()
args.add(textpb)
args.add(binarypb)
args.add(ctx.executable._proto_compiler)
args.add(ctx.attr.message_type, format = "--encode=%s")
args.add("--proto_path=.")
args.add_all(
all_protos,
map_each = _map_root_path,
format_each = "--proto_path=%s",
uniquify = True,
)
args.add_all(
all_protos,
map_each = _map_short_path,
uniquify = True,
)
# Note: the combination of absolute_paths and proto_path, as well as the exact
# order of gendir before ., is needed for the proto compiler to resolve
# import statements that reference proto files produced by a genrule.
ctx.actions.run_shell(
tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler],
outputs = [binarypb],
command = " ".join(
[
ctx.executable._proto_compiler.path,
"--encode=" + ctx.attr.message_type,
"--proto_path=" + ctx.genfiles_dir.path,
"--proto_path=" + ctx.bin_dir.path,
"--proto_path=.",
] + path_list + file_list +
["<", textpb.path, ">", binarypb.path],
tools = depset(
direct = [textpb, ctx.executable._proto_compiler],
transitive = [all_protos],
),
outputs = [binarypb],
command = "${@:3} < $1 > $2",
arguments = [args],
mnemonic = "EncodeProto",
)

View File

@ -19,7 +19,7 @@ package mediapipe;
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
// the joint and its visibility.
message Joint {
// Joint rotation in 6D contineous representation ordered as
// Joint rotation in 6D continuous representation ordered as
// [a1, b1, a2, b2, a3, b3].
//
// Such representation is more sutable for NN model training and can be

View File

@ -15,7 +15,7 @@ def mediapipe_cc_test(
platforms = ["linux", "android", "ios", "wasm"],
exclude_platforms = None,
# ios_unit_test arguments
ios_minimum_os_version = "11.0",
ios_minimum_os_version = "12.0",
# android_cc_test arguments
open_gl_driver = None,
emulator_mini_boot = True,

View File

@ -466,8 +466,7 @@ struct MessageRegistrationImpl {
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder,
__FILE__, __LINE__));
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
// For non-Message payloads, this does nothing.
template <typename T, typename Enable = void>

View File

@ -261,8 +261,8 @@ cc_library(
)
cc_library(
name = "opencv_highgui",
hdrs = ["opencv_highgui_inc.h"],
name = "opencv_photo",
hdrs = ["opencv_photo_inc.h"],
deps = [
":opencv_core",
"//third_party:opencv",
@ -297,6 +297,15 @@ cc_library(
],
)
cc_library(
name = "opencv_highgui",
hdrs = ["opencv_highgui_inc.h"],
deps = [
":opencv_core",
"//third_party:opencv",
],
)
cc_library(
name = "opencv_videoio",
hdrs = ["opencv_videoio_inc.h"],

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
#define MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
#ifndef MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
#define MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
#include <opencv2/core/version.hpp>
@ -25,4 +25,4 @@
#include <opencv2/highgui.hpp>
#endif
#endif // MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
#endif // MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors.
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,15 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#ifndef MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
#define MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "third_party/OpenCV/photo.hpp"
namespace mediapipe {
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
ClipDetectionVectorSizeCalculator;
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
} // namespace mediapipe
#endif // MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors.
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.

View File

@ -48,6 +48,18 @@ class MuxInputStreamHandler : public InputStreamHandler {
: InputStreamHandler(std::move(tag_map), cc_manager, options,
calculator_run_in_parallel) {}
private:
CollectionItemId GetControlStreamId() const {
return input_stream_managers_.EndId() - 1;
}
void RemoveOutdatedDataPackets(Timestamp timestamp) {
const CollectionItemId control_stream_id = GetControlStreamId();
for (CollectionItemId id = input_stream_managers_.BeginId();
id < control_stream_id; ++id) {
input_stream_managers_.Get(id)->ErasePacketsEarlierThan(timestamp);
}
}
protected:
// In MuxInputStreamHandler, a node is "ready" if:
// - the control stream is done (need to call Close() in this case), or
@ -58,9 +70,15 @@ class MuxInputStreamHandler : public InputStreamHandler {
absl::MutexLock lock(&input_streams_mutex_);
const auto& control_stream =
input_stream_managers_.Get(input_stream_managers_.EndId() - 1);
input_stream_managers_.Get(GetControlStreamId());
bool empty;
*min_stream_timestamp = control_stream->MinTimestampOrBound(&empty);
// Data streams may contain some outdated packets which failed to be popped
// out during "FillInputSet". (This handler doesn't sync input streams,
// hence "FillInputSet" can be triggerred before every input stream is
// filled with packets corresponding to the same timestamp.)
RemoveOutdatedDataPackets(*min_stream_timestamp);
if (empty) {
if (*min_stream_timestamp == Timestamp::Done()) {
// Calculator is done if the control input stream is done.
@ -78,11 +96,6 @@ class MuxInputStreamHandler : public InputStreamHandler {
const auto& data_stream = input_stream_managers_.Get(
input_stream_managers_.BeginId() + control_value);
// Data stream may contain some outdated packets which failed to be popped
// out during "FillInputSet". (This handler doesn't sync input streams,
// hence "FillInputSet" can be triggerred before every input stream is
// filled with packets corresponding to the same timestamp.)
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
if (empty) {
if (stream_timestamp <= *min_stream_timestamp) {
@ -111,8 +124,7 @@ class MuxInputStreamHandler : public InputStreamHandler {
CHECK(input_set);
absl::MutexLock lock(&input_streams_mutex_);
const CollectionItemId control_stream_id =
input_stream_managers_.EndId() - 1;
const CollectionItemId control_stream_id = GetControlStreamId();
auto& control_stream = input_stream_managers_.Get(control_stream_id);
int num_packets_dropped = 0;
bool stream_is_done = false;
@ -140,15 +152,8 @@ class MuxInputStreamHandler : public InputStreamHandler {
AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet),
stream_is_done);
// Discard old packets on other streams.
// Note that control_stream_id is the last valid id.
auto next_timestamp = input_timestamp.NextAllowedInStream();
for (CollectionItemId id = input_stream_managers_.BeginId();
id < control_stream_id; ++id) {
if (id == data_stream_id) continue;
auto& other_stream = input_stream_managers_.Get(id);
other_stream->ErasePacketsEarlierThan(next_timestamp);
}
// Discard old packets on data streams.
RemoveOutdatedDataPackets(input_timestamp.NextAllowedInStream());
}
private:

View File

@ -645,5 +645,41 @@ TEST(MuxInputStreamHandlerTest,
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, RemovesUnusedDataStreamPackets) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input0"
input_stream: "INPUT:1:input1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
)pb");
config.set_max_queue_size(1);
config.set_report_deadlock(true);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add two delayed packets to the deselected input. They should be discarded
// instead of triggering the deadlock detection (max_queue_size = 1).
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
} // namespace
} // namespace mediapipe

View File

@ -64,13 +64,13 @@ GraphRegistry::GraphRegistry(
void GraphRegistry::Register(
const std::string& type_name,
std::function<std::unique_ptr<Subgraph>()> factory) {
local_factories_.Register(type_name, factory, __FILE__, __LINE__);
local_factories_.Register(type_name, factory);
}
// TODO: Remove this convenience function.
void GraphRegistry::Register(const std::string& type_name,
const CalculatorGraphConfig& config) {
Register(type_name, [config] {
local_factories_.Register(type_name, [config] {
auto result = absl::make_unique<ProtoSubgraph>(config);
return std::unique_ptr<Subgraph>(result.release());
});
@ -79,7 +79,7 @@ void GraphRegistry::Register(const std::string& type_name,
// TODO: Remove this convenience function.
void GraphRegistry::Register(const std::string& type_name,
const CalculatorGraphTemplate& templ) {
Register(type_name, [templ] {
local_factories_.Register(type_name, [templ] {
auto result = absl::make_unique<TemplateSubgraph>(templ);
return std::unique_ptr<Subgraph>(result.release());
});

View File

@ -228,7 +228,9 @@ absl::Status CompareAndSaveImageOutput(
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
options.max_alpha_diff, options.max_avg_diff,
diff_img);
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
if (diff_img) {
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
}
return status;
}

View File

@ -1121,7 +1121,7 @@ objc_library(
alwayslink = 1,
)
MIN_IOS_VERSION = "11.0"
MIN_IOS_VERSION = "12.0"
test_suite(
name = "ios",

View File

@ -109,9 +109,8 @@ absl::Status GlContext::CreateContext(
}
MP_RETURN_IF_ERROR(status);
LOG(INFO) << "Successfully created a WebGL context with major version "
<< gl_major_version_ << " and handle " << context_;
VLOG(1) << "Successfully created a WebGL context with major version "
<< gl_major_version_ << " and handle " << context_;
return absl::OkStatus();
}

View File

@ -104,6 +104,7 @@ class GlScalerCalculator : public CalculatorBase {
bool vertical_flip_output_;
bool horizontal_flip_output_;
FrameScaleMode scale_mode_ = FrameScaleMode::kStretch;
bool use_nearest_neighbor_interpolation_ = false;
};
REGISTER_CALCULATOR(GlScalerCalculator);
@ -186,7 +187,8 @@ absl::Status GlScalerCalculator::Open(CalculatorContext* cc) {
scale_mode_ =
FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch);
}
use_nearest_neighbor_interpolation_ =
options.use_nearest_neighbor_interpolation();
if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) {
const auto& dimensions =
TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)
@ -297,6 +299,11 @@ absl::Status GlScalerCalculator::Process(CalculatorContext* cc) {
glBindTexture(src2.target(), src2.name());
}
if (use_nearest_neighbor_interpolation_) {
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
}
MP_RETURN_IF_ERROR(renderer->GlRender(
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_,
rotation_, horizontal_flip_output_, vertical_flip_output_,

View File

@ -19,7 +19,7 @@ package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto";
// Next id: 8.
// Next id: 9.
message GlScalerCalculatorOptions {
extend CalculatorOptions {
optional GlScalerCalculatorOptions ext = 166373014;
@ -39,4 +39,7 @@ message GlScalerCalculatorOptions {
// Flip the output texture horizontally. This is applied after rotation.
optional bool flip_horizontal = 5;
optional ScaleMode.Mode scale_mode = 6;
// Whether to use nearest neighbor interpolation. Default to use linear
// interpolation.
optional bool use_nearest_neighbor_interpolation = 8 [default = false];
}

View File

@ -100,6 +100,10 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
#endif // TARGET_OS_OSX
}},
{GpuBufferFormat::kOneComponent8Alpha,
{
{GL_ALPHA, GL_ALPHA, GL_UNSIGNED_BYTE, 1},
}},
{GpuBufferFormat::kOneComponent8Red,
{
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
@ -221,6 +225,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
case GpuBufferFormat::kRGBA32:
// TODO: this likely maps to ImageFormat::SRGBA
case GpuBufferFormat::kGrayHalf16:
case GpuBufferFormat::kOneComponent8Alpha:
case GpuBufferFormat::kOneComponent8Red:
case GpuBufferFormat::kTwoComponent8:
case GpuBufferFormat::kTwoComponentHalf16:

View File

@ -43,6 +43,7 @@ enum class GpuBufferFormat : uint32_t {
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
kOneComponent8Alpha = MEDIAPIPE_FOURCC('A', '0', '0', '8'),
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
@ -101,6 +102,7 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
return kCVPixelFormatType_OneComponent32Float;
case GpuBufferFormat::kOneComponent8:
return kCVPixelFormatType_OneComponent8;
case GpuBufferFormat::kOneComponent8Alpha:
case GpuBufferFormat::kOneComponent8Red:
return -1;
case GpuBufferFormat::kTwoComponent8:

View File

@ -78,17 +78,21 @@ public class AppTextureFrame implements TextureFrame {
* Use {@link waitUntilReleasedWithGpuSync} whenever possible.
*/
public void waitUntilReleased() throws InterruptedException {
GlSyncToken tokenToRelease = null;
synchronized (this) {
while (inUse && releaseSyncToken == null) {
wait();
}
if (releaseSyncToken != null) {
releaseSyncToken.waitOnCpu();
releaseSyncToken.release();
tokenToRelease = releaseSyncToken;
inUse = false;
releaseSyncToken = null;
}
}
if (tokenToRelease != null) {
tokenToRelease.waitOnCpu();
tokenToRelease.release();
}
}
/**
@ -98,17 +102,21 @@ public class AppTextureFrame implements TextureFrame {
* TextureFrame.
*/
public void waitUntilReleasedWithGpuSync() throws InterruptedException {
GlSyncToken tokenToRelease = null;
synchronized (this) {
while (inUse && releaseSyncToken == null) {
wait();
}
if (releaseSyncToken != null) {
releaseSyncToken.waitOnGpu();
releaseSyncToken.release();
tokenToRelease = releaseSyncToken;
inUse = false;
releaseSyncToken = null;
}
}
if (tokenToRelease != null) {
tokenToRelease.waitOnGpu();
tokenToRelease.release();
}
}
/**

View File

@ -239,7 +239,7 @@ public final class PacketGetter {
/**
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
* array has the the same size of image list packet, and assumes the output buffer stores pixels
* array has the same size of image list packet, and assumes the output buffer stores pixels
* contiguously. It returns false if this assumption does not hold.
*
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of

View File

@ -24,6 +24,7 @@ package_group(
package_group(
name = "1p_client",
packages = [
"//cloud/ml/applications/vision/model_garden/model_oss/mediapipe/...",
"//research/privacy/learning/fl_eval/pcvr/...",
],
)

View File

@ -57,3 +57,14 @@ py_test(
srcs = ["classification_dataset_test.py"],
deps = [":classification_dataset"],
)
py_library(
name = "cache_files",
srcs = ["cache_files.py"],
)
py_test(
name = "cache_files_test",
srcs = ["cache_files_test.py"],
deps = [":cache_files"],
)

View File

@ -0,0 +1,112 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common TFRecord cache files library."""
import dataclasses
import os
import tempfile
from typing import Any, Mapping, Sequence
import tensorflow as tf
import yaml
# Suffix of the meta data file name.
METADATA_FILE_SUFFIX = '_metadata.yaml'
@dataclasses.dataclass(frozen=True)
class TFRecordCacheFiles:
"""TFRecordCacheFiles dataclass to store and load cached TFRecord files.
Attributes:
cache_prefix_filename: The cache prefix filename. This is usually provided
as a hash of the original data source to avoid different data sources
resulting in the same cache file.
cache_dir: The cache directory to save TFRecord and metadata file. When
cache_dir is None, a temporary folder will be created and will not be
removed automatically after training which makes it can be used later.
num_shards: Number of shards for output tfrecord files.
"""
cache_prefix_filename: str = 'cache_prefix'
cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp)
num_shards: int = 1
def __post_init__(self):
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
if not self.cache_prefix_filename:
raise ValueError('cache_prefix_filename cannot be empty.')
if self.num_shards <= 0:
raise ValueError(
f'num_shards must be greater than 0, got {self.num_shards}'
)
@property
def cache_prefix(self) -> str:
"""The cache prefix including the cache directory and the cache prefix filename."""
return os.path.join(self.cache_dir, self.cache_prefix_filename)
@property
def tfrecord_files(self) -> Sequence[str]:
"""The TFRecord files."""
tfrecord_files = [
self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards)
for i in range(self.num_shards)
]
return tfrecord_files
@property
def metadata_file(self) -> str:
"""The metadata file."""
return self.cache_prefix + METADATA_FILE_SUFFIX
def get_writers(self) -> Sequence[tf.io.TFRecordWriter]:
"""Gets an array of TFRecordWriter objects.
Note that these writers should each be closed using .close() when done.
Returns:
Array of TFRecordWriter objects
"""
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
def save_metadata(self, metadata):
"""Writes metadata to file.
Args:
metadata: A dictionary of metadata content to write. Exact format is
dependent on the specific dataset, but typically includes a 'size' and
'label_names' entry.
"""
with tf.io.gfile.GFile(self.metadata_file, 'w') as f:
yaml.dump(metadata, f)
def load_metadata(self) -> Mapping[Any, Any]:
"""Reads metadata from file.
Returns:
Dictionary object containing metadata
"""
if not tf.io.gfile.exists(self.metadata_file):
return {}
with tf.io.gfile.GFile(self.metadata_file, 'r') as f:
metadata = yaml.load(f, Loader=yaml.FullLoader)
return metadata
def is_cached(self) -> bool:
"""Checks whether this CacheFiles is already cached."""
all_cached_files = list(self.tfrecord_files) + [self.metadata_file]
return all(tf.io.gfile.exists(f) for f in all_cached_files)

View File

@ -0,0 +1,77 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
class CacheFilesTest(tf.test.TestCase):
def test_tfrecord_cache_files(self):
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='tfrecord',
cache_dir='/tmp/cache_dir',
num_shards=2,
)
self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord')
self.assertEqual(
cf.metadata_file,
'/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX,
)
expected_tfrecord_files = [
'/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2)
for i in range(2)
]
self.assertEqual(cf.tfrecord_files, expected_tfrecord_files)
# Writing TFRecord Files
self.assertFalse(cf.is_cached())
for tfrecord_file in cf.tfrecord_files:
self.assertFalse(tf.io.gfile.exists(tfrecord_file))
writers = cf.get_writers()
for writer in writers:
writer.close()
for tfrecord_file in cf.tfrecord_files:
self.assertTrue(tf.io.gfile.exists(tfrecord_file))
self.assertFalse(cf.is_cached())
# Writing Metadata Files
original_metadata = {'size': 10, 'label_names': ['label1', 'label2']}
cf.save_metadata(original_metadata)
self.assertTrue(cf.is_cached())
metadata = cf.load_metadata()
self.assertEqual(metadata, original_metadata)
def test_recordio_cache_files_error(self):
with self.assertRaisesRegex(
ValueError, 'cache_prefix_filename cannot be empty'
):
cache_files.TFRecordCacheFiles(
cache_prefix_filename='',
cache_dir='/tmp/cache_dir',
num_shards=2,
)
with self.assertRaisesRegex(
ValueError, 'num_shards must be greater than 0, got 0'
):
cache_files.TFRecordCacheFiles(
cache_prefix_filename='tfrecord',
cache_dir='/tmp/cache_dir',
num_shards=0,
)
if __name__ == '__main__':
tf.test.main()

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Common classification dataset library."""
from typing import List, Tuple
from typing import List, Optional, Tuple
import tensorflow as tf
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
size: Optional[int] = None,
):
super().__init__(dataset, size)
self._label_names = label_names

View File

@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
value: A value variable stored by the mock dataset class for testing.
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str], value: Any):
super().__init__(dataset=dataset, size=size, label_names=label_names)
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
value: Any,
size: int,
):
super().__init__(dataset=dataset, label_names=label_names, size=size)
self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset(
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
)
# Train/Test data split.
fraction = .25

View File

@ -56,15 +56,14 @@ class Dataset(object):
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
Same functionality as calling __len__. See the __len__ method definition for
more information.
Raises:
TypeError if self._size is not set and the cardinality of self._dataset
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
"""
return self._size
return self.__len__()
def gen_tf_dataset(
self,
@ -116,8 +115,22 @@ class Dataset(object):
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
def __len__(self) -> int:
"""Returns the number of element of the dataset.
If size is not set, this method will fallback to using the __len__ method
of the tf.data.Dataset in self._dataset. Calling __len__ on a
tf.data.Dataset instance may throw a TypeError because the dataset may
be lazy-loaded with an unknown size or have infinite size.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and the _size instance variable will be already set.
Raises:
TypeError if self._size is not set and the cardinality of self._dataset
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
"""
if self._size is not None:
return self._size
else:
@ -152,15 +165,25 @@ class Dataset(object):
Returns:
The splitted two sub datasets.
Raises:
ValueError: if the provided fraction is not between 0 and 1.
ValueError: if this dataset does not have a set size.
"""
assert (fraction > 0 and fraction < 1)
if not (fraction > 0 and fraction < 1):
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
if not self._size:
raise ValueError(
'Dataset size unknown. Cannot split the dataset when '
'the size is unknown.'
)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
return trainset, testset

View File

@ -15,7 +15,7 @@
import dataclasses
import tempfile
from typing import Optional
from typing import Mapping, Optional
import tensorflow as tf
@ -36,6 +36,8 @@ class BaseHParams:
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size divided by batch size.
class_weights: An optional mapping of indices to weights for weighting the
loss function during training.
shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to
@ -57,6 +59,7 @@ class BaseHParams:
batch_size: int
epochs: int
steps_per_epoch: Optional[int] = None
class_weights: Optional[Mapping[int, float]] = None
# Dataset-related parameters
shuffle: bool = False

View File

@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
# dataset is exhausted even if there are epochs remaining.
steps_per_epoch=None,
validation_data=validation_dataset,
callbacks=self._callbacks)
callbacks=self._callbacks,
class_weight=self._hparams.class_weights,
)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.

View File

@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
"""
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
"""Constructor.
"""Initializes FocalLoss.
Args:
gamma: Focal loss gamma, as described in class docs.
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
return tf.reduce_sum(losses) / batch_size
class SparseFocalLoss(FocalLoss):
"""Sparse implementation of Focal Loss.
This is the same as FocalLoss, except the labels are expected to be class ids
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
in this same file for more details.
Example usage:
>>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> gamma = 2
>>> focal_loss = SparseFocalLoss(gamma, 3)
>>> focal_loss(y_true, y_pred).numpy()
0.9326
>>> # Calling with 'sample_weight'.
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.6528
"""
def __init__(
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
):
"""Initializes SparseFocalLoss.
Args:
gamma: Focal loss gamma, as described in class docs.
num_classes: Number of classes.
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super().__init__(gamma, class_weight=class_weight)
self._num_classes = num_classes
def __call__(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> tf.Tensor:
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
@dataclasses.dataclass
class PerceptualLossWeight:
"""The weight for each perceptual loss.

View File

@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(loss, expected_loss, 1e-4)
class SparseFocalLossTest(tf.test.TestCase):
def test_sparse_focal_loss_matches_focal_loss(self):
num_classes = 2
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
y_true = tf.constant([1, 0])
y_true_one_hot = tf.one_hot(y_true, num_classes)
for gamma in [0.0, 0.5, 1.0]:
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
loss_fn = loss_functions.SparseFocalLoss(
gamma=gamma, num_classes=num_classes
)
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
loss = loss_fn(y_true, y_pred)
self.assertNear(loss, expected_loss, 1e-4)
class MockPerceptualLoss(loss_functions.PerceptualLoss):
"""A mock class with implementation of abstract methods for testing."""

View File

@ -46,13 +46,17 @@ class BertModelSpec:
"""
downloaded_files: file_util.DownloadedFiles
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=3,
batch_size=32,
learning_rate=3e-5,
distribution_strategy='mirrored')
model_options: bert_model_options.BertModelOptions = (
bert_model_options.BertModelOptions())
hparams: hp.BaseHParams = dataclasses.field(
default_factory=lambda: hp.BaseHParams(
epochs=3,
batch_size=32,
learning_rate=3e-5,
distribution_strategy='mirrored',
)
)
model_options: bert_model_options.BertModelOptions = dataclasses.field(
default_factory=bert_model_options.BertModelOptions
)
do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict binary and library compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
package(default_visibility = ["//mediapipe:__subpackages__"])
@ -76,7 +76,10 @@ py_test(
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
deps = [
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset",
],
)
py_test(
@ -88,7 +91,10 @@ py_test(
py_library(
name = "preprocessor",
srcs = ["preprocessor.py"],
deps = [":dataset"],
deps = [
":dataset",
"//mediapipe/model_maker/python/core/data:cache_files",
],
)
py_test(
@ -99,6 +105,7 @@ py_test(
":dataset",
":model_spec",
":preprocessor",
"//mediapipe/model_maker/python/core/data:cache_files",
],
)
@ -124,6 +131,7 @@ py_library(
":text_classifier_options",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
@ -147,6 +155,7 @@ py_test(
],
deps = [
":text_classifier_import",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -15,11 +15,15 @@
import csv
import dataclasses
import hashlib
import os
import random
import tempfile
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.core.data import classification_dataset
@ -46,21 +50,49 @@ class CSVParameters:
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for text classifier."""
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
size: Optional[int] = None,
):
super().__init__(dataset, label_names, size)
if not tfrecord_cache_files:
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename="tfrecord", num_shards=1
)
self.tfrecord_cache_files = tfrecord_cache_files
@classmethod
def from_csv(cls,
filename: str,
csv_params: CSVParameters,
shuffle: bool = True) -> "Dataset":
def from_csv(
cls,
filename: str,
csv_params: CSVParameters,
shuffle: bool = True,
cache_dir: Optional[str] = None,
num_shards: int = 1,
) -> "Dataset":
"""Loads text with labels from a CSV file.
Args:
filename: Name of the CSV file.
csv_params: Parameters used for reading the CSV file.
shuffle: If True, randomly shuffle the data.
cache_dir: Optional parameter to specify where to store the preprocessed
dataset. Only used for BERT models.
num_shards: Optional parameter for num shards of the preprocessed dataset.
Note that using more than 1 shard will reorder the dataset. Only used
for BERT models.
Returns:
Dataset containing (text, label) pairs and other related info.
"""
if cache_dir is None:
cache_dir = tempfile.mkdtemp()
# calculate hash for cache based off of files
hasher = hashlib.md5()
hasher.update(os.path.basename(filename).encode("utf-8"))
with tf.io.gfile.GFile(filename, "r") as f:
reader = csv.DictReader(
f,
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
quotechar=csv_params.quotechar)
lines = list(reader)
for line in lines:
hasher.update(str(line).encode("utf-8"))
if shuffle:
random.shuffle(lines)
@ -81,8 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
index_by_label[line[csv_params.label_column]] for line in lines
]
label_index_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(label_indices, tf.int64))
tf.cast(label_indices, tf.int64)
)
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
hasher.update(str(num_shards).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename, cache_dir, num_shards
)
return Dataset(
dataset=text_label_ds, size=len(texts), label_names=label_names)
dataset=text_label_ds,
label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
size=len(texts),
)

View File

@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd']

View File

@ -15,7 +15,7 @@
import dataclasses
import enum
from typing import Union
from typing import Sequence, Union
from mediapipe.model_maker.python.core import hyperparameters as hp
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
optimizer: Optimizer to use for training. Only supported values are "adamw"
and "lamb".
end_learning_rate: End learning rate for linear decay. Defaults to 0.
batch_size: Batch size for training. Defaults to 48.
epochs: Number of training iterations over the dataset. Defaults to 2.
optimizer: Optimizer to use for training. Supported values are defined in
BertOptimizer enum: ADAMW and LAMB.
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0.
"""
learning_rate: float = 3e-5
end_learning_rate: float = 0.0
batch_size: int = 48
epochs: int = 2
optimizer: BertOptimizer = BertOptimizer.ADAMW
weight_decay: float = 0.01
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
gamma: float = 2.0
HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
"""
# `learning_rate` is unused for the average word embedding model
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0
hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
default_factory=lambda: hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0
)
)
model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
default_factory=mo.AverageWordEmbeddingModelOptions
)
model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingModelOptions())
name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial(
@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
inherited from the BertModelSpec.
"""
hparams: hp.BertHParams = hp.BertHParams()
hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
mobilebert_classifier_spec = functools.partial(
@ -76,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='MobileBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)
exbert_classifier_spec = functools.partial(
@ -90,11 +88,6 @@ exbert_classifier_spec = functools.partial(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='ExBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)

View File

@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual(
model_spec_obj.tflite_input_name, {
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0'
})
model_spec_obj.tflite_input_name,
{
'ids': 'serving_default_input_word_ids:0',
'mask': 'serving_default_input_mask:0',
'segment_ids': 'serving_default_input_type_ids:0',
},
)
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.BertModelOptions(

View File

@ -15,14 +15,15 @@
"""Preprocessors for text classification."""
import collections
import hashlib
import os
import re
import tempfile
from typing import Mapping, Sequence, Tuple, Union
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
@ -75,19 +76,20 @@ def _decode_record(
return bert_features, example["label_ids"]
def _single_file_dataset(
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
def _tfrecord_dataset(
tfrecord_files: Sequence[str],
name_to_features: Mapping[str, tf.io.FixedLenFeature],
) -> tf.data.TFRecordDataset:
"""Creates a single-file dataset to be passed for BERT custom training.
Args:
input_file: Filepath for the dataset.
tfrecord_files: Filepaths for the dataset.
name_to_features: Maps record keys to feature types.
Returns:
Dataset containing BERT model input features and labels.
"""
d = tf.data.TFRecordDataset(input_file)
d = tf.data.TFRecordDataset(tfrecord_files)
d = d.map(
lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.AUTOTUNE)
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
seq_len: Length of the input sequence to the model.
vocab_file: File containing the BERT vocab.
tokenizer: BERT tokenizer.
model_name: Name of the model provided by the model_spec. Used to associate
cached files with specific Bert model vocab.
"""
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
def __init__(
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
):
self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI.
self._vocab_file = os.path.join(
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
do_lower_case)
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
)
self._do_lower_case = do_lower_case
self._tokenizer = tokenization.FullTokenizer(
self._vocab_file, self._do_lower_case
)
self._model_name = model_name
def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types."""
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
"""Returns the vocab file of the BertClassifierPreprocessor."""
return self._vocab_file
def _get_tfrecord_cache_files(
self, ds_cache_files
) -> cache_files_lib.TFRecordCacheFiles:
"""Helper to regenerate cache prefix filename using preprocessor info.
We need to update the dataset cache_prefix cache because the actual cached
dataset depends on the preprocessor parameters such as model_name, seq_len,
and do_lower_case in addition to the raw dataset parameters which is already
included in the ds_cache_files.cache_prefix_filename
Specifically, the new cache_prefix_filename used by the preprocessor will
be a hash generated from the following:
1. cache_prefix_filename of the initial raw dataset
2. model_name
3. seq_len
4. do_lower_case
Args:
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
Returns:
A new TFRecordCacheFiles object which incorporates the preprocessor
parameters.
"""
hasher = hashlib.md5()
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
hasher.update(self._model_name.encode("utf-8"))
hasher.update(str(self._seq_len).encode("utf-8"))
hasher.update(str(self._do_lower_case).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
return cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename,
ds_cache_files.cache_dir,
ds_cache_files.num_shards,
)
def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for a BERT-based classifier.
Args:
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
Returns:
Dataset containing (bert_features, label) data.
"""
examples = []
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label)
examples.append(
classifier_data_lib.InputExample(
guid=str(index),
text_a=text.numpy()[0].decode("utf-8"),
text_b=None,
# InputExample expects the label name rather than the int ID
label=dataset.label_names[label.numpy()[0]]))
ds_cache_files = dataset.tfrecord_cache_files
# Get new tfrecord_cache_files by including preprocessor information.
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
if not tfrecord_cache_files.is_cached():
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
writers = tfrecord_cache_files.get_writers()
size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label)
example = classifier_data_lib.InputExample(
guid=str(index),
text_a=text.numpy()[0].decode("utf-8"),
text_b=None,
# InputExample expects the label name rather than the int ID
# label=dataset.label_names[label.numpy()[0]])
label=label.numpy()[0],
)
feature = classifier_data_lib.convert_single_example(
index, example, None, self._seq_len, self._tokenizer
)
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
classifier_data_lib.file_based_convert_examples_to_features(
examples=examples,
label_list=dataset.label_names,
max_seq_length=self._seq_len,
tokenizer=self._tokenizer,
output_file=tfrecord_file)
preprocessed_ds = _single_file_dataset(tfrecord_file,
self._get_name_to_features())
def create_int_feature(values):
f = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values))
)
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
tf_example = tf.train.Example(
features=tf.train.Features(feature=features)
)
writers[index % len(writers)].write(tf_example.SerializeToString())
size = index + 1
for writer in writers:
writer.close()
metadata = {"size": size, "label_names": dataset.label_names}
tfrecord_cache_files.save_metadata(metadata)
else:
print(
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
)
metadata = tfrecord_cache_files.load_metadata()
size = metadata["size"]
label_names = metadata["label_names"]
preprocessed_ds = _tfrecord_dataset(
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
)
return text_classifier_ds.Dataset(
dataset=preprocessed_ds,
size=dataset.size,
label_names=dataset.label_names)
size=size,
label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
)
TextClassifierPreprocessor = (
Union[BertClassifierPreprocessor,
AverageWordEmbeddingClassifierPreprocessor])
TextClassifierPreprocessor = Union[
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
]

View File

@ -13,14 +13,17 @@
# limitations under the License.
import csv
import io
import os
import tempfile
from unittest import mock as unittest_mock
import mock
import numpy as np
import numpy.testing as npt
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = []
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0])
self.assertSameElements(
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
)
for feature in features.values():
self.assertEqual(feature.shape, [1, 5])
input_masks.append(features['input_mask'].numpy()[0])
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
[0, 0, 0, 0, 0])
npt.assert_array_equal(
features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
)
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
)
self.assertEqual(labels, [1, 0])
def test_bert_preprocessor_cache(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file,
csv_params=self.CSV_PARAMS_,
cache_dir=self.get_temp_dir(),
)
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
ds_cache_files = dataset.tfrecord_cache_files
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
ds_cache_files
)
self.assertFalse(preprocessed_cache_files.is_cached())
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
self.assertTrue(preprocessed_cache_files.is_cached())
self.assertEqual(
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
)
# The second time running preprocessor, it should load from cache directly
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
_ = bert_preprocessor.preprocess(dataset)
self.assertEqual(
mock_stdout.getvalue(),
'Using existing cache files at'
f' {preprocessed_cache_files.cache_prefix}\n',
)
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len,
do_lower_case=do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
return new_cf.cache_prefix_filename
def test_bert_get_tfrecord_cache_files(self):
# Test to ensure regenerated cache_files have different prefixes
all_cf_prefixes = set()
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
new_cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='new_cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
# Each item of all_cf_prefixes should be unique, so 7 total.
self.assertLen(all_cf_prefixes, 7)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -16,8 +16,8 @@
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
@ -27,8 +27,8 @@
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {

View File

@ -24,6 +24,7 @@ import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
):
text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data,
options,
train_data.label_names))
text_classifier = _BertClassifier.create_bert_classifier(
train_data, validation_data, options
)
elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
text_classifier = (
_AverageWordEmbeddingClassifier
.create_average_word_embedding_classifier(train_data, validation_data,
options,
train_data.label_names))
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
train_data, validation_data, options
)
else:
raise ValueError(f"Unknown model {options.supported_model}")
@ -166,28 +164,8 @@ class TextClassifier(classifier.Classifier):
processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
additional_metrics = []
if desired_precisions and len(data.label_names) == 2:
for precision in desired_precisions:
additional_metrics.append(
metrics.BinarySparseRecallAtPrecision(
precision, name=f"recall_at_precision_{precision}"
)
)
if desired_recalls and len(data.label_names) == 2:
for recall in desired_recalls:
additional_metrics.append(
metrics.BinarySparsePrecisionAtRecall(
recall, name=f"precision_at_recall_{recall}"
)
)
metric_functions = self._metric_functions + additional_metrics
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=metric_functions,
)
return self._model.evaluate(dataset)
with self._hparams.get_strategy().scope():
return self._model.evaluate(dataset)
def export_model(
self,
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
@classmethod
def create_average_word_embedding_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
cls,
train_data: text_ds.Dataset,
validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
) -> "_AverageWordEmbeddingClassifier":
"""Creates, trains, and returns an Average Word Embedding classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
An Average Word Embedding classifier.
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options
with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_functions = [
tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
self._loss_function = loss_functions.SparseFocalLoss(
self._hparams.gamma, self._num_classes
)
self._metric_functions = self._create_metrics()
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
def create_bert_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
cls,
train_data: text_ds.Dataset,
validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_BertClassifier":
) -> "_BertClassifier":
"""Creates, trains, and returns a BERT-based classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
A BERT-based classifier.
@ -435,9 +411,59 @@ class _BertClassifier(TextClassifier):
seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.downloaded_files.get_path(),
model_name=self._model_spec.name,
)
return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data))
return (
self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data),
)
def _create_metrics(self):
"""Creates metrics for training and evaluation.
The default metrics are accuracy, precision, and recall.
For binary classification tasks only (num_classes=2):
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
the desired_presisions and desired_recalls fields in BertHParams.
Returns:
A list of tf.keras.Metric subclasses which can be used with model.compile
"""
metric_functions = [
tf.keras.metrics.SparseCategoricalAccuracy(
"accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
if self._num_classes == 2:
if self._hparams.desired_precisions:
for desired_precision in self._hparams.desired_precisions:
metric_functions.append(
metrics.BinarySparseRecallAtPrecision(
desired_precision,
name=f"recall_at_precision_{desired_precision}",
num_thresholds=1000,
)
)
if self._hparams.desired_recalls:
for desired_recall in self._hparams.desired_recalls:
metric_functions.append(
metrics.BinarySparseRecallAtPrecision(
desired_recall,
name=f"precision_at_recall_{desired_recall}",
num_thresholds=1000,
)
)
else:
if self._hparams.desired_precisions or self._hparams.desired_recalls:
raise ValueError(
"desired_recalls and desired_precisions parameters are binary"
" metrics and not supported for num_classes > 2. Found"
f" num_classes: {self._num_classes}"
)
return metric_functions
def _create_model(self):
"""Creates a BERT-based classifier model.
@ -447,11 +473,20 @@ class _BertClassifier(TextClassifier):
"""
encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_word_ids",
),
input_mask=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_mask",
),
input_type_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_type_ids",
),
)
encoder = hub.KerasLayer(
self._model_spec.downloaded_files.get_path(),
@ -493,16 +528,21 @@ class _BertClassifier(TextClassifier):
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr,
decay_steps=total_steps,
end_learning_rate=0.0,
power=1.0)
end_learning_rate=self._hparams.end_learning_rate,
power=1.0,
)
if warmup_steps:
lr_schedule = model_util.WarmUp(
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
warmup_steps=warmup_steps,
)
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
lr_schedule,
weight_decay=self._hparams.weight_decay,
epsilon=1e-6,
global_clipnorm=1.0,
)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"]
@ -510,7 +550,7 @@ class _BertClassifier(TextClassifier):
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB(
lr_schedule,
weight_decay_rate=0.01,
weight_decay_rate=self._hparams.weight_decay,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0,

View File

@ -84,8 +84,8 @@ def run(data_dir,
options)
# Gets evaluation results.
_, acc = model.evaluate(validation_data)
print('Eval accuracy: %f' % acc)
metrics = model.evaluate(validation_data)
print('Eval accuracy: %f' % metrics[1])
model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir=options.hparams.export_dir)

View File

@ -16,17 +16,17 @@ import csv
import filecmp
import os
import tempfile
import unittest
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils
@unittest.skip('b/275624089')
class TextClassifierTest(tf.test.TestCase):
class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifier.create(train_data, validation_data,
options))
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
metrics = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
# Test export_model
average_word_embedding_classifier.export_model()
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
filecmp.cmp(
output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
shallow=False))
shallow=False,
)
)
def test_create_and_train_bert(self):
@parameterized.named_parameters(
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
# dict(
# testcase_name='mobilebert',
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
# ),
dict(
testcase_name='exbert',
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
),
)
def test_create_and_train_bert(self, supported_model):
train_data, validation_data = self._get_data()
options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
supported_model=supported_model,
model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2
),
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)
_, accuracy = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
metrics = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
# Test export_model
bert_classifier.export_model()
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
)
def test_label_mismatch(self):
options = (
text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
options = text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
)
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
with self.assertRaisesRegex(
ValueError,
'Training data label names .* not equal to validation data label names'
'Training data label names .* not equal to validation data label names',
):
text_classifier.TextClassifier.create(train_data, validation_data,
options)
text_classifier.TextClassifier.create(
train_data, validation_data, options
)
def test_options_mismatch(self):
train_data, validation_data = self._get_data()
avg_options = (
text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'):
text_classifier.TextClassifier.create(train_data, validation_data,
avg_options)
avg_options = text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
)
with self.assertRaisesWithLiteralMatch(
ValueError,
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.EXBERT_CLASSIFIER',
):
text_classifier.TextClassifier.create(
train_data, validation_data, avg_options
)
bert_options = (
text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=text_classifier.BertModelOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
text_classifier.TextClassifier.create(train_data, validation_data,
bert_options)
bert_options = text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
),
model_options=text_classifier.BertModelOptions(),
)
with self.assertRaisesWithLiteralMatch(
ValueError,
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
):
text_classifier.TextClassifier.create(
train_data, validation_data, bert_options
)
def test_bert_loss_and_metrics_creation(self):
train_data, validation_data = self._get_data()
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
hparams = text_classifier.BertHParams(
desired_recalls=[0.2],
desired_precisions=[0.9],
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off',
gamma=3.5,
)
options = text_classifier.TextClassifierOptions(
supported_model=supported_model, hparams=hparams
)
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options
)
loss_fn = bert_classifier._loss_function
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
self.assertEqual(loss_fn._gamma, 3.5)
self.assertEqual(loss_fn._num_classes, 2)
metric_names = [m.name for m in bert_classifier._metric_functions]
expected_metric_names = [
'accuracy',
'recall',
'precision',
'precision_at_recall_0.2',
'recall_at_precision_0.9',
]
self.assertCountEqual(metric_names, expected_metric_names)
# Non-binary data
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
with self.assertRaisesWithLiteralMatch(
ValueError,
'desired_recalls and desired_precisions parameters are binary metrics'
' and not supported for num_classes > 2. Found num_classes: 3',
):
text_classifier.TextClassifier.create(data, data, options)
if __name__ == '__main__':

View File

@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
dataset=image_label_ds,
label_names=label_names,
size=all_image_size,
)

View File

@ -13,7 +13,7 @@
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict binary and library compatibility macro.
licenses(["notice"])

View File

@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
len(valid_hand_data), len(label_names), ','.join(label_names)))
return Dataset(
dataset=hand_embedding_label_ds,
label_names=label_names,
size=len(valid_hand_data),
label_names=label_names)
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict binary and library compatibility macro.
# Placeholder for internal Python library rule.
licenses(["notice"])

View File

@ -15,28 +15,12 @@
import os
import random
from typing import List, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
def _create_data(
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
label_names: List[str]
) -> Optional[classification_dataset.ClassificationDataset]:
"""Creates a Dataset object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return Dataset(data, size, label_names)
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for image classifier."""
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names)
dataset=image_label_ds, label_names=label_names, size=all_image_size
)

View File

@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2)

View File

@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
ds = tf.data.Dataset.from_generator(
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
['cyan', 'magenta', 'yellow'])
data = image_classifier.Dataset(
ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
)
return data
def setUp(self):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict binary and library compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
@ -54,6 +54,7 @@ py_library(
srcs = ["dataset.py"],
deps = [
":dataset_util",
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset",
],
)
@ -73,6 +74,7 @@ py_test(
py_library(
name = "dataset_util",
srcs = ["dataset_util.py"],
deps = ["//mediapipe/model_maker/python/core/data:cache_files"],
)
py_test(

View File

@ -16,8 +16,8 @@
from typing import Optional
import tensorflow as tf
import yaml
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.object_detector import dataset_util
from official.vision.dataloaders import tf_example_decoder
@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset):
ValueError: If the label_name for id 0 is set to something other than
the 'background' class.
"""
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
if not dataset_util.is_cached(cache_files):
tfrecord_cache_files = dataset_util.get_cache_files_coco(
data_dir, cache_dir
)
if not tfrecord_cache_files.is_cached():
label_map = dataset_util.get_label_map_coco(data_dir)
cache_writer = dataset_util.COCOCacheFilesWriter(
label_map=label_map, max_num_images=max_num_images
)
cache_writer.write_files(cache_files, data_dir)
return cls.from_cache(cache_files.cache_prefix)
cache_writer.write_files(tfrecord_cache_files, data_dir)
return cls.from_cache(tfrecord_cache_files)
@classmethod
def from_pascal_voc_folder(
@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset):
Raises:
ValueError: if the input data directory is empty.
"""
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
if not dataset_util.is_cached(cache_files):
tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc(
data_dir, cache_dir
)
if not tfrecord_cache_files.is_cached():
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
cache_writer = dataset_util.PascalVocCacheFilesWriter(
label_map=label_map, max_num_images=max_num_images
)
cache_writer.write_files(cache_files, data_dir)
cache_writer.write_files(tfrecord_cache_files, data_dir)
return cls.from_cache(cache_files.cache_prefix)
return cls.from_cache(tfrecord_cache_files)
@classmethod
def from_cache(cls, cache_prefix: str) -> 'Dataset':
def from_cache(
cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles
) -> 'Dataset':
"""Loads the TFRecord data from cache.
Args:
cache_prefix: The cache prefix including the cache directory and the cache
prefix filename, e.g: '/tmp/cache/train'.
tfrecord_cache_files: The TFRecordCacheFiles object containing the already
cached TFRecord and metadata files.
Returns:
ObjectDetectorDataset object.
Raises:
ValueError if tfrecord_cache_files are not already cached.
"""
# Get TFRecord Files
tfrecord_file_pattern = cache_prefix + '*.tfrecord'
matched_files = tf.io.gfile.glob(tfrecord_file_pattern)
if not matched_files:
raise ValueError('TFRecord files are empty.')
if not tfrecord_cache_files.is_cached():
raise ValueError(
'Cache files must be already cached to use the from_cache method.'
)
# Load meta_data.
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
if not tf.io.gfile.exists(meta_data_file):
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
meta_data = yaml.load(f, Loader=yaml.FullLoader)
metadata = tfrecord_cache_files.load_metadata()
dataset = tf.data.TFRecordDataset(matched_files)
dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files)
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
label_map = meta_data['label_map']
label_map = metadata['label_map']
label_names = [label_map[k] for k in sorted(label_map.keys())]
return Dataset(
dataset=dataset, size=meta_data['size'], label_names=label_names
dataset=dataset, label_names=label_names, size=metadata['size']
)

View File

@ -15,25 +15,20 @@
import abc
import collections
import dataclasses
import hashlib
import json
import math
import os
import tempfile
from typing import Any, Dict, List, Mapping, Optional, Sequence
from typing import Any, Dict, List, Mapping, Optional
import xml.etree.ElementTree as ET
import tensorflow as tf
import yaml
from mediapipe.model_maker.python.core.data import cache_files
from official.vision.data import tfrecord_lib
# Suffix of the meta data file name.
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
def _xml_get(node: ET.Element, name: str) -> ET.Element:
"""Gets a named child from an XML Element node.
@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str:
return os.path.basename(os.path.abspath(data_dir))
@dataclasses.dataclass(frozen=True)
class CacheFiles:
"""Cache files for object detection."""
cache_prefix: str
tfrecord_files: Sequence[str]
meta_data_file: str
def _get_cache_files(
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
) -> CacheFiles:
) -> cache_files.TFRecordCacheFiles:
"""Creates an object of CacheFiles class.
Args:
@ -96,28 +82,16 @@ def _get_cache_files(
An object of CacheFiles class.
"""
cache_dir = _get_cache_dir_or_create(cache_dir)
# The cache prefix including the cache directory and the cache prefix
# filename, e.g: '/tmp/cache/train'.
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
tf.compat.v1.logging.info(
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
% (cache_dir, cache_prefix_filename, cache_prefix)
)
# Cached files including the TFRecord files and the meta data file.
tfrecord_files = [
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
for i in range(num_shards)
]
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
return CacheFiles(
cache_prefix=cache_prefix,
tfrecord_files=tuple(tfrecord_files),
meta_data_file=meta_data_file,
return cache_files.TFRecordCacheFiles(
cache_prefix_filename=cache_prefix_filename,
cache_dir=cache_dir,
num_shards=num_shards,
)
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
def get_cache_files_coco(
data_dir: str, cache_dir: str
) -> cache_files.TFRecordCacheFiles:
"""Creates an object of CacheFiles class using a COCO formatted dataset.
Args:
@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
def get_cache_files_pascal_voc(
data_dir: str, cache_dir: str
) -> cache_files.TFRecordCacheFiles:
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
Args:
@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
def is_cached(cache_files: CacheFiles) -> bool:
"""Checks whether cache files are already cached."""
all_cached_files = list(cache_files.tfrecord_files) + [
cache_files.meta_data_file
]
return all(tf.io.gfile.exists(path) for path in all_cached_files)
class CacheFilesWriter(abc.ABC):
"""CacheFilesWriter class to write the cached files."""
@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC):
self.label_map = label_map
self.max_num_images = max_num_images
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
"""Writes TFRecord and meta_data files.
def write_files(
self,
tfrecord_cache_files: cache_files.TFRecordCacheFiles,
*args,
**kwargs,
) -> None:
"""Writes TFRecord and metadata files.
Args:
cache_files: CacheFiles object including a list of TFRecord files and the
meta data yaml file to save the meta_data including data size and
label_map.
tfrecord_cache_files: TFRecordCacheFiles object including a list of
TFRecord files and the meta data yaml file to save the metadata
including data size and label_map.
*args: Non-keyword of parameters used in the `_get_example` method.
**kwargs: Keyword parameters used in the `_get_example` method.
"""
writers = [
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
]
writers = tfrecord_cache_files.get_writers()
# Writes tf.Example into TFRecord files.
size = 0
@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC):
for writer in writers:
writer.close()
# Writes meta_data into meta_data_file.
meta_data = {'size': size, 'label_map': self.label_map}
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
yaml.dump(meta_data, f)
# Writes metadata into metadata_file.
metadata = {'size': size, 'label_map': self.label_map}
tfrecord_cache_files.save_metadata(metadata)
@abc.abstractmethod
def _get_example(self, *args, **kwargs):

View File

@ -19,7 +19,6 @@ import shutil
from unittest import mock as unittest_mock
import tensorflow as tf
import yaml
from mediapipe.model_maker.python.vision.core import test_utils
from mediapipe.model_maker.python.vision.object_detector import dataset_util
@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase):
def _assert_cache_files_equal(self, cf1, cf2):
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
self.assertEqual(cf1.num_shards, cf2.num_shards)
def _assert_cache_files_not_equal(self, cf1, cf2):
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertEqual(
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
)
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
def test_matching_get_cache_files_coco(self):
cache_dir = self.create_tempdir()
@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertEqual(
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
)
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
def test_matching_get_cache_files_pascal_voc(self):
cache_dir = self.create_tempdir()
@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase):
cache_files = dataset_util.get_cache_files_coco(
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
)
self.assertFalse(dataset_util.is_cached(cache_files))
self.assertFalse(cache_files.is_cached())
with open(cache_files.tfrecord_files[0], 'w') as f:
f.write('test')
self.assertFalse(dataset_util.is_cached(cache_files))
with open(cache_files.meta_data_file, 'w') as f:
self.assertFalse(cache_files.is_cached())
with open(cache_files.metadata_file, 'w') as f:
f.write('test')
self.assertTrue(dataset_util.is_cached(cache_files))
self.assertTrue(cache_files.is_cached())
def test_get_label_map_coco(self):
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
# Checks the meta_data file
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
# Size is 3 because some examples are skipped for having poor bboxes
self.assertEqual(meta_data_dict['size'], expected_size)
# Checks the metadata file
self.assertTrue(os.path.isfile(cache_files.metadata_file))
self.assertGreater(os.path.getsize(cache_files.metadata_file), 0)
metadata_dict = cache_files.load_metadata()
self.assertEqual(metadata_dict['size'], expected_size)
def test_coco_cache_files_writer(self):
tempdir = self.create_tempdir()

Some files were not shown because too many files have changed in this diff Show More