Use model bundle for gesture recognizer.

PiperOrigin-RevId: 482960305
This commit is contained in:
MediaPipe Team 2022-10-21 21:52:40 -07:00 committed by Copybara-Service
parent 404323f631
commit d8006a2f87
15 changed files with 309 additions and 229 deletions

View File

@ -62,13 +62,19 @@ cc_library(
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
@ -93,10 +99,14 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
@ -140,6 +150,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
@ -112,57 +113,38 @@ CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<GestureRecognizerGraphOptionsProto> std::unique_ptr<GestureRecognizerGraphOptionsProto>
ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>(); auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE;
// TODO remove these workarounds for base options of subgraphs.
// Configure hand detector options. // Configure hand detector options.
auto base_options_proto_for_hand_detector =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_hand_detector)));
base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode);
auto* hand_detector_graph_options = auto* hand_detector_graph_options =
options_proto->mutable_hand_landmarker_graph_options() options_proto->mutable_hand_landmarker_graph_options()
->mutable_hand_detector_graph_options(); ->mutable_hand_detector_graph_options();
hand_detector_graph_options->mutable_base_options()->Swap(
base_options_proto_for_hand_detector.get());
hand_detector_graph_options->set_num_hands(options->num_hands); hand_detector_graph_options->set_num_hands(options->num_hands);
hand_detector_graph_options->set_min_detection_confidence( hand_detector_graph_options->set_min_detection_confidence(
options->min_hand_detection_confidence); options->min_hand_detection_confidence);
// Configure hand landmark detector options. // Configure hand landmark detector options.
auto base_options_proto_for_hand_landmarker =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_hand_landmarker)));
base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode);
auto* hand_landmarks_detector_graph_options =
options_proto->mutable_hand_landmarker_graph_options()
->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->mutable_base_options()->Swap(
base_options_proto_for_hand_landmarker.get());
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
auto* hand_landmarker_graph_options = auto* hand_landmarker_graph_options =
options_proto->mutable_hand_landmarker_graph_options(); options_proto->mutable_hand_landmarker_graph_options();
hand_landmarker_graph_options->set_min_tracking_confidence( hand_landmarker_graph_options->set_min_tracking_confidence(
options->min_tracking_confidence); options->min_tracking_confidence);
auto* hand_landmarks_detector_graph_options =
hand_landmarker_graph_options
->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
// Configure hand gesture recognizer options. // Configure hand gesture recognizer options.
auto base_options_proto_for_gesture_recognizer =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_gesture_recognizer)));
base_options_proto_for_gesture_recognizer->set_use_stream_mode(
use_stream_mode);
auto* hand_gesture_recognizer_graph_options = auto* hand_gesture_recognizer_graph_options =
options_proto->mutable_hand_gesture_recognizer_graph_options(); options_proto->mutable_hand_gesture_recognizer_graph_options();
hand_gesture_recognizer_graph_options->mutable_base_options()->Swap(
base_options_proto_for_gesture_recognizer.get());
if (options->min_gesture_confidence >= 0) { if (options->min_gesture_confidence >= 0) {
hand_gesture_recognizer_graph_options->mutable_classifier_options() hand_gesture_recognizer_graph_options
->mutable_canned_gesture_classifier_graph_options()
->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence); ->set_score_threshold(options->min_gesture_confidence);
} }
return options_proto; return options_proto;

View File

