1 //===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utilities to support construction of simple RPC APIs. 10 // 11 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ 12 // programmers, high performance, low memory overhead, and efficient use of the 13 // communications channel. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H 18 #define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H 19 20 #include <map> 21 #include <thread> 22 #include <vector> 23 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" 26 #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" 27 #include "llvm/Support/MSVCErrorWorkarounds.h" 28 29 #include <future> 30 31 namespace llvm { 32 namespace orc { 33 namespace shared { 34 35 /// Base class of all fatal RPC errors (those that necessarily result in the 36 /// termination of the RPC session). 37 class RPCFatalError : public ErrorInfo<RPCFatalError> { 38 public: 39 static char ID; 40 }; 41 42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection 43 /// has already been closed due to either an error or graceful disconnection. 44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> { 45 public: 46 static char ID; 47 std::error_code convertToErrorCode() const override; 48 void log(raw_ostream &OS) const override; 49 }; 50 51 /// BadFunctionCall is returned from handleOne when the remote makes a call with 52 /// an unrecognized function id. 53 /// 54 /// This error is fatal because Orc RPC needs to know how to parse a function 55 /// call to know where the next call starts, and if it doesn't recognize the 56 /// function id it cannot parse the call. 57 template <typename FnIdT, typename SeqNoT> 58 class BadFunctionCall 59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { 60 public: 61 static char ID; 62 63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) 64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} 65 66 std::error_code convertToErrorCode() const override { 67 return orcError(OrcErrorCode::UnexpectedRPCCall); 68 } 69 70 void log(raw_ostream &OS) const override { 71 OS << "Call to invalid RPC function id '" << FnId 72 << "' with " 73 "sequence number " 74 << SeqNo; 75 } 76 77 private: 78 FnIdT FnId; 79 SeqNoT SeqNo; 80 }; 81 82 template <typename FnIdT, typename SeqNoT> 83 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; 84 85 /// InvalidSequenceNumberForResponse is returned from handleOne when a response 86 /// call arrives with a sequence number that doesn't correspond to any in-flight 87 /// function call. 88 /// 89 /// This error is fatal because Orc RPC needs to know how to parse the rest of 90 /// the response call to know where the next call starts, and if it doesn't have 91 /// a result parser for this sequence number it can't do that. 92 template <typename SeqNoT> 93 class InvalidSequenceNumberForResponse 94 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, 95 RPCFatalError> { 96 public: 97 static char ID; 98 99 InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {} 100 101 std::error_code convertToErrorCode() const override { 102 return orcError(OrcErrorCode::UnexpectedRPCCall); 103 }; 104 105 void log(raw_ostream &OS) const override { 106 OS << "Response has unknown sequence number " << SeqNo; 107 } 108 109 private: 110 SeqNoT SeqNo; 111 }; 112 113 template <typename SeqNoT> 114 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; 115 116 /// This non-fatal error will be passed to asynchronous result handlers in place 117 /// of a result if the connection goes down before a result returns, or if the 118 /// function to be called cannot be negotiated with the remote. 119 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { 120 public: 121 static char ID; 122 123 std::error_code convertToErrorCode() const override; 124 void log(raw_ostream &OS) const override; 125 }; 126 127 /// This error is returned if the remote does not have a handler installed for 128 /// the given RPC function. 129 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { 130 public: 131 static char ID; 132 133 CouldNotNegotiate(std::string Signature); 134 std::error_code convertToErrorCode() const override; 135 void log(raw_ostream &OS) const override; 136 const std::string &getSignature() const { return Signature; } 137 138 private: 139 std::string Signature; 140 }; 141 142 template <typename DerivedFunc, typename FnT> class RPCFunction; 143 144 // RPC Function class. 145 // DerivedFunc should be a user defined class with a static 'getName()' method 146 // returning a const char* representing the function's name. 147 template <typename DerivedFunc, typename RetT, typename... ArgTs> 148 class RPCFunction<DerivedFunc, RetT(ArgTs...)> { 149 public: 150 /// User defined function type. 151 using Type = RetT(ArgTs...); 152 153 /// Return type. 154 using ReturnType = RetT; 155 156 /// Returns the full function prototype as a string. 157 static const char *getPrototype() { 158 static std::string Name = [] { 159 std::string Name; 160 raw_string_ostream(Name) 161 << SerializationTypeName<RetT>::getName() << " " 162 << DerivedFunc::getName() << "(" 163 << SerializationTypeNameSequence<ArgTs...>() << ")"; 164 return Name; 165 }(); 166 return Name.data(); 167 } 168 }; 169 170 /// Allocates RPC function ids during autonegotiation. 171 /// Specializations of this class must provide four members: 172 /// 173 /// static T getInvalidId(): 174 /// Should return a reserved id that will be used to represent missing 175 /// functions during autonegotiation. 176 /// 177 /// static T getResponseId(): 178 /// Should return a reserved id that will be used to send function responses 179 /// (return values). 180 /// 181 /// static T getNegotiateId(): 182 /// Should return a reserved id for the negotiate function, which will be used 183 /// to negotiate ids for user defined functions. 184 /// 185 /// template <typename Func> T allocate(): 186 /// Allocate a unique id for function Func. 187 template <typename T, typename = void> class RPCFunctionIdAllocator; 188 189 /// This specialization of RPCFunctionIdAllocator provides a default 190 /// implementation for integral types. 191 template <typename T> 192 class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> { 193 public: 194 static T getInvalidId() { return T(0); } 195 static T getResponseId() { return T(1); } 196 static T getNegotiateId() { return T(2); } 197 198 template <typename Func> T allocate() { return NextId++; } 199 200 private: 201 T NextId = 3; 202 }; 203 204 namespace detail { 205 206 /// Provides a typedef for a tuple containing the decayed argument types. 207 template <typename T> class RPCFunctionArgsTuple; 208 209 template <typename RetT, typename... ArgTs> 210 class RPCFunctionArgsTuple<RetT(ArgTs...)> { 211 public: 212 using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>; 213 }; 214 215 // ResultTraits provides typedefs and utilities specific to the return type 216 // of functions. 217 template <typename RetT> class ResultTraits { 218 public: 219 // The return type wrapped in llvm::Expected. 220 using ErrorReturnType = Expected<RetT>; 221 222 #ifdef _MSC_VER 223 // The ErrorReturnType wrapped in a std::promise. 224 using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; 225 226 // The ErrorReturnType wrapped in a std::future. 227 using ReturnFutureType = std::future<MSVCPExpected<RetT>>; 228 #else 229 // The ErrorReturnType wrapped in a std::promise. 230 using ReturnPromiseType = std::promise<ErrorReturnType>; 231 232 // The ErrorReturnType wrapped in a std::future. 233 using ReturnFutureType = std::future<ErrorReturnType>; 234 #endif 235 236 // Create a 'blank' value of the ErrorReturnType, ready and safe to 237 // overwrite. 238 static ErrorReturnType createBlankErrorReturnValue() { 239 return ErrorReturnType(RetT()); 240 } 241 242 // Consume an abandoned ErrorReturnType. 243 static void consumeAbandoned(ErrorReturnType RetOrErr) { 244 consumeError(RetOrErr.takeError()); 245 } 246 247 static ErrorReturnType returnError(Error Err) { return std::move(Err); } 248 }; 249 250 // ResultTraits specialization for void functions. 251 template <> class ResultTraits<void> { 252 public: 253 // For void functions, ErrorReturnType is llvm::Error. 254 using ErrorReturnType = Error; 255 256 #ifdef _MSC_VER 257 // The ErrorReturnType wrapped in a std::promise. 258 using ReturnPromiseType = std::promise<MSVCPError>; 259 260 // The ErrorReturnType wrapped in a std::future. 261 using ReturnFutureType = std::future<MSVCPError>; 262 #else 263 // The ErrorReturnType wrapped in a std::promise. 264 using ReturnPromiseType = std::promise<ErrorReturnType>; 265 266 // The ErrorReturnType wrapped in a std::future. 267 using ReturnFutureType = std::future<ErrorReturnType>; 268 #endif 269 270 // Create a 'blank' value of the ErrorReturnType, ready and safe to 271 // overwrite. 272 static ErrorReturnType createBlankErrorReturnValue() { 273 return ErrorReturnType::success(); 274 } 275 276 // Consume an abandoned ErrorReturnType. 277 static void consumeAbandoned(ErrorReturnType Err) { 278 consumeError(std::move(Err)); 279 } 280 281 static ErrorReturnType returnError(Error Err) { return Err; } 282 }; 283 284 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows 285 // handlers for void RPC functions to return either void (in which case they 286 // implicitly succeed) or Error (in which case their error return is 287 // propagated). See usage in HandlerTraits::runHandlerHelper. 288 template <> class ResultTraits<Error> : public ResultTraits<void> {}; 289 290 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows 291 // handlers for RPC functions returning a T to return either a T (in which 292 // case they implicitly succeed) or Expected<T> (in which case their error 293 // return is propagated). See usage in HandlerTraits::runHandlerHelper. 294 template <typename RetT> 295 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; 296 297 // Determines whether an RPC function's defined error return type supports 298 // error return value. 299 template <typename T> class SupportsErrorReturn { 300 public: 301 static const bool value = false; 302 }; 303 304 template <> class SupportsErrorReturn<Error> { 305 public: 306 static const bool value = true; 307 }; 308 309 template <typename T> class SupportsErrorReturn<Expected<T>> { 310 public: 311 static const bool value = true; 312 }; 313 314 // RespondHelper packages return values based on whether or not the declared 315 // RPC function return type supports error returns. 316 template <bool FuncSupportsErrorReturn> class RespondHelper; 317 318 // RespondHelper specialization for functions that support error returns. 319 template <> class RespondHelper<true> { 320 public: 321 // Send Expected<T>. 322 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 323 typename FunctionIdT, typename SequenceNumberT> 324 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 325 SequenceNumberT SeqNo, 326 Expected<HandlerRetT> ResultOrErr) { 327 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) 328 return ResultOrErr.takeError(); 329 330 // Open the response message. 331 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 332 return Err; 333 334 // Serialize the result. 335 if (auto Err = 336 SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>:: 337 serialize(C, std::move(ResultOrErr))) 338 return Err; 339 340 // Close the response message. 341 if (auto Err = C.endSendMessage()) 342 return Err; 343 return C.send(); 344 } 345 346 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 347 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 348 SequenceNumberT SeqNo, Error Err) { 349 if (Err && Err.isA<RPCFatalError>()) 350 return Err; 351 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 352 return Err2; 353 if (auto Err2 = serializeSeq(C, std::move(Err))) 354 return Err2; 355 if (auto Err2 = C.endSendMessage()) 356 return Err2; 357 return C.send(); 358 } 359 }; 360 361 // RespondHelper specialization for functions that do not support error returns. 362 template <> class RespondHelper<false> { 363 public: 364 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 365 typename FunctionIdT, typename SequenceNumberT> 366 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 367 SequenceNumberT SeqNo, 368 Expected<HandlerRetT> ResultOrErr) { 369 if (auto Err = ResultOrErr.takeError()) 370 return Err; 371 372 // Open the response message. 373 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 374 return Err; 375 376 // Serialize the result. 377 if (auto Err = 378 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( 379 C, *ResultOrErr)) 380 return Err; 381 382 // End the response message. 383 if (auto Err = C.endSendMessage()) 384 return Err; 385 386 return C.send(); 387 } 388 389 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 390 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 391 SequenceNumberT SeqNo, Error Err) { 392 if (Err) 393 return Err; 394 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 395 return Err2; 396 if (auto Err2 = C.endSendMessage()) 397 return Err2; 398 return C.send(); 399 } 400 }; 401 402 // Send a response of the given wire return type (WireRetT) over the 403 // channel, with the given sequence number. 404 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 405 typename FunctionIdT, typename SequenceNumberT> 406 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, 407 Expected<HandlerRetT> ResultOrErr) { 408 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 409 template sendResult<WireRetT>(C, ResponseId, SeqNo, 410 std::move(ResultOrErr)); 411 } 412 413 // Send an empty response message on the given channel to indicate that 414 // the handler ran. 415 template <typename WireRetT, typename ChannelT, typename FunctionIdT, 416 typename SequenceNumberT> 417 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, 418 Error Err) { 419 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult( 420 C, ResponseId, SeqNo, std::move(Err)); 421 } 422 423 // Converts a given type to the equivalent error return type. 424 template <typename T> class WrappedHandlerReturn { 425 public: 426 using Type = Expected<T>; 427 }; 428 429 template <typename T> class WrappedHandlerReturn<Expected<T>> { 430 public: 431 using Type = Expected<T>; 432 }; 433 434 template <> class WrappedHandlerReturn<void> { 435 public: 436 using Type = Error; 437 }; 438 439 template <> class WrappedHandlerReturn<Error> { 440 public: 441 using Type = Error; 442 }; 443 444 template <> class WrappedHandlerReturn<ErrorSuccess> { 445 public: 446 using Type = Error; 447 }; 448 449 // Traits class that strips the response function from the list of handler 450 // arguments. 451 template <typename FnT> class AsyncHandlerTraits; 452 453 template <typename ResultT, typename... ArgTs> 454 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, 455 ArgTs...)> { 456 public: 457 using Type = Error(ArgTs...); 458 using ResultType = Expected<ResultT>; 459 }; 460 461 template <typename... ArgTs> 462 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { 463 public: 464 using Type = Error(ArgTs...); 465 using ResultType = Error; 466 }; 467 468 template <typename... ArgTs> 469 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { 470 public: 471 using Type = Error(ArgTs...); 472 using ResultType = Error; 473 }; 474 475 template <typename... ArgTs> 476 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { 477 public: 478 using Type = Error(ArgTs...); 479 using ResultType = Error; 480 }; 481 482 template <typename ResponseHandlerT, typename... ArgTs> 483 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> 484 : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>, 485 ArgTs...)> {}; 486 487 // This template class provides utilities related to RPC function handlers. 488 // The base case applies to non-function types (the template class is 489 // specialized for function types) and inherits from the appropriate 490 // speciilization for the given non-function type's call operator. 491 template <typename HandlerT> 492 class HandlerTraits 493 : public HandlerTraits< 494 decltype(&std::remove_reference<HandlerT>::type::operator())> {}; 495 496 // Traits for handlers with a given function type. 497 template <typename RetT, typename... ArgTs> 498 class HandlerTraits<RetT(ArgTs...)> { 499 public: 500 // Function type of the handler. 501 using Type = RetT(ArgTs...); 502 503 // Return type of the handler. 504 using ReturnType = RetT; 505 506 // Call the given handler with the given arguments. 507 template <typename HandlerT, typename... TArgTs> 508 static typename WrappedHandlerReturn<RetT>::Type 509 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { 510 return unpackAndRunHelper(Handler, Args, 511 std::index_sequence_for<TArgTs...>()); 512 } 513 514 // Call the given handler with the given arguments. 515 template <typename HandlerT, typename ResponderT, typename... TArgTs> 516 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, 517 std::tuple<TArgTs...> &Args) { 518 return unpackAndRunAsyncHelper(Handler, Responder, Args, 519 std::index_sequence_for<TArgTs...>()); 520 } 521 522 // Call the given handler with the given arguments. 523 template <typename HandlerT> 524 static std::enable_if_t< 525 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error> 526 run(HandlerT &Handler, ArgTs &&...Args) { 527 Handler(std::move(Args)...); 528 return Error::success(); 529 } 530 531 template <typename HandlerT, typename... TArgTs> 532 static std::enable_if_t< 533 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 534 typename HandlerTraits<HandlerT>::ReturnType> 535 run(HandlerT &Handler, TArgTs... Args) { 536 return Handler(std::move(Args)...); 537 } 538 539 // Serialize arguments to the channel. 540 template <typename ChannelT, typename... CArgTs> 541 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { 542 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); 543 } 544 545 // Deserialize arguments from the channel. 546 template <typename ChannelT, typename... CArgTs> 547 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { 548 return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); 549 } 550 551 private: 552 template <typename ChannelT, typename... CArgTs, size_t... Indexes> 553 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, 554 std::index_sequence<Indexes...> _) { 555 return SequenceSerialization<ChannelT, ArgTs...>::deserialize( 556 C, std::get<Indexes>(Args)...); 557 } 558 559 template <typename HandlerT, typename ArgTuple, size_t... Indexes> 560 static typename WrappedHandlerReturn< 561 typename HandlerTraits<HandlerT>::ReturnType>::Type 562 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, 563 std::index_sequence<Indexes...>) { 564 return run(Handler, std::move(std::get<Indexes>(Args))...); 565 } 566 567 template <typename HandlerT, typename ResponderT, typename ArgTuple, 568 size_t... Indexes> 569 static typename WrappedHandlerReturn< 570 typename HandlerTraits<HandlerT>::ReturnType>::Type 571 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, 572 ArgTuple &Args, std::index_sequence<Indexes...>) { 573 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); 574 } 575 }; 576 577 // Handler traits for free functions. 578 template <typename RetT, typename... ArgTs> 579 class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> { 580 }; 581 582 // Handler traits for class methods (especially call operators for lambdas). 583 template <typename Class, typename RetT, typename... ArgTs> 584 class HandlerTraits<RetT (Class::*)(ArgTs...)> 585 : public HandlerTraits<RetT(ArgTs...)> {}; 586 587 // Handler traits for const class methods (especially call operators for 588 // lambdas). 589 template <typename Class, typename RetT, typename... ArgTs> 590 class HandlerTraits<RetT (Class::*)(ArgTs...) const> 591 : public HandlerTraits<RetT(ArgTs...)> {}; 592 593 // Utility to peel the Expected wrapper off a response handler error type. 594 template <typename HandlerT> class ResponseHandlerArg; 595 596 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { 597 public: 598 using ArgType = Expected<ArgT>; 599 using UnwrappedArgType = ArgT; 600 }; 601 602 template <typename ArgT> 603 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { 604 public: 605 using ArgType = Expected<ArgT>; 606 using UnwrappedArgType = ArgT; 607 }; 608 609 template <> class ResponseHandlerArg<Error(Error)> { 610 public: 611 using ArgType = Error; 612 }; 613 614 template <> class ResponseHandlerArg<ErrorSuccess(Error)> { 615 public: 616 using ArgType = Error; 617 }; 618 619 // ResponseHandler represents a handler for a not-yet-received function call 620 // result. 621 template <typename ChannelT> class ResponseHandler { 622 public: 623 virtual ~ResponseHandler() {} 624 625 // Reads the function result off the wire and acts on it. The meaning of 626 // "act" will depend on how this method is implemented in any given 627 // ResponseHandler subclass but could, for example, mean running a 628 // user-specified handler or setting a promise value. 629 virtual Error handleResponse(ChannelT &C) = 0; 630 631 // Abandons this outstanding result. 632 virtual void abandon() = 0; 633 634 // Create an error instance representing an abandoned response. 635 static Error createAbandonedResponseError() { 636 return make_error<ResponseAbandoned>(); 637 } 638 }; 639 640 // ResponseHandler subclass for RPC functions with non-void returns. 641 template <typename ChannelT, typename FuncRetT, typename HandlerT> 642 class ResponseHandlerImpl : public ResponseHandler<ChannelT> { 643 public: 644 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 645 646 // Handle the result by deserializing it from the channel then passing it 647 // to the user defined handler. 648 Error handleResponse(ChannelT &C) override { 649 using UnwrappedArgType = typename ResponseHandlerArg< 650 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; 651 UnwrappedArgType Result; 652 if (auto Err = 653 SerializationTraits<ChannelT, FuncRetT, 654 UnwrappedArgType>::deserialize(C, Result)) 655 return Err; 656 if (auto Err = C.endReceiveMessage()) 657 return Err; 658 return Handler(std::move(Result)); 659 } 660 661 // Abandon this response by calling the handler with an 'abandoned response' 662 // error. 663 void abandon() override { 664 if (auto Err = Handler(this->createAbandonedResponseError())) { 665 // Handlers should not fail when passed an abandoned response error. 666 report_fatal_error(std::move(Err)); 667 } 668 } 669 670 private: 671 HandlerT Handler; 672 }; 673 674 // ResponseHandler subclass for RPC functions with void returns. 675 template <typename ChannelT, typename HandlerT> 676 class ResponseHandlerImpl<ChannelT, void, HandlerT> 677 : public ResponseHandler<ChannelT> { 678 public: 679 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 680 681 // Handle the result (no actual value, just a notification that the function 682 // has completed on the remote end) by calling the user-defined handler with 683 // Error::success(). 684 Error handleResponse(ChannelT &C) override { 685 if (auto Err = C.endReceiveMessage()) 686 return Err; 687 return Handler(Error::success()); 688 } 689 690 // Abandon this response by calling the handler with an 'abandoned response' 691 // error. 692 void abandon() override { 693 if (auto Err = Handler(this->createAbandonedResponseError())) { 694 // Handlers should not fail when passed an abandoned response error. 695 report_fatal_error(std::move(Err)); 696 } 697 } 698 699 private: 700 HandlerT Handler; 701 }; 702 703 template <typename ChannelT, typename FuncRetT, typename HandlerT> 704 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> 705 : public ResponseHandler<ChannelT> { 706 public: 707 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 708 709 // Handle the result by deserializing it from the channel then passing it 710 // to the user defined handler. 711 Error handleResponse(ChannelT &C) override { 712 using HandlerArgType = typename ResponseHandlerArg< 713 typename HandlerTraits<HandlerT>::Type>::ArgType; 714 HandlerArgType Result((typename HandlerArgType::value_type())); 715 716 if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>, 717 HandlerArgType>::deserialize(C, Result)) 718 return Err; 719 if (auto Err = C.endReceiveMessage()) 720 return Err; 721 return Handler(std::move(Result)); 722 } 723 724 // Abandon this response by calling the handler with an 'abandoned response' 725 // error. 726 void abandon() override { 727 if (auto Err = Handler(this->createAbandonedResponseError())) { 728 // Handlers should not fail when passed an abandoned response error. 729 report_fatal_error(std::move(Err)); 730 } 731 } 732 733 private: 734 HandlerT Handler; 735 }; 736 737 template <typename ChannelT, typename HandlerT> 738 class ResponseHandlerImpl<ChannelT, Error, HandlerT> 739 : public ResponseHandler<ChannelT> { 740 public: 741 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 742 743 // Handle the result by deserializing it from the channel then passing it 744 // to the user defined handler. 745 Error handleResponse(ChannelT &C) override { 746 Error Result = Error::success(); 747 if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( 748 C, Result)) { 749 consumeError(std::move(Result)); 750 return Err; 751 } 752 if (auto Err = C.endReceiveMessage()) { 753 consumeError(std::move(Result)); 754 return Err; 755 } 756 return Handler(std::move(Result)); 757 } 758 759 // Abandon this response by calling the handler with an 'abandoned response' 760 // error. 761 void abandon() override { 762 if (auto Err = Handler(this->createAbandonedResponseError())) { 763 // Handlers should not fail when passed an abandoned response error. 764 report_fatal_error(std::move(Err)); 765 } 766 } 767 768 private: 769 HandlerT Handler; 770 }; 771 772 // Create a ResponseHandler from a given user handler. 773 template <typename ChannelT, typename FuncRetT, typename HandlerT> 774 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { 775 return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( 776 std::move(H)); 777 } 778 779 // Helper for wrapping member functions up as functors. This is useful for 780 // installing methods as result handlers. 781 template <typename ClassT, typename RetT, typename... ArgTs> 782 class MemberFnWrapper { 783 public: 784 using MethodT = RetT (ClassT::*)(ArgTs...); 785 MemberFnWrapper(ClassT &Instance, MethodT Method) 786 : Instance(Instance), Method(Method) {} 787 RetT operator()(ArgTs &&...Args) { 788 return (Instance.*Method)(std::move(Args)...); 789 } 790 791 private: 792 ClassT &Instance; 793 MethodT Method; 794 }; 795 796 // Helper that provides a Functor for deserializing arguments. 797 template <typename... ArgTs> class ReadArgs { 798 public: 799 Error operator()() { return Error::success(); } 800 }; 801 802 template <typename ArgT, typename... ArgTs> 803 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { 804 public: 805 ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} 806 807 Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) { 808 this->Arg = std::move(ArgVal); 809 return ReadArgs<ArgTs...>::operator()(ArgVals...); 810 } 811 812 private: 813 ArgT &Arg; 814 }; 815 816 // Manage sequence numbers. 817 template <typename SequenceNumberT> class SequenceNumberManager { 818 public: 819 // Reset, making all sequence numbers available. 820 void reset() { 821 std::lock_guard<std::mutex> Lock(SeqNoLock); 822 NextSequenceNumber = 0; 823 FreeSequenceNumbers.clear(); 824 } 825 826 // Get the next available sequence number. Will re-use numbers that have 827 // been released. 828 SequenceNumberT getSequenceNumber() { 829 std::lock_guard<std::mutex> Lock(SeqNoLock); 830 if (FreeSequenceNumbers.empty()) 831 return NextSequenceNumber++; 832 auto SequenceNumber = FreeSequenceNumbers.back(); 833 FreeSequenceNumbers.pop_back(); 834 return SequenceNumber; 835 } 836 837 // Release a sequence number, making it available for re-use. 838 void releaseSequenceNumber(SequenceNumberT SequenceNumber) { 839 std::lock_guard<std::mutex> Lock(SeqNoLock); 840 FreeSequenceNumbers.push_back(SequenceNumber); 841 } 842 843 private: 844 std::mutex SeqNoLock; 845 SequenceNumberT NextSequenceNumber = 0; 846 std::vector<SequenceNumberT> FreeSequenceNumbers; 847 }; 848 849 // Checks that predicate P holds for each corresponding pair of type arguments 850 // from T1 and T2 tuple. 851 template <template <class, class> class P, typename T1Tuple, typename T2Tuple> 852 class RPCArgTypeCheckHelper; 853 854 template <template <class, class> class P> 855 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { 856 public: 857 static const bool value = true; 858 }; 859 860 template <template <class, class> class P, typename T, typename... Ts, 861 typename U, typename... Us> 862 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { 863 public: 864 static const bool value = 865 P<T, U>::value && 866 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; 867 }; 868 869 template <template <class, class> class P, typename T1Sig, typename T2Sig> 870 class RPCArgTypeCheck { 871 public: 872 using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type; 873 using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type; 874 875 static_assert(std::tuple_size<T1Tuple>::value >= 876 std::tuple_size<T2Tuple>::value, 877 "Too many arguments to RPC call"); 878 static_assert(std::tuple_size<T1Tuple>::value <= 879 std::tuple_size<T2Tuple>::value, 880 "Too few arguments to RPC call"); 881 882 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; 883 }; 884 885 template <typename ChannelT, typename WireT, typename ConcreteT> 886 class CanSerialize { 887 private: 888 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 889 890 template <typename T> 891 static std::true_type check( 892 std::enable_if_t<std::is_same<decltype(T::serialize( 893 std::declval<ChannelT &>(), 894 std::declval<const ConcreteT &>())), 895 Error>::value, 896 void *>); 897 898 template <typename> static std::false_type check(...); 899 900 public: 901 static const bool value = decltype(check<S>(0))::value; 902 }; 903 904 template <typename ChannelT, typename WireT, typename ConcreteT> 905 class CanDeserialize { 906 private: 907 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 908 909 template <typename T> 910 static std::true_type 911 check(std::enable_if_t< 912 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), 913 std::declval<ConcreteT &>())), 914 Error>::value, 915 void *>); 916 917 template <typename> static std::false_type check(...); 918 919 public: 920 static const bool value = decltype(check<S>(0))::value; 921 }; 922 923 /// Contains primitive utilities for defining, calling and handling calls to 924 /// remote procedures. ChannelT is a bidirectional stream conforming to the 925 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure 926 /// identifier type that must be serializable on ChannelT, and SequenceNumberT 927 /// is an integral type that will be used to number in-flight function calls. 928 /// 929 /// These utilities support the construction of very primitive RPC utilities. 930 /// Their intent is to ensure correct serialization and deserialization of 931 /// procedure arguments, and to keep the client and server's view of the API in 932 /// sync. 933 template <typename ImplT, typename ChannelT, typename FunctionIdT, 934 typename SequenceNumberT> 935 class RPCEndpointBase { 936 protected: 937 class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> { 938 public: 939 static const char *getName() { return "__orc_rpc$invalid"; } 940 }; 941 942 class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> { 943 public: 944 static const char *getName() { return "__orc_rpc$response"; } 945 }; 946 947 class OrcRPCNegotiate 948 : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> { 949 public: 950 static const char *getName() { return "__orc_rpc$negotiate"; } 951 }; 952 953 // Helper predicate for testing for the presence of SerializeTraits 954 // serializers. 955 template <typename WireT, typename ConcreteT> 956 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { 957 public: 958 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; 959 960 static_assert(value, "Missing serializer for argument (Can't serialize the " 961 "first template type argument of CanSerializeCheck " 962 "from the second)"); 963 }; 964 965 // Helper predicate for testing for the presence of SerializeTraits 966 // deserializers. 967 template <typename WireT, typename ConcreteT> 968 class CanDeserializeCheck 969 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { 970 public: 971 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; 972 973 static_assert(value, "Missing deserializer for argument (Can't deserialize " 974 "the second template type argument of " 975 "CanDeserializeCheck from the first)"); 976 }; 977 978 public: 979 /// Construct an RPC instance on a channel. 980 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) 981 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { 982 // Hold ResponseId in a special variable, since we expect Response to be 983 // called relatively frequently, and want to avoid the map lookup. 984 ResponseId = FnIdAllocator.getResponseId(); 985 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; 986 987 // Register the negotiate function id and handler. 988 auto NegotiateId = FnIdAllocator.getNegotiateId(); 989 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; 990 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( 991 [this](const std::string &Name) { return handleNegotiate(Name); }); 992 } 993 994 /// Negotiate a function id for Func with the other end of the channel. 995 template <typename Func> Error negotiateFunction(bool Retry = false) { 996 return getRemoteFunctionId<Func>(true, Retry).takeError(); 997 } 998 999 /// Append a call Func, does not call send on the channel. 1000 /// The first argument specifies a user-defined handler to be run when the 1001 /// function returns. The handler should take an Expected<Func::ReturnType>, 1002 /// or an Error (if Func::ReturnType is void). The handler will be called 1003 /// with an error if the return value is abandoned due to a channel error. 1004 template <typename Func, typename HandlerT, typename... ArgTs> 1005 Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) { 1006 1007 static_assert( 1008 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, 1009 void(ArgTs...)>::value, 1010 ""); 1011 1012 // Look up the function ID. 1013 FunctionIdT FnId; 1014 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) 1015 FnId = *FnIdOrErr; 1016 else { 1017 // Negotiation failed. Notify the handler then return the negotiate-failed 1018 // error. 1019 cantFail(Handler(make_error<ResponseAbandoned>())); 1020 return FnIdOrErr.takeError(); 1021 } 1022 1023 SequenceNumberT SeqNo; // initialized in locked scope below. 1024 { 1025 // Lock the pending responses map and sequence number manager. 1026 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1027 1028 // Allocate a sequence number. 1029 SeqNo = SequenceNumberMgr.getSequenceNumber(); 1030 assert(!PendingResponses.count(SeqNo) && 1031 "Sequence number already allocated"); 1032 1033 // Install the user handler. 1034 PendingResponses[SeqNo] = 1035 detail::createResponseHandler<ChannelT, typename Func::ReturnType>( 1036 std::move(Handler)); 1037 } 1038 1039 // Open the function call message. 1040 if (auto Err = C.startSendMessage(FnId, SeqNo)) { 1041 abandonPendingResponses(); 1042 return Err; 1043 } 1044 1045 // Serialize the call arguments. 1046 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( 1047 C, Args...)) { 1048 abandonPendingResponses(); 1049 return Err; 1050 } 1051 1052 // Close the function call messagee. 1053 if (auto Err = C.endSendMessage()) { 1054 abandonPendingResponses(); 1055 return Err; 1056 } 1057 1058 return Error::success(); 1059 } 1060 1061 Error sendAppendedCalls() { return C.send(); }; 1062 1063 template <typename Func, typename HandlerT, typename... ArgTs> 1064 Error callAsync(HandlerT Handler, const ArgTs &...Args) { 1065 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) 1066 return Err; 1067 return C.send(); 1068 } 1069 1070 /// Handle one incoming call. 1071 Error handleOne() { 1072 FunctionIdT FnId; 1073 SequenceNumberT SeqNo; 1074 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { 1075 abandonPendingResponses(); 1076 return Err; 1077 } 1078 if (FnId == ResponseId) 1079 return handleResponse(SeqNo); 1080 auto I = Handlers.find(FnId); 1081 if (I != Handlers.end()) 1082 return I->second(C, SeqNo); 1083 1084 // else: No handler found. Report error to client? 1085 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, 1086 SeqNo); 1087 } 1088 1089 /// Helper for handling setter procedures - this method returns a functor that 1090 /// sets the variables referred to by Args... to values deserialized from the 1091 /// channel. 1092 /// E.g. 1093 /// 1094 /// typedef Function<0, bool, int> Func1; 1095 /// 1096 /// ... 1097 /// bool B; 1098 /// int I; 1099 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) 1100 /// /* Handle Args */ ; 1101 /// 1102 template <typename... ArgTs> 1103 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) { 1104 return detail::ReadArgs<ArgTs...>(Args...); 1105 } 1106 1107 /// Abandon all outstanding result handlers. 1108 /// 1109 /// This will call all currently registered result handlers to receive an 1110 /// "abandoned" error as their argument. This is used internally by the RPC 1111 /// in error situations, but can also be called directly by clients who are 1112 /// disconnecting from the remote and don't or can't expect responses to their 1113 /// outstanding calls. (Especially for outstanding blocking calls, calling 1114 /// this function may be necessary to avoid dead threads). 1115 void abandonPendingResponses() { 1116 // Lock the pending responses map and sequence number manager. 1117 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1118 1119 for (auto &KV : PendingResponses) 1120 KV.second->abandon(); 1121 PendingResponses.clear(); 1122 SequenceNumberMgr.reset(); 1123 } 1124 1125 /// Remove the handler for the given function. 1126 /// A handler must currently be registered for this function. 1127 template <typename Func> void removeHandler() { 1128 auto IdItr = LocalFunctionIds.find(Func::getPrototype()); 1129 assert(IdItr != LocalFunctionIds.end() && 1130 "Function does not have a registered handler"); 1131 auto HandlerItr = Handlers.find(IdItr->second); 1132 assert(HandlerItr != Handlers.end() && 1133 "Function does not have a registered handler"); 1134 Handlers.erase(HandlerItr); 1135 } 1136 1137 /// Clear all handlers. 1138 void clearHandlers() { Handlers.clear(); } 1139 1140 protected: 1141 FunctionIdT getInvalidFunctionId() const { 1142 return FnIdAllocator.getInvalidId(); 1143 } 1144 1145 /// Add the given handler to the handler map and make it available for 1146 /// autonegotiation and execution. 1147 template <typename Func, typename HandlerT> 1148 void addHandlerImpl(HandlerT Handler) { 1149 1150 static_assert(detail::RPCArgTypeCheck< 1151 CanDeserializeCheck, typename Func::Type, 1152 typename detail::HandlerTraits<HandlerT>::Type>::value, 1153 ""); 1154 1155 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1156 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1157 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); 1158 } 1159 1160 template <typename Func, typename HandlerT> 1161 void addAsyncHandlerImpl(HandlerT Handler) { 1162 1163 static_assert( 1164 detail::RPCArgTypeCheck< 1165 CanDeserializeCheck, typename Func::Type, 1166 typename detail::AsyncHandlerTraits< 1167 typename detail::HandlerTraits<HandlerT>::Type>::Type>::value, 1168 ""); 1169 1170 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1171 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1172 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); 1173 } 1174 1175 Error handleResponse(SequenceNumberT SeqNo) { 1176 using Handler = typename decltype(PendingResponses)::mapped_type; 1177 Handler PRHandler; 1178 1179 { 1180 // Lock the pending responses map and sequence number manager. 1181 std::unique_lock<std::mutex> Lock(ResponsesMutex); 1182 auto I = PendingResponses.find(SeqNo); 1183 1184 if (I != PendingResponses.end()) { 1185 PRHandler = std::move(I->second); 1186 PendingResponses.erase(I); 1187 SequenceNumberMgr.releaseSequenceNumber(SeqNo); 1188 } else { 1189 // Unlock the pending results map to prevent recursive lock. 1190 Lock.unlock(); 1191 abandonPendingResponses(); 1192 return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>( 1193 SeqNo); 1194 } 1195 } 1196 1197 assert(PRHandler && 1198 "If we didn't find a response handler we should have bailed out"); 1199 1200 if (auto Err = PRHandler->handleResponse(C)) { 1201 abandonPendingResponses(); 1202 return Err; 1203 } 1204 1205 return Error::success(); 1206 } 1207 1208 FunctionIdT handleNegotiate(const std::string &Name) { 1209 auto I = LocalFunctionIds.find(Name); 1210 if (I == LocalFunctionIds.end()) 1211 return getInvalidFunctionId(); 1212 return I->second; 1213 } 1214 1215 // Find the remote FunctionId for the given function. 1216 template <typename Func> 1217 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, 1218 bool NegotiateIfInvalid) { 1219 bool DoNegotiate; 1220 1221 // Check if we already have a function id... 1222 auto I = RemoteFunctionIds.find(Func::getPrototype()); 1223 if (I != RemoteFunctionIds.end()) { 1224 // If it's valid there's nothing left to do. 1225 if (I->second != getInvalidFunctionId()) 1226 return I->second; 1227 DoNegotiate = NegotiateIfInvalid; 1228 } else 1229 DoNegotiate = NegotiateIfNotInMap; 1230 1231 // We don't have a function id for Func yet, but we're allowed to try to 1232 // negotiate one. 1233 if (DoNegotiate) { 1234 auto &Impl = static_cast<ImplT &>(*this); 1235 if (auto RemoteIdOrErr = 1236 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { 1237 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; 1238 if (*RemoteIdOrErr == getInvalidFunctionId()) 1239 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1240 return *RemoteIdOrErr; 1241 } else 1242 return RemoteIdOrErr.takeError(); 1243 } 1244 1245 // No key was available in the map and we weren't allowed to try to 1246 // negotiate one, so return an unknown function error. 1247 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1248 } 1249 1250 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; 1251 1252 // Wrap the given user handler in the necessary argument-deserialization code, 1253 // result-serialization code, and call to the launch policy (if present). 1254 template <typename Func, typename HandlerT> 1255 WrappedHandlerFn wrapHandler(HandlerT Handler) { 1256 return [this, Handler](ChannelT &Channel, 1257 SequenceNumberT SeqNo) mutable -> Error { 1258 // Start by deserializing the arguments. 1259 using ArgsTuple = typename detail::RPCFunctionArgsTuple< 1260 typename detail::HandlerTraits<HandlerT>::Type>::Type; 1261 auto Args = std::make_shared<ArgsTuple>(); 1262 1263 if (auto Err = 1264 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1265 Channel, *Args)) 1266 return Err; 1267 1268 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1269 // for RPCArgs. Void cast RPCArgs to work around this for now. 1270 // FIXME: Remove this workaround once we can assume a working GCC version. 1271 (void)Args; 1272 1273 // End receieve message, unlocking the channel for reading. 1274 if (auto Err = Channel.endReceiveMessage()) 1275 return Err; 1276 1277 using HTraits = detail::HandlerTraits<HandlerT>; 1278 using FuncReturn = typename Func::ReturnType; 1279 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, 1280 HTraits::unpackAndRun(Handler, *Args)); 1281 }; 1282 } 1283 1284 // Wrap the given user handler in the necessary argument-deserialization code, 1285 // result-serialization code, and call to the launch policy (if present). 1286 template <typename Func, typename HandlerT> 1287 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { 1288 return [this, Handler](ChannelT &Channel, 1289 SequenceNumberT SeqNo) mutable -> Error { 1290 // Start by deserializing the arguments. 1291 using AHTraits = detail::AsyncHandlerTraits< 1292 typename detail::HandlerTraits<HandlerT>::Type>; 1293 using ArgsTuple = 1294 typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type; 1295 auto Args = std::make_shared<ArgsTuple>(); 1296 1297 if (auto Err = 1298 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1299 Channel, *Args)) 1300 return Err; 1301 1302 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1303 // for RPCArgs. Void cast RPCArgs to work around this for now. 1304 // FIXME: Remove this workaround once we can assume a working GCC version. 1305 (void)Args; 1306 1307 // End receieve message, unlocking the channel for reading. 1308 if (auto Err = Channel.endReceiveMessage()) 1309 return Err; 1310 1311 using HTraits = detail::HandlerTraits<HandlerT>; 1312 using FuncReturn = typename Func::ReturnType; 1313 auto Responder = [this, 1314 SeqNo](typename AHTraits::ResultType RetVal) -> Error { 1315 return detail::respond<FuncReturn>(C, ResponseId, SeqNo, 1316 std::move(RetVal)); 1317 }; 1318 1319 return HTraits::unpackAndRunAsync(Handler, Responder, *Args); 1320 }; 1321 } 1322 1323 ChannelT &C; 1324 1325 bool LazyAutoNegotiation; 1326 1327 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; 1328 1329 FunctionIdT ResponseId; 1330 std::map<std::string, FunctionIdT> LocalFunctionIds; 1331 std::map<const char *, FunctionIdT> RemoteFunctionIds; 1332 1333 std::map<FunctionIdT, WrappedHandlerFn> Handlers; 1334 1335 std::mutex ResponsesMutex; 1336 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; 1337 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> 1338 PendingResponses; 1339 }; 1340 1341 } // end namespace detail 1342 1343 template <typename ChannelT, typename FunctionIdT = uint32_t, 1344 typename SequenceNumberT = uint32_t> 1345 class MultiThreadedRPCEndpoint 1346 : public detail::RPCEndpointBase< 1347 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1348 ChannelT, FunctionIdT, SequenceNumberT> { 1349 private: 1350 using BaseClass = detail::RPCEndpointBase< 1351 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1352 ChannelT, FunctionIdT, SequenceNumberT>; 1353 1354 public: 1355 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1356 : BaseClass(C, LazyAutoNegotiation) {} 1357 1358 /// Add a handler for the given RPC function. 1359 /// This installs the given handler functor for the given RPCFunction, and 1360 /// makes the RPC function available for negotiation/calling from the remote. 1361 template <typename Func, typename HandlerT> 1362 void addHandler(HandlerT Handler) { 1363 return this->template addHandlerImpl<Func>(std::move(Handler)); 1364 } 1365 1366 /// Add a class-method as a handler. 1367 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1368 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1369 addHandler<Func>( 1370 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1371 } 1372 1373 template <typename Func, typename HandlerT> 1374 void addAsyncHandler(HandlerT Handler) { 1375 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1376 } 1377 1378 /// Add a class-method as a handler. 1379 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1380 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1381 addAsyncHandler<Func>( 1382 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1383 } 1384 1385 /// Return type for non-blocking call primitives. 1386 template <typename Func> 1387 using NonBlockingCallResult = typename detail::ResultTraits< 1388 typename Func::ReturnType>::ReturnFutureType; 1389 1390 /// Call Func on Channel C. Does not block, does not call send. Returns a pair 1391 /// of a future result and the sequence number assigned to the result. 1392 /// 1393 /// This utility function is primarily used for single-threaded mode support, 1394 /// where the sequence number can be used to wait for the corresponding 1395 /// result. In multi-threaded mode the appendCallNB method, which does not 1396 /// return the sequence numeber, should be preferred. 1397 template <typename Func, typename... ArgTs> 1398 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) { 1399 using RTraits = detail::ResultTraits<typename Func::ReturnType>; 1400 using ErrorReturn = typename RTraits::ErrorReturnType; 1401 using ErrorReturnPromise = typename RTraits::ReturnPromiseType; 1402 1403 ErrorReturnPromise Promise; 1404 auto FutureResult = Promise.get_future(); 1405 1406 if (auto Err = this->template appendCallAsync<Func>( 1407 [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { 1408 Promise.set_value(std::move(RetOrErr)); 1409 return Error::success(); 1410 }, 1411 Args...)) { 1412 RTraits::consumeAbandoned(FutureResult.get()); 1413 return std::move(Err); 1414 } 1415 return std::move(FutureResult); 1416 } 1417 1418 /// The same as appendCallNBWithSeq, except that it calls C.send() to 1419 /// flush the channel after serializing the call. 1420 template <typename Func, typename... ArgTs> 1421 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) { 1422 auto Result = appendCallNB<Func>(Args...); 1423 if (!Result) 1424 return Result; 1425 if (auto Err = this->C.send()) { 1426 this->abandonPendingResponses(); 1427 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1428 std::move(Result->get())); 1429 return std::move(Err); 1430 } 1431 return Result; 1432 } 1433 1434 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error 1435 /// for void functions or an Expected<T> for functions returning a T. 1436 /// 1437 /// This function is for use in threaded code where another thread is 1438 /// handling responses and incoming calls. 1439 template <typename Func, typename... ArgTs, 1440 typename AltRetT = typename Func::ReturnType> 1441 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1442 callB(const ArgTs &...Args) { 1443 if (auto FutureResOrErr = callNB<Func>(Args...)) 1444 return FutureResOrErr->get(); 1445 else 1446 return FutureResOrErr.takeError(); 1447 } 1448 1449 /// Handle incoming RPC calls. 1450 Error handlerLoop() { 1451 while (true) 1452 if (auto Err = this->handleOne()) 1453 return Err; 1454 return Error::success(); 1455 } 1456 }; 1457 1458 template <typename ChannelT, typename FunctionIdT = uint32_t, 1459 typename SequenceNumberT = uint32_t> 1460 class SingleThreadedRPCEndpoint 1461 : public detail::RPCEndpointBase< 1462 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1463 ChannelT, FunctionIdT, SequenceNumberT> { 1464 private: 1465 using BaseClass = detail::RPCEndpointBase< 1466 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1467 ChannelT, FunctionIdT, SequenceNumberT>; 1468 1469 public: 1470 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1471 : BaseClass(C, LazyAutoNegotiation) {} 1472 1473 template <typename Func, typename HandlerT> 1474 void addHandler(HandlerT Handler) { 1475 return this->template addHandlerImpl<Func>(std::move(Handler)); 1476 } 1477 1478 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1479 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1480 addHandler<Func>( 1481 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1482 } 1483 1484 template <typename Func, typename HandlerT> 1485 void addAsyncHandler(HandlerT Handler) { 1486 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1487 } 1488 1489 /// Add a class-method as a handler. 1490 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1491 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1492 addAsyncHandler<Func>( 1493 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1494 } 1495 1496 template <typename Func, typename... ArgTs, 1497 typename AltRetT = typename Func::ReturnType> 1498 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1499 callB(const ArgTs &...Args) { 1500 bool ReceivedResponse = false; 1501 using AltRetTraits = detail::ResultTraits<AltRetT>; 1502 using ResultType = typename AltRetTraits::ErrorReturnType; 1503 ResultType Result = AltRetTraits::createBlankErrorReturnValue(); 1504 1505 // We have to 'Check' result (which we know is in a success state at this 1506 // point) so that it can be overwritten in the async handler. 1507 (void)!!Result; 1508 1509 if (Error Err = this->template appendCallAsync<Func>( 1510 [&](ResultType R) { 1511 Result = std::move(R); 1512 ReceivedResponse = true; 1513 return Error::success(); 1514 }, 1515 Args...)) { 1516 AltRetTraits::consumeAbandoned(std::move(Result)); 1517 return AltRetTraits::returnError(std::move(Err)); 1518 } 1519 1520 if (Error Err = this->C.send()) { 1521 AltRetTraits::consumeAbandoned(std::move(Result)); 1522 return AltRetTraits::returnError(std::move(Err)); 1523 } 1524 1525 while (!ReceivedResponse) { 1526 if (Error Err = this->handleOne()) { 1527 AltRetTraits::consumeAbandoned(std::move(Result)); 1528 return AltRetTraits::returnError(std::move(Err)); 1529 } 1530 } 1531 1532 return Result; 1533 } 1534 }; 1535 1536 /// Asynchronous dispatch for a function on an RPC endpoint. 1537 template <typename RPCClass, typename Func> class RPCAsyncDispatch { 1538 public: 1539 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} 1540 1541 template <typename HandlerT, typename... ArgTs> 1542 Error operator()(HandlerT Handler, const ArgTs &...Args) const { 1543 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); 1544 } 1545 1546 private: 1547 RPCClass &Endpoint; 1548 }; 1549 1550 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func. 1551 template <typename Func, typename RPCEndpointT> 1552 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { 1553 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); 1554 } 1555 1556 /// Allows a set of asynchrounous calls to be dispatched, and then 1557 /// waited on as a group. 1558 class ParallelCallGroup { 1559 public: 1560 ParallelCallGroup() = default; 1561 ParallelCallGroup(const ParallelCallGroup &) = delete; 1562 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; 1563 1564 /// Make as asynchronous call. 1565 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> 1566 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, 1567 const ArgTs &...Args) { 1568 // Increment the count of outstanding calls. This has to happen before 1569 // we invoke the call, as the handler may (depending on scheduling) 1570 // be run immediately on another thread, and we don't want the decrement 1571 // in the wrapped handler below to run before the increment. 1572 { 1573 std::unique_lock<std::mutex> Lock(M); 1574 ++NumOutstandingCalls; 1575 } 1576 1577 // Wrap the user handler in a lambda that will decrement the 1578 // outstanding calls count, then poke the condition variable. 1579 using ArgType = typename detail::ResponseHandlerArg< 1580 typename detail::HandlerTraits<HandlerT>::Type>::ArgType; 1581 auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { 1582 auto Err = Handler(std::move(Arg)); 1583 std::unique_lock<std::mutex> Lock(M); 1584 --NumOutstandingCalls; 1585 CV.notify_all(); 1586 return Err; 1587 }; 1588 1589 return AsyncDispatch(std::move(WrappedHandler), Args...); 1590 } 1591 1592 /// Blocks until all calls have been completed and their return value 1593 /// handlers run. 1594 void wait() { 1595 std::unique_lock<std::mutex> Lock(M); 1596 while (NumOutstandingCalls > 0) 1597 CV.wait(Lock); 1598 } 1599 1600 private: 1601 std::mutex M; 1602 std::condition_variable CV; 1603 uint32_t NumOutstandingCalls = 0; 1604 }; 1605 1606 /// Convenience class for grouping RPCFunctions into APIs that can be 1607 /// negotiated as a block. 1608 /// 1609 template <typename... Funcs> class APICalls { 1610 public: 1611 /// Test whether this API contains Function F. 1612 template <typename F> class Contains { 1613 public: 1614 static const bool value = false; 1615 }; 1616 1617 /// Negotiate all functions in this API. 1618 template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { 1619 return Error::success(); 1620 } 1621 }; 1622 1623 template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> { 1624 public: 1625 template <typename F> class Contains { 1626 public: 1627 static const bool value = std::is_same<F, Func>::value | 1628 APICalls<Funcs...>::template Contains<F>::value; 1629 }; 1630 1631 template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { 1632 if (auto Err = R.template negotiateFunction<Func>()) 1633 return Err; 1634 return APICalls<Funcs...>::negotiate(R); 1635 } 1636 }; 1637 1638 template <typename... InnerFuncs, typename... Funcs> 1639 class APICalls<APICalls<InnerFuncs...>, Funcs...> { 1640 public: 1641 template <typename F> class Contains { 1642 public: 1643 static const bool value = 1644 APICalls<InnerFuncs...>::template Contains<F>::value | 1645 APICalls<Funcs...>::template Contains<F>::value; 1646 }; 1647 1648 template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) { 1649 if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) 1650 return Err; 1651 return APICalls<Funcs...>::negotiate(R); 1652 } 1653 }; 1654 1655 } // end namespace shared 1656 } // end namespace orc 1657 } // end namespace llvm 1658 1659 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H 1660