Merge branch 'master' into nguyencse/facemeshioslib

This commit is contained in:
nguyencse 2023-08-02 18:00:22 +07:00
commit 1001ead358
190 changed files with 4258 additions and 1068 deletions

View File

@ -157,22 +157,22 @@ http_archive(
# 2020-08-21 # 2020-08-21
http_archive( http_archive(
name = "com_github_glog_glog", name = "com_github_glog_glog",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
urls = [ urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
], ],
) )
http_archive( http_archive(
name = "com_github_glog_glog_no_gflags", name = "com_github_glog_glog_no_gflags",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
build_file = "@//third_party:glog_no_gflags.BUILD", build_file = "@//third_party:glog_no_gflags.BUILD",
urls = [ urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
], ],
patches = [ patches = [
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff", "@//third_party:com_github_glog_glog.diff",
], ],
patch_args = [ patch_args = [
"-p1", "-p1",

View File

@ -68,30 +68,108 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Note: this cannot just match "apple_platform_type": "macos" because that option # Generic MacOS.
# defaults to "macos" even when building on Linux! config_setting(
alias(
name = "macos", name = "macos",
actual = select({ constraint_values = [
":macos_i386": ":macos_i386", "@platforms//os:macos",
":macos_x86_64": ":macos_x86_64", ],
":macos_arm64": ":macos_arm64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Note: this also matches on crosstool_top so that it does not produce ambiguous # MacOS x86 64-bit.
# selectors when used together with "android". config_setting(
name = "macos_x86_64",
constraint_values = [
"@platforms//os:macos",
"@platforms//cpu:x86_64",
],
visibility = ["//visibility:public"],
)
# MacOS ARM64.
config_setting(
name = "macos_arm64",
constraint_values = [
"@platforms//os:macos",
"@platforms//cpu:arm64",
],
visibility = ["//visibility:public"],
)
# Generic iOS.
config_setting( config_setting(
name = "ios", name = "ios",
values = { constraint_values = [
"crosstool_top": "@bazel_tools//tools/cpp:toolchain", "@platforms//os:ios",
"apple_platform_type": "ios", ],
},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# iOS device ARM32.
config_setting(
name = "ios_armv7",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm",
],
visibility = ["//visibility:public"],
)
# iOS device ARM64.
config_setting(
name = "ios_arm64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64",
],
visibility = ["//visibility:public"],
)
# iOS device ARM64E.
config_setting(
name = "ios_arm64e",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64e",
],
visibility = ["//visibility:public"],
)
# iOS simulator x86 32-bit.
config_setting(
name = "ios_i386",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:x86_32",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# iOS simulator x86 64-bit.
config_setting(
name = "ios_x86_64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:x86_64",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# iOS simulator ARM64.
config_setting(
name = "ios_sim_arm64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# Generic Apple.
alias( alias(
name = "apple", name = "apple",
actual = select({ actual = select({
@ -102,49 +180,6 @@ alias(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "macos_i386",
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_x86_64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_x86_64",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_arm64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_arm64",
},
visibility = ["//visibility:public"],
)
[
config_setting(
name = arch,
values = {"cpu": arch},
visibility = ["//visibility:public"],
)
for arch in [
"ios_i386",
"ios_x86_64",
"ios_armv7",
"ios_arm64",
"ios_arm64e",
"ios_sim_arm64",
]
]
config_setting( config_setting(
name = "windows", name = "windows",
values = {"cpu": "x64_windows"}, values = {"cpu": "x64_windows"},

View File

@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator);
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0). // Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518; const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
namespace {
std::unique_ptr<audio_dsp::WindowFunction> MakeWindowFun(
const SpectrogramCalculatorOptions::WindowType window_type) {
switch (window_type) {
// The cosine window and square root of Hann are equivalent.
case SpectrogramCalculatorOptions::COSINE:
case SpectrogramCalculatorOptions::SQRT_HANN:
return std::make_unique<audio_dsp::CosineWindow>();
case SpectrogramCalculatorOptions::HANN:
return std::make_unique<audio_dsp::HannWindow>();
case SpectrogramCalculatorOptions::HAMMING:
return std::make_unique<audio_dsp::HammingWindow>();
}
return nullptr;
}
} // namespace
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
SpectrogramCalculatorOptions spectrogram_options = SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>(); cc->Options<SpectrogramCalculatorOptions>();
@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
output_scale_ = spectrogram_options.output_scale(); output_scale_ = spectrogram_options.output_scale();
auto window_fun = MakeWindowFun(spectrogram_options.window_type());
if (window_fun == nullptr) {
return absl::Status(absl::StatusCode::kInvalidArgument,
absl::StrCat("Invalid window type ",
spectrogram_options.window_type()));
}
std::vector<double> window; std::vector<double> window;
switch (spectrogram_options.window_type()) { window_fun->GetPeriodicSamples(frame_duration_samples_, &window);
case SpectrogramCalculatorOptions::COSINE:
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::SQRT_HANN: {
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
absl::c_transform(window, window.begin(),
[](double x) { return std::sqrt(x); });
break;
}
}
// Propagate settings down to the actual Spectrogram object. // Propagate settings down to the actual Spectrogram object.
spectrogram_generators_.clear(); spectrogram_generators_.clear();

View File

@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions {
HANN = 0; HANN = 0;
HAMMING = 1; HAMMING = 1;
COSINE = 2; COSINE = 2;
SQRT_HANN = 4; SQRT_HANN = 4; // Alias of COSINE.
} }
optional WindowType window_type = 6 [default = HANN]; optional WindowType window_type = 6 [default = HANN];

View File

@ -381,17 +381,6 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "clip_detection_vector_size_calculator",
srcs = ["clip_detection_vector_size_calculator.cc"],
deps = [
":clip_vector_size_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "clip_vector_size_calculator_test", name = "clip_vector_size_calculator_test",
srcs = ["clip_vector_size_calculator_test.cc"], srcs = ["clip_vector_size_calculator_test.cc"],

View File

@ -76,4 +76,9 @@ REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator);
// A calculator to process std::vector<mediapipe::Image>. // A calculator to process std::vector<mediapipe::Image>.
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator; typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
REGISTER_CALCULATOR(BeginLoopImageCalculator); REGISTER_CALCULATOR(BeginLoopImageCalculator);
// A calculator to process std::vector<float>.
typedef BeginLoopCalculator<std::vector<float>> BeginLoopFloatCalculator;
REGISTER_CALCULATOR(BeginLoopFloatCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -123,8 +123,11 @@ class PreviousLoopbackCalculator : public Node {
// However, LOOP packet is empty. // However, LOOP packet is empty.
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
} else { } else {
// Avoids sending leftovers to a stream that's already closed.
if (!kPrevLoop(cc).IsClosed()) {
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
} }
}
loop_packets_.pop_front(); loop_packets_.pop_front();
main_packet_specs_.pop_front(); main_packet_specs_.pop_front();
} }

View File

@ -135,7 +135,6 @@ cc_library(
deps = [ deps = [
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],

View File

@ -112,7 +112,7 @@ class BilateralFilterCalculator : public CalculatorBase {
REGISTER_CALCULATOR(BilateralFilterCalculator); REGISTER_CALCULATOR(BilateralFilterCalculator);
absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) { absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
if (cc->Inputs().HasTag(kInputFrameTag) && if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().HasTag(kInputFrameTagGpu)) {

View File

@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator);
absl::Status SegmentationSmoothingCalculator::GetContract( absl::Status SegmentationSmoothingCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
cc->Inputs().Tag(kCurrentMaskTag).Set<Image>(); cc->Inputs().Tag(kCurrentMaskTag).Set<Image>();
cc->Inputs().Tag(kPreviousMaskTag).Set<Image>(); cc->Inputs().Tag(kPreviousMaskTag).Set<Image>();

View File

@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase {
REGISTER_CALCULATOR(SetAlphaCalculator); REGISTER_CALCULATOR(SetAlphaCalculator);
absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false; bool use_gpu = false;

View File

@ -282,13 +282,17 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
if (options.has_volume_gain_db()) { if (options.has_volume_gain_db()) {
gain_ = pow(10, options.volume_gain_db() / 20.0); gain_ = pow(10, options.volume_gain_db() / 20.0);
} }
if (options.has_source_sample_rate()) {
source_sample_rate_ = options.source_sample_rate();
} else {
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
!kAudioIn(cc).Header().IsEmpty()) !kAudioIn(cc).Header().IsEmpty())
<< "Must either specify the time series header of the \"AUDIO\" stream " << "Must either specify the time series header of the \"AUDIO\" stream "
"or have the \"SAMPLE_RATE\" stream connected."; "or have the \"SAMPLE_RATE\" stream connected.";
if (!kAudioIn(cc).Header().IsEmpty()) { if (!kAudioIn(cc).Header().IsEmpty()) {
mediapipe::TimeSeriesHeader input_header; mediapipe::TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( MP_RETURN_IF_ERROR(
mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
kAudioIn(cc).Header(), &input_header)); kAudioIn(cc).Header(), &input_header));
if (stream_mode_) { if (stream_mode_) {
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
@ -296,6 +300,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
source_sample_rate_ = input_header.sample_rate(); source_sample_rate_ = input_header.sample_rate();
} }
} }
}
AppendZerosToSampleBuffer(padding_samples_before_); AppendZerosToSampleBuffer(padding_samples_before_);
if (options.has_fft_size()) { if (options.has_fft_size()) {
RET_CHECK(IsValidFftSize(options.fft_size())) RET_CHECK(IsValidFftSize(options.fft_size()))

View File

@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions {
// The volume gain, measured in dB. // The volume gain, measured in dB.
// Scale the input audio amplitude by 10^(volume_gain_db/20). // Scale the input audio amplitude by 10^(volume_gain_db/20).
optional double volume_gain_db = 12; optional double volume_gain_db = 12;
// The source number of samples per second (hertz) of the input audio buffers.
optional double source_sample_rate = 13;
} }

View File

@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
gpu_delegate_options); gpu_delegate_options);
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool UseSerializedModel() const { return use_serialized_model_; }
private: private:
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
} }
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() { absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
MP_RETURN_IF_ERROR(
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
return gpu_helper_.RunInGlContext([this]() -> absl::Status { return gpu_helper_.RunInGlContext([this]() -> absl::Status {
tflite_gpu_runner_.reset(); tflite_gpu_runner_.reset();
return absl::OkStatus(); return absl::OkStatus();
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
tflite_gpu_runner_->GetOutputShapes()[i].c}; tflite_gpu_runner_->GetOutputShapes()[i].c};
} }
if (on_disk_cache_helper_.UseSerializedModel()) {
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
}
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
return tflite_gpu_runner_->Build(); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
} }
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)

View File

@ -256,6 +256,7 @@ class TensorsToDetectionsCalculator : public Node {
bool gpu_inited_ = false; bool gpu_inited_ = false;
bool gpu_input_ = false; bool gpu_input_ = false;
bool gpu_has_enough_work_groups_ = true;
bool anchors_init_ = false; bool anchors_init_ = false;
}; };
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
@ -291,7 +292,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
auto output_detections = absl::make_unique<std::vector<Detection>>(); auto output_detections = absl::make_unique<std::vector<Detection>>();
bool gpu_processing = false; bool gpu_processing = false;
if (CanUseGpu()) { if (CanUseGpu() && gpu_has_enough_work_groups_) {
// Use GPU processing only if at least one input tensor is already on GPU // Use GPU processing only if at least one input tensor is already on GPU
// (to avoid CPU->GPU overhead). // (to avoid CPU->GPU overhead).
for (const auto& tensor : *kInTensors(cc)) { for (const auto& tensor : *kInTensors(cc)) {
@ -321,11 +322,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
RET_CHECK(!has_custom_box_indices_); RET_CHECK(!has_custom_box_indices_);
} }
if (gpu_processing) { if (gpu_processing && !gpu_inited_) {
if (!gpu_inited_) { auto status = GpuInit(cc);
MP_RETURN_IF_ERROR(GpuInit(cc)); if (status.ok()) {
gpu_inited_ = true; gpu_inited_ = true;
} else if (status.code() == absl::StatusCode::kFailedPrecondition) {
// For initialization error because of hardware limitation, fallback to
// CPU processing.
LOG(WARNING) << status.message();
} else {
// For other error, let the error propagates.
return status;
} }
}
if (gpu_processing && gpu_inited_) {
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
} else { } else {
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
@ -346,17 +356,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// TODO: Add flexible input tensor size handling. // TODO: Add flexible input tensor size handling.
auto raw_box_tensor = auto raw_box_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()]; &input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
if (raw_box_tensor->shape().dims.size() == 3) {
// The tensors from CPU inference has dim 3.
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
} else if (raw_box_tensor->shape().dims.size() == 4) {
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
// we allow tensors with 4 dims.
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_);
} else {
return absl::InvalidArgumentError(
"The dimensions of box Tensor must be 3 or 4.");
}
auto raw_score_tensor = auto raw_score_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()]; &input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); if (raw_score_tensor->shape().dims.size() == 3) {
// The tensors from CPU inference has dim 3.
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
} else if (raw_score_tensor->shape().dims.size() == 4) {
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
// we allow tensors with 4 dims.
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_);
} else {
return absl::InvalidArgumentError(
"The dimensions of score Tensor must be 3 or 4.");
}
auto raw_box_view = raw_box_tensor->GetCpuReadView(); auto raw_box_view = raw_box_tensor->GetCpuReadView();
auto raw_boxes = raw_box_view.buffer<float>(); auto raw_boxes = raw_box_view.buffer<float>();
auto raw_scores_view = raw_score_tensor->GetCpuReadView(); auto raw_scores_view = raw_score_tensor->GetCpuReadView();
@ -1111,8 +1145,13 @@ void main() {
int max_wg_size; // typically <= 1024 int max_wg_size; // typically <= 1024
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
&max_wg_size); // y-dim &max_wg_size); // y-dim
CHECK_LT(num_classes_, max_wg_size) gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
<< "# classes must be < " << max_wg_size; if (!gpu_has_enough_work_groups_) {
return absl::FailedPreconditionError(absl::StrFormat(
"Hardware limitation: Processing will be done on CPU, because "
"num_classes %d exceeds the max work_group size %d.",
num_classes_, max_wg_size));
}
// TODO support better filtering. // TODO support better filtering.
if (class_index_set_.is_allowlist) { if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(), CHECK_EQ(class_index_set_.values.size(),
@ -1370,7 +1409,13 @@ kernel void scoreKernel(
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
// # filter classes supported is hardware dependent. // # filter classes supported is hardware dependent.
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup; int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
if (!gpu_has_enough_work_groups_) {
return absl::FailedPreconditionError(absl::StrFormat(
"Hardware limitation: Processing will be done on CPU, because "
"num_classes %d exceeds the max work_group size %d.",
num_classes_, max_wg_size));
}
} }
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)

View File

@ -406,8 +406,13 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed # This dependency removed the following 3 targets because they failed Boq conformance test:
# Boq conformance test. Weigh your use case to see if this will work for you. #
# tensorflow_jellyfish_deps
# jfprof_lib
# xprofilez_with_server
#
# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader.
cc_library( cc_library(
name = "tensorflow_inference_calculator_for_boq", name = "tensorflow_inference_calculator_for_boq",
srcs = ["tensorflow_inference_calculator.cc"], srcs = ["tensorflow_inference_calculator.cc"],
@ -927,7 +932,6 @@ cc_test(
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/formats:location_opencv",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",

View File

@ -164,7 +164,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
} }
CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag)) cc->OutputSidePackets().HasTag(kSequenceExampleTag))
<< "Neither the output stream nor the output side packet is set to " << "Neither the output stream nor the output side packet is set to "
"output the sequence example."; "output the sequence example.";

View File

@ -23,7 +23,6 @@
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/location_opencv.h" #include "mediapipe/framework/formats/location_opencv.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
@ -96,7 +95,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
mpms::SetClipMediaId(test_video_id, input_sequence.get()); mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2); encoded_image.set_width(2);
@ -139,7 +139,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
mpms::SetClipMediaId(test_video_id, input_sequence.get()); mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2); encoded_image.set_width(2);
@ -378,7 +379,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
Adopt(input_sequence.release()); Adopt(input_sequence.release());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
auto image_ptr = auto image_ptr =
@ -410,7 +412,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end()); std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow; OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string); encoded_flow.set_encoded_image(test_flow_string);
@ -618,7 +621,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
} }
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(width); encoded_image.set_width(width);
@ -767,7 +771,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end()); std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow; OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string); encoded_flow.set_encoded_image(test_flow_string);
@ -813,7 +818,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
mpms::SetClipMediaId(test_video_id, input_sequence.get()); mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end()); std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow; OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string); encoded_flow.set_encoded_image(test_flow_string);
@ -970,7 +976,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2); encoded_image.set_width(2);
@ -1021,7 +1028,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
int height = 2; int height = 2;