@ -39,12 +39,6 @@ struct GestureRecognizerOptions {
// model file with metadata, accelerator options, op resolver, etc. // model file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options; tasks::core::BaseOptions base_options;
// TODO: remove these. Temporary solutions before bundle asset is
// ready.
tasks::core::BaseOptions base_options_for_hand_landmarker;
tasks::core::BaseOptions base_options_for_hand_detector;
tasks::core::BaseOptions base_options_for_gesture_recognizer;
// The running mode of the task. Default to the image mode. // The running mode of the task. Default to the image mode.
// GestureRecognizer has three running modes: // GestureRecognizer has three running modes:
// 1) The image mode for recognizing hand gestures on single image inputs. // 1) The image mode for recognizing hand gestures on single image inputs.

View File

@ -25,9 +25,13 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
@ -46,6 +50,8 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
GestureRecognizerGraphOptions; GestureRecognizerGraphOptions;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
@ -61,6 +67,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
constexpr char kHandGestureRecognizerBundleAssetName[] =
"hand_gesture_recognizer.task";
struct GestureRecognizerOutputs { struct GestureRecognizerOutputs {
Source<std::vector<ClassificationList>> gesture; Source<std::vector<ClassificationList>> gesture;
@ -70,6 +79,53 @@ struct GestureRecognizerOutputs {
Source<Image> image; Source<Image> image;
}; };
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
GestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_landmarker_file,
resources.GetModelFile(kHandLandmarkerBundleAssetName));
auto* hand_landmarker_graph_options =
options->mutable_hand_landmarker_graph_options();
SetExternalFile(hand_landmarker_file,
hand_landmarker_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
hand_landmarker_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(
const auto hand_gesture_recognizer_file,
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
auto* hand_gesture_recognizer_graph_options =
options->mutable_hand_gesture_recognizer_graph_options();
SetExternalFile(hand_gesture_recognizer_file,
hand_gesture_recognizer_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
hand_gesture_recognizer_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
if (!hand_gesture_recognizer_graph_options->base_options()
.acceleration()
.has_xnnpack() &&
!hand_gesture_recognizer_graph_options->base_options()
.acceleration()
.has_tflite()) {
hand_gesture_recognizer_graph_options->mutable_base_options()
->mutable_acceleration()
->mutable_xnnpack();
LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets "
<< "HandGestureRecognizerGraph acceleartion to Xnnpack.";
}
hand_gesture_recognizer_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus();
}
} // namespace } // namespace
// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs // A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs
@ -136,6 +192,21 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
Graph graph; Graph graph;
if (sc->Options<GestureRecognizerGraphOptions>()
.base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<GestureRecognizerGraphOptions>(sc));
// When the model resources cache service is available, filling in
// the file pointer meta in the subtasks' base options. Otherwise,
// providing the file contents instead.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<GestureRecognizerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph( BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(), *sc->MutableOptions<GestureRecognizerGraphOptions>(),

View File

@ -30,11 +30,17 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
@ -51,6 +57,8 @@ using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::processors:: using ::mediapipe::tasks::components::processors::
ConfigureTensorsToClassificationCalculator; ConfigureTensorsToClassificationCalculator;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions; HandGestureRecognizerGraphOptions;
@ -70,6 +78,14 @@ constexpr char kVectorTag[] = "VECTOR";
constexpr char kIndexTag[] = "INDEX"; constexpr char kIndexTag[] = "INDEX";
constexpr char kIterableTag[] = "ITERABLE"; constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite";
constexpr char kCannedGestureClassifierTFLiteName[] =
"canned_gesture_classifier.tflite";
struct SubTaskModelResources {
const core::ModelResources* gesture_embedder_model_resource;
const core::ModelResources* canned_gesture_classifier_model_resource;
};
Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix, Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
Graph& graph) { Graph& graph) {
@ -78,6 +94,41 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
return node[Output<std::vector<Tensor>>{"TENSORS"}]; return node[Output<std::vector<Tensor>>{"TENSORS"}];
} }
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandGestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file,
gesture_embedder_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
gesture_embedder_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile(
canned_gesture_classifier_file,
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
canned_gesture_classifier_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus();
}
} // namespace } // namespace
// A // A
@ -128,14 +179,29 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
if (sc->Options<HandGestureRecognizerGraphOptions>()
.base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
const auto* model_resources, const auto* model_asset_bundle_resources,
CreateModelResources<HandGestureRecognizerGraphOptions>(sc)); CreateModelAssetBundleResources<HandGestureRecognizerGraphOptions>(
sc));
// When the model resources cache service is available, filling in
// the file pointer meta in the subtasks' base options. Otherwise,
// providing the file contents instead.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<HandGestureRecognizerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(const auto sub_task_model_resources,
CreateSubTaskModelResources(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(auto hand_gestures,
auto hand_gestures,
BuildGestureRecognizerGraph( BuildGestureRecognizerGraph(
sc->Options<HandGestureRecognizerGraphOptions>(), *model_resources, sc->Options<HandGestureRecognizerGraphOptions>(),
sub_task_model_resources,
graph[Input<ClassificationList>(kHandednessTag)], graph[Input<ClassificationList>(kHandednessTag)],
graph[Input<NormalizedLandmarkList>(kLandmarksTag)], graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)], graph[Input<LandmarkList>(kWorldLandmarksTag)],
@ -146,9 +212,37 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
} }
private: private:
absl::StatusOr<SubTaskModelResources> CreateSubTaskModelResources(
SubgraphContext* sc) {
auto* options = sc->MutableOptions<HandGestureRecognizerGraphOptions>();
SubTaskModelResources sub_task_model_resources;
auto& gesture_embedder_model_asset =
*options->mutable_gesture_embedder_graph_options()
->mutable_base_options()
->mutable_model_asset();
ASSIGN_OR_RETURN(
sub_task_model_resources.gesture_embedder_model_resource,
CreateModelResources(sc,
std::make_unique<core::proto::ExternalFile>(
std::move(gesture_embedder_model_asset)),
"_gesture_embedder"));
auto& canned_gesture_classifier_model_asset =
*options->mutable_canned_gesture_classifier_graph_options()
->mutable_base_options()
->mutable_model_asset();
ASSIGN_OR_RETURN(
sub_task_model_resources.canned_gesture_classifier_model_resource,
CreateModelResources(
sc,
std::make_unique<core::proto::ExternalFile>(
std::move(canned_gesture_classifier_model_asset)),
"_canned_gesture_classifier"));
return sub_task_model_resources;
}
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph( absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
const HandGestureRecognizerGraphOptions& graph_options, const HandGestureRecognizerGraphOptions& graph_options,
const core::ModelResources& model_resources, const SubTaskModelResources& sub_task_model_resources,
Source<ClassificationList> handedness, Source<ClassificationList> handedness,
Source<NormalizedLandmarkList> hand_landmarks, Source<NormalizedLandmarkList> hand_landmarks,
Source<LandmarkList> hand_world_landmarks, Source<LandmarkList> hand_world_landmarks,
@ -209,17 +303,33 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
auto concatenated_tensors = concatenate_tensor_vector.Out(""); auto concatenated_tensors = concatenate_tensor_vector.Out("");
// Inference for static hand gesture recognition. // Inference for static hand gesture recognition.
// TODO add embedding step. auto& gesture_embedder_inference =
auto& inference = AddInference( AddInference(*sub_task_model_resources.gesture_embedder_model_resource,
model_resources, graph_options.base_options().acceleration(), graph); graph_options.gesture_embedder_graph_options()
concatenated_tensors >> inference.In(kTensorsTag); .base_options()
auto inference_output_tensors = inference.Out(kTensorsTag); .acceleration(),
graph);
concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag);
auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag);
auto& canned_gesture_classifier_inference = AddInference(
*sub_task_model_resources.canned_gesture_classifier_model_resource,
graph_options.canned_gesture_classifier_graph_options()
.base_options()
.acceleration(),
graph);
embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag);
auto inference_output_tensors =
canned_gesture_classifier_inference.Out(kTensorsTag);
auto& tensors_to_classification = auto& tensors_to_classification =
graph.AddNode("TensorsToClassificationCalculator"); graph.AddNode("TensorsToClassificationCalculator");
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
graph_options.classifier_options(), graph_options.canned_gesture_classifier_graph_options()
*model_resources.GetMetadataExtractor(), 0, .classifier_options(),
*sub_task_model_resources.canned_gesture_classifier_model_resource
->GetMetadataExtractor(),
0,
&tensors_to_classification.GetOptions< &tensors_to_classification.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>())); mediapipe::TensorsToClassificationCalculatorOptions>()));
inference_output_tensors >> tensors_to_classification.In(kTensorsTag); inference_output_tensors >> tensors_to_classification.In(kTensorsTag);

