diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD new file mode 100644 index 000000000..5e7c5afa5 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD @@ -0,0 +1,44 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "kmeans_embedding_lookup", + srcs = ["kmeans_embedding_lookup.cc"], + hdrs = ["kmeans_embedding_lookup.h"], + copts = tflite_copts(), + deps = [ + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_test( + name = "kmeans_embedding_lookup_test", + size = "small", + srcs = ["kmeans_embedding_lookup_test.cc"], + deps = [ + ":kmeans_embedding_lookup", + "//mediapipe/framework/port:gtest_main", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc new file mode 100644 index 000000000..2ab3ed74d --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc @@ -0,0 +1,145 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite::ops::custom { +namespace kmeans_embedding_lookup_op { + +namespace { + +constexpr int kInputMessage = 0; +constexpr int kEncodingTable = 1; +constexpr int kCodebook = 2; +constexpr int kOutputLabel = 0; + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = 1; + const TfLiteTensor* input = GetInput(context, node, kInputMessage); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable); + TF_LITE_ENSURE(context, encoding_table != nullptr); + const TfLiteTensor* codebook = GetInput(context, node, kCodebook); + TF_LITE_ENSURE(context, codebook != nullptr); + const int encoding_size = encoding_table->dims->data[1]; + const int block_size = codebook->dims->data[1]; + output_size->data[1] = encoding_size * block_size; + + // Check if the inputs and output are typed correctly. + if (input->type != kTfLiteInt32) { + context->ReportError(context, "Input type must be Int32."); + return kTfLiteError; + } + if (encoding_table->type != kTfLiteUInt8) { + context->ReportError(context, "Encoding Table type must be UInt8."); + return kTfLiteError; + } + if (codebook->type != kTfLiteFloat32) { + context->ReportError(context, "Codebook type must be Float32."); + return kTfLiteError; + } + if (output->type != kTfLiteFloat32) { + context->ReportError(context, "Output type must be Float32."); + return kTfLiteError; + } + + return context->ResizeTensor(context, output, output_size); +} + +// This is the core method that generates the aggregated embedding from the +// given input, encoding table and codebook tensors. +void GetEmbedding(const TfLiteTensor* input, const TfLiteTensor* encoding_table, + const TfLiteTensor* codebook, float* data) { + const int input_encoding_size = encoding_table->dims->data[1]; + const int block_size = codebook->dims->data[1]; + const int num_tokens = input->dims->data[1]; + const int output_embedding_size = input_encoding_size * block_size; + + int num_embeddings = 0; + std::vector final_embedding(output_embedding_size, 0.0); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + const int32_t token = GetTensorData(input)[token_idx]; + if (token == 0) { + break; + } + ++num_embeddings; + + for (int encoding_dim_idx = 0; encoding_dim_idx < input_encoding_size; + encoding_dim_idx++) { + int codebook_idx = GetTensorData( + encoding_table)[token * input_encoding_size + encoding_dim_idx]; + for (int block_offset = 0; block_offset < block_size; block_offset++) { + final_embedding[encoding_dim_idx * block_size + block_offset] += + GetTensorData( + codebook)[codebook_idx * block_size + block_offset]; + } + } + } + + // Compute the mean of the embeddings. + for (int embed_dim_idx = 0; embed_dim_idx < output_embedding_size; + embed_dim_idx++) { + data[embed_dim_idx] = + final_embedding[embed_dim_idx] / (std::max(num_embeddings, 1)); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputMessage); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable); + TF_LITE_ENSURE(context, encoding_table != nullptr); + const TfLiteTensor* codebook = GetInput(context, node, kCodebook); + TF_LITE_ENSURE(context, codebook != nullptr); + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + + // Sanity checks on the input. + const int batch_size = input->dims->data[0]; + if (batch_size != 1) { + context->ReportError(context, "`batch_size` must be == 1."); + return kTfLiteError; + } + + // Compute the output embedding. + GetEmbedding(input, encoding_table, codebook, GetTensorData(output)); + + return kTfLiteOk; +} + +} // namespace kmeans_embedding_lookup_op + +TfLiteRegistration* Register_KmeansEmbeddingLookup() { + static TfLiteRegistration r = {nullptr, nullptr, + kmeans_embedding_lookup_op::Prepare, + kmeans_embedding_lookup_op::Eval}; + return &r; +} + +} // namespace tflite::ops::custom diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h new file mode 100644 index 000000000..99025b1f6 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This op was originally written by the Learn2Compress team. +// It takes in a list of indices, an encoding table which consists of +// integer indices into a codebook with floating point vectors. +// For each index, it looks up the corresponding row in the encoding table and +// for each entry in the row of the encoding table, it looks up the +// corresponding row in the codebook and populates it in an output embedding. +// The average of the output embeddings for each of the input indices is the +// output of this op. + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace tflite::ops::custom { + +TfLiteRegistration* Register_KmeansEmbeddingLookup(); + +} // namespace tflite::ops::custom + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc new file mode 100644 index 000000000..7bfcb93b9 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc @@ -0,0 +1,176 @@ +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" + +#include +#include +#include +#include +#include + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" + +namespace tflite::ops::custom { +namespace { + +using ::testing::ElementsAreArray; +using ::tflite::ArrayFloatNear; + +// Helper class for testing the op. +class KmeansEmbeddingLookupModel : public SingleOpModel { + public: + explicit KmeansEmbeddingLookupModel( + std::initializer_list input_shape, + std::initializer_list encoding_table_shape, + std::initializer_list codebook_shape, + std::initializer_list output_shape) { + // Setup the model inputs and the interpreter. + output_ = AddOutput({TensorType_FLOAT32, output_shape}); + SetCustomOp("KmeansEmbeddingLookup", std::vector(), + Register_KmeansEmbeddingLookup); + BuildInterpreter({input_shape, encoding_table_shape, codebook_shape}); + } + + TfLiteStatus SetUpInputTensor(const std::vector& input, + const std::vector& encoding_table, + const std::vector& codebook) { + PopulateTensor(input_, {input}); + PopulateTensor(encoding_table_, {encoding_table}); + PopulateTensor(codebook_, {codebook}); + return interpreter_->AllocateTensors(); + } + + void Invoke(const std::vector& input, + const std::vector& encoding_table, + const std::vector& codebook) { + CHECK_EQ(SetUpInputTensor(input, encoding_table, codebook), kTfLiteOk); + CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); + } + + TfLiteStatus InvokeUnchecked(const std::vector& input, + const std::vector& encoding_table, + const std::vector& codebook) { + TfLiteStatus allocation_status = + SetUpInputTensor(input, encoding_table, codebook); + if (allocation_status != kTfLiteOk) { + return allocation_status; + } + return SingleOpModel::Invoke(); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_ = AddInput(TensorType_INT32); + int encoding_table_ = AddInput(TensorType_UINT8); + int codebook_ = AddInput(TensorType_FLOAT32); + int output_; +}; + +template +std::vector FlattenVector2D(std::vector> input_vec) { + std::vector output_vec(input_vec.size() * input_vec[0].size()); + for (int i = 0, k = 0; i < input_vec.size(); i++) { + for (int j = 0; j < input_vec[i].size(); j++, k++) { + output_vec[k] = input_vec[i][j]; + } + } + return output_vec; +} + +class KmeansEmbeddingLookupTestWithSampleInputs : public ::testing::Test { + public: + KmeansEmbeddingLookupTestWithSampleInputs() { + input_ = std::vector({1, 2, 3, 0, 0}); + encoding_table_ = std::vector>( + {{0, 0}, {1, 1}, {1, 2}, {1, 0}, {1, 0}, {2, 0}}); + codebook_ = + std::vector>({{0.0, 0.0}, {7.0, 7.0}, {7.0, 0.0}}); + expected_output_ = std::vector( + // The output is the average of the embeddings at the three indices + // (1, 2, 3). + {7.0, 7.0, 4.66667, 2.33333}); + } + + protected: + std::vector input_; + std::vector> encoding_table_; + std::vector> codebook_; + std::vector expected_output_; +}; + +TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, ReturnsCorrectly) { + // Check if the expected output is returned + KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 5}, + /*encoding_table_shape=*/{6, 2}, + /*codebook_shape=*/{3, 2}, + /*output_shape=*/{1, 4}); + + m.Invoke(input_, FlattenVector2D(encoding_table_), + FlattenVector2D(codebook_)); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear(expected_output_, 1e-5))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4})); +} + +TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, + HandlesNegativeValuesInCodebook) { + KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4}, + /*encoding_table_shape=*/{4, 2}, + /*codebook_shape=*/{4, 3}, + /*output_shape=*/{1, 6}); + std::vector input = std::vector({2, 2, 1, 3}); + std::vector> encoding_table = + std::vector>({{0, 0}, {1, 2}, {3, 0}, {2, 3}}); + std::vector> codebook = std::vector>( + {{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}}); + m.Invoke(input, FlattenVector2D(encoding_table), + FlattenVector2D(codebook)); + std::vector expected_output = + std::vector({2.55, 0.1, 2.4, 2.925, 1.1, 2.65}); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear(expected_output, 1e-5))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6})); +} + +TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, IgnoresIndicesAfterZero) { + KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4}, + /*encoding_table_shape=*/{4, 2}, + /*codebook_shape=*/{4, 3}, + /*output_shape=*/{1, 6}); + std::vector input = std::vector({2, 2, 0, 3}); + std::vector> encoding_table = + std::vector>({{0, 0}, {1, 2}, {3, 0}, {2, 3}}); + std::vector> codebook = std::vector>( + {{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}}); + m.Invoke(input, FlattenVector2D(encoding_table), + FlattenVector2D(codebook)); + std::vector expected_output = + std::vector({0.5, -2.0, 1.0, 5.0, 2.0, 3.0}); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear(expected_output, 1e-5))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6})); +} + +TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) { + // Check that the op errors out when the batch size is greater than 1. + KmeansEmbeddingLookupModel m(/*input_shape=*/{2, 1}, + /*encoding_table_shape=*/{1, 1}, + /*codebook shape=*/{1, 2}, + /*output_shape=*/{2, 2}); + + EXPECT_EQ(m.InvokeUnchecked(/*input=*/{1, 1}, + /*encoding_table=*/{0}, + /*codebook=*/{2.3, 4.5}), + kTfLiteError); +} + +} // namespace +} // namespace tflite::ops::custom