Merge branch 'master' into nguyencse/facemeshioslib
This commit is contained in:
commit
1001ead358
14
WORKSPACE
14
WORKSPACE
|
@ -157,22 +157,22 @@ http_archive(
|
||||||
# 2020-08-21
|
# 2020-08-21
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "com_github_glog_glog",
|
name = "com_github_glog_glog",
|
||||||
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
|
strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
|
||||||
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
|
sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
|
||||||
urls = [
|
urls = [
|
||||||
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
|
"https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "com_github_glog_glog_no_gflags",
|
name = "com_github_glog_glog_no_gflags",
|
||||||
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6",
|
strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
|
||||||
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab",
|
sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
|
||||||
build_file = "@//third_party:glog_no_gflags.BUILD",
|
build_file = "@//third_party:glog_no_gflags.BUILD",
|
||||||
urls = [
|
urls = [
|
||||||
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip",
|
"https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
|
||||||
],
|
],
|
||||||
patches = [
|
patches = [
|
||||||
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff",
|
"@//third_party:com_github_glog_glog.diff",
|
||||||
],
|
],
|
||||||
patch_args = [
|
patch_args = [
|
||||||
"-p1",
|
"-p1",
|
||||||
|
|
151
mediapipe/BUILD
151
mediapipe/BUILD
|
@ -68,30 +68,108 @@ config_setting(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: this cannot just match "apple_platform_type": "macos" because that option
|
# Generic MacOS.
|
||||||
# defaults to "macos" even when building on Linux!
|
config_setting(
|
||||||
alias(
|
|
||||||
name = "macos",
|
name = "macos",
|
||||||
actual = select({
|
constraint_values = [
|
||||||
":macos_i386": ":macos_i386",
|
"@platforms//os:macos",
|
||||||
":macos_x86_64": ":macos_x86_64",
|
],
|
||||||
":macos_arm64": ":macos_arm64",
|
|
||||||
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
|
|
||||||
}),
|
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: this also matches on crosstool_top so that it does not produce ambiguous
|
# MacOS x86 64-bit.
|
||||||
# selectors when used together with "android".
|
config_setting(
|
||||||
|
name = "macos_x86_64",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:macos",
|
||||||
|
"@platforms//cpu:x86_64",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# MacOS ARM64.
|
||||||
|
config_setting(
|
||||||
|
name = "macos_arm64",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:macos",
|
||||||
|
"@platforms//cpu:arm64",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generic iOS.
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "ios",
|
name = "ios",
|
||||||
values = {
|
constraint_values = [
|
||||||
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
|
"@platforms//os:ios",
|
||||||
"apple_platform_type": "ios",
|
],
|
||||||
},
|
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# iOS device ARM32.
|
||||||
|
config_setting(
|
||||||
|
name = "ios_armv7",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:ios",
|
||||||
|
"@platforms//cpu:arm",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# iOS device ARM64.
|
||||||
|
config_setting(
|
||||||
|
name = "ios_arm64",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:ios",
|
||||||
|
"@platforms//cpu:arm64",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# iOS device ARM64E.
|
||||||
|
config_setting(
|
||||||
|
name = "ios_arm64e",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:ios",
|
||||||
|
"@platforms//cpu:arm64e",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# iOS simulator x86 32-bit.
|
||||||
|
config_setting(
|
||||||
|
name = "ios_i386",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:ios",
|
||||||
|
"@platforms//cpu:x86_32",
|
||||||
|
"@build_bazel_apple_support//constraints:simulator",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# iOS simulator x86 64-bit.
|
||||||
|
config_setting(
|
||||||
|
name = "ios_x86_64",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:ios",
|
||||||
|
"@platforms//cpu:x86_64",
|
||||||
|
"@build_bazel_apple_support//constraints:simulator",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# iOS simulator ARM64.
|
||||||
|
config_setting(
|
||||||
|
name = "ios_sim_arm64",
|
||||||
|
constraint_values = [
|
||||||
|
"@platforms//os:ios",
|
||||||
|
"@platforms//cpu:arm64",
|
||||||
|
"@build_bazel_apple_support//constraints:simulator",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generic Apple.
|
||||||
alias(
|
alias(
|
||||||
name = "apple",
|
name = "apple",
|
||||||
actual = select({
|
actual = select({
|
||||||
|
@ -102,49 +180,6 @@ alias(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
config_setting(
|
|
||||||
name = "macos_i386",
|
|
||||||
values = {
|
|
||||||
"apple_platform_type": "macos",
|
|
||||||
"cpu": "darwin",
|
|
||||||
},
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
config_setting(
|
|
||||||
name = "macos_x86_64",
|
|
||||||
values = {
|
|
||||||
"apple_platform_type": "macos",
|
|
||||||
"cpu": "darwin_x86_64",
|
|
||||||
},
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
config_setting(
|
|
||||||
name = "macos_arm64",
|
|
||||||
values = {
|
|
||||||
"apple_platform_type": "macos",
|
|
||||||
"cpu": "darwin_arm64",
|
|
||||||
},
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
[
|
|
||||||
config_setting(
|
|
||||||
name = arch,
|
|
||||||
values = {"cpu": arch},
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
for arch in [
|
|
||||||
"ios_i386",
|
|
||||||
"ios_x86_64",
|
|
||||||
"ios_armv7",
|
|
||||||
"ios_arm64",
|
|
||||||
"ios_arm64e",
|
|
||||||
"ios_sim_arm64",
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "windows",
|
name = "windows",
|
||||||
values = {"cpu": "x64_windows"},
|
values = {"cpu": "x64_windows"},
|
||||||
|
|
|
@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator);
|
||||||
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
|
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
|
||||||
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
|
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::unique_ptr<audio_dsp::WindowFunction> MakeWindowFun(
|
||||||
|
const SpectrogramCalculatorOptions::WindowType window_type) {
|
||||||
|
switch (window_type) {
|
||||||
|
// The cosine window and square root of Hann are equivalent.
|
||||||
|
case SpectrogramCalculatorOptions::COSINE:
|
||||||
|
case SpectrogramCalculatorOptions::SQRT_HANN:
|
||||||
|
return std::make_unique<audio_dsp::CosineWindow>();
|
||||||
|
case SpectrogramCalculatorOptions::HANN:
|
||||||
|
return std::make_unique<audio_dsp::HannWindow>();
|
||||||
|
case SpectrogramCalculatorOptions::HAMMING:
|
||||||
|
return std::make_unique<audio_dsp::HammingWindow>();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
||||||
SpectrogramCalculatorOptions spectrogram_options =
|
SpectrogramCalculatorOptions spectrogram_options =
|
||||||
cc->Options<SpectrogramCalculatorOptions>();
|
cc->Options<SpectrogramCalculatorOptions>();
|
||||||
|
@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
output_scale_ = spectrogram_options.output_scale();
|
output_scale_ = spectrogram_options.output_scale();
|
||||||
|
|
||||||
std::vector<double> window;
|
auto window_fun = MakeWindowFun(spectrogram_options.window_type());
|
||||||
switch (spectrogram_options.window_type()) {
|
if (window_fun == nullptr) {
|
||||||
case SpectrogramCalculatorOptions::COSINE:
|
return absl::Status(absl::StatusCode::kInvalidArgument,
|
||||||
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
|
absl::StrCat("Invalid window type ",
|
||||||
&window);
|
spectrogram_options.window_type()));
|
||||||
break;
|
|
||||||
case SpectrogramCalculatorOptions::HANN:
|
|
||||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
|
||||||
&window);
|
|
||||||
break;
|
|
||||||
case SpectrogramCalculatorOptions::HAMMING:
|
|
||||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
|
|
||||||
&window);
|
|
||||||
break;
|
|
||||||
case SpectrogramCalculatorOptions::SQRT_HANN: {
|
|
||||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
|
||||||
&window);
|
|
||||||
absl::c_transform(window, window.begin(),
|
|
||||||
[](double x) { return std::sqrt(x); });
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
std::vector<double> window;
|
||||||
|
window_fun->GetPeriodicSamples(frame_duration_samples_, &window);
|
||||||
|
|
||||||
// Propagate settings down to the actual Spectrogram object.
|
// Propagate settings down to the actual Spectrogram object.
|
||||||
spectrogram_generators_.clear();
|
spectrogram_generators_.clear();
|
||||||
|
|
|
@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions {
|
||||||
HANN = 0;
|
HANN = 0;
|
||||||
HAMMING = 1;
|
HAMMING = 1;
|
||||||
COSINE = 2;
|
COSINE = 2;
|
||||||
SQRT_HANN = 4;
|
SQRT_HANN = 4; // Alias of COSINE.
|
||||||
}
|
}
|
||||||
optional WindowType window_type = 6 [default = HANN];
|
optional WindowType window_type = 6 [default = HANN];
|
||||||
|
|
||||||
|
|
|
@ -381,17 +381,6 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "clip_detection_vector_size_calculator",
|
|
||||||
srcs = ["clip_detection_vector_size_calculator.cc"],
|
|
||||||
deps = [
|
|
||||||
":clip_vector_size_calculator",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "clip_vector_size_calculator_test",
|
name = "clip_vector_size_calculator_test",
|
||||||
srcs = ["clip_vector_size_calculator_test.cc"],
|
srcs = ["clip_vector_size_calculator_test.cc"],
|
||||||
|
|
|
@ -76,4 +76,9 @@ REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator);
|
||||||
// A calculator to process std::vector<mediapipe::Image>.
|
// A calculator to process std::vector<mediapipe::Image>.
|
||||||
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
|
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
|
||||||
REGISTER_CALCULATOR(BeginLoopImageCalculator);
|
REGISTER_CALCULATOR(BeginLoopImageCalculator);
|
||||||
|
|
||||||
|
// A calculator to process std::vector<float>.
|
||||||
|
typedef BeginLoopCalculator<std::vector<float>> BeginLoopFloatCalculator;
|
||||||
|
REGISTER_CALCULATOR(BeginLoopFloatCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -123,7 +123,10 @@ class PreviousLoopbackCalculator : public Node {
|
||||||
// However, LOOP packet is empty.
|
// However, LOOP packet is empty.
|
||||||
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
||||||
} else {
|
} else {
|
||||||
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
|
// Avoids sending leftovers to a stream that's already closed.
|
||||||
|
if (!kPrevLoop(cc).IsClosed()) {
|
||||||
|
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
loop_packets_.pop_front();
|
loop_packets_.pop_front();
|
||||||
main_packet_specs_.pop_front();
|
main_packet_specs_.pop_front();
|
||||||
|
|
|
@ -135,7 +135,6 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:image_frame_opencv",
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
|
||||||
"//mediapipe/framework/port:opencv_imgproc",
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
],
|
],
|
||||||
|
|
|
@ -112,7 +112,7 @@ class BilateralFilterCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(BilateralFilterCalculator);
|
REGISTER_CALCULATOR(BilateralFilterCalculator);
|
||||||
|
|
||||||
absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) {
|
||||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kInputFrameTag) &&
|
if (cc->Inputs().HasTag(kInputFrameTag) &&
|
||||||
cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
||||||
|
|
|
@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator);
|
||||||
|
|
||||||
absl::Status SegmentationSmoothingCalculator::GetContract(
|
absl::Status SegmentationSmoothingCalculator::GetContract(
|
||||||
CalculatorContract* cc) {
|
CalculatorContract* cc) {
|
||||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||||
|
|
||||||
cc->Inputs().Tag(kCurrentMaskTag).Set<Image>();
|
cc->Inputs().Tag(kCurrentMaskTag).Set<Image>();
|
||||||
cc->Inputs().Tag(kPreviousMaskTag).Set<Image>();
|
cc->Inputs().Tag(kPreviousMaskTag).Set<Image>();
|
||||||
|
|
|
@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(SetAlphaCalculator);
|
REGISTER_CALCULATOR(SetAlphaCalculator);
|
||||||
|
|
||||||
absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
|
||||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||||
|
|
||||||
bool use_gpu = false;
|
bool use_gpu = false;
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ std::string FourCCToString(libyuv::FourCC fourcc) {
|
||||||
buf[0] = (fourcc >> 24) & 0xff;
|
buf[0] = (fourcc >> 24) & 0xff;
|
||||||
buf[1] = (fourcc >> 16) & 0xff;
|
buf[1] = (fourcc >> 16) & 0xff;
|
||||||
buf[2] = (fourcc >> 8) & 0xff;
|
buf[2] = (fourcc >> 8) & 0xff;
|
||||||
buf[3] = (fourcc)&0xff;
|
buf[3] = (fourcc) & 0xff;
|
||||||
buf[4] = 0;
|
buf[4] = 0;
|
||||||
return std::string(buf);
|
return std::string(buf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -282,18 +282,23 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
|
||||||
if (options.has_volume_gain_db()) {
|
if (options.has_volume_gain_db()) {
|
||||||
gain_ = pow(10, options.volume_gain_db() / 20.0);
|
gain_ = pow(10, options.volume_gain_db() / 20.0);
|
||||||
}
|
}
|
||||||
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
|
if (options.has_source_sample_rate()) {
|
||||||
!kAudioIn(cc).Header().IsEmpty())
|
source_sample_rate_ = options.source_sample_rate();
|
||||||
<< "Must either specify the time series header of the \"AUDIO\" stream "
|
} else {
|
||||||
"or have the \"SAMPLE_RATE\" stream connected.";
|
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
|
||||||
if (!kAudioIn(cc).Header().IsEmpty()) {
|
!kAudioIn(cc).Header().IsEmpty())
|
||||||
mediapipe::TimeSeriesHeader input_header;
|
<< "Must either specify the time series header of the \"AUDIO\" stream "
|
||||||
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
|
"or have the \"SAMPLE_RATE\" stream connected.";
|
||||||
kAudioIn(cc).Header(), &input_header));
|
if (!kAudioIn(cc).Header().IsEmpty()) {
|
||||||
if (stream_mode_) {
|
mediapipe::TimeSeriesHeader input_header;
|
||||||
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
|
MP_RETURN_IF_ERROR(
|
||||||
} else {
|
mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
|
||||||
source_sample_rate_ = input_header.sample_rate();
|
kAudioIn(cc).Header(), &input_header));
|
||||||
|
if (stream_mode_) {
|
||||||
|
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
|
||||||
|
} else {
|
||||||
|
source_sample_rate_ = input_header.sample_rate();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AppendZerosToSampleBuffer(padding_samples_before_);
|
AppendZerosToSampleBuffer(padding_samples_before_);
|
||||||
|
|
|
@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions {
|
||||||
// The volume gain, measured in dB.
|
// The volume gain, measured in dB.
|
||||||
// Scale the input audio amplitude by 10^(volume_gain_db/20).
|
// Scale the input audio amplitude by 10^(volume_gain_db/20).
|
||||||
optional double volume_gain_db = 12;
|
optional double volume_gain_db = 12;
|
||||||
|
|
||||||
|
// The source number of samples per second (hertz) of the input audio buffers.
|
||||||
|
optional double source_sample_rate = 13;
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
|
||||||
gpu_delegate_options);
|
gpu_delegate_options);
|
||||||
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||||
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||||
|
bool UseSerializedModel() const { return use_serialized_model_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool use_kernel_caching_ = false;
|
bool use_kernel_caching_ = false;
|
||||||
|
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
|
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
|
||||||
MP_RETURN_IF_ERROR(
|
|
||||||
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
|
|
||||||
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||||
tflite_gpu_runner_.reset();
|
tflite_gpu_runner_.reset();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
|
||||||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (on_disk_cache_helper_.UseSerializedModel()) {
|
||||||
|
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
|
||||||
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
||||||
return tflite_gpu_runner_->Build();
|
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||||
|
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
|
@ -256,6 +256,7 @@ class TensorsToDetectionsCalculator : public Node {
|
||||||
|
|
||||||
bool gpu_inited_ = false;
|
bool gpu_inited_ = false;
|
||||||
bool gpu_input_ = false;
|
bool gpu_input_ = false;
|
||||||
|
bool gpu_has_enough_work_groups_ = true;
|
||||||
bool anchors_init_ = false;
|
bool anchors_init_ = false;
|
||||||
};
|
};
|
||||||
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
|
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
|
||||||
|
@ -291,7 +292,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
|
||||||
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
||||||
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||||
bool gpu_processing = false;
|
bool gpu_processing = false;
|
||||||
if (CanUseGpu()) {
|
if (CanUseGpu() && gpu_has_enough_work_groups_) {
|
||||||
// Use GPU processing only if at least one input tensor is already on GPU
|
// Use GPU processing only if at least one input tensor is already on GPU
|
||||||
// (to avoid CPU->GPU overhead).
|
// (to avoid CPU->GPU overhead).
|
||||||
for (const auto& tensor : *kInTensors(cc)) {
|
for (const auto& tensor : *kInTensors(cc)) {
|
||||||
|
@ -321,11 +322,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
||||||
RET_CHECK(!has_custom_box_indices_);
|
RET_CHECK(!has_custom_box_indices_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gpu_processing) {
|
if (gpu_processing && !gpu_inited_) {
|
||||||
if (!gpu_inited_) {
|
auto status = GpuInit(cc);
|
||||||
MP_RETURN_IF_ERROR(GpuInit(cc));
|
if (status.ok()) {
|
||||||
gpu_inited_ = true;
|
gpu_inited_ = true;
|
||||||
|
} else if (status.code() == absl::StatusCode::kFailedPrecondition) {
|
||||||
|
// For initialization error because of hardware limitation, fallback to
|
||||||
|
// CPU processing.
|
||||||
|
LOG(WARNING) << status.message();
|
||||||
|
} else {
|
||||||
|
// For other error, let the error propagates.
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (gpu_processing && gpu_inited_) {
|
||||||
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
|
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
|
||||||
} else {
|
} else {
|
||||||
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
||||||
|
@ -346,17 +356,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
||||||
// TODO: Add flexible input tensor size handling.
|
// TODO: Add flexible input tensor size handling.
|
||||||
auto raw_box_tensor =
|
auto raw_box_tensor =
|
||||||
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
|
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
|
||||||
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
|
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
|
if (raw_box_tensor->shape().dims.size() == 3) {
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
|
// The tensors from CPU inference has dim 3.
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
|
||||||
|
} else if (raw_box_tensor->shape().dims.size() == 4) {
|
||||||
|
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
|
||||||
|
// we allow tensors with 4 dims.
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_);
|
||||||
|
} else {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The dimensions of box Tensor must be 3 or 4.");
|
||||||
|
}
|
||||||
auto raw_score_tensor =
|
auto raw_score_tensor =
|
||||||
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
|
if (raw_score_tensor->shape().dims.size() == 3) {
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
// The tensors from CPU inference has dim 3.
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
|
||||||
|
} else if (raw_score_tensor->shape().dims.size() == 4) {
|
||||||
|
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
|
||||||
|
// we allow tensors with 4 dims.
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_);
|
||||||
|
} else {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The dimensions of score Tensor must be 3 or 4.");
|
||||||
|
}
|
||||||
auto raw_box_view = raw_box_tensor->GetCpuReadView();
|
auto raw_box_view = raw_box_tensor->GetCpuReadView();
|
||||||
auto raw_boxes = raw_box_view.buffer<float>();
|
auto raw_boxes = raw_box_view.buffer<float>();
|
||||||
auto raw_scores_view = raw_score_tensor->GetCpuReadView();
|
auto raw_scores_view = raw_score_tensor->GetCpuReadView();
|
||||||
|
@ -1111,8 +1145,13 @@ void main() {
|
||||||
int max_wg_size; // typically <= 1024
|
int max_wg_size; // typically <= 1024
|
||||||
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
|
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
|
||||||
&max_wg_size); // y-dim
|
&max_wg_size); // y-dim
|
||||||
CHECK_LT(num_classes_, max_wg_size)
|
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
|
||||||
<< "# classes must be < " << max_wg_size;
|
if (!gpu_has_enough_work_groups_) {
|
||||||
|
return absl::FailedPreconditionError(absl::StrFormat(
|
||||||
|
"Hardware limitation: Processing will be done on CPU, because "
|
||||||
|
"num_classes %d exceeds the max work_group size %d.",
|
||||||
|
num_classes_, max_wg_size));
|
||||||
|
}
|
||||||
// TODO support better filtering.
|
// TODO support better filtering.
|
||||||
if (class_index_set_.is_allowlist) {
|
if (class_index_set_.is_allowlist) {
|
||||||
CHECK_EQ(class_index_set_.values.size(),
|
CHECK_EQ(class_index_set_.values.size(),
|
||||||
|
@ -1370,7 +1409,13 @@ kernel void scoreKernel(
|
||||||
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
|
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
|
||||||
// # filter classes supported is hardware dependent.
|
// # filter classes supported is hardware dependent.
|
||||||
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
|
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
|
||||||
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
|
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
|
||||||
|
if (!gpu_has_enough_work_groups_) {
|
||||||
|
return absl::FailedPreconditionError(absl::StrFormat(
|
||||||
|
"Hardware limitation: Processing will be done on CPU, because "
|
||||||
|
"num_classes %d exceeds the max work_group size %d.",
|
||||||
|
num_classes_, max_wg_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||||
|
|
|
@ -406,8 +406,13 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed
|
# This dependency removed the following 3 targets because they failed Boq conformance test:
|
||||||
# Boq conformance test. Weigh your use case to see if this will work for you.
|
#
|
||||||
|
# tensorflow_jellyfish_deps
|
||||||
|
# jfprof_lib
|
||||||
|
# xprofilez_with_server
|
||||||
|
#
|
||||||
|
# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorflow_inference_calculator_for_boq",
|
name = "tensorflow_inference_calculator_for_boq",
|
||||||
srcs = ["tensorflow_inference_calculator.cc"],
|
srcs = ["tensorflow_inference_calculator.cc"],
|
||||||
|
@ -927,7 +932,6 @@ cc_test(
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:image_frame_opencv",
|
|
||||||
"//mediapipe/framework/formats:location",
|
"//mediapipe/framework/formats:location",
|
||||||
"//mediapipe/framework/formats:location_opencv",
|
"//mediapipe/framework/formats:location_opencv",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
|
|
@ -164,8 +164,8 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
|
RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
|
||||||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
|
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
|
||||||
<< "Neither the output stream nor the output side packet is set to "
|
<< "Neither the output stream nor the output side packet is set to "
|
||||||
"output the sequence example.";
|
"output the sequence example.";
|
||||||
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
|
if (cc->Outputs().HasTag(kSequenceExampleTag)) {
|
||||||
|
|
|
@ -23,7 +23,6 @@
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/formats/location_opencv.h"
|
#include "mediapipe/framework/formats/location_opencv.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
@ -96,7 +95,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
|
||||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
encoded_image.set_width(2);
|
encoded_image.set_width(2);
|
||||||
|
@ -139,7 +139,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
|
||||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
encoded_image.set_width(2);
|
encoded_image.set_width(2);
|
||||||
|
@ -378,7 +379,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
|
@ -410,7 +412,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
|
||||||
|
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
std::string test_flow_string(bytes.begin(), bytes.end());
|
std::string test_flow_string(bytes.begin(), bytes.end());
|
||||||
OpenCvImageEncoderCalculatorResults encoded_flow;
|
OpenCvImageEncoderCalculatorResults encoded_flow;
|
||||||
encoded_flow.set_encoded_image(test_flow_string);
|
encoded_flow.set_encoded_image(test_flow_string);
|
||||||
|
@ -618,7 +621,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
||||||
}
|
}
|
||||||
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
encoded_image.set_width(width);
|
encoded_image.set_width(width);
|
||||||
|
@ -767,7 +771,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
|
||||||
|
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
std::string test_flow_string(bytes.begin(), bytes.end());
|
std::string test_flow_string(bytes.begin(), bytes.end());
|
||||||
OpenCvImageEncoderCalculatorResults encoded_flow;
|
OpenCvImageEncoderCalculatorResults encoded_flow;
|
||||||
encoded_flow.set_encoded_image(test_flow_string);
|
encoded_flow.set_encoded_image(test_flow_string);
|
||||||
|
@ -813,7 +818,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
|
||||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
std::string test_flow_string(bytes.begin(), bytes.end());
|
std::string test_flow_string(bytes.begin(), bytes.end());
|
||||||
OpenCvImageEncoderCalculatorResults encoded_flow;
|
OpenCvImageEncoderCalculatorResults encoded_flow;
|
||||||
encoded_flow.set_encoded_image(test_flow_string);
|
encoded_flow.set_encoded_image(test_flow_string);
|
||||||
|
@ -970,7 +976,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
||||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
encoded_image.set_width(2);
|
encoded_image.set_width(2);
|
||||||
|
@ -1021,7 +1028,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
||||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
|
ASSERT_TRUE(
|
||||||
|
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
|
||||||
OpenCvImageEncoderCalculatorResults encoded_image;
|
OpenCvImageEncoderCalculatorResults encoded_image;
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
int height = 2;
|
int height = 2;
|
||||||
|
|
|
@ -172,7 +172,7 @@ class AnnotationOverlayCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
||||||
|
|
||||||
absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
|
||||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||||
|
|
||||||
bool use_gpu = false;
|
bool use_gpu = false;
|
||||||
|
|
||||||
|
@ -189,13 +189,13 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
if (cc->Inputs().HasTag(kGpuBufferTag)) {
|
if (cc->Inputs().HasTag(kGpuBufferTag)) {
|
||||||
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
|
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
|
||||||
CHECK(cc->Outputs().HasTag(kGpuBufferTag));
|
RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag));
|
||||||
use_gpu = true;
|
use_gpu = true;
|
||||||
}
|
}
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
if (cc->Inputs().HasTag(kImageFrameTag)) {
|
if (cc->Inputs().HasTag(kImageFrameTag)) {
|
||||||
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
|
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
|
||||||
CHECK(cc->Outputs().HasTag(kImageFrameTag));
|
RET_CHECK(cc->Outputs().HasTag(kImageFrameTag));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Data streams to render.
|
// Data streams to render.
|
||||||
|
|
|
@ -322,27 +322,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
options_.presence_threshold(), options_.connection_color(), thickness,
|
options_.presence_threshold(), options_.connection_color(), thickness,
|
||||||
/*normalized=*/false, render_data.get());
|
/*normalized=*/false, render_data.get());
|
||||||
}
|
}
|
||||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
if (options_.render_landmarks()) {
|
||||||
const Landmark& landmark = landmarks.landmark(i);
|
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||||
|
const Landmark& landmark = landmarks.landmark(i);
|
||||||
|
|
||||||
if (!IsLandmarkVisibleAndPresent<Landmark>(
|
if (!IsLandmarkVisibleAndPresent<Landmark>(
|
||||||
landmark, options_.utilize_visibility(),
|
landmark, options_.utilize_visibility(),
|
||||||
options_.visibility_threshold(), options_.utilize_presence(),
|
options_.visibility_threshold(), options_.utilize_presence(),
|
||||||
options_.presence_threshold())) {
|
options_.presence_threshold())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* landmark_data_render = AddPointRenderData(
|
auto* landmark_data_render = AddPointRenderData(
|
||||||
options_.landmark_color(), thickness, render_data.get());
|
options_.landmark_color(), thickness, render_data.get());
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render,
|
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
||||||
options_.min_depth_circle_thickness(),
|
landmark_data_render,
|
||||||
options_.max_depth_circle_thickness());
|
options_.min_depth_circle_thickness(),
|
||||||
|
options_.max_depth_circle_thickness());
|
||||||
|
}
|
||||||
|
auto* landmark_data = landmark_data_render->mutable_point();
|
||||||
|
landmark_data->set_normalized(false);
|
||||||
|
landmark_data->set_x(landmark.x());
|
||||||
|
landmark_data->set_y(landmark.y());
|
||||||
}
|
}
|
||||||
auto* landmark_data = landmark_data_render->mutable_point();
|
|
||||||
landmark_data->set_normalized(false);
|
|
||||||
landmark_data->set_x(landmark.x());
|
|
||||||
landmark_data->set_y(landmark.y());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -368,27 +371,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
options_.presence_threshold(), options_.connection_color(), thickness,
|
options_.presence_threshold(), options_.connection_color(), thickness,
|
||||||
/*normalized=*/true, render_data.get());
|
/*normalized=*/true, render_data.get());
|
||||||
}
|
}
|
||||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
if (options_.render_landmarks()) {
|
||||||
const NormalizedLandmark& landmark = landmarks.landmark(i);
|
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||||
|
const NormalizedLandmark& landmark = landmarks.landmark(i);
|
||||||
|
|
||||||
if (!IsLandmarkVisibleAndPresent<NormalizedLandmark>(
|
if (!IsLandmarkVisibleAndPresent<NormalizedLandmark>(
|
||||||
landmark, options_.utilize_visibility(),
|
landmark, options_.utilize_visibility(),
|
||||||
options_.visibility_threshold(), options_.utilize_presence(),
|
options_.visibility_threshold(), options_.utilize_presence(),
|
||||||
options_.presence_threshold())) {
|
options_.presence_threshold())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* landmark_data_render = AddPointRenderData(
|
auto* landmark_data_render = AddPointRenderData(
|
||||||
options_.landmark_color(), thickness, render_data.get());
|
options_.landmark_color(), thickness, render_data.get());
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render,
|
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
||||||
options_.min_depth_circle_thickness(),
|
landmark_data_render,
|
||||||
options_.max_depth_circle_thickness());
|
options_.min_depth_circle_thickness(),
|
||||||
|
options_.max_depth_circle_thickness());
|
||||||
|
}
|
||||||
|
auto* landmark_data = landmark_data_render->mutable_point();
|
||||||
|
landmark_data->set_normalized(true);
|
||||||
|
landmark_data->set_x(landmark.x());
|
||||||
|
landmark_data->set_y(landmark.y());
|
||||||
}
|
}
|
||||||
auto* landmark_data = landmark_data_render->mutable_point();
|
|
||||||
landmark_data->set_normalized(true);
|
|
||||||
landmark_data->set_x(landmark.x());
|
|
||||||
landmark_data->set_y(landmark.y());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,10 @@ message LandmarksToRenderDataCalculatorOptions {
|
||||||
|
|
||||||
// Color of the landmarks.
|
// Color of the landmarks.
|
||||||
optional Color landmark_color = 2;
|
optional Color landmark_color = 2;
|
||||||
|
|
||||||
|
// Whether to render landmarks as points.
|
||||||
|
optional bool render_landmarks = 14 [default = true];
|
||||||
|
|
||||||
// Color of the connections.
|
// Color of the connections.
|
||||||
optional Color connection_color = 3;
|
optional Color connection_color = 3;
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,7 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RefineLandmarksFromHeatMap(
|
||||||
int center_row = out_lms.landmark(lm_index).y() * hm_height;
|
int center_row = out_lms.landmark(lm_index).y() * hm_height;
|
||||||
// Point is outside of the image let's keep it intact.
|
// Point is outside of the image let's keep it intact.
|
||||||
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
|
if (center_col < 0 || center_col >= hm_width || center_row < 0 ||
|
||||||
center_col >= hm_height) {
|
center_row >= hm_height) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,6 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:video_stream_header",
|
"//mediapipe/framework/formats:video_stream_header",
|
||||||
"//mediapipe/framework/port:opencv_imgproc",
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
"//mediapipe/framework/port:opencv_video",
|
"//mediapipe/framework/port:opencv_video",
|
||||||
"//mediapipe/framework/port:ret_check",
|
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/tool:status_util",
|
"//mediapipe/framework/tool:status_util",
|
||||||
],
|
],
|
||||||
|
@ -341,7 +340,6 @@ cc_test(
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/framework/tool:test_util",
|
"//mediapipe/framework/tool:test_util",
|
||||||
"@com_google_absl//absl/flags:flag",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -367,7 +365,6 @@ cc_test(
|
||||||
"//mediapipe/framework/port:opencv_video",
|
"//mediapipe/framework/port:opencv_video",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/framework/tool:test_util",
|
"//mediapipe/framework/tool:test_util",
|
||||||
"@com_google_absl//absl/flags:flag",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -451,7 +448,6 @@ cc_test(
|
||||||
"//mediapipe/framework/tool:test_util",
|
"//mediapipe/framework/tool:test_util",
|
||||||
"//mediapipe/util/tracking:box_tracker_cc_proto",
|
"//mediapipe/util/tracking:box_tracker_cc_proto",
|
||||||
"//mediapipe/util/tracking:tracking_cc_proto",
|
"//mediapipe/util/tracking:tracking_cc_proto",
|
||||||
"@com_google_absl//absl/flags:flag",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -1,6 +1,6 @@
|
||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip
|
||||||
networkTimeout=10000
|
networkTimeout=10000
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "facedetectioncpu",
|
name = "facedetectioncpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "facedetectiongpu",
|
name = "facedetectiongpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "faceeffect",
|
name = "faceeffect",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "facemeshgpu",
|
name = "facemeshgpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "handdetectiongpu",
|
name = "handdetectiongpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "handtrackinggpu",
|
name = "handtrackinggpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "helloworld",
|
name = "helloworld",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "holistictrackinggpu",
|
name = "holistictrackinggpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "iristrackinggpu",
|
name = "iristrackinggpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "objectdetectioncpu",
|
name = "objectdetectioncpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "objectdetectiongpu",
|
name = "objectdetectiongpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "objectdetectiontrackinggpu",
|
name = "objectdetectiontrackinggpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "posetrackinggpu",
|
name = "posetrackinggpu",
|
||||||
|
|
|
@ -24,7 +24,7 @@ load(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "selfiesegmentationgpu",
|
name = "selfiesegmentationgpu",
|
||||||
|
|
|
@ -44,6 +44,9 @@ bzl_library(
|
||||||
"encode_binary_proto.bzl",
|
"encode_binary_proto.bzl",
|
||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"@bazel_skylib//lib:paths",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
|
|
|
@ -64,58 +64,13 @@ class CalculatorBaseFactoryFor<
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
// Defining a member of this type causes P to be ODR-used, which forces its
|
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(
|
||||||
// instantiation if it's a static member of a template.
|
NodeRegistrator, mediapipe::CalculatorBaseRegistry, T::kCalculatorName,
|
||||||
// Previously we depended on the pointer's value to determine whether the size
|
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>)
|
||||||
// of a character array is 0 or 1, forcing it to be instantiated so the
|
|
||||||
// compiler can determine the object's layout. But using it as a template
|
|
||||||
// argument is more compact.
|
|
||||||
template <auto* P>
|
|
||||||
struct ForceStaticInstantiation {
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
// Just having it as the template argument does not count as a use for
|
|
||||||
// MSVC.
|
|
||||||
static constexpr bool Use() { return P != nullptr; }
|
|
||||||
char force_static[Use()];
|
|
||||||
#endif // _MSC_VER
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper template for forcing the definition of a static registration token.
|
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(SubgraphRegistrator,
|
||||||
template <typename T>
|
mediapipe::SubgraphRegistry,
|
||||||
struct NodeRegistrationStatic {
|
T::kCalculatorName, absl::make_unique<T>)
|
||||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
|
||||||
|
|
||||||
static mediapipe::RegistrationToken Make() {
|
|
||||||
return mediapipe::CalculatorBaseRegistry::Register(
|
|
||||||
T::kCalculatorName,
|
|
||||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<T>>,
|
|
||||||
__FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
using RequireStatics = ForceStaticInstantiation<®istration>;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Static members of template classes can be defined in the header.
|
|
||||||
template <typename T>
|
|
||||||
NoDestructor<mediapipe::RegistrationToken>
|
|
||||||
NodeRegistrationStatic<T>::registration(NodeRegistrationStatic<T>::Make());
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct SubgraphRegistrationImpl {
|
|
||||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
|
||||||
|
|
||||||
static mediapipe::RegistrationToken Make() {
|
|
||||||
return mediapipe::SubgraphRegistry::Register(
|
|
||||||
T::kCalculatorName, absl::make_unique<T>, __FILE__, __LINE__);
|
|
||||||
}
|
|
||||||
|
|
||||||
using RequireStatics = ForceStaticInstantiation<®istration>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
NoDestructor<mediapipe::RegistrationToken>
|
|
||||||
SubgraphRegistrationImpl<T>::registration(
|
|
||||||
SubgraphRegistrationImpl<T>::Make());
|
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
|
@ -128,14 +83,7 @@ template <class Impl = void>
|
||||||
class RegisteredNode;
|
class RegisteredNode;
|
||||||
|
|
||||||
template <class Impl>
|
template <class Impl>
|
||||||
class RegisteredNode : public Node {
|
class RegisteredNode : public Node, private internal::NodeRegistrator<Impl> {};
|
||||||
private:
|
|
||||||
// The member below triggers instantiation of the registration static.
|
|
||||||
// Note that the constructor of calculator subclasses is only invoked through
|
|
||||||
// the registration token, and so we cannot simply use the static in the
|
|
||||||
// constructor.
|
|
||||||
typename internal::NodeRegistrationStatic<Impl>::RequireStatics register_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// No-op version for backwards compatibility.
|
// No-op version for backwards compatibility.
|
||||||
template <>
|
template <>
|
||||||
|
@ -217,31 +165,27 @@ class NodeImpl : public RegisteredNode<Impl>, public Intf {
|
||||||
// TODO: verify that the subgraph config fully implements the
|
// TODO: verify that the subgraph config fully implements the
|
||||||
// declared interface.
|
// declared interface.
|
||||||
template <class Intf, class Impl>
|
template <class Intf, class Impl>
|
||||||
class SubgraphImpl : public Subgraph, public Intf {
|
class SubgraphImpl : public Subgraph,
|
||||||
private:
|
public Intf,
|
||||||
typename internal::SubgraphRegistrationImpl<Impl>::RequireStatics register_;
|
private internal::SubgraphRegistrator<Impl> {};
|
||||||
};
|
|
||||||
|
|
||||||
// This macro is used to register a calculator that does not use automatic
|
// This macro is used to register a calculator that does not use automatic
|
||||||
// registration. Deprecated.
|
// registration. Deprecated.
|
||||||
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
|
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
|
||||||
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
|
MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
|
||||||
REGISTRY_STATIC_VAR(calculator_registration, \
|
mediapipe::CalculatorBaseRegistry, calculator_registration, \
|
||||||
__LINE__)(mediapipe::CalculatorBaseRegistry::Register( \
|
Impl::kCalculatorName, \
|
||||||
Impl::kCalculatorName, \
|
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>)
|
||||||
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>, \
|
|
||||||
__FILE__, __LINE__))
|
|
||||||
|
|
||||||
// This macro is used to register a non-split-contract calculator. Deprecated.
|
// This macro is used to register a non-split-contract calculator. Deprecated.
|
||||||
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
|
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
|
||||||
|
|
||||||
// This macro is used to define a subgraph that does not use automatic
|
// This macro is used to define a subgraph that does not use automatic
|
||||||
// registration. Deprecated.
|
// registration. Deprecated.
|
||||||
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
|
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
|
||||||
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
|
MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \
|
||||||
REGISTRY_STATIC_VAR(subgraph_registration, \
|
mediapipe::SubgraphRegistry, subgraph_registration, \
|
||||||
__LINE__)(mediapipe::SubgraphRegistry::Register( \
|
Impl::kCalculatorName, absl::make_unique<Impl>)
|
||||||
Impl::kCalculatorName, absl::make_unique<Impl>, __FILE__, __LINE__))
|
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -183,8 +183,7 @@ TEST(CalculatorTest, CreateByNameWhitelisted) {
|
||||||
CalculatorBaseRegistry::Register(
|
CalculatorBaseRegistry::Register(
|
||||||
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
|
"::mediapipe::test_ns::whitelisted_ns::DeadCalculator",
|
||||||
absl::make_unique<internal::CalculatorBaseFactoryFor<
|
absl::make_unique<internal::CalculatorBaseFactoryFor<
|
||||||
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>,
|
mediapipe::test_ns::whitelisted_ns::DeadCalculator>>);
|
||||||
__FILE__, __LINE__);
|
|
||||||
|
|
||||||
// A whitelisted calculator can be found in its own namespace.
|
// A whitelisted calculator can be found in its own namespace.
|
||||||
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //
|
MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( //
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
#define MEDIAPIPE_DEPS_REGISTRATION_H_
|
#define MEDIAPIPE_DEPS_REGISTRATION_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
@ -145,6 +144,23 @@ template <typename T>
|
||||||
struct WrapStatusOr<absl::StatusOr<T>> {
|
struct WrapStatusOr<absl::StatusOr<T>> {
|
||||||
using type = absl::StatusOr<T>;
|
using type = absl::StatusOr<T>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Defining a member of this type causes P to be ODR-used, which forces its
|
||||||
|
// instantiation if it's a static member of a template.
|
||||||
|
// Previously we depended on the pointer's value to determine whether the size
|
||||||
|
// of a character array is 0 or 1, forcing it to be instantiated so the
|
||||||
|
// compiler can determine the object's layout. But using it as a template
|
||||||
|
// argument is more compact.
|
||||||
|
template <auto* P>
|
||||||
|
struct ForceStaticInstantiation {
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
// Just having it as the template argument does not count as a use for
|
||||||
|
// MSVC.
|
||||||
|
static constexpr bool Use() { return P != nullptr; }
|
||||||
|
char force_static[Use()];
|
||||||
|
#endif // _MSC_VER
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace registration_internal
|
} // namespace registration_internal
|
||||||
|
|
||||||
class NamespaceAllowlist {
|
class NamespaceAllowlist {
|
||||||
|
@ -162,8 +178,7 @@ class FunctionRegistry {
|
||||||
FunctionRegistry(const FunctionRegistry&) = delete;
|
FunctionRegistry(const FunctionRegistry&) = delete;
|
||||||
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
|
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
|
||||||
|
|
||||||
RegistrationToken Register(absl::string_view name, Function func,
|
RegistrationToken Register(absl::string_view name, Function func)
|
||||||
std::string filename, uint64_t line)
|
|
||||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||||
std::string normalized_name = GetNormalizedName(name);
|
std::string normalized_name = GetNormalizedName(name);
|
||||||
absl::WriterMutexLock lock(&lock_);
|
absl::WriterMutexLock lock(&lock_);
|
||||||
|
@ -173,21 +188,10 @@ class FunctionRegistry {
|
||||||
}
|
}
|
||||||
if (functions_.insert(std::make_pair(normalized_name, std::move(func)))
|
if (functions_.insert(std::make_pair(normalized_name, std::move(func)))
|
||||||
.second) {
|
.second) {
|
||||||
#ifndef NDEBUG
|
|
||||||
locations_.emplace(normalized_name,
|
|
||||||
std::make_pair(std::move(filename), line));
|
|
||||||
#endif
|
|
||||||
return RegistrationToken(
|
return RegistrationToken(
|
||||||
[this, normalized_name]() { Unregister(normalized_name); });
|
[this, normalized_name]() { Unregister(normalized_name); });
|
||||||
}
|
}
|
||||||
#ifndef NDEBUG
|
|
||||||
LOG(FATAL) << "Function with name " << name << " already registered."
|
|
||||||
<< " First registration at "
|
|
||||||
<< locations_.at(normalized_name).first << ":"
|
|
||||||
<< locations_.at(normalized_name).second;
|
|
||||||
#else
|
|
||||||
LOG(FATAL) << "Function with name " << name << " already registered.";
|
LOG(FATAL) << "Function with name " << name << " already registered.";
|
||||||
#endif
|
|
||||||
return RegistrationToken([]() {});
|
return RegistrationToken([]() {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,11 +320,6 @@ class FunctionRegistry {
|
||||||
private:
|
private:
|
||||||
mutable absl::Mutex lock_;
|
mutable absl::Mutex lock_;
|
||||||
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
||||||
#ifndef NDEBUG
|
|
||||||
// Stores filename and line number for useful debug log.
|
|
||||||
absl::flat_hash_map<std::string, std::pair<std::string, uint32_t>> locations_
|
|
||||||
ABSL_GUARDED_BY(lock_);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// For names included in NamespaceAllowlist, strips the namespace.
|
// For names included in NamespaceAllowlist, strips the namespace.
|
||||||
std::string GetAdjustedName(absl::string_view name) {
|
std::string GetAdjustedName(absl::string_view name) {
|
||||||
|
@ -351,10 +350,8 @@ class GlobalFactoryRegistry {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static RegistrationToken Register(absl::string_view name,
|
static RegistrationToken Register(absl::string_view name,
|
||||||
typename Functions::Function func,
|
typename Functions::Function func) {
|
||||||
std::string filename, uint64_t line) {
|
return functions()->Register(name, std::move(func));
|
||||||
return functions()->Register(name, std::move(func), std::move(filename),
|
|
||||||
line);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invokes the specified factory function and returns the result.
|
// Invokes the specified factory function and returns the result.
|
||||||
|
@ -414,12 +411,77 @@ class GlobalFactoryRegistry {
|
||||||
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \
|
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \
|
||||||
static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \
|
static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \
|
||||||
new mediapipe::RegistrationToken( \
|
new mediapipe::RegistrationToken( \
|
||||||
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__))
|
RegistryType::Register(#name, __VA_ARGS__))
|
||||||
|
|
||||||
|
#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \
|
||||||
|
name, ...) \
|
||||||
|
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
|
||||||
|
new mediapipe::RegistrationToken( \
|
||||||
|
RegistryType::Register(name, __VA_ARGS__))
|
||||||
|
|
||||||
|
// TODO: migrate to the above.
|
||||||
#define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \
|
#define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \
|
||||||
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
|
static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \
|
||||||
new mediapipe::RegistrationToken( \
|
new mediapipe::RegistrationToken( \
|
||||||
RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__))
|
RegistryType::Register(#name, __VA_ARGS__))
|
||||||
|
|
||||||
|
// Defines a utility registrator class which can be used to automatically
|
||||||
|
// register factory functions.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// === Defining a registry ================================================
|
||||||
|
//
|
||||||
|
// class Component {};
|
||||||
|
//
|
||||||
|
// using ComponentRegistry = GlobalFactoryRegistry<std::unique_ptr<Component>>;
|
||||||
|
//
|
||||||
|
// === Defining a registrator =============================================
|
||||||
|
//
|
||||||
|
// MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(ComponentRegistrator,
|
||||||
|
// ComponentRegistry, T::kName,
|
||||||
|
// absl::make_unique<T>);
|
||||||
|
//
|
||||||
|
// === Defining and registering a new component. ==========================
|
||||||
|
//
|
||||||
|
// class MyComponent : public Component,
|
||||||
|
// private ComponentRegistrator<MyComponent> {
|
||||||
|
// public:
|
||||||
|
// static constexpr char kName[] = "MyComponent";
|
||||||
|
// ...
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// NOTE:
|
||||||
|
// - MyComponent is automatically registered in ComponentRegistry by
|
||||||
|
// "MyComponent" name.
|
||||||
|
// - Every component is require to provide its name (T::kName here.)
|
||||||
|
#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \
|
||||||
|
name, ...) \
|
||||||
|
template <typename T> \
|
||||||
|
struct Internal##RegistratorName { \
|
||||||
|
static NoDestructor<mediapipe::RegistrationToken> registration; \
|
||||||
|
\
|
||||||
|
static mediapipe::RegistrationToken Make() { \
|
||||||
|
return RegistryType::Register(name, __VA_ARGS__); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
using RequireStatics = \
|
||||||
|
registration_internal::ForceStaticInstantiation<®istration>; \
|
||||||
|
}; \
|
||||||
|
/* Static members of template classes can be defined in the header. */ \
|
||||||
|
template <typename T> \
|
||||||
|
NoDestructor<mediapipe::RegistrationToken> \
|
||||||
|
Internal##RegistratorName<T>::registration( \
|
||||||
|
Internal##RegistratorName<T>::Make()); \
|
||||||
|
\
|
||||||
|
template <typename T> \
|
||||||
|
class RegistratorName { \
|
||||||
|
private: \
|
||||||
|
/* The member below triggers instantiation of the registration static. */ \
|
||||||
|
/* Note that the constructor of calculator subclasses is only invoked */ \
|
||||||
|
/* through the registration token, and so we cannot simply use the */ \
|
||||||
|
/* static in theconstructor. */ \
|
||||||
|
typename Internal##RegistratorName<T>::RequireStatics register_; \
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -37,29 +37,33 @@ Args:
|
||||||
output: The desired name of the output file. Optional.
|
output: The desired name of the output file. Optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||||
|
|
||||||
PROTOC = "@com_google_protobuf//:protoc"
|
PROTOC = "@com_google_protobuf//:protoc"
|
||||||
|
|
||||||
def _canonicalize_proto_path_oss(all_protos, genfile_path):
|
def _canonicalize_proto_path_oss(f):
|
||||||
"""For the protos from external repository, canonicalize the proto path and the file name.
|
if not f.root.path:
|
||||||
|
return struct(
|
||||||
|
proto_path = ".",
|
||||||
|
file_name = f.short_path,
|
||||||
|
)
|
||||||
|
|
||||||
Returns:
|
# `f.path` looks like "<genfiles>/external/<repo>/(_virtual_imports/<library>/)?<file_name>"
|
||||||
Proto path list and proto source file list.
|
repo_name, _, file_name = f.path[len(paths.join(f.root.path, "external") + "/"):].partition("/")
|
||||||
"""
|
if file_name.startswith("_virtual_imports/"):
|
||||||
proto_paths = []
|
# This is a virtual import; move "_virtual_imports/<library>" from `repo_name` to `file_name`.
|
||||||
proto_file_names = []
|
repo_name = paths.join(repo_name, *file_name.split("/", 2)[:2])
|
||||||
for s in all_protos.to_list():
|
file_name = file_name.split("/", 2)[-1]
|
||||||
if s.path.startswith(genfile_path):
|
return struct(
|
||||||
repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/")
|
proto_path = paths.join(f.root.path, "external", repo_name),
|
||||||
|
file_name = file_name,
|
||||||
|
)
|
||||||
|
|
||||||
# handle virtual imports
|
def _map_root_path(f):
|
||||||
if file_name.startswith("_virtual_imports"):
|
return _canonicalize_proto_path_oss(f).proto_path
|
||||||
repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2])
|
|
||||||
file_name = file_name.split("/", 2)[-1]
|
def _map_short_path(f):
|
||||||
proto_paths.append(genfile_path + "/external/" + repo_name)
|
return _canonicalize_proto_path_oss(f).file_name
|
||||||
proto_file_names.append(file_name)
|
|
||||||
else:
|
|
||||||
proto_file_names.append(s.path)
|
|
||||||
return ([" --proto_path=" + path for path in proto_paths], proto_file_names)
|
|
||||||
|
|
||||||
def _get_proto_provider(dep):
|
def _get_proto_provider(dep):
|
||||||
"""Get the provider for protocol buffers from a dependnecy.
|
"""Get the provider for protocol buffers from a dependnecy.
|
||||||
|
@ -90,24 +94,35 @@ def _encode_binary_proto_impl(ctx):
|
||||||
sibling = textpb,
|
sibling = textpb,
|
||||||
)
|
)
|
||||||
|
|
||||||
path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path)
|
args = ctx.actions.args()
|
||||||
|
args.add(textpb)
|
||||||
|
args.add(binarypb)
|
||||||
|
args.add(ctx.executable._proto_compiler)
|
||||||
|
args.add(ctx.attr.message_type, format = "--encode=%s")
|
||||||
|
args.add("--proto_path=.")
|
||||||
|
args.add_all(
|
||||||
|
all_protos,
|
||||||
|
map_each = _map_root_path,
|
||||||
|
format_each = "--proto_path=%s",
|
||||||
|
uniquify = True,
|
||||||
|
)
|
||||||
|
args.add_all(
|
||||||
|
all_protos,
|
||||||
|
map_each = _map_short_path,
|
||||||
|
uniquify = True,
|
||||||
|
)
|
||||||
|
|
||||||
# Note: the combination of absolute_paths and proto_path, as well as the exact
|
# Note: the combination of absolute_paths and proto_path, as well as the exact
|
||||||
# order of gendir before ., is needed for the proto compiler to resolve
|
# order of gendir before ., is needed for the proto compiler to resolve
|
||||||
# import statements that reference proto files produced by a genrule.
|
# import statements that reference proto files produced by a genrule.
|
||||||
ctx.actions.run_shell(
|
ctx.actions.run_shell(
|
||||||
tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler],
|
tools = depset(
|
||||||
outputs = [binarypb],
|
direct = [textpb, ctx.executable._proto_compiler],
|
||||||
command = " ".join(
|
transitive = [all_protos],
|
||||||
[
|
|
||||||
ctx.executable._proto_compiler.path,
|
|
||||||
"--encode=" + ctx.attr.message_type,
|
|
||||||
"--proto_path=" + ctx.genfiles_dir.path,
|
|
||||||
"--proto_path=" + ctx.bin_dir.path,
|
|
||||||
"--proto_path=.",
|
|
||||||
] + path_list + file_list +
|
|
||||||
["<", textpb.path, ">", binarypb.path],
|
|
||||||
),
|
),
|
||||||
|
outputs = [binarypb],
|
||||||
|
command = "${@:3} < $1 > $2",
|
||||||
|
arguments = [args],
|
||||||
mnemonic = "EncodeProto",
|
mnemonic = "EncodeProto",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe;
|
||||||
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
|
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
|
||||||
// the joint and its visibility.
|
// the joint and its visibility.
|
||||||
message Joint {
|
message Joint {
|
||||||
// Joint rotation in 6D contineous representation ordered as
|
// Joint rotation in 6D continuous representation ordered as
|
||||||
// [a1, b1, a2, b2, a3, b3].
|
// [a1, b1, a2, b2, a3, b3].
|
||||||
//
|
//
|
||||||
// Such representation is more sutable for NN model training and can be
|
// Such representation is more sutable for NN model training and can be
|
||||||
|
|
|
@ -15,7 +15,7 @@ def mediapipe_cc_test(
|
||||||
platforms = ["linux", "android", "ios", "wasm"],
|
platforms = ["linux", "android", "ios", "wasm"],
|
||||||
exclude_platforms = None,
|
exclude_platforms = None,
|
||||||
# ios_unit_test arguments
|
# ios_unit_test arguments
|
||||||
ios_minimum_os_version = "11.0",
|
ios_minimum_os_version = "12.0",
|
||||||
# android_cc_test arguments
|
# android_cc_test arguments
|
||||||
open_gl_driver = None,
|
open_gl_driver = None,
|
||||||
emulator_mini_boot = True,
|
emulator_mini_boot = True,
|
||||||
|
|
|
@ -466,8 +466,7 @@ struct MessageRegistrationImpl {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NoDestructor<mediapipe::RegistrationToken>
|
NoDestructor<mediapipe::RegistrationToken>
|
||||||
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
||||||
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder,
|
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
|
||||||
__FILE__, __LINE__));
|
|
||||||
|
|
||||||
// For non-Message payloads, this does nothing.
|
// For non-Message payloads, this does nothing.
|
||||||
template <typename T, typename Enable = void>
|
template <typename T, typename Enable = void>
|
||||||
|
|
|
@ -261,8 +261,8 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "opencv_highgui",
|
name = "opencv_photo",
|
||||||
hdrs = ["opencv_highgui_inc.h"],
|
hdrs = ["opencv_photo_inc.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":opencv_core",
|
":opencv_core",
|
||||||
"//third_party:opencv",
|
"//third_party:opencv",
|
||||||
|
@ -297,6 +297,15 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "opencv_highgui",
|
||||||
|
hdrs = ["opencv_highgui_inc.h"],
|
||||||
|
deps = [
|
||||||
|
":opencv_core",
|
||||||
|
"//third_party:opencv",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "opencv_videoio",
|
name = "opencv_videoio",
|
||||||
hdrs = ["opencv_videoio_inc.h"],
|
hdrs = ["opencv_videoio_inc.h"],
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
@ -12,8 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
|
#ifndef MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
|
||||||
#define MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
|
#define MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
|
||||||
|
|
||||||
#include <opencv2/core/version.hpp>
|
#include <opencv2/core/version.hpp>
|
||||||
|
|
||||||
|
@ -25,4 +25,4 @@
|
||||||
#include <opencv2/highgui.hpp>
|
#include <opencv2/highgui.hpp>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_
|
#endif // MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
@ -12,15 +12,9 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <vector>
|
#ifndef MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
|
||||||
|
#define MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
|
||||||
|
|
||||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
#include "third_party/OpenCV/photo.hpp"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
|
||||||
|
|
||||||
namespace mediapipe {
|
#endif // MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_
|
||||||
|
|
||||||
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
|
|
||||||
ClipDetectionVectorSizeCalculator;
|
|
||||||
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
|
|
||||||
|
|
||||||
} // namespace mediapipe
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2019 The MediaPipe Authors.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -48,6 +48,18 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
||||||
: InputStreamHandler(std::move(tag_map), cc_manager, options,
|
: InputStreamHandler(std::move(tag_map), cc_manager, options,
|
||||||
calculator_run_in_parallel) {}
|
calculator_run_in_parallel) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CollectionItemId GetControlStreamId() const {
|
||||||
|
return input_stream_managers_.EndId() - 1;
|
||||||
|
}
|
||||||
|
void RemoveOutdatedDataPackets(Timestamp timestamp) {
|
||||||
|
const CollectionItemId control_stream_id = GetControlStreamId();
|
||||||
|
for (CollectionItemId id = input_stream_managers_.BeginId();
|
||||||
|
id < control_stream_id; ++id) {
|
||||||
|
input_stream_managers_.Get(id)->ErasePacketsEarlierThan(timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// In MuxInputStreamHandler, a node is "ready" if:
|
// In MuxInputStreamHandler, a node is "ready" if:
|
||||||
// - the control stream is done (need to call Close() in this case), or
|
// - the control stream is done (need to call Close() in this case), or
|
||||||
|
@ -58,9 +70,15 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
||||||
absl::MutexLock lock(&input_streams_mutex_);
|
absl::MutexLock lock(&input_streams_mutex_);
|
||||||
|
|
||||||
const auto& control_stream =
|
const auto& control_stream =
|
||||||
input_stream_managers_.Get(input_stream_managers_.EndId() - 1);
|
input_stream_managers_.Get(GetControlStreamId());
|
||||||
bool empty;
|
bool empty;
|
||||||
*min_stream_timestamp = control_stream->MinTimestampOrBound(&empty);
|
*min_stream_timestamp = control_stream->MinTimestampOrBound(&empty);
|
||||||
|
|
||||||
|
// Data streams may contain some outdated packets which failed to be popped
|
||||||
|
// out during "FillInputSet". (This handler doesn't sync input streams,
|
||||||
|
// hence "FillInputSet" can be triggerred before every input stream is
|
||||||
|
// filled with packets corresponding to the same timestamp.)
|
||||||
|
RemoveOutdatedDataPackets(*min_stream_timestamp);
|
||||||
if (empty) {
|
if (empty) {
|
||||||
if (*min_stream_timestamp == Timestamp::Done()) {
|
if (*min_stream_timestamp == Timestamp::Done()) {
|
||||||
// Calculator is done if the control input stream is done.
|
// Calculator is done if the control input stream is done.
|
||||||
|
@ -78,11 +96,6 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
||||||
const auto& data_stream = input_stream_managers_.Get(
|
const auto& data_stream = input_stream_managers_.Get(
|
||||||
input_stream_managers_.BeginId() + control_value);
|
input_stream_managers_.BeginId() + control_value);
|
||||||
|
|
||||||
// Data stream may contain some outdated packets which failed to be popped
|
|
||||||
// out during "FillInputSet". (This handler doesn't sync input streams,
|
|
||||||
// hence "FillInputSet" can be triggerred before every input stream is
|
|
||||||
// filled with packets corresponding to the same timestamp.)
|
|
||||||
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
|
|
||||||
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
|
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
|
||||||
if (empty) {
|
if (empty) {
|
||||||
if (stream_timestamp <= *min_stream_timestamp) {
|
if (stream_timestamp <= *min_stream_timestamp) {
|
||||||
|
@ -111,8 +124,7 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
||||||
CHECK(input_set);
|
CHECK(input_set);
|
||||||
absl::MutexLock lock(&input_streams_mutex_);
|
absl::MutexLock lock(&input_streams_mutex_);
|
||||||
|
|
||||||
const CollectionItemId control_stream_id =
|
const CollectionItemId control_stream_id = GetControlStreamId();
|
||||||
input_stream_managers_.EndId() - 1;
|
|
||||||
auto& control_stream = input_stream_managers_.Get(control_stream_id);
|
auto& control_stream = input_stream_managers_.Get(control_stream_id);
|
||||||
int num_packets_dropped = 0;
|
int num_packets_dropped = 0;
|
||||||
bool stream_is_done = false;
|
bool stream_is_done = false;
|
||||||
|
@ -140,15 +152,8 @@ class MuxInputStreamHandler : public InputStreamHandler {
|
||||||
AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet),
|
AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet),
|
||||||
stream_is_done);
|
stream_is_done);
|
||||||
|
|
||||||
// Discard old packets on other streams.
|
// Discard old packets on data streams.
|
||||||
// Note that control_stream_id is the last valid id.
|
RemoveOutdatedDataPackets(input_timestamp.NextAllowedInStream());
|
||||||
auto next_timestamp = input_timestamp.NextAllowedInStream();
|
|
||||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
|
||||||
id < control_stream_id; ++id) {
|
|
||||||
if (id == data_stream_id) continue;
|
|
||||||
auto& other_stream = input_stream_managers_.Get(id);
|
|
||||||
other_stream->ErasePacketsEarlierThan(next_timestamp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -645,5 +645,41 @@ TEST(MuxInputStreamHandlerTest,
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MuxInputStreamHandlerTest, RemovesUnusedDataStreamPackets) {
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input0"
|
||||||
|
input_stream: "input1"
|
||||||
|
input_stream: "select"
|
||||||
|
node {
|
||||||
|
calculator: "MuxCalculator"
|
||||||
|
input_stream: "INPUT:0:input0"
|
||||||
|
input_stream: "INPUT:1:input1"
|
||||||
|
input_stream: "SELECT:select"
|
||||||
|
output_stream: "OUTPUT:output"
|
||||||
|
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
config.set_max_queue_size(1);
|
||||||
|
config.set_report_deadlock(true);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"select", MakePacket<int>(0).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input0", MakePacket<int>(1000).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Add two delayed packets to the deselected input. They should be discarded
|
||||||
|
// instead of triggering the deadlock detection (max_queue_size = 1).
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(900).At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input1", MakePacket<int>(900).At(Timestamp(2))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -64,13 +64,13 @@ GraphRegistry::GraphRegistry(
|
||||||
void GraphRegistry::Register(
|
void GraphRegistry::Register(
|
||||||
const std::string& type_name,
|
const std::string& type_name,
|
||||||
std::function<std::unique_ptr<Subgraph>()> factory) {
|
std::function<std::unique_ptr<Subgraph>()> factory) {
|
||||||
local_factories_.Register(type_name, factory, __FILE__, __LINE__);
|
local_factories_.Register(type_name, factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove this convenience function.
|
// TODO: Remove this convenience function.
|
||||||
void GraphRegistry::Register(const std::string& type_name,
|
void GraphRegistry::Register(const std::string& type_name,
|
||||||
const CalculatorGraphConfig& config) {
|
const CalculatorGraphConfig& config) {
|
||||||
Register(type_name, [config] {
|
local_factories_.Register(type_name, [config] {
|
||||||
auto result = absl::make_unique<ProtoSubgraph>(config);
|
auto result = absl::make_unique<ProtoSubgraph>(config);
|
||||||
return std::unique_ptr<Subgraph>(result.release());
|
return std::unique_ptr<Subgraph>(result.release());
|
||||||
});
|
});
|
||||||
|
@ -79,7 +79,7 @@ void GraphRegistry::Register(const std::string& type_name,
|
||||||
// TODO: Remove this convenience function.
|
// TODO: Remove this convenience function.
|
||||||
void GraphRegistry::Register(const std::string& type_name,
|
void GraphRegistry::Register(const std::string& type_name,
|
||||||
const CalculatorGraphTemplate& templ) {
|
const CalculatorGraphTemplate& templ) {
|
||||||
Register(type_name, [templ] {
|
local_factories_.Register(type_name, [templ] {
|
||||||
auto result = absl::make_unique<TemplateSubgraph>(templ);
|
auto result = absl::make_unique<TemplateSubgraph>(templ);
|
||||||
return std::unique_ptr<Subgraph>(result.release());
|
return std::unique_ptr<Subgraph>(result.release());
|
||||||
});
|
});
|
||||||
|
|
|
@ -228,7 +228,9 @@ absl::Status CompareAndSaveImageOutput(
|
||||||
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
|
auto status = CompareImageFrames(**expected, actual, options.max_color_diff,
|
||||||
options.max_alpha_diff, options.max_avg_diff,
|
options.max_alpha_diff, options.max_avg_diff,
|
||||||
diff_img);
|
diff_img);
|
||||||
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
|
if (diff_img) {
|
||||||
|
ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff"));
|
||||||
|
}
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1121,7 +1121,7 @@ objc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
MIN_IOS_VERSION = "11.0"
|
MIN_IOS_VERSION = "12.0"
|
||||||
|
|
||||||
test_suite(
|
test_suite(
|
||||||
name = "ios",
|
name = "ios",
|
||||||
|
|
|
@ -109,9 +109,8 @@ absl::Status GlContext::CreateContext(
|
||||||
}
|
}
|
||||||
MP_RETURN_IF_ERROR(status);
|
MP_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
LOG(INFO) << "Successfully created a WebGL context with major version "
|
VLOG(1) << "Successfully created a WebGL context with major version "
|
||||||
<< gl_major_version_ << " and handle " << context_;
|
<< gl_major_version_ << " and handle " << context_;
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -104,6 +104,7 @@ class GlScalerCalculator : public CalculatorBase {
|
||||||
bool vertical_flip_output_;
|
bool vertical_flip_output_;
|
||||||
bool horizontal_flip_output_;
|
bool horizontal_flip_output_;
|
||||||
FrameScaleMode scale_mode_ = FrameScaleMode::kStretch;
|
FrameScaleMode scale_mode_ = FrameScaleMode::kStretch;
|
||||||
|
bool use_nearest_neighbor_interpolation_ = false;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(GlScalerCalculator);
|
REGISTER_CALCULATOR(GlScalerCalculator);
|
||||||
|
|
||||||
|
@ -186,7 +187,8 @@ absl::Status GlScalerCalculator::Open(CalculatorContext* cc) {
|
||||||
scale_mode_ =
|
scale_mode_ =
|
||||||
FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch);
|
FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch);
|
||||||
}
|
}
|
||||||
|
use_nearest_neighbor_interpolation_ =
|
||||||
|
options.use_nearest_neighbor_interpolation();
|
||||||
if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) {
|
if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) {
|
||||||
const auto& dimensions =
|
const auto& dimensions =
|
||||||
TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)
|
TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)
|
||||||
|
@ -297,6 +299,11 @@ absl::Status GlScalerCalculator::Process(CalculatorContext* cc) {
|
||||||
glBindTexture(src2.target(), src2.name());
|
glBindTexture(src2.target(), src2.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (use_nearest_neighbor_interpolation_) {
|
||||||
|
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
|
||||||
|
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
|
||||||
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(renderer->GlRender(
|
MP_RETURN_IF_ERROR(renderer->GlRender(
|
||||||
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_,
|
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_,
|
||||||
rotation_, horizontal_flip_output_, vertical_flip_output_,
|
rotation_, horizontal_flip_output_, vertical_flip_output_,
|
||||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe;
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/gpu/scale_mode.proto";
|
import "mediapipe/gpu/scale_mode.proto";
|
||||||
|
|
||||||
// Next id: 8.
|
// Next id: 9.
|
||||||
message GlScalerCalculatorOptions {
|
message GlScalerCalculatorOptions {
|
||||||
extend CalculatorOptions {
|
extend CalculatorOptions {
|
||||||
optional GlScalerCalculatorOptions ext = 166373014;
|
optional GlScalerCalculatorOptions ext = 166373014;
|
||||||
|
@ -39,4 +39,7 @@ message GlScalerCalculatorOptions {
|
||||||
// Flip the output texture horizontally. This is applied after rotation.
|
// Flip the output texture horizontally. This is applied after rotation.
|
||||||
optional bool flip_horizontal = 5;
|
optional bool flip_horizontal = 5;
|
||||||
optional ScaleMode.Mode scale_mode = 6;
|
optional ScaleMode.Mode scale_mode = 6;
|
||||||
|
// Whether to use nearest neighbor interpolation. Default to use linear
|
||||||
|
// interpolation.
|
||||||
|
optional bool use_nearest_neighbor_interpolation = 8 [default = false];
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,6 +100,10 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
|
||||||
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
|
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
|
||||||
#endif // TARGET_OS_OSX
|
#endif // TARGET_OS_OSX
|
||||||
}},
|
}},
|
||||||
|
{GpuBufferFormat::kOneComponent8Alpha,
|
||||||
|
{
|
||||||
|
{GL_ALPHA, GL_ALPHA, GL_UNSIGNED_BYTE, 1},
|
||||||
|
}},
|
||||||
{GpuBufferFormat::kOneComponent8Red,
|
{GpuBufferFormat::kOneComponent8Red,
|
||||||
{
|
{
|
||||||
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
|
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
|
||||||
|
@ -221,6 +225,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
|
||||||
case GpuBufferFormat::kRGBA32:
|
case GpuBufferFormat::kRGBA32:
|
||||||
// TODO: this likely maps to ImageFormat::SRGBA
|
// TODO: this likely maps to ImageFormat::SRGBA
|
||||||
case GpuBufferFormat::kGrayHalf16:
|
case GpuBufferFormat::kGrayHalf16:
|
||||||
|
case GpuBufferFormat::kOneComponent8Alpha:
|
||||||
case GpuBufferFormat::kOneComponent8Red:
|
case GpuBufferFormat::kOneComponent8Red:
|
||||||
case GpuBufferFormat::kTwoComponent8:
|
case GpuBufferFormat::kTwoComponent8:
|
||||||
case GpuBufferFormat::kTwoComponentHalf16:
|
case GpuBufferFormat::kTwoComponentHalf16:
|
||||||
|
|
|
@ -43,6 +43,7 @@ enum class GpuBufferFormat : uint32_t {
|
||||||
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
|
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
|
||||||
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
|
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
|
||||||
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
|
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
|
||||||
|
kOneComponent8Alpha = MEDIAPIPE_FOURCC('A', '0', '0', '8'),
|
||||||
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
|
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
|
||||||
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
|
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
|
||||||
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
|
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
|
||||||
|
@ -101,6 +102,7 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
|
||||||
return kCVPixelFormatType_OneComponent32Float;
|
return kCVPixelFormatType_OneComponent32Float;
|
||||||
case GpuBufferFormat::kOneComponent8:
|
case GpuBufferFormat::kOneComponent8:
|
||||||
return kCVPixelFormatType_OneComponent8;
|
return kCVPixelFormatType_OneComponent8;
|
||||||
|
case GpuBufferFormat::kOneComponent8Alpha:
|
||||||
case GpuBufferFormat::kOneComponent8Red:
|
case GpuBufferFormat::kOneComponent8Red:
|
||||||
return -1;
|
return -1;
|
||||||
case GpuBufferFormat::kTwoComponent8:
|
case GpuBufferFormat::kTwoComponent8:
|
||||||
|
|
|
@ -78,17 +78,21 @@ public class AppTextureFrame implements TextureFrame {
|
||||||
* Use {@link waitUntilReleasedWithGpuSync} whenever possible.
|
* Use {@link waitUntilReleasedWithGpuSync} whenever possible.
|
||||||
*/
|
*/
|
||||||
public void waitUntilReleased() throws InterruptedException {
|
public void waitUntilReleased() throws InterruptedException {
|
||||||
|
GlSyncToken tokenToRelease = null;
|
||||||
synchronized (this) {
|
synchronized (this) {
|
||||||
while (inUse && releaseSyncToken == null) {
|
while (inUse && releaseSyncToken == null) {
|
||||||
wait();
|
wait();
|
||||||
}
|
}
|
||||||
if (releaseSyncToken != null) {
|
if (releaseSyncToken != null) {
|
||||||
releaseSyncToken.waitOnCpu();
|
tokenToRelease = releaseSyncToken;
|
||||||
releaseSyncToken.release();
|
|
||||||
inUse = false;
|
inUse = false;
|
||||||
releaseSyncToken = null;
|
releaseSyncToken = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (tokenToRelease != null) {
|
||||||
|
tokenToRelease.waitOnCpu();
|
||||||
|
tokenToRelease.release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -98,17 +102,21 @@ public class AppTextureFrame implements TextureFrame {
|
||||||
* TextureFrame.
|
* TextureFrame.
|
||||||
*/
|
*/
|
||||||
public void waitUntilReleasedWithGpuSync() throws InterruptedException {
|
public void waitUntilReleasedWithGpuSync() throws InterruptedException {
|
||||||
|
GlSyncToken tokenToRelease = null;
|
||||||
synchronized (this) {
|
synchronized (this) {
|
||||||
while (inUse && releaseSyncToken == null) {
|
while (inUse && releaseSyncToken == null) {
|
||||||
wait();
|
wait();
|
||||||
}
|
}
|
||||||
if (releaseSyncToken != null) {
|
if (releaseSyncToken != null) {
|
||||||
releaseSyncToken.waitOnGpu();
|
tokenToRelease = releaseSyncToken;
|
||||||
releaseSyncToken.release();
|
|
||||||
inUse = false;
|
inUse = false;
|
||||||
releaseSyncToken = null;
|
releaseSyncToken = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (tokenToRelease != null) {
|
||||||
|
tokenToRelease.waitOnGpu();
|
||||||
|
tokenToRelease.release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -239,7 +239,7 @@ public final class PacketGetter {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
|
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
|
||||||
* array has the the same size of image list packet, and assumes the output buffer stores pixels
|
* array has the same size of image list packet, and assumes the output buffer stores pixels
|
||||||
* contiguously. It returns false if this assumption does not hold.
|
* contiguously. It returns false if this assumption does not hold.
|
||||||
*
|
*
|
||||||
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of
|
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of
|
||||||
|
|
|
@ -24,6 +24,7 @@ package_group(
|
||||||
package_group(
|
package_group(
|
||||||
name = "1p_client",
|
name = "1p_client",
|
||||||
packages = [
|
packages = [
|
||||||
|
"//cloud/ml/applications/vision/model_garden/model_oss/mediapipe/...",
|
||||||
"//research/privacy/learning/fl_eval/pcvr/...",
|
"//research/privacy/learning/fl_eval/pcvr/...",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -57,3 +57,14 @@ py_test(
|
||||||
srcs = ["classification_dataset_test.py"],
|
srcs = ["classification_dataset_test.py"],
|
||||||
deps = [":classification_dataset"],
|
deps = [":classification_dataset"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cache_files",
|
||||||
|
srcs = ["cache_files.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "cache_files_test",
|
||||||
|
srcs = ["cache_files_test.py"],
|
||||||
|
deps = [":cache_files"],
|
||||||
|
)
|
||||||
|
|
112
mediapipe/model_maker/python/core/data/cache_files.py
Normal file
112
mediapipe/model_maker/python/core/data/cache_files.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Common TFRecord cache files library."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
# Suffix of the meta data file name.
|
||||||
|
METADATA_FILE_SUFFIX = '_metadata.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class TFRecordCacheFiles:
|
||||||
|
"""TFRecordCacheFiles dataclass to store and load cached TFRecord files.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cache_prefix_filename: The cache prefix filename. This is usually provided
|
||||||
|
as a hash of the original data source to avoid different data sources
|
||||||
|
resulting in the same cache file.
|
||||||
|
cache_dir: The cache directory to save TFRecord and metadata file. When
|
||||||
|
cache_dir is None, a temporary folder will be created and will not be
|
||||||
|
removed automatically after training which makes it can be used later.
|
||||||
|
num_shards: Number of shards for output tfrecord files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache_prefix_filename: str = 'cache_prefix'
|
||||||
|
cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp)
|
||||||
|
num_shards: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not tf.io.gfile.exists(self.cache_dir):
|
||||||
|
tf.io.gfile.makedirs(self.cache_dir)
|
||||||
|
if not self.cache_prefix_filename:
|
||||||
|
raise ValueError('cache_prefix_filename cannot be empty.')
|
||||||
|
if self.num_shards <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f'num_shards must be greater than 0, got {self.num_shards}'
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_prefix(self) -> str:
|
||||||
|
"""The cache prefix including the cache directory and the cache prefix filename."""
|
||||||
|
return os.path.join(self.cache_dir, self.cache_prefix_filename)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tfrecord_files(self) -> Sequence[str]:
|
||||||
|
"""The TFRecord files."""
|
||||||
|
tfrecord_files = [
|
||||||
|
self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards)
|
||||||
|
for i in range(self.num_shards)
|
||||||
|
]
|
||||||
|
return tfrecord_files
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata_file(self) -> str:
|
||||||
|
"""The metadata file."""
|
||||||
|
return self.cache_prefix + METADATA_FILE_SUFFIX
|
||||||
|
|
||||||
|
def get_writers(self) -> Sequence[tf.io.TFRecordWriter]:
|
||||||
|
"""Gets an array of TFRecordWriter objects.
|
||||||
|
|
||||||
|
Note that these writers should each be closed using .close() when done.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of TFRecordWriter objects
|
||||||
|
"""
|
||||||
|
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
|
||||||
|
|
||||||
|
def save_metadata(self, metadata):
|
||||||
|
"""Writes metadata to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: A dictionary of metadata content to write. Exact format is
|
||||||
|
dependent on the specific dataset, but typically includes a 'size' and
|
||||||
|
'label_names' entry.
|
||||||
|
"""
|
||||||
|
with tf.io.gfile.GFile(self.metadata_file, 'w') as f:
|
||||||
|
yaml.dump(metadata, f)
|
||||||
|
|
||||||
|
def load_metadata(self) -> Mapping[Any, Any]:
|
||||||
|
"""Reads metadata from file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary object containing metadata
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(self.metadata_file):
|
||||||
|
return {}
|
||||||
|
with tf.io.gfile.GFile(self.metadata_file, 'r') as f:
|
||||||
|
metadata = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def is_cached(self) -> bool:
|
||||||
|
"""Checks whether this CacheFiles is already cached."""
|
||||||
|
all_cached_files = list(self.tfrecord_files) + [self.metadata_file]
|
||||||
|
return all(tf.io.gfile.exists(f) for f in all_cached_files)
|
77
mediapipe/model_maker/python/core/data/cache_files_test.py
Normal file
77
mediapipe/model_maker/python/core/data/cache_files_test.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
|
|
||||||
|
|
||||||
|
class CacheFilesTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_tfrecord_cache_files(self):
|
||||||
|
cf = cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='tfrecord',
|
||||||
|
cache_dir='/tmp/cache_dir',
|
||||||
|
num_shards=2,
|
||||||
|
)
|
||||||
|
self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord')
|
||||||
|
self.assertEqual(
|
||||||
|
cf.metadata_file,
|
||||||
|
'/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX,
|
||||||
|
)
|
||||||
|
expected_tfrecord_files = [
|
||||||
|
'/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2)
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
self.assertEqual(cf.tfrecord_files, expected_tfrecord_files)
|
||||||
|
|
||||||
|
# Writing TFRecord Files
|
||||||
|
self.assertFalse(cf.is_cached())
|
||||||
|
for tfrecord_file in cf.tfrecord_files:
|
||||||
|
self.assertFalse(tf.io.gfile.exists(tfrecord_file))
|
||||||
|
writers = cf.get_writers()
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
for tfrecord_file in cf.tfrecord_files:
|
||||||
|
self.assertTrue(tf.io.gfile.exists(tfrecord_file))
|
||||||
|
self.assertFalse(cf.is_cached())
|
||||||
|
|
||||||
|
# Writing Metadata Files
|
||||||
|
original_metadata = {'size': 10, 'label_names': ['label1', 'label2']}
|
||||||
|
cf.save_metadata(original_metadata)
|
||||||
|
self.assertTrue(cf.is_cached())
|
||||||
|
metadata = cf.load_metadata()
|
||||||
|
self.assertEqual(metadata, original_metadata)
|
||||||
|
|
||||||
|
def test_recordio_cache_files_error(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'cache_prefix_filename cannot be empty'
|
||||||
|
):
|
||||||
|
cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='',
|
||||||
|
cache_dir='/tmp/cache_dir',
|
||||||
|
num_shards=2,
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'num_shards must be greater than 0, got 0'
|
||||||
|
):
|
||||||
|
cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='tfrecord',
|
||||||
|
cache_dir='/tmp/cache_dir',
|
||||||
|
num_shards=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Common classification dataset library."""
|
"""Common classification dataset library."""
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
class ClassificationDataset(ds.Dataset):
|
class ClassificationDataset(ds.Dataset):
|
||||||
"""Dataset Loader for classification models."""
|
"""Dataset Loader for classification models."""
|
||||||
|
|
||||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
def __init__(
|
||||||
label_names: List[str]):
|
self,
|
||||||
|
dataset: tf.data.Dataset,
|
||||||
|
label_names: List[str],
|
||||||
|
size: Optional[int] = None,
|
||||||
|
):
|
||||||
super().__init__(dataset, size)
|
super().__init__(dataset, size)
|
||||||
self._label_names = label_names
|
self._label_names = label_names
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
||||||
value: A value variable stored by the mock dataset class for testing.
|
value: A value variable stored by the mock dataset class for testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
def __init__(
|
||||||
label_names: List[str], value: Any):
|
self,
|
||||||
super().__init__(dataset=dataset, size=size, label_names=label_names)
|
dataset: tf.data.Dataset,
|
||||||
|
label_names: List[str],
|
||||||
|
value: Any,
|
||||||
|
size: int,
|
||||||
|
):
|
||||||
|
super().__init__(dataset=dataset, label_names=label_names, size=size)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||||
|
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
||||||
# Create data loader from sample data.
|
# Create data loader from sample data.
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
data = MagicClassificationDataset(
|
data = MagicClassificationDataset(
|
||||||
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
|
dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
|
||||||
|
)
|
||||||
|
|
||||||
# Train/Test data split.
|
# Train/Test data split.
|
||||||
fraction = .25
|
fraction = .25
|
||||||
|
|
|
@ -56,15 +56,14 @@ class Dataset(object):
|
||||||
def size(self) -> Optional[int]:
|
def size(self) -> Optional[int]:
|
||||||
"""Returns the size of the dataset.
|
"""Returns the size of the dataset.
|
||||||
|
|
||||||
Note that this function may return None becuase the exact size of the
|
Same functionality as calling __len__. See the __len__ method definition for
|
||||||
dataset isn't a necessary parameter to create an instance of this class,
|
more information.
|
||||||
and tf.data.Dataset donesn't support a function to get the length directly
|
|
||||||
since it's lazy-loaded and may be infinite.
|
Raises:
|
||||||
In most cases, however, when an instance of this class is created by helper
|
TypeError if self._size is not set and the cardinality of self._dataset
|
||||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
|
||||||
and this function can return an int representing the size of the dataset.
|
|
||||||
"""
|
"""
|
||||||
return self._size
|
return self.__len__()
|
||||||
|
|
||||||
def gen_tf_dataset(
|
def gen_tf_dataset(
|
||||||
self,
|
self,
|
||||||
|
@ -116,8 +115,22 @@ class Dataset(object):
|
||||||
# here.
|
# here.
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
"""Returns the number of element of the dataset."""
|
"""Returns the number of element of the dataset.
|
||||||
|
|
||||||
|
If size is not set, this method will fallback to using the __len__ method
|
||||||
|
of the tf.data.Dataset in self._dataset. Calling __len__ on a
|
||||||
|
tf.data.Dataset instance may throw a TypeError because the dataset may
|
||||||
|
be lazy-loaded with an unknown size or have infinite size.
|
||||||
|
|
||||||
|
In most cases, however, when an instance of this class is created by helper
|
||||||
|
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||||
|
and the _size instance variable will be already set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError if self._size is not set and the cardinality of self._dataset
|
||||||
|
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
|
||||||
|
"""
|
||||||
if self._size is not None:
|
if self._size is not None:
|
||||||
return self._size
|
return self._size
|
||||||
else:
|
else:
|
||||||
|
@ -152,15 +165,25 @@ class Dataset(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The splitted two sub datasets.
|
The splitted two sub datasets.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the provided fraction is not between 0 and 1.
|
||||||
|
ValueError: if this dataset does not have a set size.
|
||||||
"""
|
"""
|
||||||
assert (fraction > 0 and fraction < 1)
|
if not (fraction > 0 and fraction < 1):
|
||||||
|
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
|
||||||
|
if not self._size:
|
||||||
|
raise ValueError(
|
||||||
|
'Dataset size unknown. Cannot split the dataset when '
|
||||||
|
'the size is unknown.'
|
||||||
|
)
|
||||||
|
|
||||||
dataset = self._dataset
|
dataset = self._dataset
|
||||||
|
|
||||||
train_size = int(self._size * fraction)
|
train_size = int(self._size * fraction)
|
||||||
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
|
||||||
|
|
||||||
test_size = self._size - train_size
|
test_size = self._size - train_size
|
||||||
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
|
||||||
|
|
||||||
return trainset, testset
|
return trainset, testset
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional
|
from typing import Mapping, Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -36,6 +36,8 @@ class BaseHParams:
|
||||||
steps_per_epoch: An optional integer indicate the number of training steps
|
steps_per_epoch: An optional integer indicate the number of training steps
|
||||||
per epoch. If not set, the training pipeline calculates the default steps
|
per epoch. If not set, the training pipeline calculates the default steps
|
||||||
per epoch as the training dataset size divided by batch size.
|
per epoch as the training dataset size divided by batch size.
|
||||||
|
class_weights: An optional mapping of indices to weights for weighting the
|
||||||
|
loss function during training.
|
||||||
shuffle: True if the dataset is shuffled before training.
|
shuffle: True if the dataset is shuffled before training.
|
||||||
export_dir: The location of the model checkpoint files.
|
export_dir: The location of the model checkpoint files.
|
||||||
distribution_strategy: A string specifying which Distribution Strategy to
|
distribution_strategy: A string specifying which Distribution Strategy to
|
||||||
|
@ -57,6 +59,7 @@ class BaseHParams:
|
||||||
batch_size: int
|
batch_size: int
|
||||||
epochs: int
|
epochs: int
|
||||||
steps_per_epoch: Optional[int] = None
|
steps_per_epoch: Optional[int] = None
|
||||||
|
class_weights: Optional[Mapping[int, float]] = None
|
||||||
|
|
||||||
# Dataset-related parameters
|
# Dataset-related parameters
|
||||||
shuffle: bool = False
|
shuffle: bool = False
|
||||||
|
|
|
@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
|
||||||
# dataset is exhausted even if there are epochs remaining.
|
# dataset is exhausted even if there are epochs remaining.
|
||||||
steps_per_epoch=None,
|
steps_per_epoch=None,
|
||||||
validation_data=validation_dataset,
|
validation_data=validation_dataset,
|
||||||
callbacks=self._callbacks)
|
callbacks=self._callbacks,
|
||||||
|
class_weight=self._hparams.class_weights,
|
||||||
|
)
|
||||||
|
|
||||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
"""Evaluates the classifier with the provided evaluation dataset.
|
"""Evaluates the classifier with the provided evaluation dataset.
|
||||||
|
|
|
@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||||
"""Constructor.
|
"""Initializes FocalLoss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gamma: Focal loss gamma, as described in class docs.
|
gamma: Focal loss gamma, as described in class docs.
|
||||||
|
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
|
||||||
return tf.reduce_sum(losses) / batch_size
|
return tf.reduce_sum(losses) / batch_size
|
||||||
|
|
||||||
|
|
||||||
|
class SparseFocalLoss(FocalLoss):
|
||||||
|
"""Sparse implementation of Focal Loss.
|
||||||
|
|
||||||
|
This is the same as FocalLoss, except the labels are expected to be class ids
|
||||||
|
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
|
||||||
|
in this same file for more details.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> y_true = [1, 2]
|
||||||
|
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||||
|
>>> gamma = 2
|
||||||
|
>>> focal_loss = SparseFocalLoss(gamma, 3)
|
||||||
|
>>> focal_loss(y_true, y_pred).numpy()
|
||||||
|
0.9326
|
||||||
|
|
||||||
|
>>> # Calling with 'sample_weight'.
|
||||||
|
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||||
|
0.6528
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
|
||||||
|
):
|
||||||
|
"""Initializes SparseFocalLoss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gamma: Focal loss gamma, as described in class docs.
|
||||||
|
num_classes: Number of classes.
|
||||||
|
class_weight: A weight to apply to the loss, one for each class. The
|
||||||
|
weight is applied for each input where the ground truth label matches.
|
||||||
|
"""
|
||||||
|
super().__init__(gamma, class_weight=class_weight)
|
||||||
|
self._num_classes = num_classes
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
y_true: tf.Tensor,
|
||||||
|
y_pred: tf.Tensor,
|
||||||
|
sample_weight: Optional[tf.Tensor] = None,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
|
||||||
|
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
|
||||||
|
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PerceptualLossWeight:
|
class PerceptualLossWeight:
|
||||||
"""The weight for each perceptual loss.
|
"""The weight for each perceptual loss.
|
||||||
|
|
|
@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNear(loss, expected_loss, 1e-4)
|
self.assertNear(loss, expected_loss, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseFocalLossTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_sparse_focal_loss_matches_focal_loss(self):
|
||||||
|
num_classes = 2
|
||||||
|
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
|
||||||
|
y_true = tf.constant([1, 0])
|
||||||
|
y_true_one_hot = tf.one_hot(y_true, num_classes)
|
||||||
|
for gamma in [0.0, 0.5, 1.0]:
|
||||||
|
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
|
||||||
|
loss_fn = loss_functions.SparseFocalLoss(
|
||||||
|
gamma=gamma, num_classes=num_classes
|
||||||
|
)
|
||||||
|
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
|
||||||
|
loss = loss_fn(y_true, y_pred)
|
||||||
|
self.assertNear(loss, expected_loss, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
||||||
"""A mock class with implementation of abstract methods for testing."""
|
"""A mock class with implementation of abstract methods for testing."""
|
||||||
|
|
||||||
|
|
|
@ -46,13 +46,17 @@ class BertModelSpec:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
downloaded_files: file_util.DownloadedFiles
|
downloaded_files: file_util.DownloadedFiles
|
||||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
hparams: hp.BaseHParams = dataclasses.field(
|
||||||
epochs=3,
|
default_factory=lambda: hp.BaseHParams(
|
||||||
batch_size=32,
|
epochs=3,
|
||||||
learning_rate=3e-5,
|
batch_size=32,
|
||||||
distribution_strategy='mirrored')
|
learning_rate=3e-5,
|
||||||
model_options: bert_model_options.BertModelOptions = (
|
distribution_strategy='mirrored',
|
||||||
bert_model_options.BertModelOptions())
|
)
|
||||||
|
)
|
||||||
|
model_options: bert_model_options.BertModelOptions = dataclasses.field(
|
||||||
|
default_factory=bert_model_options.BertModelOptions
|
||||||
|
)
|
||||||
do_lower_case: bool = True
|
do_lower_case: bool = True
|
||||||
tflite_input_name: Dict[str, str] = dataclasses.field(
|
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||||
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# Placeholder for internal Python strict library and test compatibility macro.
|
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||||
# Placeholder for internal Python strict test compatibility macro.
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
@ -76,7 +76,10 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -88,7 +91,10 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "preprocessor",
|
name = "preprocessor",
|
||||||
srcs = ["preprocessor.py"],
|
srcs = ["preprocessor.py"],
|
||||||
deps = [":dataset"],
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -99,6 +105,7 @@ py_test(
|
||||||
":dataset",
|
":dataset",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":preprocessor",
|
":preprocessor",
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -124,6 +131,7 @@ py_library(
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
@ -147,6 +155,7 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":text_classifier_import",
|
":text_classifier_import",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,11 +15,15 @@
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
import tempfile
|
||||||
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from typing import Optional, Sequence
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,21 +50,49 @@ class CSVParameters:
|
||||||
class Dataset(classification_dataset.ClassificationDataset):
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
"""Dataset library for text classifier."""
|
"""Dataset library for text classifier."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: tf.data.Dataset,
|
||||||
|
label_names: List[str],
|
||||||
|
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
|
||||||
|
size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(dataset, label_names, size)
|
||||||
|
if not tfrecord_cache_files:
|
||||||
|
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename="tfrecord", num_shards=1
|
||||||
|
)
|
||||||
|
self.tfrecord_cache_files = tfrecord_cache_files
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_csv(cls,
|
def from_csv(
|
||||||
filename: str,
|
cls,
|
||||||
csv_params: CSVParameters,
|
filename: str,
|
||||||
shuffle: bool = True) -> "Dataset":
|
csv_params: CSVParameters,
|
||||||
|
shuffle: bool = True,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
num_shards: int = 1,
|
||||||
|
) -> "Dataset":
|
||||||
"""Loads text with labels from a CSV file.
|
"""Loads text with labels from a CSV file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of the CSV file.
|
filename: Name of the CSV file.
|
||||||
csv_params: Parameters used for reading the CSV file.
|
csv_params: Parameters used for reading the CSV file.
|
||||||
shuffle: If True, randomly shuffle the data.
|
shuffle: If True, randomly shuffle the data.
|
||||||
|
cache_dir: Optional parameter to specify where to store the preprocessed
|
||||||
|
dataset. Only used for BERT models.
|
||||||
|
num_shards: Optional parameter for num shards of the preprocessed dataset.
|
||||||
|
Note that using more than 1 shard will reorder the dataset. Only used
|
||||||
|
for BERT models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing (text, label) pairs and other related info.
|
Dataset containing (text, label) pairs and other related info.
|
||||||
"""
|
"""
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = tempfile.mkdtemp()
|
||||||
|
# calculate hash for cache based off of files
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
hasher.update(os.path.basename(filename).encode("utf-8"))
|
||||||
with tf.io.gfile.GFile(filename, "r") as f:
|
with tf.io.gfile.GFile(filename, "r") as f:
|
||||||
reader = csv.DictReader(
|
reader = csv.DictReader(
|
||||||
f,
|
f,
|
||||||
|
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
quotechar=csv_params.quotechar)
|
quotechar=csv_params.quotechar)
|
||||||
|
|
||||||
lines = list(reader)
|
lines = list(reader)
|
||||||
|
for line in lines:
|
||||||
|
hasher.update(str(line).encode("utf-8"))
|
||||||
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
random.shuffle(lines)
|
random.shuffle(lines)
|
||||||
|
|
||||||
|
@ -81,8 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
index_by_label[line[csv_params.label_column]] for line in lines
|
index_by_label[line[csv_params.label_column]] for line in lines
|
||||||
]
|
]
|
||||||
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
tf.cast(label_indices, tf.int64))
|
tf.cast(label_indices, tf.int64)
|
||||||
|
)
|
||||||
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||||
|
|
||||||
|
hasher.update(str(num_shards).encode("utf-8"))
|
||||||
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
|
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename, cache_dir, num_shards
|
||||||
|
)
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=text_label_ds, size=len(texts), label_names=label_names)
|
dataset=text_label_ds,
|
||||||
|
label_names=label_names,
|
||||||
|
tfrecord_cache_files=tfrecord_cache_files,
|
||||||
|
size=len(texts),
|
||||||
|
)
|
||||||
|
|
|
@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
|
||||||
train_data, test_data = data.split(0.5)
|
train_data, test_data = data.split(0.5)
|
||||||
expected_train_data = [b'good', b'bad']
|
expected_train_data = [b'good', b'bad']
|
||||||
expected_test_data = [b'neutral', b'odd']
|
expected_test_data = [b'neutral', b'odd']
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
from typing import Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
learning_rate: Learning rate to use for gradient descent training.
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
batch_size: Batch size for training.
|
end_learning_rate: End learning rate for linear decay. Defaults to 0.
|
||||||
epochs: Number of training iterations over the dataset.
|
batch_size: Batch size for training. Defaults to 48.
|
||||||
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
epochs: Number of training iterations over the dataset. Defaults to 2.
|
||||||
and "lamb".
|
optimizer: Optimizer to use for training. Supported values are defined in
|
||||||
|
BertOptimizer enum: ADAMW and LAMB.
|
||||||
|
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
|
||||||
|
desired_precisions: If specified, adds a RecallAtPrecision metric per
|
||||||
|
desired_precisions[i] entry which tracks the recall given the constraint
|
||||||
|
on precision. Only supported for binary classification.
|
||||||
|
desired_recalls: If specified, adds a PrecisionAtRecall metric per
|
||||||
|
desired_recalls[i] entry which tracks the precision given the constraint
|
||||||
|
on recall. Only supported for binary classification.
|
||||||
|
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
||||||
|
value to 0. Defaults to 2.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
learning_rate: float = 3e-5
|
learning_rate: float = 3e-5
|
||||||
|
end_learning_rate: float = 0.0
|
||||||
|
|
||||||
batch_size: int = 48
|
batch_size: int = 48
|
||||||
epochs: int = 2
|
epochs: int = 2
|
||||||
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||||
|
weight_decay: float = 0.01
|
||||||
|
|
||||||
|
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
|
||||||
|
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
gamma: float = 2.0
|
||||||
|
|
||||||
|
|
||||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||||
|
|
|
@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# `learning_rate` is unused for the average word embedding model
|
# `learning_rate` is unused for the average word embedding model
|
||||||
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
|
hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
|
||||||
epochs=10, batch_size=32, learning_rate=0
|
default_factory=lambda: hp.AverageWordEmbeddingHParams(
|
||||||
|
epochs=10, batch_size=32, learning_rate=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
|
||||||
|
default_factory=mo.AverageWordEmbeddingModelOptions
|
||||||
)
|
)
|
||||||
model_options: mo.AverageWordEmbeddingModelOptions = (
|
|
||||||
mo.AverageWordEmbeddingModelOptions())
|
|
||||||
name: str = 'AverageWordEmbedding'
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
average_word_embedding_classifier_spec = functools.partial(
|
average_word_embedding_classifier_spec = functools.partial(
|
||||||
|
@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
|
||||||
inherited from the BertModelSpec.
|
inherited from the BertModelSpec.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hparams: hp.BertHParams = hp.BertHParams()
|
hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
|
||||||
|
|
||||||
|
|
||||||
mobilebert_classifier_spec = functools.partial(
|
mobilebert_classifier_spec = functools.partial(
|
||||||
|
@ -76,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='MobileBert',
|
name='MobileBert',
|
||||||
tflite_input_name={
|
|
||||||
'ids': 'serving_default_input_1:0',
|
|
||||||
'segment_ids': 'serving_default_input_2:0',
|
|
||||||
'mask': 'serving_default_input_3:0',
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
exbert_classifier_spec = functools.partial(
|
exbert_classifier_spec = functools.partial(
|
||||||
|
@ -90,11 +88,6 @@ exbert_classifier_spec = functools.partial(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='ExBert',
|
name='ExBert',
|
||||||
tflite_input_name={
|
|
||||||
'ids': 'serving_default_input_1:0',
|
|
||||||
'segment_ids': 'serving_default_input_2:0',
|
|
||||||
'mask': 'serving_default_input_3:0',
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
||||||
self.assertTrue(model_spec_obj.do_lower_case)
|
self.assertTrue(model_spec_obj.do_lower_case)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.tflite_input_name, {
|
model_spec_obj.tflite_input_name,
|
||||||
'ids': 'serving_default_input_1:0',
|
{
|
||||||
'mask': 'serving_default_input_3:0',
|
'ids': 'serving_default_input_word_ids:0',
|
||||||
'segment_ids': 'serving_default_input_2:0'
|
'mask': 'serving_default_input_mask:0',
|
||||||
})
|
'segment_ids': 'serving_default_input_type_ids:0',
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.model_options,
|
model_spec_obj.model_options,
|
||||||
classifier_model_options.BertModelOptions(
|
classifier_model_options.BertModelOptions(
|
||||||
|
|
|
@ -15,14 +15,15 @@
|
||||||
"""Preprocessors for text classification."""
|
"""Preprocessors for text classification."""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
from typing import Mapping, Sequence, Tuple, Union
|
from typing import Mapping, Sequence, Tuple, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_hub
|
import tensorflow_hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from official.nlp.data import classifier_data_lib
|
from official.nlp.data import classifier_data_lib
|
||||||
from official.nlp.tools import tokenization
|
from official.nlp.tools import tokenization
|
||||||
|
@ -75,19 +76,20 @@ def _decode_record(
|
||||||
return bert_features, example["label_ids"]
|
return bert_features, example["label_ids"]
|
||||||
|
|
||||||
|
|
||||||
def _single_file_dataset(
|
def _tfrecord_dataset(
|
||||||
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
tfrecord_files: Sequence[str],
|
||||||
|
name_to_features: Mapping[str, tf.io.FixedLenFeature],
|
||||||
) -> tf.data.TFRecordDataset:
|
) -> tf.data.TFRecordDataset:
|
||||||
"""Creates a single-file dataset to be passed for BERT custom training.
|
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_file: Filepath for the dataset.
|
tfrecord_files: Filepaths for the dataset.
|
||||||
name_to_features: Maps record keys to feature types.
|
name_to_features: Maps record keys to feature types.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing BERT model input features and labels.
|
Dataset containing BERT model input features and labels.
|
||||||
"""
|
"""
|
||||||
d = tf.data.TFRecordDataset(input_file)
|
d = tf.data.TFRecordDataset(tfrecord_files)
|
||||||
d = d.map(
|
d = d.map(
|
||||||
lambda record: _decode_record(record, name_to_features),
|
lambda record: _decode_record(record, name_to_features),
|
||||||
num_parallel_calls=tf.data.AUTOTUNE)
|
num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
|
||||||
seq_len: Length of the input sequence to the model.
|
seq_len: Length of the input sequence to the model.
|
||||||
vocab_file: File containing the BERT vocab.
|
vocab_file: File containing the BERT vocab.
|
||||||
tokenizer: BERT tokenizer.
|
tokenizer: BERT tokenizer.
|
||||||
|
model_name: Name of the model provided by the model_spec. Used to associate
|
||||||
|
cached files with specific Bert model vocab.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
def __init__(
|
||||||
|
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
|
||||||
|
):
|
||||||
self._seq_len = seq_len
|
self._seq_len = seq_len
|
||||||
# Vocab filepath is tied to the BERT module's URI.
|
# Vocab filepath is tied to the BERT module's URI.
|
||||||
self._vocab_file = os.path.join(
|
self._vocab_file = os.path.join(
|
||||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
||||||
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
)
|
||||||
do_lower_case)
|
self._do_lower_case = do_lower_case
|
||||||
|
self._tokenizer = tokenization.FullTokenizer(
|
||||||
|
self._vocab_file, self._do_lower_case
|
||||||
|
)
|
||||||
|
self._model_name = model_name
|
||||||
|
|
||||||
def _get_name_to_features(self):
|
def _get_name_to_features(self):
|
||||||
"""Gets the dictionary mapping record keys to feature types."""
|
"""Gets the dictionary mapping record keys to feature types."""
|
||||||
|
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
|
||||||
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||||
return self._vocab_file
|
return self._vocab_file
|
||||||
|
|
||||||
|
def _get_tfrecord_cache_files(
|
||||||
|
self, ds_cache_files
|
||||||
|
) -> cache_files_lib.TFRecordCacheFiles:
|
||||||
|
"""Helper to regenerate cache prefix filename using preprocessor info.
|
||||||
|
|
||||||
|
We need to update the dataset cache_prefix cache because the actual cached
|
||||||
|
dataset depends on the preprocessor parameters such as model_name, seq_len,
|
||||||
|
and do_lower_case in addition to the raw dataset parameters which is already
|
||||||
|
included in the ds_cache_files.cache_prefix_filename
|
||||||
|
|
||||||
|
Specifically, the new cache_prefix_filename used by the preprocessor will
|
||||||
|
be a hash generated from the following:
|
||||||
|
1. cache_prefix_filename of the initial raw dataset
|
||||||
|
2. model_name
|
||||||
|
3. seq_len
|
||||||
|
4. do_lower_case
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new TFRecordCacheFiles object which incorporates the preprocessor
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
|
||||||
|
hasher.update(self._model_name.encode("utf-8"))
|
||||||
|
hasher.update(str(self._seq_len).encode("utf-8"))
|
||||||
|
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
||||||
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
|
return cache_files_lib.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename,
|
||||||
|
ds_cache_files.cache_dir,
|
||||||
|
ds_cache_files.num_shards,
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
self, dataset: text_classifier_ds.Dataset
|
||||||
|
) -> text_classifier_ds.Dataset:
|
||||||
"""Preprocesses data into input for a BERT-based classifier.
|
"""Preprocesses data into input for a BERT-based classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing (bert_features, label) data.
|
Dataset containing (bert_features, label) data.
|
||||||
"""
|
"""
|
||||||
examples = []
|
ds_cache_files = dataset.tfrecord_cache_files
|
||||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
# Get new tfrecord_cache_files by including preprocessor information.
|
||||||
_validate_text_and_label(text, label)
|
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
|
||||||
examples.append(
|
if not tfrecord_cache_files.is_cached():
|
||||||
classifier_data_lib.InputExample(
|
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
|
||||||
guid=str(index),
|
writers = tfrecord_cache_files.get_writers()
|
||||||
text_a=text.numpy()[0].decode("utf-8"),
|
size = 0
|
||||||
text_b=None,
|
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||||
# InputExample expects the label name rather than the int ID
|
_validate_text_and_label(text, label)
|
||||||
label=dataset.label_names[label.numpy()[0]]))
|
example = classifier_data_lib.InputExample(
|
||||||
|
guid=str(index),
|
||||||
|
text_a=text.numpy()[0].decode("utf-8"),
|
||||||
|
text_b=None,
|
||||||
|
# InputExample expects the label name rather than the int ID
|
||||||
|
# label=dataset.label_names[label.numpy()[0]])
|
||||||
|
label=label.numpy()[0],
|
||||||
|
)
|
||||||
|
feature = classifier_data_lib.convert_single_example(
|
||||||
|
index, example, None, self._seq_len, self._tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
def create_int_feature(values):
|
||||||
classifier_data_lib.file_based_convert_examples_to_features(
|
f = tf.train.Feature(
|
||||||
examples=examples,
|
int64_list=tf.train.Int64List(value=list(values))
|
||||||
label_list=dataset.label_names,
|
)
|
||||||
max_seq_length=self._seq_len,
|
return f
|
||||||
tokenizer=self._tokenizer,
|
|
||||||
output_file=tfrecord_file)
|
features = collections.OrderedDict()
|
||||||
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||||
self._get_name_to_features())
|
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||||
|
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||||
|
features["label_ids"] = create_int_feature([feature.label_id])
|
||||||
|
tf_example = tf.train.Example(
|
||||||
|
features=tf.train.Features(feature=features)
|
||||||
|
)
|
||||||
|
writers[index % len(writers)].write(tf_example.SerializeToString())
|
||||||
|
size = index + 1
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
metadata = {"size": size, "label_names": dataset.label_names}
|
||||||
|
tfrecord_cache_files.save_metadata(metadata)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
|
||||||
|
)
|
||||||
|
metadata = tfrecord_cache_files.load_metadata()
|
||||||
|
size = metadata["size"]
|
||||||
|
label_names = metadata["label_names"]
|
||||||
|
preprocessed_ds = _tfrecord_dataset(
|
||||||
|
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
|
||||||
|
)
|
||||||
return text_classifier_ds.Dataset(
|
return text_classifier_ds.Dataset(
|
||||||
dataset=preprocessed_ds,
|
dataset=preprocessed_ds,
|
||||||
size=dataset.size,
|
size=size,
|
||||||
label_names=dataset.label_names)
|
label_names=label_names,
|
||||||
|
tfrecord_cache_files=tfrecord_cache_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TextClassifierPreprocessor = (
|
TextClassifierPreprocessor = Union[
|
||||||
Union[BertClassifierPreprocessor,
|
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
||||||
AverageWordEmbeddingClassifierPreprocessor])
|
]
|
||||||
|
|
|
@ -13,14 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.testing as npt
|
import numpy.testing as npt
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
csv_file = self._get_csv_file()
|
csv_file = self._get_csv_file()
|
||||||
dataset = text_classifier_ds.Dataset.from_csv(
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
seq_len=5,
|
seq_len=5,
|
||||||
do_lower_case=bert_spec.do_lower_case,
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
uri=bert_spec.downloaded_files.get_path(),
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
model_name=bert_spec.name,
|
||||||
)
|
)
|
||||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
self.assertEqual(label.shape, [1])
|
self.assertEqual(label.shape, [1])
|
||||||
labels.append(label.numpy()[0])
|
labels.append(label.numpy()[0])
|
||||||
self.assertSameElements(
|
self.assertSameElements(
|
||||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
|
||||||
|
)
|
||||||
for feature in features.values():
|
for feature in features.values():
|
||||||
self.assertEqual(feature.shape, [1, 5])
|
self.assertEqual(feature.shape, [1, 5])
|
||||||
input_masks.append(features['input_mask'].numpy()[0])
|
input_masks.append(features['input_mask'].numpy()[0])
|
||||||
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
npt.assert_array_equal(
|
||||||
[0, 0, 0, 0, 0])
|
features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
|
||||||
|
)
|
||||||
npt.assert_array_equal(
|
npt.assert_array_equal(
|
||||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
|
||||||
|
)
|
||||||
self.assertEqual(labels, [1, 0])
|
self.assertEqual(labels, [1, 0])
|
||||||
|
|
||||||
|
def test_bert_preprocessor_cache(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
|
filename=csv_file,
|
||||||
|
csv_params=self.CSV_PARAMS_,
|
||||||
|
cache_dir=self.get_temp_dir(),
|
||||||
|
)
|
||||||
|
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||||
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=5,
|
||||||
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
model_name=bert_spec.name,
|
||||||
|
)
|
||||||
|
ds_cache_files = dataset.tfrecord_cache_files
|
||||||
|
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
||||||
|
ds_cache_files
|
||||||
|
)
|
||||||
|
self.assertFalse(preprocessed_cache_files.is_cached())
|
||||||
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
|
self.assertTrue(preprocessed_cache_files.is_cached())
|
||||||
|
self.assertEqual(
|
||||||
|
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
|
||||||
|
)
|
||||||
|
|
||||||
|
# The second time running preprocessor, it should load from cache directly
|
||||||
|
mock_stdout = io.StringIO()
|
||||||
|
with mock.patch('sys.stdout', mock_stdout):
|
||||||
|
_ = bert_preprocessor.preprocess(dataset)
|
||||||
|
self.assertEqual(
|
||||||
|
mock_stdout.getvalue(),
|
||||||
|
'Using existing cache files at'
|
||||||
|
f' {preprocessed_cache_files.cache_prefix}\n',
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
|
||||||
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=seq_len,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
model_name=bert_spec.name,
|
||||||
|
)
|
||||||
|
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||||
|
return new_cf.cache_prefix_filename
|
||||||
|
|
||||||
|
def test_bert_get_tfrecord_cache_files(self):
|
||||||
|
# Test to ensure regenerated cache_files have different prefixes
|
||||||
|
all_cf_prefixes = set()
|
||||||
|
cf = cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='cache_prefix',
|
||||||
|
cache_dir=self.get_temp_dir(),
|
||||||
|
num_shards=1,
|
||||||
|
)
|
||||||
|
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
|
||||||
|
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
|
||||||
|
new_cf = cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='new_cache_prefix',
|
||||||
|
cache_dir=self.get_temp_dir(),
|
||||||
|
num_shards=1,
|
||||||
|
)
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
|
||||||
|
|
||||||
|
# Each item of all_cf_prefixes should be unique, so 7 total.
|
||||||
|
self.assertLen(all_cf_prefixes, 7)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Load compressed models from tensorflow_hub
|
# Load compressed models from tensorflow_hub
|
||||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "mask",
|
"name": "segment_ids",
|
||||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||||
"content": {
|
"content": {
|
||||||
"content_properties_type": "FeatureProperties",
|
"content_properties_type": "FeatureProperties",
|
||||||
"content_properties": {
|
"content_properties": {
|
||||||
|
@ -27,8 +27,8 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "segment_ids",
|
"name": "mask",
|
||||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||||
"content": {
|
"content": {
|
||||||
"content_properties_type": "FeatureProperties",
|
"content_properties_type": "FeatureProperties",
|
||||||
"content_properties": {
|
"content_properties": {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import tensorflow_hub as hub
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
from mediapipe.model_maker.python.core.tasks import classifier
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
from mediapipe.model_maker.python.core.utils import metrics
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
|
||||||
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||||
):
|
):
|
||||||
text_classifier = (
|
text_classifier = _BertClassifier.create_bert_classifier(
|
||||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
train_data, validation_data, options
|
||||||
options,
|
)
|
||||||
train_data.label_names))
|
|
||||||
elif (options.supported_model ==
|
elif (options.supported_model ==
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
text_classifier = (
|
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
|
||||||
_AverageWordEmbeddingClassifier
|
train_data, validation_data, options
|
||||||
.create_average_word_embedding_classifier(train_data, validation_data,
|
)
|
||||||
options,
|
|
||||||
train_data.label_names))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model {options.supported_model}")
|
raise ValueError(f"Unknown model {options.supported_model}")
|
||||||
|
|
||||||
|
@ -166,28 +164,8 @@ class TextClassifier(classifier.Classifier):
|
||||||
processed_data = self._text_preprocessor.preprocess(data)
|
processed_data = self._text_preprocessor.preprocess(data)
|
||||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||||
|
|
||||||
additional_metrics = []
|
with self._hparams.get_strategy().scope():
|
||||||
if desired_precisions and len(data.label_names) == 2:
|
return self._model.evaluate(dataset)
|
||||||
for precision in desired_precisions:
|
|
||||||
additional_metrics.append(
|
|
||||||
metrics.BinarySparseRecallAtPrecision(
|
|
||||||
precision, name=f"recall_at_precision_{precision}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if desired_recalls and len(data.label_names) == 2:
|
|
||||||
for recall in desired_recalls:
|
|
||||||
additional_metrics.append(
|
|
||||||
metrics.BinarySparsePrecisionAtRecall(
|
|
||||||
recall, name=f"precision_at_recall_{recall}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
metric_functions = self._metric_functions + additional_metrics
|
|
||||||
self._model.compile(
|
|
||||||
optimizer=self._optimizer,
|
|
||||||
loss=self._loss_function,
|
|
||||||
metrics=metric_functions,
|
|
||||||
)
|
|
||||||
return self._model.evaluate(dataset)
|
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
self,
|
self,
|
||||||
|
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_average_word_embedding_classifier(
|
def create_average_word_embedding_classifier(
|
||||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
cls,
|
||||||
|
train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset,
|
||||||
options: text_classifier_options.TextClassifierOptions,
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
) -> "_AverageWordEmbeddingClassifier":
|
||||||
"""Creates, trains, and returns an Average Word Embedding classifier.
|
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data.
|
validation_data: Validation data.
|
||||||
options: Options for creating and training the text classifier.
|
options: Options for creating and training the text classifier.
|
||||||
label_names: Label names used in the data.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Average Word Embedding classifier.
|
An Average Word Embedding classifier.
|
||||||
|
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = loss_functions.SparseFocalLoss(
|
||||||
self._metric_functions = [
|
self._hparams.gamma, self._num_classes
|
||||||
tf.keras.metrics.SparseCategoricalAccuracy(
|
)
|
||||||
"test_accuracy", dtype=tf.float32
|
self._metric_functions = self._create_metrics()
|
||||||
),
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
|
||||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
|
||||||
]
|
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_bert_classifier(
|
def create_bert_classifier(
|
||||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
cls,
|
||||||
|
train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset,
|
||||||
options: text_classifier_options.TextClassifierOptions,
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
label_names: Sequence[str]) -> "_BertClassifier":
|
) -> "_BertClassifier":
|
||||||
"""Creates, trains, and returns a BERT-based classifier.
|
"""Creates, trains, and returns a BERT-based classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data.
|
validation_data: Validation data.
|
||||||
options: Options for creating and training the text classifier.
|
options: Options for creating and training the text classifier.
|
||||||
label_names: Label names used in the data.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A BERT-based classifier.
|
A BERT-based classifier.
|
||||||
|
@ -435,9 +411,59 @@ class _BertClassifier(TextClassifier):
|
||||||
seq_len=self._model_options.seq_len,
|
seq_len=self._model_options.seq_len,
|
||||||
do_lower_case=self._model_spec.do_lower_case,
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
uri=self._model_spec.downloaded_files.get_path(),
|
uri=self._model_spec.downloaded_files.get_path(),
|
||||||
|
model_name=self._model_spec.name,
|
||||||
)
|
)
|
||||||
return (self._text_preprocessor.preprocess(train_data),
|
return (
|
||||||
self._text_preprocessor.preprocess(validation_data))
|
self._text_preprocessor.preprocess(train_data),
|
||||||
|
self._text_preprocessor.preprocess(validation_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_metrics(self):
|
||||||
|
"""Creates metrics for training and evaluation.
|
||||||
|
|
||||||
|
The default metrics are accuracy, precision, and recall.
|
||||||
|
|
||||||
|
For binary classification tasks only (num_classes=2):
|
||||||
|
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
|
||||||
|
the desired_presisions and desired_recalls fields in BertHParams.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tf.keras.Metric subclasses which can be used with model.compile
|
||||||
|
"""
|
||||||
|
metric_functions = [
|
||||||
|
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
"accuracy", dtype=tf.float32
|
||||||
|
),
|
||||||
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
]
|
||||||
|
if self._num_classes == 2:
|
||||||
|
if self._hparams.desired_precisions:
|
||||||
|
for desired_precision in self._hparams.desired_precisions:
|
||||||
|
metric_functions.append(
|
||||||
|
metrics.BinarySparseRecallAtPrecision(
|
||||||
|
desired_precision,
|
||||||
|
name=f"recall_at_precision_{desired_precision}",
|
||||||
|
num_thresholds=1000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self._hparams.desired_recalls:
|
||||||
|
for desired_recall in self._hparams.desired_recalls:
|
||||||
|
metric_functions.append(
|
||||||
|
metrics.BinarySparseRecallAtPrecision(
|
||||||
|
desired_recall,
|
||||||
|
name=f"precision_at_recall_{desired_recall}",
|
||||||
|
num_thresholds=1000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self._hparams.desired_precisions or self._hparams.desired_recalls:
|
||||||
|
raise ValueError(
|
||||||
|
"desired_recalls and desired_precisions parameters are binary"
|
||||||
|
" metrics and not supported for num_classes > 2. Found"
|
||||||
|
f" num_classes: {self._num_classes}"
|
||||||
|
)
|
||||||
|
return metric_functions
|
||||||
|
|
||||||
def _create_model(self):
|
def _create_model(self):
|
||||||
"""Creates a BERT-based classifier model.
|
"""Creates a BERT-based classifier model.
|
||||||
|
@ -447,11 +473,20 @@ class _BertClassifier(TextClassifier):
|
||||||
"""
|
"""
|
||||||
encoder_inputs = dict(
|
encoder_inputs = dict(
|
||||||
input_word_ids=tf.keras.layers.Input(
|
input_word_ids=tf.keras.layers.Input(
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_word_ids",
|
||||||
|
),
|
||||||
input_mask=tf.keras.layers.Input(
|
input_mask=tf.keras.layers.Input(
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_mask",
|
||||||
|
),
|
||||||
input_type_ids=tf.keras.layers.Input(
|
input_type_ids=tf.keras.layers.Input(
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_type_ids",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
encoder = hub.KerasLayer(
|
encoder = hub.KerasLayer(
|
||||||
self._model_spec.downloaded_files.get_path(),
|
self._model_spec.downloaded_files.get_path(),
|
||||||
|
@ -493,16 +528,21 @@ class _BertClassifier(TextClassifier):
|
||||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||||
initial_learning_rate=initial_lr,
|
initial_learning_rate=initial_lr,
|
||||||
decay_steps=total_steps,
|
decay_steps=total_steps,
|
||||||
end_learning_rate=0.0,
|
end_learning_rate=self._hparams.end_learning_rate,
|
||||||
power=1.0)
|
power=1.0,
|
||||||
|
)
|
||||||
if warmup_steps:
|
if warmup_steps:
|
||||||
lr_schedule = model_util.WarmUp(
|
lr_schedule = model_util.WarmUp(
|
||||||
initial_learning_rate=initial_lr,
|
initial_learning_rate=initial_lr,
|
||||||
decay_schedule_fn=lr_schedule,
|
decay_schedule_fn=lr_schedule,
|
||||||
warmup_steps=warmup_steps)
|
warmup_steps=warmup_steps,
|
||||||
|
)
|
||||||
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
lr_schedule,
|
||||||
|
weight_decay=self._hparams.weight_decay,
|
||||||
|
epsilon=1e-6,
|
||||||
|
global_clipnorm=1.0,
|
||||||
)
|
)
|
||||||
self._optimizer.exclude_from_weight_decay(
|
self._optimizer.exclude_from_weight_decay(
|
||||||
var_names=["LayerNorm", "layer_norm", "bias"]
|
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||||
|
@ -510,7 +550,7 @@ class _BertClassifier(TextClassifier):
|
||||||
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||||
self._optimizer = tfa_optimizers.LAMB(
|
self._optimizer = tfa_optimizers.LAMB(
|
||||||
lr_schedule,
|
lr_schedule,
|
||||||
weight_decay_rate=0.01,
|
weight_decay_rate=self._hparams.weight_decay,
|
||||||
epsilon=1e-6,
|
epsilon=1e-6,
|
||||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||||
global_clipnorm=1.0,
|
global_clipnorm=1.0,
|
||||||
|
|
|
@ -84,8 +84,8 @@ def run(data_dir,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
# Gets evaluation results.
|
# Gets evaluation results.
|
||||||
_, acc = model.evaluate(validation_data)
|
metrics = model.evaluate(validation_data)
|
||||||
print('Eval accuracy: %f' % acc)
|
print('Eval accuracy: %f' % metrics[1])
|
||||||
|
|
||||||
model.export_model(quantization_config=quantization_config)
|
model.export_model(quantization_config=quantization_config)
|
||||||
model.export_labels(export_dir=options.hparams.export_dir)
|
model.export_labels(export_dir=options.hparams.export_dir)
|
||||||
|
|
|
@ -16,17 +16,17 @@ import csv
|
||||||
import filecmp
|
import filecmp
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
from mediapipe.model_maker.python.text import text_classifier
|
from mediapipe.model_maker.python.text import text_classifier
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip('b/275624089')
|
class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
class TextClassifierTest(tf.test.TestCase):
|
|
||||||
|
|
||||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||||
|
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
options))
|
options))
|
||||||
|
|
||||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
metrics = average_word_embedding_classifier.evaluate(validation_data)
|
||||||
self.assertGreaterEqual(accuracy, 0.0)
|
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||||
|
|
||||||
# Test export_model
|
# Test export_model
|
||||||
average_word_embedding_classifier.export_model()
|
average_word_embedding_classifier.export_model()
|
||||||
|
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
filecmp.cmp(
|
filecmp.cmp(
|
||||||
output_metadata_file,
|
output_metadata_file,
|
||||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
||||||
shallow=False))
|
shallow=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_and_train_bert(self):
|
@parameterized.named_parameters(
|
||||||
|
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
|
||||||
|
# dict(
|
||||||
|
# testcase_name='mobilebert',
|
||||||
|
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
|
# ),
|
||||||
|
dict(
|
||||||
|
testcase_name='exbert',
|
||||||
|
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_create_and_train_bert(self, supported_model):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=supported_model,
|
||||||
model_options=text_classifier.BertModelOptions(
|
model_options=text_classifier.BertModelOptions(
|
||||||
do_fine_tuning=False, seq_len=2
|
do_fine_tuning=False, seq_len=2
|
||||||
),
|
),
|
||||||
|
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
bert_classifier = text_classifier.TextClassifier.create(
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
train_data, validation_data, options)
|
train_data, validation_data, options)
|
||||||
|
|
||||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
metrics = bert_classifier.evaluate(validation_data)
|
||||||
self.assertGreaterEqual(accuracy, 0.0)
|
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||||
|
|
||||||
# Test export_model
|
# Test export_model
|
||||||
bert_classifier.export_model()
|
bert_classifier.export_model()
|
||||||
|
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_label_mismatch(self):
|
def test_label_mismatch(self):
|
||||||
options = (
|
options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
|
||||||
supported_model=(
|
)
|
||||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
|
|
||||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
|
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
|
||||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
|
validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
'Training data label names .* not equal to validation data label names'
|
'Training data label names .* not equal to validation data label names',
|
||||||
):
|
):
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
text_classifier.TextClassifier.create(
|
||||||
options)
|
train_data, validation_data, options
|
||||||
|
)
|
||||||
|
|
||||||
def test_options_mismatch(self):
|
def test_options_mismatch(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
|
|
||||||
avg_options = (
|
avg_options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
|
||||||
supported_model=(
|
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
|
||||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
)
|
||||||
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
|
with self.assertRaisesWithLiteralMatch(
|
||||||
with self.assertRaisesRegex(
|
ValueError,
|
||||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
' SupportedModels.EXBERT_CLASSIFIER',
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
):
|
||||||
avg_options)
|
text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, avg_options
|
||||||
|
)
|
||||||
|
|
||||||
bert_options = (
|
bert_options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(
|
||||||
supported_model=(text_classifier.SupportedModels
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
),
|
||||||
model_options=text_classifier.BertModelOptions()))
|
model_options=text_classifier.BertModelOptions(),
|
||||||
with self.assertRaisesRegex(
|
)
|
||||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
with self.assertRaisesWithLiteralMatch(
|
||||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
ValueError,
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
|
||||||
bert_options)
|
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
|
||||||
|
):
|
||||||
|
text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, bert_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bert_loss_and_metrics_creation(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
|
||||||
|
hparams = text_classifier.BertHParams(
|
||||||
|
desired_recalls=[0.2],
|
||||||
|
desired_precisions=[0.9],
|
||||||
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off',
|
||||||
|
gamma=3.5,
|
||||||
|
)
|
||||||
|
options = text_classifier.TextClassifierOptions(
|
||||||
|
supported_model=supported_model, hparams=hparams
|
||||||
|
)
|
||||||
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, options
|
||||||
|
)
|
||||||
|
loss_fn = bert_classifier._loss_function
|
||||||
|
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
|
||||||
|
self.assertEqual(loss_fn._gamma, 3.5)
|
||||||
|
self.assertEqual(loss_fn._num_classes, 2)
|
||||||
|
metric_names = [m.name for m in bert_classifier._metric_functions]
|
||||||
|
expected_metric_names = [
|
||||||
|
'accuracy',
|
||||||
|
'recall',
|
||||||
|
'precision',
|
||||||
|
'precision_at_recall_0.2',
|
||||||
|
'recall_at_precision_0.9',
|
||||||
|
]
|
||||||
|
self.assertCountEqual(metric_names, expected_metric_names)
|
||||||
|
|
||||||
|
# Non-binary data
|
||||||
|
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
|
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError,
|
||||||
|
'desired_recalls and desired_precisions parameters are binary metrics'
|
||||||
|
' and not supported for num_classes > 2. Found num_classes: 3',
|
||||||
|
):
|
||||||
|
text_classifier.TextClassifier.create(data, data, options)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
', '.join(label_names),
|
', '.join(label_names),
|
||||||
)
|
)
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
dataset=image_label_ds,
|
||||||
|
label_names=label_names,
|
||||||
|
size=all_image_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# Placeholder for internal Python strict test compatibility macro.
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
# Placeholder for internal Python strict library and test compatibility macro.
|
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
|
|
@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=hand_embedding_label_ds,
|
dataset=hand_embedding_label_ds,
|
||||||
|
label_names=label_names,
|
||||||
size=len(valid_hand_data),
|
size=len(valid_hand_data),
|
||||||
label_names=label_names)
|
)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# Placeholder for internal Python strict library and test compatibility macro.
|
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||||
# Placeholder for internal Python library rule.
|
# Placeholder for internal Python library rule.
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
|
@ -15,28 +15,12 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.vision.core import image_utils
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
|
||||||
|
|
||||||
def _create_data(
|
|
||||||
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
|
|
||||||
label_names: List[str]
|
|
||||||
) -> Optional[classification_dataset.ClassificationDataset]:
|
|
||||||
"""Creates a Dataset object from tfds data."""
|
|
||||||
if name not in data:
|
|
||||||
return None
|
|
||||||
data = data[name]
|
|
||||||
data = data.map(lambda a: (a['image'], a['label']))
|
|
||||||
size = info.splits[name].num_examples
|
|
||||||
return Dataset(data, size, label_names)
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(classification_dataset.ClassificationDataset):
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
"""Dataset library for image classifier."""
|
"""Dataset library for image classifier."""
|
||||||
|
|
||||||
|
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||||
all_label_size, ', '.join(label_names))
|
all_label_size, ', '.join(label_names))
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds, size=all_image_size, label_names=label_names)
|
dataset=image_label_ds, label_names=label_names, size=all_image_size
|
||||||
|
)
|
||||||
|
|
|
@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
|
data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
|
||||||
train_data, test_data = data.split(fraction=0.5)
|
train_data, test_data = data.split(fraction=0.5)
|
||||||
|
|
||||||
self.assertLen(train_data, 2)
|
self.assertLen(train_data, 2)
|
||||||
|
|
|
@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
ds = tf.data.Dataset.from_generator(
|
ds = tf.data.Dataset.from_generator(
|
||||||
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
||||||
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
||||||
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
|
data = image_classifier.Dataset(
|
||||||
['cyan', 'magenta', 'yellow'])
|
ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# Placeholder for internal Python strict library and test compatibility macro.
|
# Placeholder for internal Python strict binary and library compatibility macro.
|
||||||
# Placeholder for internal Python strict test compatibility macro.
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
@ -54,6 +54,7 @@ py_library(
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset_util",
|
":dataset_util",
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -73,6 +74,7 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset_util",
|
name = "dataset_util",
|
||||||
srcs = ["dataset_util.py"],
|
srcs = ["dataset_util.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/data:cache_files"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||||
from official.vision.dataloaders import tf_example_decoder
|
from official.vision.dataloaders import tf_example_decoder
|
||||||
|
@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
ValueError: If the label_name for id 0 is set to something other than
|
ValueError: If the label_name for id 0 is set to something other than
|
||||||
the 'background' class.
|
the 'background' class.
|
||||||
"""
|
"""
|
||||||
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
|
tfrecord_cache_files = dataset_util.get_cache_files_coco(
|
||||||
if not dataset_util.is_cached(cache_files):
|
data_dir, cache_dir
|
||||||
|
)
|
||||||
|
if not tfrecord_cache_files.is_cached():
|
||||||
label_map = dataset_util.get_label_map_coco(data_dir)
|
label_map = dataset_util.get_label_map_coco(data_dir)
|
||||||
cache_writer = dataset_util.COCOCacheFilesWriter(
|
cache_writer = dataset_util.COCOCacheFilesWriter(
|
||||||
label_map=label_map, max_num_images=max_num_images
|
label_map=label_map, max_num_images=max_num_images
|
||||||
)
|
)
|
||||||
cache_writer.write_files(cache_files, data_dir)
|
cache_writer.write_files(tfrecord_cache_files, data_dir)
|
||||||
return cls.from_cache(cache_files.cache_prefix)
|
return cls.from_cache(tfrecord_cache_files)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pascal_voc_folder(
|
def from_pascal_voc_folder(
|
||||||
|
@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the input data directory is empty.
|
ValueError: if the input data directory is empty.
|
||||||
"""
|
"""
|
||||||
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
|
tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||||
if not dataset_util.is_cached(cache_files):
|
data_dir, cache_dir
|
||||||
|
)
|
||||||
|
if not tfrecord_cache_files.is_cached():
|
||||||
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
||||||
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
||||||
label_map=label_map, max_num_images=max_num_images
|
label_map=label_map, max_num_images=max_num_images
|
||||||
)
|
)
|
||||||
cache_writer.write_files(cache_files, data_dir)
|
cache_writer.write_files(tfrecord_cache_files, data_dir)
|
||||||
|
|
||||||
return cls.from_cache(cache_files.cache_prefix)
|
return cls.from_cache(tfrecord_cache_files)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cache(cls, cache_prefix: str) -> 'Dataset':
|
def from_cache(
|
||||||
|
cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles
|
||||||
|
) -> 'Dataset':
|
||||||
"""Loads the TFRecord data from cache.
|
"""Loads the TFRecord data from cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_prefix: The cache prefix including the cache directory and the cache
|
tfrecord_cache_files: The TFRecordCacheFiles object containing the already
|
||||||
prefix filename, e.g: '/tmp/cache/train'.
|
cached TFRecord and metadata files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ObjectDetectorDataset object.
|
ObjectDetectorDataset object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if tfrecord_cache_files are not already cached.
|
||||||
"""
|
"""
|
||||||
# Get TFRecord Files
|
if not tfrecord_cache_files.is_cached():
|
||||||
tfrecord_file_pattern = cache_prefix + '*.tfrecord'
|
raise ValueError(
|
||||||
matched_files = tf.io.gfile.glob(tfrecord_file_pattern)
|
'Cache files must be already cached to use the from_cache method.'
|
||||||
if not matched_files:
|
)
|
||||||
raise ValueError('TFRecord files are empty.')
|
|
||||||
|
|
||||||
# Load meta_data.
|
metadata = tfrecord_cache_files.load_metadata()
|
||||||
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
|
|
||||||
if not tf.io.gfile.exists(meta_data_file):
|
|
||||||
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
|
|
||||||
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
|
|
||||||
meta_data = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
dataset = tf.data.TFRecordDataset(matched_files)
|
dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files)
|
||||||
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
||||||
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
|
||||||
label_map = meta_data['label_map']
|
label_map = metadata['label_map']
|
||||||
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
||||||
|
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=dataset, size=meta_data['size'], label_names=label_names
|
dataset=dataset, label_names=label_names, size=metadata['size']
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,25 +15,20 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import dataclasses
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
from official.vision.data import tfrecord_lib
|
from official.vision.data import tfrecord_lib
|
||||||
|
|
||||||
|
|
||||||
# Suffix of the meta data file name.
|
|
||||||
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
|
|
||||||
|
|
||||||
|
|
||||||
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
||||||
"""Gets a named child from an XML Element node.
|
"""Gets a named child from an XML Element node.
|
||||||
|
|
||||||
|
@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str:
|
||||||
return os.path.basename(os.path.abspath(data_dir))
|
return os.path.basename(os.path.abspath(data_dir))
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class CacheFiles:
|
|
||||||
"""Cache files for object detection."""
|
|
||||||
|
|
||||||
cache_prefix: str
|
|
||||||
tfrecord_files: Sequence[str]
|
|
||||||
meta_data_file: str
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_files(
|
def _get_cache_files(
|
||||||
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
||||||
) -> CacheFiles:
|
) -> cache_files.TFRecordCacheFiles:
|
||||||
"""Creates an object of CacheFiles class.
|
"""Creates an object of CacheFiles class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -96,28 +82,16 @@ def _get_cache_files(
|
||||||
An object of CacheFiles class.
|
An object of CacheFiles class.
|
||||||
"""
|
"""
|
||||||
cache_dir = _get_cache_dir_or_create(cache_dir)
|
cache_dir = _get_cache_dir_or_create(cache_dir)
|
||||||
# The cache prefix including the cache directory and the cache prefix
|
return cache_files.TFRecordCacheFiles(
|
||||||
# filename, e.g: '/tmp/cache/train'.
|
cache_prefix_filename=cache_prefix_filename,
|
||||||
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
|
cache_dir=cache_dir,
|
||||||
tf.compat.v1.logging.info(
|
num_shards=num_shards,
|
||||||
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
|
|
||||||
% (cache_dir, cache_prefix_filename, cache_prefix)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cached files including the TFRecord files and the meta data file.
|
|
||||||
tfrecord_files = [
|
|
||||||
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
|
|
||||||
for i in range(num_shards)
|
|
||||||
]
|
|
||||||
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
|
|
||||||
return CacheFiles(
|
|
||||||
cache_prefix=cache_prefix,
|
|
||||||
tfrecord_files=tuple(tfrecord_files),
|
|
||||||
meta_data_file=meta_data_file,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
def get_cache_files_coco(
|
||||||
|
data_dir: str, cache_dir: str
|
||||||
|
) -> cache_files.TFRecordCacheFiles:
|
||||||
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
def get_cache_files_pascal_voc(
|
||||||
|
data_dir: str, cache_dir: str
|
||||||
|
) -> cache_files.TFRecordCacheFiles:
|
||||||
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||||
|
|
||||||
|
|
||||||
def is_cached(cache_files: CacheFiles) -> bool:
|
|
||||||
"""Checks whether cache files are already cached."""
|
|
||||||
all_cached_files = list(cache_files.tfrecord_files) + [
|
|
||||||
cache_files.meta_data_file
|
|
||||||
]
|
|
||||||
return all(tf.io.gfile.exists(path) for path in all_cached_files)
|
|
||||||
|
|
||||||
|
|
||||||
class CacheFilesWriter(abc.ABC):
|
class CacheFilesWriter(abc.ABC):
|
||||||
"""CacheFilesWriter class to write the cached files."""
|
"""CacheFilesWriter class to write the cached files."""
|
||||||
|
|
||||||
|
@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC):
|
||||||
self.label_map = label_map
|
self.label_map = label_map
|
||||||
self.max_num_images = max_num_images
|
self.max_num_images = max_num_images
|
||||||
|
|
||||||
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
|
def write_files(
|
||||||
"""Writes TFRecord and meta_data files.
|
self,
|
||||||
|
tfrecord_cache_files: cache_files.TFRecordCacheFiles,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Writes TFRecord and metadata files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_files: CacheFiles object including a list of TFRecord files and the
|
tfrecord_cache_files: TFRecordCacheFiles object including a list of
|
||||||
meta data yaml file to save the meta_data including data size and
|
TFRecord files and the meta data yaml file to save the metadata
|
||||||
label_map.
|
including data size and label_map.
|
||||||
*args: Non-keyword of parameters used in the `_get_example` method.
|
*args: Non-keyword of parameters used in the `_get_example` method.
|
||||||
**kwargs: Keyword parameters used in the `_get_example` method.
|
**kwargs: Keyword parameters used in the `_get_example` method.
|
||||||
"""
|
"""
|
||||||
writers = [
|
writers = tfrecord_cache_files.get_writers()
|
||||||
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
|
|
||||||
]
|
|
||||||
|
|
||||||
# Writes tf.Example into TFRecord files.
|
# Writes tf.Example into TFRecord files.
|
||||||
size = 0
|
size = 0
|
||||||
|
@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC):
|
||||||
for writer in writers:
|
for writer in writers:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
# Writes meta_data into meta_data_file.
|
# Writes metadata into metadata_file.
|
||||||
meta_data = {'size': size, 'label_map': self.label_map}
|
metadata = {'size': size, 'label_map': self.label_map}
|
||||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
|
tfrecord_cache_files.save_metadata(metadata)
|
||||||
yaml.dump(meta_data, f)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _get_example(self, *args, **kwargs):
|
def _get_example(self, *args, **kwargs):
|
||||||
|
|
|
@ -19,7 +19,6 @@ import shutil
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision.core import test_utils
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||||
|
@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _assert_cache_files_equal(self, cf1, cf2):
|
def _assert_cache_files_equal(self, cf1, cf2):
|
||||||
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||||
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
self.assertEqual(cf1.num_shards, cf2.num_shards)
|
||||||
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
|
|
||||||
|
|
||||||
def _assert_cache_files_not_equal(self, cf1, cf2):
|
def _assert_cache_files_not_equal(self, cf1, cf2):
|
||||||
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||||
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
|
||||||
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
|
|
||||||
|
|
||||||
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
||||||
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
||||||
|
@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||||
)
|
)
|
||||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
|
||||||
|
|
||||||
def test_matching_get_cache_files_coco(self):
|
def test_matching_get_cache_files_coco(self):
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
|
@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||||
)
|
)
|
||||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
|
||||||
|
|
||||||
def test_matching_get_cache_files_pascal_voc(self):
|
def test_matching_get_cache_files_pascal_voc(self):
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
|
@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
cache_files = dataset_util.get_cache_files_coco(
|
cache_files = dataset_util.get_cache_files_coco(
|
||||||
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
||||||
)
|
)
|
||||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
self.assertFalse(cache_files.is_cached())
|
||||||
with open(cache_files.tfrecord_files[0], 'w') as f:
|
with open(cache_files.tfrecord_files[0], 'w') as f:
|
||||||
f.write('test')
|
f.write('test')
|
||||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
self.assertFalse(cache_files.is_cached())
|
||||||
with open(cache_files.meta_data_file, 'w') as f:
|
with open(cache_files.metadata_file, 'w') as f:
|
||||||
f.write('test')
|
f.write('test')
|
||||||
self.assertTrue(dataset_util.is_cached(cache_files))
|
self.assertTrue(cache_files.is_cached())
|
||||||
|
|
||||||
def test_get_label_map_coco(self):
|
def test_get_label_map_coco(self):
|
||||||
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||||
|
@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
||||||
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
||||||
|
|
||||||
# Checks the meta_data file
|
# Checks the metadata file
|
||||||
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
|
self.assertTrue(os.path.isfile(cache_files.metadata_file))
|
||||||
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
|
self.assertGreater(os.path.getsize(cache_files.metadata_file), 0)
|
||||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
|
metadata_dict = cache_files.load_metadata()
|
||||||
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
|
self.assertEqual(metadata_dict['size'], expected_size)
|
||||||
# Size is 3 because some examples are skipped for having poor bboxes
|
|
||||||
self.assertEqual(meta_data_dict['size'], expected_size)
|
|
||||||
|
|
||||||
def test_coco_cache_files_writer(self):
|
def test_coco_cache_files_writer(self):
|
||||||
tempdir = self.create_tempdir()
|
tempdir = self.create_tempdir()
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user