View File

@ -49,7 +49,6 @@ mediapipe_proto_library(
":gesture_embedder_graph_options_proto", ":gesture_embedder_graph_options_proto",
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,6 @@ syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto; package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
@ -37,15 +36,11 @@ message HandGestureRecognizerGraphOptions {
// Options for GestureEmbedder. // Options for GestureEmbedder.
optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2; optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2;
// Options for GestureClassifier of default gestures. // Options for GestureClassifier of canned gestures.
optional GestureClassifierGraphOptions optional GestureClassifierGraphOptions
canned_gesture_classifier_graph_options = 3; canned_gesture_classifier_graph_options = 3;
// Options for GestureClassifier of custom gestures. // Options for GestureClassifier of custom gestures.
optional GestureClassifierGraphOptions optional GestureClassifierGraphOptions
custom_gesture_classifier_graph_options = 4; custom_gesture_classifier_graph_options = 4;
// TODO: remove these. Temporary solutions before bundle asset is
// ready.
optional components.processors.proto.ClassifierOptions classifier_options = 5;
} }

View File

@ -92,18 +92,30 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_detector_file, ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName)); resources.GetModelFile(kHandDetectorTFLiteName));
auto* hand_detector_graph_options =
options->mutable_hand_detector_graph_options();
SetExternalFile(hand_detector_file, SetExternalFile(hand_detector_file,
options->mutable_hand_detector_graph_options() hand_detector_graph_options->mutable_base_options()
->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
is_copy); is_copy);
hand_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
hand_detector_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
auto* hand_landmarks_detector_graph_options =
options->mutable_hand_landmarks_detector_graph_options();
SetExternalFile(hand_landmarks_detector_file, SetExternalFile(hand_landmarks_detector_file,
options->mutable_hand_landmarks_detector_graph_options() hand_landmarks_detector_graph_options->mutable_base_options()
->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
is_copy); is_copy);
hand_landmarks_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
hand_landmarks_detector_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -67,7 +67,7 @@ using ::testing::proto::Approximately;
using ::testing::proto::Partially; using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task"; constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task";
constexpr char kLeftHandsImage[] = "left_hands.jpg"; constexpr char kLeftHandsImage[] = "left_hands.jpg";
constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg";