View File

@ -172,7 +172,7 @@ class AnnotationOverlayCalculator : public CalculatorBase {
REGISTER_CALCULATOR(AnnotationOverlayCalculator); REGISTER_CALCULATOR(AnnotationOverlayCalculator);
absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false; bool use_gpu = false;
@ -189,13 +189,13 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kGpuBufferTag)) { if (cc->Inputs().HasTag(kGpuBufferTag)) {
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
CHECK(cc->Outputs().HasTag(kGpuBufferTag)); RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag));
use_gpu = true; use_gpu = true;
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kImageFrameTag)) { if (cc->Inputs().HasTag(kImageFrameTag)) {
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>(); cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
CHECK(cc->Outputs().HasTag(kImageFrameTag)); RET_CHECK(cc->Outputs().HasTag(kImageFrameTag));
} }
// Data streams to render. // Data streams to render.

View File

@ -322,6 +322,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
options_.presence_threshold(), options_.connection_color(), thickness, options_.presence_threshold(), options_.connection_color(), thickness,
/*normalized=*/false, render_data.get()); /*normalized=*/false, render_data.get());
} }
if (options_.render_landmarks()) {
for (int i = 0; i < landmarks.landmark_size(); ++i) { for (int i = 0; i < landmarks.landmark_size(); ++i) {
const Landmark& landmark = landmarks.landmark(i); const Landmark& landmark = landmarks.landmark(i);
@ -335,7 +336,8 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
auto* landmark_data_render = AddPointRenderData( auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get()); options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) { if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
landmark_data_render,
options_.min_depth_circle_thickness(), options_.min_depth_circle_thickness(),
options_.max_depth_circle_thickness()); options_.max_depth_circle_thickness());
} }
@ -345,6 +347,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
landmark_data->set_y(landmark.y()); landmark_data->set_y(landmark.y());
} }
} }
}
if (cc->Inputs().HasTag(kNormLandmarksTag)) { if (cc->Inputs().HasTag(kNormLandmarksTag)) {
const NormalizedLandmarkList& landmarks = const NormalizedLandmarkList& landmarks =
@ -368,6 +371,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
options_.presence_threshold(), options_.connection_color(), thickness, options_.presence_threshold(), options_.connection_color(), thickness,
/*normalized=*/true, render_data.get()); /*normalized=*/true, render_data.get());
} }
if (options_.render_landmarks()) {
for (int i = 0; i < landmarks.landmark_size(); ++i) { for (int i = 0; i < landmarks.landmark_size(); ++i) {
const NormalizedLandmark& landmark = landmarks.landmark(i); const NormalizedLandmark& landmark = landmarks.landmark(i);
@ -381,7 +385,8 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
auto* landmark_data_render = AddPointRenderData( auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get()); options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) { if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
landmark_data_render,
options_.min_depth_circle_thickness(), options_.min_depth_circle_thickness(),
options_.max_depth_circle_thickness()); options_.max_depth_circle_thickness());
} }
@ -391,6 +396,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
landmark_data->set_y(landmark.y()); landmark_data->set_y(landmark.y());
} }
} }
}
cc->Outputs() cc->Outputs()
.Tag(kRenderDataTag) .Tag(kRenderDataTag)

View File

@ -32,6 +32,10 @@ message LandmarksToRenderDataCalculatorOptions {
// Color of the landmarks. // Color of the landmarks.
optional Color landmark_color = 2; optional Color landmark_color = 2;
// Whether to render landmarks as points.
optional bool render_landmarks = 14 [default = true];
// Color of the connections. // Color of the connections.
optional Color connection_color = 3; optional Color connection_color = 3;

View File

@ -124,7 +124,7 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
int center_row = out_lms.landmark(lm_index).y() * hm_height; int center_row = out_lms.landmark(lm_index).y() * hm_height;
// Point is outside of the image let's keep it intact. // Point is outside of the image let's keep it intact.
if (center_col < 0 || center_col >= hm_width || center_row < 0 || if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
center_col >= hm_height) { center_row >= hm_height) {
continue; continue;
} }

View File

@ -130,7 +130,6 @@ cc_library(
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
], ],
@ -341,7 +340,6 @@ cc_test(
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util", "//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag",
], ],
) )
@ -367,7 +365,6 @@ cc_test(
"//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util", "//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag",
], ],
) )
@ -451,7 +448,6 @@ cc_test(
"//mediapipe/framework/tool:test_util", "//mediapipe/framework/tool:test_util",
"//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto",
"//mediapipe/util/tracking:tracking_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto",
"@com_google_absl//absl/flags:flag",
], ],
) )

View File

