//===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // A buffer for serialized results. // //===----------------------------------------------------------------------===// #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H #define LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h" #include "llvm/Support/Error.h" #include namespace llvm { namespace orc { namespace shared { // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. union CWrapperFunctionResultDataUnion { char *ValuePtr; char Value[sizeof(ValuePtr)]; }; // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. typedef struct { CWrapperFunctionResultDataUnion Data; size_t Size; } CWrapperFunctionResult; /// C++ wrapper function result: Same as CWrapperFunctionResult but /// auto-releases memory. class WrapperFunctionResult { public: /// Create a default WrapperFunctionResult. WrapperFunctionResult() { init(R); } /// Create a WrapperFunctionResult by taking ownership of a /// CWrapperFunctionResult. /// /// Warning: This should only be used by clients writing wrapper-function /// caller utilities (like TargetProcessControl). WrapperFunctionResult(CWrapperFunctionResult R) : R(R) { // Reset R. init(R); } WrapperFunctionResult(const WrapperFunctionResult &) = delete; WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; WrapperFunctionResult(WrapperFunctionResult &&Other) { init(R); std::swap(R, Other.R); } WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { WrapperFunctionResult Tmp(std::move(Other)); std::swap(R, Tmp.R); return *this; } ~WrapperFunctionResult() { if ((R.Size > sizeof(R.Data.Value)) || (R.Size == 0 && R.Data.ValuePtr != nullptr)) free(R.Data.ValuePtr); } /// Release ownership of the contained CWrapperFunctionResult. /// Warning: Do not use -- this method will be removed in the future. It only /// exists to temporarily support some code that will eventually be moved to /// the ORC runtime. CWrapperFunctionResult release() { CWrapperFunctionResult Tmp; init(Tmp); std::swap(R, Tmp); return Tmp; } /// Get a pointer to the data contained in this instance. char *data() { assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && "Cannot get data for out-of-band error value"); return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; } /// Get a const pointer to the data contained in this instance. const char *data() const { assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && "Cannot get data for out-of-band error value"); return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; } /// Returns the size of the data contained in this instance. size_t size() const { assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && "Cannot get data for out-of-band error value"); return R.Size; } /// Returns true if this value is equivalent to a default-constructed /// WrapperFunctionResult. bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; } /// Create a WrapperFunctionResult with the given size and return a pointer /// to the underlying memory. static WrapperFunctionResult allocate(size_t Size) { // Reset. WrapperFunctionResult WFR; WFR.R.Size = Size; if (WFR.R.Size > sizeof(WFR.R.Data.Value)) WFR.R.Data.ValuePtr = (char *)malloc(WFR.R.Size); return WFR; } /// Copy from the given char range. static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { auto WFR = allocate(Size); memcpy(WFR.data(), Source, Size); return WFR; } /// Copy from the given null-terminated string (includes the null-terminator). static WrapperFunctionResult copyFrom(const char *Source) { return copyFrom(Source, strlen(Source) + 1); } /// Copy from the given std::string (includes the null terminator). static WrapperFunctionResult copyFrom(const std::string &Source) { return copyFrom(Source.c_str()); } /// Create an out-of-band error by copying the given string. static WrapperFunctionResult createOutOfBandError(const char *Msg) { // Reset. WrapperFunctionResult WFR; char *Tmp = (char *)malloc(strlen(Msg) + 1); strcpy(Tmp, Msg); WFR.R.Data.ValuePtr = Tmp; return WFR; } /// Create an out-of-band error by copying the given string. static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { return createOutOfBandError(Msg.c_str()); } /// If this value is an out-of-band error then this returns the error message, /// otherwise returns nullptr. const char *getOutOfBandError() const { return R.Size == 0 ? R.Data.ValuePtr : nullptr; } private: static void init(CWrapperFunctionResult &R) { R.Data.ValuePtr = nullptr; R.Size = 0; } CWrapperFunctionResult R; }; namespace detail { template WrapperFunctionResult serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...)); SPSOutputBuffer OB(Result.data(), Result.size()); if (!SPSArgListT::serialize(OB, Args...)) return WrapperFunctionResult::createOutOfBandError( "Error serializing arguments to blob in call"); return Result; } template class WrapperFunctionHandlerCaller { public: template static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, std::index_sequence) { return std::forward(H)(std::get(Args)...); } }; template <> class WrapperFunctionHandlerCaller { public: template static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, std::index_sequence) { std::forward(H)(std::get(Args)...); return SPSEmpty(); } }; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper< decltype(&std::remove_reference_t::operator()), ResultSerializer, SPSTagTs...> {}; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper { public: using ArgTuple = std::tuple...>; using ArgIndices = std::make_index_sequence::value>; template static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, size_t ArgSize) { ArgTuple Args; if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) return WrapperFunctionResult::createOutOfBandError( "Could not deserialize arguments for wrapper function call"); auto HandlerResult = WrapperFunctionHandlerCaller::call( std::forward(H), Args, ArgIndices{}); return ResultSerializer::serialize( std::move(HandlerResult)); } private: template static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, std::index_sequence) { SPSInputBuffer IB(ArgData, ArgSize); return SPSArgList::deserialize(IB, std::get(Args)...); } }; // Map function pointers to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; // Map non-const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; // Map const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionAsyncHandlerHelper : public WrapperFunctionAsyncHandlerHelper< decltype(&std::remove_reference_t::operator()), ResultSerializer, SPSTagTs...> {}; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionAsyncHandlerHelper { public: using ArgTuple = std::tuple...>; using ArgIndices = std::make_index_sequence::value>; template static void applyAsync(HandlerT &&H, SendWrapperFunctionResultT &&SendWrapperFunctionResult, const char *ArgData, size_t ArgSize) { ArgTuple Args; if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) { SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError( "Could not deserialize arguments for wrapper function call")); return; } auto SendResult = [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable { using ResultT = decltype(Result); SendWFR(ResultSerializer::serialize(std::move(Result))); }; callAsync(std::forward(H), std::move(SendResult), std::move(Args), ArgIndices{}); } private: template static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, std::index_sequence) { SPSInputBuffer IB(ArgData, ArgSize); return SPSArgList::deserialize(IB, std::get(Args)...); } template static void callAsync(HandlerT &&H, SerializeAndSendResultT &&SerializeAndSendResult, ArgTupleT Args, std::index_sequence) { (void)Args; // Silence a buggy GCC warning. return std::forward(H)(std::move(SerializeAndSendResult), std::move(std::get(Args))...); } }; // Map function pointers to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionAsyncHandlerHelper : public WrapperFunctionAsyncHandlerHelper {}; // Map non-const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionAsyncHandlerHelper : public WrapperFunctionAsyncHandlerHelper {}; // Map const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionAsyncHandlerHelper : public WrapperFunctionAsyncHandlerHelper {}; template class ResultSerializer { public: static WrapperFunctionResult serialize(RetT Result) { return serializeViaSPSToWrapperFunctionResult>( Result); } }; template class ResultSerializer { public: static WrapperFunctionResult serialize(Error Err) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(Err))); } }; template class ResultSerializer { public: static WrapperFunctionResult serialize(ErrorSuccess Err) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(Err))); } }; template class ResultSerializer> { public: static WrapperFunctionResult serialize(Expected E) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(E))); } }; template class ResultDeserializer { public: static RetT makeValue() { return RetT(); } static void makeSafe(RetT &Result) {} static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); if (!SPSArgList::deserialize(IB, Result)) return make_error( "Error deserializing return value from blob in call", inconvertibleErrorCode()); return Error::success(); } }; template <> class ResultDeserializer { public: static Error makeValue() { return Error::success(); } static void makeSafe(Error &Err) { cantFail(std::move(Err)); } static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); SPSSerializableError BSE; if (!SPSArgList::deserialize(IB, BSE)) return make_error( "Error deserializing return value from blob in call", inconvertibleErrorCode()); Err = fromSPSSerializable(std::move(BSE)); return Error::success(); } }; template class ResultDeserializer, Expected> { public: static Expected makeValue() { return T(); } static void makeSafe(Expected &E) { cantFail(E.takeError()); } static Error deserialize(Expected &E, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); SPSSerializableExpected BSE; if (!SPSArgList>::deserialize(IB, BSE)) return make_error( "Error deserializing return value from blob in call", inconvertibleErrorCode()); E = fromSPSSerializable(std::move(BSE)); return Error::success(); } }; template class AsyncCallResultHelper { // Did you forget to use Error / Expected in your handler? }; } // end namespace detail template class WrapperFunction; template class WrapperFunction { private: template using ResultSerializer = detail::ResultSerializer; public: /// Call a wrapper function. Caller should be callable as /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize); template static Error call(const CallerFn &Caller, RetT &Result, const ArgTs &...Args) { // RetT might be an Error or Expected value. Set the checked flag now: // we don't want the user to have to check the unused result if this // operation fails. detail::ResultDeserializer::makeSafe(Result); auto ArgBuffer = detail::serializeViaSPSToWrapperFunctionResult>( Args...); if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) return make_error(ErrMsg, inconvertibleErrorCode()); WrapperFunctionResult ResultBuffer = Caller(ArgBuffer.data(), ArgBuffer.size()); if (auto ErrMsg = ResultBuffer.getOutOfBandError()) return make_error(ErrMsg, inconvertibleErrorCode()); return detail::ResultDeserializer::deserialize( Result, ResultBuffer.data(), ResultBuffer.size()); } /// Call an async wrapper function. /// Caller should be callable as /// void Fn(unique_function SendResult, /// WrapperFunctionResult ArgBuffer); template static void callAsync(AsyncCallerFn &&Caller, SendDeserializedResultFn &&SendDeserializedResult, const ArgTs &...Args) { using RetT = typename std::tuple_element< 1, typename detail::WrapperFunctionHandlerHelper< std::remove_reference_t, ResultSerializer, SPSRetTagT>::ArgTuple>::type; auto ArgBuffer = detail::serializeViaSPSToWrapperFunctionResult>( Args...); if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) { SendDeserializedResult( make_error(ErrMsg, inconvertibleErrorCode()), detail::ResultDeserializer::makeValue()); return; } auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)]( WrapperFunctionResult R) mutable { RetT RetVal = detail::ResultDeserializer::makeValue(); detail::ResultDeserializer::makeSafe(RetVal); if (auto *ErrMsg = R.getOutOfBandError()) { SDR(make_error(ErrMsg, inconvertibleErrorCode()), std::move(RetVal)); return; } SPSInputBuffer IB(R.data(), R.size()); if (auto Err = detail::ResultDeserializer::deserialize( RetVal, R.data(), R.size())) SDR(std::move(Err), std::move(RetVal)); SDR(Error::success(), std::move(RetVal)); }; Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size()); } /// Handle a call to a wrapper function. template static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, HandlerT &&Handler) { using WFHH = detail::WrapperFunctionHandlerHelper, ResultSerializer, SPSTagTs...>; return WFHH::apply(std::forward(Handler), ArgData, ArgSize); } /// Handle a call to an async wrapper function. template static void handleAsync(const char *ArgData, size_t ArgSize, HandlerT &&Handler, SendResultT &&SendResult) { using WFAHH = detail::WrapperFunctionAsyncHandlerHelper< std::remove_reference_t, ResultSerializer, SPSTagTs...>; WFAHH::applyAsync(std::forward(Handler), std::forward(SendResult), ArgData, ArgSize); } private: template static const T &makeSerializable(const T &Value) { return Value; } static detail::SPSSerializableError makeSerializable(Error Err) { return detail::toSPSSerializable(std::move(Err)); } template static detail::SPSSerializableExpected makeSerializable(Expected E) { return detail::toSPSSerializable(std::move(E)); } }; template class WrapperFunction : private WrapperFunction { public: template static Error call(const CallerFn &Caller, const ArgTs &...Args) { SPSEmpty BE; return WrapperFunction::call(Caller, BE, Args...); } template static void callAsync(AsyncCallerFn &&Caller, SendDeserializedResultFn &&SendDeserializedResult, const ArgTs &...Args) { WrapperFunction::callAsync( std::forward(Caller), [SDR = std::move(SendDeserializedResult)](Error SerializeErr, SPSEmpty E) mutable { SDR(std::move(SerializeErr)); }, Args...); } using WrapperFunction::handle; using WrapperFunction::handleAsync; }; /// A function object that takes an ExecutorAddr as its first argument, /// casts that address to a ClassT*, then calls the given method on that /// pointer passing in the remaining function arguments. This utility /// removes some of the boilerplate from writing wrappers for method calls. /// /// @code{.cpp} /// class MyClass { /// public: /// void myMethod(uint32_t, bool) { ... } /// }; /// /// // SPS Method signature -- note MyClass object address as first argument. /// using SPSMyMethodWrapperSignature = /// SPSTuple; /// /// WrapperFunctionResult /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { /// return WrapperFunction::handle( /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); /// } /// @endcode /// template class MethodWrapperHandler { public: using MethodT = RetT (ClassT::*)(ArgTs...); MethodWrapperHandler(MethodT M) : M(M) {} RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { return (ObjAddr.toPtr()->*M)(std::forward(Args)...); } private: MethodT M; }; /// Create a MethodWrapperHandler object from the given method pointer. template MethodWrapperHandler makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { return MethodWrapperHandler(Method); } /// Represents a serialized wrapper function call. /// Serializing calls themselves allows us to batch them: We can make one /// "run-wrapper-functions" utility and send it a list of calls to run. /// /// The motivating use-case for this API is JITLink allocation actions, where /// we want to run multiple functions to finalize linked memory without having /// to make separate IPC calls for each one. class WrapperFunctionCall { public: using ArgDataBufferType = SmallVector; /// Create a WrapperFunctionCall using the given SPS serializer to serialize /// the arguments. template static Expected Create(ExecutorAddr FnAddr, const ArgTs &...Args) { ArgDataBufferType ArgData; ArgData.resize(SPSSerializer::size(Args...)); SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), ArgData.size()); if (SPSSerializer::serialize(OB, Args...)) return WrapperFunctionCall(FnAddr, std::move(ArgData)); return make_error("Cannot serialize arguments for " "AllocActionCall", inconvertibleErrorCode()); } WrapperFunctionCall() = default; /// Create a WrapperFunctionCall from a target function and arg buffer. WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} /// Returns the address to be called. const ExecutorAddr &getCallee() const { return FnAddr; } /// Returns the argument data. const ArgDataBufferType &getArgData() const { return ArgData; } /// WrapperFunctionCalls convert to true if the callee is non-null. explicit operator bool() const { return !!FnAddr; } /// Run call returning raw WrapperFunctionResult. shared::WrapperFunctionResult run() const { using FnTy = shared::CWrapperFunctionResult(const char *ArgData, size_t ArgSize); return shared::WrapperFunctionResult( FnAddr.toPtr()(ArgData.data(), ArgData.size())); } /// Run call and deserialize result using SPS. template std::enable_if_t::value, Error> runWithSPSRet(RetT &RetVal) const { auto WFR = run(); if (const char *ErrMsg = WFR.getOutOfBandError()) return make_error(ErrMsg, inconvertibleErrorCode()); shared::SPSInputBuffer IB(WFR.data(), WFR.size()); if (!shared::SPSSerializationTraits::deserialize(IB, RetVal)) return make_error("Could not deserialize result from " "serialized wrapper function call", inconvertibleErrorCode()); return Error::success(); } /// Overload for SPS functions returning void. template std::enable_if_t::value, Error> runWithSPSRet() const { shared::SPSEmpty E; return runWithSPSRet(E); } /// Run call and deserialize an SPSError result. SPSError returns and /// deserialization failures are merged into the returned error. Error runWithSPSRetErrorMerged() const { detail::SPSSerializableError RetErr; if (auto Err = runWithSPSRet(RetErr)) return Err; return detail::fromSPSSerializable(std::move(RetErr)); } private: orc::ExecutorAddr FnAddr; ArgDataBufferType ArgData; }; using SPSWrapperFunctionCall = SPSTuple>; template <> class SPSSerializationTraits { public: static size_t size(const WrapperFunctionCall &WFC) { return SPSWrapperFunctionCall::AsArgList::size(WFC.getCallee(), WFC.getArgData()); } static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { return SPSWrapperFunctionCall::AsArgList::serialize(OB, WFC.getCallee(), WFC.getArgData()); } static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { ExecutorAddr FnAddr; WrapperFunctionCall::ArgDataBufferType ArgData; if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) return false; WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); return true; } }; } // end namespace shared } // end namespace orc } // end namespace llvm #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H