View File

@ -128,6 +128,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",

View File

@ -38,6 +38,7 @@ import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto;
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto;
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.HandGestureRecognizerGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.HandGestureRecognizerGraphOptionsProto;
import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto; import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto;
@ -300,13 +301,6 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
*/ */
public abstract Builder setRunningMode(RunningMode value); public abstract Builder setRunningMode(RunningMode value);
// TODO: remove these. Temporary solutions before bundle asset is ready.
public abstract Builder setBaseOptionsHandDetector(BaseOptions value);
public abstract Builder setBaseOptionsHandLandmarker(BaseOptions value);
public abstract Builder setBaseOptionsGestureRecognizer(BaseOptions value);
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */ /** Sets the maximum number of hands can be detected by the GestureRecognizer. */
public abstract Builder setNumHands(Integer value); public abstract Builder setNumHands(Integer value);
@ -366,13 +360,6 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
abstract BaseOptions baseOptions(); abstract BaseOptions baseOptions();
// TODO: remove these. Temporary solutions before bundle asset is ready.
abstract BaseOptions baseOptionsHandDetector();
abstract BaseOptions baseOptionsHandLandmarker();
abstract BaseOptions baseOptionsGestureRecognizer();
abstract RunningMode runningMode(); abstract RunningMode runningMode();
abstract Optional<Integer> numHands(); abstract Optional<Integer> numHands();
@ -405,22 +392,18 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
*/ */
@Override @Override
public CalculatorOptions convertToCalculatorOptionsProto() { public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptions()));
GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder = GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder =
GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder() GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
.build());
// Setup HandDetectorGraphOptions. // Setup HandDetectorGraphOptions.
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder
handDetectorGraphOptionsBuilder = handDetectorGraphOptionsBuilder =
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder() HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder();
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandDetector())));
numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands); numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands);
minHandDetectionConfidence() minHandDetectionConfidence()
.ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence); .ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence);
@ -428,19 +411,12 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
// Setup HandLandmarkerGraphOptions. // Setup HandLandmarkerGraphOptions.
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder
handLandmarksDetectorGraphOptionsBuilder = handLandmarksDetectorGraphOptionsBuilder =
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder() HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder();
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
minHandPresenceConfidence() minHandPresenceConfidence()
.ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); .ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence);
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder
handLandmarkerGraphOptionsBuilder = handLandmarkerGraphOptionsBuilder =
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder();
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE));
minTrackingConfidence() minTrackingConfidence()
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
handLandmarkerGraphOptionsBuilder handLandmarkerGraphOptionsBuilder
@ -450,16 +426,13 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
// Setup HandGestureRecognizerGraphOptions. // Setup HandGestureRecognizerGraphOptions.
HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder
handGestureRecognizerGraphOptionsBuilder = handGestureRecognizerGraphOptionsBuilder =
HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder() HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder();
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptionsGestureRecognizer())));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder(); ClassifierOptionsProto.ClassifierOptions.newBuilder();
minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold);
handGestureRecognizerGraphOptionsBuilder.setClassifierOptions( handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions(
classifierOptionsBuilder.build()); GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder()
.setClassifierOptions(classifierOptionsBuilder.build()));
taskOptionsBuilder taskOptionsBuilder
.setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build())

