Home | History | Annotate | Line # | Download | only in Shared
      1 //===- RawByteChannel.h -----------------------------------------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 
      9 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H
     10 #define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H
     11 
     12 #include "llvm/ADT/StringRef.h"
     13 #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h"
     14 #include "llvm/Support/Endian.h"
     15 #include "llvm/Support/Error.h"
     16 #include <cstdint>
     17 #include <mutex>
     18 #include <string>
     19 #include <type_traits>
     20 
     21 namespace llvm {
     22 namespace orc {
     23 namespace shared {
     24 
     25 /// Interface for byte-streams to be used with ORC Serialization.
     26 class RawByteChannel {
     27 public:
     28   virtual ~RawByteChannel() = default;
     29 
     30   /// Read Size bytes from the stream into *Dst.
     31   virtual Error readBytes(char *Dst, unsigned Size) = 0;
     32 
     33   /// Read size bytes from *Src and append them to the stream.
     34   virtual Error appendBytes(const char *Src, unsigned Size) = 0;
     35 
     36   /// Flush the stream if possible.
     37   virtual Error send() = 0;
     38 
     39   /// Notify the channel that we're starting a message send.
     40   /// Locks the channel for writing.
     41   template <typename FunctionIdT, typename SequenceIdT>
     42   Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
     43     writeLock.lock();
     44     if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
     45       writeLock.unlock();
     46       return Err;
     47     }
     48     return Error::success();
     49   }
     50 
     51   /// Notify the channel that we're ending a message send.
     52   /// Unlocks the channel for writing.
     53   Error endSendMessage() {
     54     writeLock.unlock();
     55     return Error::success();
     56   }
     57 
     58   /// Notify the channel that we're starting a message receive.
     59   /// Locks the channel for reading.
     60   template <typename FunctionIdT, typename SequenceNumberT>
     61   Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
     62     readLock.lock();
     63     if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
     64       readLock.unlock();
     65       return Err;
     66     }
     67     return Error::success();
     68   }
     69 
     70   /// Notify the channel that we're ending a message receive.
     71   /// Unlocks the channel for reading.
     72   Error endReceiveMessage() {
     73     readLock.unlock();
     74     return Error::success();
     75   }
     76 
     77   /// Get the lock for stream reading.
     78   std::mutex &getReadLock() { return readLock; }
     79 
     80   /// Get the lock for stream writing.
     81   std::mutex &getWriteLock() { return writeLock; }
     82 
     83 private:
     84   std::mutex readLock, writeLock;
     85 };
     86 
     87 template <typename ChannelT, typename T>
     88 class SerializationTraits<
     89     ChannelT, T, T,
     90     std::enable_if_t<
     91         std::is_base_of<RawByteChannel, ChannelT>::value &&
     92         (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
     93          std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
     94          std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
     95          std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
     96          std::is_same<T, char>::value)>> {
     97 public:
     98   static Error serialize(ChannelT &C, T V) {
     99     support::endian::byte_swap<T, support::big>(V);
    100     return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
    101   };
    102 
    103   static Error deserialize(ChannelT &C, T &V) {
    104     if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
    105       return Err;
    106     support::endian::byte_swap<T, support::big>(V);
    107     return Error::success();
    108   };
    109 };
    110 
    111 template <typename ChannelT>
    112 class SerializationTraits<
    113     ChannelT, bool, bool,
    114     std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> {
    115 public:
    116   static Error serialize(ChannelT &C, bool V) {
    117     uint8_t Tmp = V ? 1 : 0;
    118     if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
    119       return Err;
    120     return Error::success();
    121   }
    122 
    123   static Error deserialize(ChannelT &C, bool &V) {
    124     uint8_t Tmp = 0;
    125     if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
    126       return Err;
    127     V = Tmp != 0;
    128     return Error::success();
    129   }
    130 };
    131 
    132 template <typename ChannelT>
    133 class SerializationTraits<
    134     ChannelT, std::string, StringRef,
    135     std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> {
    136 public:
    137   /// Serialization channel serialization for std::strings.
    138   static Error serialize(RawByteChannel &C, StringRef S) {
    139     if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
    140       return Err;
    141     return C.appendBytes((const char *)S.data(), S.size());
    142   }
    143 };
    144 
    145 template <typename ChannelT, typename T>
    146 class SerializationTraits<
    147     ChannelT, std::string, T,
    148     std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value &&
    149                      (std::is_same<T, const char *>::value ||
    150                       std::is_same<T, char *>::value)>> {
    151 public:
    152   static Error serialize(RawByteChannel &C, const char *S) {
    153     return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
    154                                                                             S);
    155   }
    156 };
    157 
    158 template <typename ChannelT>
    159 class SerializationTraits<
    160     ChannelT, std::string, std::string,
    161     std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> {
    162 public:
    163   /// Serialization channel serialization for std::strings.
    164   static Error serialize(RawByteChannel &C, const std::string &S) {
    165     return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
    166                                                                             S);
    167   }
    168 
    169   /// Serialization channel deserialization for std::strings.
    170   static Error deserialize(RawByteChannel &C, std::string &S) {
    171     uint64_t Count = 0;
    172     if (auto Err = deserializeSeq(C, Count))
    173       return Err;
    174     S.resize(Count);
    175     return C.readBytes(&S[0], Count);
    176   }
    177 };
    178 
    179 } // end namespace shared
    180 } // end namespace orc
    181 } // end namespace llvm
    182 
    183 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H
    184