Merge branch 'master' into nguyencse/facemeshioslib
This commit is contained in:
commit
1001ead358
14
WORKSPACE
14
WORKSPACE
|
@ -157,22 +157,22 @@ http_archive(
|
|||
# 2020-08-21
|
||||
http_archive(
|
||||
name = "com_github_glog_glog",
|
||||
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
|
||||
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
|
||||
strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
|
||||
sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
|
||||
urls = [
|
||||
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
|
||||
"https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
|
||||
],
|
||||
)
|
||||
http_archive(
|
||||
name = "com_github_glog_glog_no_gflags",
|
||||
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
|
||||
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
|
||||
strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
|
||||
sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
|
||||
build_file = "@//third_party:glog_no_gflags.BUILD",
|
||||
urls = [
|
||||
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
|
||||
"https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
|
||||
],
|
||||
patches = [
|
||||
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff",
|
||||
"@//third_party:com_github_glog_glog.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
|
|
151
mediapipe/BUILD
151
mediapipe/BUILD
|
@ -68,30 +68,108 @@ config_setting(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Note: this cannot just match "apple_platform_type": "macos" because that option
|
||||
# defaults to "macos" even when building on Linux!
|
||||
alias(
|
||||
# Generic MacOS.
|
||||
config_setting(
|
||||
name = "macos",
|
||||
actual = select({
|
||||
":macos_i386": ":macos_i386",
|
||||
":macos_x86_64": ":macos_x86_64",
|
||||
":macos_arm64": ":macos_arm64",
|
||||
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
|
||||
}),
|
||||
constraint_values = [
|
||||
"@platforms//os:macos",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Note: this also matches on crosstool_top so that it does not produce ambiguous
|
||||
# selectors when used together with "android".
|
||||
# MacOS x86 64-bit.
|
||||
config_setting(
|
||||
name = "macos_x86_64",
|
||||
constraint_values = [
|
||||
"@platforms//os:macos",
|
||||
"@platforms//cpu:x86_64",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# MacOS ARM64.
|
||||
config_setting(
|
||||
name = "macos_arm64",
|
||||
constraint_values = [
|
||||
"@platforms//os:macos",
|
||||
"@platforms//cpu:arm64",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Generic iOS.
|
||||
config_setting(
|
||||
name = "ios",
|
||||
values = {
|
||||
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
|
||||
"apple_platform_type": "ios",
|
||||
},
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# iOS device ARM32.
|
||||
config_setting(
|
||||
name = "ios_armv7",
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
"@platforms//cpu:arm",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# iOS device ARM64.
|
||||
config_setting(
|
||||
name = "ios_arm64",
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
"@platforms//cpu:arm64",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# iOS device ARM64E.
|
||||
config_setting(
|
||||
name = "ios_arm64e",
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
"@platforms//cpu:arm64e",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# iOS simulator x86 32-bit.
|
||||
config_setting(
|
||||
name = "ios_i386",
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
"@platforms//cpu:x86_32",
|
||||
"@build_bazel_apple_support//constraints:simulator",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# iOS simulator x86 64-bit.
|
||||
config_setting(
|
||||
name = "ios_x86_64",
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
"@platforms//cpu:x86_64",
|
||||
"@build_bazel_apple_support//constraints:simulator",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# iOS simulator ARM64.
|
||||
config_setting(
|
||||
name = "ios_sim_arm64",
|
||||
constraint_values = [
|
||||
"@platforms//os:ios",
|
||||
"@platforms//cpu:arm64",
|
||||
"@build_bazel_apple_support//constraints:simulator",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Generic Apple.
|
||||
alias(
|
||||
name = "apple",
|
||||
actual = select({
|
||||
|
@ -102,49 +180,6 @@ alias(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_i386",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_x86_64",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin_x86_64",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_arm64",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin_arm64",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
[
|
||||
config_setting(
|
||||
name = arch,
|
||||
values = {"cpu": arch},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
for arch in [
|
||||
"ios_i386",
|
||||
"ios_x86_64",
|
||||
"ios_armv7",
|
||||
"ios_arm64",
|
||||
"ios_arm64e",
|
||||
"ios_sim_arm64",
|
||||
]
|
||||
]
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {"cpu": "x64_windows"},
|
||||
|
|
|
@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator);
|
|||
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
|
||||
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
|
||||
|
||||
namespace {
|
||||
std::unique_ptr<audio_dsp::WindowFunction> MakeWindowFun(
|
||||
const SpectrogramCalculatorOptions::WindowType window_type) {
|
||||
switch (window_type) {
|
||||
// The cosine window and square root of Hann are equivalent.
|
||||
case SpectrogramCalculatorOptions::COSINE:
|
||||
case SpectrogramCalculatorOptions::SQRT_HANN:
|
||||
return std::make_unique<audio_dsp::CosineWindow>();
|
||||
case SpectrogramCalculatorOptions::HANN:
|
||||
return std::make_unique<audio_dsp::HannWindow>();
|
||||
case SpectrogramCalculatorOptions::HAMMING:
|
||||
return std::make_unique<audio_dsp::HammingWindow>();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
||||
SpectrogramCalculatorOptions spectrogram_options =
|
||||
cc->Options<SpectrogramCalculatorOptions>();
|
||||
|
@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
output_scale_ = spectrogram_options.output_scale();
|
||||
|
||||
auto window_fun = MakeWindowFun(spectrogram_options.window_type());
|
||||
if (window_fun == nullptr) {
|
||||
return absl::Status(absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("Invalid window type ",
|
||||
spectrogram_options.window_type()));
|
||||
}
|
||||
std::vector<double> window;
|
||||
switch (spectrogram_options.window_type()) {
|
||||
case SpectrogramCalculatorOptions::COSINE:
|
||||
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
case SpectrogramCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
case SpectrogramCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
case SpectrogramCalculatorOptions::SQRT_HANN: {
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
absl::c_transform(window, window.begin(),
|
||||
[](double x) { return std::sqrt(x); });
|
||||
break;
|
||||
}
|
||||
}
|
||||
window_fun->GetPeriodicSamples(frame_duration_samples_, &window);
|
||||
|
||||
// Propagate settings down to the actual Spectrogram object.
|
||||
spectrogram_generators_.clear();
|
||||
|
|
|
@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions {
|
|||
HANN = 0;
|
||||
HAMMING = 1;
|
||||
COSINE = 2;
|
||||
SQRT_HANN = 4;
|
||||
SQRT_HANN = 4; // Alias of COSINE.
|
||||
}
|
||||
optional WindowType window_type = 6 [default = HANN];
|
||||
|
||||
|
|
|
@ -381,17 +381,6 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "clip_detection_vector_size_calculator",
|
||||
srcs = ["clip_detection_vector_size_calculator.cc"],
|
||||
deps = [
|
||||
":clip_vector_size_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "clip_vector_size_calculator_test",
|
||||
srcs = ["clip_vector_size_calculator_test.cc"],
|
||||
|
|
|
@ -76,4 +76,9 @@ REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator);
|
|||
// A calculator to process std::vector<mediapipe::Image>.
|
||||
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopImageCalculator);
|
||||
|
||||
// A calculator to process std::vector<float>.
|
||||
typedef BeginLoopCalculator<std::vector<float>> BeginLoopFloatCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopFloatCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -123,8 +123,11 @@ class PreviousLoopbackCalculator : public Node {
|
|||
// However, LOOP packet is empty.
|
||||
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
} else {
|
||||
// Avoids sending leftovers to a stream that's already closed.
|
||||
if (!kPrevLoop(cc).IsClosed()) {
|
||||
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
|
||||
}
|
||||
}
|
||||
loop_packets_.pop_front();
|
||||
main_packet_specs_.pop_front();
|
||||
}
|
||||
|
|
|
@ -135,7 +135,6 @@ cc_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
|
|
|
@ -112,7 +112,7 @@ class BilateralFilterCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(BilateralFilterCalculator);
|
||||
|
||||
absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) {
|
||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
|
||||
if (cc->Inputs().HasTag(kInputFrameTag) &&
|
||||
cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
||||
|
|
|
@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator);
|
|||
|
||||
absl::Status SegmentationSmoothingCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
|
||||
cc->Inputs().Tag(kCurrentMaskTag).Set<Image>();
|
||||
cc->Inputs().Tag(kPreviousMaskTag).Set<Image>();
|
||||
|
|
|
@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(SetAlphaCalculator);
|
||||
|
||||
absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
|
||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
|
||||
bool use_gpu = false;
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ std::string FourCCToString(libyuv::FourCC fourcc) {
|
|||
buf[0] = (fourcc >> 24) & 0xff;
|
||||
buf[1] = (fourcc >> 16) & 0xff;
|
||||
buf[2] = (fourcc >> 8) & 0xff;
|
||||
buf[3] = (fourcc)&0xff;
|
||||
buf[3] = (fourcc) & 0xff;
|
||||
buf[4] = 0;
|
||||
return std::string(buf);
|
||||
}
|
||||
|
|
|
@ -282,13 +282,17 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
|
|||
if (options.has_volume_gain_db()) {
|
||||
gain_ = pow(10, options.volume_gain_db() / 20.0);
|
||||
}
|
||||
if (options.has_source_sample_rate()) {
|
||||
source_sample_rate_ = options.source_sample_rate();
|
||||
} else {
|
||||
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
|
||||
!kAudioIn(cc).Header().IsEmpty())
|
||||
<< "Must either specify the time series header of the \"AUDIO\" stream "
|
||||
"or have the \"SAMPLE_RATE\" stream connected.";
|
||||
if (!kAudioIn(cc).Header().IsEmpty()) {
|
||||
mediapipe::TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
kAudioIn(cc).Header(), &input_header));
|
||||
if (stream_mode_) {
|
||||
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
|
||||
|
@ -296,6 +300,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
|
|||
source_sample_rate_ = input_header.sample_rate();
|
||||
}
|
||||
}
|
||||
}
|
||||
AppendZerosToSampleBuffer(padding_samples_before_);
|
||||
if (options.has_fft_size()) {
|
||||
RET_CHECK(IsValidFftSize(options.fft_size()))
|
||||
|
|
|
@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions {
|
|||
// The volume gain, measured in dB.
|
||||
// Scale the input audio amplitude by 10^(volume_gain_db/20).
|
||||
optional double volume_gain_db = 12;
|
||||
|
||||
// The source number of samples per second (hertz) of the input audio buffers.
|
||||
optional double source_sample_rate = 13;
|
||||
}
|
||||
|
|
|
@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
|
|||
gpu_delegate_options);
|
||||
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
bool UseSerializedModel() const { return use_serialized_model_; }
|
||||
|
||||
private:
|
||||
bool use_kernel_caching_ = false;
|
||||
|
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
|||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
|
||||
MP_RETURN_IF_ERROR(
|
||||
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
|
||||
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||
tflite_gpu_runner_.reset();
|
||||
return absl::OkStatus();
|
||||
|
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
|
|||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||
}
|
||||
|
||||
if (on_disk_cache_helper_.UseSerializedModel()) {
|
||||
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(
|
||||
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
||||
return tflite_gpu_runner_->Build();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
|
||||
}
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
|
|
@ -256,6 +256,7 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
|
||||
bool gpu_inited_ = false;
|
||||
bool gpu_input_ = false;
|
||||
bool gpu_has_enough_work_groups_ = true;
|
||||
bool anchors_init_ = false;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
|
||||
|
@ -291,7 +292,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
|
|||
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
||||
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||
bool gpu_processing = false;
|
||||
if (CanUseGpu()) {
|
||||
if (CanUseGpu() && gpu_has_enough_work_groups_) {
|
||||
// Use GPU processing only if at least one input tensor is already on GPU
|
||||
// (to avoid CPU->GPU overhead).
|
||||
for (const auto& tensor : *kInTensors(cc)) {
|
||||
|
@ -321,11 +322,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
|||
RET_CHECK(!has_custom_box_indices_);
|
||||
}
|
||||
|
||||
if (gpu_processing) {
|
||||
if (!gpu_inited_) {
|
||||
MP_RETURN_IF_ERROR(GpuInit(cc));
|
||||
if (gpu_processing && !gpu_inited_) {
|
||||
auto status = GpuInit(cc);
|
||||
if (status.ok()) {
|
||||
gpu_inited_ = true;
|
||||
} else if (status.code() == absl::StatusCode::kFailedPrecondition) {
|
||||
// For initialization error because of hardware limitation, fallback to
|
||||
// CPU processing.
|
||||
LOG(WARNING) << status.message();
|
||||
} else {
|
||||
// For other error, let the error propagates.
|
||||
return status;
|
||||
}
|
||||
}
|
||||
if (gpu_processing && gpu_inited_) {
|
||||
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
||||
|
@ -346,17 +356,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
// TODO: Add flexible input tensor size handling.
|
||||
auto raw_box_tensor =
|
||||
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
|
||||
if (raw_box_tensor->shape().dims.size() == 3) {
|
||||
// The tensors from CPU inference has dim 3.
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
|
||||
} else if (raw_box_tensor->shape().dims.size() == 4) {
|
||||
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
|
||||
// we allow tensors with 4 dims.
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_);
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"The dimensions of box Tensor must be 3 or 4.");
|
||||
}
|
||||
auto raw_score_tensor =
|
||||
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
|
||||
if (raw_score_tensor->shape().dims.size() == 3) {
|
||||
// The tensors from CPU inference has dim 3.
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
|
||||
} else if (raw_score_tensor->shape().dims.size() == 4) {
|
||||
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
|
||||
// we allow tensors with 4 dims.
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_);
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"The dimensions of score Tensor must be 3 or 4.");
|
||||
}
|
||||
auto raw_box_view = raw_box_tensor->GetCpuReadView();
|
||||
auto raw_boxes = raw_box_view.buffer<float>();
|
||||
auto raw_scores_view = raw_score_tensor->GetCpuReadView();
|
||||
|
@ -1111,8 +1145,13 @@ void main() {
|
|||
int max_wg_size; // typically <= 1024
|
||||
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
|
||||
&max_wg_size); // y-dim
|
||||
CHECK_LT(num_classes_, max_wg_size)
|
||||
<< "# classes must be < " << max_wg_size;
|
||||
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
|
||||
if (!gpu_has_enough_work_groups_) {
|
||||
return absl::FailedPreconditionError(absl::StrFormat(
|
||||
"Hardware limitation: Processing will be done on CPU, because "
|
||||
"num_classes %d exceeds the max work_group size %d.",
|
||||
num_classes_, max_wg_size));
|
||||
}
|
||||
// TODO support better filtering.
|
||||
if (class_index_set_.is_allowlist) {
|
||||
CHECK_EQ(class_index_set_.values.size(),
|
||||
|
@ -1370,7 +1409,13 @@ kernel void scoreKernel(
|
|||
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
|
||||
// # filter classes supported is hardware dependent.
|
||||
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
|
||||
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
|
||||
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
|
||||
if (!gpu_has_enough_work_groups_) {
|
||||
return absl::FailedPreconditionError(absl::StrFormat(
|
||||
"Hardware limitation: Processing will be done on CPU, because "
|
||||
"num_classes %d exceeds the max work_group size %d.",
|
||||
num_classes_, max_wg_size));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
|
|
|
@ -406,8 +406,13 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed
|
||||
# Boq conformance test. Weigh your use case to see if this will work for you.
|
||||
# This dependency removed the following 3 targets because they failed Boq conformance test:
|
||||
#
|
||||
# tensorflow_jellyfish_deps
|
||||
# jfprof_lib
|
||||
# xprofilez_with_server
|
||||
#
|
||||
# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader.
|
||||
cc_library(
|
||||
name = "tensorflow_inference_calculator_for_boq",
|
||||
srcs = ["tensorflow_inference_calculator.cc"],
|
||||
|
@ -927,7 +932,6 @@ cc_test(
|
|||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:location",
|
||||
"//mediapipe/framework/formats:location_opencv",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
|
|
@ -164,7 +164,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
}
|
||||
}
|
||||
|
||||
CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
|
||||
RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
|
||||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
|
||||
<< "Neither the output stream nor the output side packet is set to "
|
||||
"output the sequence example.";
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/location.h"
|
||||
#include "mediapipe/framework/formats/location_opencv.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
|
@ -96,7 +95,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
|
|||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
encoded_image.set_width(2);
|
||||
|
@ -139,7 +139,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
|
|||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
encoded_image.set_width(2);
|
||||
|
@ -378,7 +379,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
|||
Adopt(input_sequence.release());
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
auto image_ptr =
|
||||
|
@ -410,7 +412,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
|
|||
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
std::string test_flow_string(bytes.begin(), bytes.end());
|
||||
OpenCvImageEncoderCalculatorResults encoded_flow;
|
||||
encoded_flow.set_encoded_image(test_flow_string);
|
||||
|
@ -618,7 +621,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
|||
}
|
||||
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
encoded_image.set_width(width);
|
||||
|
@ -767,7 +771,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
|
|||
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
std::string test_flow_string(bytes.begin(), bytes.end());
|
||||
OpenCvImageEncoderCalculatorResults encoded_flow;
|
||||
encoded_flow.set_encoded_image(test_flow_string);
|
||||
|
@ -813,7 +818,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
|
|||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
std::string test_flow_string(bytes.begin(), bytes.end());
|
||||
OpenCvImageEncoderCalculatorResults encoded_flow;
|
||||
encoded_flow.set_encoded_image(test_flow_string);
|
||||
|
@ -970,7 +976,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
|||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
encoded_image.set_width(2);
|
||||
|
@ -1021,7 +1028,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
|||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
||||
ASSERT_TRUE(
|
||||
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
int height = 2;
|
||||
|
|
|
@ -172,7 +172,7 @@ class AnnotationOverlayCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
||||
|
||||
absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
|
||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
|
||||
bool use_gpu = false;
|
||||
|
||||
|
@ -189,13 +189,13 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
|
|||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
if (cc->Inputs().HasTag(kGpuBufferTag)) {
|
||||
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
|
||||
CHECK(cc->Outputs().HasTag(kGpuBufferTag));
|
||||
RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag));
|
||||
use_gpu = true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
if (cc->Inputs().HasTag(kImageFrameTag)) {
|
||||
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
|
||||
CHECK(cc->Outputs().HasTag(kImageFrameTag));
|
||||
RET_CHECK(cc->Outputs().HasTag(kImageFrameTag));
|
||||
}
|
||||
|
||||
// Data streams to render.
|
||||
|
|
|
@ -322,6 +322,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
options_.presence_threshold(), options_.connection_color(), thickness,
|
||||
/*normalized=*/false, render_data.get());
|
||||
}
|
||||
if (options_.render_landmarks()) {
|
||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
const Landmark& landmark = landmarks.landmark(i);
|
||||
|
||||
|
@ -335,7 +336,8 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
auto* landmark_data_render = AddPointRenderData(
|
||||
options_.landmark_color(), thickness, render_data.get());
|
||||
if (visualize_depth) {
|
||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render,
|
||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
||||
landmark_data_render,
|
||||
options_.min_depth_circle_thickness(),
|
||||
options_.max_depth_circle_thickness());
|
||||
}
|
||||
|
@ -345,6 +347,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
landmark_data->set_y(landmark.y());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
|
||||
const NormalizedLandmarkList& landmarks =
|
||||
|
@ -368,6 +371,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
options_.presence_threshold(), options_.connection_color(), thickness,
|
||||
/*normalized=*/true, render_data.get());
|
||||
}
|
||||
if (options_.render_landmarks()) {
|
||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
const NormalizedLandmark& landmark = landmarks.landmark(i);
|
||||
|
||||
|
@ -381,7 +385,8 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
auto* landmark_data_render = AddPointRenderData(
|
||||
options_.landmark_color(), thickness, render_data.get());
|
||||
if (visualize_depth) {
|
||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render,
|
||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
||||
landmark_data_render,
|
||||
options_.min_depth_circle_thickness(),
|
||||
options_.max_depth_circle_thickness());
|
||||
}
|
||||
|
@ -391,6 +396,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
landmark_data->set_y(landmark.y());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kRenderDataTag)
|
||||
|
|
|
@ -32,6 +32,10 @@ message LandmarksToRenderDataCalculatorOptions {
|
|||
|
||||
// Color of the landmarks.
|
||||
optional Color landmark_color = 2;
|
||||
|
||||
// Whether to render landmarks as points.
|
||||
optional bool render_landmarks = 14 [default = true];
|
||||
|
||||
// Color of the connections.
|
||||
optional Color connection_color = 3;
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
|
|||
int center_row = out_lms.landmark(lm_index).y() * hm_height;
|
||||
// Point is outside of the image let's keep it intact.
|
||||
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
|
||||
center_col >= hm_height) {
|
||||
center_row >= hm_height) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -130,7 +130,6 @@ cc_library(
|
|||
"//mediapipe/framework/formats:video_stream_header",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:opencv_video",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
],
|
||||
|
@ -341,7 +340,6 @@ cc_test(
|
|||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -367,7 +365,6 @@ cc_test(
|
|||
"//mediapipe/framework/port:opencv_video",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -451,7 +448,6 @@ cc_test(
|
|||
"//mediapipe/framework/tool:test_util",
|
||||
"//mediapipe/util/tracking:box_tracker_cc_proto",
|
||||
"//mediapipe/util/tracking:tracking_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
Binary file not shown.
|
@ -1,6 +1,6 @@
|
|||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip
|
||||
networkTimeout=10000
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "facedetectioncpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "facedetectiongpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "faceeffect",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "facemeshgpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "handdetectiongpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "handtrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "helloworld",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "holistictrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "iristrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "objectdetectioncpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "objectdetectiongpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "objectdetectiontrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "posetrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
alias(
|
||||
name = "selfiesegmentationgpu",
|
||||
|
|
|
@ -44,6 +44,9 @@ bzl_library(
|
|||
"encode_binary_proto.bzl",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@bazel_skylib//lib:paths",
|
||||
],
|
||||
)
|
||||
|
||||
alias(
|
||||
|
|
|
@ -64,58 +64,13 @@ class CalculatorBaseFactoryFor<
|
|||
namespace api2 {
|
||||
namespace internal {
|
||||
|
||||
// Defining a member of this type causes P to be ODR-used, which forces its
|
||||
// instantiation if it's a static member of a template.
|
||||
// Previously we depended on the pointer's value to determine whether the size
|
||||
// of a character array is 0 or 1, forcing it to be instantiated so the
|
||||
// compiler can determine the object's layout. But using it as a template
|
||||
// argument is more compact.
|
||||
template <auto* P>
|
||||
struct ForceStaticInstantiation {
|
||||
#ifdef _MSC_VER
|
||||
// Just having it as the template argument does not count as a use for
|
||||
// MSVC.
|
||||
static constexpr bool Use() { return P != nullptr; }
|
||||
char force_static[Use()];
|
||||
#endif // _MSC_VER
|
||||
};
|
||||
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(
|
||||
NodeRegistrator, mediapipe::CalculatorBaseRegistry, T::kCalculatorName,
|
||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>)
|
||||
|
||||
// Helper template for forcing the definition of a static registration token.
|
||||
template <typename T>
|
||||
struct NodeRegistrationStatic {
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||
|
||||
static mediapipe::RegistrationToken Make() {
|
||||
return mediapipe::CalculatorBaseRegistry::Register(
|
||||
T::kCalculatorName,
|
||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>,
|
||||
__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
using RequireStatics = ForceStaticInstantiation<®istration>;
|
||||
};
|
||||
|
||||
// Static members of template classes can be defined in the header.
|
||||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
NodeRegistrationStatic<T>::registration(NodeRegistrationStatic<T>::Make());
|
||||
|
||||
template <typename T>
|
||||
struct SubgraphRegistrationImpl {
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||
|
||||
static mediapipe::RegistrationToken Make() {
|
||||
return mediapipe::SubgraphRegistry::Register(
|
||||
T::kCalculatorName, absl::make_unique<T>, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
using RequireStatics = ForceStaticInstantiation<®istration>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
SubgraphRegistrationImpl<T>::registration(
|
||||
SubgraphRegistrationImpl<T>::Make());
|
||||
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(SubgraphRegistrator,
|
||||
mediapipe::SubgraphRegistry,
|
||||
T::kCalculatorName, absl::make_unique<T>)
|
||||
|
||||
} // namespace internal
|
||||
|
||||
|
@ -128,14 +83,7 @@ template <class Impl = void>
|
|||
class RegisteredNode;
|
||||
|
||||
template <class Impl>
|
||||
class RegisteredNode : public Node {
|
||||
private:
|
||||
// The member below triggers instantiation of the registration static.
|
||||
// Note that the constructor of calculator subclasses is only invoked through
|
||||
// the registration token, and so we cannot simply use the static in the
|
||||
// constructor.
|
||||
typename internal::NodeRegistrationStatic<Impl>::RequireStatics register_;
|
||||
};
|
||||
class RegisteredNode : public Node, private internal::NodeRegistrator<Impl> {};
|
||||
|
||||
// No-op version for backwards compatibility.
|
||||
template <>
|
||||
|
@ -217,20 +165,17 @@ class NodeImpl : public RegisteredNode<Impl>, public Intf {
|
|||
// TODO: verify that the subgraph config fully implements the
|
||||
// declared interface.
|
||||
template <class Intf, class Impl>
|
||||
class SubgraphImpl : public Subgraph, public Intf {
|
||||
private:
|
||||
typename internal::SubgraphRegistrationImpl<Impl>::RequireStatics register_;
|
||||
};
|
||||
class SubgraphImpl : public Subgraph,
|
||||
public Intf,
|
||||
private internal::SubgraphRegistrator<Impl> {};
|
||||
|
||||
// This macro is used to register a calculator that does not use automatic
|
||||
// registration. Deprecated.
|
||||
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
|
||||
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
|
||||
REGISTRY_STATIC_VAR(calculator_registration, \
|
||||
__LINE__)(mediapipe::CalculatorBaseRegistry::Register( \
|
||||
MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
|
||||
mediapipe::CalculatorBaseRegistry, calculator_registration, \
|
||||
Impl::kCalculatorName, \
|
||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>, \
|
||||
__FILE__, __LINE__))
|
||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>)
|
||||
|
||||
// This macro is used to register a non-split-contract calculator. Deprecated.
|
||||
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
|
||||
|
@ -238,10 +183,9 @@ class SubgraphImpl : public Subgraph, public Intf {
|
|||
// This macro is used to define a subgraph that does not use automatic
|
||||
// registration. Deprecated.
|
||||
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
|
||||
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
|
||||
REGISTRY_STATIC_VAR(subgraph_registration, \
|
||||
__LINE__)(mediapipe::SubgraphRegistry::Register( \
|
||||
Impl::kCalculatorName, absl::make_unique<Impl>, __FILE__, __LINE__))
|
||||
MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
|
||||
mediapipe::SubgraphRegistry, subgraph_registration, \
|
||||
Impl::kCalculatorName, absl::make_unique<Impl>)
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -183,8 +183,7 @@ TEST(CalculatorTest, CreateByNameWhitelisted) {
|
|||
CalculatorBaseRegistry::Register(
|
||||
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
|
||||
absl::make_unique<internal::CalculatorBaseFactoryFor<
|
||||
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>,
|
||||
__FILE__, __LINE__);
|
||||
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>);
|
||||
|
||||
// A whitelisted calculator can be found in its own namespace.
|
||||
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#define MEDIAPIPE_DEPS_REGISTRATION_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
@ -145,6 +144,23 @@ template <typename T>
|
|||
struct WrapStatusOr<absl::StatusOr<T>> {
|
||||
using type = absl::StatusOr<T>;
|
||||
};
|
||||
|
||||
// Defining a member of this type causes P to be ODR-used, which forces its
|
||||
// instantiation if it's a static member of a template.
|
||||
// Previously we depended on the pointer's value to determine whether the size
|
||||
// of a character array is 0 or 1, forcing it to be instantiated so the
|
||||
// compiler can determine the object's layout. But using it as a template
|
||||
// argument is more compact.
|
||||
template <auto* P>
|
||||
struct ForceStaticInstantiation {
|
||||
#ifdef _MSC_VER
|
||||
// Just having it as the template argument does not count as a use for
|
||||
// MSVC.
|
||||
static constexpr bool Use() { return P != nullptr; }
|
||||
char force_static[Use()];
|
||||
#endif // _MSC_VER
|
||||
};
|
||||
|
||||
} // namespace registration_internal
|
||||
|
||||
class NamespaceAllowlist {
|
||||
|
@ -162,8 +178,7 @@ class FunctionRegistry {
|
|||
FunctionRegistry(const FunctionRegistry&) = delete;
|
||||
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
|
||||
|
||||
RegistrationToken Register(absl::string_view name, Function func,
|
||||
std::string filename, uint64_t line)
|
||||
RegistrationToken Register(absl::string_view name, Function func)
|
||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
std::string normalized_name = GetNormalizedName(name);
|
||||
absl::WriterMutexLock lock(&lock_);
|
||||
|
@ -173,21 +188,10 @@ class FunctionRegistry {
|
|||
}
|
||||
if (functions_.insert(std::make_pair(normalized_name, std::move(func)))
|
||||
.second) {
|
||||
#ifndef NDEBUG
|
||||
locations_.emplace(normalized_name,
|
||||
std::make_pair(std::move(filename), line));
|
||||
#endif
|
||||
return RegistrationToken(
|
||||
[this, normalized_name]() { Unregister(normalized_name); });
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
LOG(FATAL) << "Function with name " << name << " already registered."
|
||||
<< " First registration at "
|
||||
<< locations_.at(normalized_name).first << ":"
|
||||
<< locations_.at(normalized_name).second;
|
||||
#else
|
||||
LOG(FATAL) << "Function with name " << name << " already registered.";
|
||||
#endif
|
||||
return RegistrationToken([]() {});
|
||||
}
|
||||
|
||||
|
@ -316,11 +320,6 @@ class FunctionRegistry {
|
|||
private:
|
||||
mutable absl::Mutex lock_;
|
||||
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
||||
#ifndef NDEBUG
|
||||
// Stores filename and line number for useful debug log.
|
||||
absl::flat_hash_map<std::string, std::pair<std::string, uint32_t>> locations_
|
||||
ABSL_GUARDED_BY(lock_);
|
||||
#endif
|
||||
|
||||
// For names included in NamespaceAllowlist, strips the namespace.
|
||||
std::string GetAdjustedName(absl::string_view name) {
|
||||
|
@ -351,10 +350,8 @@ class GlobalFactoryRegistry {
|
|||
|
||||
public:
|
||||
static RegistrationToken Register(absl::string_view name,
|
||||
typename Functions::Function func,
|
||||
std::string filename, uint64_t line) {
|
||||
return functions()->Register(name, std::move(func), std::move(filename),
|
||||
line);
|
||||
typename Functions::Function func) {
|
||||
return functions()->Register(name, std::move(func));
|
||||
}
|
||||
|
||||
// Invokes the specified factory function and returns the result.
|
||||
|
@ -414,12 +411,77 @@ class GlobalFactoryRegistry {
|
|||
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \
|
||||
static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \
|
||||
new mediapipe::RegistrationToken( \
|
||||
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__))
|
||||
RegistryType::Register(#name, __VA_ARGS__))
|
||||
|
||||
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \
|
||||
name, ...) \
|
||||
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
|
||||
new mediapipe::RegistrationToken( \
|
||||
RegistryType::Register(name, __VA_ARGS__))
|
||||
|
||||
// TODO: migrate to the above.
|
||||
#define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \
|
||||
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
|
||||
new mediapipe::RegistrationToken( \
|
||||
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__))
|
||||
RegistryType::Register(#name, __VA_ARGS__))
|
||||
|
||||
// Defines a utility registrator class which can be used to automatically
|
||||
// register factory functions.
|
||||
//
|
||||
// Example:
|
||||
// === Defining a registry ================================================
|
||||
//
|
||||
// class Component {};
|
||||
//
|
||||
// using ComponentRegistry = GlobalFactoryRegistry<std::unique_ptr<Component>>;
|
||||
//
|
||||
// === Defining a registrator =============================================
|
||||
//
|
||||
// MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(ComponentRegistrator,
|
||||
// ComponentRegistry, T::kName,
|
||||
// absl::make_unique<T>);
|
||||
//
|
||||
// === Defining and registering a new component. ==========================
|
||||
//
|
||||
// class MyComponent : public Component,
|
||||
// private ComponentRegistrator<MyComponent> {
|
||||
// public:
|
||||
// static constexpr char kName[] = "MyComponent";
|
||||
// ...
|
||||
// };
|
||||
//
|
||||
// NOTE:
|
||||
// - MyComponent is automatically registered in ComponentRegistry by
|
||||
// "MyComponent" name.
|
||||
// - Every component is require to provide its name (T::kName here.)
|
||||
#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \
|
||||
name, ...) \
|
||||
template <typename T> \
|
||||
struct Internal##RegistratorName { \
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration; \
|
||||
\
|
||||
static mediapipe::RegistrationToken Make() { \
|
||||
return RegistryType::Register(name, __VA_ARGS__); \
|
||||
} \
|
||||
\
|
||||
using RequireStatics = \
|
||||
registration_internal::ForceStaticInstantiation<®istration>; \
|
||||
}; \
|
||||
/* Static members of template classes can be defined in the header. */ \
|
||||
template <typename T> \
|
||||
NoDestructor<mediapipe::RegistrationToken> \
|
||||
Internal##RegistratorName<T>::registration( \
|
||||
Internal##RegistratorName<T>::Make()); \
|
||||
\
|
||||
template <typename T> \
|
||||
class RegistratorName { \
|
||||
private: \
|
||||
/* The member below triggers instantiation of the registration static. */ \
|
||||
/* Note that the constructor of calculator subclasses is only invoked */ \
|
||||
/* through the registration token, and so we cannot simply use the */ \
|
||||
/* static in theconstructor. */ \
|
||||
typename Internal##RegistratorName<T>::RequireStatics register_; \
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -37,29 +37,33 @@ Args:
|
|||
output: The desired name of the output file. Optional.
|
||||
"""
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
|
||||
PROTOC = "@com_google_protobuf//:protoc"
|
||||
|
||||
def _canonicalize_proto_path_oss(all_protos, genfile_path):
|
||||
"""For the protos from external repository, canonicalize the proto path and the file name.
|
||||
def _canonicalize_proto_path_oss(f):
|
||||
if not f.root.path:
|
||||
return struct(
|
||||
proto_path = ".",
|
||||
file_name = f.short_path,
|
||||
)
|
||||
|
||||
Returns:
|
||||
Proto path list and proto source file list.
|
||||
"""
|
||||
proto_paths = []
|
||||
proto_file_names = []
|
||||
for s in all_protos.to_list():
|
||||
if s.path.startswith(genfile_path):
|
||||
repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/")
|
||||
|
||||
# handle virtual imports
|
||||
if file_name.startswith("_virtual_imports"):
|
||||
repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2])
|
||||
# `f.path` looks like "<genfiles>/external/<repo>/(_virtual_imports/<library>/)?<file_name>"
|
||||
repo_name, _, file_name = f.path[len(paths.join(f.root.path, "external") + "/"):].partition("/")
|
||||
if file_name.startswith("_virtual_imports/"):
|
||||
# This is a virtual import; move "_virtual_imports/<library>" from `repo_name` to `file_name`.
|
||||
repo_name = paths.join(repo_name, *file_name.split("/", 2)[:2])
|
||||
file_name = file_name.split("/", 2)[-1]
|
||||
proto_paths.append(genfile_path + "/external/" + repo_name)
|
||||
proto_file_names.append(file_name)
|
||||
else:
|
||||
proto_file_names.append(s.path)
|
||||
return ([" --proto_path=" + path for path in proto_paths], proto_file_names)
|
||||
return struct(
|
||||
proto_path = paths.join(f.root.path, "external", repo_name),
|
||||
file_name = file_name,
|
||||
)
|
||||
|
||||
def _map_root_path(f):
|
||||
return _canonicalize_proto_path_oss(f).proto_path
|
||||
|
||||
def _map_short_path(f):
|
||||
return _canonicalize_proto_path_oss(f).file_name
|
||||
|
||||
def _get_proto_provider(dep):
|
||||
"""Get the provider for protocol buffers from a dependnecy.
|
||||
|
@ -90,24 +94,35 @@ def _encode_binary_proto_impl(ctx):
|
|||
sibling = textpb,
|
||||
)
|
||||
|
||||
path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path)
|
||||
args = ctx.actions.args()
|
||||
args.add(textpb)
|
||||
args.add(binarypb)
|
||||
args.add(ctx.executable._proto_compiler)
|
||||
args.add(ctx.attr.message_type, format = "--encode=%s")
|
||||
args.add("--proto_path=.")
|
||||
args.add_all(
|
||||
all_protos,
|
||||
map_each = _map_root_path,
|
||||
format_each = "--proto_path=%s",
|
||||
uniquify = True,
|
||||
)
|
||||
args.add_all(
|
||||
all_protos,
|
||||
map_each = _map_short_path,
|
||||
uniquify = True,
|
||||
)
|
||||
|
||||
# Note: the combination of absolute_paths and proto_path, as well as the exact
|
||||
# order of gendir before ., is needed for the proto compiler to resolve
|
||||
# import statements that reference proto files produced by a genrule.
|
||||
ctx.actions.run_shell(
|
||||
tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler],
|
||||
outputs = [binarypb],
|
||||
command = " ".join(
|
||||
[
|
||||
ctx.executable._proto_compiler.path,
|
||||
"--encode=" + ctx.attr.message_type,
|
||||
"--proto_path=" + ctx.genfiles_dir.path,
|
||||
"--proto_path=" + ctx.bin_dir.path,
|
||||
"--proto_path=.",
|
||||
] + path_list + file_list +
|
||||
["<", textpb.path, ">", binarypb.path],
|
||||
tools = depset(
|
||||
direct = [textpb, ctx.executable._proto_compiler],
|
||||
transitive = [all_protos],
|
||||
),
|
||||
outputs = [binarypb],
|
||||
command = "${@:3} < $1 > $2",
|
||||
arguments = [args],
|
||||
mnemonic = "EncodeProto",
|
||||
)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe;
|
|||
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
|
||||
// the joint and its visibility.
|
||||
message Joint {
|
||||
// Joint rotation in 6D contineous representation ordered as
|
||||
// Joint rotation in 6D continuous representation ordered as
|
||||
// [a1, b1, a2, b2, a3, b3].
|
||||
//
|
||||
// Such representation is more sutable for NN model training and can be
|
||||
|
|
|
@ -15,7 +15,7 @@ def mediapipe_cc_test(
|
|||
platforms = ["linux", "android", "ios", "wasm"],
|
||||
exclude_platforms = None,
|
||||
# ios_unit_test arguments
|
||||
ios_minimum_os_version = "11.0",
|
||||
ios_minimum_os_version = "12.0",
|
||||
# android_cc_test arguments
|
||||
open_gl_driver = None,
|
||||
emulator_mini_boot = True,
|
||||
|
|
|
@ -466,8 +466,7 @@ struct MessageRegistrationImpl {
|
|||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
||||
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder,
|
||||
__FILE__, __LINE__));
|
||||
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
|
||||
|
||||
// For non-Message payloads, this does nothing.
|
||||
template <typename T, typename Enable = void>
|
||||
|
|
|
@ -261,8 +261,8 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "opencv_highgui",
|
||||
hdrs = ["opencv_highgui_inc.h"],
|
||||
name = "opencv_photo",
|
||||
hdrs = ["opencv_photo_inc.h"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
|
@ -297,6 +297,15 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "opencv_highgui",
|
||||
hdrs = ["opencv_highgui_inc.h"],
|
||||
deps = [
|
||||
":opencv_core",
|
||||
"//third_party:opencv",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "opencv_videoio",
|
||||
hdrs = ["opencv_videoio_inc.h"],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -12,8 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
|
||||
#define MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
|
||||
#ifndef MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
|
||||
|
||||
#include <opencv2/core/version.hpp>
|
||||
|
||||
|
@ -25,4 +25,4 @@
|
|||
#include <opencv2/highgui.hpp>
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
|
||||
#endif // MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -12,15 +12,9 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
#ifndef MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
|
||||
#define MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "third_party/OpenCV/photo.hpp"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
|
||||
ClipDetectionVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
#endif // MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -48,6 +48,18 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
|||
: InputStreamHandler(std::move(tag_map), cc_manager, options,
|
||||
calculator_run_in_parallel) {}
|
||||
|
||||
private:
|
||||
CollectionItemId GetControlStreamId() const {
|
||||
return input_stream_managers_.EndId() - 1;
|
||||
}
|
||||
void RemoveOutdatedDataPackets(Timestamp timestamp) {
|
||||
const CollectionItemId control_stream_id = GetControlStreamId();
|
||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
||||
id < control_stream_id; ++id) {
|
||||
input_stream_managers_.Get(id)->ErasePacketsEarlierThan(timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// In MuxInputStreamHandler, a node is "ready" if:
|
||||
// - the control stream is done (need to call Close() in this case), or
|
||||
|
@ -58,9 +70,15 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
|||
absl::MutexLock lock(&input_streams_mutex_);
|
||||
|
||||
const auto& control_stream =
|
||||
input_stream_managers_.Get(input_stream_managers_.EndId() - 1);
|
||||
input_stream_managers_.Get(GetControlStreamId());
|
||||
bool empty;
|
||||
*min_stream_timestamp = control_stream->MinTimestampOrBound(&empty);
|
||||
|
||||
// Data streams may contain some outdated packets which failed to be popped
|
||||
// out during "FillInputSet". (This handler doesn't sync input streams,
|
||||
// hence "FillInputSet" can be triggerred before every input stream is
|
||||
// filled with packets corresponding to the same timestamp.)
|
||||
RemoveOutdatedDataPackets(*min_stream_timestamp);
|
||||
if (empty) {
|
||||
if (*min_stream_timestamp == Timestamp::Done()) {
|
||||
// Calculator is done if the control input stream is done.
|
||||
|
@ -78,11 +96,6 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
|||
const auto& data_stream = input_stream_managers_.Get(
|
||||
input_stream_managers_.BeginId() + control_value);
|
||||
|
||||
// Data stream may contain some outdated packets which failed to be popped
|
||||
// out during "FillInputSet". (This handler doesn't sync input streams,
|
||||
// hence "FillInputSet" can be triggerred before every input stream is
|
||||
// filled with packets corresponding to the same timestamp.)
|
||||
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
|
||||
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
|
||||
if (empty) {
|
||||
if (stream_timestamp <= *min_stream_timestamp) {
|
||||
|
@ -111,8 +124,7 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
|||
CHECK(input_set);
|
||||
absl::MutexLock lock(&input_streams_mutex_);
|
||||
|
||||
const CollectionItemId control_stream_id =
|
||||
input_stream_managers_.EndId() - 1;
|
||||
const CollectionItemId control_stream_id = GetControlStreamId();
|
||||
auto& control_stream = input_stream_managers_.Get(control_stream_id);
|
||||
int num_packets_dropped = 0;
|
||||
bool stream_is_done = false;
|
||||
|
@ -140,15 +152,8 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
|||
AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet),
|
||||
stream_is_done);
|
||||
|
||||
// Discard old packets on other streams.
|
||||
// Note that control_stream_id is the last valid id.
|
||||
auto next_timestamp = input_timestamp.NextAllowedInStream();
|
||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
||||
id < control_stream_id; ++id) {
|
||||
if (id == data_stream_id) continue;
|
||||
auto& other_stream = input_stream_managers_.Get(id);
|
||||
other_stream->ErasePacketsEarlierThan(next_timestamp);
|
||||
}
|
||||
// Discard old packets on data streams.
|
||||
RemoveOutdatedDataPackets(input_timestamp.NextAllowedInStream());
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -645,5 +645,41 @@ TEST(MuxInputStreamHandlerTest,
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxInputStreamHandlerTest, RemovesUnusedDataStreamPackets) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input0"
|
||||
input_stream: "input1"
|
||||
input_stream: "select"
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:input0"
|
||||
input_stream: "INPUT:1:input1"
|
||||
input_stream: "SELECT:select"
|
||||
output_stream: "OUTPUT:output"
|
||||
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||
}
|
||||
)pb");
|
||||
config.set_max_queue_size(1);
|
||||
config.set_report_deadlock(true);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input0", MakePacket<int>(1000).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Add two delayed packets to the deselected input. They should be discarded
|
||||
// instead of triggering the deadlock detection (max_queue_size = 1).
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(900).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input1", MakePacket<int>(900).At(Timestamp(2))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -64,13 +64,13 @@ GraphRegistry::GraphRegistry(
|
|||
void GraphRegistry::Register(
|
||||
const std::string& type_name,
|
||||
std::function<std::unique_ptr<Subgraph>()> factory) {
|
||||
local_factories_.Register(type_name, factory, __FILE__, __LINE__);
|
||||
local_factories_.Register(type_name, factory);
|
||||
}
|
||||
|
||||
// TODO: Remove this convenience function.
|
||||
void GraphRegistry::Register(const std::string& type_name,
|
||||
const CalculatorGraphConfig& config) {
|
||||
Register(type_name, [config] {
|
||||
local_factories_.Register(type_name, [config] {
|
||||
auto result = absl::make_unique<ProtoSubgraph>(config);
|
||||
return std::unique_ptr<Subgraph>(result.release());
|
||||
});
|
||||
|
@ -79,7 +79,7 @@ void GraphRegistry::Register(const std::string& type_name,
|
|||
// TODO: Remove this convenience function.
|
||||
void GraphRegistry::Register(const std::string& type_name,
|
||||
const CalculatorGraphTemplate& templ) {
|
||||
Register(type_name, [templ] {
|
||||
local_factories_.Register(type_name, [templ] {
|
||||
auto result = absl::make_unique<TemplateSubgraph>(templ);
|
||||
return std::unique_ptr<Subgraph>(result.release());
|
||||
});
|
||||
|
|
|
@ -228,7 +228,9 @@ absl::Status CompareAndSaveImageOutput(
|
|||
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
|
||||
options.max_alpha_diff, options.max_avg_diff,
|
||||
diff_img);
|
||||
if (diff_img) {
|
||||
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
|
@ -1121,7 +1121,7 @@ objc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
MIN_IOS_VERSION = "12.0"
|
||||
|
||||
test_suite(
|
||||
name = "ios",
|
||||
|
|
|
@ -109,9 +109,8 @@ absl::Status GlContext::CreateContext(
|
|||
}
|
||||
MP_RETURN_IF_ERROR(status);
|
||||
|
||||
LOG(INFO) << "Successfully created a WebGL context with major version "
|
||||
VLOG(1) << "Successfully created a WebGL context with major version "
|
||||
<< gl_major_version_ << " and handle " << context_;
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -104,6 +104,7 @@ class GlScalerCalculator : public CalculatorBase {
|
|||
bool vertical_flip_output_;
|
||||
bool horizontal_flip_output_;
|
||||
FrameScaleMode scale_mode_ = FrameScaleMode::kStretch;
|
||||
bool use_nearest_neighbor_interpolation_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(GlScalerCalculator);
|
||||
|
||||
|
@ -186,7 +187,8 @@ absl::Status GlScalerCalculator::Open(CalculatorContext* cc) {
|
|||
scale_mode_ =
|
||||
FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch);
|
||||
}
|
||||
|
||||
use_nearest_neighbor_interpolation_ =
|
||||
options.use_nearest_neighbor_interpolation();
|
||||
if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) {
|
||||
const auto& dimensions =
|
||||
TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)
|
||||
|
@ -297,6 +299,11 @@ absl::Status GlScalerCalculator::Process(CalculatorContext* cc) {
|
|||
glBindTexture(src2.target(), src2.name());
|
||||
}
|
||||
|
||||
if (use_nearest_neighbor_interpolation_) {
|
||||
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
|
||||
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(renderer->GlRender(
|
||||
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_,
|
||||
rotation_, horizontal_flip_output_, vertical_flip_output_,
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe;
|
|||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/gpu/scale_mode.proto";
|
||||
|
||||
// Next id: 8.
|
||||
// Next id: 9.
|
||||
message GlScalerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional GlScalerCalculatorOptions ext = 166373014;
|
||||
|
@ -39,4 +39,7 @@ message GlScalerCalculatorOptions {
|
|||
// Flip the output texture horizontally. This is applied after rotation.
|
||||
optional bool flip_horizontal = 5;
|
||||
optional ScaleMode.Mode scale_mode = 6;
|
||||
// Whether to use nearest neighbor interpolation. Default to use linear
|
||||
// interpolation.
|
||||
optional bool use_nearest_neighbor_interpolation = 8 [default = false];
|
||||
}
|
||||
|
|
|
@ -100,6 +100,10 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
|
|||
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
|
||||
#endif // TARGET_OS_OSX
|
||||
}},
|
||||
{GpuBufferFormat::kOneComponent8Alpha,
|
||||
{
|
||||
{GL_ALPHA, GL_ALPHA, GL_UNSIGNED_BYTE, 1},
|
||||
}},
|
||||
{GpuBufferFormat::kOneComponent8Red,
|
||||
{
|
||||
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
|
||||
|
@ -221,6 +225,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
|
|||
case GpuBufferFormat::kRGBA32:
|
||||
// TODO: this likely maps to ImageFormat::SRGBA
|
||||
case GpuBufferFormat::kGrayHalf16:
|
||||
case GpuBufferFormat::kOneComponent8Alpha:
|
||||
case GpuBufferFormat::kOneComponent8Red:
|
||||
case GpuBufferFormat::kTwoComponent8:
|
||||
case GpuBufferFormat::kTwoComponentHalf16:
|
||||
|
|
|
@ -43,6 +43,7 @@ enum class GpuBufferFormat : uint32_t {
|
|||
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
|
||||
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
|
||||
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
|
||||
kOneComponent8Alpha = MEDIAPIPE_FOURCC('A', '0', '0', '8'),
|
||||
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
|
||||
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
|
||||
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
|
||||
|
@ -101,6 +102,7 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
|
|||
return kCVPixelFormatType_OneComponent32Float;
|
||||
case GpuBufferFormat::kOneComponent8:
|
||||
return kCVPixelFormatType_OneComponent8;
|
||||
case GpuBufferFormat::kOneComponent8Alpha:
|
||||
case GpuBufferFormat::kOneComponent8Red:
|
||||
return -1;
|
||||
case GpuBufferFormat::kTwoComponent8:
|
||||
|
|
|
@ -78,17 +78,21 @@ public class AppTextureFrame implements TextureFrame {
|
|||
* Use {@link waitUntilReleasedWithGpuSync} whenever possible.
|
||||
*/
|
||||
public void waitUntilReleased() throws InterruptedException {
|
||||
GlSyncToken tokenToRelease = null;
|
||||
synchronized (this) {
|
||||
while (inUse && releaseSyncToken == null) {
|
||||
wait();
|
||||
}
|
||||
if (releaseSyncToken != null) {
|
||||
releaseSyncToken.waitOnCpu();
|
||||
releaseSyncToken.release();
|
||||
tokenToRelease = releaseSyncToken;
|
||||
inUse = false;
|
||||
releaseSyncToken = null;
|
||||
}
|
||||
}
|
||||
if (tokenToRelease != null) {
|
||||
tokenToRelease.waitOnCpu();
|
||||
tokenToRelease.release();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -98,17 +102,21 @@ public class AppTextureFrame implements TextureFrame {
|
|||
* TextureFrame.
|
||||
*/
|
||||
public void waitUntilReleasedWithGpuSync() throws InterruptedException {
|
||||
GlSyncToken tokenToRelease = null;
|
||||
synchronized (this) {
|
||||
while (inUse && releaseSyncToken == null) {
|
||||
wait();
|
||||
}
|
||||
if (releaseSyncToken != null) {
|
||||
releaseSyncToken.waitOnGpu();
|
||||
releaseSyncToken.release();
|
||||
tokenToRelease = releaseSyncToken;
|
||||
inUse = false;
|
||||
releaseSyncToken = null;
|
||||
}
|
||||
}
|
||||
if (tokenToRelease != null) {
|
||||
tokenToRelease.waitOnGpu();
|
||||
tokenToRelease.release();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -239,7 +239,7 @@ public final class PacketGetter {
|
|||
|
||||
/**
|
||||
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
|
||||
* array has the the same size of image list packet, and assumes the output buffer stores pixels
|
||||
* array has the same size of image list packet, and assumes the output buffer stores pixels
|
||||
* contiguously. It returns false if this assumption does not hold.
|
||||
*
|
||||
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of
|
||||
|
|
|
@ -24,6 +24,7 @@ package_group(
|
|||
package_group(
|
||||
name = "1p_client",
|
||||
packages = [
|
||||
"//cloud/ml/applications/vision/model_garden/model_oss/mediapipe/...",
|
||||
"//research/privacy/learning/fl_eval/pcvr/...",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -57,3 +57,14 @@ py_test(
|
|||
srcs = ["classification_dataset_test.py"],
|
||||
deps = [":classification_dataset"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cache_files",
|
||||
srcs = ["cache_files.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cache_files_test",
|
||||
srcs = ["cache_files_test.py"],
|
||||
deps = [":cache_files"],
|
||||
)
|
||||
|
|
112
mediapipe/model_maker/python/core/data/cache_files.py
Normal file
112
mediapipe/model_maker/python/core/data/cache_files.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common TFRecord cache files library."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
|
||||
# Suffix of the meta data file name.
|
||||
METADATA_FILE_SUFFIX = '_metadata.yaml'
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TFRecordCacheFiles:
|
||||
"""TFRecordCacheFiles dataclass to store and load cached TFRecord files.
|
||||
|
||||
Attributes:
|
||||
cache_prefix_filename: The cache prefix filename. This is usually provided
|
||||
as a hash of the original data source to avoid different data sources
|
||||
resulting in the same cache file.
|
||||
cache_dir: The cache directory to save TFRecord and metadata file. When
|
||||
cache_dir is None, a temporary folder will be created and will not be
|
||||
removed automatically after training which makes it can be used later.
|
||||
num_shards: Number of shards for output tfrecord files.
|
||||
"""
|
||||
|
||||
cache_prefix_filename: str = 'cache_prefix'
|
||||
cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp)
|
||||
num_shards: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if not tf.io.gfile.exists(self.cache_dir):
|
||||
tf.io.gfile.makedirs(self.cache_dir)
|
||||
if not self.cache_prefix_filename:
|
||||
raise ValueError('cache_prefix_filename cannot be empty.')
|
||||
if self.num_shards <= 0:
|
||||
raise ValueError(
|
||||
f'num_shards must be greater than 0, got {self.num_shards}'
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_prefix(self) -> str:
|
||||
"""The cache prefix including the cache directory and the cache prefix filename."""
|
||||
return os.path.join(self.cache_dir, self.cache_prefix_filename)
|
||||
|
||||
@property
|
||||
def tfrecord_files(self) -> Sequence[str]:
|
||||
"""The TFRecord files."""
|
||||
tfrecord_files = [
|
||||
self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards)
|
||||
for i in range(self.num_shards)
|
||||
]
|
||||
return tfrecord_files
|
||||
|
||||
@property
|
||||
def metadata_file(self) -> str:
|
||||
"""The metadata file."""
|
||||
return self.cache_prefix + METADATA_FILE_SUFFIX
|
||||
|
||||
def get_writers(self) -> Sequence[tf.io.TFRecordWriter]:
|
||||
"""Gets an array of TFRecordWriter objects.
|
||||
|
||||
Note that these writers should each be closed using .close() when done.
|
||||
|
||||
Returns:
|
||||
Array of TFRecordWriter objects
|
||||
"""
|
||||
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
|
||||
|
||||
def save_metadata(self, metadata):
|
||||
"""Writes metadata to file.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary of metadata content to write. Exact format is
|
||||
dependent on the specific dataset, but typically includes a 'size' and
|
||||
'label_names' entry.
|
||||
"""
|
||||
with tf.io.gfile.GFile(self.metadata_file, 'w') as f:
|
||||
yaml.dump(metadata, f)
|
||||
|
||||
def load_metadata(self) -> Mapping[Any, Any]:
|
||||
"""Reads metadata from file.
|
||||
|
||||
Returns:
|
||||
Dictionary object containing metadata
|
||||
"""
|
||||
if not tf.io.gfile.exists(self.metadata_file):
|
||||
return {}
|
||||
with tf.io.gfile.GFile(self.metadata_file, 'r') as f:
|
||||
metadata = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return metadata
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
"""Checks whether this CacheFiles is already cached."""
|
||||
all_cached_files = list(self.tfrecord_files) + [self.metadata_file]
|
||||
return all(tf.io.gfile.exists(f) for f in all_cached_files)
|
77
mediapipe/model_maker/python/core/data/cache_files_test.py
Normal file
77
mediapipe/model_maker/python/core/data/cache_files_test.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files
|
||||
|
||||
|
||||
class CacheFilesTest(tf.test.TestCase):
|
||||
|
||||
def test_tfrecord_cache_files(self):
|
||||
cf = cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='tfrecord',
|
||||
cache_dir='/tmp/cache_dir',
|
||||
num_shards=2,
|
||||
)
|
||||
self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord')
|
||||
self.assertEqual(
|
||||
cf.metadata_file,
|
||||
'/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX,
|
||||
)
|
||||
expected_tfrecord_files = [
|
||||
'/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2)
|
||||
for i in range(2)
|
||||
]
|
||||
self.assertEqual(cf.tfrecord_files, expected_tfrecord_files)
|
||||
|
||||
# Writing TFRecord Files
|
||||
self.assertFalse(cf.is_cached())
|
||||
for tfrecord_file in cf.tfrecord_files:
|
||||
self.assertFalse(tf.io.gfile.exists(tfrecord_file))
|
||||
writers = cf.get_writers()
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
for tfrecord_file in cf.tfrecord_files:
|
||||
self.assertTrue(tf.io.gfile.exists(tfrecord_file))
|
||||
self.assertFalse(cf.is_cached())
|
||||
|
||||
# Writing Metadata Files
|
||||
original_metadata = {'size': 10, 'label_names': ['label1', 'label2']}
|
||||
cf.save_metadata(original_metadata)
|
||||
self.assertTrue(cf.is_cached())
|
||||
metadata = cf.load_metadata()
|
||||
self.assertEqual(metadata, original_metadata)
|
||||
|
||||
def test_recordio_cache_files_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'cache_prefix_filename cannot be empty'
|
||||
):
|
||||
cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='',
|
||||
cache_dir='/tmp/cache_dir',
|
||||
num_shards=2,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'num_shards must be greater than 0, got 0'
|
||||
):
|
||||
cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='tfrecord',
|
||||
cache_dir='/tmp/cache_dir',
|
||||
num_shards=0,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Common classification dataset library."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
|||
class ClassificationDataset(ds.Dataset):
|
||||
"""Dataset Loader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
label_names: List[str]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: tf.data.Dataset,
|
||||
label_names: List[str],
|
||||
size: Optional[int] = None,
|
||||
):
|
||||
super().__init__(dataset, size)
|
||||
self._label_names = label_names
|
||||
|
||||
|
|
|
@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
|||
value: A value variable stored by the mock dataset class for testing.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
label_names: List[str], value: Any):
|
||||
super().__init__(dataset=dataset, size=size, label_names=label_names)
|
||||
def __init__(
|
||||
self,
|
||||
dataset: tf.data.Dataset,
|
||||
label_names: List[str],
|
||||
value: Any,
|
||||
size: int,
|
||||
):
|
||||
super().__init__(dataset=dataset, label_names=label_names, size=size)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
|
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
|||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataset(
|
||||
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
|
||||
dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
|
||||
)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
|
|
|
@ -56,15 +56,14 @@ class Dataset(object):
|
|||
def size(self) -> Optional[int]:
|
||||
"""Returns the size of the dataset.
|
||||
|
||||
Note that this function may return None becuase the exact size of the
|
||||
dataset isn't a necessary parameter to create an instance of this class,
|
||||
and tf.data.Dataset donesn't support a function to get the length directly
|
||||
since it's lazy-loaded and may be infinite.
|
||||
In most cases, however, when an instance of this class is created by helper
|
||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||
and this function can return an int representing the size of the dataset.
|
||||
Same functionality as calling __len__. See the __len__ method definition for
|
||||
more information.
|
||||
|
||||
Raises:
|
||||
TypeError if self._size is not set and the cardinality of self._dataset
|
||||
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
|
||||
"""
|
||||
return self._size
|
||||
return self.__len__()
|
||||
|
||||
def gen_tf_dataset(
|
||||
self,
|
||||
|
@ -116,8 +115,22 @@ class Dataset(object):
|
|||
# here.
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of element of the dataset."""
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of element of the dataset.
|
||||
|
||||
If size is not set, this method will fallback to using the __len__ method
|
||||
of the tf.data.Dataset in self._dataset. Calling __len__ on a
|
||||
tf.data.Dataset instance may throw a TypeError because the dataset may
|
||||
be lazy-loaded with an unknown size or have infinite size.
|
||||
|
||||
In most cases, however, when an instance of this class is created by helper
|
||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||
and the _size instance variable will be already set.
|
||||
|
||||
Raises:
|
||||
TypeError if self._size is not set and the cardinality of self._dataset
|
||||
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
|
||||
"""
|
||||
if self._size is not None:
|
||||
return self._size
|
||||
else:
|
||||
|
@ -152,15 +165,25 @@ class Dataset(object):
|
|||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
|
||||
Raises:
|
||||
ValueError: if the provided fraction is not between 0 and 1.
|
||||
ValueError: if this dataset does not have a set size.
|
||||
"""
|
||||
assert (fraction > 0 and fraction < 1)
|
||||
if not (fraction > 0 and fraction < 1):
|
||||
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
|
||||
if not self._size:
|
||||
raise ValueError(
|
||||
'Dataset size unknown. Cannot split the dataset when '
|
||||
'the size is unknown.'
|
||||
)
|
||||
|
||||
dataset = self._dataset
|
||||
|
||||
train_size = int(self._size * fraction)
|
||||
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
||||
trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
|
||||
|
||||
test_size = self._size - train_size
|
||||
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
||||
testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
|
||||
|
||||
return trainset, testset
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
from typing import Mapping, Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -36,6 +36,8 @@ class BaseHParams:
|
|||
steps_per_epoch: An optional integer indicate the number of training steps
|
||||
per epoch. If not set, the training pipeline calculates the default steps
|
||||
per epoch as the training dataset size divided by batch size.
|
||||
class_weights: An optional mapping of indices to weights for weighting the
|
||||
loss function during training.
|
||||
shuffle: True if the dataset is shuffled before training.
|
||||
export_dir: The location of the model checkpoint files.
|
||||
distribution_strategy: A string specifying which Distribution Strategy to
|
||||
|
@ -57,6 +59,7 @@ class BaseHParams:
|
|||
batch_size: int
|
||||
epochs: int
|
||||
steps_per_epoch: Optional[int] = None
|
||||
class_weights: Optional[Mapping[int, float]] = None
|
||||
|
||||
# Dataset-related parameters
|
||||
shuffle: bool = False
|
||||
|
|
|
@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
|
|||
# dataset is exhausted even if there are epochs remaining.
|
||||
steps_per_epoch=None,
|
||||
validation_data=validation_dataset,
|
||||
callbacks=self._callbacks)
|
||||
callbacks=self._callbacks,
|
||||
class_weight=self._hparams.class_weights,
|
||||
)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
|
|
@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
|
|||
"""
|
||||
|
||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||
"""Constructor.
|
||||
"""Initializes FocalLoss.
|
||||
|
||||
Args:
|
||||
gamma: Focal loss gamma, as described in class docs.
|
||||
|
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
|
|||
return tf.reduce_sum(losses) / batch_size
|
||||
|
||||
|
||||
class SparseFocalLoss(FocalLoss):
|
||||
"""Sparse implementation of Focal Loss.
|
||||
|
||||
This is the same as FocalLoss, except the labels are expected to be class ids
|
||||
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
|
||||
in this same file for more details.
|
||||
|
||||
Example usage:
|
||||
>>> y_true = [1, 2]
|
||||
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||
>>> gamma = 2
|
||||
>>> focal_loss = SparseFocalLoss(gamma, 3)
|
||||
>>> focal_loss(y_true, y_pred).numpy()
|
||||
0.9326
|
||||
|
||||
>>> # Calling with 'sample_weight'.
|
||||
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||
0.6528
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
|
||||
):
|
||||
"""Initializes SparseFocalLoss.
|
||||
|
||||
Args:
|
||||
gamma: Focal loss gamma, as described in class docs.
|
||||
num_classes: Number of classes.
|
||||
class_weight: A weight to apply to the loss, one for each class. The
|
||||
weight is applied for each input where the ground truth label matches.
|
||||
"""
|
||||
super().__init__(gamma, class_weight=class_weight)
|
||||
self._num_classes = num_classes
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
y_true: tf.Tensor,
|
||||
y_pred: tf.Tensor,
|
||||
sample_weight: Optional[tf.Tensor] = None,
|
||||
) -> tf.Tensor:
|
||||
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
|
||||
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
|
||||
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PerceptualLossWeight:
|
||||
"""The weight for each perceptual loss.
|
||||
|
|
|
@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertNear(loss, expected_loss, 1e-4)
|
||||
|
||||
|
||||
class SparseFocalLossTest(tf.test.TestCase):
|
||||
|
||||
def test_sparse_focal_loss_matches_focal_loss(self):
|
||||
num_classes = 2
|
||||
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
|
||||
y_true = tf.constant([1, 0])
|
||||
y_true_one_hot = tf.one_hot(y_true, num_classes)
|
||||
for gamma in [0.0, 0.5, 1.0]:
|
||||
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
|
||||
loss_fn = loss_functions.SparseFocalLoss(
|
||||
gamma=gamma, num_classes=num_classes
|
||||
)
|
||||
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
self.assertNear(loss, expected_loss, 1e-4)
|
||||
|
||||
|
||||
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
||||
"""A mock class with implementation of abstract methods for testing."""
|
||||
|
||||
|
|
|
@ -46,13 +46,17 @@ class BertModelSpec:
|
|||
"""
|
||||
|
||||
downloaded_files: file_util.DownloadedFiles
|
||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||
hparams: hp.BaseHParams = dataclasses.field(
|
||||
default_factory=lambda: hp.BaseHParams(
|
||||
epochs=3,
|
||||
batch_size=32,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='mirrored')
|
||||
model_options: bert_model_options.BertModelOptions = (
|
||||
bert_model_options.BertModelOptions())
|
||||
distribution_strategy='mirrored',
|
||||
)
|
||||
)
|
||||
model_options: bert_model_options.BertModelOptions = dataclasses.field(
|
||||
default_factory=bert_model_options.BertModelOptions
|
||||
)
|
||||
do_lower_case: bool = True
|
||||
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
@ -76,7 +76,10 @@ py_test(
|
|||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -88,7 +91,10 @@ py_test(
|
|||
py_library(
|
||||
name = "preprocessor",
|
||||
srcs = ["preprocessor.py"],
|
||||
deps = [":dataset"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -99,6 +105,7 @@ py_test(
|
|||
":dataset",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -124,6 +131,7 @@ py_library(
|
|||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
|
@ -147,6 +155,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
":text_classifier_import",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,11 +15,15 @@
|
|||
|
||||
import csv
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from typing import Optional, Sequence
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
|
@ -46,21 +50,49 @@ class CSVParameters:
|
|||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for text classifier."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: tf.data.Dataset,
|
||||
label_names: List[str],
|
||||
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
|
||||
size: Optional[int] = None,
|
||||
):
|
||||
super().__init__(dataset, label_names, size)
|
||||
if not tfrecord_cache_files:
|
||||
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename="tfrecord", num_shards=1
|
||||
)
|
||||
self.tfrecord_cache_files = tfrecord_cache_files
|
||||
|
||||
@classmethod
|
||||
def from_csv(cls,
|
||||
def from_csv(
|
||||
cls,
|
||||
filename: str,
|
||||
csv_params: CSVParameters,
|
||||
shuffle: bool = True) -> "Dataset":
|
||||
shuffle: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
num_shards: int = 1,
|
||||
) -> "Dataset":
|
||||
"""Loads text with labels from a CSV file.
|
||||
|
||||
Args:
|
||||
filename: Name of the CSV file.
|
||||
csv_params: Parameters used for reading the CSV file.
|
||||
shuffle: If True, randomly shuffle the data.
|
||||
cache_dir: Optional parameter to specify where to store the preprocessed
|
||||
dataset. Only used for BERT models.
|
||||
num_shards: Optional parameter for num shards of the preprocessed dataset.
|
||||
Note that using more than 1 shard will reorder the dataset. Only used
|
||||
for BERT models.
|
||||
|
||||
Returns:
|
||||
Dataset containing (text, label) pairs and other related info.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = tempfile.mkdtemp()
|
||||
# calculate hash for cache based off of files
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(os.path.basename(filename).encode("utf-8"))
|
||||
with tf.io.gfile.GFile(filename, "r") as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
|
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
quotechar=csv_params.quotechar)
|
||||
|
||||
lines = list(reader)
|
||||
for line in lines:
|
||||
hasher.update(str(line).encode("utf-8"))
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(lines)
|
||||
|
||||
|
@ -81,8 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
index_by_label[line[csv_params.label_column]] for line in lines
|
||||
]
|
||||
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(label_indices, tf.int64))
|
||||
tf.cast(label_indices, tf.int64)
|
||||
)
|
||||
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||
|
||||
hasher.update(str(num_shards).encode("utf-8"))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename, cache_dir, num_shards
|
||||
)
|
||||
return Dataset(
|
||||
dataset=text_label_ds, size=len(texts), label_names=label_names)
|
||||
dataset=text_label_ds,
|
||||
label_names=label_names,
|
||||
tfrecord_cache_files=tfrecord_cache_files,
|
||||
size=len(texts),
|
||||
)
|
||||
|
|
|
@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||
data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
|
||||
train_data, test_data = data.split(0.5)
|
||||
expected_train_data = [b'good', b'bad']
|
||||
expected_test_data = [b'neutral', b'odd']
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
|
|||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
||||
and "lamb".
|
||||
end_learning_rate: End learning rate for linear decay. Defaults to 0.
|
||||
batch_size: Batch size for training. Defaults to 48.
|
||||
epochs: Number of training iterations over the dataset. Defaults to 2.
|
||||
optimizer: Optimizer to use for training. Supported values are defined in
|
||||
BertOptimizer enum: ADAMW and LAMB.
|
||||
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
|
||||
desired_precisions: If specified, adds a RecallAtPrecision metric per
|
||||
desired_precisions[i] entry which tracks the recall given the constraint
|
||||
on precision. Only supported for binary classification.
|
||||
desired_recalls: If specified, adds a PrecisionAtRecall metric per
|
||||
desired_recalls[i] entry which tracks the precision given the constraint
|
||||
on recall. Only supported for binary classification.
|
||||
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
||||
value to 0. Defaults to 2.0.
|
||||
"""
|
||||
|
||||
learning_rate: float = 3e-5
|
||||
end_learning_rate: float = 0.0
|
||||
|
||||
batch_size: int = 48
|
||||
epochs: int = 2
|
||||
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||
weight_decay: float = 0.01
|
||||
|
||||
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
|
||||
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
|
||||
|
||||
gamma: float = 2.0
|
||||
|
||||
|
||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||
|
|
|
@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
|
|||
"""
|
||||
|
||||
# `learning_rate` is unused for the average word embedding model
|
||||
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
|
||||
hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
|
||||
default_factory=lambda: hp.AverageWordEmbeddingHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0
|
||||
)
|
||||
model_options: mo.AverageWordEmbeddingModelOptions = (
|
||||
mo.AverageWordEmbeddingModelOptions())
|
||||
)
|
||||
model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
|
||||
default_factory=mo.AverageWordEmbeddingModelOptions
|
||||
)
|
||||
name: str = 'AverageWordEmbedding'
|
||||
|
||||
average_word_embedding_classifier_spec = functools.partial(
|
||||
|
@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
|
|||
inherited from the BertModelSpec.
|
||||
"""
|
||||
|
||||
hparams: hp.BertHParams = hp.BertHParams()
|
||||
hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
|
||||
|
||||
|
||||
mobilebert_classifier_spec = functools.partial(
|
||||
|
@ -76,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
|
|||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='MobileBert',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
},
|
||||
)
|
||||
|
||||
exbert_classifier_spec = functools.partial(
|
||||
|
@ -90,11 +88,6 @@ exbert_classifier_spec = functools.partial(
|
|||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='ExBert',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
||||
self.assertTrue(model_spec_obj.do_lower_case)
|
||||
self.assertEqual(
|
||||
model_spec_obj.tflite_input_name, {
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
'segment_ids': 'serving_default_input_2:0'
|
||||
})
|
||||
model_spec_obj.tflite_input_name,
|
||||
{
|
||||
'ids': 'serving_default_input_word_ids:0',
|
||||
'mask': 'serving_default_input_mask:0',
|
||||
'segment_ids': 'serving_default_input_type_ids:0',
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.BertModelOptions(
|
||||
|
|
|
@ -15,14 +15,15 @@
|
|||
"""Preprocessors for text classification."""
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Mapping, Sequence, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from official.nlp.data import classifier_data_lib
|
||||
from official.nlp.tools import tokenization
|
||||
|
@ -75,19 +76,20 @@ def _decode_record(
|
|||
return bert_features, example["label_ids"]
|
||||
|
||||
|
||||
def _single_file_dataset(
|
||||
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||
def _tfrecord_dataset(
|
||||
tfrecord_files: Sequence[str],
|
||||
name_to_features: Mapping[str, tf.io.FixedLenFeature],
|
||||
) -> tf.data.TFRecordDataset:
|
||||
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||
|
||||
Args:
|
||||
input_file: Filepath for the dataset.
|
||||
tfrecord_files: Filepaths for the dataset.
|
||||
name_to_features: Maps record keys to feature types.
|
||||
|
||||
Returns:
|
||||
Dataset containing BERT model input features and labels.
|
||||
"""
|
||||
d = tf.data.TFRecordDataset(input_file)
|
||||
d = tf.data.TFRecordDataset(tfrecord_files)
|
||||
d = d.map(
|
||||
lambda record: _decode_record(record, name_to_features),
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
|
|||
seq_len: Length of the input sequence to the model.
|
||||
vocab_file: File containing the BERT vocab.
|
||||
tokenizer: BERT tokenizer.
|
||||
model_name: Name of the model provided by the model_spec. Used to associate
|
||||
cached files with specific Bert model vocab.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
||||
def __init__(
|
||||
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
|
||||
):
|
||||
self._seq_len = seq_len
|
||||
# Vocab filepath is tied to the BERT module's URI.
|
||||
self._vocab_file = os.path.join(
|
||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
||||
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
||||
do_lower_case)
|
||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
||||
)
|
||||
self._do_lower_case = do_lower_case
|
||||
self._tokenizer = tokenization.FullTokenizer(
|
||||
self._vocab_file, self._do_lower_case
|
||||
)
|
||||
self._model_name = model_name
|
||||
|
||||
def _get_name_to_features(self):
|
||||
"""Gets the dictionary mapping record keys to feature types."""
|
||||
|
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
|
|||
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||
return self._vocab_file
|
||||
|
||||
def _get_tfrecord_cache_files(
|
||||
self, ds_cache_files
|
||||
) -> cache_files_lib.TFRecordCacheFiles:
|
||||
"""Helper to regenerate cache prefix filename using preprocessor info.
|
||||
|
||||
We need to update the dataset cache_prefix cache because the actual cached
|
||||
dataset depends on the preprocessor parameters such as model_name, seq_len,
|
||||
and do_lower_case in addition to the raw dataset parameters which is already
|
||||
included in the ds_cache_files.cache_prefix_filename
|
||||
|
||||
Specifically, the new cache_prefix_filename used by the preprocessor will
|
||||
be a hash generated from the following:
|
||||
1. cache_prefix_filename of the initial raw dataset
|
||||
2. model_name
|
||||
3. seq_len
|
||||
4. do_lower_case
|
||||
|
||||
Args:
|
||||
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
||||
|
||||
Returns:
|
||||
A new TFRecordCacheFiles object which incorporates the preprocessor
|
||||
parameters.
|
||||
"""
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
|
||||
hasher.update(self._model_name.encode("utf-8"))
|
||||
hasher.update(str(self._seq_len).encode("utf-8"))
|
||||
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
return cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename,
|
||||
ds_cache_files.cache_dir,
|
||||
ds_cache_files.num_shards,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||
self, dataset: text_classifier_ds.Dataset
|
||||
) -> text_classifier_ds.Dataset:
|
||||
"""Preprocesses data into input for a BERT-based classifier.
|
||||
|
||||
Args:
|
||||
|
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
|
|||
Returns:
|
||||
Dataset containing (bert_features, label) data.
|
||||
"""
|
||||
examples = []
|
||||
ds_cache_files = dataset.tfrecord_cache_files
|
||||
# Get new tfrecord_cache_files by including preprocessor information.
|
||||
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
|
||||
if not tfrecord_cache_files.is_cached():
|
||||
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
|
||||
writers = tfrecord_cache_files.get_writers()
|
||||
size = 0
|
||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||
_validate_text_and_label(text, label)
|
||||
examples.append(
|
||||
classifier_data_lib.InputExample(
|
||||
example = classifier_data_lib.InputExample(
|
||||
guid=str(index),
|
||||
text_a=text.numpy()[0].decode("utf-8"),
|
||||
text_b=None,
|
||||
# InputExample expects the label name rather than the int ID
|
||||
label=dataset.label_names[label.numpy()[0]]))
|
||||
# label=dataset.label_names[label.numpy()[0]])
|
||||
label=label.numpy()[0],
|
||||
)
|
||||
feature = classifier_data_lib.convert_single_example(
|
||||
index, example, None, self._seq_len, self._tokenizer
|
||||
)
|
||||
|
||||
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
||||
classifier_data_lib.file_based_convert_examples_to_features(
|
||||
examples=examples,
|
||||
label_list=dataset.label_names,
|
||||
max_seq_length=self._seq_len,
|
||||
tokenizer=self._tokenizer,
|
||||
output_file=tfrecord_file)
|
||||
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
||||
self._get_name_to_features())
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=list(values))
|
||||
)
|
||||
return f
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
features["label_ids"] = create_int_feature([feature.label_id])
|
||||
tf_example = tf.train.Example(
|
||||
features=tf.train.Features(feature=features)
|
||||
)
|
||||
writers[index % len(writers)].write(tf_example.SerializeToString())
|
||||
size = index + 1
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
metadata = {"size": size, "label_names": dataset.label_names}
|
||||
tfrecord_cache_files.save_metadata(metadata)
|
||||
else:
|
||||
print(
|
||||
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
|
||||
)
|
||||
metadata = tfrecord_cache_files.load_metadata()
|
||||
size = metadata["size"]
|
||||
label_names = metadata["label_names"]
|
||||
preprocessed_ds = _tfrecord_dataset(
|
||||
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
|
||||
)
|
||||
return text_classifier_ds.Dataset(
|
||||
dataset=preprocessed_ds,
|
||||
size=dataset.size,
|
||||
label_names=dataset.label_names)
|
||||
size=size,
|
||||
label_names=label_names,
|
||||
tfrecord_cache_files=tfrecord_cache_files,
|
||||
)
|
||||
|
||||
|
||||
TextClassifierPreprocessor = (
|
||||
Union[BertClassifierPreprocessor,
|
||||
AverageWordEmbeddingClassifierPreprocessor])
|
||||
TextClassifierPreprocessor = Union[
|
||||
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
||||
]
|
||||
|
|
|
@ -13,14 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import mock
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
|
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
labels = []
|
||||
|
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
self.assertEqual(label.shape, [1])
|
||||
labels.append(label.numpy()[0])
|
||||
self.assertSameElements(
|
||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
|
||||
)
|
||||
for feature in features.values():
|
||||
self.assertEqual(feature.shape, [1, 5])
|
||||
input_masks.append(features['input_mask'].numpy()[0])
|
||||
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
||||
[0, 0, 0, 0, 0])
|
||||
npt.assert_array_equal(
|
||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
||||
features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
|
||||
)
|
||||
npt.assert_array_equal(
|
||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
|
||||
)
|
||||
self.assertEqual(labels, [1, 0])
|
||||
|
||||
def test_bert_preprocessor_cache(self):
|
||||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file,
|
||||
csv_params=self.CSV_PARAMS_,
|
||||
cache_dir=self.get_temp_dir(),
|
||||
)
|
||||
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
ds_cache_files = dataset.tfrecord_cache_files
|
||||
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
||||
ds_cache_files
|
||||
)
|
||||
self.assertFalse(preprocessed_cache_files.is_cached())
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
self.assertTrue(preprocessed_cache_files.is_cached())
|
||||
self.assertEqual(
|
||||
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
|
||||
)
|
||||
|
||||
# The second time running preprocessor, it should load from cache directly
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
_ = bert_preprocessor.preprocess(dataset)
|
||||
self.assertEqual(
|
||||
mock_stdout.getvalue(),
|
||||
'Using existing cache files at'
|
||||
f' {preprocessed_cache_files.cache_prefix}\n',
|
||||
)
|
||||
|
||||
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=seq_len,
|
||||
do_lower_case=do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||
return new_cf.cache_prefix_filename
|
||||
|
||||
def test_bert_get_tfrecord_cache_files(self):
|
||||
# Test to ensure regenerated cache_files have different prefixes
|
||||
all_cf_prefixes = set()
|
||||
cf = cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='cache_prefix',
|
||||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
|
||||
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
|
||||
new_cf = cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='new_cache_prefix',
|
||||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
|
||||
|
||||
# Each item of all_cf_prefixes should be unique, so 7 total.
|
||||
self.assertLen(all_cf_prefixes, 7)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
|
@ -27,8 +27,8 @@
|
|||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
|
|
|
@ -24,6 +24,7 @@ import tensorflow_hub as hub
|
|||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.core.utils import metrics
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
|
|||
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||
):
|
||||
text_classifier = (
|
||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||
options,
|
||||
train_data.label_names))
|
||||
text_classifier = _BertClassifier.create_bert_classifier(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
elif (options.supported_model ==
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
text_classifier = (
|
||||
_AverageWordEmbeddingClassifier
|
||||
.create_average_word_embedding_classifier(train_data, validation_data,
|
||||
options,
|
||||
train_data.label_names))
|
||||
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model {options.supported_model}")
|
||||
|
||||
|
@ -166,27 +164,7 @@ class TextClassifier(classifier.Classifier):
|
|||
processed_data = self._text_preprocessor.preprocess(data)
|
||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||
|
||||
additional_metrics = []
|
||||
if desired_precisions and len(data.label_names) == 2:
|
||||
for precision in desired_precisions:
|
||||
additional_metrics.append(
|
||||
metrics.BinarySparseRecallAtPrecision(
|
||||
precision, name=f"recall_at_precision_{precision}"
|
||||
)
|
||||
)
|
||||
if desired_recalls and len(data.label_names) == 2:
|
||||
for recall in desired_recalls:
|
||||
additional_metrics.append(
|
||||
metrics.BinarySparsePrecisionAtRecall(
|
||||
recall, name=f"precision_at_recall_{recall}"
|
||||
)
|
||||
)
|
||||
metric_functions = self._metric_functions + additional_metrics
|
||||
self._model.compile(
|
||||
optimizer=self._optimizer,
|
||||
loss=self._loss_function,
|
||||
metrics=metric_functions,
|
||||
)
|
||||
with self._hparams.get_strategy().scope():
|
||||
return self._model.evaluate(dataset)
|
||||
|
||||
def export_model(
|
||||
|
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
|||
|
||||
@classmethod
|
||||
def create_average_word_embedding_classifier(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
cls,
|
||||
train_data: text_ds.Dataset,
|
||||
validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions,
|
||||
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
||||
) -> "_AverageWordEmbeddingClassifier":
|
||||
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
label_names: Label names used in the data.
|
||||
|
||||
Returns:
|
||||
An Average Word Embedding classifier.
|
||||
|
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
|
|||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._model_options = model_options
|
||||
with self._hparams.get_strategy().scope():
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self._metric_functions = [
|
||||
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
"test_accuracy", dtype=tf.float32
|
||||
),
|
||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||
]
|
||||
self._loss_function = loss_functions.SparseFocalLoss(
|
||||
self._hparams.gamma, self._num_classes
|
||||
)
|
||||
self._metric_functions = self._create_metrics()
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
def create_bert_classifier(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
cls,
|
||||
train_data: text_ds.Dataset,
|
||||
validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions,
|
||||
label_names: Sequence[str]) -> "_BertClassifier":
|
||||
) -> "_BertClassifier":
|
||||
"""Creates, trains, and returns a BERT-based classifier.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
label_names: Label names used in the data.
|
||||
|
||||
Returns:
|
||||
A BERT-based classifier.
|
||||
|
@ -435,9 +411,59 @@ class _BertClassifier(TextClassifier):
|
|||
seq_len=self._model_options.seq_len,
|
||||
do_lower_case=self._model_spec.do_lower_case,
|
||||
uri=self._model_spec.downloaded_files.get_path(),
|
||||
model_name=self._model_spec.name,
|
||||
)
|
||||
return (self._text_preprocessor.preprocess(train_data),
|
||||
self._text_preprocessor.preprocess(validation_data))
|
||||
return (
|
||||
self._text_preprocessor.preprocess(train_data),
|
||||
self._text_preprocessor.preprocess(validation_data),
|
||||
)
|
||||
|
||||
def _create_metrics(self):
|
||||
"""Creates metrics for training and evaluation.
|
||||
|
||||
The default metrics are accuracy, precision, and recall.
|
||||
|
||||
For binary classification tasks only (num_classes=2):
|
||||
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
|
||||
the desired_presisions and desired_recalls fields in BertHParams.
|
||||
|
||||
Returns:
|
||||
A list of tf.keras.Metric subclasses which can be used with model.compile
|
||||
"""
|
||||
metric_functions = [
|
||||
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
"accuracy", dtype=tf.float32
|
||||
),
|
||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||
]
|
||||
if self._num_classes == 2:
|
||||
if self._hparams.desired_precisions:
|
||||
for desired_precision in self._hparams.desired_precisions:
|
||||
metric_functions.append(
|
||||
metrics.BinarySparseRecallAtPrecision(
|
||||
desired_precision,
|
||||
name=f"recall_at_precision_{desired_precision}",
|
||||
num_thresholds=1000,
|
||||
)
|
||||
)
|
||||
if self._hparams.desired_recalls:
|
||||
for desired_recall in self._hparams.desired_recalls:
|
||||
metric_functions.append(
|
||||
metrics.BinarySparseRecallAtPrecision(
|
||||
desired_recall,
|
||||
name=f"precision_at_recall_{desired_recall}",
|
||||
num_thresholds=1000,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self._hparams.desired_precisions or self._hparams.desired_recalls:
|
||||
raise ValueError(
|
||||
"desired_recalls and desired_precisions parameters are binary"
|
||||
" metrics and not supported for num_classes > 2. Found"
|
||||
f" num_classes: {self._num_classes}"
|
||||
)
|
||||
return metric_functions
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates a BERT-based classifier model.
|
||||
|
@ -447,11 +473,20 @@ class _BertClassifier(TextClassifier):
|
|||
"""
|
||||
encoder_inputs = dict(
|
||||
input_word_ids=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
shape=(self._model_options.seq_len,),
|
||||
dtype=tf.int32,
|
||||
name="input_word_ids",
|
||||
),
|
||||
input_mask=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
shape=(self._model_options.seq_len,),
|
||||
dtype=tf.int32,
|
||||
name="input_mask",
|
||||
),
|
||||
input_type_ids=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
shape=(self._model_options.seq_len,),
|
||||
dtype=tf.int32,
|
||||
name="input_type_ids",
|
||||
),
|
||||
)
|
||||
encoder = hub.KerasLayer(
|
||||
self._model_spec.downloaded_files.get_path(),
|
||||
|
@ -493,16 +528,21 @@ class _BertClassifier(TextClassifier):
|
|||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_steps=total_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0)
|
||||
end_learning_rate=self._hparams.end_learning_rate,
|
||||
power=1.0,
|
||||
)
|
||||
if warmup_steps:
|
||||
lr_schedule = model_util.WarmUp(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_schedule_fn=lr_schedule,
|
||||
warmup_steps=warmup_steps)
|
||||
warmup_steps=warmup_steps,
|
||||
)
|
||||
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
||||
lr_schedule,
|
||||
weight_decay=self._hparams.weight_decay,
|
||||
epsilon=1e-6,
|
||||
global_clipnorm=1.0,
|
||||
)
|
||||
self._optimizer.exclude_from_weight_decay(
|
||||
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||
|
@ -510,7 +550,7 @@ class _BertClassifier(TextClassifier):
|
|||
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||
self._optimizer = tfa_optimizers.LAMB(
|
||||
lr_schedule,
|
||||
weight_decay_rate=0.01,
|
||||
weight_decay_rate=self._hparams.weight_decay,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||
global_clipnorm=1.0,
|
||||
|
|
|
@ -84,8 +84,8 @@ def run(data_dir,
|
|||
options)
|
||||
|
||||
# Gets evaluation results.
|
||||
_, acc = model.evaluate(validation_data)
|
||||
print('Eval accuracy: %f' % acc)
|
||||
metrics = model.evaluate(validation_data)
|
||||
print('Eval accuracy: %f' % metrics[1])
|
||||
|
||||
model.export_model(quantization_config=quantization_config)
|
||||
model.export_labels(export_dir=options.hparams.export_dir)
|
||||
|
|
|
@ -16,17 +16,17 @@ import csv
|
|||
import filecmp
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
@unittest.skip('b/275624089')
|
||||
class TextClassifierTest(tf.test.TestCase):
|
||||
class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||
|
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options))
|
||||
|
||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
metrics = average_word_embedding_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||
|
||||
# Test export_model
|
||||
average_word_embedding_classifier.export_model()
|
||||
|
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
filecmp.cmp(
|
||||
output_metadata_file,
|
||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
||||
shallow=False))
|
||||
shallow=False,
|
||||
)
|
||||
)
|
||||
|
||||
def test_create_and_train_bert(self):
|
||||
@parameterized.named_parameters(
|
||||
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
|
||||
# dict(
|
||||
# testcase_name='mobilebert',
|
||||
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
# ),
|
||||
dict(
|
||||
testcase_name='exbert',
|
||||
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
|
||||
),
|
||||
)
|
||||
def test_create_and_train_bert(self, supported_model):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
supported_model=supported_model,
|
||||
model_options=text_classifier.BertModelOptions(
|
||||
do_fine_tuning=False, seq_len=2
|
||||
),
|
||||
|
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
bert_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options)
|
||||
|
||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
metrics = bert_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||
|
||||
# Test export_model
|
||||
bert_classifier.export_model()
|
||||
|
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
)
|
||||
|
||||
def test_label_mismatch(self):
|
||||
options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
|
||||
)
|
||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
|
||||
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
|
||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||
validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Training data label names .* not equal to validation data label names'
|
||||
'Training data label names .* not equal to validation data label names',
|
||||
):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options)
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
|
||||
def test_options_mismatch(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
|
||||
avg_options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
avg_options)
|
||||
avg_options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
|
||||
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
|
||||
)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.EXBERT_CLASSIFIER',
|
||||
):
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, avg_options
|
||||
)
|
||||
|
||||
bert_options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels
|
||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
model_options=text_classifier.BertModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
bert_options)
|
||||
bert_options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||
),
|
||||
model_options=text_classifier.BertModelOptions(),
|
||||
)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
|
||||
):
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, bert_options
|
||||
)
|
||||
|
||||
def test_bert_loss_and_metrics_creation(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
|
||||
hparams = text_classifier.BertHParams(
|
||||
desired_recalls=[0.2],
|
||||
desired_precisions=[0.9],
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off',
|
||||
gamma=3.5,
|
||||
)
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=supported_model, hparams=hparams
|
||||
)
|
||||
bert_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
loss_fn = bert_classifier._loss_function
|
||||
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
|
||||
self.assertEqual(loss_fn._gamma, 3.5)
|
||||
self.assertEqual(loss_fn._num_classes, 2)
|
||||
metric_names = [m.name for m in bert_classifier._metric_functions]
|
||||
expected_metric_names = [
|
||||
'accuracy',
|
||||
'recall',
|
||||
'precision',
|
||||
'precision_at_recall_0.2',
|
||||
'recall_at_precision_0.9',
|
||||
]
|
||||
self.assertCountEqual(metric_names, expected_metric_names)
|
||||
|
||||
# Non-binary data
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'desired_recalls and desired_precisions parameters are binary metrics'
|
||||
' and not supported for num_classes > 2. Found num_classes: 3',
|
||||
):
|
||||
text_classifier.TextClassifier.create(data, data, options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
', '.join(label_names),
|
||||
)
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
||||
dataset=image_label_ds,
|
||||
label_names=label_names,
|
||||
size=all_image_size,
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
|
|
@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
||||
return Dataset(
|
||||
dataset=hand_embedding_label_ds,
|
||||
label_names=label_names,
|
||||
size=len(valid_hand_data),
|
||||
label_names=label_names)
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||
# Placeholder for internal Python library rule.
|
||||
|
||||
licenses(["notice"])
|
||||
|
|
|
@ -15,28 +15,12 @@
|
|||
|
||||
import os
|
||||
import random
|
||||
|
||||
from typing import List, Optional
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
|
||||
|
||||
def _create_data(
|
||||
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
|
||||
label_names: List[str]
|
||||
) -> Optional[classification_dataset.ClassificationDataset]:
|
||||
"""Creates a Dataset object from tfds data."""
|
||||
if name not in data:
|
||||
return None
|
||||
data = data[name]
|
||||
data = data.map(lambda a: (a['image'], a['label']))
|
||||
size = info.splits[name].num_examples
|
||||
return Dataset(data, size, label_names)
|
||||
|
||||
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for image classifier."""
|
||||
|
||||
|
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, label_names=label_names)
|
||||
dataset=image_label_ds, label_names=label_names, size=all_image_size
|
||||
)
|
||||
|
|
|
@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
|
||||
data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
|
||||
train_data, test_data = data.split(fraction=0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
|
|
|
@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
ds = tf.data.Dataset.from_generator(
|
||||
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
||||
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
||||
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
|
||||
['cyan', 'magenta', 'yellow'])
|
||||
data = image_classifier.Dataset(
|
||||
ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
|
||||
)
|
||||
return data
|
||||
|
||||
def setUp(self):
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
@ -54,6 +54,7 @@ py_library(
|
|||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":dataset_util",
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
],
|
||||
)
|
||||
|
@ -73,6 +74,7 @@ py_test(
|
|||
py_library(
|
||||
name = "dataset_util",
|
||||
srcs = ["dataset_util.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/data:cache_files"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||
from official.vision.dataloaders import tf_example_decoder
|
||||
|
@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
ValueError: If the label_name for id 0 is set to something other than
|
||||
the 'background' class.
|
||||
"""
|
||||
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
|
||||
if not dataset_util.is_cached(cache_files):
|
||||
tfrecord_cache_files = dataset_util.get_cache_files_coco(
|
||||
data_dir, cache_dir
|
||||
)
|
||||
if not tfrecord_cache_files.is_cached():
|
||||
label_map = dataset_util.get_label_map_coco(data_dir)
|
||||
cache_writer = dataset_util.COCOCacheFilesWriter(
|
||||
label_map=label_map, max_num_images=max_num_images
|
||||
)
|
||||
cache_writer.write_files(cache_files, data_dir)
|
||||
return cls.from_cache(cache_files.cache_prefix)
|
||||
cache_writer.write_files(tfrecord_cache_files, data_dir)
|
||||
return cls.from_cache(tfrecord_cache_files)
|
||||
|
||||
@classmethod
|
||||
def from_pascal_voc_folder(
|
||||
|
@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
|
||||
if not dataset_util.is_cached(cache_files):
|
||||
tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||
data_dir, cache_dir
|
||||
)
|
||||
if not tfrecord_cache_files.is_cached():
|
||||
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
||||
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
||||
label_map=label_map, max_num_images=max_num_images
|
||||
)
|
||||
cache_writer.write_files(cache_files, data_dir)
|
||||
cache_writer.write_files(tfrecord_cache_files, data_dir)
|
||||
|
||||
return cls.from_cache(cache_files.cache_prefix)
|
||||
return cls.from_cache(tfrecord_cache_files)
|
||||
|
||||
@classmethod
|
||||
def from_cache(cls, cache_prefix: str) -> 'Dataset':
|
||||
def from_cache(
|
||||
cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles
|
||||
) -> 'Dataset':
|
||||
"""Loads the TFRecord data from cache.
|
||||
|
||||
Args:
|
||||
cache_prefix: The cache prefix including the cache directory and the cache
|
||||
prefix filename, e.g: '/tmp/cache/train'.
|
||||
tfrecord_cache_files: The TFRecordCacheFiles object containing the already
|
||||
cached TFRecord and metadata files.
|
||||
|
||||
Returns:
|
||||
ObjectDetectorDataset object.
|
||||
|
||||
Raises:
|
||||
ValueError if tfrecord_cache_files are not already cached.
|
||||
"""
|
||||
# Get TFRecord Files
|
||||
tfrecord_file_pattern = cache_prefix + '*.tfrecord'
|
||||
matched_files = tf.io.gfile.glob(tfrecord_file_pattern)
|
||||
if not matched_files:
|
||||
raise ValueError('TFRecord files are empty.')
|
||||
if not tfrecord_cache_files.is_cached():
|
||||
raise ValueError(
|
||||
'Cache files must be already cached to use the from_cache method.'
|
||||
)
|
||||
|
||||
# Load meta_data.
|
||||
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
|
||||
if not tf.io.gfile.exists(meta_data_file):
|
||||
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
|
||||
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
|
||||
meta_data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
metadata = tfrecord_cache_files.load_metadata()
|
||||
|
||||
dataset = tf.data.TFRecordDataset(matched_files)
|
||||
dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files)
|
||||
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
||||
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
label_map = meta_data['label_map']
|
||||
label_map = metadata['label_map']
|
||||
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
||||
|
||||
return Dataset(
|
||||
dataset=dataset, size=meta_data['size'], label_names=label_names
|
||||
dataset=dataset, label_names=label_names, size=metadata['size']
|
||||
)
|
||||
|
|
|
@ -15,25 +15,20 @@
|
|||
|
||||
import abc
|
||||
import collections
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files
|
||||
from official.vision.data import tfrecord_lib
|
||||
|
||||
|
||||
# Suffix of the meta data file name.
|
||||
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
|
||||
|
||||
|
||||
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
||||
"""Gets a named child from an XML Element node.
|
||||
|
||||
|
@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str:
|
|||
return os.path.basename(os.path.abspath(data_dir))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CacheFiles:
|
||||
"""Cache files for object detection."""
|
||||
|
||||
cache_prefix: str
|
||||
tfrecord_files: Sequence[str]
|
||||
meta_data_file: str
|
||||
|
||||
|
||||
def _get_cache_files(
|
||||
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
||||
) -> CacheFiles:
|
||||
) -> cache_files.TFRecordCacheFiles:
|
||||
"""Creates an object of CacheFiles class.
|
||||
|
||||
Args:
|
||||
|
@ -96,28 +82,16 @@ def _get_cache_files(
|
|||
An object of CacheFiles class.
|
||||
"""
|
||||
cache_dir = _get_cache_dir_or_create(cache_dir)
|
||||
# The cache prefix including the cache directory and the cache prefix
|
||||
# filename, e.g: '/tmp/cache/train'.
|
||||
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
|
||||
tf.compat.v1.logging.info(
|
||||
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
|
||||
% (cache_dir, cache_prefix_filename, cache_prefix)
|
||||
)
|
||||
|
||||
# Cached files including the TFRecord files and the meta data file.
|
||||
tfrecord_files = [
|
||||
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
|
||||
for i in range(num_shards)
|
||||
]
|
||||
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
|
||||
return CacheFiles(
|
||||
cache_prefix=cache_prefix,
|
||||
tfrecord_files=tuple(tfrecord_files),
|
||||
meta_data_file=meta_data_file,
|
||||
return cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename=cache_prefix_filename,
|
||||
cache_dir=cache_dir,
|
||||
num_shards=num_shards,
|
||||
)
|
||||
|
||||
|
||||
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||
def get_cache_files_coco(
|
||||
data_dir: str, cache_dir: str
|
||||
) -> cache_files.TFRecordCacheFiles:
|
||||
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
||||
|
||||
Args:
|
||||
|
@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
|||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||
|
||||
|
||||
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||
def get_cache_files_pascal_voc(
|
||||
data_dir: str, cache_dir: str
|
||||
) -> cache_files.TFRecordCacheFiles:
|
||||
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
||||
|
||||
Args:
|
||||
|
@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
|||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||
|
||||
|
||||
def is_cached(cache_files: CacheFiles) -> bool:
|
||||
"""Checks whether cache files are already cached."""
|
||||
all_cached_files = list(cache_files.tfrecord_files) + [
|
||||
cache_files.meta_data_file
|
||||
]
|
||||
return all(tf.io.gfile.exists(path) for path in all_cached_files)
|
||||
|
||||
|
||||
class CacheFilesWriter(abc.ABC):
|
||||
"""CacheFilesWriter class to write the cached files."""
|
||||
|
||||
|
@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC):
|
|||
self.label_map = label_map
|
||||
self.max_num_images = max_num_images
|
||||
|
||||
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
|
||||
"""Writes TFRecord and meta_data files.
|
||||
def write_files(
|
||||
self,
|
||||
tfrecord_cache_files: cache_files.TFRecordCacheFiles,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Writes TFRecord and metadata files.
|
||||
|
||||
Args:
|
||||
cache_files: CacheFiles object including a list of TFRecord files and the
|
||||
meta data yaml file to save the meta_data including data size and
|
||||
label_map.
|
||||
tfrecord_cache_files: TFRecordCacheFiles object including a list of
|
||||
TFRecord files and the meta data yaml file to save the metadata
|
||||
including data size and label_map.
|
||||
*args: Non-keyword of parameters used in the `_get_example` method.
|
||||
**kwargs: Keyword parameters used in the `_get_example` method.
|
||||
"""
|
||||
writers = [
|
||||
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
|
||||
]
|
||||
writers = tfrecord_cache_files.get_writers()
|
||||
|
||||
# Writes tf.Example into TFRecord files.
|
||||
size = 0
|
||||
|
@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC):
|
|||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
# Writes meta_data into meta_data_file.
|
||||
meta_data = {'size': size, 'label_map': self.label_map}
|
||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
|
||||
yaml.dump(meta_data, f)
|
||||
# Writes metadata into metadata_file.
|
||||
metadata = {'size': size, 'label_map': self.label_map}
|
||||
tfrecord_cache_files.save_metadata(metadata)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_example(self, *args, **kwargs):
|
||||
|
|
|
@ -19,7 +19,6 @@ import shutil
|
|||
from unittest import mock as unittest_mock
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import test_utils
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||
|
@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase):
|
|||
|
||||
def _assert_cache_files_equal(self, cf1, cf2):
|
||||
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
||||
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
|
||||
self.assertEqual(cf1.num_shards, cf2.num_shards)
|
||||
|
||||
def _assert_cache_files_not_equal(self, cf1, cf2):
|
||||
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
||||
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
|
||||
|
||||
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
||||
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
||||
|
@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase):
|
|||
self.assertEqual(
|
||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||
)
|
||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
||||
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
|
||||
|
||||
def test_matching_get_cache_files_coco(self):
|
||||
cache_dir = self.create_tempdir()
|
||||
|
@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase):
|
|||
self.assertEqual(
|
||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||
)
|
||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
||||
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
|
||||
|
||||
def test_matching_get_cache_files_pascal_voc(self):
|
||||
cache_dir = self.create_tempdir()
|
||||
|
@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase):
|
|||
cache_files = dataset_util.get_cache_files_coco(
|
||||
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
||||
)
|
||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
||||
self.assertFalse(cache_files.is_cached())
|
||||
with open(cache_files.tfrecord_files[0], 'w') as f:
|
||||
f.write('test')
|
||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
||||
with open(cache_files.meta_data_file, 'w') as f:
|
||||
self.assertFalse(cache_files.is_cached())
|
||||
with open(cache_files.metadata_file, 'w') as f:
|
||||
f.write('test')
|
||||
self.assertTrue(dataset_util.is_cached(cache_files))
|
||||
self.assertTrue(cache_files.is_cached())
|
||||
|
||||
def test_get_label_map_coco(self):
|
||||
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||
|
@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase):
|
|||
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
||||
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
||||
|
||||
# Checks the meta_data file
|
||||
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
|
||||
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
|
||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
|
||||
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# Size is 3 because some examples are skipped for having poor bboxes
|
||||
self.assertEqual(meta_data_dict['size'], expected_size)
|
||||
# Checks the metadata file
|
||||
self.assertTrue(os.path.isfile(cache_files.metadata_file))
|
||||
self.assertGreater(os.path.getsize(cache_files.metadata_file), 0)
|
||||
metadata_dict = cache_files.load_metadata()
|
||||
self.assertEqual(metadata_dict['size'], expected_size)
|
||||
|
||||
def test_coco_cache_files_writer(self):
|
||||
tempdir = self.create_tempdir()
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user