View File

@ -43,10 +43,7 @@ import org.junit.runners.Suite.SuiteClasses;
@RunWith(Suite.class) @RunWith(Suite.class)
@SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class}) @SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class})
public class GestureRecognizerTest { public class GestureRecognizerTest {
private static final String HAND_DETECTOR_MODEL_FILE = "palm_detection_full.tflite"; private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task";
private static final String HAND_LANDMARKER_MODEL_FILE = "hand_landmark_full.tflite";
private static final String GESTURE_RECOGNIZER_MODEL_FILE =
"cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite";
private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
@ -66,13 +63,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.build(); .build();
GestureRecognizer gestureRecognizer = GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -88,13 +81,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.build(); .build();
GestureRecognizer gestureRecognizer = GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -111,16 +100,12 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
// TODO update the confidence to be in range [0,1] after embedding model // TODO update the confidence to be in range [0,1] after embedding model
// and scoring calculator is integrated. // and scoring calculator is integrated.
.setMinGestureConfidence(3.0f) .setMinGestureConfidence(2.0f)
.build(); .build();
GestureRecognizer gestureRecognizer = GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -139,13 +124,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setNumHands(2) .setNumHands(2)
.build(); .build();
GestureRecognizer gestureRecognizer = GestureRecognizer gestureRecognizer =
@ -168,19 +149,7 @@ public class GestureRecognizerTest {
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder() BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setBaseOptionsHandDetector(
BaseOptions.builder()
.setModelAssetPath(HAND_DETECTOR_MODEL_FILE)
.build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_MODEL_FILE)
.build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE)
.build()) .build())
.setRunningMode(mode) .setRunningMode(mode)
.setResultListener((gestureRecognitionResult, inputImage) -> {}) .setResultListener((gestureRecognitionResult, inputImage) -> {})
@ -201,15 +170,7 @@ public class GestureRecognizerTest {
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder() BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setBaseOptionsHandDetector(
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE)
.build()) .build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.build()); .build());
@ -223,13 +184,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.build(); .build();
@ -252,13 +209,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.build(); .build();
@ -281,13 +234,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((gestureRecognitionResult, inputImage) -> {}) .setResultListener((gestureRecognitionResult, inputImage) -> {})
.build(); .build();
@ -311,13 +260,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.build(); .build();
@ -335,13 +280,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.build(); .build();
GestureRecognizer gestureRecognizer = GestureRecognizer gestureRecognizer =
@ -363,13 +304,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(actualResult, inputImage) -> { (actualResult, inputImage) -> {
@ -397,13 +334,9 @@ public class GestureRecognizerTest {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) BaseOptions.builder()
.setBaseOptionsHandDetector( .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) .build())
.setBaseOptionsHandLandmarker(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build())
.setBaseOptionsGestureRecognizer(
BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(actualResult, inputImage) -> { (actualResult, inputImage) -> {

View File

@ -35,7 +35,6 @@ mediapipe_files(srcs = [
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite", "deeplabv3.tflite",
"hand_landmark.task",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"left_hands.jpg", "left_hands.jpg",
@ -67,13 +66,13 @@ mediapipe_files(srcs = [
exports_files( exports_files(
srcs = [ srcs = [
"cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite",
"expected_left_down_hand_landmarks.prototxt", "expected_left_down_hand_landmarks.prototxt",
"expected_left_down_hand_rotated_landmarks.prototxt", "expected_left_down_hand_rotated_landmarks.prototxt",
"expected_left_up_hand_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt",
"expected_left_up_hand_rotated_landmarks.prototxt", "expected_left_up_hand_rotated_landmarks.prototxt",
"expected_right_down_hand_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt",
"expected_right_up_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt",
"gesture_recognizer.task",
], ],
) )
@ -119,9 +118,9 @@ filegroup(
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite", "deeplabv3.tflite",
"hand_landmark.task",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"hand_landmarker.task",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",

View File

@ -31,7 +31,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_bert_text_classifier_tflite", name = "com_google_mediapipe_bert_text_classifier_tflite",
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1663009542017720"], urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
) )
http_file( http_file(
@ -46,12 +46,6 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"],
) )
http_file(
name = "com_google_mediapipe_BUILD_orig",
sha256 = "650df617b3e125e0890f1b8c936cc64c9d975707f57e616b6430fc667ce315d4",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1665609930388174"],
)
http_file( http_file(
name = "com_google_mediapipe_burger_crop_jpg", name = "com_google_mediapipe_burger_crop_jpg",
sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50", sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50",
@ -127,7 +121,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_coco_ssd_mobilenet_v1_1_0_quant_2018_06_29_tflite", name = "com_google_mediapipe_coco_ssd_mobilenet_v1_1_0_quant_2018_06_29_tflite",
sha256 = "61d598093ed03ed41aa47c3a39a28ac01e960d6a810a5419b9a5016a1e9c469b", sha256 = "61d598093ed03ed41aa47c3a39a28ac01e960d6a810a5419b9a5016a1e9c469b",
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite?generation=1661875702588267"], urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite?generation=1666144700870810"],
) )
http_file( http_file(
@ -274,6 +268,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_two_hands.pbtxt?generation=1662745353586157"], urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_two_hands.pbtxt?generation=1662745353586157"],
) )
http_file(
name = "com_google_mediapipe_hand_landmarker_task",
sha256 = "2ed44f10872e87a5834b9b1130fb9ada30e107af2c6fcc4562ad788aca4e7bc4",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmarker.task?generation=1666153732577904"],
)
http_file( http_file(
name = "com_google_mediapipe_hand_landmark_full_tflite", name = "com_google_mediapipe_hand_landmark_full_tflite",
sha256 = "11c272b891e1a99ab034208e23937a8008388cf11ed2a9d776ed3d01d0ba00e3", sha256 = "11c272b891e1a99ab034208e23937a8008388cf11ed2a9d776ed3d01d0ba00e3",
@ -287,9 +287,9 @@ def external_files():
) )
http_file( http_file(
name = "com_google_mediapipe_hand_landmark_task", name = "com_google_mediapipe_hand_landmark_tflite",
sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95", sha256 = "bad88ac1fd144f034e00f075afcade4f3a21d0d09c41bee8dd50504dacd70efd",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"], urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.tflite?generation=1666153735814956"],
) )
http_file( http_file(
@ -475,7 +475,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_mobile_object_labeler_v1_tflite", name = "com_google_mediapipe_mobile_object_labeler_v1_tflite",
sha256 = "9400671e04685f5277edd3052a311cc51533de9da94255c52ebde1e18484c77c", sha256 = "9400671e04685f5277edd3052a311cc51533de9da94255c52ebde1e18484c77c",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_object_labeler_v1.tflite?generation=1661875846924538"], urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_object_labeler_v1.tflite?generation=1666144701839813"],
) )
http_file( http_file(