Merge branch 'master' into image-embedder-python
This commit is contained in:
		
						commit
						ba1ee5b404
					
				| 
						 | 
					@ -936,6 +936,7 @@ cc_test(
 | 
				
			||||||
        "//mediapipe/framework/tool:simulation_clock",
 | 
					        "//mediapipe/framework/tool:simulation_clock",
 | 
				
			||||||
        "//mediapipe/framework/tool:simulation_clock_executor",
 | 
					        "//mediapipe/framework/tool:simulation_clock_executor",
 | 
				
			||||||
        "//mediapipe/framework/tool:sink",
 | 
					        "//mediapipe/framework/tool:sink",
 | 
				
			||||||
 | 
					        "//mediapipe/util:packet_test_util",
 | 
				
			||||||
        "@com_google_absl//absl/time",
 | 
					        "@com_google_absl//absl/time",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,6 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
 | 
					#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/port/ret_check.h"
 | 
					 | 
				
			||||||
#include "mediapipe/framework/port/status.h"
 | 
					#include "mediapipe/framework/port/status.h"
 | 
				
			||||||
#include "mediapipe/util/header_util.h"
 | 
					#include "mediapipe/util/header_util.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS";
 | 
				
			||||||
// FlowLimiterCalculator provides limited support for multiple input streams.
 | 
					// FlowLimiterCalculator provides limited support for multiple input streams.
 | 
				
			||||||
// The first input stream is treated as the main input stream and successive
 | 
					// The first input stream is treated as the main input stream and successive
 | 
				
			||||||
// input streams are treated as auxiliary input streams.  The auxiliary input
 | 
					// input streams are treated as auxiliary input streams.  The auxiliary input
 | 
				
			||||||
// streams are limited to timestamps passed on the main input stream.
 | 
					// streams are limited to timestamps allowed by the "ALLOW" stream.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
class FlowLimiterCalculator : public CalculatorBase {
 | 
					class FlowLimiterCalculator : public CalculatorBase {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
| 
						 | 
					@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase {
 | 
				
			||||||
          cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
 | 
					          cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    input_queues_.resize(cc->Inputs().NumEntries(""));
 | 
					    input_queues_.resize(cc->Inputs().NumEntries(""));
 | 
				
			||||||
 | 
					    allowed_[Timestamp::Unset()] = true;
 | 
				
			||||||
    RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
 | 
					    RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
 | 
				
			||||||
    return absl::OkStatus();
 | 
					    return absl::OkStatus();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Returns true if an additional frame can be released for processing.
 | 
					 | 
				
			||||||
  // The "ALLOW" output stream indicates this condition at each input frame.
 | 
					 | 
				
			||||||
  bool ProcessingAllowed() {
 | 
					 | 
				
			||||||
    return frames_in_flight_.size() < options_.max_in_flight();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Outputs a packet indicating whether a frame was sent or dropped.
 | 
					 | 
				
			||||||
  void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
 | 
					 | 
				
			||||||
    if (cc->Outputs().HasTag(kAllowTag)) {
 | 
					 | 
				
			||||||
      cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Sets the timestamp bound or closes an output stream.
 | 
					 | 
				
			||||||
  void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
 | 
					 | 
				
			||||||
    if (bound > Timestamp::Max()) {
 | 
					 | 
				
			||||||
      stream->Close();
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      stream->SetNextTimestampBound(bound);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Returns true if a certain timestamp is being processed.
 | 
					 | 
				
			||||||
  bool IsInFlight(Timestamp timestamp) {
 | 
					 | 
				
			||||||
    return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
 | 
					 | 
				
			||||||
                     timestamp) != frames_in_flight_.end();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Releases input packets up to the latest settled input timestamp.
 | 
					 | 
				
			||||||
  void ProcessAuxiliaryInputs(CalculatorContext* cc) {
 | 
					 | 
				
			||||||
    Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
 | 
					 | 
				
			||||||
    for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
 | 
					 | 
				
			||||||
      // Release settled frames from each input queue.
 | 
					 | 
				
			||||||
      while (!input_queues_[i].empty() &&
 | 
					 | 
				
			||||||
             input_queues_[i].front().Timestamp() < settled_bound) {
 | 
					 | 
				
			||||||
        Packet packet = input_queues_[i].front();
 | 
					 | 
				
			||||||
        input_queues_[i].pop_front();
 | 
					 | 
				
			||||||
        if (IsInFlight(packet.Timestamp())) {
 | 
					 | 
				
			||||||
          cc->Outputs().Get("", i).AddPacket(packet);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      // Propagate each input timestamp bound.
 | 
					 | 
				
			||||||
      if (!input_queues_[i].empty()) {
 | 
					 | 
				
			||||||
        Timestamp bound = input_queues_[i].front().Timestamp();
 | 
					 | 
				
			||||||
        SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
 | 
					 | 
				
			||||||
      } else {
 | 
					 | 
				
			||||||
        Timestamp bound =
 | 
					 | 
				
			||||||
            cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
 | 
					 | 
				
			||||||
        SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Releases input packets allowed by the max_in_flight constraint.
 | 
					  // Releases input packets allowed by the max_in_flight constraint.
 | 
				
			||||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
					  absl::Status Process(CalculatorContext* cc) final {
 | 
				
			||||||
    options_ = tool::RetrieveOptions(options_, cc->Inputs());
 | 
					    options_ = tool::RetrieveOptions(options_, cc->Inputs());
 | 
				
			||||||
| 
						 | 
					@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ProcessAuxiliaryInputs(cc);
 | 
					    ProcessAuxiliaryInputs(cc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Discard old ALLOW ranges.
 | 
				
			||||||
 | 
					    Timestamp input_bound = InputTimestampBound(cc);
 | 
				
			||||||
 | 
					    auto first_range = std::prev(allowed_.upper_bound(input_bound));
 | 
				
			||||||
 | 
					    allowed_.erase(allowed_.begin(), first_range);
 | 
				
			||||||
    return absl::OkStatus();
 | 
					    return absl::OkStatus();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int LedgerSize() {
 | 
				
			||||||
 | 
					    int result = frames_in_flight_.size() + allowed_.size();
 | 
				
			||||||
 | 
					    for (const auto& queue : input_queues_) {
 | 
				
			||||||
 | 
					      result += queue.size();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  // Returns true if an additional frame can be released for processing.
 | 
				
			||||||
 | 
					  // The "ALLOW" output stream indicates this condition at each input frame.
 | 
				
			||||||
 | 
					  bool ProcessingAllowed() {
 | 
				
			||||||
 | 
					    return frames_in_flight_.size() < options_.max_in_flight();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Outputs a packet indicating whether a frame was sent or dropped.
 | 
				
			||||||
 | 
					  void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
 | 
				
			||||||
 | 
					    if (cc->Outputs().HasTag(kAllowTag)) {
 | 
				
			||||||
 | 
					      cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    allowed_[ts] = allow;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Returns true if a timestamp falls within a range of allowed timestamps.
 | 
				
			||||||
 | 
					  bool IsAllowed(Timestamp timestamp) {
 | 
				
			||||||
 | 
					    auto it = allowed_.upper_bound(timestamp);
 | 
				
			||||||
 | 
					    return std::prev(it)->second;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Sets the timestamp bound or closes an output stream.
 | 
				
			||||||
 | 
					  void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
 | 
				
			||||||
 | 
					    if (bound > Timestamp::Max()) {
 | 
				
			||||||
 | 
					      stream->Close();
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      stream->SetNextTimestampBound(bound);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Returns the lowest unprocessed input Timestamp.
 | 
				
			||||||
 | 
					  Timestamp InputTimestampBound(CalculatorContext* cc) {
 | 
				
			||||||
 | 
					    Timestamp result = Timestamp::Done();
 | 
				
			||||||
 | 
					    for (int i = 0; i < input_queues_.size(); ++i) {
 | 
				
			||||||
 | 
					      auto& queue = input_queues_[i];
 | 
				
			||||||
 | 
					      auto& stream = cc->Inputs().Get("", i);
 | 
				
			||||||
 | 
					      Timestamp bound = queue.empty()
 | 
				
			||||||
 | 
					                            ? stream.Value().Timestamp().NextAllowedInStream()
 | 
				
			||||||
 | 
					                            : queue.front().Timestamp();
 | 
				
			||||||
 | 
					      result = std::min(result, bound);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Releases input packets up to the latest settled input timestamp.
 | 
				
			||||||
 | 
					  void ProcessAuxiliaryInputs(CalculatorContext* cc) {
 | 
				
			||||||
 | 
					    Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
 | 
				
			||||||
 | 
					    for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
 | 
				
			||||||
 | 
					      // Release settled frames from each input queue.
 | 
				
			||||||
 | 
					      while (!input_queues_[i].empty() &&
 | 
				
			||||||
 | 
					             input_queues_[i].front().Timestamp() < settled_bound) {
 | 
				
			||||||
 | 
					        Packet packet = input_queues_[i].front();
 | 
				
			||||||
 | 
					        input_queues_[i].pop_front();
 | 
				
			||||||
 | 
					        if (IsAllowed(packet.Timestamp())) {
 | 
				
			||||||
 | 
					          cc->Outputs().Get("", i).AddPacket(packet);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Propagate each input timestamp bound.
 | 
				
			||||||
 | 
					      if (!input_queues_[i].empty()) {
 | 
				
			||||||
 | 
					        Timestamp bound = input_queues_[i].front().Timestamp();
 | 
				
			||||||
 | 
					        SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        Timestamp bound =
 | 
				
			||||||
 | 
					            cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
 | 
				
			||||||
 | 
					        SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  FlowLimiterCalculatorOptions options_;
 | 
					  FlowLimiterCalculatorOptions options_;
 | 
				
			||||||
  std::vector<std::deque<Packet>> input_queues_;
 | 
					  std::vector<std::deque<Packet>> input_queues_;
 | 
				
			||||||
  std::deque<Timestamp> frames_in_flight_;
 | 
					  std::deque<Timestamp> frames_in_flight_;
 | 
				
			||||||
 | 
					  std::map<Timestamp, bool> allowed_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
REGISTER_CALCULATOR(FlowLimiterCalculator);
 | 
					REGISTER_CALCULATOR(FlowLimiterCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/time/clock.h"
 | 
					#include "absl/time/clock.h"
 | 
				
			||||||
| 
						 | 
					@ -32,6 +33,7 @@
 | 
				
			||||||
#include "mediapipe/framework/tool/simulation_clock.h"
 | 
					#include "mediapipe/framework/tool/simulation_clock.h"
 | 
				
			||||||
#include "mediapipe/framework/tool/simulation_clock_executor.h"
 | 
					#include "mediapipe/framework/tool/simulation_clock_executor.h"
 | 
				
			||||||
#include "mediapipe/framework/tool/sink.h"
 | 
					#include "mediapipe/framework/tool/sink.h"
 | 
				
			||||||
 | 
					#include "mediapipe/util/packet_test_util.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -77,6 +79,77 @@ std::vector<T> PacketValues(const std::vector<Packet>& packets) {
 | 
				
			||||||
  return result;
 | 
					  return result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					std::vector<Packet> MakePackets(std::vector<std::pair<Timestamp, T>> contents) {
 | 
				
			||||||
 | 
					  std::vector<Packet> result;
 | 
				
			||||||
 | 
					  for (auto& entry : contents) {
 | 
				
			||||||
 | 
					    result.push_back(MakePacket<T>(entry.second).At(entry.first));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::string SourceString(Timestamp t) {
 | 
				
			||||||
 | 
					  return (t.IsSpecialValue())
 | 
				
			||||||
 | 
					             ? t.DebugString()
 | 
				
			||||||
 | 
					             : absl::StrCat("Timestamp(", t.DebugString(), ")");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename PacketContainer, typename PacketContent>
 | 
				
			||||||
 | 
					class PacketsEqMatcher
 | 
				
			||||||
 | 
					    : public ::testing::MatcherInterface<const PacketContainer&> {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  PacketsEqMatcher(PacketContainer packets) : packets_(packets) {}
 | 
				
			||||||
 | 
					  void DescribeTo(::std::ostream* os) const override {
 | 
				
			||||||
 | 
					    *os << "The expected packet contents: \n";
 | 
				
			||||||
 | 
					    Print(packets_, os);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  bool MatchAndExplain(
 | 
				
			||||||
 | 
					      const PacketContainer& value,
 | 
				
			||||||
 | 
					      ::testing::MatchResultListener* listener) const override {
 | 
				
			||||||
 | 
					    if (!Equals(packets_, value)) {
 | 
				
			||||||
 | 
					      if (listener->IsInterested()) {
 | 
				
			||||||
 | 
					        *listener << "The actual packet contents: \n";
 | 
				
			||||||
 | 
					        Print(value, listener->stream());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return true;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  bool Equals(const PacketContainer& c1, const PacketContainer& c2) const {
 | 
				
			||||||
 | 
					    if (c1.size() != c2.size()) {
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) {
 | 
				
			||||||
 | 
					      Packet p1 = *i1, p2 = *i2;
 | 
				
			||||||
 | 
					      if (p1.Timestamp() != p2.Timestamp() ||
 | 
				
			||||||
 | 
					          p1.Get<PacketContent>() != p2.Get<PacketContent>()) {
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return true;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  void Print(const PacketContainer& packets, ::std::ostream* os) const {
 | 
				
			||||||
 | 
					    for (auto it = packets.begin(); it != packets.end(); ++it) {
 | 
				
			||||||
 | 
					      const Packet& packet = *it;
 | 
				
			||||||
 | 
					      *os << (it == packets.begin() ? "{" : "") << "{"
 | 
				
			||||||
 | 
					          << SourceString(packet.Timestamp()) << ", "
 | 
				
			||||||
 | 
					          << packet.Get<PacketContent>() << "}"
 | 
				
			||||||
 | 
					          << (std::next(it) == packets.end() ? "}" : ", ");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const PacketContainer packets_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename PacketContainer, typename PacketContent>
 | 
				
			||||||
 | 
					::testing::Matcher<const PacketContainer&> PackestEq(
 | 
				
			||||||
 | 
					    const PacketContainer& packets) {
 | 
				
			||||||
 | 
					  return MakeMatcher(
 | 
				
			||||||
 | 
					      new PacketsEqMatcher<PacketContainer, PacketContent>(packets));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// A Calculator::Process callback function.
 | 
					// A Calculator::Process callback function.
 | 
				
			||||||
typedef std::function<absl::Status(const InputStreamShardSet&,
 | 
					typedef std::function<absl::Status(const InputStreamShardSet&,
 | 
				
			||||||
                                   OutputStreamShardSet*)>
 | 
					                                   OutputStreamShardSet*)>
 | 
				
			||||||
| 
						 | 
					@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
 | 
				
			||||||
      input_packets_[17], input_packets_[19], input_packets_[20],
 | 
					      input_packets_[17], input_packets_[19], input_packets_[20],
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  EXPECT_EQ(out_1_packets_, expected_output);
 | 
					  EXPECT_EQ(out_1_packets_, expected_output);
 | 
				
			||||||
  // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
 | 
					  // The timestamps released by FlowLimiterCalculator for in_1_sampled,
 | 
				
			||||||
 | 
					  // plus input_packets_[21].
 | 
				
			||||||
  std::vector<Packet> expected_output_2 = {
 | 
					  std::vector<Packet> expected_output_2 = {
 | 
				
			||||||
      input_packets_[0],  input_packets_[2],  input_packets_[4],
 | 
					      input_packets_[0],  input_packets_[2],  input_packets_[4],
 | 
				
			||||||
      input_packets_[14], input_packets_[17], input_packets_[19],
 | 
					      input_packets_[14], input_packets_[17], input_packets_[19],
 | 
				
			||||||
      input_packets_[20],
 | 
					      input_packets_[20], input_packets_[21],
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  EXPECT_EQ(out_2_packets, expected_output_2);
 | 
					  EXPECT_EQ(out_2_packets, expected_output_2);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
 | 
				
			||||||
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
 | 
					// The processing time "sleep_time" is reduced from 22ms to 12ms to create
 | 
				
			||||||
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
 | 
					// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
 | 
				
			||||||
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
 | 
					TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
 | 
				
			||||||
 | 
					  auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
 | 
				
			||||||
 | 
					  auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Configure the test.
 | 
					  // Configure the test.
 | 
				
			||||||
  SetUpInputData();
 | 
					  SetUpInputData();
 | 
				
			||||||
  SetUpSimulationClock();
 | 
					  SetUpSimulationClock();
 | 
				
			||||||
| 
						 | 
					@ -699,10 +776,9 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      )pb");
 | 
					      )pb");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
 | 
					  auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
 | 
				
			||||||
    max_in_flight: 1
 | 
					      R"pb(
 | 
				
			||||||
    max_in_queue: 0
 | 
					        max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000  # 100 ms
 | 
				
			||||||
    in_flight_timeout: 100000  # 100 ms
 | 
					 | 
				
			||||||
      )pb");
 | 
					      )pb");
 | 
				
			||||||
  std::map<std::string, Packet> side_packets = {
 | 
					  std::map<std::string, Packet> side_packets = {
 | 
				
			||||||
      {"limiter_options",
 | 
					      {"limiter_options",
 | 
				
			||||||
| 
						 | 
					@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
 | 
				
			||||||
      input_packets_[0],  input_packets_[2],  input_packets_[15],
 | 
					      input_packets_[0],  input_packets_[2],  input_packets_[15],
 | 
				
			||||||
      input_packets_[17], input_packets_[19],
 | 
					      input_packets_[17], input_packets_[19],
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  EXPECT_EQ(out_1_packets_, expected_output);
 | 
					  EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
 | 
				
			||||||
  // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
 | 
					  // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
 | 
				
			||||||
  std::vector<Packet> expected_output_2 = {
 | 
					  std::vector<Packet> expected_output_2 = {
 | 
				
			||||||
      input_packets_[0],  input_packets_[2],  input_packets_[4],
 | 
					      input_packets_[0],  input_packets_[2],  input_packets_[4],
 | 
				
			||||||
      input_packets_[15], input_packets_[17], input_packets_[19],
 | 
					      input_packets_[15], input_packets_[17], input_packets_[19],
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  EXPECT_EQ(out_2_packets, expected_output_2);
 | 
					  EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Validate the ALLOW stream output.
 | 
				
			||||||
 | 
					  std::vector<Packet> expected_allow = MakePackets<bool>(  //
 | 
				
			||||||
 | 
					      {{Timestamp(0), true},       {Timestamp(10000), false},
 | 
				
			||||||
 | 
					       {Timestamp(20000), true},   {Timestamp(30000), false},
 | 
				
			||||||
 | 
					       {Timestamp(40000), true},   {Timestamp(50000), false},
 | 
				
			||||||
 | 
					       {Timestamp(60000), false},  {Timestamp(70000), false},
 | 
				
			||||||
 | 
					       {Timestamp(80000), false},  {Timestamp(90000), false},
 | 
				
			||||||
 | 
					       {Timestamp(100000), false}, {Timestamp(110000), false},
 | 
				
			||||||
 | 
					       {Timestamp(120000), false}, {Timestamp(130000), false},
 | 
				
			||||||
 | 
					       {Timestamp(140000), false}, {Timestamp(150000), true},
 | 
				
			||||||
 | 
					       {Timestamp(160000), false}, {Timestamp(170000), true},
 | 
				
			||||||
 | 
					       {Timestamp(180000), false}, {Timestamp(190000), true},
 | 
				
			||||||
 | 
					       {Timestamp(200000), false}});
 | 
				
			||||||
 | 
					  EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Shows how FlowLimiterCalculator releases auxiliary input packets.
 | 
				
			||||||
 | 
					// In this test, auxiliary input packets arrive at twice the primary rate.
 | 
				
			||||||
 | 
					TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
 | 
				
			||||||
 | 
					  auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
 | 
				
			||||||
 | 
					  auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Configure the test.
 | 
				
			||||||
 | 
					  SetUpInputData();
 | 
				
			||||||
 | 
					  SetUpSimulationClock();
 | 
				
			||||||
 | 
					  CalculatorGraphConfig graph_config =
 | 
				
			||||||
 | 
					      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        input_stream: 'in_1'
 | 
				
			||||||
 | 
					        input_stream: 'in_2'
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: 'FlowLimiterCalculator'
 | 
				
			||||||
 | 
					          input_side_packet: 'OPTIONS:limiter_options'
 | 
				
			||||||
 | 
					          input_stream: 'in_1'
 | 
				
			||||||
 | 
					          input_stream: 'in_2'
 | 
				
			||||||
 | 
					          input_stream: 'FINISHED:out_1'
 | 
				
			||||||
 | 
					          input_stream_info: { tag_index: 'FINISHED' back_edge: true }
 | 
				
			||||||
 | 
					          output_stream: 'in_1_sampled'
 | 
				
			||||||
 | 
					          output_stream: 'in_2_sampled'
 | 
				
			||||||
 | 
					          output_stream: 'ALLOW:allow'
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: 'SleepCalculator'
 | 
				
			||||||
 | 
					          input_side_packet: 'WARMUP_TIME:warmup_time'
 | 
				
			||||||
 | 
					          input_side_packet: 'SLEEP_TIME:sleep_time'
 | 
				
			||||||
 | 
					          input_side_packet: 'CLOCK:clock'
 | 
				
			||||||
 | 
					          input_stream: 'PACKET:in_1_sampled'
 | 
				
			||||||
 | 
					          output_stream: 'PACKET:out_1'
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
 | 
				
			||||||
 | 
					      R"pb(
 | 
				
			||||||
 | 
					        max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000  # 1s
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  std::map<std::string, Packet> side_packets = {
 | 
				
			||||||
 | 
					      {"limiter_options",
 | 
				
			||||||
 | 
					       MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
 | 
				
			||||||
 | 
					      {"warmup_time", MakePacket<int64>(22000)},
 | 
				
			||||||
 | 
					      {"sleep_time", MakePacket<int64>(22000)},
 | 
				
			||||||
 | 
					      {"clock", MakePacket<mediapipe::Clock*>(clock_)},
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Start the graph.
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph_.Initialize(graph_config));
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
 | 
				
			||||||
 | 
					    out_1_packets_.push_back(p);
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }));
 | 
				
			||||||
 | 
					  std::vector<Packet> out_2_packets;
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
 | 
				
			||||||
 | 
					    out_2_packets.push_back(p);
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }));
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
 | 
				
			||||||
 | 
					    allow_packets_.push_back(p);
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }));
 | 
				
			||||||
 | 
					  simulation_clock_->ThreadStart();
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph_.StartRun(side_packets));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2.
 | 
				
			||||||
 | 
					  clock_->Sleep(absl::Microseconds(10000));
 | 
				
			||||||
 | 
					  for (int i = 1; i < 10; ++i) {
 | 
				
			||||||
 | 
					    if (i % 2 == 0) {
 | 
				
			||||||
 | 
					      MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i]));
 | 
				
			||||||
 | 
					    clock_->Sleep(absl::Microseconds(10000));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Finish the graph run.
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(graph_.CloseAllPacketSources());
 | 
				
			||||||
 | 
					  clock_->Sleep(absl::Microseconds(40000));
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(graph_.WaitUntilDone());
 | 
				
			||||||
 | 
					  simulation_clock_->ThreadFinish();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Validate the output.
 | 
				
			||||||
 | 
					  // Input packets 4 and 8 are dropped due to max_in_flight.
 | 
				
			||||||
 | 
					  std::vector<Packet> expected_output = {
 | 
				
			||||||
 | 
					      input_packets_[2],
 | 
				
			||||||
 | 
					      input_packets_[6],
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
 | 
				
			||||||
 | 
					  // Packets following input packets 2 and 6, and not input packets 4 and 8.
 | 
				
			||||||
 | 
					  std::vector<Packet> expected_output_2 = {
 | 
				
			||||||
 | 
					      input_packets_[1], input_packets_[2], input_packets_[3],
 | 
				
			||||||
 | 
					      input_packets_[6], input_packets_[7],
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Validate the ALLOW stream output.
 | 
				
			||||||
 | 
					  std::vector<Packet> expected_allow =
 | 
				
			||||||
 | 
					      MakePackets<bool>({{Timestamp(20000), 1},
 | 
				
			||||||
 | 
					                         {Timestamp(40000), 0},
 | 
				
			||||||
 | 
					                         {Timestamp(60000), 1},
 | 
				
			||||||
 | 
					                         {Timestamp(80000), 0}});
 | 
				
			||||||
 | 
					  EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // anonymous namespace
 | 
					}  // anonymous namespace
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,45 +14,43 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/framework/deps/status_builder.h"
 | 
					#include "mediapipe/framework/deps/status_builder.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <sstream>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/memory/memory.h"
 | 
					#include "absl/memory/memory.h"
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "absl/strings/str_cat.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder::StatusBuilder(const StatusBuilder& sb) {
 | 
					StatusBuilder::StatusBuilder(const StatusBuilder& sb)
 | 
				
			||||||
  status_ = sb.status_;
 | 
					    : impl_(sb.impl_ ? std::make_unique<Impl>(*sb.impl_) : nullptr) {}
 | 
				
			||||||
  file_ = sb.file_;
 | 
					 | 
				
			||||||
  line_ = sb.line_;
 | 
					 | 
				
			||||||
  no_logging_ = sb.no_logging_;
 | 
					 | 
				
			||||||
  stream_ = sb.stream_
 | 
					 | 
				
			||||||
                ? absl::make_unique<std::ostringstream>(sb.stream_->str())
 | 
					 | 
				
			||||||
                : nullptr;
 | 
					 | 
				
			||||||
  join_style_ = sb.join_style_;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
 | 
					StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
 | 
				
			||||||
  status_ = sb.status_;
 | 
					  if (!sb.impl_) {
 | 
				
			||||||
  file_ = sb.file_;
 | 
					    impl_ = nullptr;
 | 
				
			||||||
  line_ = sb.line_;
 | 
					    return *this;
 | 
				
			||||||
  no_logging_ = sb.no_logging_;
 | 
					  }
 | 
				
			||||||
  stream_ = sb.stream_
 | 
					  if (impl_) {
 | 
				
			||||||
                ? absl::make_unique<std::ostringstream>(sb.stream_->str())
 | 
					    *impl_ = *sb.impl_;
 | 
				
			||||||
                : nullptr;
 | 
					    return *this;
 | 
				
			||||||
  join_style_ = sb.join_style_;
 | 
					  }
 | 
				
			||||||
 | 
					  impl_ = std::make_unique<Impl>(*sb.impl_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return *this;
 | 
					  return *this;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder& StatusBuilder::SetAppend() & {
 | 
					StatusBuilder& StatusBuilder::SetAppend() & {
 | 
				
			||||||
  if (status_.ok()) return *this;
 | 
					  if (!impl_) return *this;
 | 
				
			||||||
  join_style_ = MessageJoinStyle::kAppend;
 | 
					  impl_->join_style = Impl::MessageJoinStyle::kAppend;
 | 
				
			||||||
  return *this;
 | 
					  return *this;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); }
 | 
					StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder& StatusBuilder::SetPrepend() & {
 | 
					StatusBuilder& StatusBuilder::SetPrepend() & {
 | 
				
			||||||
  if (status_.ok()) return *this;
 | 
					  if (!impl_) return *this;
 | 
				
			||||||
  join_style_ = MessageJoinStyle::kPrepend;
 | 
					  impl_->join_style = Impl::MessageJoinStyle::kPrepend;
 | 
				
			||||||
  return *this;
 | 
					  return *this;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,7 +59,8 @@ StatusBuilder&& StatusBuilder::SetPrepend() && {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder& StatusBuilder::SetNoLogging() & {
 | 
					StatusBuilder& StatusBuilder::SetNoLogging() & {
 | 
				
			||||||
  no_logging_ = true;
 | 
					  if (!impl_) return *this;
 | 
				
			||||||
 | 
					  impl_->no_logging = true;
 | 
				
			||||||
  return *this;
 | 
					  return *this;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -70,34 +69,72 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder::operator Status() const& {
 | 
					StatusBuilder::operator Status() const& {
 | 
				
			||||||
  if (!stream_ || stream_->str().empty() || no_logging_) {
 | 
					 | 
				
			||||||
    return status_;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return StatusBuilder(*this).JoinMessageToStatus();
 | 
					  return StatusBuilder(*this).JoinMessageToStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
StatusBuilder::operator Status() && {
 | 
					StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
 | 
				
			||||||
  if (!stream_ || stream_->str().empty() || no_logging_) {
 | 
					 | 
				
			||||||
    return status_;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return JoinMessageToStatus();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status StatusBuilder::JoinMessageToStatus() {
 | 
					absl::Status StatusBuilder::JoinMessageToStatus() {
 | 
				
			||||||
  if (!stream_) {
 | 
					  if (!impl_) {
 | 
				
			||||||
    return absl::OkStatus();
 | 
					    return absl::OkStatus();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  std::string message;
 | 
					  return impl_->JoinMessageToStatus();
 | 
				
			||||||
  if (join_style_ == MessageJoinStyle::kAnnotate) {
 | 
					}
 | 
				
			||||||
    if (!status_.ok()) {
 | 
					
 | 
				
			||||||
      message = absl::StrCat(status_.message(), "; ", stream_->str());
 | 
					absl::Status StatusBuilder::Impl::JoinMessageToStatus() {
 | 
				
			||||||
 | 
					  if (stream.str().empty() || no_logging) {
 | 
				
			||||||
 | 
					    return status;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  } else {
 | 
					  return absl::Status(status.code(), [this]() {
 | 
				
			||||||
    message = join_style_ == MessageJoinStyle::kPrepend
 | 
					    switch (join_style) {
 | 
				
			||||||
                  ? absl::StrCat(stream_->str(), status_.message())
 | 
					      case MessageJoinStyle::kAnnotate:
 | 
				
			||||||
                  : absl::StrCat(status_.message(), stream_->str());
 | 
					        return absl::StrCat(status.message(), "; ", stream.str());
 | 
				
			||||||
 | 
					      case MessageJoinStyle::kAppend:
 | 
				
			||||||
 | 
					        return absl::StrCat(status.message(), stream.str());
 | 
				
			||||||
 | 
					      case MessageJoinStyle::kPrepend:
 | 
				
			||||||
 | 
					        return absl::StrCat(stream.str(), status.message());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  return Status(status_.code(), message);
 | 
					  }());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StatusBuilder::Impl::Impl(const absl::Status& status, const char* file,
 | 
				
			||||||
 | 
					                          int line)
 | 
				
			||||||
 | 
					    : status(status), line(line), file(file), stream() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line)
 | 
				
			||||||
 | 
					    : status(std::move(status)), line(line), file(file), stream() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StatusBuilder::Impl::Impl(const absl::Status& status,
 | 
				
			||||||
 | 
					                          mediapipe::source_location location)
 | 
				
			||||||
 | 
					    : status(status),
 | 
				
			||||||
 | 
					      line(location.line()),
 | 
				
			||||||
 | 
					      file(location.file_name()),
 | 
				
			||||||
 | 
					      stream() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StatusBuilder::Impl::Impl(absl::Status&& status,
 | 
				
			||||||
 | 
					                          mediapipe::source_location location)
 | 
				
			||||||
 | 
					    : status(std::move(status)),
 | 
				
			||||||
 | 
					      line(location.line()),
 | 
				
			||||||
 | 
					      file(location.file_name()),
 | 
				
			||||||
 | 
					      stream() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StatusBuilder::Impl::Impl(const Impl& other)
 | 
				
			||||||
 | 
					    : status(other.status),
 | 
				
			||||||
 | 
					      line(other.line),
 | 
				
			||||||
 | 
					      file(other.file),
 | 
				
			||||||
 | 
					      no_logging(other.no_logging),
 | 
				
			||||||
 | 
					      stream(other.stream.str()),
 | 
				
			||||||
 | 
					      join_style(other.join_style) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) {
 | 
				
			||||||
 | 
					  status = other.status;
 | 
				
			||||||
 | 
					  line = other.line;
 | 
				
			||||||
 | 
					  file = other.file;
 | 
				
			||||||
 | 
					  no_logging = other.no_logging;
 | 
				
			||||||
 | 
					  stream = std::ostringstream(other.stream.str());
 | 
				
			||||||
 | 
					  join_style = other.join_style;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return *this;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,7 +22,6 @@
 | 
				
			||||||
#include "absl/base/attributes.h"
 | 
					#include "absl/base/attributes.h"
 | 
				
			||||||
#include "absl/memory/memory.h"
 | 
					#include "absl/memory/memory.h"
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
#include "absl/strings/str_cat.h"
 | 
					 | 
				
			||||||
#include "absl/strings/string_view.h"
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
#include "mediapipe/framework/deps/source_location.h"
 | 
					#include "mediapipe/framework/deps/source_location.h"
 | 
				
			||||||
#include "mediapipe/framework/deps/status.h"
 | 
					#include "mediapipe/framework/deps/status.h"
 | 
				
			||||||
| 
						 | 
					@ -42,34 +41,37 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
 | 
				
			||||||
  // occurs.  A typical user will call this with `MEDIAPIPE_LOC`.
 | 
					  // occurs.  A typical user will call this with `MEDIAPIPE_LOC`.
 | 
				
			||||||
  StatusBuilder(const absl::Status& original_status,
 | 
					  StatusBuilder(const absl::Status& original_status,
 | 
				
			||||||
                mediapipe::source_location location)
 | 
					                mediapipe::source_location location)
 | 
				
			||||||
      : status_(original_status),
 | 
					      : impl_(original_status.ok()
 | 
				
			||||||
        line_(location.line()),
 | 
					                  ? nullptr
 | 
				
			||||||
        file_(location.file_name()),
 | 
					                  : std::make_unique<Impl>(original_status, location)) {}
 | 
				
			||||||
        stream_(InitStream(status_)) {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  StatusBuilder(absl::Status&& original_status,
 | 
					  StatusBuilder(absl::Status&& original_status,
 | 
				
			||||||
                mediapipe::source_location location)
 | 
					                mediapipe::source_location location)
 | 
				
			||||||
      : status_(std::move(original_status)),
 | 
					      : impl_(original_status.ok()
 | 
				
			||||||
        line_(location.line()),
 | 
					                  ? nullptr
 | 
				
			||||||
        file_(location.file_name()),
 | 
					                  : std::make_unique<Impl>(std::move(original_status),
 | 
				
			||||||
        stream_(InitStream(status_)) {}
 | 
					                                           location)) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Creates a `StatusBuilder` from a mediapipe status code.  If logging is
 | 
					  // Creates a `StatusBuilder` from a mediapipe status code.  If logging is
 | 
				
			||||||
  // enabled, it will use `location` as the location from which the log message
 | 
					  // enabled, it will use `location` as the location from which the log message
 | 
				
			||||||
  // occurs.  A typical user will call this with `MEDIAPIPE_LOC`.
 | 
					  // occurs.  A typical user will call this with `MEDIAPIPE_LOC`.
 | 
				
			||||||
  StatusBuilder(absl::StatusCode code, mediapipe::source_location location)
 | 
					  StatusBuilder(absl::StatusCode code, mediapipe::source_location location)
 | 
				
			||||||
      : status_(code, ""),
 | 
					      : impl_(code == absl::StatusCode::kOk
 | 
				
			||||||
        line_(location.line()),
 | 
					                  ? nullptr
 | 
				
			||||||
        file_(location.file_name()),
 | 
					                  : std::make_unique<Impl>(absl::Status(code, ""), location)) {}
 | 
				
			||||||
        stream_(InitStream(status_)) {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  StatusBuilder(const absl::Status& original_status, const char* file, int line)
 | 
					  StatusBuilder(const absl::Status& original_status, const char* file, int line)
 | 
				
			||||||
      : status_(original_status),
 | 
					      : impl_(original_status.ok()
 | 
				
			||||||
        line_(line),
 | 
					                  ? nullptr
 | 
				
			||||||
        file_(file),
 | 
					                  : std::make_unique<Impl>(original_status, file, line)) {}
 | 
				
			||||||
        stream_(InitStream(status_)) {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  bool ok() const { return status_.ok(); }
 | 
					  StatusBuilder(absl::Status&& original_status, const char* file, int line)
 | 
				
			||||||
 | 
					      : impl_(original_status.ok()
 | 
				
			||||||
 | 
					                  ? nullptr
 | 
				
			||||||
 | 
					                  : std::make_unique<Impl>(std::move(original_status), file,
 | 
				
			||||||
 | 
					                                           line)) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool ok() const { return !impl_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  StatusBuilder& SetAppend() &;
 | 
					  StatusBuilder& SetAppend() &;
 | 
				
			||||||
  StatusBuilder&& SetAppend() &&;
 | 
					  StatusBuilder&& SetAppend() &&;
 | 
				
			||||||
| 
						 | 
					@ -82,8 +84,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  template <typename T>
 | 
					  template <typename T>
 | 
				
			||||||
  StatusBuilder& operator<<(const T& msg) & {
 | 
					  StatusBuilder& operator<<(const T& msg) & {
 | 
				
			||||||
    if (!stream_) return *this;
 | 
					    if (!impl_) return *this;
 | 
				
			||||||
    *stream_ << msg;
 | 
					    impl_->stream << msg;
 | 
				
			||||||
    return *this;
 | 
					    return *this;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -98,6 +100,7 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
 | 
				
			||||||
  absl::Status JoinMessageToStatus();
 | 
					  absl::Status JoinMessageToStatus();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
 | 
					  struct Impl {
 | 
				
			||||||
    // Specifies how to join the error message in the original status and any
 | 
					    // Specifies how to join the error message in the original status and any
 | 
				
			||||||
    // additional message that has been streamed into the builder.
 | 
					    // additional message that has been streamed into the builder.
 | 
				
			||||||
    enum class MessageJoinStyle {
 | 
					    enum class MessageJoinStyle {
 | 
				
			||||||
| 
						 | 
					@ -106,27 +109,33 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
 | 
				
			||||||
      kPrepend,
 | 
					      kPrepend,
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Conditionally creates an ostringstream if the status is not ok.
 | 
					    Impl(const absl::Status& status, const char* file, int line);
 | 
				
			||||||
  static std::unique_ptr<std::ostringstream> InitStream(
 | 
					    Impl(absl::Status&& status, const char* file, int line);
 | 
				
			||||||
      const absl::Status status) {
 | 
					    Impl(const absl::Status& status, mediapipe::source_location location);
 | 
				
			||||||
    if (status.ok()) {
 | 
					    Impl(absl::Status&& status, mediapipe::source_location location);
 | 
				
			||||||
      return nullptr;
 | 
					    Impl(const Impl&);
 | 
				
			||||||
    }
 | 
					    Impl& operator=(const Impl&);
 | 
				
			||||||
    return absl::make_unique<std::ostringstream>();
 | 
					
 | 
				
			||||||
  }
 | 
					    absl::Status JoinMessageToStatus();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // The status that the result will be based on.
 | 
					    // The status that the result will be based on.
 | 
				
			||||||
  absl::Status status_;
 | 
					    absl::Status status;
 | 
				
			||||||
    // The line to record if this file is logged.
 | 
					    // The line to record if this file is logged.
 | 
				
			||||||
  int line_;
 | 
					    int line;
 | 
				
			||||||
    // Not-owned: The file to record if this status is logged.
 | 
					    // Not-owned: The file to record if this status is logged.
 | 
				
			||||||
  const char* file_;
 | 
					    const char* file;
 | 
				
			||||||
  bool no_logging_ = false;
 | 
					    // Logging disabled if true.
 | 
				
			||||||
 | 
					    bool no_logging = false;
 | 
				
			||||||
    // The additional messages added with `<<`.  This is nullptr when status_ is
 | 
					    // The additional messages added with `<<`.  This is nullptr when status_ is
 | 
				
			||||||
    // ok.
 | 
					    // ok.
 | 
				
			||||||
  std::unique_ptr<std::ostringstream> stream_;
 | 
					    std::ostringstream stream;
 | 
				
			||||||
    // Specifies how to join the message in `status_` and `stream_`.
 | 
					    // Specifies how to join the message in `status_` and `stream_`.
 | 
				
			||||||
  MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate;
 | 
					    MessageJoinStyle join_style = MessageJoinStyle::kAnnotate;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Internal store of data for the class.  An invariant of the class is that
 | 
				
			||||||
 | 
					  // this is null when the original status is okay, and not-null otherwise.
 | 
				
			||||||
 | 
					  std::unique_ptr<Impl> impl_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
inline StatusBuilder AlreadyExistsErrorBuilder(
 | 
					inline StatusBuilder AlreadyExistsErrorBuilder(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,6 +33,21 @@ TEST(StatusBuilder, OkStatusRvalue) {
 | 
				
			||||||
  ASSERT_EQ(status, absl::OkStatus());
 | 
					  ASSERT_EQ(status, absl::OkStatus());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) {
 | 
				
			||||||
 | 
					  absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234)
 | 
				
			||||||
 | 
					                        << "annotated message1 "
 | 
				
			||||||
 | 
					                        << "annotated message2";
 | 
				
			||||||
 | 
					  ASSERT_EQ(status, absl::OkStatus());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) {
 | 
				
			||||||
 | 
					  const auto original_status = absl::OkStatus();
 | 
				
			||||||
 | 
					  absl::Status status = StatusBuilder(original_status, "hello.cc", 1234)
 | 
				
			||||||
 | 
					                        << "annotated message1 "
 | 
				
			||||||
 | 
					                        << "annotated message2";
 | 
				
			||||||
 | 
					  ASSERT_EQ(status, absl::OkStatus());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(StatusBuilder, AnnotateMode) {
 | 
					TEST(StatusBuilder, AnnotateMode) {
 | 
				
			||||||
  absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
 | 
					  absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
 | 
				
			||||||
                                                   "original message"),
 | 
					                                                   "original message"),
 | 
				
			||||||
| 
						 | 
					@ -45,6 +60,30 @@ TEST(StatusBuilder, AnnotateMode) {
 | 
				
			||||||
            "original message; annotated message1 annotated message2");
 | 
					            "original message; annotated message1 annotated message2");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) {
 | 
				
			||||||
 | 
					  absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
 | 
				
			||||||
 | 
					                                                   "original message"),
 | 
				
			||||||
 | 
					                                      "hello.cc", 1234)
 | 
				
			||||||
 | 
					                        << "annotated message1 "
 | 
				
			||||||
 | 
					                        << "annotated message2";
 | 
				
			||||||
 | 
					  ASSERT_FALSE(status.ok());
 | 
				
			||||||
 | 
					  EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
 | 
				
			||||||
 | 
					  EXPECT_EQ(status.message(),
 | 
				
			||||||
 | 
					            "original message; annotated message1 annotated message2");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) {
 | 
				
			||||||
 | 
					  const auto original_status =
 | 
				
			||||||
 | 
					      absl::Status(absl::StatusCode::kNotFound, "original message");
 | 
				
			||||||
 | 
					  absl::Status status = StatusBuilder(original_status, "hello.cc", 1234)
 | 
				
			||||||
 | 
					                        << "annotated message1 "
 | 
				
			||||||
 | 
					                        << "annotated message2";
 | 
				
			||||||
 | 
					  ASSERT_FALSE(status.ok());
 | 
				
			||||||
 | 
					  EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
 | 
				
			||||||
 | 
					  EXPECT_EQ(status.message(),
 | 
				
			||||||
 | 
					            "original message; annotated message1 annotated message2");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(StatusBuilder, PrependModeLvalue) {
 | 
					TEST(StatusBuilder, PrependModeLvalue) {
 | 
				
			||||||
  StatusBuilder builder(
 | 
					  StatusBuilder builder(
 | 
				
			||||||
      absl::Status(absl::StatusCode::kInvalidArgument, "original message"),
 | 
					      absl::Status(absl::StatusCode::kInvalidArgument, "original message"),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model
 | 
				
			||||||
class Classifier(custom_model.CustomModel):
 | 
					class Classifier(custom_model.CustomModel):
 | 
				
			||||||
  """An abstract base class that represents a TensorFlow classifier."""
 | 
					  """An abstract base class that represents a TensorFlow classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
 | 
					  def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
 | 
				
			||||||
               full_train: bool):
 | 
					 | 
				
			||||||
    """Initilizes a classifier with its specifications.
 | 
					    """Initilizes a classifier with its specifications.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        model_spec: Specification for the model.
 | 
					        model_spec: Specification for the model.
 | 
				
			||||||
        label_names: A list of label names for the classes.
 | 
					        label_names: A list of label names for the classes.
 | 
				
			||||||
        shuffle: Whether the dataset should be shuffled.
 | 
					        shuffle: Whether the dataset should be shuffled.
 | 
				
			||||||
        full_train: If true, train the model end-to-end including the backbone
 | 
					 | 
				
			||||||
          and the classification layers on top. Otherwise, only train the top
 | 
					 | 
				
			||||||
          classification layers.
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    super(Classifier, self).__init__(model_spec, shuffle)
 | 
					    super(Classifier, self).__init__(model_spec, shuffle)
 | 
				
			||||||
    self._label_names = label_names
 | 
					    self._label_names = label_names
 | 
				
			||||||
    self._full_train = full_train
 | 
					 | 
				
			||||||
    self._num_classes = len(label_names)
 | 
					    self._num_classes = len(label_names)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
 | 
					  def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase):
 | 
				
			||||||
    super(ClassifierTest, self).setUp()
 | 
					    super(ClassifierTest, self).setUp()
 | 
				
			||||||
    label_names = ['cat', 'dog']
 | 
					    label_names = ['cat', 'dog']
 | 
				
			||||||
    self.model = MockClassifier(
 | 
					    self.model = MockClassifier(
 | 
				
			||||||
        model_spec=None,
 | 
					        model_spec=None, label_names=label_names, shuffle=False)
 | 
				
			||||||
        label_names=label_names,
 | 
					 | 
				
			||||||
        shuffle=False,
 | 
					 | 
				
			||||||
        full_train=False)
 | 
					 | 
				
			||||||
    self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
 | 
					    self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _check_nonempty_file(self, filepath):
 | 
					  def _check_nonempty_file(self, filepath):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
      hparams: The hyperparameters for training image classifier.
 | 
					      hparams: The hyperparameters for training image classifier.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    super().__init__(
 | 
					    super().__init__(
 | 
				
			||||||
        model_spec=model_spec,
 | 
					        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
				
			||||||
        label_names=label_names,
 | 
					 | 
				
			||||||
        shuffle=hparams.shuffle,
 | 
					 | 
				
			||||||
        full_train=hparams.do_fine_tuning)
 | 
					 | 
				
			||||||
    self._hparams = hparams
 | 
					    self._hparams = hparams
 | 
				
			||||||
    self._preprocess = image_preprocessing.Preprocessor(
 | 
					    self._preprocess = image_preprocessing.Preprocessor(
 | 
				
			||||||
        input_shape=self._model_spec.input_image_shape,
 | 
					        input_shape=self._model_spec.input_image_shape,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,6 +93,13 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
 | 
					        "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
 | 
					        "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
 | 
					    ] + select({
 | 
				
			||||||
 | 
					        # TODO: Build text_classifier_graph on Windows.
 | 
				
			||||||
 | 
					        "//mediapipe:windows": [],
 | 
				
			||||||
 | 
					        "//conditions:default": [
 | 
				
			||||||
 | 
					            "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) {
 | 
				
			||||||
      // TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
 | 
					      // TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
 | 
				
			||||||
      // as the input argument. Investigate why bazel non-optimized mode
 | 
					      // as the input argument. Investigate why bazel non-optimized mode
 | 
				
			||||||
      // triggers a memory allocation bug in Eigen::internal::aligned_free().
 | 
					      // triggers a memory allocation bug in Eigen::internal::aligned_free().
 | 
				
			||||||
      [](const Eigen::MatrixXf& matrix) {
 | 
					      [](const Eigen::MatrixXf& matrix, bool transpose) {
 | 
				
			||||||
        // MakePacket copies the data.
 | 
					        // MakePacket copies the data.
 | 
				
			||||||
 | 
					        if (transpose) {
 | 
				
			||||||
 | 
					          return MakePacket<Matrix>(matrix.transpose());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        return MakePacket<Matrix>(matrix);
 | 
					        return MakePacket<Matrix>(matrix);
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
      R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
 | 
					      R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
 | 
				
			||||||
| 
						 | 
					@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Args:
 | 
					  Args:
 | 
				
			||||||
    matrix: A 2d numpy float ndarray.
 | 
					    matrix: A 2d numpy float ndarray.
 | 
				
			||||||
 | 
					    transpose: A boolean to indicate if the input matrix needs to be transposed.
 | 
				
			||||||
 | 
					      Default to False.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Returns:
 | 
					  Returns:
 | 
				
			||||||
    A MediaPipe Matrix Packet.
 | 
					    A MediaPipe Matrix Packet.
 | 
				
			||||||
| 
						 | 
					@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) {
 | 
				
			||||||
        np.array([[.1, .2, .3], [.4, .5, .6]])
 | 
					        np.array([[.1, .2, .3], [.4, .5, .6]])
 | 
				
			||||||
    matrix = mp.packet_getter.get_matrix(packet)
 | 
					    matrix = mp.packet_getter.get_matrix(packet)
 | 
				
			||||||
)doc",
 | 
					)doc",
 | 
				
			||||||
 | 
					      py::arg("matrix"), py::arg("transpose") = false,
 | 
				
			||||||
      py::return_value_policy::move);
 | 
					      py::return_value_policy::move);
 | 
				
			||||||
}  // NOLINT(readability/fn_size)
 | 
					}  // NOLINT(readability/fn_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,13 +31,13 @@ namespace core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Base options for MediaPipe C++ Tasks.
 | 
					// Base options for MediaPipe C++ Tasks.
 | 
				
			||||||
struct BaseOptions {
 | 
					struct BaseOptions {
 | 
				
			||||||
  // The model asset file contents as as string.
 | 
					  // The model asset file contents as a string.
 | 
				
			||||||
  std::unique_ptr<std::string> model_asset_buffer;
 | 
					  std::unique_ptr<std::string> model_asset_buffer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The path to the model asset to open and mmap in memory.
 | 
					  // The path to the model asset to open and mmap in memory.
 | 
				
			||||||
  std::string model_asset_path = "";
 | 
					  std::string model_asset_path = "";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The delegate to run MediaPipe. If the delegate is not set, default
 | 
					  // The delegate to run MediaPipe. If the delegate is not set, the default
 | 
				
			||||||
  // delegate CPU is used.
 | 
					  // delegate CPU is used.
 | 
				
			||||||
  enum Delegate {
 | 
					  enum Delegate {
 | 
				
			||||||
    CPU = 0,
 | 
					    CPU = 0,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
 | 
					        hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
 | 
				
			||||||
            kHandGesturesTag)];
 | 
					            kHandGesturesTag)];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return {{.gesture = hand_gestures,
 | 
					    return GestureRecognizerOutputs{
 | 
				
			||||||
             .handedness = handedness,
 | 
					        /*gesture=*/hand_gestures,
 | 
				
			||||||
             .hand_landmarks = hand_landmarks,
 | 
					        /*handedness=*/handedness,
 | 
				
			||||||
             .hand_world_landmarks = hand_world_landmarks,
 | 
					        /*hand_landmarks=*/hand_landmarks,
 | 
				
			||||||
             .image = hand_landmarker_graph[Output<Image>(kImageTag)]}};
 | 
					        /*hand_world_landmarks=*/hand_world_landmarks,
 | 
				
			||||||
 | 
					        /*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Populate normalized non rotated face bounding box
 | 
					  // Populate normalized non rotated face bounding box
 | 
				
			||||||
  return {.left = bounding_box_left,
 | 
					  return Rect{/*left=*/bounding_box_left,
 | 
				
			||||||
          .top = bounding_box_top,
 | 
					              /*top=*/bounding_box_top,
 | 
				
			||||||
          .right = bounding_box_right,
 | 
					              /*right=*/bounding_box_right,
 | 
				
			||||||
          .bottom = bounding_box_bottom};
 | 
					              /*bottom=*/bounding_box_bottom};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Uses IoU and distance of some corresponding hand landmarks to detect
 | 
					// Uses IoU and distance of some corresponding hand landmarks to detect
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,7 +48,7 @@ public abstract class BaseOptions {
 | 
				
			||||||
    public abstract Builder setModelAssetBuffer(ByteBuffer value);
 | 
					    public abstract Builder setModelAssetBuffer(ByteBuffer value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default
 | 
					     * Sets device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
 | 
				
			||||||
     * delegate CPU is used.
 | 
					     * delegate CPU is used.
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public abstract Builder setDelegate(Delegate delegate);
 | 
					    public abstract Builder setDelegate(Delegate delegate);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -68,7 +68,11 @@ class AudioData(object):
 | 
				
			||||||
      ValueError: If the input array has an incorrect shape or if
 | 
					      ValueError: If the input array has an incorrect shape or if
 | 
				
			||||||
        `offset` + `size` exceeds the length of the `src` array.
 | 
					        `offset` + `size` exceeds the length of the `src` array.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if src.shape[1] != self._audio_format.num_channels:
 | 
					    if len(src.shape) == 1:
 | 
				
			||||||
 | 
					      if self._audio_format.num_channels != 1:
 | 
				
			||||||
 | 
					        raise ValueError(f"Input audio is mono, but the audio data is expected "
 | 
				
			||||||
 | 
					                         f"to have {self._audio_format.num_channels} channels.")
 | 
				
			||||||
 | 
					    elif src.shape[1] != self._audio_format.num_channels:
 | 
				
			||||||
      raise ValueError(f"Input audio contains an invalid number of channels. "
 | 
					      raise ValueError(f"Input audio contains an invalid number of channels. "
 | 
				
			||||||
                       f"Expect {self._audio_format.num_channels}.")
 | 
					                       f"Expect {self._audio_format.num_channels}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,6 +97,28 @@ class AudioData(object):
 | 
				
			||||||
      self._buffer = np.roll(self._buffer, -shift, axis=0)
 | 
					      self._buffer = np.roll(self._buffer, -shift, axis=0)
 | 
				
			||||||
      self._buffer[-shift:, :] = src[offset:offset + size].copy()
 | 
					      self._buffer[-shift:, :] = src[offset:offset + size].copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def create_from_array(cls,
 | 
				
			||||||
 | 
					                        src: np.ndarray,
 | 
				
			||||||
 | 
					                        sample_rate: Optional[float] = None) -> "AudioData":
 | 
				
			||||||
 | 
					    """Creates an `AudioData` object from a NumPy array.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      src: A NumPy source array contains the input audio.
 | 
				
			||||||
 | 
					      sample_rate: the optional audio sample rate.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      An `AudioData` object that contains a copy of the NumPy source array as
 | 
				
			||||||
 | 
					      the data.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    obj = cls(
 | 
				
			||||||
 | 
					        buffer_length=src.shape[0],
 | 
				
			||||||
 | 
					        audio_format=AudioFormat(
 | 
				
			||||||
 | 
					            num_channels=1 if len(src.shape) == 1 else src.shape[1],
 | 
				
			||||||
 | 
					            sample_rate=sample_rate))
 | 
				
			||||||
 | 
					    obj.load_from_array(src)
 | 
				
			||||||
 | 
					    return obj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @property
 | 
					  @property
 | 
				
			||||||
  def audio_format(self) -> AudioFormat:
 | 
					  def audio_format(self) -> AudioFormat:
 | 
				
			||||||
    """Gets the audio format of the audio."""
 | 
					    """Gets the audio format of the audio."""
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -157,7 +157,7 @@ def _normalize_number_fields(pb):
 | 
				
			||||||
                       descriptor.FieldDescriptor.TYPE_ENUM):
 | 
					                       descriptor.FieldDescriptor.TYPE_ENUM):
 | 
				
			||||||
      normalized_values = [int(x) for x in values]
 | 
					      normalized_values = [int(x) for x in values]
 | 
				
			||||||
    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
 | 
					    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
 | 
				
			||||||
      normalized_values = [round(x, 5) for x in values]
 | 
					      normalized_values = [round(x, 4) for x in values]
 | 
				
			||||||
    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
 | 
					    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
 | 
				
			||||||
      normalized_values = [round(float(x), 6) for x in values]
 | 
					      normalized_values = [round(float(x), 6) for x in values]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										36
									
								
								mediapipe/tasks/python/test/text/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								mediapipe/tasks/python/test/text/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,36 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "text_classifier_test",
 | 
				
			||||||
 | 
					    srcs = ["text_classifier_test.py"],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:bert_text_classifier_models",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:text_classifier_models",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:classifications",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/text:text_classifier",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										13
									
								
								mediapipe/tasks/python/test/text/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/tasks/python/test/text/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
							
								
								
									
										244
									
								
								mediapipe/tasks/python/test/text/text_classifier_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										244
									
								
								mediapipe/tasks/python/test/text/text_classifier_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,244 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Tests for text classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import enum
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl.testing import absltest
 | 
				
			||||||
 | 
					from absl.testing import parameterized
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import category
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import classifications as classifications_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.processors import classifier_options
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.text import text_classifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_BaseOptions = base_options_module.BaseOptions
 | 
				
			||||||
 | 
					_ClassifierOptions = classifier_options.ClassifierOptions
 | 
				
			||||||
 | 
					_Category = category.Category
 | 
				
			||||||
 | 
					_ClassificationEntry = classifications_module.ClassificationEntry
 | 
				
			||||||
 | 
					_Classifications = classifications_module.Classifications
 | 
				
			||||||
 | 
					_TextClassifierResult = classifications_module.ClassificationResult
 | 
				
			||||||
 | 
					_TextClassifier = text_classifier.TextClassifier
 | 
				
			||||||
 | 
					_TextClassifierOptions = text_classifier.TextClassifierOptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_BERT_MODEL_FILE = 'bert_text_classifier.tflite'
 | 
				
			||||||
 | 
					_REGEX_MODEL_FILE = 'test_model_text_classifier_with_regex_tokenizer.tflite'
 | 
				
			||||||
 | 
					_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_NEGATIVE_TEXT = 'What a waste of my time.'
 | 
				
			||||||
 | 
					_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
 | 
				
			||||||
 | 
					                  'Strongly recommend it!')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
 | 
				
			||||||
 | 
					    _Classifications(
 | 
				
			||||||
 | 
					        entries=[
 | 
				
			||||||
 | 
					            _ClassificationEntry(
 | 
				
			||||||
 | 
					                categories=[
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=0,
 | 
				
			||||||
 | 
					                        score=0.999479,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='negative'),
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=1,
 | 
				
			||||||
 | 
					                        score=0.00052154,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='positive')
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                timestamp_ms=0)
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        head_index=0,
 | 
				
			||||||
 | 
					        head_name='probability')
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
 | 
				
			||||||
 | 
					    _Classifications(
 | 
				
			||||||
 | 
					        entries=[
 | 
				
			||||||
 | 
					            _ClassificationEntry(
 | 
				
			||||||
 | 
					                categories=[
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=1,
 | 
				
			||||||
 | 
					                        score=0.999466,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='positive'),
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=0,
 | 
				
			||||||
 | 
					                        score=0.000533596,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='negative')
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                timestamp_ms=0)
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        head_index=0,
 | 
				
			||||||
 | 
					        head_name='probability')
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
 | 
				
			||||||
 | 
					    _Classifications(
 | 
				
			||||||
 | 
					        entries=[
 | 
				
			||||||
 | 
					            _ClassificationEntry(
 | 
				
			||||||
 | 
					                categories=[
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=0,
 | 
				
			||||||
 | 
					                        score=0.81313,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='Negative'),
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=1,
 | 
				
			||||||
 | 
					                        score=0.1868704,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='Positive')
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                timestamp_ms=0)
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        head_index=0,
 | 
				
			||||||
 | 
					        head_name='probability')
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
 | 
				
			||||||
 | 
					    _Classifications(
 | 
				
			||||||
 | 
					        entries=[
 | 
				
			||||||
 | 
					            _ClassificationEntry(
 | 
				
			||||||
 | 
					                categories=[
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=1,
 | 
				
			||||||
 | 
					                        score=0.5134273,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='Positive'),
 | 
				
			||||||
 | 
					                    _Category(
 | 
				
			||||||
 | 
					                        index=0,
 | 
				
			||||||
 | 
					                        score=0.486573,
 | 
				
			||||||
 | 
					                        display_name='',
 | 
				
			||||||
 | 
					                        category_name='Negative')
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                timestamp_ms=0)
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        head_index=0,
 | 
				
			||||||
 | 
					        head_name='probability')
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelFileType(enum.Enum):
 | 
				
			||||||
 | 
					  FILE_CONTENT = 1
 | 
				
			||||||
 | 
					  FILE_NAME = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    self.model_path = test_utils.get_test_data_path(
 | 
				
			||||||
 | 
					        os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
				
			||||||
 | 
					    # Creates with default option and valid model file successfully.
 | 
				
			||||||
 | 
					    with _TextClassifier.create_from_model_path(self.model_path) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _TextClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_succeeds_with_valid_model_path(self):
 | 
				
			||||||
 | 
					    # Creates with options containing model file successfully.
 | 
				
			||||||
 | 
					    base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
				
			||||||
 | 
					    options = _TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					    with _TextClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _TextClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
 | 
					    # Invalid empty model path.
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        r"ExternalFile must specify at least one of 'file_content', "
 | 
				
			||||||
 | 
					        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_path='')
 | 
				
			||||||
 | 
					      options = _TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					      _TextClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_succeeds_with_valid_model_content(self):
 | 
				
			||||||
 | 
					    # Creates with options containing model content successfully.
 | 
				
			||||||
 | 
					    with open(self.model_path, 'rb') as f:
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_buffer=f.read())
 | 
				
			||||||
 | 
					      options = _TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					      classifier = _TextClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _TextClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters(
 | 
				
			||||||
 | 
					      (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
 | 
				
			||||||
 | 
					       _BERT_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
 | 
				
			||||||
 | 
					                                 _NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS),
 | 
				
			||||||
 | 
					      (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _POSITIVE_TEXT,
 | 
				
			||||||
 | 
					       _BERT_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
 | 
				
			||||||
 | 
					                                 _POSITIVE_TEXT, _BERT_POSITIVE_RESULTS),
 | 
				
			||||||
 | 
					      (ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _NEGATIVE_TEXT,
 | 
				
			||||||
 | 
					       _REGEX_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE,
 | 
				
			||||||
 | 
					                                  _NEGATIVE_TEXT, _REGEX_NEGATIVE_RESULTS),
 | 
				
			||||||
 | 
					      (ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _POSITIVE_TEXT,
 | 
				
			||||||
 | 
					       _REGEX_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE,
 | 
				
			||||||
 | 
					                                  _POSITIVE_TEXT, _REGEX_POSITIVE_RESULTS))
 | 
				
			||||||
 | 
					  def test_classify(self, model_file_type, model_name, text,
 | 
				
			||||||
 | 
					                    expected_classification_result):
 | 
				
			||||||
 | 
					    # Creates classifier.
 | 
				
			||||||
 | 
					    model_path = test_utils.get_test_data_path(
 | 
				
			||||||
 | 
					        os.path.join(_TEST_DATA_DIR, model_name))
 | 
				
			||||||
 | 
					    if model_file_type is ModelFileType.FILE_NAME:
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_path=model_path)
 | 
				
			||||||
 | 
					    elif model_file_type is ModelFileType.FILE_CONTENT:
 | 
				
			||||||
 | 
					      with open(model_path, 'rb') as f:
 | 
				
			||||||
 | 
					        model_content = f.read()
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_buffer=model_content)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      # Should never happen
 | 
				
			||||||
 | 
					      raise ValueError('model_file_type is invalid.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    options = _TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					    classifier = _TextClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Performs text classification on the input.
 | 
				
			||||||
 | 
					    text_result = classifier.classify(text)
 | 
				
			||||||
 | 
					    # Comparing results.
 | 
				
			||||||
 | 
					    test_utils.assert_proto_equals(self, text_result.to_pb2(),
 | 
				
			||||||
 | 
					                                   expected_classification_result.to_pb2())
 | 
				
			||||||
 | 
					    # Closes the classifier explicitly when the classifier is not used in
 | 
				
			||||||
 | 
					    # a context.
 | 
				
			||||||
 | 
					    classifier.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters((ModelFileType.FILE_NAME, _BERT_MODEL_FILE,
 | 
				
			||||||
 | 
					                             _NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS),
 | 
				
			||||||
 | 
					                            (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
 | 
				
			||||||
 | 
					                             _NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS))
 | 
				
			||||||
 | 
					  def test_classify_in_context(self, model_file_type, model_name, text,
 | 
				
			||||||
 | 
					                               expected_classification_result):
 | 
				
			||||||
 | 
					    # Creates classifier.
 | 
				
			||||||
 | 
					    model_path = test_utils.get_test_data_path(
 | 
				
			||||||
 | 
					        os.path.join(_TEST_DATA_DIR, model_name))
 | 
				
			||||||
 | 
					    if model_file_type is ModelFileType.FILE_NAME:
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_path=model_path)
 | 
				
			||||||
 | 
					    elif model_file_type is ModelFileType.FILE_CONTENT:
 | 
				
			||||||
 | 
					      with open(model_path, 'rb') as f:
 | 
				
			||||||
 | 
					        model_content = f.read()
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_buffer=model_content)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      # Should never happen
 | 
				
			||||||
 | 
					      raise ValueError('model_file_type is invalid.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    options = _TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with _TextClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs text classification on the input.
 | 
				
			||||||
 | 
					      text_result = classifier.classify(text)
 | 
				
			||||||
 | 
					      # Comparing results.
 | 
				
			||||||
 | 
					      test_utils.assert_proto_equals(self, text_result.to_pb2(),
 | 
				
			||||||
 | 
					                                     expected_classification_result.to_pb2())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  absltest.main()
 | 
				
			||||||
							
								
								
									
										38
									
								
								mediapipe/tasks/python/text/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mediapipe/tasks/python/text/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,38 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict library and test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "text_classifier",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "text_classifier.py",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/python:packet_creator",
 | 
				
			||||||
 | 
					        "//mediapipe/python:packet_getter",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:classifications",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:task_info",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/text/core:base_text_task_api",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										13
									
								
								mediapipe/tasks/python/text/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/tasks/python/text/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
							
								
								
									
										31
									
								
								mediapipe/tasks/python/text/core/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								mediapipe/tasks/python/text/core/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,31 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict library and test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "base_text_task_api",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "base_text_task_api.py",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_py_pb2",
 | 
				
			||||||
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										16
									
								
								mediapipe/tasks/python/text/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								mediapipe/tasks/python/text/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,16 @@
 | 
				
			||||||
 | 
					"""Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
							
								
								
									
										55
									
								
								mediapipe/tasks/python/text/core/base_text_task_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								mediapipe/tasks/python/text/core/base_text_task_api.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,55 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""MediaPipe text task base api."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.framework import calculator_pb2
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import task_runner
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_TaskRunner = task_runner.TaskRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaseTextTaskApi(object):
 | 
				
			||||||
 | 
					  """The base class of the user-facing mediapipe text task api classes."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self,
 | 
				
			||||||
 | 
					               graph_config: calculator_pb2.CalculatorGraphConfig) -> None:
 | 
				
			||||||
 | 
					    """Initializes the `BaseVisionTaskApi` object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      graph_config: The mediapipe text task graph config proto.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    self._runner = _TaskRunner.create(graph_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def close(self) -> None:
 | 
				
			||||||
 | 
					    """Shuts down the mediapipe text task instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      RuntimeError: If the mediapipe text task failed to close.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    self._runner.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @doc_controls.do_not_generate_docs
 | 
				
			||||||
 | 
					  def __enter__(self):
 | 
				
			||||||
 | 
					    """Returns `self` upon entering the runtime context."""
 | 
				
			||||||
 | 
					    return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @doc_controls.do_not_generate_docs
 | 
				
			||||||
 | 
					  def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
 | 
				
			||||||
 | 
					    """Shuts down the mediapipe text task instance on exit of the context manager.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      RuntimeError: If the mediapipe text task failed to close.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    self.close()
 | 
				
			||||||
							
								
								
									
										140
									
								
								mediapipe/tasks/python/text/text_classifier.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								mediapipe/tasks/python/text/text_classifier.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,140 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""MediaPipe text classifier task."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.python import packet_creator
 | 
				
			||||||
 | 
					from mediapipe.python import packet_getter
 | 
				
			||||||
 | 
					from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
 | 
				
			||||||
 | 
					from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import classifications
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.processors import classifier_options
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import task_info as task_info_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.text.core import base_text_task_api
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TextClassifierResult = classifications.ClassificationResult
 | 
				
			||||||
 | 
					_BaseOptions = base_options_module.BaseOptions
 | 
				
			||||||
 | 
					_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
 | 
				
			||||||
 | 
					_ClassifierOptions = classifier_options.ClassifierOptions
 | 
				
			||||||
 | 
					_TaskInfo = task_info_module.TaskInfo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
 | 
				
			||||||
 | 
					_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
 | 
				
			||||||
 | 
					_TEXT_IN_STREAM_NAME = 'text_in'
 | 
				
			||||||
 | 
					_TEXT_TAG = 'TEXT'
 | 
				
			||||||
 | 
					_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class TextClassifierOptions:
 | 
				
			||||||
 | 
					  """Options for the text classifier task.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    base_options: Base options for the text classifier task.
 | 
				
			||||||
 | 
					    classifier_options: Options for the text classification task.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  base_options: _BaseOptions
 | 
				
			||||||
 | 
					  classifier_options: _ClassifierOptions = _ClassifierOptions()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @doc_controls.do_not_generate_docs
 | 
				
			||||||
 | 
					  def to_pb2(self) -> _TextClassifierGraphOptionsProto:
 | 
				
			||||||
 | 
					    """Generates an TextClassifierOptions protobuf object."""
 | 
				
			||||||
 | 
					    base_options_proto = self.base_options.to_pb2()
 | 
				
			||||||
 | 
					    classifier_options_proto = self.classifier_options.to_pb2()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return _TextClassifierGraphOptionsProto(
 | 
				
			||||||
 | 
					        base_options=base_options_proto,
 | 
				
			||||||
 | 
					        classifier_options=classifier_options_proto)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TextClassifier(base_text_task_api.BaseTextTaskApi):
 | 
				
			||||||
 | 
					  """Class that performs classification on text."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def create_from_model_path(cls, model_path: str) -> 'TextClassifier':
 | 
				
			||||||
 | 
					    """Creates an `TextClassifier` object from a TensorFlow Lite model and the default `TextClassifierOptions`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_path: Path to the model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      `TextClassifier` object that's created from the model file and the
 | 
				
			||||||
 | 
					      default `TextClassifierOptions`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If failed to create `TextClassifier` object from the provided
 | 
				
			||||||
 | 
					        file such as invalid file path.
 | 
				
			||||||
 | 
					      RuntimeError: If other types of error occurred.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    base_options = _BaseOptions(model_asset_path=model_path)
 | 
				
			||||||
 | 
					    options = TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					    return cls.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def create_from_options(cls,
 | 
				
			||||||
 | 
					                          options: TextClassifierOptions) -> 'TextClassifier':
 | 
				
			||||||
 | 
					    """Creates the `TextClassifier` object from text classifier options.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      options: Options for the text classifier task.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      `TextClassifier` object that's created from `options`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If failed to create `TextClassifier` object from
 | 
				
			||||||
 | 
					        `TextClassifierOptions` such as missing the model.
 | 
				
			||||||
 | 
					      RuntimeError: If other types of error occurred.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    task_info = _TaskInfo(
 | 
				
			||||||
 | 
					        task_graph=_TASK_GRAPH_NAME,
 | 
				
			||||||
 | 
					        input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
 | 
				
			||||||
 | 
					        output_streams=[
 | 
				
			||||||
 | 
					            ':'.join([
 | 
				
			||||||
 | 
					                _CLASSIFICATION_RESULT_TAG,
 | 
				
			||||||
 | 
					                _CLASSIFICATION_RESULT_OUT_STREAM_NAME
 | 
				
			||||||
 | 
					            ])
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        task_options=options)
 | 
				
			||||||
 | 
					    return cls(task_info.generate_graph_config())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def classify(self, text: str) -> TextClassifierResult:
 | 
				
			||||||
 | 
					    """Performs classification on the input `text`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      text: The input text.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A `TextClassifierResult` object that contains a list of text
 | 
				
			||||||
 | 
					      classifications.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If any of the input arguments is invalid.
 | 
				
			||||||
 | 
					      RuntimeError: If text classification failed to run.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    output_packets = self._runner.process(
 | 
				
			||||||
 | 
					        {_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    classification_result_proto = classifications_pb2.ClassificationResult()
 | 
				
			||||||
 | 
					    classification_result_proto.CopyFrom(
 | 
				
			||||||
 | 
					        packet_getter.get_proto(
 | 
				
			||||||
 | 
					            output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return TextClassifierResult([
 | 
				
			||||||
 | 
					        classifications.Classifications.create_from_pb2(classification)
 | 
				
			||||||
 | 
					        for classification in classification_result_proto.classifications
 | 
				
			||||||
 | 
					    ])
 | 
				
			||||||
							
								
								
									
										33
									
								
								mediapipe/tasks/web/audio/audio_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								mediapipe/tasks/web/audio/audio_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,33 @@
 | 
				
			||||||
 | 
					# This contains the MediaPipe Audio Classifier Task.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This task takes audio data and outputs the classification result.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "audio_classifier",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "audio_classifier.ts",
 | 
				
			||||||
 | 
					        "audio_classifier_options.ts",
 | 
				
			||||||
 | 
					        "audio_classifier_result.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:classifications",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_result",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:task_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										215
									
								
								mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,215 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
 | 
					import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
				
			||||||
 | 
					import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
 | 
				
			||||||
 | 
					import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
 | 
				
			||||||
 | 
					import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
				
			||||||
 | 
					import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
				
			||||||
 | 
					import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
 | 
				
			||||||
 | 
					import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
				
			||||||
 | 
					import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
				
			||||||
 | 
					import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {AudioClassifierOptions} from './audio_classifier_options';
 | 
				
			||||||
 | 
					import {Classifications} from './audio_classifier_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const MEDIAPIPE_GRAPH =
 | 
				
			||||||
 | 
					    'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and
 | 
				
			||||||
 | 
					// cannot be changed
 | 
				
			||||||
 | 
					// TODO: Change this to `audio_in` to match the name in the CC
 | 
				
			||||||
 | 
					// implementation
 | 
				
			||||||
 | 
					const AUDIO_STREAM = 'input_audio';
 | 
				
			||||||
 | 
					const SAMPLE_RATE_STREAM = 'sample_rate';
 | 
				
			||||||
 | 
					const CLASSIFICATION_RESULT_STREAM = 'classification_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Performs audio classification. */
 | 
				
			||||||
 | 
					export class AudioClassifier extends TaskRunner {
 | 
				
			||||||
 | 
					  private classifications: Classifications[] = [];
 | 
				
			||||||
 | 
					  private defaultSampleRate = 48000;
 | 
				
			||||||
 | 
					  private readonly options = new AudioClassifierGraphOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new audio classifier from the
 | 
				
			||||||
 | 
					   * provided options.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param audioClassifierOptions The options for the audio classifier. Note
 | 
				
			||||||
 | 
					   *     that either a path to the model asset or a model buffer needs to be
 | 
				
			||||||
 | 
					   *     provided (via `baseOptions`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromOptions(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      audioClassifierOptions: AudioClassifierOptions):
 | 
				
			||||||
 | 
					      Promise<AudioClassifier> {
 | 
				
			||||||
 | 
					    // Create a file locator based on the loader options
 | 
				
			||||||
 | 
					    const fileLocator: FileLocator = {
 | 
				
			||||||
 | 
					      locateFile() {
 | 
				
			||||||
 | 
					        // The only file loaded with this mechanism is the Wasm binary
 | 
				
			||||||
 | 
					        return wasmLoaderOptions.wasmBinaryPath.toString();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const classifier = await createMediaPipeLib(
 | 
				
			||||||
 | 
					        AudioClassifier, wasmLoaderOptions.wasmLoaderPath,
 | 
				
			||||||
 | 
					        /* assetLoaderScript= */ undefined,
 | 
				
			||||||
 | 
					        /* glCanvas= */ undefined, fileLocator);
 | 
				
			||||||
 | 
					    await classifier.setOptions(audioClassifierOptions);
 | 
				
			||||||
 | 
					    return classifier;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new audio classifier based on
 | 
				
			||||||
 | 
					   * the provided model asset buffer.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetBuffer A binary representation of the model.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static createFromModelBuffer(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
 | 
				
			||||||
 | 
					    return AudioClassifier.createFromOptions(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new audio classifier based on
 | 
				
			||||||
 | 
					   * the path to the model asset.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetPath The path to the model asset.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromModelPath(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetPath: string): Promise<AudioClassifier> {
 | 
				
			||||||
 | 
					    const response = await fetch(modelAssetPath.toString());
 | 
				
			||||||
 | 
					    const graphData = await response.arrayBuffer();
 | 
				
			||||||
 | 
					    return AudioClassifier.createFromModelBuffer(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, new Uint8Array(graphData));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets new options for the audio classifier.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * Calling `setOptions()` with a subset of options only affects those options.
 | 
				
			||||||
 | 
					   * You can reset an option back to its default value by explicitly setting it
 | 
				
			||||||
 | 
					   * to `undefined`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param options The options for the audio classifier.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
				
			||||||
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
 | 
					      const baseOptionsProto =
 | 
				
			||||||
 | 
					          await convertBaseOptionsToProto(options.baseOptions);
 | 
				
			||||||
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.options.setClassifierOptions(convertClassifierOptionsToProto(
 | 
				
			||||||
 | 
					        options, this.options.getClassifierOptions()));
 | 
				
			||||||
 | 
					    this.refreshGraph();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets the sample rate for all calls to `classify()` that omit an explicit
 | 
				
			||||||
 | 
					   * sample rate. `48000` is used as a default if this method is not called.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param sampleRate A sample rate (e.g. `44100`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  setDefaultSampleRate(sampleRate: number) {
 | 
				
			||||||
 | 
					    this.defaultSampleRate = sampleRate;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs audio classification on the provided audio data and waits
 | 
				
			||||||
 | 
					   * synchronously for the response.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param audioData An array of raw audio capture data, like
 | 
				
			||||||
 | 
					   *     from a call to getChannelData on an AudioBuffer.
 | 
				
			||||||
 | 
					   * @param sampleRate The sample rate in Hz of the provided audio data. If not
 | 
				
			||||||
 | 
					   *     set, defaults to the sample rate set via `setDefaultSampleRate()` or
 | 
				
			||||||
 | 
					   *     `48000` if no custom default was set.
 | 
				
			||||||
 | 
					   * @return The classification result of the audio datas
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
 | 
				
			||||||
 | 
					    sampleRate = sampleRate ?? this.defaultSampleRate;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Configures the number of samples in the WASM layer. We re-configure the
 | 
				
			||||||
 | 
					    // number of samples and the sample rate for every frame, but ignore other
 | 
				
			||||||
 | 
					    // side effects of this function (such as sending the input side packet and
 | 
				
			||||||
 | 
					    // the input stream header).
 | 
				
			||||||
 | 
					    this.configureAudio(
 | 
				
			||||||
 | 
					        /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const timestamp = performance.now();
 | 
				
			||||||
 | 
					    this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
 | 
				
			||||||
 | 
					    this.addAudioToStream(audioData, timestamp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.classifications = [];
 | 
				
			||||||
 | 
					    this.finishProcessing();
 | 
				
			||||||
 | 
					    return [...this.classifications];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Internal function for converting raw data into a classification, and
 | 
				
			||||||
 | 
					   * adding it to our classfications list.
 | 
				
			||||||
 | 
					   **/
 | 
				
			||||||
 | 
					  private addJsAudioClassification(binaryProto: Uint8Array): void {
 | 
				
			||||||
 | 
					    const classificationResult =
 | 
				
			||||||
 | 
					        ClassificationResult.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					    this.classifications.push(
 | 
				
			||||||
 | 
					        ...convertFromClassificationResultProto(classificationResult));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
 | 
					  private refreshGraph(): void {
 | 
				
			||||||
 | 
					    const graphConfig = new CalculatorGraphConfig();
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(AUDIO_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(SAMPLE_RATE_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const calculatorOptions = new CalculatorOptions();
 | 
				
			||||||
 | 
					    calculatorOptions.setExtension(
 | 
				
			||||||
 | 
					        AudioClassifierGraphOptions.ext, this.options);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Perform audio classification. Pre-processing and results post-processing
 | 
				
			||||||
 | 
					    // are built-in.
 | 
				
			||||||
 | 
					    const classifierNode = new CalculatorGraphConfig.Node();
 | 
				
			||||||
 | 
					    classifierNode.setCalculator(MEDIAPIPE_GRAPH);
 | 
				
			||||||
 | 
					    classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.addOutputStream(
 | 
				
			||||||
 | 
					        'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    graphConfig.addNode(classifierNode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.addJsAudioClassification(binaryProto);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const binaryGraph = graphConfig.serializeBinary();
 | 
				
			||||||
 | 
					    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,17 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options';
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {Category} from '../../../../tasks/web/components/containers/category';
 | 
				
			||||||
 | 
					export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
 | 
				
			||||||
| 
						 | 
					@ -29,31 +29,31 @@ export function convertClassifierOptionsToProto(
 | 
				
			||||||
    baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
 | 
					    baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
 | 
				
			||||||
  const classifierOptions =
 | 
					  const classifierOptions =
 | 
				
			||||||
      baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
 | 
					      baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
 | 
				
			||||||
  if (options.displayNamesLocale) {
 | 
					  if (options.displayNamesLocale !== undefined) {
 | 
				
			||||||
    classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
 | 
					    classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
 | 
				
			||||||
  } else if (options.displayNamesLocale === undefined) {
 | 
					  } else if (options.displayNamesLocale === undefined) {
 | 
				
			||||||
    classifierOptions.clearDisplayNamesLocale();
 | 
					    classifierOptions.clearDisplayNamesLocale();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (options.maxResults) {
 | 
					  if (options.maxResults !== undefined) {
 | 
				
			||||||
    classifierOptions.setMaxResults(options.maxResults);
 | 
					    classifierOptions.setMaxResults(options.maxResults);
 | 
				
			||||||
  } else if ('maxResults' in options) {  // Check for undefined
 | 
					  } else if ('maxResults' in options) {  // Check for undefined
 | 
				
			||||||
    classifierOptions.clearMaxResults();
 | 
					    classifierOptions.clearMaxResults();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (options.scoreThreshold) {
 | 
					  if (options.scoreThreshold !== undefined) {
 | 
				
			||||||
    classifierOptions.setScoreThreshold(options.scoreThreshold);
 | 
					    classifierOptions.setScoreThreshold(options.scoreThreshold);
 | 
				
			||||||
  } else if ('scoreThreshold' in options) {  // Check for undefined
 | 
					  } else if ('scoreThreshold' in options) {  // Check for undefined
 | 
				
			||||||
    classifierOptions.clearScoreThreshold();
 | 
					    classifierOptions.clearScoreThreshold();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (options.categoryAllowlist) {
 | 
					  if (options.categoryAllowlist !== undefined) {
 | 
				
			||||||
    classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
 | 
					    classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
 | 
				
			||||||
  } else if ('categoryAllowlist' in options) {  // Check for undefined
 | 
					  } else if ('categoryAllowlist' in options) {  // Check for undefined
 | 
				
			||||||
    classifierOptions.clearCategoryAllowlistList();
 | 
					    classifierOptions.clearCategoryAllowlistList();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (options.categoryDenylist) {
 | 
					  if (options.categoryDenylist !== undefined) {
 | 
				
			||||||
    classifierOptions.setCategoryDenylistList(options.categoryDenylist);
 | 
					    classifierOptions.setCategoryDenylistList(options.categoryDenylist);
 | 
				
			||||||
  } else if ('categoryDenylist' in options) {  // Check for undefined
 | 
					  } else if ('categoryDenylist' in options) {  // Check for undefined
 | 
				
			||||||
    classifierOptions.clearCategoryDenylistList();
 | 
					    classifierOptions.clearCategoryDenylistList();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,6 +12,18 @@ mediapipe_ts_library(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "task_runner",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "task_runner.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mediapipe_ts_library(
 | 
					mediapipe_ts_library(
 | 
				
			||||||
    name = "classifier_options",
 | 
					    name = "classifier_options",
 | 
				
			||||||
    srcs = [
 | 
					    srcs = [
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										83
									
								
								mediapipe/tasks/web/core/task_runner.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								mediapipe/tasks/web/core/task_runner.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,83 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
 | 
				
			||||||
 | 
					import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib';
 | 
				
			||||||
 | 
					import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// tslint:disable-next-line:enforce-name-casing
 | 
				
			||||||
 | 
					const WasmMediaPipeImageLib =
 | 
				
			||||||
 | 
					    SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Base class for all MediaPipe Tasks. */
 | 
				
			||||||
 | 
					export abstract class TaskRunner extends WasmMediaPipeImageLib {
 | 
				
			||||||
 | 
					  private processingErrors: Error[] = [];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  constructor(wasmModule: WasmModule) {
 | 
				
			||||||
 | 
					    super(wasmModule);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Disables the automatic render-to-screen code, which allows for pure
 | 
				
			||||||
 | 
					    // CPU processing.
 | 
				
			||||||
 | 
					    this.setAutoRenderToScreen(false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Enables use of our model resource caching graph service.
 | 
				
			||||||
 | 
					    this.registerModelResourcesGraphService();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
 | 
				
			||||||
 | 
					   * over the video stream. Will replace the previously running MediaPipe graph,
 | 
				
			||||||
 | 
					   * if there is one.
 | 
				
			||||||
 | 
					   * @param graphData The raw MediaPipe graph data, either in binary
 | 
				
			||||||
 | 
					   *     protobuffer format (.binarypb), or else in raw text format (.pbtxt or
 | 
				
			||||||
 | 
					   *     .textproto).
 | 
				
			||||||
 | 
					   * @param isBinary This should be set to true if the graph is in
 | 
				
			||||||
 | 
					   *     binary format, and false if it is in human-readable text format.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  override setGraph(graphData: Uint8Array, isBinary: boolean): void {
 | 
				
			||||||
 | 
					    this.attachErrorListener((code, message) => {
 | 
				
			||||||
 | 
					      this.processingErrors.push(new Error(message));
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    super.setGraph(graphData, isBinary);
 | 
				
			||||||
 | 
					    this.handleErrors();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Forces all queued-up packets to be pushed through the MediaPipe graph as
 | 
				
			||||||
 | 
					   * far as possible, performing all processing until no more processing can be
 | 
				
			||||||
 | 
					   * done.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  override finishProcessing(): void {
 | 
				
			||||||
 | 
					    super.finishProcessing();
 | 
				
			||||||
 | 
					    this.handleErrors();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Throws the error from the error listener if an error was raised. */
 | 
				
			||||||
 | 
					  private handleErrors() {
 | 
				
			||||||
 | 
					    const errorCount = this.processingErrors.length;
 | 
				
			||||||
 | 
					    if (errorCount === 1) {
 | 
				
			||||||
 | 
					      // Re-throw error to get a more meaningful stacktrace
 | 
				
			||||||
 | 
					      throw new Error(this.processingErrors[0].message);
 | 
				
			||||||
 | 
					    } else if (errorCount > 1) {
 | 
				
			||||||
 | 
					      throw new Error(
 | 
				
			||||||
 | 
					          'Encountered multiple errors: ' +
 | 
				
			||||||
 | 
					          this.processingErrors.map(e => e.message).join(', '));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    this.processingErrors = [];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										34
									
								
								mediapipe/tasks/web/text/text_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								mediapipe/tasks/web/text/text_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					# This contains the MediaPipe Text Classifier Task.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This task takes text input performs Natural Language classification (including
 | 
				
			||||||
 | 
					# BERT-based text classification).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "text_classifier",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "text_classifier.ts",
 | 
				
			||||||
 | 
					        "text_classifier_options.d.ts",
 | 
				
			||||||
 | 
					        "text_classifier_result.d.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:classifications",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_result",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:task_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										180
									
								
								mediapipe/tasks/web/text/text_classifier/text_classifier.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								mediapipe/tasks/web/text/text_classifier/text_classifier.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,180 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
 | 
					import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
				
			||||||
 | 
					import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
 | 
				
			||||||
 | 
					import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
 | 
				
			||||||
 | 
					import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
				
			||||||
 | 
					import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
				
			||||||
 | 
					import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
 | 
				
			||||||
 | 
					import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
				
			||||||
 | 
					import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
				
			||||||
 | 
					import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {TextClassifierOptions} from './text_classifier_options';
 | 
				
			||||||
 | 
					import {Classifications} from './text_classifier_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const INPUT_STREAM = 'text_in';
 | 
				
			||||||
 | 
					const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
 | 
				
			||||||
 | 
					const TEXT_CLASSIFIER_GRAPH =
 | 
				
			||||||
 | 
					    'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Performs Natural Language classification. */
 | 
				
			||||||
 | 
					export class TextClassifier extends TaskRunner {
 | 
				
			||||||
 | 
					  private classifications: Classifications[] = [];
 | 
				
			||||||
 | 
					  private readonly options = new TextClassifierGraphOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new text classifier from the
 | 
				
			||||||
 | 
					   * provided options.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param textClassifierOptions The options for the text classifier. Note that
 | 
				
			||||||
 | 
					   *     either a path to the TFLite model or the model itself needs to be
 | 
				
			||||||
 | 
					   *     provided (via `baseOptions`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromOptions(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
 | 
				
			||||||
 | 
					    // Create a file locator based on the loader options
 | 
				
			||||||
 | 
					    const fileLocator: FileLocator = {
 | 
				
			||||||
 | 
					      locateFile() {
 | 
				
			||||||
 | 
					        // The only file we load is the Wasm binary
 | 
				
			||||||
 | 
					        return wasmLoaderOptions.wasmBinaryPath.toString();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const classifier = await createMediaPipeLib(
 | 
				
			||||||
 | 
					        TextClassifier, wasmLoaderOptions.wasmLoaderPath,
 | 
				
			||||||
 | 
					        /* assetLoaderScript= */ undefined,
 | 
				
			||||||
 | 
					        /* glCanvas= */ undefined, fileLocator);
 | 
				
			||||||
 | 
					    await classifier.setOptions(textClassifierOptions);
 | 
				
			||||||
 | 
					    return classifier;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new text classifier based on the
 | 
				
			||||||
 | 
					   * provided model asset buffer.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetBuffer A binary representation of the model.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static createFromModelBuffer(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
 | 
				
			||||||
 | 
					    return TextClassifier.createFromOptions(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new text classifier based on the
 | 
				
			||||||
 | 
					   * path to the model asset.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetPath The path to the model asset.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromModelPath(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetPath: string): Promise<TextClassifier> {
 | 
				
			||||||
 | 
					    const response = await fetch(modelAssetPath.toString());
 | 
				
			||||||
 | 
					    const graphData = await response.arrayBuffer();
 | 
				
			||||||
 | 
					    return TextClassifier.createFromModelBuffer(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, new Uint8Array(graphData));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets new options for the text classifier.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * Calling `setOptions()` with a subset of options only affects those options.
 | 
				
			||||||
 | 
					   * You can reset an option back to its default value by explicitly setting it
 | 
				
			||||||
 | 
					   * to `undefined`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param options The options for the text classifier.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  async setOptions(options: TextClassifierOptions): Promise<void> {
 | 
				
			||||||
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
 | 
					      const baseOptionsProto =
 | 
				
			||||||
 | 
					          await convertBaseOptionsToProto(options.baseOptions);
 | 
				
			||||||
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.options.setClassifierOptions(convertClassifierOptionsToProto(
 | 
				
			||||||
 | 
					        options, this.options.getClassifierOptions()));
 | 
				
			||||||
 | 
					    this.refreshGraph();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs Natural Language classification on the provided text and waits
 | 
				
			||||||
 | 
					   * synchronously for the response.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param text The text to process.
 | 
				
			||||||
 | 
					   * @return The classification result of the text
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  classify(text: string): Classifications[] {
 | 
				
			||||||
 | 
					    // Get classification classes by running our MediaPipe graph.
 | 
				
			||||||
 | 
					    this.classifications = [];
 | 
				
			||||||
 | 
					    this.addStringToStream(
 | 
				
			||||||
 | 
					        text, INPUT_STREAM, /* timestamp= */ performance.now());
 | 
				
			||||||
 | 
					    this.finishProcessing();
 | 
				
			||||||
 | 
					    return [...this.classifications];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Internal function for converting raw data into a classification, and
 | 
				
			||||||
 | 
					  // adding it to our classifications list.
 | 
				
			||||||
 | 
					  private addJsTextClassification(binaryProto: Uint8Array): void {
 | 
				
			||||||
 | 
					    const classificationResult =
 | 
				
			||||||
 | 
					        ClassificationResult.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					    console.log(classificationResult.toObject());
 | 
				
			||||||
 | 
					    this.classifications.push(
 | 
				
			||||||
 | 
					        ...convertFromClassificationResultProto(classificationResult));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
 | 
					  private refreshGraph(): void {
 | 
				
			||||||
 | 
					    const graphConfig = new CalculatorGraphConfig();
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(INPUT_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const calculatorOptions = new CalculatorOptions();
 | 
				
			||||||
 | 
					    calculatorOptions.setExtension(
 | 
				
			||||||
 | 
					        TextClassifierGraphOptions.ext, this.options);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const classifierNode = new CalculatorGraphConfig.Node();
 | 
				
			||||||
 | 
					    classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
 | 
				
			||||||
 | 
					    classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.addOutputStream(
 | 
				
			||||||
 | 
					        'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    graphConfig.addNode(classifierNode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.addJsTextClassification(binaryProto);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const binaryGraph = graphConfig.serializeBinary();
 | 
				
			||||||
 | 
					    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										17
									
								
								mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,17 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options';
 | 
				
			||||||
							
								
								
									
										18
									
								
								mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {Category} from '../../../../tasks/web/components/containers/category';
 | 
				
			||||||
 | 
					export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
 | 
				
			||||||
							
								
								
									
										40
									
								
								mediapipe/tasks/web/vision/gesture_recognizer/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								mediapipe/tasks/web/vision/gesture_recognizer/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,40 @@
 | 
				
			||||||
 | 
					# This contains the MediaPipe Gesture Recognizer Task.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This task takes video frames and outputs synchronized frames along with
 | 
				
			||||||
 | 
					# the detection results for one or more gesture categories, using Gesture Recognizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "gesture_recognizer",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "gesture_recognizer.ts",
 | 
				
			||||||
 | 
					        "gesture_recognizer_options.d.ts",
 | 
				
			||||||
 | 
					        "gesture_recognizer_result.d.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:landmark_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:rect_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:landmark",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:task_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,374 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
 | 
					import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
				
			||||||
 | 
					import {ClassificationList} from '../../../../framework/formats/classification_pb';
 | 
				
			||||||
 | 
					import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
 | 
				
			||||||
 | 
					import {NormalizedRect} from '../../../../framework/formats/rect_pb';
 | 
				
			||||||
 | 
					import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb';
 | 
				
			||||||
 | 
					import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb';
 | 
				
			||||||
 | 
					import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb';
 | 
				
			||||||
 | 
					import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb';
 | 
				
			||||||
 | 
					import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb';
 | 
				
			||||||
 | 
					import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb';
 | 
				
			||||||
 | 
					import {Category} from '../../../../tasks/web/components/containers/category';
 | 
				
			||||||
 | 
					import {Landmark} from '../../../../tasks/web/components/containers/landmark';
 | 
				
			||||||
 | 
					import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
				
			||||||
 | 
					import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
				
			||||||
 | 
					import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
				
			||||||
 | 
					import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
				
			||||||
 | 
					import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {GestureRecognizerOptions} from './gesture_recognizer_options';
 | 
				
			||||||
 | 
					import {GestureRecognitionResult} from './gesture_recognizer_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {ImageSource};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const IMAGE_STREAM = 'image_in';
 | 
				
			||||||
 | 
					const NORM_RECT_STREAM = 'norm_rect';
 | 
				
			||||||
 | 
					const HAND_GESTURES_STREAM = 'hand_gestures';
 | 
				
			||||||
 | 
					const LANDMARKS_STREAM = 'hand_landmarks';
 | 
				
			||||||
 | 
					const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks';
 | 
				
			||||||
 | 
					const HANDEDNESS_STREAM = 'handedness';
 | 
				
			||||||
 | 
					const GESTURE_RECOGNIZER_GRAPH =
 | 
				
			||||||
 | 
					    'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const DEFAULT_NUM_HANDS = 1;
 | 
				
			||||||
 | 
					const DEFAULT_SCORE_THRESHOLD = 0.5;
 | 
				
			||||||
 | 
					const DEFAULT_CATEGORY_INDEX = -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const FULL_IMAGE_RECT = new NormalizedRect();
 | 
				
			||||||
 | 
					FULL_IMAGE_RECT.setXCenter(0.5);
 | 
				
			||||||
 | 
					FULL_IMAGE_RECT.setYCenter(0.5);
 | 
				
			||||||
 | 
					FULL_IMAGE_RECT.setWidth(1);
 | 
				
			||||||
 | 
					FULL_IMAGE_RECT.setHeight(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Performs hand gesture recognition on images. */
 | 
				
			||||||
 | 
					export class GestureRecognizer extends TaskRunner {
 | 
				
			||||||
 | 
					  private gestures: Category[][] = [];
 | 
				
			||||||
 | 
					  private landmarks: Landmark[][] = [];
 | 
				
			||||||
 | 
					  private worldLandmarks: Landmark[][] = [];
 | 
				
			||||||
 | 
					  private handednesses: Category[][] = [];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private readonly options: GestureRecognizerGraphOptions;
 | 
				
			||||||
 | 
					  private readonly handLandmarkerGraphOptions: HandLandmarkerGraphOptions;
 | 
				
			||||||
 | 
					  private readonly handLandmarksDetectorGraphOptions:
 | 
				
			||||||
 | 
					      HandLandmarksDetectorGraphOptions;
 | 
				
			||||||
 | 
					  private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
 | 
				
			||||||
 | 
					  private readonly handGestureRecognizerGraphOptions:
 | 
				
			||||||
 | 
					      HandGestureRecognizerGraphOptions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new gesture recognizer from the
 | 
				
			||||||
 | 
					   * provided options.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param gestureRecognizerOptions The options for the gesture recognizer.
 | 
				
			||||||
 | 
					   *     Note that either a path to the model asset or a model buffer needs to
 | 
				
			||||||
 | 
					   *     be provided (via `baseOptions`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromOptions(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      gestureRecognizerOptions: GestureRecognizerOptions):
 | 
				
			||||||
 | 
					      Promise<GestureRecognizer> {
 | 
				
			||||||
 | 
					    // Create a file locator based on the loader options
 | 
				
			||||||
 | 
					    const fileLocator: FileLocator = {
 | 
				
			||||||
 | 
					      locateFile() {
 | 
				
			||||||
 | 
					        // The only file we load via this mechanism is the Wasm binary
 | 
				
			||||||
 | 
					        return wasmLoaderOptions.wasmBinaryPath.toString();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const recognizer = await createMediaPipeLib(
 | 
				
			||||||
 | 
					        GestureRecognizer, wasmLoaderOptions.wasmLoaderPath,
 | 
				
			||||||
 | 
					        /* assetLoaderScript= */ undefined,
 | 
				
			||||||
 | 
					        /* glCanvas= */ undefined, fileLocator);
 | 
				
			||||||
 | 
					    await recognizer.setOptions(gestureRecognizerOptions);
 | 
				
			||||||
 | 
					    return recognizer;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new gesture recognizer based on
 | 
				
			||||||
 | 
					   * the provided model asset buffer.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetBuffer A binary representation of the model.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static createFromModelBuffer(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
 | 
				
			||||||
 | 
					    return GestureRecognizer.createFromOptions(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new gesture recognizer based on
 | 
				
			||||||
 | 
					   * the path to the model asset.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetPath The path to the model asset.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromModelPath(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetPath: string): Promise<GestureRecognizer> {
 | 
				
			||||||
 | 
					    const response = await fetch(modelAssetPath.toString());
 | 
				
			||||||
 | 
					    const graphData = await response.arrayBuffer();
 | 
				
			||||||
 | 
					    return GestureRecognizer.createFromModelBuffer(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, new Uint8Array(graphData));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  constructor(wasmModule: WasmModule) {
 | 
				
			||||||
 | 
					    super(wasmModule);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.options = new GestureRecognizerGraphOptions();
 | 
				
			||||||
 | 
					    this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions();
 | 
				
			||||||
 | 
					    this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions);
 | 
				
			||||||
 | 
					    this.handLandmarksDetectorGraphOptions =
 | 
				
			||||||
 | 
					        new HandLandmarksDetectorGraphOptions();
 | 
				
			||||||
 | 
					    this.handLandmarkerGraphOptions.setHandLandmarksDetectorGraphOptions(
 | 
				
			||||||
 | 
					        this.handLandmarksDetectorGraphOptions);
 | 
				
			||||||
 | 
					    this.handDetectorGraphOptions = new HandDetectorGraphOptions();
 | 
				
			||||||
 | 
					    this.handLandmarkerGraphOptions.setHandDetectorGraphOptions(
 | 
				
			||||||
 | 
					        this.handDetectorGraphOptions);
 | 
				
			||||||
 | 
					    this.handGestureRecognizerGraphOptions =
 | 
				
			||||||
 | 
					        new HandGestureRecognizerGraphOptions();
 | 
				
			||||||
 | 
					    this.options.setHandGestureRecognizerGraphOptions(
 | 
				
			||||||
 | 
					        this.handGestureRecognizerGraphOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.initDefaults();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Disables the automatic render-to-screen code, which allows for pure
 | 
				
			||||||
 | 
					    // CPU processing.
 | 
				
			||||||
 | 
					    this.setAutoRenderToScreen(false);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets new options for the gesture recognizer.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * Calling `setOptions()` with a subset of options only affects those options.
 | 
				
			||||||
 | 
					   * You can reset an option back to its default value by explicitly setting it
 | 
				
			||||||
 | 
					   * to `undefined`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param options The options for the gesture recognizer.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  async setOptions(options: GestureRecognizerOptions): Promise<void> {
 | 
				
			||||||
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
 | 
					      const baseOptionsProto =
 | 
				
			||||||
 | 
					          await convertBaseOptionsToProto(options.baseOptions);
 | 
				
			||||||
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if ('numHands' in options) {
 | 
				
			||||||
 | 
					      this.handDetectorGraphOptions.setNumHands(
 | 
				
			||||||
 | 
					          options.numHands ?? DEFAULT_NUM_HANDS);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if ('minHandDetectionConfidence' in options) {
 | 
				
			||||||
 | 
					      this.handDetectorGraphOptions.setMinDetectionConfidence(
 | 
				
			||||||
 | 
					          options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if ('minHandPresenceConfidence' in options) {
 | 
				
			||||||
 | 
					      this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
 | 
				
			||||||
 | 
					          options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if ('minTrackingConfidence' in options) {
 | 
				
			||||||
 | 
					      this.handLandmarkerGraphOptions.setMinTrackingConfidence(
 | 
				
			||||||
 | 
					          options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (options.cannedGesturesClassifierOptions) {
 | 
				
			||||||
 | 
					      // Note that we have to support both JSPB and ProtobufJS and cannot
 | 
				
			||||||
 | 
					      // use JSPB's getMutableX() APIs.
 | 
				
			||||||
 | 
					      const graphOptions = new GestureClassifierGraphOptions();
 | 
				
			||||||
 | 
					      graphOptions.setClassifierOptions(convertClassifierOptionsToProto(
 | 
				
			||||||
 | 
					          options.cannedGesturesClassifierOptions,
 | 
				
			||||||
 | 
					          this.handGestureRecognizerGraphOptions
 | 
				
			||||||
 | 
					              .getCannedGestureClassifierGraphOptions()
 | 
				
			||||||
 | 
					              ?.getClassifierOptions()));
 | 
				
			||||||
 | 
					      this.handGestureRecognizerGraphOptions
 | 
				
			||||||
 | 
					          .setCannedGestureClassifierGraphOptions(graphOptions);
 | 
				
			||||||
 | 
					    } else if (options.cannedGesturesClassifierOptions === undefined) {
 | 
				
			||||||
 | 
					      this.handGestureRecognizerGraphOptions
 | 
				
			||||||
 | 
					          .getCannedGestureClassifierGraphOptions()
 | 
				
			||||||
 | 
					          ?.clearClassifierOptions();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (options.customGesturesClassifierOptions) {
 | 
				
			||||||
 | 
					      const graphOptions = new GestureClassifierGraphOptions();
 | 
				
			||||||
 | 
					      graphOptions.setClassifierOptions(convertClassifierOptionsToProto(
 | 
				
			||||||
 | 
					          options.customGesturesClassifierOptions,
 | 
				
			||||||
 | 
					          this.handGestureRecognizerGraphOptions
 | 
				
			||||||
 | 
					              .getCustomGestureClassifierGraphOptions()
 | 
				
			||||||
 | 
					              ?.getClassifierOptions()));
 | 
				
			||||||
 | 
					      this.handGestureRecognizerGraphOptions
 | 
				
			||||||
 | 
					          .setCustomGestureClassifierGraphOptions(graphOptions);
 | 
				
			||||||
 | 
					    } else if (options.customGesturesClassifierOptions === undefined) {
 | 
				
			||||||
 | 
					      this.handGestureRecognizerGraphOptions
 | 
				
			||||||
 | 
					          .getCustomGestureClassifierGraphOptions()
 | 
				
			||||||
 | 
					          ?.clearClassifierOptions();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.refreshGraph();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs gesture recognition on the provided single image and waits
 | 
				
			||||||
 | 
					   * synchronously for the response.
 | 
				
			||||||
 | 
					   * @param imageSource An image source to process.
 | 
				
			||||||
 | 
					   * @param timestamp The timestamp of the current frame, in ms. If not
 | 
				
			||||||
 | 
					   *    provided, defaults to `performance.now()`.
 | 
				
			||||||
 | 
					   * @return The detected gestures.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  recognize(imageSource: ImageSource, timestamp: number = performance.now()):
 | 
				
			||||||
 | 
					      GestureRecognitionResult {
 | 
				
			||||||
 | 
					    this.gestures = [];
 | 
				
			||||||
 | 
					    this.landmarks = [];
 | 
				
			||||||
 | 
					    this.worldLandmarks = [];
 | 
				
			||||||
 | 
					    this.handednesses = [];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
 | 
				
			||||||
 | 
					    this.addProtoToStream(
 | 
				
			||||||
 | 
					        FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
 | 
				
			||||||
 | 
					        NORM_RECT_STREAM, timestamp);
 | 
				
			||||||
 | 
					    this.finishProcessing();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					      gestures: this.gestures,
 | 
				
			||||||
 | 
					      landmarks: this.landmarks,
 | 
				
			||||||
 | 
					      worldLandmarks: this.worldLandmarks,
 | 
				
			||||||
 | 
					      handednesses: this.handednesses
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Sets the default values for the graph. */
 | 
				
			||||||
 | 
					  private initDefaults(): void {
 | 
				
			||||||
 | 
					    this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS);
 | 
				
			||||||
 | 
					    this.handDetectorGraphOptions.setMinDetectionConfidence(
 | 
				
			||||||
 | 
					        DEFAULT_SCORE_THRESHOLD);
 | 
				
			||||||
 | 
					    this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
 | 
				
			||||||
 | 
					        DEFAULT_SCORE_THRESHOLD);
 | 
				
			||||||
 | 
					    this.handLandmarkerGraphOptions.setMinTrackingConfidence(
 | 
				
			||||||
 | 
					        DEFAULT_SCORE_THRESHOLD);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Converts the proto data to a Category[][] structure. */
 | 
				
			||||||
 | 
					  private toJsCategories(data: Uint8Array[]): Category[][] {
 | 
				
			||||||
 | 
					    const result: Category[][] = [];
 | 
				
			||||||
 | 
					    for (const binaryProto of data) {
 | 
				
			||||||
 | 
					      const inputList = ClassificationList.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					      const outputList: Category[] = [];
 | 
				
			||||||
 | 
					      for (const classification of inputList.getClassificationList()) {
 | 
				
			||||||
 | 
					        outputList.push({
 | 
				
			||||||
 | 
					          score: classification.getScore() ?? 0,
 | 
				
			||||||
 | 
					          index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX,
 | 
				
			||||||
 | 
					          categoryName: classification.getLabel() ?? '',
 | 
				
			||||||
 | 
					          displayName: classification.getDisplayName() ?? '',
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      result.push(outputList);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Converts raw data into a landmark, and adds it to our landmarks list. */
 | 
				
			||||||
 | 
					  private addJsLandmarks(data: Uint8Array[]): void {
 | 
				
			||||||
 | 
					    for (const binaryProto of data) {
 | 
				
			||||||
 | 
					      const handLandmarksProto =
 | 
				
			||||||
 | 
					          NormalizedLandmarkList.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					      const landmarks: Landmark[] = [];
 | 
				
			||||||
 | 
					      for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) {
 | 
				
			||||||
 | 
					        landmarks.push({
 | 
				
			||||||
 | 
					          x: handLandmarkProto.getX() ?? 0,
 | 
				
			||||||
 | 
					          y: handLandmarkProto.getY() ?? 0,
 | 
				
			||||||
 | 
					          z: handLandmarkProto.getZ() ?? 0,
 | 
				
			||||||
 | 
					          normalized: true
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      this.landmarks.push(landmarks);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Converts raw data into a landmark, and adds it to our worldLandmarks
 | 
				
			||||||
 | 
					   * list.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  private adddJsWorldLandmarks(data: Uint8Array[]): void {
 | 
				
			||||||
 | 
					    for (const binaryProto of data) {
 | 
				
			||||||
 | 
					      const handWorldLandmarksProto =
 | 
				
			||||||
 | 
					          LandmarkList.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					      const worldLandmarks: Landmark[] = [];
 | 
				
			||||||
 | 
					      for (const handWorldLandmarkProto of
 | 
				
			||||||
 | 
					               handWorldLandmarksProto.getLandmarkList()) {
 | 
				
			||||||
 | 
					        worldLandmarks.push({
 | 
				
			||||||
 | 
					          x: handWorldLandmarkProto.getX() ?? 0,
 | 
				
			||||||
 | 
					          y: handWorldLandmarkProto.getY() ?? 0,
 | 
				
			||||||
 | 
					          z: handWorldLandmarkProto.getZ() ?? 0,
 | 
				
			||||||
 | 
					          normalized: false
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      this.worldLandmarks.push(worldLandmarks);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
 | 
					  private refreshGraph(): void {
 | 
				
			||||||
 | 
					    const graphConfig = new CalculatorGraphConfig();
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(IMAGE_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(NORM_RECT_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(HAND_GESTURES_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(LANDMARKS_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(HANDEDNESS_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const calculatorOptions = new CalculatorOptions();
 | 
				
			||||||
 | 
					    calculatorOptions.setExtension(
 | 
				
			||||||
 | 
					        GestureRecognizerGraphOptions.ext, this.options);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const recognizerNode = new CalculatorGraphConfig.Node();
 | 
				
			||||||
 | 
					    recognizerNode.setCalculator(GESTURE_RECOGNIZER_GRAPH);
 | 
				
			||||||
 | 
					    recognizerNode.addInputStream('IMAGE:' + IMAGE_STREAM);
 | 
				
			||||||
 | 
					    recognizerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
 | 
				
			||||||
 | 
					    recognizerNode.addOutputStream('HAND_GESTURES:' + HAND_GESTURES_STREAM);
 | 
				
			||||||
 | 
					    recognizerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM);
 | 
				
			||||||
 | 
					    recognizerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM);
 | 
				
			||||||
 | 
					    recognizerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM);
 | 
				
			||||||
 | 
					    recognizerNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    graphConfig.addNode(recognizerNode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.addJsLandmarks(binaryProto);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.adddJsWorldLandmarks(binaryProto);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.gestures.push(...this.toJsCategories(binaryProto));
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.handednesses.push(...this.toJsCategories(binaryProto));
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const binaryGraph = graphConfig.serializeBinary();
 | 
				
			||||||
 | 
					    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										65
									
								
								mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,65 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
				
			||||||
 | 
					import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Options to configure the MediaPipe Gesture Recognizer Task */
 | 
				
			||||||
 | 
					export interface GestureRecognizerOptions {
 | 
				
			||||||
 | 
					  /** Options to configure the loading of the model assets. */
 | 
				
			||||||
 | 
					  baseOptions?: BaseOptions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * The maximum number of hands can be detected by the GestureRecognizer.
 | 
				
			||||||
 | 
					   * Defaults to 1.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  numHands?: number|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * The minimum confidence score for the hand detection to be considered
 | 
				
			||||||
 | 
					   * successful. Defaults to 0.5.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  minHandDetectionConfidence?: number|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * The minimum confidence score of hand presence score in the hand landmark
 | 
				
			||||||
 | 
					   * detection. Defaults to 0.5.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  minHandPresenceConfidence?: number|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * The minimum confidence score for the hand tracking to be considered
 | 
				
			||||||
 | 
					   * successful. Defaults to 0.5.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  minTrackingConfidence?: number|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets the optional `ClassifierOptions` controling the canned gestures
 | 
				
			||||||
 | 
					   * classifier, such as score threshold, allow list and deny list of gestures.
 | 
				
			||||||
 | 
					   * The categories for canned gesture
 | 
				
			||||||
 | 
					   * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up",
 | 
				
			||||||
 | 
					   * "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"]
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  // TODO: Note this option is subject to change
 | 
				
			||||||
 | 
					  cannedGesturesClassifierOptions?: ClassifierOptions|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Options for configuring the custom gestures classifier, such as score
 | 
				
			||||||
 | 
					   * threshold, allow list and deny list of gestures.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  // TODO b/251816640): Note this option is subject to change.
 | 
				
			||||||
 | 
					  customGesturesClassifierOptions?: ClassifierOptions|undefined;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										35
									
								
								mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,35 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {Category} from '../../../../tasks/web/components/containers/category';
 | 
				
			||||||
 | 
					import {Landmark} from '../../../../tasks/web/components/containers/landmark';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Represents the gesture recognition results generated by `GestureRecognizer`.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export interface GestureRecognitionResult {
 | 
				
			||||||
 | 
					  /** Hand landmarks of detected hands. */
 | 
				
			||||||
 | 
					  landmarks: Landmark[][];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Hand landmarks in world coordniates of detected hands. */
 | 
				
			||||||
 | 
					  worldLandmarks: Landmark[][];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Handedness of detected hands. */
 | 
				
			||||||
 | 
					  handednesses: Category[][];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Recognized hand gestures of detected hands */
 | 
				
			||||||
 | 
					  gestures: Category[][];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										33
									
								
								mediapipe/tasks/web/vision/image_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								mediapipe/tasks/web/vision/image_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,33 @@
 | 
				
			||||||
 | 
					# This contains the MediaPipe Image Classifier Task.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This task takes video or image frames and outputs the classification result.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "image_classifier",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "image_classifier.ts",
 | 
				
			||||||
 | 
					        "image_classifier_options.ts",
 | 
				
			||||||
 | 
					        "image_classifier_result.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:classifications",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_result",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:task_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										186
									
								
								mediapipe/tasks/web/vision/image_classifier/image_classifier.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								mediapipe/tasks/web/vision/image_classifier/image_classifier.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,186 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
 | 
					import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
				
			||||||
 | 
					import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
 | 
				
			||||||
 | 
					import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb';
 | 
				
			||||||
 | 
					import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
				
			||||||
 | 
					import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
				
			||||||
 | 
					import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
 | 
				
			||||||
 | 
					import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
				
			||||||
 | 
					import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
				
			||||||
 | 
					import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {ImageClassifierOptions} from './image_classifier_options';
 | 
				
			||||||
 | 
					import {Classifications} from './image_classifier_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const IMAGE_CLASSIFIER_GRAPH =
 | 
				
			||||||
 | 
					    'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
 | 
				
			||||||
 | 
					const INPUT_STREAM = 'input_image';
 | 
				
			||||||
 | 
					const CLASSIFICATION_RESULT_STREAM = 'classification_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {ImageSource};  // Used in the public API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Performs classification on images. */
 | 
				
			||||||
 | 
					export class ImageClassifier extends TaskRunner {
 | 
				
			||||||
 | 
					  private classifications: Classifications[] = [];
 | 
				
			||||||
 | 
					  private readonly options = new ImageClassifierGraphOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new image classifier from the
 | 
				
			||||||
 | 
					   * provided options.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param imageClassifierOptions The options for the image classifier. Note
 | 
				
			||||||
 | 
					   *     that either a path to the model asset or a model buffer needs to be
 | 
				
			||||||
 | 
					   *     provided (via `baseOptions`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromOptions(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      imageClassifierOptions: ImageClassifierOptions):
 | 
				
			||||||
 | 
					      Promise<ImageClassifier> {
 | 
				
			||||||
 | 
					    // Create a file locator based on the loader options
 | 
				
			||||||
 | 
					    const fileLocator: FileLocator = {
 | 
				
			||||||
 | 
					      locateFile() {
 | 
				
			||||||
 | 
					        // The only file we load is the Wasm binary
 | 
				
			||||||
 | 
					        return wasmLoaderOptions.wasmBinaryPath.toString();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const classifier = await createMediaPipeLib(
 | 
				
			||||||
 | 
					        ImageClassifier, wasmLoaderOptions.wasmLoaderPath,
 | 
				
			||||||
 | 
					        /* assetLoaderScript= */ undefined,
 | 
				
			||||||
 | 
					        /* glCanvas= */ undefined, fileLocator);
 | 
				
			||||||
 | 
					    await classifier.setOptions(imageClassifierOptions);
 | 
				
			||||||
 | 
					    return classifier;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new image classifier based on
 | 
				
			||||||
 | 
					   * the provided model asset buffer.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetBuffer A binary representation of the model.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static createFromModelBuffer(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
 | 
				
			||||||
 | 
					    return ImageClassifier.createFromOptions(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new image classifier based on
 | 
				
			||||||
 | 
					   * the path to the model asset.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetPath The path to the model asset.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromModelPath(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetPath: string): Promise<ImageClassifier> {
 | 
				
			||||||
 | 
					    const response = await fetch(modelAssetPath.toString());
 | 
				
			||||||
 | 
					    const graphData = await response.arrayBuffer();
 | 
				
			||||||
 | 
					    return ImageClassifier.createFromModelBuffer(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, new Uint8Array(graphData));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets new options for the image classifier.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * Calling `setOptions()` with a subset of options only affects those options.
 | 
				
			||||||
 | 
					   * You can reset an option back to its default value by explicitly setting it
 | 
				
			||||||
 | 
					   * to `undefined`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param options The options for the image classifier.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  async setOptions(options: ImageClassifierOptions): Promise<void> {
 | 
				
			||||||
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
 | 
					      const baseOptionsProto =
 | 
				
			||||||
 | 
					          await convertBaseOptionsToProto(options.baseOptions);
 | 
				
			||||||
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.options.setClassifierOptions(convertClassifierOptionsToProto(
 | 
				
			||||||
 | 
					        options, this.options.getClassifierOptions()));
 | 
				
			||||||
 | 
					    this.refreshGraph();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs image classification on the provided image and waits synchronously
 | 
				
			||||||
 | 
					   * for the response.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param imageSource An image source to process.
 | 
				
			||||||
 | 
					   * @param timestamp The timestamp of the current frame, in ms. If not
 | 
				
			||||||
 | 
					   *     provided, defaults to `performance.now()`.
 | 
				
			||||||
 | 
					   * @return The classification result of the image
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
 | 
				
			||||||
 | 
					    // Get classification classes by running our MediaPipe graph.
 | 
				
			||||||
 | 
					    this.classifications = [];
 | 
				
			||||||
 | 
					    this.addGpuBufferAsImageToStream(
 | 
				
			||||||
 | 
					        imageSource, INPUT_STREAM, timestamp ?? performance.now());
 | 
				
			||||||
 | 
					    this.finishProcessing();
 | 
				
			||||||
 | 
					    return [...this.classifications];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Internal function for converting raw data into a classification, and
 | 
				
			||||||
 | 
					   * adding it to our classfications list.
 | 
				
			||||||
 | 
					   **/
 | 
				
			||||||
 | 
					  private addJsImageClassification(binaryProto: Uint8Array): void {
 | 
				
			||||||
 | 
					    const classificationResult =
 | 
				
			||||||
 | 
					        ClassificationResult.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					    this.classifications.push(
 | 
				
			||||||
 | 
					        ...convertFromClassificationResultProto(classificationResult));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
 | 
					  private refreshGraph(): void {
 | 
				
			||||||
 | 
					    const graphConfig = new CalculatorGraphConfig();
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(INPUT_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const calculatorOptions = new CalculatorOptions();
 | 
				
			||||||
 | 
					    calculatorOptions.setExtension(
 | 
				
			||||||
 | 
					        ImageClassifierGraphOptions.ext, this.options);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Perform image classification. Pre-processing and results post-processing
 | 
				
			||||||
 | 
					    // are built-in.
 | 
				
			||||||
 | 
					    const classifierNode = new CalculatorGraphConfig.Node();
 | 
				
			||||||
 | 
					    classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
 | 
				
			||||||
 | 
					    classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.addOutputStream(
 | 
				
			||||||
 | 
					        'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
 | 
				
			||||||
 | 
					    classifierNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    graphConfig.addNode(classifierNode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.addJsImageClassification(binaryProto);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const binaryGraph = graphConfig.serializeBinary();
 | 
				
			||||||
 | 
					    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,17 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options';
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {Category} from '../../../../tasks/web/components/containers/category';
 | 
				
			||||||
 | 
					export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
 | 
				
			||||||
							
								
								
									
										30
									
								
								mediapipe/tasks/web/vision/object_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								mediapipe/tasks/web/vision/object_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,30 @@
 | 
				
			||||||
 | 
					# This contains the MediaPipe Object Detector Task.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This task takes video frames and outputs synchronized frames along with
 | 
				
			||||||
 | 
					# the detection results for one or more object categories, using Object Detector.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "object_detector",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "object_detector.ts",
 | 
				
			||||||
 | 
					        "object_detector_options.d.ts",
 | 
				
			||||||
 | 
					        "object_detector_result.d.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:detection_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/containers:category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:task_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										233
									
								
								mediapipe/tasks/web/vision/object_detector/object_detector.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								mediapipe/tasks/web/vision/object_detector/object_detector.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,233 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
 | 
					import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
				
			||||||
 | 
					import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
 | 
				
			||||||
 | 
					import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
 | 
				
			||||||
 | 
					import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
				
			||||||
 | 
					import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
				
			||||||
 | 
					import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
				
			||||||
 | 
					import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {ObjectDetectorOptions} from './object_detector_options';
 | 
				
			||||||
 | 
					import {Detection} from './object_detector_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const INPUT_STREAM = 'input_frame_gpu';
 | 
				
			||||||
 | 
					const DETECTIONS_STREAM = 'detections';
 | 
				
			||||||
 | 
					const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const DEFAULT_CATEGORY_INDEX = -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export {ImageSource};  // Used in the public API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Performs object detection on images. */
 | 
				
			||||||
 | 
					export class ObjectDetector extends TaskRunner {
 | 
				
			||||||
 | 
					  private detections: Detection[] = [];
 | 
				
			||||||
 | 
					  private readonly options = new ObjectDetectorOptionsProto();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new object detector from the
 | 
				
			||||||
 | 
					   * provided options.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param objectDetectorOptions The options for the Object Detector. Note that
 | 
				
			||||||
 | 
					   *     either a path to the model asset or a model buffer needs to be
 | 
				
			||||||
 | 
					   *     provided (via `baseOptions`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromOptions(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
 | 
				
			||||||
 | 
					    // Create a file locator based on the loader options
 | 
				
			||||||
 | 
					    const fileLocator: FileLocator = {
 | 
				
			||||||
 | 
					      locateFile() {
 | 
				
			||||||
 | 
					        // The only file we load is the Wasm binary
 | 
				
			||||||
 | 
					        return wasmLoaderOptions.wasmBinaryPath.toString();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const detector = await createMediaPipeLib(
 | 
				
			||||||
 | 
					        ObjectDetector, wasmLoaderOptions.wasmLoaderPath,
 | 
				
			||||||
 | 
					        /* assetLoaderScript= */ undefined,
 | 
				
			||||||
 | 
					        /* glCanvas= */ undefined, fileLocator);
 | 
				
			||||||
 | 
					    await detector.setOptions(objectDetectorOptions);
 | 
				
			||||||
 | 
					    return detector;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new object detector based on the
 | 
				
			||||||
 | 
					   * provided model asset buffer.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetBuffer A binary representation of the model.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static createFromModelBuffer(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
 | 
				
			||||||
 | 
					    return ObjectDetector.createFromOptions(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Initializes the Wasm runtime and creates a new object detector based on the
 | 
				
			||||||
 | 
					   * path to the model asset.
 | 
				
			||||||
 | 
					   * @param wasmLoaderOptions A configuration object that provides the location
 | 
				
			||||||
 | 
					   *     of the Wasm binary and its loader.
 | 
				
			||||||
 | 
					   * @param modelAssetPath The path to the model asset.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static async createFromModelPath(
 | 
				
			||||||
 | 
					      wasmLoaderOptions: WasmLoaderOptions,
 | 
				
			||||||
 | 
					      modelAssetPath: string): Promise<ObjectDetector> {
 | 
				
			||||||
 | 
					    const response = await fetch(modelAssetPath.toString());
 | 
				
			||||||
 | 
					    const graphData = await response.arrayBuffer();
 | 
				
			||||||
 | 
					    return ObjectDetector.createFromModelBuffer(
 | 
				
			||||||
 | 
					        wasmLoaderOptions, new Uint8Array(graphData));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets new options for the object detector.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * Calling `setOptions()` with a subset of options only affects those options.
 | 
				
			||||||
 | 
					   * You can reset an option back to its default value by explicitly setting it
 | 
				
			||||||
 | 
					   * to `undefined`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param options The options for the object detector.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  async setOptions(options: ObjectDetectorOptions): Promise<void> {
 | 
				
			||||||
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
 | 
					      const baseOptionsProto =
 | 
				
			||||||
 | 
					          await convertBaseOptionsToProto(options.baseOptions);
 | 
				
			||||||
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Note that we have to support both JSPB and ProtobufJS, hence we
 | 
				
			||||||
 | 
					    // have to expliclity clear the values instead of setting them to
 | 
				
			||||||
 | 
					    // `undefined`.
 | 
				
			||||||
 | 
					    if (options.displayNamesLocale !== undefined) {
 | 
				
			||||||
 | 
					      this.options.setDisplayNamesLocale(options.displayNamesLocale);
 | 
				
			||||||
 | 
					    } else if ('displayNamesLocale' in options) {  // Check for undefined
 | 
				
			||||||
 | 
					      this.options.clearDisplayNamesLocale();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (options.maxResults !== undefined) {
 | 
				
			||||||
 | 
					      this.options.setMaxResults(options.maxResults);
 | 
				
			||||||
 | 
					    } else if ('maxResults' in options) {  // Check for undefined
 | 
				
			||||||
 | 
					      this.options.clearMaxResults();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (options.scoreThreshold !== undefined) {
 | 
				
			||||||
 | 
					      this.options.setScoreThreshold(options.scoreThreshold);
 | 
				
			||||||
 | 
					    } else if ('scoreThreshold' in options) {  // Check for undefined
 | 
				
			||||||
 | 
					      this.options.clearScoreThreshold();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (options.categoryAllowlist !== undefined) {
 | 
				
			||||||
 | 
					      this.options.setCategoryAllowlistList(options.categoryAllowlist);
 | 
				
			||||||
 | 
					    } else if ('categoryAllowlist' in options) {  // Check for undefined
 | 
				
			||||||
 | 
					      this.options.clearCategoryAllowlistList();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (options.categoryDenylist !== undefined) {
 | 
				
			||||||
 | 
					      this.options.setCategoryDenylistList(options.categoryDenylist);
 | 
				
			||||||
 | 
					    } else if ('categoryDenylist' in options) {  // Check for undefined
 | 
				
			||||||
 | 
					      this.options.clearCategoryDenylistList();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.refreshGraph();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs object detection on the provided single image and waits
 | 
				
			||||||
 | 
					   * synchronously for the response.
 | 
				
			||||||
 | 
					   * @param imageSource An image source to process.
 | 
				
			||||||
 | 
					   * @param timestamp The timestamp of the current frame, in ms. If not
 | 
				
			||||||
 | 
					   *    provided, defaults to `performance.now()`.
 | 
				
			||||||
 | 
					   * @return The list of detected objects
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  detect(imageSource: ImageSource, timestamp?: number): Detection[] {
 | 
				
			||||||
 | 
					    // Get detections by running our MediaPipe graph.
 | 
				
			||||||
 | 
					    this.detections = [];
 | 
				
			||||||
 | 
					    this.addGpuBufferAsImageToStream(
 | 
				
			||||||
 | 
					        imageSource, INPUT_STREAM, timestamp ?? performance.now());
 | 
				
			||||||
 | 
					    this.finishProcessing();
 | 
				
			||||||
 | 
					    return [...this.detections];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Converts raw data into a Detection, and adds it to our detection list. */
 | 
				
			||||||
 | 
					  private addJsObjectDetections(data: Uint8Array[]): void {
 | 
				
			||||||
 | 
					    for (const binaryProto of data) {
 | 
				
			||||||
 | 
					      const detectionProto = DetectionProto.deserializeBinary(binaryProto);
 | 
				
			||||||
 | 
					      const scores = detectionProto.getScoreList();
 | 
				
			||||||
 | 
					      const indexes = detectionProto.getLabelIdList();
 | 
				
			||||||
 | 
					      const labels = detectionProto.getLabelList();
 | 
				
			||||||
 | 
					      const displayNames = detectionProto.getDisplayNameList();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      const detection: Detection = {categories: []};
 | 
				
			||||||
 | 
					      for (let i = 0; i < scores.length; i++) {
 | 
				
			||||||
 | 
					        detection.categories.push({
 | 
				
			||||||
 | 
					          score: scores[i],
 | 
				
			||||||
 | 
					          index: indexes[i] ?? DEFAULT_CATEGORY_INDEX,
 | 
				
			||||||
 | 
					          categoryName: labels[i] ?? '',
 | 
				
			||||||
 | 
					          displayName: displayNames[i] ?? '',
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      const boundingBox = detectionProto.getLocationData()?.getBoundingBox();
 | 
				
			||||||
 | 
					      if (boundingBox) {
 | 
				
			||||||
 | 
					        detection.boundingBox = {
 | 
				
			||||||
 | 
					          originX: boundingBox.getXmin() ?? 0,
 | 
				
			||||||
 | 
					          originY: boundingBox.getYmin() ?? 0,
 | 
				
			||||||
 | 
					          width: boundingBox.getWidth() ?? 0,
 | 
				
			||||||
 | 
					          height: boundingBox.getHeight() ?? 0
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      this.detections.push(detection);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
 | 
					  private refreshGraph(): void {
 | 
				
			||||||
 | 
					    const graphConfig = new CalculatorGraphConfig();
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(INPUT_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addOutputStream(DETECTIONS_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const calculatorOptions = new CalculatorOptions();
 | 
				
			||||||
 | 
					    calculatorOptions.setExtension(
 | 
				
			||||||
 | 
					        ObjectDetectorOptionsProto.ext, this.options);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const detectorNode = new CalculatorGraphConfig.Node();
 | 
				
			||||||
 | 
					    detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH);
 | 
				
			||||||
 | 
					    detectorNode.addInputStream('IMAGE:' + INPUT_STREAM);
 | 
				
			||||||
 | 
					    detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM);
 | 
				
			||||||
 | 
					    detectorNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    graphConfig.addNode(detectorNode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => {
 | 
				
			||||||
 | 
					      this.addJsObjectDetections(binaryProto);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const binaryGraph = graphConfig.serializeBinary();
 | 
				
			||||||
 | 
					    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										52
									
								
								mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,52 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Options to configure the MediaPipe Object Detector Task */
 | 
				
			||||||
 | 
					export interface ObjectDetectorOptions {
 | 
				
			||||||
 | 
					  /** Options to configure the loading of the model assets. */
 | 
				
			||||||
 | 
					  baseOptions?: BaseOptions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * The locale to use for display names specified through the TFLite Model
 | 
				
			||||||
 | 
					   * Metadata, if any. Defaults to English.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  displayNamesLocale?: string|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** The maximum number of top-scored detection results to return. */
 | 
				
			||||||
 | 
					  maxResults?: number|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Overrides the value provided in the model metadata. Results below this
 | 
				
			||||||
 | 
					   * value are rejected.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  scoreThreshold?: number|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Allowlist of category names. If non-empty, detection results whose category
 | 
				
			||||||
 | 
					   * name is not in this set will be filtered out. Duplicate or unknown category
 | 
				
			||||||
 | 
					   * names are ignored. Mutually exclusive with `categoryDenylist`.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  categoryAllowlist?: string[]|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Denylist of category names. If non-empty, detection results whose category
 | 
				
			||||||
 | 
					   * name is in this set will be filtered out. Duplicate or unknown category
 | 
				
			||||||
 | 
					   * names are ignored. Mutually exclusive with `categoryAllowlist`.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  categoryDenylist?: string[]|undefined;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										38
									
								
								mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,38 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {Category} from '../../../../tasks/web/components/containers/category';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** An integer bounding box, axis aligned. */
 | 
				
			||||||
 | 
					export interface BoundingBox {
 | 
				
			||||||
 | 
					  /** The X coordinate of the top-left corner, in pixels. */
 | 
				
			||||||
 | 
					  originX: number;
 | 
				
			||||||
 | 
					  /** The Y coordinate of the top-left corner, in pixels. */
 | 
				
			||||||
 | 
					  originY: number;
 | 
				
			||||||
 | 
					  /** The width of the bounding box, in pixels. */
 | 
				
			||||||
 | 
					  width: number;
 | 
				
			||||||
 | 
					  /** The height of the bounding box, in pixels. */
 | 
				
			||||||
 | 
					  height: number;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Represents one object detected by the `ObjectDetector`. */
 | 
				
			||||||
 | 
					export interface Detection {
 | 
				
			||||||
 | 
					  /** A list of `Category` objects. */
 | 
				
			||||||
 | 
					  categories: Category[];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** The bounding box of the detected objects. */
 | 
				
			||||||
 | 
					  boundingBox?: BoundingBox;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										41
									
								
								mediapipe/web/graph_runner/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								mediapipe/web/graph_runner/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,41 @@
 | 
				
			||||||
 | 
					# The TypeScript graph runner used by all MediaPipe Web tasks.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = [
 | 
				
			||||||
 | 
					    ":internal",
 | 
				
			||||||
 | 
					    "//mediapipe/tasks:internal",
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package_group(
 | 
				
			||||||
 | 
					    name = "internal",
 | 
				
			||||||
 | 
					    packages = [
 | 
				
			||||||
 | 
					        "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "wasm_mediapipe_lib_ts",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        ":wasm_mediapipe_lib.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    allow_unoptimized_namespaces = True,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "wasm_mediapipe_image_lib_ts",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        ":wasm_mediapipe_image_lib.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    allow_unoptimized_namespaces = True,
 | 
				
			||||||
 | 
					    deps = [":wasm_mediapipe_lib_ts"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "register_model_resources_graph_service_ts",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        ":register_model_resources_graph_service.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    allow_unoptimized_namespaces = True,
 | 
				
			||||||
 | 
					    deps = [":wasm_mediapipe_lib_ts"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,41 @@
 | 
				
			||||||
 | 
					import {WasmMediaPipeLib} from './wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has
 | 
				
			||||||
 | 
					 * access to the wasmModule, among other things. The `any` type is required for
 | 
				
			||||||
 | 
					 * mixin constructors.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					// tslint:disable-next-line:no-any
 | 
				
			||||||
 | 
					type LibConstructor = new (...args: any[]) => WasmMediaPipeLib;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
 | 
				
			||||||
 | 
					 * doesn't break our JS/C++ bridge.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export declare interface WasmModuleRegisterModelResources {
 | 
				
			||||||
 | 
					  _registerModelResourcesGraphService: () => void;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * An implementation of WasmMediaPipeLib that supports registering model
 | 
				
			||||||
 | 
					 * resources to a cache, in the form of a GraphService C++-side. We implement as
 | 
				
			||||||
 | 
					 * a proper TS mixin, to allow for effective multiple inheritance. Sample usage:
 | 
				
			||||||
 | 
					 * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService(
 | 
				
			||||||
 | 
					 *     WasmMediaPipeLib);`
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					// tslint:disable:enforce-name-casing
 | 
				
			||||||
 | 
					export function SupportModelResourcesGraphService<TBase extends LibConstructor>(
 | 
				
			||||||
 | 
					    Base: TBase) {
 | 
				
			||||||
 | 
					  return class extends Base {
 | 
				
			||||||
 | 
					    // tslint:enable:enforce-name-casing
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Instructs the graph runner to use the model resource caching graph
 | 
				
			||||||
 | 
					     * service for both graph expansion/inintialization, as well as for graph
 | 
				
			||||||
 | 
					     * run.
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    registerModelResourcesGraphService(): void {
 | 
				
			||||||
 | 
					      (this.wasmModule as unknown as WasmModuleRegisterModelResources)
 | 
				
			||||||
 | 
					          ._registerModelResourcesGraphService();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										52
									
								
								mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,52 @@
 | 
				
			||||||
 | 
					import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has
 | 
				
			||||||
 | 
					 * access to the wasmModule, among other things. The `any` type is required for
 | 
				
			||||||
 | 
					 * mixin constructors.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					// tslint:disable-next-line:no-any
 | 
				
			||||||
 | 
					type LibConstructor = new (...args: any[]) => WasmMediaPipeLib;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
 | 
				
			||||||
 | 
					 * doesn't break our JS/C++ bridge.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export declare interface WasmImageModule {
 | 
				
			||||||
 | 
					  _addBoundTextureAsImageToStream:
 | 
				
			||||||
 | 
					      (streamNamePtr: number, width: number, height: number,
 | 
				
			||||||
 | 
					       timestamp: number) => void;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * An implementation of WasmMediaPipeLib that supports binding GPU image data as
 | 
				
			||||||
 | 
					 * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for
 | 
				
			||||||
 | 
					 * effective multiple inheritance. Example usage:
 | 
				
			||||||
 | 
					 * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);`
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					// tslint:disable-next-line:enforce-name-casing
 | 
				
			||||||
 | 
					export function SupportImage<TBase extends LibConstructor>(Base: TBase) {
 | 
				
			||||||
 | 
					  return class extends Base {
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Takes the relevant information from the HTML video or image element, and
 | 
				
			||||||
 | 
					     * passes it into the WebGL-based graph for processing on the given stream
 | 
				
			||||||
 | 
					     * at the given timestamp as a MediaPipe image. Processing will not occur
 | 
				
			||||||
 | 
					     * until a blocking call (like processVideoGl or finishProcessing) is made.
 | 
				
			||||||
 | 
					     * @param imageSource Reference to the video frame we wish to add into our
 | 
				
			||||||
 | 
					     *     graph.
 | 
				
			||||||
 | 
					     * @param streamName The name of the MediaPipe graph stream to add the frame
 | 
				
			||||||
 | 
					     *     to.
 | 
				
			||||||
 | 
					     * @param timestamp The timestamp of the input frame, in ms.
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    addGpuBufferAsImageToStream(
 | 
				
			||||||
 | 
					        imageSource: ImageSource, streamName: string, timestamp: number): void {
 | 
				
			||||||
 | 
					      this.wrapStringPtr(streamName, (streamNamePtr: number) => {
 | 
				
			||||||
 | 
					        const [width, height] =
 | 
				
			||||||
 | 
					            this.bindTextureToStream(imageSource, streamNamePtr);
 | 
				
			||||||
 | 
					        (this.wasmModule as unknown as WasmImageModule)
 | 
				
			||||||
 | 
					            ._addBoundTextureAsImageToStream(
 | 
				
			||||||
 | 
					                streamNamePtr, width, height, timestamp);
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										1044
									
								
								mediapipe/web/graph_runner/wasm_mediapipe_lib.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1044
									
								
								mediapipe/web/graph_runner/wasm_mediapipe_lib.ts
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue
	
	Block a user