Remove unsafe cast.

PiperOrigin-RevId: 555007705
This commit is contained in:
MediaPipe Team 2023-08-08 18:47:42 -07:00 committed by Copybara-Service
parent f9a0244c5b
commit 00e0314040
3 changed files with 22 additions and 13 deletions

View File

@ -18,6 +18,7 @@
#define MEDIAPIPE_FRAMEWORK_PACKET_H_ #define MEDIAPIPE_FRAMEWORK_PACKET_H_
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
@ -368,11 +369,14 @@ class HolderBase {
} }
// Returns a printable string identifying the type stored in the holder. // Returns a printable string identifying the type stored in the holder.
virtual const std::string DebugTypeName() const = 0; virtual const std::string DebugTypeName() const = 0;
// Returns debug data id.
virtual int64_t DebugDataId() const = 0;
// Returns the registered type name if it's available, otherwise the // Returns the registered type name if it's available, otherwise the
// empty string. // empty string.
virtual const std::string RegisteredTypeName() const = 0; virtual const std::string RegisteredTypeName() const = 0;
// Get the type id of the underlying data type. // Get the type id of the underlying data type.
virtual TypeId GetTypeId() const = 0; virtual TypeId GetTypeId() const = 0;
// Downcasts this to Holder<T>. Returns nullptr if deserialization // Downcasts this to Holder<T>. Returns nullptr if deserialization
// failed or if the requested type is not what is stored. // failed or if the requested type is not what is stored.
template <typename T> template <typename T>
@ -534,6 +538,7 @@ class Holder : public HolderBase {
const std::string DebugTypeName() const final { const std::string DebugTypeName() const final {
return MediaPipeTypeStringOrDemangled<T>(); return MediaPipeTypeStringOrDemangled<T>();
} }
int64_t DebugDataId() const final { return reinterpret_cast<int64_t>(ptr_); }
const std::string RegisteredTypeName() const final { const std::string RegisteredTypeName() const final {
const std::string* type_string = MediaPipeTypeString<T>(); const std::string* type_string = MediaPipeTypeString<T>();
if (type_string) { if (type_string) {

View File

@ -1423,5 +1423,13 @@ TEST_F(GraphTracerE2ETest, DestructGraph) {
} }
} }
TEST(TraceBuilderTest, EventDataIsExtracted) {
int value = 10;
Packet p = PointToForeign(&value);
TraceEvent event;
event.set_packet_data_id(&p);
EXPECT_EQ(event.event_data, reinterpret_cast<int64_t>(&value));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,6 +15,9 @@
#ifndef MEDIAPIPE_FRAMEWORK_PROFILER_TRACE_BUFFER_H_ #ifndef MEDIAPIPE_FRAMEWORK_PROFILER_TRACE_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_PROFILER_TRACE_BUFFER_H_ #define MEDIAPIPE_FRAMEWORK_PROFILER_TRACE_BUFFER_H_
#include <cstdint>
#include <string>
#include "absl/time/time.h" #include "absl/time/time.h"
#include "mediapipe/framework/calculator_profile.pb.h" #include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
@ -23,17 +26,6 @@
namespace mediapipe { namespace mediapipe {
namespace packet_internal {
// Returns a hash of the packet data address from a packet data holder.
inline const int64 GetPacketDataId(const HolderBase* holder) {
if (holder == nullptr) {
return 0;
}
const void* address = &(static_cast<const Holder<int>*>(holder)->data());
return reinterpret_cast<int64>(address);
}
} // namespace packet_internal
// Packet trace log event. // Packet trace log event.
struct TraceEvent { struct TraceEvent {
using EventType = GraphTrace::EventType; using EventType = GraphTrace::EventType;
@ -75,8 +67,12 @@ struct TraceEvent {
return *this; return *this;
} }
inline TraceEvent& set_packet_data_id(const Packet* packet) { inline TraceEvent& set_packet_data_id(const Packet* packet) {
this->event_data = const auto* holder = packet_internal::GetHolder(*packet);
packet_internal::GetPacketDataId(packet_internal::GetHolder(*packet)); int64_t data_id = 0;
if (holder != nullptr) {
data_id = holder->DebugDataId();
}
this->event_data = data_id;
return *this; return *this;
} }
inline TraceEvent& set_thread_id(int thread_id) { inline TraceEvent& set_thread_id(int thread_id) {