@ -1,6 +1,6 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip
networkTimeout=10000 networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "facedetectioncpu", name = "facedetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "facedetectiongpu", name = "facedetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "faceeffect", name = "faceeffect",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "facemeshgpu", name = "facemeshgpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "handdetectiongpu", name = "handdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "handtrackinggpu", name = "handtrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "helloworld", name = "helloworld",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "holistictrackinggpu", name = "holistictrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "iristrackinggpu", name = "iristrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "objectdetectioncpu", name = "objectdetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "objectdetectiongpu", name = "objectdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "objectdetectiontrackinggpu", name = "objectdetectiontrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "posetrackinggpu", name = "posetrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
alias( alias(
name = "selfiesegmentationgpu", name = "selfiesegmentationgpu",

View File

@ -44,6 +44,9 @@ bzl_library(
"encode_binary_proto.bzl", "encode_binary_proto.bzl",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [
"@bazel_skylib//lib:paths",
],
) )
alias( alias(

View File

@ -64,58 +64,13 @@ class CalculatorBaseFactoryFor<
namespace api2 { namespace api2 {
namespace internal { namespace internal {
// Defining a member of this type causes P to be ODR-used, which forces its MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(
// instantiation if it's a static member of a template. NodeRegistrator, mediapipe::CalculatorBaseRegistry, T::kCalculatorName,
// Previously we depended on the pointer's value to determine whether the size absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>)
// of a character array is 0 or 1, forcing it to be instantiated so the
// compiler can determine the object's layout. But using it as a template
// argument is more compact.
template <auto* P>
struct ForceStaticInstantiation {
#ifdef _MSC_VER
// Just having it as the template argument does not count as a use for
// MSVC.
static constexpr bool Use() { return P != nullptr; }
char force_static[Use()];
#endif // _MSC_VER
};
// Helper template for forcing the definition of a static registration token. MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(SubgraphRegistrator,
template <typename T> mediapipe::SubgraphRegistry,
struct NodeRegistrationStatic { T::kCalculatorName, absl::make_unique<T>)
static NoDestructor<mediapipe::RegistrationToken> registration;
static mediapipe::RegistrationToken Make() {
return mediapipe::CalculatorBaseRegistry::Register(
T::kCalculatorName,
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>,
__FILE__, __LINE__);
}
using RequireStatics = ForceStaticInstantiation<&registration>;
};
// Static members of template classes can be defined in the header.
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
NodeRegistrationStatic<T>::registration(NodeRegistrationStatic<T>::Make());
template <typename T>
struct SubgraphRegistrationImpl {
static NoDestructor<mediapipe::RegistrationToken> registration;
static mediapipe::RegistrationToken Make() {
return mediapipe::SubgraphRegistry::Register(
T::kCalculatorName, absl::make_unique<T>, __FILE__, __LINE__);
}
using RequireStatics = ForceStaticInstantiation<&registration>;
};
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
SubgraphRegistrationImpl<T>::registration(
SubgraphRegistrationImpl<T>::Make());
} // namespace internal } // namespace internal
@ -128,14 +83,7 @@ template <class Impl = void>
class RegisteredNode; class RegisteredNode;
template <class Impl> template <class Impl>
class RegisteredNode : public Node { class RegisteredNode : public Node, private internal::NodeRegistrator<Impl> {};
private:
// The member below triggers instantiation of the registration static.
// Note that the constructor of calculator subclasses is only invoked through
// the registration token, and so we cannot simply use the static in the
// constructor.
typename internal::NodeRegistrationStatic<Impl>::RequireStatics register_;
};
// No-op version for backwards compatibility. // No-op version for backwards compatibility.
template <> template <>
@ -217,20 +165,17 @@ class NodeImpl : public RegisteredNode<Impl>, public Intf {
// TODO: verify that the subgraph config fully implements the // TODO: verify that the subgraph config fully implements the
// declared interface. // declared interface.
template <class Intf, class Impl> template <class Intf, class Impl>
class SubgraphImpl : public Subgraph, public Intf { class SubgraphImpl : public Subgraph,
private: public Intf,
typename internal::SubgraphRegistrationImpl<Impl>::RequireStatics register_; private internal::SubgraphRegistrator<Impl> {};
};
// This macro is used to register a calculator that does not use automatic // This macro is used to register a calculator that does not use automatic
// registration. Deprecated. // registration. Deprecated.
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ #define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \ MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
REGISTRY_STATIC_VAR(calculator_registration, \ mediapipe::CalculatorBaseRegistry, calculator_registration, \
__LINE__)(mediapipe::CalculatorBaseRegistry::Register( \
Impl::kCalculatorName, \ Impl::kCalculatorName, \
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>, \ absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>)
__FILE__, __LINE__))
// This macro is used to register a non-split-contract calculator. Deprecated. // This macro is used to register a non-split-contract calculator. Deprecated.
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name) #define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
@ -238,10 +183,9 @@ class SubgraphImpl : public Subgraph, public Intf {
// This macro is used to define a subgraph that does not use automatic // This macro is used to define a subgraph that does not use automatic
// registration. Deprecated. // registration. Deprecated.
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ #define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \ MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
REGISTRY_STATIC_VAR(subgraph_registration, \ mediapipe::SubgraphRegistry, subgraph_registration, \
__LINE__)(mediapipe::SubgraphRegistry::Register( \ Impl::kCalculatorName, absl::make_unique<Impl>)
Impl::kCalculatorName, absl::make_unique<Impl>, __FILE__, __LINE__))
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -183,8 +183,7 @@ TEST(CalculatorTest, CreateByNameWhitelisted) {
CalculatorBaseRegistry::Register( CalculatorBaseRegistry::Register(
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator", "::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
absl::make_unique<internal::CalculatorBaseFactoryFor< absl::make_unique<internal::CalculatorBaseFactoryFor<
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>, mediapipe::test_ns::whitelisted_ns::DeadCalculator>>);
__FILE__, __LINE__);
// A whitelisted calculator can be found in its own namespace. // A whitelisted calculator can be found in its own namespace.
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //

View File

@ -16,7 +16,6 @@
#define MEDIAPIPE_DEPS_REGISTRATION_H_ #define MEDIAPIPE_DEPS_REGISTRATION_H_
#include <algorithm> #include <algorithm>
#include <cstdint>
#include <functional> #include <functional>
#include <string> #include <string>
#include <tuple> #include <tuple>
@ -145,6 +144,23 @@ template <typename T>
struct WrapStatusOr<absl::StatusOr<T>> { struct WrapStatusOr<absl::StatusOr<T>> {
using type = absl::StatusOr<T>; using type = absl::StatusOr<T>;
}; };
// Defining a member of this type causes P to be ODR-used, which forces its
// instantiation if it's a static member of a template.
// Previously we depended on the pointer's value to determine whether the size
// of a character array is 0 or 1, forcing it to be instantiated so the
// compiler can determine the object's layout. But using it as a template
// argument is more compact.
template <auto* P>
struct ForceStaticInstantiation {
#ifdef _MSC_VER
// Just having it as the template argument does not count as a use for
// MSVC.
static constexpr bool Use() { return P != nullptr; }
char force_static[Use()];
#endif // _MSC_VER
};
} // namespace registration_internal } // namespace registration_internal
class NamespaceAllowlist { class NamespaceAllowlist {
@ -162,8 +178,7 @@ class FunctionRegistry {
FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry(const FunctionRegistry&) = delete;
FunctionRegistry& operator=(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete;
RegistrationToken Register(absl::string_view name, Function func, RegistrationToken Register(absl::string_view name, Function func)
std::string filename, uint64_t line)
ABSL_LOCKS_EXCLUDED(lock_) { ABSL_LOCKS_EXCLUDED(lock_) {
std::string normalized_name = GetNormalizedName(name); std::string normalized_name = GetNormalizedName(name);
absl::WriterMutexLock lock(&lock_); absl::WriterMutexLock lock(&lock_);
@ -173,21 +188,10 @@ class FunctionRegistry {
} }
if (functions_.insert(std::make_pair(normalized_name, std::move(func))) if (functions_.insert(std::make_pair(normalized_name, std::move(func)))
.second) { .second) {
#ifndef NDEBUG
locations_.emplace(normalized_name,
std::make_pair(std::move(filename), line));
#endif
return RegistrationToken( return RegistrationToken(
[this, normalized_name]() { Unregister(normalized_name); }); [this, normalized_name]() { Unregister(normalized_name); });
} }
#ifndef NDEBUG
LOG(FATAL) << "Function with name " << name << " already registered."
<< " First registration at "
<< locations_.at(normalized_name).first << ":"
<< locations_.at(normalized_name).second;
#else
LOG(FATAL) << "Function with name " << name << " already registered."; LOG(FATAL) << "Function with name " << name << " already registered.";
#endif
return RegistrationToken([]() {}); return RegistrationToken([]() {});
} }
@ -316,11 +320,6 @@ class FunctionRegistry {
private: private:
mutable absl::Mutex lock_; mutable absl::Mutex lock_;
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_); absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
#ifndef NDEBUG
// Stores filename and line number for useful debug log.
absl::flat_hash_map<std::string, std::pair<std::string, uint32_t>> locations_
ABSL_GUARDED_BY(lock_);
#endif
// For names included in NamespaceAllowlist, strips the namespace. // For names included in NamespaceAllowlist, strips the namespace.
std::string GetAdjustedName(absl::string_view name) { std::string GetAdjustedName(absl::string_view name) {
@ -351,10 +350,8 @@ class GlobalFactoryRegistry {
public: public:
static RegistrationToken Register(absl::string_view name, static RegistrationToken Register(absl::string_view name,
typename Functions::Function func, typename Functions::Function func) {
std::string filename, uint64_t line) { return functions()->Register(name, std::move(func));
return functions()->Register(name, std::move(func), std::move(filename),
line);
} }
// Invokes the specified factory function and returns the result. // Invokes the specified factory function and returns the result.
@ -414,12 +411,77 @@ class GlobalFactoryRegistry {
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \ #define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \
static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \ static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \
new mediapipe::RegistrationToken( \ new mediapipe::RegistrationToken( \
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__)) RegistryType::Register(#name, __VA_ARGS__))
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \
name, ...) \
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
new mediapipe::RegistrationToken( \
RegistryType::Register(name, __VA_ARGS__))
// TODO: migrate to the above.
#define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \ #define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \ static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
new mediapipe::RegistrationToken( \ new mediapipe::RegistrationToken( \
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__)) RegistryType::Register(#name, __VA_ARGS__))
// Defines a utility registrator class which can be used to automatically
// register factory functions.
//
// Example:
// === Defining a registry ================================================
//
// class Component {};
//
// using ComponentRegistry = GlobalFactoryRegistry<std::unique_ptr<Component>>;
//
// === Defining a registrator =============================================
//
// MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(ComponentRegistrator,
// ComponentRegistry, T::kName,
// absl::make_unique<T>);
//
// === Defining and registering a new component. ==========================
//
// class MyComponent : public Component,
// private ComponentRegistrator<MyComponent> {
// public:
// static constexpr char kName[] = "MyComponent";
// ...
// };
//
// NOTE:
// - MyComponent is automatically registered in ComponentRegistry by
// "MyComponent" name.
// - Every component is require to provide its name (T::kName here.)
#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \
name, ...) \
template <typename T> \
struct Internal##RegistratorName { \
static NoDestructor<mediapipe::RegistrationToken> registration; \
\
static mediapipe::RegistrationToken Make() { \
return RegistryType::Register(name, __VA_ARGS__); \
} \
\
using RequireStatics = \
registration_internal::ForceStaticInstantiation<&registration>; \
}; \
/* Static members of template classes can be defined in the header. */ \
template <typename T> \
NoDestructor<mediapipe::RegistrationToken> \
Internal##RegistratorName<T>::registration( \
Internal##RegistratorName<T>::Make()); \
\
template <typename T> \
class RegistratorName { \
private: \
/* The member below triggers instantiation of the registration static. */ \
/* Note that the constructor of calculator subclasses is only invoked */ \
/* through the registration token, and so we cannot simply use the */ \
/* static in theconstructor. */ \
typename Internal##RegistratorName<T>::RequireStatics register_; \
};
} // namespace mediapipe } // namespace mediapipe

View File

@ -37,29 +37,33 @@ Args:
output: The desired name of the output file. Optional. output: The desired name of the output file. Optional.
""" """
load("@bazel_skylib//lib:paths.bzl", "paths")
PROTOC = "@com_google_protobuf//:protoc" PROTOC = "@com_google_protobuf//:protoc"
def _canonicalize_proto_path_oss(all_protos, genfile_path): def _canonicalize_proto_path_oss(f):
"""For the protos from external repository, canonicalize the proto path and the file name. if not f.root.path:
return struct(
proto_path = ".",
file_name = f.short_path,
)
Returns: # `f.path` looks like "<genfiles>/external/<repo>/(_virtual_imports/<library>/)?<file_name>"
Proto path list and proto source file list. repo_name, _, file_name = f.path[len(paths.join(f.root.path, "external") + "/"):].partition("/")
""" if file_name.startswith("_virtual_imports/"):
proto_paths = [] # This is a virtual import; move "_virtual_imports/<library>" from `repo_name` to `file_name`.
proto_file_names = [] repo_name = paths.join(repo_name, *file_name.split("/", 2)[:2])
for s in all_protos.to_list():
if s.path.startswith(genfile_path):
repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/")
# handle virtual imports
if file_name.startswith("_virtual_imports"):
repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2])
file_name = file_name.split("/", 2)[-1] file_name = file_name.split("/", 2)[-1]
proto_paths.append(genfile_path + "/external/" + repo_name) return struct(
proto_file_names.append(file_name) proto_path = paths.join(f.root.path, "external", repo_name),
else: file_name = file_name,
proto_file_names.append(s.path) )
return ([" --proto_path=" + path for path in proto_paths], proto_file_names)
def _map_root_path(f):
return _canonicalize_proto_path_oss(f).proto_path
def _map_short_path(f):
return _canonicalize_proto_path_oss(f).file_name
def _get_proto_provider(dep): def _get_proto_provider(dep):
"""Get the provider for protocol buffers from a dependnecy. """Get the provider for protocol buffers from a dependnecy.
@ -90,24 +94,35 @@ def _encode_binary_proto_impl(ctx):
sibling = textpb, sibling = textpb,
) )
path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path) args = ctx.actions.args()
args.add(textpb)
args.add(binarypb)
args.add(ctx.executable._proto_compiler)
args.add(ctx.attr.message_type, format = "--encode=%s")
args.add("--proto_path=.")
args.add_all(
all_protos,
map_each = _map_root_path,
format_each = "--proto_path=%s",
uniquify = True,
)
args.add_all(
all_protos,
map_each = _map_short_path,
uniquify = True,
)
# Note: the combination of absolute_paths and proto_path, as well as the exact # Note: the combination of absolute_paths and proto_path, as well as the exact
# order of gendir before ., is needed for the proto compiler to resolve # order of gendir before ., is needed for the proto compiler to resolve
# import statements that reference proto files produced by a genrule. # import statements that reference proto files produced by a genrule.
ctx.actions.run_shell( ctx.actions.run_shell(
tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler], tools = depset(
outputs = [binarypb], direct = [textpb, ctx.executable._proto_compiler],
command = " ".join( transitive = [all_protos],
[
ctx.executable._proto_compiler.path,
"--encode=" + ctx.attr.message_type,
"--proto_path=" + ctx.genfiles_dir.path,
"--proto_path=" + ctx.bin_dir.path,
"--proto_path=.",
] + path_list + file_list +
["<", textpb.path, ">", binarypb.path],
), ),
outputs = [binarypb],
command = "${@:3} < $1 > $2",
arguments = [args],
mnemonic = "EncodeProto", mnemonic = "EncodeProto",
) )

View File

@ -19,7 +19,7 @@ package mediapipe;
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of // Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
// the joint and its visibility. // the joint and its visibility.
message Joint { message Joint {
// Joint rotation in 6D contineous representation ordered as // Joint rotation in 6D continuous representation ordered as
// [a1, b1, a2, b2, a3, b3]. // [a1, b1, a2, b2, a3, b3].
// //
// Such representation is more sutable for NN model training and can be // Such representation is more sutable for NN model training and can be

View File

@ -15,7 +15,7 @@ def mediapipe_cc_test(
platforms = ["linux", "android", "ios", "wasm"], platforms = ["linux", "android", "ios", "wasm"],
exclude_platforms = None, exclude_platforms = None,
# ios_unit_test arguments # ios_unit_test arguments
ios_minimum_os_version = "11.0", ios_minimum_os_version = "12.0",
# android_cc_test arguments # android_cc_test arguments
open_gl_driver = None, open_gl_driver = None,
emulator_mini_boot = True, emulator_mini_boot = True,

View File

@ -466,8 +466,7 @@ struct MessageRegistrationImpl {
template <typename T> template <typename T>
NoDestructor<mediapipe::RegistrationToken> NoDestructor<mediapipe::RegistrationToken>
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register( MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder, T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
__FILE__, __LINE__));
// For non-Message payloads, this does nothing. // For non-Message payloads, this does nothing.
template <typename T, typename Enable = void> template <typename T, typename Enable = void>

View File

@ -261,8 +261,8 @@ cc_library(
) )
cc_library( cc_library(
name = "opencv_highgui", name = "opencv_photo",
hdrs = ["opencv_highgui_inc.h"], hdrs = ["opencv_photo_inc.h"],
deps = [ deps = [
":opencv_core", ":opencv_core",
"//third_party:opencv", "//third_party:opencv",
@ -297,6 +297,15 @@ cc_library(
], ],
) )
cc_library(
name = "opencv_highgui",
hdrs = ["opencv_highgui_inc.h"],
deps = [
":opencv_core",
"//third_party:opencv",
],
)
cc_library( cc_library(
name = "opencv_videoio", name = "opencv_videoio",
hdrs = ["opencv_videoio_inc.h"], hdrs = ["opencv_videoio_inc.h"],

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2023 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ #ifndef MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
#define MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ #define MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
#include <opencv2/core/version.hpp> #include <opencv2/core/version.hpp>
@ -25,4 +25,4 @@
#include <opencv2/highgui.hpp> #include <opencv2/highgui.hpp>
#endif #endif
#endif // MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ #endif // MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2023 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,15 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <vector> #ifndef MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
#define MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
#include "mediapipe/calculators/core/clip_vector_size_calculator.h" #include "third_party/OpenCV/photo.hpp"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe { #endif // MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
ClipDetectionVectorSizeCalculator;
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
} // namespace mediapipe

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2022 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.

View File

@ -48,6 +48,18 @@ class MuxInputStreamHandler : public InputStreamHandler {
: InputStreamHandler(std::move(tag_map), cc_manager, options, : InputStreamHandler(std::move(tag_map), cc_manager, options,
calculator_run_in_parallel) {} calculator_run_in_parallel) {}
private:
CollectionItemId GetControlStreamId() const {
return input_stream_managers_.EndId() - 1;
}
void RemoveOutdatedDataPackets(Timestamp timestamp) {
const CollectionItemId control_stream_id = GetControlStreamId();
for (CollectionItemId id = input_stream_managers_.BeginId();
id < control_stream_id; ++id) {
input_stream_managers_.Get(id)->ErasePacketsEarlierThan(timestamp);
}
}
protected: protected:
// In MuxInputStreamHandler, a node is "ready" if: // In MuxInputStreamHandler, a node is "ready" if:
// - the control stream is done (need to call Close() in this case), or // - the control stream is done (need to call Close() in this case), or
@ -58,9 +70,15 @@ class MuxInputStreamHandler : public InputStreamHandler {
absl::MutexLock lock(&input_streams_mutex_); absl::MutexLock lock(&input_streams_mutex_);
const auto& control_stream = const auto& control_stream =
input_stream_managers_.Get(input_stream_managers_.EndId() - 1); input_stream_managers_.Get(GetControlStreamId());
bool empty; bool empty;
*min_stream_timestamp = control_stream->MinTimestampOrBound(&empty); *min_stream_timestamp = control_stream->MinTimestampOrBound(&empty);
// Data streams may contain some outdated packets which failed to be popped
// out during "FillInputSet". (This handler doesn't sync input streams,
// hence "FillInputSet" can be triggerred before every input stream is
// filled with packets corresponding to the same timestamp.)
RemoveOutdatedDataPackets(*min_stream_timestamp);
if (empty) { if (empty) {
if (*min_stream_timestamp == Timestamp::Done()) { if (*min_stream_timestamp == Timestamp::Done()) {
// Calculator is done if the control input stream is done. // Calculator is done if the control input stream is done.
@ -78,11 +96,6 @@ class MuxInputStreamHandler : public InputStreamHandler {
const auto& data_stream = input_stream_managers_.Get( const auto& data_stream = input_stream_managers_.Get(
input_stream_managers_.BeginId() + control_value); input_stream_managers_.BeginId() + control_value);
// Data stream may contain some outdated packets which failed to be popped
// out during "FillInputSet". (This handler doesn't sync input streams,
// hence "FillInputSet" can be triggerred before every input stream is
// filled with packets corresponding to the same timestamp.)
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty); Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
if (empty) { if (empty) {
if (stream_timestamp <= *min_stream_timestamp) { if (stream_timestamp <= *min_stream_timestamp) {
@ -111,8 +124,7 @@ class MuxInputStreamHandler : public InputStreamHandler {
CHECK(input_set); CHECK(input_set);
absl::MutexLock lock(&input_streams_mutex_); absl::MutexLock lock(&input_streams_mutex_);
const CollectionItemId control_stream_id = const CollectionItemId control_stream_id = GetControlStreamId();
input_stream_managers_.EndId() - 1;
auto& control_stream = input_stream_managers_.Get(control_stream_id); auto& control_stream = input_stream_managers_.Get(control_stream_id);
int num_packets_dropped = 0; int num_packets_dropped = 0;
bool stream_is_done = false; bool stream_is_done = false;
@ -140,15 +152,8 @@ class MuxInputStreamHandler : public InputStreamHandler {
AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet), AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet),
stream_is_done); stream_is_done);
// Discard old packets on other streams. // Discard old packets on data streams.
// Note that control_stream_id is the last valid id. RemoveOutdatedDataPackets(input_timestamp.NextAllowedInStream());
auto next_timestamp = input_timestamp.NextAllowedInStream();
for (CollectionItemId id = input_stream_managers_.BeginId();
id < control_stream_id; ++id) {
if (id == data_stream_id) continue;
auto& other_stream = input_stream_managers_.Get(id);
other_stream->ErasePacketsEarlierThan(next_timestamp);
}
} }
private: private:

View File

@ -645,5 +645,41 @@ TEST(MuxInputStreamHandlerTest,
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
TEST(MuxInputStreamHandlerTest, RemovesUnusedDataStreamPackets) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input0"
input_stream: "INPUT:1:input1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
)pb");
config.set_max_queue_size(1);
config.set_report_deadlock(true);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add two delayed packets to the deselected input. They should be discarded
// instead of triggering the deadlock detection (max_queue_size = 1).
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -64,13 +64,13 @@ GraphRegistry::GraphRegistry(
void GraphRegistry::Register( void GraphRegistry::Register(
const std::string& type_name, const std::string& type_name,
std::function<std::unique_ptr<Subgraph>()> factory) { std::function<std::unique_ptr<Subgraph>()> factory) {
local_factories_.Register(type_name, factory, __FILE__, __LINE__); local_factories_.Register(type_name, factory);
} }
// TODO: Remove this convenience function. // TODO: Remove this convenience function.
void GraphRegistry::Register(const std::string& type_name, void GraphRegistry::Register(const std::string& type_name,
const CalculatorGraphConfig& config) { const CalculatorGraphConfig& config) {
Register(type_name, [config] { local_factories_.Register(type_name, [config] {
auto result = absl::make_unique<ProtoSubgraph>(config); auto result = absl::make_unique<ProtoSubgraph>(config);
return std::unique_ptr<Subgraph>(result.release()); return std::unique_ptr<Subgraph>(result.release());
}); });
@ -79,7 +79,7 @@ void GraphRegistry::Register(const std::string& type_name,
// TODO: Remove this convenience function. // TODO: Remove this convenience function.
void GraphRegistry::Register(const std::string& type_name, void GraphRegistry::Register(const std::string& type_name,
const CalculatorGraphTemplate& templ) { const CalculatorGraphTemplate& templ) {
Register(type_name, [templ] { local_factories_.Register(type_name, [templ] {
auto result = absl::make_unique<TemplateSubgraph>(templ); auto result = absl::make_unique<TemplateSubgraph>(templ);
return std::unique_ptr<Subgraph>(result.release()); return std::unique_ptr<Subgraph>(result.release());
}); });

View File

@ -228,7 +228,9 @@ absl::Status CompareAndSaveImageOutput(
auto status = CompareImageFrames(**expected, actual, options.max_color_diff, auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
options.max_alpha_diff, options.max_avg_diff, options.max_alpha_diff, options.max_avg_diff,
diff_img); diff_img);
if (diff_img) {
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff")); ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
}
return status; return status;
} }

View File

@ -1121,7 +1121,7 @@ objc_library(
alwayslink = 1, alwayslink = 1,
) )
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
test_suite( test_suite(
name = "ios", name = "ios",

View File

@ -109,9 +109,8 @@ absl::Status GlContext::CreateContext(
} }
MP_RETURN_IF_ERROR(status); MP_RETURN_IF_ERROR(status);
LOG(INFO) << "Successfully created a WebGL context with major version " VLOG(1) << "Successfully created a WebGL context with major version "
<< gl_major_version_ << " and handle " << context_; << gl_major_version_ << " and handle " << context_;
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -104,6 +104,7 @@ class GlScalerCalculator : public CalculatorBase {
bool vertical_flip_output_; bool vertical_flip_output_;
bool horizontal_flip_output_; bool horizontal_flip_output_;
FrameScaleMode scale_mode_ = FrameScaleMode::kStretch; FrameScaleMode scale_mode_ = FrameScaleMode::kStretch;
bool use_nearest_neighbor_interpolation_ = false;
}; };
REGISTER_CALCULATOR(GlScalerCalculator); REGISTER_CALCULATOR(GlScalerCalculator);
@ -186,7 +187,8 @@ absl::Status GlScalerCalculator::Open(CalculatorContext* cc) {
scale_mode_ = scale_mode_ =
FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch); FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch);
} }
use_nearest_neighbor_interpolation_ =
options.use_nearest_neighbor_interpolation();
if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) { if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) {
const auto& dimensions = const auto& dimensions =
TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1) TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)
@ -297,6 +299,11 @@ absl::Status GlScalerCalculator::Process(CalculatorContext* cc) {
glBindTexture(src2.target(), src2.name()); glBindTexture(src2.target(), src2.name());
} }
if (use_nearest_neighbor_interpolation_) {
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
}
MP_RETURN_IF_ERROR(renderer->GlRender( MP_RETURN_IF_ERROR(renderer->GlRender(
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_, src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_,
rotation_, horizontal_flip_output_, vertical_flip_output_, rotation_, horizontal_flip_output_, vertical_flip_output_,

View File

@ -19,7 +19,7 @@ package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto"; import "mediapipe/gpu/scale_mode.proto";
// Next id: 8. // Next id: 9.
message GlScalerCalculatorOptions { message GlScalerCalculatorOptions {
extend CalculatorOptions { extend CalculatorOptions {
optional GlScalerCalculatorOptions ext = 166373014; optional GlScalerCalculatorOptions ext = 166373014;
@ -39,4 +39,7 @@ message GlScalerCalculatorOptions {
// Flip the output texture horizontally. This is applied after rotation. // Flip the output texture horizontally. This is applied after rotation.
optional bool flip_horizontal = 5; optional bool flip_horizontal = 5;
optional ScaleMode.Mode scale_mode = 6; optional ScaleMode.Mode scale_mode = 6;
// Whether to use nearest neighbor interpolation. Default to use linear
// interpolation.
optional bool use_nearest_neighbor_interpolation = 8 [default = false];
} }

View File

@ -100,6 +100,10 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
#endif // TARGET_OS_OSX #endif // TARGET_OS_OSX
}}, }},
{GpuBufferFormat::kOneComponent8Alpha,
{
{GL_ALPHA, GL_ALPHA, GL_UNSIGNED_BYTE, 1},
}},
{GpuBufferFormat::kOneComponent8Red, {GpuBufferFormat::kOneComponent8Red,
{ {
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
@ -221,6 +225,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
case GpuBufferFormat::kRGBA32: case GpuBufferFormat::kRGBA32:
// TODO: this likely maps to ImageFormat::SRGBA // TODO: this likely maps to ImageFormat::SRGBA
case GpuBufferFormat::kGrayHalf16: case GpuBufferFormat::kGrayHalf16:
case GpuBufferFormat::kOneComponent8Alpha:
case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kOneComponent8Red:
case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponent8:
case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kTwoComponentHalf16:

View File

@ -43,6 +43,7 @@ enum class GpuBufferFormat : uint32_t {
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'), kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'), kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'), kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
kOneComponent8Alpha = MEDIAPIPE_FOURCC('A', '0', '0', '8'),
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'), kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'), kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'), kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
@ -101,6 +102,7 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
return kCVPixelFormatType_OneComponent32Float; return kCVPixelFormatType_OneComponent32Float;
case GpuBufferFormat::kOneComponent8: case GpuBufferFormat::kOneComponent8:
return kCVPixelFormatType_OneComponent8; return kCVPixelFormatType_OneComponent8;
case GpuBufferFormat::kOneComponent8Alpha:
case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kOneComponent8Red:
return -1; return -1;
case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponent8:

View File

@ -78,17 +78,21 @@ public class AppTextureFrame implements TextureFrame {
* Use {@link waitUntilReleasedWithGpuSync} whenever possible. * Use {@link waitUntilReleasedWithGpuSync} whenever possible.
*/ */
public void waitUntilReleased() throws InterruptedException { public void waitUntilReleased() throws InterruptedException {
GlSyncToken tokenToRelease = null;
synchronized (this) { synchronized (this) {
while (inUse && releaseSyncToken == null) { while (inUse && releaseSyncToken == null) {
wait(); wait();
} }
if (releaseSyncToken != null) { if (releaseSyncToken != null) {
releaseSyncToken.waitOnCpu(); tokenToRelease = releaseSyncToken;
releaseSyncToken.release();
inUse = false; inUse = false;
releaseSyncToken = null; releaseSyncToken = null;
} }
} }
if (tokenToRelease != null) {
tokenToRelease.waitOnCpu();
tokenToRelease.release();
}
} }
/** /**
@ -98,17 +102,21 @@ public class AppTextureFrame implements TextureFrame {
* TextureFrame. * TextureFrame.
*/ */
public void waitUntilReleasedWithGpuSync() throws InterruptedException { public void waitUntilReleasedWithGpuSync() throws InterruptedException {
GlSyncToken tokenToRelease = null;
synchronized (this) { synchronized (this) {
while (inUse && releaseSyncToken == null) { while (inUse && releaseSyncToken == null) {
wait(); wait();
} }
if (releaseSyncToken != null) { if (releaseSyncToken != null) {
releaseSyncToken.waitOnGpu(); tokenToRelease = releaseSyncToken;
releaseSyncToken.release();
inUse = false; inUse = false;
releaseSyncToken = null; releaseSyncToken = null;
} }
} }
if (tokenToRelease != null) {
tokenToRelease.waitOnGpu();
tokenToRelease.release();
}
} }
/** /**

View File

@ -239,7 +239,7 @@ public final class PacketGetter {
/** /**
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer * Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
* array has the the same size of image list packet, and assumes the output buffer stores pixels * array has the same size of image list packet, and assumes the output buffer stores pixels
* contiguously. It returns false if this assumption does not hold. * contiguously. It returns false if this assumption does not hold.
* *
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of * <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of

View File

@ -24,6 +24,7 @@ package_group(
package_group( package_group(
name = "1p_client", name = "1p_client",
packages = [ packages = [
"//cloud/ml/applications/vision/model_garden/model_oss/mediapipe/...",
"//research/privacy/learning/fl_eval/pcvr/...", "//research/privacy/learning/fl_eval/pcvr/...",
], ],
) )

View File

@ -57,3 +57,14 @@ py_test(
srcs = ["classification_dataset_test.py"], srcs = ["classification_dataset_test.py"],
deps = [":classification_dataset"], deps = [":classification_dataset"],
) )
py_library(
name = "cache_files",
srcs = ["cache_files.py"],
)
py_test(
name = "cache_files_test",
srcs = ["cache_files_test.py"],
deps = [":cache_files"],
)

View File

@ -0,0 +1,112 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common TFRecord cache files library."""
import dataclasses
import os
import tempfile
from typing import Any, Mapping, Sequence
import tensorflow as tf
import yaml
# Suffix of the meta data file name.
METADATA_FILE_SUFFIX = '_metadata.yaml'
@dataclasses.dataclass(frozen=True)
class TFRecordCacheFiles:
"""TFRecordCacheFiles dataclass to store and load cached TFRecord files.
Attributes:
cache_prefix_filename: The cache prefix filename. This is usually provided
as a hash of the original data source to avoid different data sources
resulting in the same cache file.
cache_dir: The cache directory to save TFRecord and metadata file. When
cache_dir is None, a temporary folder will be created and will not be
removed automatically after training which makes it can be used later.
num_shards: Number of shards for output tfrecord files.
"""
cache_prefix_filename: str = 'cache_prefix'
cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp)
num_shards: int = 1
def __post_init__(self):
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
if not self.cache_prefix_filename:
raise ValueError('cache_prefix_filename cannot be empty.')
if self.num_shards <= 0:
raise ValueError(
f'num_shards must be greater than 0, got {self.num_shards}'
)
@property
def cache_prefix(self) -> str:
"""The cache prefix including the cache directory and the cache prefix filename."""
return os.path.join(self.cache_dir, self.cache_prefix_filename)
@property
def tfrecord_files(self) -> Sequence[str]:
"""The TFRecord files."""
tfrecord_files = [
self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards)
for i in range(self.num_shards)
]
return tfrecord_files
@property
def metadata_file(self) -> str:
"""The metadata file."""
return self.cache_prefix + METADATA_FILE_SUFFIX
def get_writers(self) -> Sequence[tf.io.TFRecordWriter]:
"""Gets an array of TFRecordWriter objects.
Note that these writers should each be closed using .close() when done.
Returns:
Array of TFRecordWriter objects
"""
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
def save_metadata(self, metadata):
"""Writes metadata to file.
Args:
metadata: A dictionary of metadata content to write. Exact format is
dependent on the specific dataset, but typically includes a 'size' and
'label_names' entry.
"""
with tf.io.gfile.GFile(self.metadata_file, 'w') as f:
yaml.dump(metadata, f)
def load_metadata(self) -> Mapping[Any, Any]:
"""Reads metadata from file.
Returns:
Dictionary object containing metadata
"""
if not tf.io.gfile.exists(self.metadata_file):
return {}
with tf.io.gfile.GFile(self.metadata_file, 'r') as f:
metadata = yaml.load(f, Loader=yaml.FullLoader)
return metadata
def is_cached(self) -> bool:
"""Checks whether this CacheFiles is already cached."""
all_cached_files = list(self.tfrecord_files) + [self.metadata_file]
return all(tf.io.gfile.exists(f) for f in all_cached_files)

View File

@ -0,0 +1,77 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
class CacheFilesTest(tf.test.TestCase):
def test_tfrecord_cache_files(self):
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='tfrecord',
cache_dir='/tmp/cache_dir',
num_shards=2,
)
self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord')
self.assertEqual(
cf.metadata_file,
'/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX,
)
expected_tfrecord_files = [
'/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2)
for i in range(2)
]
self.assertEqual(cf.tfrecord_files, expected_tfrecord_files)
# Writing TFRecord Files
self.assertFalse(cf.is_cached())
for tfrecord_file in cf.tfrecord_files:
self.assertFalse(tf.io.gfile.exists(tfrecord_file))
writers = cf.get_writers()
for writer in writers:
writer.close()
for tfrecord_file in cf.tfrecord_files:
self.assertTrue(tf.io.gfile.exists(tfrecord_file))
self.assertFalse(cf.is_cached())
# Writing Metadata Files
original_metadata = {'size': 10, 'label_names': ['label1', 'label2']}
cf.save_metadata(original_metadata)
self.assertTrue(cf.is_cached())
metadata = cf.load_metadata()
self.assertEqual(metadata, original_metadata)
def test_recordio_cache_files_error(self):
with self.assertRaisesRegex(
ValueError, 'cache_prefix_filename cannot be empty'
):
cache_files.TFRecordCacheFiles(
cache_prefix_filename='',
cache_dir='/tmp/cache_dir',
num_shards=2,
)
with self.assertRaisesRegex(
ValueError, 'num_shards must be greater than 0, got 0'
):
cache_files.TFRecordCacheFiles(
cache_prefix_filename='tfrecord',
cache_dir='/tmp/cache_dir',
num_shards=0,
)
if __name__ == '__main__':
tf.test.main()

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Common classification dataset library.""" """Common classification dataset library."""
from typing import List, Tuple from typing import List, Optional, Tuple
import tensorflow as tf import tensorflow as tf
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset): class ClassificationDataset(ds.Dataset):
"""Dataset Loader for classification models.""" """Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, def __init__(
label_names: List[str]): self,
dataset: tf.data.Dataset,
label_names: List[str],
size: Optional[int] = None,
):
super().__init__(dataset, size) super().__init__(dataset, size)
self._label_names = label_names self._label_names = label_names

View File

@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
value: A value variable stored by the mock dataset class for testing. value: A value variable stored by the mock dataset class for testing.
""" """
def __init__(self, dataset: tf.data.Dataset, size: int, def __init__(
label_names: List[str], value: Any): self,
super().__init__(dataset=dataset, size=size, label_names=label_names) dataset: tf.data.Dataset,
label_names: List[str],
value: Any,
size: int,
):
super().__init__(dataset=dataset, label_names=label_names, size=size)
self.value = value self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Create data loader from sample data. # Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset( data = MagicClassificationDataset(
dataset=ds, size=len(ds), label_names=label_names, value=magic_value) dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
)
# Train/Test data split. # Train/Test data split.
fraction = .25 fraction = .25

View File

@ -56,15 +56,14 @@ class Dataset(object):
def size(self) -> Optional[int]: def size(self) -> Optional[int]:
"""Returns the size of the dataset. """Returns the size of the dataset.
Note that this function may return None becuase the exact size of the Same functionality as calling __len__. See the __len__ method definition for
dataset isn't a necessary parameter to create an instance of this class, more information.
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite. Raises:
In most cases, however, when an instance of this class is created by helper TypeError if self._size is not set and the cardinality of self._dataset
functions like 'from_folder', the size of the dataset will be preprocessed, is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
and this function can return an int representing the size of the dataset.
""" """
return self._size return self.__len__()
def gen_tf_dataset( def gen_tf_dataset(
self, self,
@ -116,8 +115,22 @@ class Dataset(object):
# here. # here.
return dataset return dataset
def __len__(self): def __len__(self) -> int:
"""Returns the number of element of the dataset.""" """Returns the number of element of the dataset.
If size is not set, this method will fallback to using the __len__ method
of the tf.data.Dataset in self._dataset. Calling __len__ on a
tf.data.Dataset instance may throw a TypeError because the dataset may
be lazy-loaded with an unknown size or have infinite size.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and the _size instance variable will be already set.
Raises:
TypeError if self._size is not set and the cardinality of self._dataset
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
"""
if self._size is not None: if self._size is not None:
return self._size return self._size
else: else:
@ -152,15 +165,25 @@ class Dataset(object):
Returns: Returns:
The splitted two sub datasets. The splitted two sub datasets.
Raises:
ValueError: if the provided fraction is not between 0 and 1.
ValueError: if this dataset does not have a set size.
""" """
assert (fraction > 0 and fraction < 1) if not (fraction > 0 and fraction < 1):
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
if not self._size:
raise ValueError(
'Dataset size unknown. Cannot split the dataset when '
'the size is unknown.'
)
dataset = self._dataset dataset = self._dataset
train_size = int(self._size * fraction) train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args) trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
test_size = self._size - train_size test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args) testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
return trainset, testset return trainset, testset

View File

@ -15,7 +15,7 @@
import dataclasses import dataclasses
import tempfile import tempfile
from typing import Optional from typing import Mapping, Optional
import tensorflow as tf import tensorflow as tf
@ -36,6 +36,8 @@ class BaseHParams:
steps_per_epoch: An optional integer indicate the number of training steps steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size divided by batch size. per epoch as the training dataset size divided by batch size.
class_weights: An optional mapping of indices to weights for weighting the
loss function during training.
shuffle: True if the dataset is shuffled before training. shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files. export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to distribution_strategy: A string specifying which Distribution Strategy to
@ -57,6 +59,7 @@ class BaseHParams:
batch_size: int batch_size: int
epochs: int epochs: int
steps_per_epoch: Optional[int] = None steps_per_epoch: Optional[int] = None
class_weights: Optional[Mapping[int, float]] = None
# Dataset-related parameters # Dataset-related parameters
shuffle: bool = False shuffle: bool = False

View File

@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
# dataset is exhausted even if there are epochs remaining. # dataset is exhausted even if there are epochs remaining.
steps_per_epoch=None, steps_per_epoch=None,
validation_data=validation_dataset, validation_data=validation_dataset,
callbacks=self._callbacks) callbacks=self._callbacks,
class_weight=self._hparams.class_weights,
)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.

View File

@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
""" """
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
"""Constructor. """Initializes FocalLoss.
Args: Args:
gamma: Focal loss gamma, as described in class docs. gamma: Focal loss gamma, as described in class docs.
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
return tf.reduce_sum(losses) / batch_size return tf.reduce_sum(losses) / batch_size
class SparseFocalLoss(FocalLoss):
"""Sparse implementation of Focal Loss.
This is the same as FocalLoss, except the labels are expected to be class ids
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
in this same file for more details.
Example usage:
>>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> gamma = 2
>>> focal_loss = SparseFocalLoss(gamma, 3)
>>> focal_loss(y_true, y_pred).numpy()
0.9326
>>> # Calling with 'sample_weight'.
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.6528
"""
def __init__(
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
):
"""Initializes SparseFocalLoss.
Args:
gamma: Focal loss gamma, as described in class docs.
num_classes: Number of classes.
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super().__init__(gamma, class_weight=class_weight)
self._num_classes = num_classes
def __call__(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> tf.Tensor:
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
@dataclasses.dataclass @dataclasses.dataclass
class PerceptualLossWeight: class PerceptualLossWeight:
"""The weight for each perceptual loss. """The weight for each perceptual loss.

View File

@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(loss, expected_loss, 1e-4) self.assertNear(loss, expected_loss, 1e-4)
class SparseFocalLossTest(tf.test.TestCase):
def test_sparse_focal_loss_matches_focal_loss(self):
num_classes = 2
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
y_true = tf.constant([1, 0])
y_true_one_hot = tf.one_hot(y_true, num_classes)
for gamma in [0.0, 0.5, 1.0]:
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
loss_fn = loss_functions.SparseFocalLoss(
gamma=gamma, num_classes=num_classes
)
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
loss = loss_fn(y_true, y_pred)
self.assertNear(loss, expected_loss, 1e-4)
class MockPerceptualLoss(loss_functions.PerceptualLoss): class MockPerceptualLoss(loss_functions.PerceptualLoss):
"""A mock class with implementation of abstract methods for testing.""" """A mock class with implementation of abstract methods for testing."""

View File

@ -46,13 +46,17 @@ class BertModelSpec:
""" """
downloaded_files: file_util.DownloadedFiles downloaded_files: file_util.DownloadedFiles
hparams: hp.BaseHParams = hp.BaseHParams( hparams: hp.BaseHParams = dataclasses.field(
default_factory=lambda: hp.BaseHParams(
epochs=3, epochs=3,
batch_size=32, batch_size=32,
learning_rate=3e-5, learning_rate=3e-5,
distribution_strategy='mirrored') distribution_strategy='mirrored',
model_options: bert_model_options.BertModelOptions = ( )
bert_model_options.BertModelOptions()) )
model_options: bert_model_options.BertModelOptions = dataclasses.field(
default_factory=bert_model_options.BertModelOptions
)
do_lower_case: bool = True do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field( tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict binary and library compatibility macro.
# Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict test compatibility macro.
package(default_visibility = ["//mediapipe:__subpackages__"]) package(default_visibility = ["//mediapipe:__subpackages__"])
@ -76,7 +76,10 @@ py_test(
py_library( py_library(
name = "dataset", name = "dataset",
srcs = ["dataset.py"], srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], deps = [
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset",
],
) )
py_test( py_test(
@ -88,7 +91,10 @@ py_test(
py_library( py_library(
name = "preprocessor", name = "preprocessor",
srcs = ["preprocessor.py"], srcs = ["preprocessor.py"],
deps = [":dataset"], deps = [
":dataset",
"//mediapipe/model_maker/python/core/data:cache_files",
],
) )
py_test( py_test(
@ -99,6 +105,7 @@ py_test(
":dataset", ":dataset",
":model_spec", ":model_spec",
":preprocessor", ":preprocessor",
"//mediapipe/model_maker/python/core/data:cache_files",
], ],
) )
@ -124,6 +131,7 @@ py_library(
":text_classifier_options", ":text_classifier_options",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
@ -147,6 +155,7 @@ py_test(
], ],
deps = [ deps = [
":text_classifier_import", ":text_classifier_import",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -15,11 +15,15 @@
import csv import csv
import dataclasses import dataclasses
import hashlib
import os
import random import random
import tempfile
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
@ -46,21 +50,49 @@ class CSVParameters:
class Dataset(classification_dataset.ClassificationDataset): class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for text classifier.""" """Dataset library for text classifier."""
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
size: Optional[int] = None,
):
super().__init__(dataset, label_names, size)
if not tfrecord_cache_files:
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename="tfrecord", num_shards=1
)
self.tfrecord_cache_files = tfrecord_cache_files
@classmethod @classmethod
def from_csv(cls, def from_csv(
cls,
filename: str, filename: str,
csv_params: CSVParameters, csv_params: CSVParameters,
shuffle: bool = True) -> "Dataset": shuffle: bool = True,
cache_dir: Optional[str] = None,
num_shards: int = 1,
) -> "Dataset":
"""Loads text with labels from a CSV file. """Loads text with labels from a CSV file.
Args: Args:
filename: Name of the CSV file. filename: Name of the CSV file.
csv_params: Parameters used for reading the CSV file. csv_params: Parameters used for reading the CSV file.
shuffle: If True, randomly shuffle the data. shuffle: If True, randomly shuffle the data.
cache_dir: Optional parameter to specify where to store the preprocessed
dataset. Only used for BERT models.
num_shards: Optional parameter for num shards of the preprocessed dataset.
Note that using more than 1 shard will reorder the dataset. Only used
for BERT models.
Returns: Returns:
Dataset containing (text, label) pairs and other related info. Dataset containing (text, label) pairs and other related info.
""" """
if cache_dir is None:
cache_dir = tempfile.mkdtemp()
# calculate hash for cache based off of files
hasher = hashlib.md5()
hasher.update(os.path.basename(filename).encode("utf-8"))
with tf.io.gfile.GFile(filename, "r") as f: with tf.io.gfile.GFile(filename, "r") as f:
reader = csv.DictReader( reader = csv.DictReader(
f, f,
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
quotechar=csv_params.quotechar) quotechar=csv_params.quotechar)
lines = list(reader) lines = list(reader)
for line in lines:
hasher.update(str(line).encode("utf-8"))
if shuffle: if shuffle:
random.shuffle(lines) random.shuffle(lines)
@ -81,8 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
index_by_label[line[csv_params.label_column]] for line in lines index_by_label[line[csv_params.label_column]] for line in lines
] ]
label_index_ds = tf.data.Dataset.from_tensor_slices( label_index_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(label_indices, tf.int64)) tf.cast(label_indices, tf.int64)
)
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds)) text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
hasher.update(str(num_shards).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename, cache_dir, num_shards
)
return Dataset( return Dataset(
dataset=text_label_ds, size=len(texts), label_names=label_names) dataset=text_label_ds,
label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
size=len(texts),
)

View File

@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd']) ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, 4, ['pos', 'neg']) data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
train_data, test_data = data.split(0.5) train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad'] expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd'] expected_test_data = [b'neutral', b'odd']

View File

@ -15,7 +15,7 @@
import dataclasses import dataclasses
import enum import enum
from typing import Union from typing import Sequence, Union
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
Attributes: Attributes:
learning_rate: Learning rate to use for gradient descent training. learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training. end_learning_rate: End learning rate for linear decay. Defaults to 0.
epochs: Number of training iterations over the dataset. batch_size: Batch size for training. Defaults to 48.
optimizer: Optimizer to use for training. Only supported values are "adamw" epochs: Number of training iterations over the dataset. Defaults to 2.
and "lamb". optimizer: Optimizer to use for training. Supported values are defined in
BertOptimizer enum: ADAMW and LAMB.
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0.
""" """
learning_rate: float = 3e-5 learning_rate: float = 3e-5
end_learning_rate: float = 0.0
batch_size: int = 48 batch_size: int = 48
epochs: int = 2 epochs: int = 2
optimizer: BertOptimizer = BertOptimizer.ADAMW optimizer: BertOptimizer = BertOptimizer.ADAMW
weight_decay: float = 0.01
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
gamma: float = 2.0
HParams = Union[BertHParams, AverageWordEmbeddingHParams] HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
""" """
# `learning_rate` is unused for the average word embedding model # `learning_rate` is unused for the average word embedding model
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams( hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
default_factory=lambda: hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0 epochs=10, batch_size=32, learning_rate=0
) )
model_options: mo.AverageWordEmbeddingModelOptions = ( )
mo.AverageWordEmbeddingModelOptions()) model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
default_factory=mo.AverageWordEmbeddingModelOptions
)
name: str = 'AverageWordEmbedding' name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial( average_word_embedding_classifier_spec = functools.partial(
@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
inherited from the BertModelSpec. inherited from the BertModelSpec.
""" """
hparams: hp.BertHParams = hp.BertHParams() hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
mobilebert_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial(
@ -76,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
), ),
name='MobileBert', name='MobileBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
) )
exbert_classifier_spec = functools.partial( exbert_classifier_spec = functools.partial(
@ -90,11 +88,6 @@ exbert_classifier_spec = functools.partial(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
), ),
name='ExBert', name='ExBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
) )

View File

@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
self.assertTrue(model_spec_obj.do_lower_case) self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual( self.assertEqual(
model_spec_obj.tflite_input_name, { model_spec_obj.tflite_input_name,
'ids': 'serving_default_input_1:0', {
'mask': 'serving_default_input_3:0', 'ids': 'serving_default_input_word_ids:0',
'segment_ids': 'serving_default_input_2:0' 'mask': 'serving_default_input_mask:0',
}) 'segment_ids': 'serving_default_input_type_ids:0',
},
)
self.assertEqual( self.assertEqual(
model_spec_obj.model_options, model_spec_obj.model_options,
classifier_model_options.BertModelOptions( classifier_model_options.BertModelOptions(

View File

@ -15,14 +15,15 @@
"""Preprocessors for text classification.""" """Preprocessors for text classification."""
import collections import collections
import hashlib
import os import os
import re import re
import tempfile
from typing import Mapping, Sequence, Tuple, Union from typing import Mapping, Sequence, Tuple, Union
import tensorflow as tf import tensorflow as tf
import tensorflow_hub import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization from official.nlp.tools import tokenization
@ -75,19 +76,20 @@ def _decode_record(
return bert_features, example["label_ids"] return bert_features, example["label_ids"]
def _single_file_dataset( def _tfrecord_dataset(
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature] tfrecord_files: Sequence[str],
name_to_features: Mapping[str, tf.io.FixedLenFeature],
) -> tf.data.TFRecordDataset: ) -> tf.data.TFRecordDataset:
"""Creates a single-file dataset to be passed for BERT custom training. """Creates a single-file dataset to be passed for BERT custom training.
Args: Args:
input_file: Filepath for the dataset. tfrecord_files: Filepaths for the dataset.
name_to_features: Maps record keys to feature types. name_to_features: Maps record keys to feature types.
Returns: Returns:
Dataset containing BERT model input features and labels. Dataset containing BERT model input features and labels.
""" """
d = tf.data.TFRecordDataset(input_file) d = tf.data.TFRecordDataset(tfrecord_files)
d = d.map( d = d.map(
lambda record: _decode_record(record, name_to_features), lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.AUTOTUNE) num_parallel_calls=tf.data.AUTOTUNE)
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
seq_len: Length of the input sequence to the model. seq_len: Length of the input sequence to the model.
vocab_file: File containing the BERT vocab. vocab_file: File containing the BERT vocab.
tokenizer: BERT tokenizer. tokenizer: BERT tokenizer.
model_name: Name of the model provided by the model_spec. Used to associate
cached files with specific Bert model vocab.
""" """
def __init__(self, seq_len: int, do_lower_case: bool, uri: str): def __init__(
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
):
self._seq_len = seq_len self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI. # Vocab filepath is tied to the BERT module's URI.
self._vocab_file = os.path.join( self._vocab_file = os.path.join(
tensorflow_hub.resolve(uri), "assets", "vocab.txt") tensorflow_hub.resolve(uri), "assets", "vocab.txt"
self._tokenizer = tokenization.FullTokenizer(self._vocab_file, )
do_lower_case) self._do_lower_case = do_lower_case
self._tokenizer = tokenization.FullTokenizer(
self._vocab_file, self._do_lower_case
)
self._model_name = model_name
def _get_name_to_features(self): def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types.""" """Gets the dictionary mapping record keys to feature types."""
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
"""Returns the vocab file of the BertClassifierPreprocessor.""" """Returns the vocab file of the BertClassifierPreprocessor."""
return self._vocab_file return self._vocab_file
def _get_tfrecord_cache_files(
self, ds_cache_files
) -> cache_files_lib.TFRecordCacheFiles:
"""Helper to regenerate cache prefix filename using preprocessor info.
We need to update the dataset cache_prefix cache because the actual cached
dataset depends on the preprocessor parameters such as model_name, seq_len,
and do_lower_case in addition to the raw dataset parameters which is already
included in the ds_cache_files.cache_prefix_filename
Specifically, the new cache_prefix_filename used by the preprocessor will
be a hash generated from the following:
1. cache_prefix_filename of the initial raw dataset
2. model_name
3. seq_len
4. do_lower_case
Args:
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
Returns:
A new TFRecordCacheFiles object which incorporates the preprocessor
parameters.
"""
hasher = hashlib.md5()
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
hasher.update(self._model_name.encode("utf-8"))
hasher.update(str(self._seq_len).encode("utf-8"))
hasher.update(str(self._do_lower_case).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
return cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename,
ds_cache_files.cache_dir,
ds_cache_files.num_shards,
)
def preprocess( def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset: self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for a BERT-based classifier. """Preprocesses data into input for a BERT-based classifier.
Args: Args:
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
Returns: Returns:
Dataset containing (bert_features, label) data. Dataset containing (bert_features, label) data.
""" """
examples = [] ds_cache_files = dataset.tfrecord_cache_files
# Get new tfrecord_cache_files by including preprocessor information.
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
if not tfrecord_cache_files.is_cached():
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
writers = tfrecord_cache_files.get_writers()
size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()): for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label) _validate_text_and_label(text, label)
examples.append( example = classifier_data_lib.InputExample(
classifier_data_lib.InputExample(
guid=str(index), guid=str(index),
text_a=text.numpy()[0].decode("utf-8"), text_a=text.numpy()[0].decode("utf-8"),
text_b=None, text_b=None,
# InputExample expects the label name rather than the int ID # InputExample expects the label name rather than the int ID
label=dataset.label_names[label.numpy()[0]])) # label=dataset.label_names[label.numpy()[0]])
label=label.numpy()[0],
)
feature = classifier_data_lib.convert_single_example(
index, example, None, self._seq_len, self._tokenizer
)
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord") def create_int_feature(values):
classifier_data_lib.file_based_convert_examples_to_features( f = tf.train.Feature(
examples=examples, int64_list=tf.train.Int64List(value=list(values))
label_list=dataset.label_names, )
max_seq_length=self._seq_len, return f
tokenizer=self._tokenizer,
output_file=tfrecord_file) features = collections.OrderedDict()
preprocessed_ds = _single_file_dataset(tfrecord_file, features["input_ids"] = create_int_feature(feature.input_ids)
self._get_name_to_features()) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
tf_example = tf.train.Example(
features=tf.train.Features(feature=features)
)
writers[index % len(writers)].write(tf_example.SerializeToString())
size = index + 1
for writer in writers:
writer.close()
metadata = {"size": size, "label_names": dataset.label_names}
tfrecord_cache_files.save_metadata(metadata)
else:
print(
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
)
metadata = tfrecord_cache_files.load_metadata()
size = metadata["size"]
label_names = metadata["label_names"]
preprocessed_ds = _tfrecord_dataset(
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
)
return text_classifier_ds.Dataset( return text_classifier_ds.Dataset(
dataset=preprocessed_ds, dataset=preprocessed_ds,
size=dataset.size, size=size,
label_names=dataset.label_names) label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
)
TextClassifierPreprocessor = ( TextClassifierPreprocessor = Union[
Union[BertClassifierPreprocessor, BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
AverageWordEmbeddingClassifierPreprocessor]) ]

View File

@ -13,14 +13,17 @@
# limitations under the License. # limitations under the License.
import csv import csv
import io
import os import os
import tempfile import tempfile
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import mock
import numpy as np import numpy as np
import numpy.testing as npt import numpy.testing as npt
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
csv_file = self._get_csv_file() csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv( dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_) filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor( bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5, seq_len=5,
do_lower_case=bert_spec.do_lower_case, do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(), uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
) )
preprocessed_dataset = bert_preprocessor.preprocess(dataset) preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = [] labels = []
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
self.assertEqual(label.shape, [1]) self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0]) labels.append(label.numpy()[0])
self.assertSameElements( self.assertSameElements(
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']) features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
)
for feature in features.values(): for feature in features.values():
self.assertEqual(feature.shape, [1, 5]) self.assertEqual(feature.shape, [1, 5])
input_masks.append(features['input_mask'].numpy()[0]) input_masks.append(features['input_mask'].numpy()[0])
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
[0, 0, 0, 0, 0])
npt.assert_array_equal( npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
)
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
)
self.assertEqual(labels, [1, 0]) self.assertEqual(labels, [1, 0])
def test_bert_preprocessor_cache(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file,
csv_params=self.CSV_PARAMS_,
cache_dir=self.get_temp_dir(),
)
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
ds_cache_files = dataset.tfrecord_cache_files
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
ds_cache_files
)
self.assertFalse(preprocessed_cache_files.is_cached())
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
self.assertTrue(preprocessed_cache_files.is_cached())
self.assertEqual(
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
)
# The second time running preprocessor, it should load from cache directly
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
_ = bert_preprocessor.preprocess(dataset)
self.assertEqual(
mock_stdout.getvalue(),
'Using existing cache files at'
f' {preprocessed_cache_files.cache_prefix}\n',
)
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len,
do_lower_case=do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
return new_cf.cache_prefix_filename
def test_bert_get_tfrecord_cache_files(self):
# Test to ensure regenerated cache_files have different prefixes
all_cf_prefixes = set()
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
new_cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='new_cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
# Each item of all_cf_prefixes should be unique, so 7 total.
self.assertLen(all_cf_prefixes, 7)
if __name__ == '__main__': if __name__ == '__main__':
# Load compressed models from tensorflow_hub # Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main() tf.test.main()

View File

@ -16,8 +16,8 @@
} }
}, },
{ {
"name": "mask", "name": "segment_ids",
"description": "Mask with 1 for real tokens and 0 for padding tokens.", "description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {
@ -27,8 +27,8 @@
} }
}, },
{ {
"name": "segment_ids", "name": "mask",
"description": "0 for the first sequence, 1 for the second sequence if exists.", "description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {

View File

@ -24,6 +24,7 @@ import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
): ):
text_classifier = ( text_classifier = _BertClassifier.create_bert_classifier(
_BertClassifier.create_bert_classifier(train_data, validation_data, train_data, validation_data, options
options, )
train_data.label_names))
elif (options.supported_model == elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
text_classifier = ( text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
_AverageWordEmbeddingClassifier train_data, validation_data, options
.create_average_word_embedding_classifier(train_data, validation_data, )
options,
train_data.label_names))
else: else:
raise ValueError(f"Unknown model {options.supported_model}") raise ValueError(f"Unknown model {options.supported_model}")
@ -166,27 +164,7 @@ class TextClassifier(classifier.Classifier):
processed_data = self._text_preprocessor.preprocess(data) processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
additional_metrics = [] with self._hparams.get_strategy().scope():
if desired_precisions and len(data.label_names) == 2:
for precision in desired_precisions:
additional_metrics.append(
metrics.BinarySparseRecallAtPrecision(
precision, name=f"recall_at_precision_{precision}"
)
)
if desired_recalls and len(data.label_names) == 2:
for recall in desired_recalls:
additional_metrics.append(
metrics.BinarySparsePrecisionAtRecall(
recall, name=f"precision_at_recall_{recall}"
)
)
metric_functions = self._metric_functions + additional_metrics
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=metric_functions,
)
return self._model.evaluate(dataset) return self._model.evaluate(dataset)
def export_model( def export_model(
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
@classmethod @classmethod
def create_average_word_embedding_classifier( def create_average_word_embedding_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, cls,
train_data: text_ds.Dataset,
validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions, options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": ) -> "_AverageWordEmbeddingClassifier":
"""Creates, trains, and returns an Average Word Embedding classifier. """Creates, trains, and returns an Average Word Embedding classifier.
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
options: Options for creating and training the text classifier. options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns: Returns:
An Average Word Embedding classifier. An Average Word Embedding classifier.
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options self._model_options = model_options
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() self._loss_function = loss_functions.SparseFocalLoss(
self._metric_functions = [ self._hparams.gamma, self._num_classes
tf.keras.metrics.SparseCategoricalAccuracy( )
"test_accuracy", dtype=tf.float32 self._metric_functions = self._create_metrics()
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod @classmethod
def create_bert_classifier( def create_bert_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, cls,
train_data: text_ds.Dataset,
validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions, options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_BertClassifier": ) -> "_BertClassifier":
"""Creates, trains, and returns a BERT-based classifier. """Creates, trains, and returns a BERT-based classifier.
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
options: Options for creating and training the text classifier. options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns: Returns:
A BERT-based classifier. A BERT-based classifier.
@ -435,9 +411,59 @@ class _BertClassifier(TextClassifier):
seq_len=self._model_options.seq_len, seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case, do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.downloaded_files.get_path(), uri=self._model_spec.downloaded_files.get_path(),
model_name=self._model_spec.name,
) )
return (self._text_preprocessor.preprocess(train_data), return (
self._text_preprocessor.preprocess(validation_data)) self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data),
)
def _create_metrics(self):
"""Creates metrics for training and evaluation.
The default metrics are accuracy, precision, and recall.
For binary classification tasks only (num_classes=2):
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
the desired_presisions and desired_recalls fields in BertHParams.
Returns:
A list of tf.keras.Metric subclasses which can be used with model.compile
"""
metric_functions = [
tf.keras.metrics.SparseCategoricalAccuracy(
"accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
if self._num_classes == 2:
if self._hparams.desired_precisions:
for desired_precision in self._hparams.desired_precisions:
metric_functions.append(
metrics.BinarySparseRecallAtPrecision(
desired_precision,
name=f"recall_at_precision_{desired_precision}",
num_thresholds=1000,
)
)
if self._hparams.desired_recalls:
for desired_recall in self._hparams.desired_recalls:
metric_functions.append(
metrics.BinarySparseRecallAtPrecision(
desired_recall,
name=f"precision_at_recall_{desired_recall}",
num_thresholds=1000,
)
)
else:
if self._hparams.desired_precisions or self._hparams.desired_recalls:
raise ValueError(
"desired_recalls and desired_precisions parameters are binary"
" metrics and not supported for num_classes > 2. Found"
f" num_classes: {self._num_classes}"
)
return metric_functions
def _create_model(self): def _create_model(self):
"""Creates a BERT-based classifier model. """Creates a BERT-based classifier model.
@ -447,11 +473,20 @@ class _BertClassifier(TextClassifier):
""" """
encoder_inputs = dict( encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input( input_word_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_word_ids",
),
input_mask=tf.keras.layers.Input( input_mask=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_mask",
),
input_type_ids=tf.keras.layers.Input( input_type_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_type_ids",
),
) )
encoder = hub.KerasLayer( encoder = hub.KerasLayer(
self._model_spec.downloaded_files.get_path(), self._model_spec.downloaded_files.get_path(),
@ -493,16 +528,21 @@ class _BertClassifier(TextClassifier):
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr, initial_learning_rate=initial_lr,
decay_steps=total_steps, decay_steps=total_steps,
end_learning_rate=0.0, end_learning_rate=self._hparams.end_learning_rate,
power=1.0) power=1.0,
)
if warmup_steps: if warmup_steps:
lr_schedule = model_util.WarmUp( lr_schedule = model_util.WarmUp(
initial_learning_rate=initial_lr, initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule, decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps) warmup_steps=warmup_steps,
)
if self._hparams.optimizer == hp.BertOptimizer.ADAMW: if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW( self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0 lr_schedule,
weight_decay=self._hparams.weight_decay,
epsilon=1e-6,
global_clipnorm=1.0,
) )
self._optimizer.exclude_from_weight_decay( self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"] var_names=["LayerNorm", "layer_norm", "bias"]
@ -510,7 +550,7 @@ class _BertClassifier(TextClassifier):
elif self._hparams.optimizer == hp.BertOptimizer.LAMB: elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB( self._optimizer = tfa_optimizers.LAMB(
lr_schedule, lr_schedule,
weight_decay_rate=0.01, weight_decay_rate=self._hparams.weight_decay,
epsilon=1e-6, epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0, global_clipnorm=1.0,

View File

@ -84,8 +84,8 @@ def run(data_dir,
options) options)
# Gets evaluation results. # Gets evaluation results.
_, acc = model.evaluate(validation_data) metrics = model.evaluate(validation_data)
print('Eval accuracy: %f' % acc) print('Eval accuracy: %f' % metrics[1])
model.export_model(quantization_config=quantization_config) model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir=options.hparams.export_dir) model.export_labels(export_dir=options.hparams.export_dir)

View File

@ -16,17 +16,17 @@ import csv
import filecmp import filecmp
import os import os
import tempfile import tempfile
import unittest
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.text import text_classifier from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@unittest.skip('b/275624089') class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = ( _AVERAGE_WORD_EMBEDDING_JSON_FILE = (
test_utils.get_test_data_path('average_word_embedding_metadata.json')) test_utils.get_test_data_path('average_word_embedding_metadata.json'))
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifier.create(train_data, validation_data, text_classifier.TextClassifier.create(train_data, validation_data,
options)) options))
_, accuracy = average_word_embedding_classifier.evaluate(validation_data) metrics = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0) self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
# Test export_model # Test export_model
average_word_embedding_classifier.export_model() average_word_embedding_classifier.export_model()
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
filecmp.cmp( filecmp.cmp(
output_metadata_file, output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE, self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
shallow=False)) shallow=False,
)
)
def test_create_and_train_bert(self): @parameterized.named_parameters(
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
# dict(
# testcase_name='mobilebert',
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
# ),
dict(
testcase_name='exbert',
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
),
)
def test_create_and_train_bert(self, supported_model):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
options = text_classifier.TextClassifierOptions( options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, supported_model=supported_model,
model_options=text_classifier.BertModelOptions( model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2 do_fine_tuning=False, seq_len=2
), ),
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
bert_classifier = text_classifier.TextClassifier.create( bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options) train_data, validation_data, options)
_, accuracy = bert_classifier.evaluate(validation_data) metrics = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0) self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
# Test export_model # Test export_model
bert_classifier.export_model() bert_classifier.export_model()
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
) )
def test_label_mismatch(self): def test_label_mismatch(self):
options = ( options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
supported_model=( )
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'Training data label names .* not equal to validation data label names' 'Training data label names .* not equal to validation data label names',
): ):
text_classifier.TextClassifier.create(train_data, validation_data, text_classifier.TextClassifier.create(
options) train_data, validation_data, options
)
def test_options_mismatch(self): def test_options_mismatch(self):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
avg_options = ( avg_options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
supported_model=( model_options=text_classifier.AverageWordEmbeddingModelOptions(),
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), )
model_options=text_classifier.AverageWordEmbeddingModelOptions())) with self.assertRaisesWithLiteralMatch(
with self.assertRaisesRegex( ValueError,
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'): ' SupportedModels.EXBERT_CLASSIFIER',
text_classifier.TextClassifier.create(train_data, validation_data, ):
avg_options) text_classifier.TextClassifier.create(
train_data, validation_data, avg_options
)
bert_options = ( bert_options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(
supported_model=(text_classifier.SupportedModels text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
.AVERAGE_WORD_EMBEDDING_CLASSIFIER), ),
model_options=text_classifier.BertModelOptions())) model_options=text_classifier.BertModelOptions(),
with self.assertRaisesRegex( )
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' with self.assertRaisesWithLiteralMatch(
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): ValueError,
text_classifier.TextClassifier.create(train_data, validation_data, 'Expected a Bert Classifier(MobileBERT or EXBERT), got'
bert_options) ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
):
text_classifier.TextClassifier.create(
train_data, validation_data, bert_options
)
def test_bert_loss_and_metrics_creation(self):
train_data, validation_data = self._get_data()
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
hparams = text_classifier.BertHParams(
desired_recalls=[0.2],
desired_precisions=[0.9],
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off',
gamma=3.5,
)
options = text_classifier.TextClassifierOptions(
supported_model=supported_model, hparams=hparams
)
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options
)
loss_fn = bert_classifier._loss_function
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
self.assertEqual(loss_fn._gamma, 3.5)
self.assertEqual(loss_fn._num_classes, 2)
metric_names = [m.name for m in bert_classifier._metric_functions]
expected_metric_names = [
'accuracy',
'recall',
'precision',
'precision_at_recall_0.2',
'recall_at_precision_0.9',
]
self.assertCountEqual(metric_names, expected_metric_names)
# Non-binary data
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
with self.assertRaisesWithLiteralMatch(
ValueError,
'desired_recalls and desired_precisions parameters are binary metrics'
' and not supported for num_classes > 2. Found num_classes: 3',
):
text_classifier.TextClassifier.create(data, data, options)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
', '.join(label_names), ', '.join(label_names),
) )
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names dataset=image_label_ds,
label_names=label_names,
size=all_image_size,
) )

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict binary and library compatibility macro.
licenses(["notice"]) licenses(["notice"])

View File

@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
len(valid_hand_data), len(label_names), ','.join(label_names))) len(valid_hand_data), len(label_names), ','.join(label_names)))
return Dataset( return Dataset(
dataset=hand_embedding_label_ds, dataset=hand_embedding_label_ds,
label_names=label_names,
size=len(valid_hand_data), size=len(valid_hand_data),
label_names=label_names) )

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict binary and library compatibility macro.
# Placeholder for internal Python library rule. # Placeholder for internal Python library rule.
licenses(["notice"]) licenses(["notice"])

View File

@ -15,28 +15,12 @@
import os import os
import random import random
from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils from mediapipe.model_maker.python.vision.core import image_utils
def _create_data(
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
label_names: List[str]
) -> Optional[classification_dataset.ClassificationDataset]:
"""Creates a Dataset object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return Dataset(data, size, label_names)
class Dataset(classification_dataset.ClassificationDataset): class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for image classifier.""" """Dataset library for image classifier."""
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names) dataset=image_label_ds, label_names=label_names, size=all_image_size
)

View File

@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg']) data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
train_data, test_data = data.split(fraction=0.5) train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2) self.assertLen(train_data, 2)

View File

@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
ds = tf.data.Dataset.from_generator( ds = tf.data.Dataset.from_generator(
self._gen, (tf.uint8, tf.int64), (tf.TensorShape( self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3, data = image_classifier.Dataset(
['cyan', 'magenta', 'yellow']) ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
)
return data return data
def setUp(self): def setUp(self):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict binary and library compatibility macro.
# Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict test compatibility macro.
licenses(["notice"]) licenses(["notice"])
@ -54,6 +54,7 @@ py_library(
srcs = ["dataset.py"], srcs = ["dataset.py"],
deps = [ deps = [
":dataset_util", ":dataset_util",
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
], ],
) )
@ -73,6 +74,7 @@ py_test(
py_library( py_library(
name = "dataset_util", name = "dataset_util",
srcs = ["dataset_util.py"], srcs = ["dataset_util.py"],
deps = ["//mediapipe/model_maker/python/core/data:cache_files"],
) )
py_test( py_test(

View File

@ -16,8 +16,8 @@
from typing import Optional from typing import Optional
import tensorflow as tf import tensorflow as tf
import yaml
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.object_detector import dataset_util from mediapipe.model_maker.python.vision.object_detector import dataset_util
from official.vision.dataloaders import tf_example_decoder from official.vision.dataloaders import tf_example_decoder
@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset):
ValueError: If the label_name for id 0 is set to something other than ValueError: If the label_name for id 0 is set to something other than
the 'background' class. the 'background' class.
""" """
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir) tfrecord_cache_files = dataset_util.get_cache_files_coco(
if not dataset_util.is_cached(cache_files): data_dir, cache_dir
)
if not tfrecord_cache_files.is_cached():
label_map = dataset_util.get_label_map_coco(data_dir) label_map = dataset_util.get_label_map_coco(data_dir)
cache_writer = dataset_util.COCOCacheFilesWriter( cache_writer = dataset_util.COCOCacheFilesWriter(
label_map=label_map, max_num_images=max_num_images label_map=label_map, max_num_images=max_num_images
) )
cache_writer.write_files(cache_files, data_dir) cache_writer.write_files(tfrecord_cache_files, data_dir)
return cls.from_cache(cache_files.cache_prefix) return cls.from_cache(tfrecord_cache_files)
@classmethod @classmethod
def from_pascal_voc_folder( def from_pascal_voc_folder(
@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset):
Raises: Raises:
ValueError: if the input data directory is empty. ValueError: if the input data directory is empty.
""" """
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir) tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc(
if not dataset_util.is_cached(cache_files): data_dir, cache_dir
)
if not tfrecord_cache_files.is_cached():
label_map = dataset_util.get_label_map_pascal_voc(data_dir) label_map = dataset_util.get_label_map_pascal_voc(data_dir)
cache_writer = dataset_util.PascalVocCacheFilesWriter( cache_writer = dataset_util.PascalVocCacheFilesWriter(
label_map=label_map, max_num_images=max_num_images label_map=label_map, max_num_images=max_num_images
) )
cache_writer.write_files(cache_files, data_dir) cache_writer.write_files(tfrecord_cache_files, data_dir)
return cls.from_cache(cache_files.cache_prefix) return cls.from_cache(tfrecord_cache_files)
@classmethod @classmethod
def from_cache(cls, cache_prefix: str) -> 'Dataset': def from_cache(
cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles
) -> 'Dataset':
"""Loads the TFRecord data from cache. """Loads the TFRecord data from cache.
Args: Args:
cache_prefix: The cache prefix including the cache directory and the cache tfrecord_cache_files: The TFRecordCacheFiles object containing the already
prefix filename, e.g: '/tmp/cache/train'. cached TFRecord and metadata files.
Returns: Returns:
ObjectDetectorDataset object. ObjectDetectorDataset object.
Raises:
ValueError if tfrecord_cache_files are not already cached.
""" """
# Get TFRecord Files if not tfrecord_cache_files.is_cached():
tfrecord_file_pattern = cache_prefix + '*.tfrecord' raise ValueError(
matched_files = tf.io.gfile.glob(tfrecord_file_pattern) 'Cache files must be already cached to use the from_cache method.'
if not matched_files: )
raise ValueError('TFRecord files are empty.')
# Load meta_data. metadata = tfrecord_cache_files.load_metadata()
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
if not tf.io.gfile.exists(meta_data_file):
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
meta_data = yaml.load(f, Loader=yaml.FullLoader)
dataset = tf.data.TFRecordDataset(matched_files) dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files)
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False) decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
label_map = meta_data['label_map'] label_map = metadata['label_map']
label_names = [label_map[k] for k in sorted(label_map.keys())] label_names = [label_map[k] for k in sorted(label_map.keys())]
return Dataset( return Dataset(
dataset=dataset, size=meta_data['size'], label_names=label_names dataset=dataset, label_names=label_names, size=metadata['size']
) )

View File

@ -15,25 +15,20 @@
import abc import abc
import collections import collections
import dataclasses
import hashlib import hashlib
import json import json
import math import math
import os import os
import tempfile import tempfile
from typing import Any, Dict, List, Mapping, Optional, Sequence from typing import Any, Dict, List, Mapping, Optional
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import tensorflow as tf import tensorflow as tf
import yaml
from mediapipe.model_maker.python.core.data import cache_files
from official.vision.data import tfrecord_lib from official.vision.data import tfrecord_lib
# Suffix of the meta data file name.
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
def _xml_get(node: ET.Element, name: str) -> ET.Element: def _xml_get(node: ET.Element, name: str) -> ET.Element:
"""Gets a named child from an XML Element node. """Gets a named child from an XML Element node.
@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str:
return os.path.basename(os.path.abspath(data_dir)) return os.path.basename(os.path.abspath(data_dir))
@dataclasses.dataclass(frozen=True)
class CacheFiles:
"""Cache files for object detection."""
cache_prefix: str
tfrecord_files: Sequence[str]
meta_data_file: str
def _get_cache_files( def _get_cache_files(
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10 cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
) -> CacheFiles: ) -> cache_files.TFRecordCacheFiles:
"""Creates an object of CacheFiles class. """Creates an object of CacheFiles class.
Args: Args:
@ -96,28 +82,16 @@ def _get_cache_files(
An object of CacheFiles class. An object of CacheFiles class.
""" """
cache_dir = _get_cache_dir_or_create(cache_dir) cache_dir = _get_cache_dir_or_create(cache_dir)
# The cache prefix including the cache directory and the cache prefix return cache_files.TFRecordCacheFiles(
# filename, e.g: '/tmp/cache/train'. cache_prefix_filename=cache_prefix_filename,
cache_prefix = os.path.join(cache_dir, cache_prefix_filename) cache_dir=cache_dir,
tf.compat.v1.logging.info( num_shards=num_shards,
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
% (cache_dir, cache_prefix_filename, cache_prefix)
)
# Cached files including the TFRecord files and the meta data file.
tfrecord_files = [
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
for i in range(num_shards)
]
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
return CacheFiles(
cache_prefix=cache_prefix,
tfrecord_files=tuple(tfrecord_files),
meta_data_file=meta_data_file,
) )
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: def get_cache_files_coco(
data_dir: str, cache_dir: str
) -> cache_files.TFRecordCacheFiles:
"""Creates an object of CacheFiles class using a COCO formatted dataset. """Creates an object of CacheFiles class using a COCO formatted dataset.
Args: Args:
@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: def get_cache_files_pascal_voc(
data_dir: str, cache_dir: str
) -> cache_files.TFRecordCacheFiles:
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset. """Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
Args: Args:
@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
def is_cached(cache_files: CacheFiles) -> bool:
"""Checks whether cache files are already cached."""
all_cached_files = list(cache_files.tfrecord_files) + [
cache_files.meta_data_file
]
return all(tf.io.gfile.exists(path) for path in all_cached_files)
class CacheFilesWriter(abc.ABC): class CacheFilesWriter(abc.ABC):
"""CacheFilesWriter class to write the cached files.""" """CacheFilesWriter class to write the cached files."""
@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC):
self.label_map = label_map self.label_map = label_map
self.max_num_images = max_num_images self.max_num_images = max_num_images
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None: def write_files(
"""Writes TFRecord and meta_data files. self,
tfrecord_cache_files: cache_files.TFRecordCacheFiles,
*args,
**kwargs,
) -> None:
"""Writes TFRecord and metadata files.
Args: Args:
cache_files: CacheFiles object including a list of TFRecord files and the tfrecord_cache_files: TFRecordCacheFiles object including a list of
meta data yaml file to save the meta_data including data size and TFRecord files and the meta data yaml file to save the metadata
label_map. including data size and label_map.
*args: Non-keyword of parameters used in the `_get_example` method. *args: Non-keyword of parameters used in the `_get_example` method.
**kwargs: Keyword parameters used in the `_get_example` method. **kwargs: Keyword parameters used in the `_get_example` method.
""" """
writers = [ writers = tfrecord_cache_files.get_writers()
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
]
# Writes tf.Example into TFRecord files. # Writes tf.Example into TFRecord files.
size = 0 size = 0
@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC):
for writer in writers: for writer in writers:
writer.close() writer.close()
# Writes meta_data into meta_data_file. # Writes metadata into metadata_file.
meta_data = {'size': size, 'label_map': self.label_map} metadata = {'size': size, 'label_map': self.label_map}
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f: tfrecord_cache_files.save_metadata(metadata)
yaml.dump(meta_data, f)
@abc.abstractmethod @abc.abstractmethod
def _get_example(self, *args, **kwargs): def _get_example(self, *args, **kwargs):

View File

@ -19,7 +19,6 @@ import shutil
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import tensorflow as tf import tensorflow as tf
import yaml
from mediapipe.model_maker.python.vision.core import test_utils from mediapipe.model_maker.python.vision.core import test_utils
from mediapipe.model_maker.python.vision.object_detector import dataset_util from mediapipe.model_maker.python.vision.object_detector import dataset_util
@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase):
def _assert_cache_files_equal(self, cf1, cf2): def _assert_cache_files_equal(self, cf1, cf2):
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix) self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files) self.assertEqual(cf1.num_shards, cf2.num_shards)
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
def _assert_cache_files_not_equal(self, cf1, cf2): def _assert_cache_files_not_equal(self, cf1, cf2):
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix) self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn): def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir): def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
) )
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
def test_matching_get_cache_files_coco(self): def test_matching_get_cache_files_coco(self):
cache_dir = self.create_tempdir() cache_dir = self.create_tempdir()
@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
) )
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
def test_matching_get_cache_files_pascal_voc(self): def test_matching_get_cache_files_pascal_voc(self):
cache_dir = self.create_tempdir() cache_dir = self.create_tempdir()
@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase):
cache_files = dataset_util.get_cache_files_coco( cache_files = dataset_util.get_cache_files_coco(
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
) )
self.assertFalse(dataset_util.is_cached(cache_files)) self.assertFalse(cache_files.is_cached())
with open(cache_files.tfrecord_files[0], 'w') as f: with open(cache_files.tfrecord_files[0], 'w') as f:
f.write('test') f.write('test')
self.assertFalse(dataset_util.is_cached(cache_files)) self.assertFalse(cache_files.is_cached())
with open(cache_files.meta_data_file, 'w') as f: with open(cache_files.metadata_file, 'w') as f:
f.write('test') f.write('test')
self.assertTrue(dataset_util.is_cached(cache_files)) self.assertTrue(cache_files.is_cached())
def test_get_label_map_coco(self): def test_get_label_map_coco(self):
coco_dir = tasks_test_utils.get_test_data_path('coco_data') coco_dir = tasks_test_utils.get_test_data_path('coco_data')
@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0])) self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0) self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
# Checks the meta_data file # Checks the metadata file
self.assertTrue(os.path.isfile(cache_files.meta_data_file)) self.assertTrue(os.path.isfile(cache_files.metadata_file))
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0) self.assertGreater(os.path.getsize(cache_files.metadata_file), 0)
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f: metadata_dict = cache_files.load_metadata()
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(metadata_dict['size'], expected_size)
# Size is 3 because some examples are skipped for having poor bboxes
self.assertEqual(meta_data_dict['size'], expected_size)
def test_coco_cache_files_writer(self): def test_coco_cache_files_writer(self):
tempdir = self.create_tempdir() tempdir = self.create_tempdir()

View File

@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(), generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
) -> configs.retinanet.RetinaNet: ) -> configs.retinanet.RetinaNet:
model_config = configs.retinanet.RetinaNet( model_config = configs.retinanet.RetinaNet(
min_level=3, min_level=self._model_spec.min_level,
max_level=7, max_level=self._model_spec.max_level,
num_classes=self._num_classes, num_classes=self._num_classes,
input_size=self._model_spec.input_image_shape, input_size=self._model_spec.input_image_shape,
anchor=configs.retinanet.Anchor( anchor=configs.retinanet.Anchor(

Some files were not shown because too many files have changed in this diff Show More