diff options
53 files changed, 4795 insertions, 2534 deletions
diff --git a/.gitmodules b/.gitmodules index 9f96b70be..361f4845b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -56,5 +56,5 @@ path = externals/nx_tzdb/tzdb_to_nx url = https://github.com/lat9nq/tzdb_to_nx.git [submodule "VulkanMemoryAllocator"] - path = externals/vma/VulkanMemoryAllocator + path = externals/VulkanMemoryAllocator url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f8febb90..00d540f1f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF) CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF) +set(DEFAULT_ENABLE_OPENSSL ON) +if (ANDROID OR WIN32 OR APPLE) + # - Windows defaults to the Schannel backend. + # - macOS defaults to the SecureTransport backend. + # - Android currently has no SSL backend as the NDK doesn't include any SSL + # library; a proper 'native' backend would have to go through Java. + # But you can force builds for those platforms to use OpenSSL if you have + # your own copy of it. + set(DEFAULT_ENABLE_OPENSSL OFF) +endif() +option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL}) + # On Android, fetch and compile libcxx before doing anything else if (ANDROID) set(CMAKE_SKIP_INSTALL_RULES ON) @@ -277,10 +289,11 @@ find_package(Boost 1.79.0 REQUIRED context) find_package(enet 1.3 MODULE) find_package(fmt 9 REQUIRED) find_package(inih 52 MODULE COMPONENTS INIReader) -find_package(LLVM MODULE COMPONENTS Demangle) +find_package(LLVM 17 MODULE COMPONENTS Demangle) find_package(lz4 REQUIRED) find_package(nlohmann_json 3.8 REQUIRED) find_package(Opus 1.3 MODULE) +find_package(VulkanMemoryAllocator CONFIG) find_package(ZLIB 1.2 REQUIRED) find_package(zstd 1.5 REQUIRED) @@ -322,6 +335,10 @@ if (MINGW) find_library(MSWSOCK_LIBRARY mswsock REQUIRED) endif() +if(ENABLE_OPENSSL) + find_package(OpenSSL 1.1.1 REQUIRED) +endif() + # Please consider this as a stub if(ENABLE_QT6 AND Qt6_LOCATION) list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}") diff --git a/externals/CMakeLists.txt b/externals/CMakeLists.txt index 4ff588851..1f7cd598e 100644 --- a/externals/CMakeLists.txt +++ b/externals/CMakeLists.txt @@ -144,9 +144,9 @@ endif() add_subdirectory(nx_tzdb) # VMA -add_library(vma vma/vma.cpp) -target_include_directories(vma PUBLIC ./vma/VulkanMemoryAllocator/include) -target_link_libraries(vma PRIVATE Vulkan::Headers) +if (NOT TARGET GPUOpen::VulkanMemoryAllocator) + add_subdirectory(VulkanMemoryAllocator) +endif() if (NOT TARGET LLVM::Demangle) add_library(demangle demangle/ItaniumDemangle.cpp) diff --git a/externals/VulkanMemoryAllocator b/externals/VulkanMemoryAllocator new file mode 160000 +Subproject 9b0fc3e7b02afe97895eb3e945fe800c3a7485a diff --git a/externals/demangle/ItaniumDemangle.cpp b/externals/demangle/ItaniumDemangle.cpp index b055a2fd7..47dd5d301 100644 --- a/externals/demangle/ItaniumDemangle.cpp +++ b/externals/demangle/ItaniumDemangle.cpp @@ -20,9 +20,7 @@ #include <cstdlib> #include <cstring> #include <functional> -#include <numeric> #include <utility> -#include <vector> using namespace llvm; using namespace llvm::itanium_demangle; @@ -81,8 +79,8 @@ struct DumpVisitor { } void printStr(const char *S) { fprintf(stderr, "%s", S); } - void print(StringView SV) { - fprintf(stderr, "\"%.*s\"", (int)SV.size(), SV.begin()); + void print(std::string_view SV) { + fprintf(stderr, "\"%.*s\"", (int)SV.size(), SV.data()); } void print(const Node *N) { if (N) @@ -90,14 +88,6 @@ struct DumpVisitor { else printStr("<null>"); } - void print(NodeOrString NS) { - if (NS.isNode()) - print(NS.asNode()); - else if (NS.isString()) - print(NS.asString()); - else - printStr("NodeOrString()"); - } void print(NodeArray A) { ++Depth; printStr("{"); @@ -116,13 +106,11 @@ struct DumpVisitor { // Overload used when T is exactly 'bool', not merely convertible to 'bool'. void print(bool B) { printStr(B ? "true" : "false"); } - template <class T> - typename std::enable_if<std::is_unsigned<T>::value>::type print(T N) { + template <class T> std::enable_if_t<std::is_unsigned<T>::value> print(T N) { fprintf(stderr, "%llu", (unsigned long long)N); } - template <class T> - typename std::enable_if<std::is_signed<T>::value>::type print(T N) { + template <class T> std::enable_if_t<std::is_signed<T>::value> print(T N) { fprintf(stderr, "%lld", (long long)N); } @@ -185,6 +173,50 @@ struct DumpVisitor { return printStr("TemplateParamKind::Template"); } } + void print(Node::Prec P) { + switch (P) { + case Node::Prec::Primary: + return printStr("Node::Prec::Primary"); + case Node::Prec::Postfix: + return printStr("Node::Prec::Postfix"); + case Node::Prec::Unary: + return printStr("Node::Prec::Unary"); + case Node::Prec::Cast: + return printStr("Node::Prec::Cast"); + case Node::Prec::PtrMem: + return printStr("Node::Prec::PtrMem"); + case Node::Prec::Multiplicative: + return printStr("Node::Prec::Multiplicative"); + case Node::Prec::Additive: + return printStr("Node::Prec::Additive"); + case Node::Prec::Shift: + return printStr("Node::Prec::Shift"); + case Node::Prec::Spaceship: + return printStr("Node::Prec::Spaceship"); + case Node::Prec::Relational: + return printStr("Node::Prec::Relational"); + case Node::Prec::Equality: + return printStr("Node::Prec::Equality"); + case Node::Prec::And: + return printStr("Node::Prec::And"); + case Node::Prec::Xor: + return printStr("Node::Prec::Xor"); + case Node::Prec::Ior: + return printStr("Node::Prec::Ior"); + case Node::Prec::AndIf: + return printStr("Node::Prec::AndIf"); + case Node::Prec::OrIf: + return printStr("Node::Prec::OrIf"); + case Node::Prec::Conditional: + return printStr("Node::Prec::Conditional"); + case Node::Prec::Assign: + return printStr("Node::Prec::Assign"); + case Node::Prec::Comma: + return printStr("Node::Prec::Comma"); + case Node::Prec::Default: + return printStr("Node::Prec::Default"); + } + } void newLine() { printStr("\n"); @@ -334,36 +366,21 @@ public: using Demangler = itanium_demangle::ManglingParser<DefaultAllocator>; -char *llvm::itaniumDemangle(const char *MangledName, char *Buf, - size_t *N, int *Status) { - if (MangledName == nullptr || (Buf != nullptr && N == nullptr)) { - if (Status) - *Status = demangle_invalid_args; +char *llvm::itaniumDemangle(std::string_view MangledName) { + if (MangledName.empty()) return nullptr; - } - - int InternalStatus = demangle_success; - Demangler Parser(MangledName, MangledName + std::strlen(MangledName)); - OutputStream S; + Demangler Parser(MangledName.data(), + MangledName.data() + MangledName.length()); Node *AST = Parser.parse(); + if (!AST) + return nullptr; - if (AST == nullptr) - InternalStatus = demangle_invalid_mangled_name; - else if (!initializeOutputStream(Buf, N, S, 1024)) - InternalStatus = demangle_memory_alloc_failure; - else { - assert(Parser.ForwardTemplateRefs.empty()); - AST->print(S); - S += '\0'; - if (N != nullptr) - *N = S.getCurrentPosition(); - Buf = S.getBuffer(); - } - - if (Status) - *Status = InternalStatus; - return InternalStatus == demangle_success ? Buf : nullptr; + OutputBuffer OB; + assert(Parser.ForwardTemplateRefs.empty()); + AST->print(OB); + OB += '\0'; + return OB.getBuffer(); } ItaniumPartialDemangler::ItaniumPartialDemangler() @@ -396,14 +413,12 @@ bool ItaniumPartialDemangler::partialDemangle(const char *MangledName) { } static char *printNode(const Node *RootNode, char *Buf, size_t *N) { - OutputStream S; - if (!initializeOutputStream(Buf, N, S, 128)) - return nullptr; - RootNode->print(S); - S += '\0'; + OutputBuffer OB(Buf, N); + RootNode->print(OB); + OB += '\0'; if (N != nullptr) - *N = S.getCurrentPosition(); - return S.getBuffer(); + *N = OB.getCurrentPosition(); + return OB.getBuffer(); } char *ItaniumPartialDemangler::getFunctionBaseName(char *Buf, size_t *N) const { @@ -417,8 +432,8 @@ char *ItaniumPartialDemangler::getFunctionBaseName(char *Buf, size_t *N) const { case Node::KAbiTagAttr: Name = static_cast<const AbiTagAttr *>(Name)->Base; continue; - case Node::KStdQualifiedName: - Name = static_cast<const StdQualifiedName *>(Name)->Child; + case Node::KModuleEntity: + Name = static_cast<const ModuleEntity *>(Name)->Name; continue; case Node::KNestedName: Name = static_cast<const NestedName *>(Name)->Name; @@ -441,9 +456,7 @@ char *ItaniumPartialDemangler::getFunctionDeclContextName(char *Buf, return nullptr; const Node *Name = static_cast<const FunctionEncoding *>(RootNode)->getName(); - OutputStream S; - if (!initializeOutputStream(Buf, N, S, 128)) - return nullptr; + OutputBuffer OB(Buf, N); KeepGoingLocalFunction: while (true) { @@ -458,27 +471,27 @@ char *ItaniumPartialDemangler::getFunctionDeclContextName(char *Buf, break; } + if (Name->getKind() == Node::KModuleEntity) + Name = static_cast<const ModuleEntity *>(Name)->Name; + switch (Name->getKind()) { - case Node::KStdQualifiedName: - S += "std"; - break; case Node::KNestedName: - static_cast<const NestedName *>(Name)->Qual->print(S); + static_cast<const NestedName *>(Name)->Qual->print(OB); break; case Node::KLocalName: { auto *LN = static_cast<const LocalName *>(Name); - LN->Encoding->print(S); - S += "::"; + LN->Encoding->print(OB); + OB += "::"; Name = LN->Entity; goto KeepGoingLocalFunction; } default: break; } - S += '\0'; + OB += '\0'; if (N != nullptr) - *N = S.getCurrentPosition(); - return S.getBuffer(); + *N = OB.getCurrentPosition(); + return OB.getBuffer(); } char *ItaniumPartialDemangler::getFunctionName(char *Buf, size_t *N) const { @@ -494,17 +507,15 @@ char *ItaniumPartialDemangler::getFunctionParameters(char *Buf, return nullptr; NodeArray Params = static_cast<FunctionEncoding *>(RootNode)->getParams(); - OutputStream S; - if (!initializeOutputStream(Buf, N, S, 128)) - return nullptr; + OutputBuffer OB(Buf, N); - S += '('; - Params.printWithComma(S); - S += ')'; - S += '\0'; + OB += '('; + Params.printWithComma(OB); + OB += ')'; + OB += '\0'; if (N != nullptr) - *N = S.getCurrentPosition(); - return S.getBuffer(); + *N = OB.getCurrentPosition(); + return OB.getBuffer(); } char *ItaniumPartialDemangler::getFunctionReturnType( @@ -512,18 +523,16 @@ char *ItaniumPartialDemangler::getFunctionReturnType( if (!isFunction()) return nullptr; - OutputStream S; - if (!initializeOutputStream(Buf, N, S, 128)) - return nullptr; + OutputBuffer OB(Buf, N); if (const Node *Ret = static_cast<const FunctionEncoding *>(RootNode)->getReturnType()) - Ret->print(S); + Ret->print(OB); - S += '\0'; + OB += '\0'; if (N != nullptr) - *N = S.getCurrentPosition(); - return S.getBuffer(); + *N = OB.getCurrentPosition(); + return OB.getBuffer(); } char *ItaniumPartialDemangler::finishDemangle(char *Buf, size_t *N) const { @@ -563,8 +572,8 @@ bool ItaniumPartialDemangler::isCtorOrDtor() const { case Node::KNestedName: N = static_cast<const NestedName *>(N)->Name; break; - case Node::KStdQualifiedName: - N = static_cast<const StdQualifiedName *>(N)->Child; + case Node::KModuleEntity: + N = static_cast<const ModuleEntity *>(N)->Name; break; } } diff --git a/externals/demangle/llvm/Demangle/Demangle.h b/externals/demangle/llvm/Demangle/Demangle.h index 5b673e4e1..1552a501a 100644 --- a/externals/demangle/llvm/Demangle/Demangle.h +++ b/externals/demangle/llvm/Demangle/Demangle.h @@ -12,6 +12,7 @@ #include <cstddef> #include <string> +#include <string_view> namespace llvm { /// This is a llvm local version of __cxa_demangle. Other than the name and @@ -29,9 +30,10 @@ enum : int { demangle_success = 0, }; -char *itaniumDemangle(const char *mangled_name, char *buf, size_t *n, - int *status); - +/// Returns a non-NULL pointer to a NUL-terminated C style string +/// that should be explicitly freed, if successful. Otherwise, may return +/// nullptr if mangled_name is not a valid mangling or is nullptr. +char *itaniumDemangle(std::string_view mangled_name); enum MSDemangleFlags { MSDF_None = 0, @@ -40,10 +42,34 @@ enum MSDemangleFlags { MSDF_NoCallingConvention = 1 << 2, MSDF_NoReturnType = 1 << 3, MSDF_NoMemberType = 1 << 4, + MSDF_NoVariableType = 1 << 5, }; -char *microsoftDemangle(const char *mangled_name, char *buf, size_t *n, + +/// Demangles the Microsoft symbol pointed at by mangled_name and returns it. +/// Returns a pointer to the start of a null-terminated demangled string on +/// success, or nullptr on error. +/// If n_read is non-null and demangling was successful, it receives how many +/// bytes of the input string were consumed. +/// status receives one of the demangle_ enum entries above if it's not nullptr. +/// Flags controls various details of the demangled representation. +char *microsoftDemangle(std::string_view mangled_name, size_t *n_read, int *status, MSDemangleFlags Flags = MSDF_None); +// Demangles a Rust v0 mangled symbol. +char *rustDemangle(std::string_view MangledName); + +// Demangles a D mangled symbol. +char *dlangDemangle(std::string_view MangledName); + +/// Attempt to demangle a string using different demangling schemes. +/// The function uses heuristics to determine which demangling scheme to use. +/// \param MangledName - reference to string to demangle. +/// \returns - the demangled string, or a copy of the input string if no +/// demangling occurred. +std::string demangle(std::string_view MangledName); + +bool nonMicrosoftDemangle(std::string_view MangledName, std::string &Result); + /// "Partial" demangler. This supports demangling a string into an AST /// (typically an intermediate stage in itaniumDemangle) and querying certain /// properties or partially printing the demangled name. @@ -59,7 +85,7 @@ struct ItaniumPartialDemangler { bool partialDemangle(const char *MangledName); /// Just print the entire mangled name into Buf. Buf and N behave like the - /// second and third parameters to itaniumDemangle. + /// second and third parameters to __cxa_demangle. char *finishDemangle(char *Buf, size_t *N) const; /// Get the base name of a function. This doesn't include trailing template @@ -95,6 +121,7 @@ struct ItaniumPartialDemangler { bool isSpecialName() const; ~ItaniumPartialDemangler(); + private: void *RootNode; void *Context; diff --git a/externals/demangle/llvm/Demangle/DemangleConfig.h b/externals/demangle/llvm/Demangle/DemangleConfig.h index a8aef9df1..c7f86d766 100644 --- a/externals/demangle/llvm/Demangle/DemangleConfig.h +++ b/externals/demangle/llvm/Demangle/DemangleConfig.h @@ -13,8 +13,8 @@ // //===----------------------------------------------------------------------===// -#ifndef LLVM_DEMANGLE_COMPILER_H -#define LLVM_DEMANGLE_COMPILER_H +#ifndef LLVM_DEMANGLE_DEMANGLECONFIG_H +#define LLVM_DEMANGLE_DEMANGLECONFIG_H #ifndef __has_feature #define __has_feature(x) 0 diff --git a/externals/demangle/llvm/Demangle/ItaniumDemangle.h b/externals/demangle/llvm/Demangle/ItaniumDemangle.h index 64b35c142..0dc3d7337 100644 --- a/externals/demangle/llvm/Demangle/ItaniumDemangle.h +++ b/externals/demangle/llvm/Demangle/ItaniumDemangle.h @@ -1,5 +1,5 @@ -//===------------------------- ItaniumDemangle.h ----------------*- C++ -*-===// -// +//===--- ItaniumDemangle.h -----------*- mode:c++;eval:(read-only-mode) -*-===// +// Do not edit! See README.txt. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-FileCopyrightText: Part of the LLVM Project @@ -7,144 +7,220 @@ // //===----------------------------------------------------------------------===// // -// Generic itanium demangler library. This file has two byte-per-byte identical -// copies in the source tree, one in libcxxabi, and the other in llvm. +// Generic itanium demangler library. +// There are two copies of this file in the source tree. The one under +// libcxxabi is the original and the one under llvm is the copy. Use +// cp-to-llvm.sh to update the copy. See README.txt for more details. // //===----------------------------------------------------------------------===// #ifndef DEMANGLE_ITANIUMDEMANGLE_H #define DEMANGLE_ITANIUMDEMANGLE_H -// FIXME: (possibly) incomplete list of features that clang mangles that this -// file does not yet support: -// - C++ modules TS - #include "DemangleConfig.h" -#include "StringView.h" +#include "StringViewExtras.h" #include "Utility.h" +#include <algorithm> #include <cassert> #include <cctype> #include <cstdio> #include <cstdlib> #include <cstring> -#include <numeric> +#include <limits> +#include <new> +#include <string_view> +#include <type_traits> #include <utility> -#define FOR_EACH_NODE_KIND(X) \ - X(NodeArrayNode) \ - X(DotSuffix) \ - X(VendorExtQualType) \ - X(QualType) \ - X(ConversionOperatorType) \ - X(PostfixQualifiedType) \ - X(ElaboratedTypeSpefType) \ - X(NameType) \ - X(AbiTagAttr) \ - X(EnableIfAttr) \ - X(ObjCProtoName) \ - X(PointerType) \ - X(ReferenceType) \ - X(PointerToMemberType) \ - X(ArrayType) \ - X(FunctionType) \ - X(NoexceptSpec) \ - X(DynamicExceptionSpec) \ - X(FunctionEncoding) \ - X(LiteralOperator) \ - X(SpecialName) \ - X(CtorVtableSpecialName) \ - X(QualifiedName) \ - X(NestedName) \ - X(LocalName) \ - X(VectorType) \ - X(PixelVectorType) \ - X(SyntheticTemplateParamName) \ - X(TypeTemplateParamDecl) \ - X(NonTypeTemplateParamDecl) \ - X(TemplateTemplateParamDecl) \ - X(TemplateParamPackDecl) \ - X(ParameterPack) \ - X(TemplateArgumentPack) \ - X(ParameterPackExpansion) \ - X(TemplateArgs) \ - X(ForwardTemplateReference) \ - X(NameWithTemplateArgs) \ - X(GlobalQualifiedName) \ - X(StdQualifiedName) \ - X(ExpandedSpecialSubstitution) \ - X(SpecialSubstitution) \ - X(CtorDtorName) \ - X(DtorName) \ - X(UnnamedTypeName) \ - X(ClosureTypeName) \ - X(StructuredBindingName) \ - X(BinaryExpr) \ - X(ArraySubscriptExpr) \ - X(PostfixExpr) \ - X(ConditionalExpr) \ - X(MemberExpr) \ - X(EnclosingExpr) \ - X(CastExpr) \ - X(SizeofParamPackExpr) \ - X(CallExpr) \ - X(NewExpr) \ - X(DeleteExpr) \ - X(PrefixExpr) \ - X(FunctionParam) \ - X(ConversionExpr) \ - X(InitListExpr) \ - X(FoldExpr) \ - X(ThrowExpr) \ - X(UUIDOfExpr) \ - X(BoolExpr) \ - X(StringLiteral) \ - X(LambdaExpr) \ - X(IntegerCastExpr) \ - X(IntegerLiteral) \ - X(FloatLiteral) \ - X(DoubleLiteral) \ - X(LongDoubleLiteral) \ - X(BracedExpr) \ - X(BracedRangeExpr) - DEMANGLE_NAMESPACE_BEGIN +template <class T, size_t N> class PODSmallVector { + static_assert(std::is_pod<T>::value, + "T is required to be a plain old data type"); + + T *First = nullptr; + T *Last = nullptr; + T *Cap = nullptr; + T Inline[N] = {0}; + + bool isInline() const { return First == Inline; } + + void clearInline() { + First = Inline; + Last = Inline; + Cap = Inline + N; + } + + void reserve(size_t NewCap) { + size_t S = size(); + if (isInline()) { + auto *Tmp = static_cast<T *>(std::malloc(NewCap * sizeof(T))); + if (Tmp == nullptr) + std::terminate(); + std::copy(First, Last, Tmp); + First = Tmp; + } else { + First = static_cast<T *>(std::realloc(First, NewCap * sizeof(T))); + if (First == nullptr) + std::terminate(); + } + Last = First + S; + Cap = First + NewCap; + } + +public: + PODSmallVector() : First(Inline), Last(First), Cap(Inline + N) {} + + PODSmallVector(const PODSmallVector &) = delete; + PODSmallVector &operator=(const PODSmallVector &) = delete; + + PODSmallVector(PODSmallVector &&Other) : PODSmallVector() { + if (Other.isInline()) { + std::copy(Other.begin(), Other.end(), First); + Last = First + Other.size(); + Other.clear(); + return; + } + + First = Other.First; + Last = Other.Last; + Cap = Other.Cap; + Other.clearInline(); + } + + PODSmallVector &operator=(PODSmallVector &&Other) { + if (Other.isInline()) { + if (!isInline()) { + std::free(First); + clearInline(); + } + std::copy(Other.begin(), Other.end(), First); + Last = First + Other.size(); + Other.clear(); + return *this; + } + + if (isInline()) { + First = Other.First; + Last = Other.Last; + Cap = Other.Cap; + Other.clearInline(); + return *this; + } + + std::swap(First, Other.First); + std::swap(Last, Other.Last); + std::swap(Cap, Other.Cap); + Other.clear(); + return *this; + } + + // NOLINTNEXTLINE(readability-identifier-naming) + void push_back(const T &Elem) { + if (Last == Cap) + reserve(size() * 2); + *Last++ = Elem; + } + + // NOLINTNEXTLINE(readability-identifier-naming) + void pop_back() { + assert(Last != First && "Popping empty vector!"); + --Last; + } + + void dropBack(size_t Index) { + assert(Index <= size() && "dropBack() can't expand!"); + Last = First + Index; + } + + T *begin() { return First; } + T *end() { return Last; } + + bool empty() const { return First == Last; } + size_t size() const { return static_cast<size_t>(Last - First); } + T &back() { + assert(Last != First && "Calling back() on empty vector!"); + return *(Last - 1); + } + T &operator[](size_t Index) { + assert(Index < size() && "Invalid access!"); + return *(begin() + Index); + } + void clear() { Last = First; } + + ~PODSmallVector() { + if (!isInline()) + std::free(First); + } +}; + // Base class of all AST nodes. The AST is built by the parser, then is // traversed by the printLeft/Right functions to produce a demangled string. class Node { public: enum Kind : unsigned char { -#define ENUMERATOR(NodeKind) K ## NodeKind, - FOR_EACH_NODE_KIND(ENUMERATOR) -#undef ENUMERATOR +#define NODE(NodeKind) K##NodeKind, +#include "ItaniumNodes.def" }; /// Three-way bool to track a cached value. Unknown is possible if this node /// has an unexpanded parameter pack below it that may affect this cache. enum class Cache : unsigned char { Yes, No, Unknown, }; + /// Operator precedence for expression nodes. Used to determine required + /// parens in expression emission. + enum class Prec { + Primary, + Postfix, + Unary, + Cast, + PtrMem, + Multiplicative, + Additive, + Shift, + Spaceship, + Relational, + Equality, + And, + Xor, + Ior, + AndIf, + OrIf, + Conditional, + Assign, + Comma, + Default, + }; + private: Kind K; + Prec Precedence : 6; + // FIXME: Make these protected. public: /// Tracks if this node has a component on its right side, in which case we /// need to call printRight. - Cache RHSComponentCache; + Cache RHSComponentCache : 2; /// Track if this node is a (possibly qualified) array type. This can affect /// how we format the output string. - Cache ArrayCache; + Cache ArrayCache : 2; /// Track if this node is a (possibly qualified) function type. This can /// affect how we format the output string. - Cache FunctionCache; + Cache FunctionCache : 2; public: - Node(Kind K_, Cache RHSComponentCache_ = Cache::No, - Cache ArrayCache_ = Cache::No, Cache FunctionCache_ = Cache::No) - : K(K_), RHSComponentCache(RHSComponentCache_), ArrayCache(ArrayCache_), - FunctionCache(FunctionCache_) {} + Node(Kind K_, Prec Precedence_ = Prec::Primary, + Cache RHSComponentCache_ = Cache::No, Cache ArrayCache_ = Cache::No, + Cache FunctionCache_ = Cache::No) + : K(K_), Precedence(Precedence_), RHSComponentCache(RHSComponentCache_), + ArrayCache(ArrayCache_), FunctionCache(FunctionCache_) {} + Node(Kind K_, Cache RHSComponentCache_, Cache ArrayCache_ = Cache::No, + Cache FunctionCache_ = Cache::No) + : Node(K_, Prec::Primary, RHSComponentCache_, ArrayCache_, + FunctionCache_) {} /// Visit the most-derived object corresponding to this object. template<typename Fn> void visit(Fn F) const; @@ -155,52 +231,65 @@ public: // would construct an equivalent node. //template<typename Fn> void match(Fn F) const; - bool hasRHSComponent(OutputStream &S) const { + bool hasRHSComponent(OutputBuffer &OB) const { if (RHSComponentCache != Cache::Unknown) return RHSComponentCache == Cache::Yes; - return hasRHSComponentSlow(S); + return hasRHSComponentSlow(OB); } - bool hasArray(OutputStream &S) const { + bool hasArray(OutputBuffer &OB) const { if (ArrayCache != Cache::Unknown) return ArrayCache == Cache::Yes; - return hasArraySlow(S); + return hasArraySlow(OB); } - bool hasFunction(OutputStream &S) const { + bool hasFunction(OutputBuffer &OB) const { if (FunctionCache != Cache::Unknown) return FunctionCache == Cache::Yes; - return hasFunctionSlow(S); + return hasFunctionSlow(OB); } Kind getKind() const { return K; } - virtual bool hasRHSComponentSlow(OutputStream &) const { return false; } - virtual bool hasArraySlow(OutputStream &) const { return false; } - virtual bool hasFunctionSlow(OutputStream &) const { return false; } + Prec getPrecedence() const { return Precedence; } + + virtual bool hasRHSComponentSlow(OutputBuffer &) const { return false; } + virtual bool hasArraySlow(OutputBuffer &) const { return false; } + virtual bool hasFunctionSlow(OutputBuffer &) const { return false; } // Dig through "glue" nodes like ParameterPack and ForwardTemplateReference to // get at a node that actually represents some concrete syntax. - virtual const Node *getSyntaxNode(OutputStream &) const { - return this; - } - - void print(OutputStream &S) const { - printLeft(S); + virtual const Node *getSyntaxNode(OutputBuffer &) const { return this; } + + // Print this node as an expression operand, surrounding it in parentheses if + // its precedence is [Strictly] weaker than P. + void printAsOperand(OutputBuffer &OB, Prec P = Prec::Default, + bool StrictlyWorse = false) const { + bool Paren = + unsigned(getPrecedence()) >= unsigned(P) + unsigned(StrictlyWorse); + if (Paren) + OB.printOpen(); + print(OB); + if (Paren) + OB.printClose(); + } + + void print(OutputBuffer &OB) const { + printLeft(OB); if (RHSComponentCache != Cache::No) - printRight(S); + printRight(OB); } - // Print the "left" side of this Node into OutputStream. - virtual void printLeft(OutputStream &) const = 0; + // Print the "left" side of this Node into OutputBuffer. + virtual void printLeft(OutputBuffer &) const = 0; // Print the "right". This distinction is necessary to represent C++ types // that appear on the RHS of their subtype, such as arrays or functions. // Since most types don't have such a component, provide a default // implementation. - virtual void printRight(OutputStream &) const {} + virtual void printRight(OutputBuffer &) const {} - virtual StringView getBaseName() const { return StringView(); } + virtual std::string_view getBaseName() const { return {}; } // Silence compiler warnings, this dtor will never be called. virtual ~Node() = default; @@ -227,19 +316,19 @@ public: Node *operator[](size_t Idx) const { return Elements[Idx]; } - void printWithComma(OutputStream &S) const { + void printWithComma(OutputBuffer &OB) const { bool FirstElement = true; for (size_t Idx = 0; Idx != NumElements; ++Idx) { - size_t BeforeComma = S.getCurrentPosition(); + size_t BeforeComma = OB.getCurrentPosition(); if (!FirstElement) - S += ", "; - size_t AfterComma = S.getCurrentPosition(); - Elements[Idx]->print(S); + OB += ", "; + size_t AfterComma = OB.getCurrentPosition(); + Elements[Idx]->printAsOperand(OB, Node::Prec::Comma); // Elements[Idx] is an empty parameter pack expansion, we should erase the // comma we just printed. - if (AfterComma == S.getCurrentPosition()) { - S.setCurrentPosition(BeforeComma); + if (AfterComma == OB.getCurrentPosition()) { + OB.setCurrentPosition(BeforeComma); continue; } @@ -254,43 +343,48 @@ struct NodeArrayNode : Node { template<typename Fn> void match(Fn F) const { F(Array); } - void printLeft(OutputStream &S) const override { - Array.printWithComma(S); - } + void printLeft(OutputBuffer &OB) const override { Array.printWithComma(OB); } }; class DotSuffix final : public Node { const Node *Prefix; - const StringView Suffix; + const std::string_view Suffix; public: - DotSuffix(const Node *Prefix_, StringView Suffix_) + DotSuffix(const Node *Prefix_, std::string_view Suffix_) : Node(KDotSuffix), Prefix(Prefix_), Suffix(Suffix_) {} template<typename Fn> void match(Fn F) const { F(Prefix, Suffix); } - void printLeft(OutputStream &s) const override { - Prefix->print(s); - s += " ("; - s += Suffix; - s += ")"; + void printLeft(OutputBuffer &OB) const override { + Prefix->print(OB); + OB += " ("; + OB += Suffix; + OB += ")"; } }; class VendorExtQualType final : public Node { const Node *Ty; - StringView Ext; + std::string_view Ext; + const Node *TA; public: - VendorExtQualType(const Node *Ty_, StringView Ext_) - : Node(KVendorExtQualType), Ty(Ty_), Ext(Ext_) {} + VendorExtQualType(const Node *Ty_, std::string_view Ext_, const Node *TA_) + : Node(KVendorExtQualType), Ty(Ty_), Ext(Ext_), TA(TA_) {} + + const Node *getTy() const { return Ty; } + std::string_view getExt() const { return Ext; } + const Node *getTA() const { return TA; } - template<typename Fn> void match(Fn F) const { F(Ty, Ext); } + template <typename Fn> void match(Fn F) const { F(Ty, Ext, TA); } - void printLeft(OutputStream &S) const override { - Ty->print(S); - S += " "; - S += Ext; + void printLeft(OutputBuffer &OB) const override { + Ty->print(OB); + OB += " "; + OB += Ext; + if (TA != nullptr) + TA->print(OB); } }; @@ -316,13 +410,13 @@ protected: const Qualifiers Quals; const Node *Child; - void printQuals(OutputStream &S) const { + void printQuals(OutputBuffer &OB) const { if (Quals & QualConst) - S += " const"; + OB += " const"; if (Quals & QualVolatile) - S += " volatile"; + OB += " volatile"; if (Quals & QualRestrict) - S += " restrict"; + OB += " restrict"; } public: @@ -331,24 +425,27 @@ public: Child_->ArrayCache, Child_->FunctionCache), Quals(Quals_), Child(Child_) {} + Qualifiers getQuals() const { return Quals; } + const Node *getChild() const { return Child; } + template<typename Fn> void match(Fn F) const { F(Child, Quals); } - bool hasRHSComponentSlow(OutputStream &S) const override { - return Child->hasRHSComponent(S); + bool hasRHSComponentSlow(OutputBuffer &OB) const override { + return Child->hasRHSComponent(OB); } - bool hasArraySlow(OutputStream &S) const override { - return Child->hasArray(S); + bool hasArraySlow(OutputBuffer &OB) const override { + return Child->hasArray(OB); } - bool hasFunctionSlow(OutputStream &S) const override { - return Child->hasFunction(S); + bool hasFunctionSlow(OutputBuffer &OB) const override { + return Child->hasFunction(OB); } - void printLeft(OutputStream &S) const override { - Child->printLeft(S); - printQuals(S); + void printLeft(OutputBuffer &OB) const override { + Child->printLeft(OB); + printQuals(OB); } - void printRight(OutputStream &S) const override { Child->printRight(S); } + void printRight(OutputBuffer &OB) const override { Child->printRight(OB); } }; class ConversionOperatorType final : public Node { @@ -360,74 +457,96 @@ public: template<typename Fn> void match(Fn F) const { F(Ty); } - void printLeft(OutputStream &S) const override { - S += "operator "; - Ty->print(S); + void printLeft(OutputBuffer &OB) const override { + OB += "operator "; + Ty->print(OB); } }; class PostfixQualifiedType final : public Node { const Node *Ty; - const StringView Postfix; + const std::string_view Postfix; public: - PostfixQualifiedType(Node *Ty_, StringView Postfix_) + PostfixQualifiedType(const Node *Ty_, std::string_view Postfix_) : Node(KPostfixQualifiedType), Ty(Ty_), Postfix(Postfix_) {} template<typename Fn> void match(Fn F) const { F(Ty, Postfix); } - void printLeft(OutputStream &s) const override { - Ty->printLeft(s); - s += Postfix; + void printLeft(OutputBuffer &OB) const override { + Ty->printLeft(OB); + OB += Postfix; } }; class NameType final : public Node { - const StringView Name; + const std::string_view Name; public: - NameType(StringView Name_) : Node(KNameType), Name(Name_) {} + NameType(std::string_view Name_) : Node(KNameType), Name(Name_) {} template<typename Fn> void match(Fn F) const { F(Name); } - StringView getName() const { return Name; } - StringView getBaseName() const override { return Name; } + std::string_view getName() const { return Name; } + std::string_view getBaseName() const override { return Name; } + + void printLeft(OutputBuffer &OB) const override { OB += Name; } +}; + +class BitIntType final : public Node { + const Node *Size; + bool Signed; + +public: + BitIntType(const Node *Size_, bool Signed_) + : Node(KBitIntType), Size(Size_), Signed(Signed_) {} + + template <typename Fn> void match(Fn F) const { F(Size, Signed); } - void printLeft(OutputStream &s) const override { s += Name; } + void printLeft(OutputBuffer &OB) const override { + if (!Signed) + OB += "unsigned "; + OB += "_BitInt"; + OB.printOpen(); + Size->printAsOperand(OB); + OB.printClose(); + } }; class ElaboratedTypeSpefType : public Node { - StringView Kind; + std::string_view Kind; Node *Child; public: - ElaboratedTypeSpefType(StringView Kind_, Node *Child_) + ElaboratedTypeSpefType(std::string_view Kind_, Node *Child_) : Node(KElaboratedTypeSpefType), Kind(Kind_), Child(Child_) {} template<typename Fn> void match(Fn F) const { F(Kind, Child); } - void printLeft(OutputStream &S) const override { - S += Kind; - S += ' '; - Child->print(S); + void printLeft(OutputBuffer &OB) const override { + OB += Kind; + OB += ' '; + Child->print(OB); } }; struct AbiTagAttr : Node { Node *Base; - StringView Tag; + std::string_view Tag; - AbiTagAttr(Node* Base_, StringView Tag_) - : Node(KAbiTagAttr, Base_->RHSComponentCache, - Base_->ArrayCache, Base_->FunctionCache), + AbiTagAttr(Node *Base_, std::string_view Tag_) + : Node(KAbiTagAttr, Base_->RHSComponentCache, Base_->ArrayCache, + Base_->FunctionCache), Base(Base_), Tag(Tag_) {} template<typename Fn> void match(Fn F) const { F(Base, Tag); } - void printLeft(OutputStream &S) const override { - Base->printLeft(S); - S += "[abi:"; - S += Tag; - S += "]"; + std::string_view getBaseName() const override { return Base->getBaseName(); } + + void printLeft(OutputBuffer &OB) const override { + Base->printLeft(OB); + OB += "[abi:"; + OB += Tag; + OB += "]"; } }; @@ -439,21 +558,21 @@ public: template<typename Fn> void match(Fn F) const { F(Conditions); } - void printLeft(OutputStream &S) const override { - S += " [enable_if:"; - Conditions.printWithComma(S); - S += ']'; + void printLeft(OutputBuffer &OB) const override { + OB += " [enable_if:"; + Conditions.printWithComma(OB); + OB += ']'; } }; class ObjCProtoName : public Node { const Node *Ty; - StringView Protocol; + std::string_view Protocol; friend class PointerType; public: - ObjCProtoName(const Node *Ty_, StringView Protocol_) + ObjCProtoName(const Node *Ty_, std::string_view Protocol_) : Node(KObjCProtoName), Ty(Ty_), Protocol(Protocol_) {} template<typename Fn> void match(Fn F) const { F(Ty, Protocol); } @@ -463,11 +582,11 @@ public: static_cast<const NameType *>(Ty)->getName() == "objc_object"; } - void printLeft(OutputStream &S) const override { - Ty->print(S); - S += "<"; - S += Protocol; - S += ">"; + void printLeft(OutputBuffer &OB) const override { + Ty->print(OB); + OB += "<"; + OB += Protocol; + OB += ">"; } }; @@ -479,36 +598,38 @@ public: : Node(KPointerType, Pointee_->RHSComponentCache), Pointee(Pointee_) {} + const Node *getPointee() const { return Pointee; } + template<typename Fn> void match(Fn F) const { F(Pointee); } - bool hasRHSComponentSlow(OutputStream &S) const override { - return Pointee->hasRHSComponent(S); + bool hasRHSComponentSlow(OutputBuffer &OB) const override { + return Pointee->hasRHSComponent(OB); } - void printLeft(OutputStream &s) const override { + void printLeft(OutputBuffer &OB) const override { // We rewrite objc_object<SomeProtocol>* into id<SomeProtocol>. if (Pointee->getKind() != KObjCProtoName || !static_cast<const ObjCProtoName *>(Pointee)->isObjCObject()) { - Pointee->printLeft(s); - if (Pointee->hasArray(s)) - s += " "; - if (Pointee->hasArray(s) || Pointee->hasFunction(s)) - s += "("; - s += "*"; + Pointee->printLeft(OB); + if (Pointee->hasArray(OB)) + OB += " "; + if (Pointee->hasArray(OB) || Pointee->hasFunction(OB)) + OB += "("; + OB += "*"; } else { const auto *objcProto = static_cast<const ObjCProtoName *>(Pointee); - s += "id<"; - s += objcProto->Protocol; - s += ">"; + OB += "id<"; + OB += objcProto->Protocol; + OB += ">"; } } - void printRight(OutputStream &s) const override { + void printRight(OutputBuffer &OB) const override { if (Pointee->getKind() != KObjCProtoName || !static_cast<const ObjCProtoName *>(Pointee)->isObjCObject()) { - if (Pointee->hasArray(s) || Pointee->hasFunction(s)) - s += ")"; - Pointee->printRight(s); + if (Pointee->hasArray(OB) || Pointee->hasFunction(OB)) + OB += ")"; + Pointee->printRight(OB); } } }; @@ -528,15 +649,30 @@ class ReferenceType : public Node { // Dig through any refs to refs, collapsing the ReferenceTypes as we go. The // rule here is rvalue ref to rvalue ref collapses to a rvalue ref, and any // other combination collapses to a lvalue ref. - std::pair<ReferenceKind, const Node *> collapse(OutputStream &S) const { + // + // A combination of a TemplateForwardReference and a back-ref Substitution + // from an ill-formed string may have created a cycle; use cycle detection to + // avoid looping forever. + std::pair<ReferenceKind, const Node *> collapse(OutputBuffer &OB) const { auto SoFar = std::make_pair(RK, Pointee); + // Track the chain of nodes for the Floyd's 'tortoise and hare' + // cycle-detection algorithm, since getSyntaxNode(S) is impure + PODSmallVector<const Node *, 8> Prev; for (;;) { - const Node *SN = SoFar.second->getSyntaxNode(S); + const Node *SN = SoFar.second->getSyntaxNode(OB); if (SN->getKind() != KReferenceType) break; auto *RT = static_cast<const ReferenceType *>(SN); SoFar.second = RT->Pointee; SoFar.first = std::min(SoFar.first, RT->RK); + + // The middle of Prev is the 'slow' pointer moving at half speed + Prev.push_back(SoFar.second); + if (Prev.size() > 1 && SoFar.second == Prev[(Prev.size() - 1) / 2]) { + // Cycle detected + SoFar.second = nullptr; + break; + } } return SoFar; } @@ -548,31 +684,35 @@ public: template<typename Fn> void match(Fn F) const { F(Pointee, RK); } - bool hasRHSComponentSlow(OutputStream &S) const override { - return Pointee->hasRHSComponent(S); + bool hasRHSComponentSlow(OutputBuffer &OB) const override { + return Pointee->hasRHSComponent(OB); } - void printLeft(OutputStream &s) const override { + void printLeft(OutputBuffer &OB) const override { if (Printing) return; - SwapAndRestore<bool> SavePrinting(Printing, true); - std::pair<ReferenceKind, const Node *> Collapsed = collapse(s); - Collapsed.second->printLeft(s); - if (Collapsed.second->hasArray(s)) - s += " "; - if (Collapsed.second->hasArray(s) || Collapsed.second->hasFunction(s)) - s += "("; + ScopedOverride<bool> SavePrinting(Printing, true); + std::pair<ReferenceKind, const Node *> Collapsed = collapse(OB); + if (!Collapsed.second) + return; + Collapsed.second->printLeft(OB); + if (Collapsed.second->hasArray(OB)) + OB += " "; + if (Collapsed.second->hasArray(OB) || Collapsed.second->hasFunction(OB)) + OB += "("; - s += (Collapsed.first == ReferenceKind::LValue ? "&" : "&&"); + OB += (Collapsed.first == ReferenceKind::LValue ? "&" : "&&"); } - void printRight(OutputStream &s) const override { + void printRight(OutputBuffer &OB) const override { if (Printing) return; - SwapAndRestore<bool> SavePrinting(Printing, true); - std::pair<ReferenceKind, const Node *> Collapsed = collapse(s); - if (Collapsed.second->hasArray(s) || Collapsed.second->hasFunction(s)) - s += ")"; - Collapsed.second->printRight(s); + ScopedOverride<bool> SavePrinting(Printing, true); + std::pair<ReferenceKind, const Node *> Collapsed = collapse(OB); + if (!Collapsed.second) + return; + if (Collapsed.second->hasArray(OB) || Collapsed.second->hasFunction(OB)) + OB += ")"; + Collapsed.second->printRight(OB); } }; @@ -587,69 +727,33 @@ public: template<typename Fn> void match(Fn F) const { F(ClassType, MemberType); } - bool hasRHSComponentSlow(OutputStream &S) const override { - return MemberType->hasRHSComponent(S); + bool hasRHSComponentSlow(OutputBuffer &OB) const override { + return MemberType->hasRHSComponent(OB); } - void printLeft(OutputStream &s) const override { - MemberType->printLeft(s); - if (MemberType->hasArray(s) || MemberType->hasFunction(s)) - s += "("; + void printLeft(OutputBuffer &OB) const override { + MemberType->printLeft(OB); + if (MemberType->hasArray(OB) || MemberType->hasFunction(OB)) + OB += "("; else - s += " "; - ClassType->print(s); - s += "::*"; - } - - void printRight(OutputStream &s) const override { - if (MemberType->hasArray(s) || MemberType->hasFunction(s)) - s += ")"; - MemberType->printRight(s); - } -}; - -class NodeOrString { - const void *First; - const void *Second; - -public: - /* implicit */ NodeOrString(StringView Str) { - const char *FirstChar = Str.begin(); - const char *SecondChar = Str.end(); - if (SecondChar == nullptr) { - assert(FirstChar == SecondChar); - ++FirstChar, ++SecondChar; - } - First = static_cast<const void *>(FirstChar); - Second = static_cast<const void *>(SecondChar); + OB += " "; + ClassType->print(OB); + OB += "::*"; } - /* implicit */ NodeOrString(Node *N) - : First(static_cast<const void *>(N)), Second(nullptr) {} - NodeOrString() : First(nullptr), Second(nullptr) {} - - bool isString() const { return Second && First; } - bool isNode() const { return First && !Second; } - bool isEmpty() const { return !First && !Second; } - - StringView asString() const { - assert(isString()); - return StringView(static_cast<const char *>(First), - static_cast<const char *>(Second)); - } - - const Node *asNode() const { - assert(isNode()); - return static_cast<const Node *>(First); + void printRight(OutputBuffer &OB) const override { + if (MemberType->hasArray(OB) || MemberType->hasFunction(OB)) + OB += ")"; + MemberType->printRight(OB); } }; class ArrayType final : public Node { const Node *Base; - NodeOrString Dimension; + Node *Dimension; public: - ArrayType(const Node *Base_, NodeOrString Dimension_) + ArrayType(const Node *Base_, Node *Dimension_) : Node(KArrayType, /*RHSComponentCache=*/Cache::Yes, /*ArrayCache=*/Cache::Yes), @@ -657,21 +761,19 @@ public: template<typename Fn> void match(Fn F) const { F(Base, Dimension); } - bool hasRHSComponentSlow(OutputStream &) const override { return true; } - bool hasArraySlow(OutputStream &) const override { return true; } + bool hasRHSComponentSlow(OutputBuffer &) const override { return true; } + bool hasArraySlow(OutputBuffer &) const override { return true; } - void printLeft(OutputStream &S) const override { Base->printLeft(S); } + void printLeft(OutputBuffer &OB) const override { Base->printLeft(OB); } - void printRight(OutputStream &S) const override { - if (S.back() != ']') - S += " "; - S += "["; - if (Dimension.isString()) - S += Dimension.asString(); - else if (Dimension.isNode()) - Dimension.asNode()->print(S); - S += "]"; - Base->printRight(S); + void printRight(OutputBuffer &OB) const override { + if (OB.back() != ']') + OB += " "; + OB += "["; + if (Dimension) + Dimension->print(OB); + OB += "]"; + Base->printRight(OB); } }; @@ -695,8 +797,8 @@ public: F(Ret, Params, CVQuals, RefQual, ExceptionSpec); } - bool hasRHSComponentSlow(OutputStream &) const override { return true; } - bool hasFunctionSlow(OutputStream &) const override { return true; } + bool hasRHSComponentSlow(OutputBuffer &) const override { return true; } + bool hasFunctionSlow(OutputBuffer &) const override { return true; } // Handle C++'s ... quirky decl grammar by using the left & right // distinction. Consider: @@ -705,32 +807,32 @@ public: // that takes a char and returns an int. If we're trying to print f, start // by printing out the return types's left, then print our parameters, then // finally print right of the return type. - void printLeft(OutputStream &S) const override { - Ret->printLeft(S); - S += " "; + void printLeft(OutputBuffer &OB) const override { + Ret->printLeft(OB); + OB += " "; } - void printRight(OutputStream &S) const override { - S += "("; - Params.printWithComma(S); - S += ")"; - Ret->printRight(S); + void printRight(OutputBuffer &OB) const override { + OB.printOpen(); + Params.printWithComma(OB); + OB.printClose(); + Ret->printRight(OB); if (CVQuals & QualConst) - S += " const"; + OB += " const"; if (CVQuals & QualVolatile) - S += " volatile"; + OB += " volatile"; if (CVQuals & QualRestrict) - S += " restrict"; + OB += " restrict"; if (RefQual == FrefQualLValue) - S += " &"; + OB += " &"; else if (RefQual == FrefQualRValue) - S += " &&"; + OB += " &&"; if (ExceptionSpec != nullptr) { - S += ' '; - ExceptionSpec->print(S); + OB += ' '; + ExceptionSpec->print(OB); } } }; @@ -742,10 +844,11 @@ public: template<typename Fn> void match(Fn F) const { F(E); } - void printLeft(OutputStream &S) const override { - S += "noexcept("; - E->print(S); - S += ")"; + void printLeft(OutputBuffer &OB) const override { + OB += "noexcept"; + OB.printOpen(); + E->printAsOperand(OB); + OB.printClose(); } }; @@ -757,10 +860,11 @@ public: template<typename Fn> void match(Fn F) const { F(Types); } - void printLeft(OutputStream &S) const override { - S += "throw("; - Types.printWithComma(S); - S += ')'; + void printLeft(OutputBuffer &OB) const override { + OB += "throw"; + OB.printOpen(); + Types.printWithComma(OB); + OB.printClose(); } }; @@ -791,41 +895,41 @@ public: NodeArray getParams() const { return Params; } const Node *getReturnType() const { return Ret; } - bool hasRHSComponentSlow(OutputStream &) const override { return true; } - bool hasFunctionSlow(OutputStream &) const override { return true; } + bool hasRHSComponentSlow(OutputBuffer &) const override { return true; } + bool hasFunctionSlow(OutputBuffer &) const override { return true; } const Node *getName() const { return Name; } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (Ret) { - Ret->printLeft(S); - if (!Ret->hasRHSComponent(S)) - S += " "; + Ret->printLeft(OB); + if (!Ret->hasRHSComponent(OB)) + OB += " "; } - Name->print(S); + Name->print(OB); } - void printRight(OutputStream &S) const override { - S += "("; - Params.printWithComma(S); - S += ")"; + void printRight(OutputBuffer &OB) const override { + OB.printOpen(); + Params.printWithComma(OB); + OB.printClose(); if (Ret) - Ret->printRight(S); + Ret->printRight(OB); if (CVQuals & QualConst) - S += " const"; + OB += " const"; if (CVQuals & QualVolatile) - S += " volatile"; + OB += " volatile"; if (CVQuals & QualRestrict) - S += " restrict"; + OB += " restrict"; if (RefQual == FrefQualLValue) - S += " &"; + OB += " &"; else if (RefQual == FrefQualRValue) - S += " &&"; + OB += " &&"; if (Attrs != nullptr) - Attrs->print(S); + Attrs->print(OB); } }; @@ -838,25 +942,25 @@ public: template<typename Fn> void match(Fn F) const { F(OpName); } - void printLeft(OutputStream &S) const override { - S += "operator\"\" "; - OpName->print(S); + void printLeft(OutputBuffer &OB) const override { + OB += "operator\"\" "; + OpName->print(OB); } }; class SpecialName final : public Node { - const StringView Special; + const std::string_view Special; const Node *Child; public: - SpecialName(StringView Special_, const Node *Child_) + SpecialName(std::string_view Special_, const Node *Child_) : Node(KSpecialName), Special(Special_), Child(Child_) {} template<typename Fn> void match(Fn F) const { F(Special, Child); } - void printLeft(OutputStream &S) const override { - S += Special; - Child->print(S); + void printLeft(OutputBuffer &OB) const override { + OB += Special; + Child->print(OB); } }; @@ -871,11 +975,11 @@ public: template<typename Fn> void match(Fn F) const { F(FirstType, SecondType); } - void printLeft(OutputStream &S) const override { - S += "construction vtable for "; - FirstType->print(S); - S += "-in-"; - SecondType->print(S); + void printLeft(OutputBuffer &OB) const override { + OB += "construction vtable for "; + FirstType->print(OB); + OB += "-in-"; + SecondType->print(OB); } }; @@ -888,12 +992,52 @@ struct NestedName : Node { template<typename Fn> void match(Fn F) const { F(Qual, Name); } - StringView getBaseName() const override { return Name->getBaseName(); } + std::string_view getBaseName() const override { return Name->getBaseName(); } + + void printLeft(OutputBuffer &OB) const override { + Qual->print(OB); + OB += "::"; + Name->print(OB); + } +}; + +struct ModuleName : Node { + ModuleName *Parent; + Node *Name; + bool IsPartition; + + ModuleName(ModuleName *Parent_, Node *Name_, bool IsPartition_ = false) + : Node(KModuleName), Parent(Parent_), Name(Name_), + IsPartition(IsPartition_) {} + + template <typename Fn> void match(Fn F) const { + F(Parent, Name, IsPartition); + } + + void printLeft(OutputBuffer &OB) const override { + if (Parent) + Parent->print(OB); + if (Parent || IsPartition) + OB += IsPartition ? ':' : '.'; + Name->print(OB); + } +}; + +struct ModuleEntity : Node { + ModuleName *Module; + Node *Name; + + ModuleEntity(ModuleName *Module_, Node *Name_) + : Node(KModuleEntity), Module(Module_), Name(Name_) {} + + template <typename Fn> void match(Fn F) const { F(Module, Name); } + + std::string_view getBaseName() const override { return Name->getBaseName(); } - void printLeft(OutputStream &S) const override { - Qual->print(S); - S += "::"; - Name->print(S); + void printLeft(OutputBuffer &OB) const override { + Name->print(OB); + OB += '@'; + Module->print(OB); } }; @@ -906,10 +1050,10 @@ struct LocalName : Node { template<typename Fn> void match(Fn F) const { F(Encoding, Entity); } - void printLeft(OutputStream &S) const override { - Encoding->print(S); - S += "::"; - Entity->print(S); + void printLeft(OutputBuffer &OB) const override { + Encoding->print(OB); + OB += "::"; + Entity->print(OB); } }; @@ -924,51 +1068,66 @@ public: template<typename Fn> void match(Fn F) const { F(Qualifier, Name); } - StringView getBaseName() const override { return Name->getBaseName(); } + std::string_view getBaseName() const override { return Name->getBaseName(); } - void printLeft(OutputStream &S) const override { - Qualifier->print(S); - S += "::"; - Name->print(S); + void printLeft(OutputBuffer &OB) const override { + Qualifier->print(OB); + OB += "::"; + Name->print(OB); } }; class VectorType final : public Node { const Node *BaseType; - const NodeOrString Dimension; + const Node *Dimension; public: - VectorType(const Node *BaseType_, NodeOrString Dimension_) - : Node(KVectorType), BaseType(BaseType_), - Dimension(Dimension_) {} + VectorType(const Node *BaseType_, const Node *Dimension_) + : Node(KVectorType), BaseType(BaseType_), Dimension(Dimension_) {} + + const Node *getBaseType() const { return BaseType; } + const Node *getDimension() const { return Dimension; } template<typename Fn> void match(Fn F) const { F(BaseType, Dimension); } - void printLeft(OutputStream &S) const override { - BaseType->print(S); - S += " vector["; - if (Dimension.isNode()) - Dimension.asNode()->print(S); - else if (Dimension.isString()) - S += Dimension.asString(); - S += "]"; + void printLeft(OutputBuffer &OB) const override { + BaseType->print(OB); + OB += " vector["; + if (Dimension) + Dimension->print(OB); + OB += "]"; } }; class PixelVectorType final : public Node { - const NodeOrString Dimension; + const Node *Dimension; public: - PixelVectorType(NodeOrString Dimension_) + PixelVectorType(const Node *Dimension_) : Node(KPixelVectorType), Dimension(Dimension_) {} template<typename Fn> void match(Fn F) const { F(Dimension); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { // FIXME: This should demangle as "vector pixel". - S += "pixel vector["; - S += Dimension.asString(); - S += "]"; + OB += "pixel vector["; + Dimension->print(OB); + OB += "]"; + } +}; + +class BinaryFPType final : public Node { + const Node *Dimension; + +public: + BinaryFPType(const Node *Dimension_) + : Node(KBinaryFPType), Dimension(Dimension_) {} + + template<typename Fn> void match(Fn F) const { F(Dimension); } + + void printLeft(OutputBuffer &OB) const override { + OB += "_Float"; + Dimension->print(OB); } }; @@ -990,20 +1149,20 @@ public: template<typename Fn> void match(Fn F) const { F(Kind, Index); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { switch (Kind) { case TemplateParamKind::Type: - S += "$T"; + OB += "$T"; break; case TemplateParamKind::NonType: - S += "$N"; + OB += "$N"; break; case TemplateParamKind::Template: - S += "$TT"; + OB += "$TT"; break; } if (Index > 0) - S << Index - 1; + OB << Index - 1; } }; @@ -1017,13 +1176,9 @@ public: template<typename Fn> void match(Fn F) const { F(Name); } - void printLeft(OutputStream &S) const override { - S += "typename "; - } + void printLeft(OutputBuffer &OB) const override { OB += "typename "; } - void printRight(OutputStream &S) const override { - Name->print(S); - } + void printRight(OutputBuffer &OB) const override { Name->print(OB); } }; /// A non-type template parameter declaration, 'int N'. @@ -1037,15 +1192,15 @@ public: template<typename Fn> void match(Fn F) const { F(Name, Type); } - void printLeft(OutputStream &S) const override { - Type->printLeft(S); - if (!Type->hasRHSComponent(S)) - S += " "; + void printLeft(OutputBuffer &OB) const override { + Type->printLeft(OB); + if (!Type->hasRHSComponent(OB)) + OB += " "; } - void printRight(OutputStream &S) const override { - Name->print(S); - Type->printRight(S); + void printRight(OutputBuffer &OB) const override { + Name->print(OB); + Type->printRight(OB); } }; @@ -1062,15 +1217,14 @@ public: template<typename Fn> void match(Fn F) const { F(Name, Params); } - void printLeft(OutputStream &S) const override { - S += "template<"; - Params.printWithComma(S); - S += "> typename "; + void printLeft(OutputBuffer &OB) const override { + ScopedOverride<unsigned> LT(OB.GtIsGt, 0); + OB += "template<"; + Params.printWithComma(OB); + OB += "> typename "; } - void printRight(OutputStream &S) const override { - Name->print(S); - } + void printRight(OutputBuffer &OB) const override { Name->print(OB); } }; /// A template parameter pack declaration, 'typename ...T'. @@ -1083,14 +1237,12 @@ public: template<typename Fn> void match(Fn F) const { F(Param); } - void printLeft(OutputStream &S) const override { - Param->printLeft(S); - S += "..."; + void printLeft(OutputBuffer &OB) const override { + Param->printLeft(OB); + OB += "..."; } - void printRight(OutputStream &S) const override { - Param->printRight(S); - } + void printRight(OutputBuffer &OB) const override { Param->printRight(OB); } }; /// An unexpanded parameter pack (either in the expression or type context). If @@ -1104,11 +1256,12 @@ public: class ParameterPack final : public Node { NodeArray Data; - // Setup OutputStream for a pack expansion unless we're already expanding one. - void initializePackExpansion(OutputStream &S) const { - if (S.CurrentPackMax == std::numeric_limits<unsigned>::max()) { - S.CurrentPackMax = static_cast<unsigned>(Data.size()); - S.CurrentPackIndex = 0; + // Setup OutputBuffer for a pack expansion, unless we're already expanding + // one. + void initializePackExpansion(OutputBuffer &OB) const { + if (OB.CurrentPackMax == std::numeric_limits<unsigned>::max()) { + OB.CurrentPackMax = static_cast<unsigned>(Data.size()); + OB.CurrentPackIndex = 0; } } @@ -1131,38 +1284,38 @@ public: template<typename Fn> void match(Fn F) const { F(Data); } - bool hasRHSComponentSlow(OutputStream &S) const override { - initializePackExpansion(S); - size_t Idx = S.CurrentPackIndex; - return Idx < Data.size() && Data[Idx]->hasRHSComponent(S); + bool hasRHSComponentSlow(OutputBuffer &OB) const override { + initializePackExpansion(OB); + size_t Idx = OB.CurrentPackIndex; + return Idx < Data.size() && Data[Idx]->hasRHSComponent(OB); } - bool hasArraySlow(OutputStream &S) const override { - initializePackExpansion(S); - size_t Idx = S.CurrentPackIndex; - return Idx < Data.size() && Data[Idx]->hasArray(S); + bool hasArraySlow(OutputBuffer &OB) const override { + initializePackExpansion(OB); + size_t Idx = OB.CurrentPackIndex; + return Idx < Data.size() && Data[Idx]->hasArray(OB); } - bool hasFunctionSlow(OutputStream &S) const override { - initializePackExpansion(S); - size_t Idx = S.CurrentPackIndex; - return Idx < Data.size() && Data[Idx]->hasFunction(S); + bool hasFunctionSlow(OutputBuffer &OB) const override { + initializePackExpansion(OB); + size_t Idx = OB.CurrentPackIndex; + return Idx < Data.size() && Data[Idx]->hasFunction(OB); } - const Node *getSyntaxNode(OutputStream &S) const override { - initializePackExpansion(S); - size_t Idx = S.CurrentPackIndex; - return Idx < Data.size() ? Data[Idx]->getSyntaxNode(S) : this; + const Node *getSyntaxNode(OutputBuffer &OB) const override { + initializePackExpansion(OB); + size_t Idx = OB.CurrentPackIndex; + return Idx < Data.size() ? Data[Idx]->getSyntaxNode(OB) : this; } - void printLeft(OutputStream &S) const override { - initializePackExpansion(S); - size_t Idx = S.CurrentPackIndex; + void printLeft(OutputBuffer &OB) const override { + initializePackExpansion(OB); + size_t Idx = OB.CurrentPackIndex; if (Idx < Data.size()) - Data[Idx]->printLeft(S); + Data[Idx]->printLeft(OB); } - void printRight(OutputStream &S) const override { - initializePackExpansion(S); - size_t Idx = S.CurrentPackIndex; + void printRight(OutputBuffer &OB) const override { + initializePackExpansion(OB); + size_t Idx = OB.CurrentPackIndex; if (Idx < Data.size()) - Data[Idx]->printRight(S); + Data[Idx]->printRight(OB); } }; @@ -1181,8 +1334,8 @@ public: NodeArray getElements() const { return Elements; } - void printLeft(OutputStream &S) const override { - Elements.printWithComma(S); + void printLeft(OutputBuffer &OB) const override { + Elements.printWithComma(OB); } }; @@ -1199,35 +1352,35 @@ public: const Node *getChild() const { return Child; } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { constexpr unsigned Max = std::numeric_limits<unsigned>::max(); - SwapAndRestore<unsigned> SavePackIdx(S.CurrentPackIndex, Max); - SwapAndRestore<unsigned> SavePackMax(S.CurrentPackMax, Max); - size_t StreamPos = S.getCurrentPosition(); + ScopedOverride<unsigned> SavePackIdx(OB.CurrentPackIndex, Max); + ScopedOverride<unsigned> SavePackMax(OB.CurrentPackMax, Max); + size_t StreamPos = OB.getCurrentPosition(); // Print the first element in the pack. If Child contains a ParameterPack, // it will set up S.CurrentPackMax and print the first element. - Child->print(S); + Child->print(OB); // No ParameterPack was found in Child. This can occur if we've found a pack // expansion on a <function-param>. - if (S.CurrentPackMax == Max) { - S += "..."; + if (OB.CurrentPackMax == Max) { + OB += "..."; return; } // We found a ParameterPack, but it has no elements. Erase whatever we may // of printed. - if (S.CurrentPackMax == 0) { - S.setCurrentPosition(StreamPos); + if (OB.CurrentPackMax == 0) { + OB.setCurrentPosition(StreamPos); return; } // Else, iterate through the rest of the elements in the pack. - for (unsigned I = 1, E = S.CurrentPackMax; I < E; ++I) { - S += ", "; - S.CurrentPackIndex = I; - Child->print(S); + for (unsigned I = 1, E = OB.CurrentPackMax; I < E; ++I) { + OB += ", "; + OB.CurrentPackIndex = I; + Child->print(OB); } } }; @@ -1242,12 +1395,11 @@ public: NodeArray getParams() { return Params; } - void printLeft(OutputStream &S) const override { - S += "<"; - Params.printWithComma(S); - if (S.back() == '>') - S += " "; - S += ">"; + void printLeft(OutputBuffer &OB) const override { + ScopedOverride<unsigned> LT(OB.GtIsGt, 0); + OB += "<"; + Params.printWithComma(OB); + OB += ">"; } }; @@ -1289,42 +1441,42 @@ struct ForwardTemplateReference : Node { // special handling. template<typename Fn> void match(Fn F) const = delete; - bool hasRHSComponentSlow(OutputStream &S) const override { + bool hasRHSComponentSlow(OutputBuffer &OB) const override { if (Printing) return false; - SwapAndRestore<bool> SavePrinting(Printing, true); - return Ref->hasRHSComponent(S); + ScopedOverride<bool> SavePrinting(Printing, true); + return Ref->hasRHSComponent(OB); } - bool hasArraySlow(OutputStream &S) const override { + bool hasArraySlow(OutputBuffer &OB) const override { if (Printing) return false; - SwapAndRestore<bool> SavePrinting(Printing, true); - return Ref->hasArray(S); + ScopedOverride<bool> SavePrinting(Printing, true); + return Ref->hasArray(OB); } - bool hasFunctionSlow(OutputStream &S) const override { + bool hasFunctionSlow(OutputBuffer &OB) const override { if (Printing) return false; - SwapAndRestore<bool> SavePrinting(Printing, true); - return Ref->hasFunction(S); + ScopedOverride<bool> SavePrinting(Printing, true); + return Ref->hasFunction(OB); } - const Node *getSyntaxNode(OutputStream &S) const override { + const Node *getSyntaxNode(OutputBuffer &OB) const override { if (Printing) return this; - SwapAndRestore<bool> SavePrinting(Printing, true); - return Ref->getSyntaxNode(S); + ScopedOverride<bool> SavePrinting(Printing, true); + return Ref->getSyntaxNode(OB); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (Printing) return; - SwapAndRestore<bool> SavePrinting(Printing, true); - Ref->printLeft(S); + ScopedOverride<bool> SavePrinting(Printing, true); + Ref->printLeft(OB); } - void printRight(OutputStream &S) const override { + void printRight(OutputBuffer &OB) const override { if (Printing) return; - SwapAndRestore<bool> SavePrinting(Printing, true); - Ref->printRight(S); + ScopedOverride<bool> SavePrinting(Printing, true); + Ref->printRight(OB); } }; @@ -1338,11 +1490,11 @@ struct NameWithTemplateArgs : Node { template<typename Fn> void match(Fn F) const { F(Name, TemplateArgs); } - StringView getBaseName() const override { return Name->getBaseName(); } + std::string_view getBaseName() const override { return Name->getBaseName(); } - void printLeft(OutputStream &S) const override { - Name->print(S); - TemplateArgs->print(S); + void printLeft(OutputBuffer &OB) const override { + Name->print(OB); + TemplateArgs->print(OB); } }; @@ -1355,26 +1507,11 @@ public: template<typename Fn> void match(Fn F) const { F(Child); } - StringView getBaseName() const override { return Child->getBaseName(); } + std::string_view getBaseName() const override { return Child->getBaseName(); } - void printLeft(OutputStream &S) const override { - S += "::"; - Child->print(S); - } -}; - -struct StdQualifiedName : Node { - Node *Child; - - StdQualifiedName(Node *Child_) : Node(KStdQualifiedName), Child(Child_) {} - - template<typename Fn> void match(Fn F) const { F(Child); } - - StringView getBaseName() const override { return Child->getBaseName(); } - - void printLeft(OutputStream &S) const override { - S += "std::"; - Child->print(S); + void printLeft(OutputBuffer &OB) const override { + OB += "::"; + Child->print(OB); } }; @@ -1387,109 +1524,81 @@ enum class SpecialSubKind { iostream, }; -class ExpandedSpecialSubstitution final : public Node { +class SpecialSubstitution; +class ExpandedSpecialSubstitution : public Node { +protected: SpecialSubKind SSK; + ExpandedSpecialSubstitution(SpecialSubKind SSK_, Kind K_) + : Node(K_), SSK(SSK_) {} public: ExpandedSpecialSubstitution(SpecialSubKind SSK_) - : Node(KExpandedSpecialSubstitution), SSK(SSK_) {} + : ExpandedSpecialSubstitution(SSK_, KExpandedSpecialSubstitution) {} + inline ExpandedSpecialSubstitution(SpecialSubstitution const *); template<typename Fn> void match(Fn F) const { F(SSK); } - StringView getBaseName() const override { +protected: + bool isInstantiation() const { + return unsigned(SSK) >= unsigned(SpecialSubKind::string); + } + + std::string_view getBaseName() const override { switch (SSK) { case SpecialSubKind::allocator: - return StringView("allocator"); + return {"allocator"}; case SpecialSubKind::basic_string: - return StringView("basic_string"); + return {"basic_string"}; case SpecialSubKind::string: - return StringView("basic_string"); + return {"basic_string"}; case SpecialSubKind::istream: - return StringView("basic_istream"); + return {"basic_istream"}; case SpecialSubKind::ostream: - return StringView("basic_ostream"); + return {"basic_ostream"}; case SpecialSubKind::iostream: - return StringView("basic_iostream"); + return {"basic_iostream"}; } DEMANGLE_UNREACHABLE; } - void printLeft(OutputStream &S) const override { - switch (SSK) { - case SpecialSubKind::allocator: - S += "std::allocator"; - break; - case SpecialSubKind::basic_string: - S += "std::basic_string"; - break; - case SpecialSubKind::string: - S += "std::basic_string<char, std::char_traits<char>, " - "std::allocator<char> >"; - break; - case SpecialSubKind::istream: - S += "std::basic_istream<char, std::char_traits<char> >"; - break; - case SpecialSubKind::ostream: - S += "std::basic_ostream<char, std::char_traits<char> >"; - break; - case SpecialSubKind::iostream: - S += "std::basic_iostream<char, std::char_traits<char> >"; - break; +private: + void printLeft(OutputBuffer &OB) const override { + OB << "std::" << getBaseName(); + if (isInstantiation()) { + OB << "<char, std::char_traits<char>"; + if (SSK == SpecialSubKind::string) + OB << ", std::allocator<char>"; + OB << ">"; } } }; -class SpecialSubstitution final : public Node { +class SpecialSubstitution final : public ExpandedSpecialSubstitution { public: - SpecialSubKind SSK; - SpecialSubstitution(SpecialSubKind SSK_) - : Node(KSpecialSubstitution), SSK(SSK_) {} + : ExpandedSpecialSubstitution(SSK_, KSpecialSubstitution) {} template<typename Fn> void match(Fn F) const { F(SSK); } - StringView getBaseName() const override { - switch (SSK) { - case SpecialSubKind::allocator: - return StringView("allocator"); - case SpecialSubKind::basic_string: - return StringView("basic_string"); - case SpecialSubKind::string: - return StringView("string"); - case SpecialSubKind::istream: - return StringView("istream"); - case SpecialSubKind::ostream: - return StringView("ostream"); - case SpecialSubKind::iostream: - return StringView("iostream"); + std::string_view getBaseName() const override { + std::string_view SV = ExpandedSpecialSubstitution::getBaseName(); + if (isInstantiation()) { + // The instantiations are typedefs that drop the "basic_" prefix. + assert(llvm::itanium_demangle::starts_with(SV, "basic_")); + SV.remove_prefix(sizeof("basic_") - 1); } - DEMANGLE_UNREACHABLE; + return SV; } - void printLeft(OutputStream &S) const override { - switch (SSK) { - case SpecialSubKind::allocator: - S += "std::allocator"; - break; - case SpecialSubKind::basic_string: - S += "std::basic_string"; - break; - case SpecialSubKind::string: - S += "std::string"; - break; - case SpecialSubKind::istream: - S += "std::istream"; - break; - case SpecialSubKind::ostream: - S += "std::ostream"; - break; - case SpecialSubKind::iostream: - S += "std::iostream"; - break; - } + void printLeft(OutputBuffer &OB) const override { + OB << "std::" << getBaseName(); } }; +inline ExpandedSpecialSubstitution::ExpandedSpecialSubstitution( + SpecialSubstitution const *SS) + : ExpandedSpecialSubstitution(SS->SSK) {} + class CtorDtorName final : public Node { const Node *Basename; const bool IsDtor; @@ -1502,10 +1611,10 @@ public: template<typename Fn> void match(Fn F) const { F(Basename, IsDtor, Variant); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (IsDtor) - S += "~"; - S += Basename->getBaseName(); + OB += "~"; + OB += Basename->getBaseName(); } }; @@ -1517,35 +1626,36 @@ public: template<typename Fn> void match(Fn F) const { F(Base); } - void printLeft(OutputStream &S) const override { - S += "~"; - Base->printLeft(S); + void printLeft(OutputBuffer &OB) const override { + OB += "~"; + Base->printLeft(OB); } }; class UnnamedTypeName : public Node { - const StringView Count; + const std::string_view Count; public: - UnnamedTypeName(StringView Count_) : Node(KUnnamedTypeName), Count(Count_) {} + UnnamedTypeName(std::string_view Count_) + : Node(KUnnamedTypeName), Count(Count_) {} template<typename Fn> void match(Fn F) const { F(Count); } - void printLeft(OutputStream &S) const override { - S += "'unnamed"; - S += Count; - S += "\'"; + void printLeft(OutputBuffer &OB) const override { + OB += "'unnamed"; + OB += Count; + OB += "\'"; } }; class ClosureTypeName : public Node { NodeArray TemplateParams; NodeArray Params; - StringView Count; + std::string_view Count; public: ClosureTypeName(NodeArray TemplateParams_, NodeArray Params_, - StringView Count_) + std::string_view Count_) : Node(KClosureTypeName), TemplateParams(TemplateParams_), Params(Params_), Count(Count_) {} @@ -1553,22 +1663,23 @@ public: F(TemplateParams, Params, Count); } - void printDeclarator(OutputStream &S) const { + void printDeclarator(OutputBuffer &OB) const { if (!TemplateParams.empty()) { - S += "<"; - TemplateParams.printWithComma(S); - S += ">"; + ScopedOverride<unsigned> LT(OB.GtIsGt, 0); + OB += "<"; + TemplateParams.printWithComma(OB); + OB += ">"; } - S += "("; - Params.printWithComma(S); - S += ")"; + OB.printOpen(); + Params.printWithComma(OB); + OB.printClose(); } - void printLeft(OutputStream &S) const override { - S += "\'lambda"; - S += Count; - S += "\'"; - printDeclarator(S); + void printLeft(OutputBuffer &OB) const override { + OB += "\'lambda"; + OB += Count; + OB += "\'"; + printDeclarator(OB); } }; @@ -1580,10 +1691,10 @@ public: template<typename Fn> void match(Fn F) const { F(Bindings); } - void printLeft(OutputStream &S) const override { - S += '['; - Bindings.printWithComma(S); - S += ']'; + void printLeft(OutputBuffer &OB) const override { + OB.printOpen('['); + Bindings.printWithComma(OB); + OB.printClose(']'); } }; @@ -1591,32 +1702,35 @@ public: class BinaryExpr : public Node { const Node *LHS; - const StringView InfixOperator; + const std::string_view InfixOperator; const Node *RHS; public: - BinaryExpr(const Node *LHS_, StringView InfixOperator_, const Node *RHS_) - : Node(KBinaryExpr), LHS(LHS_), InfixOperator(InfixOperator_), RHS(RHS_) { - } - - template<typename Fn> void match(Fn F) const { F(LHS, InfixOperator, RHS); } - - void printLeft(OutputStream &S) const override { - // might be a template argument expression, then we need to disambiguate - // with parens. - if (InfixOperator == ">") - S += "("; - - S += "("; - LHS->print(S); - S += ") "; - S += InfixOperator; - S += " ("; - RHS->print(S); - S += ")"; - - if (InfixOperator == ">") - S += ")"; + BinaryExpr(const Node *LHS_, std::string_view InfixOperator_, + const Node *RHS_, Prec Prec_) + : Node(KBinaryExpr, Prec_), LHS(LHS_), InfixOperator(InfixOperator_), + RHS(RHS_) {} + + template <typename Fn> void match(Fn F) const { + F(LHS, InfixOperator, RHS, getPrecedence()); + } + + void printLeft(OutputBuffer &OB) const override { + bool ParenAll = OB.isGtInsideTemplateArgs() && + (InfixOperator == ">" || InfixOperator == ">>"); + if (ParenAll) + OB.printOpen(); + // Assignment is right associative, with special LHS precedence. + bool IsAssign = getPrecedence() == Prec::Assign; + LHS->printAsOperand(OB, IsAssign ? Prec::OrIf : getPrecedence(), !IsAssign); + // No space before comma operator + if (!(InfixOperator == ",")) + OB += " "; + OB += InfixOperator; + OB += " "; + RHS->printAsOperand(OB, getPrecedence(), IsAssign); + if (ParenAll) + OB.printClose(); } }; @@ -1625,35 +1739,36 @@ class ArraySubscriptExpr : public Node { const Node *Op2; public: - ArraySubscriptExpr(const Node *Op1_, const Node *Op2_) - : Node(KArraySubscriptExpr), Op1(Op1_), Op2(Op2_) {} + ArraySubscriptExpr(const Node *Op1_, const Node *Op2_, Prec Prec_) + : Node(KArraySubscriptExpr, Prec_), Op1(Op1_), Op2(Op2_) {} - template<typename Fn> void match(Fn F) const { F(Op1, Op2); } + template <typename Fn> void match(Fn F) const { + F(Op1, Op2, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += "("; - Op1->print(S); - S += ")["; - Op2->print(S); - S += "]"; + void printLeft(OutputBuffer &OB) const override { + Op1->printAsOperand(OB, getPrecedence()); + OB.printOpen('['); + Op2->printAsOperand(OB); + OB.printClose(']'); } }; class PostfixExpr : public Node { const Node *Child; - const StringView Operator; + const std::string_view Operator; public: - PostfixExpr(const Node *Child_, StringView Operator_) - : Node(KPostfixExpr), Child(Child_), Operator(Operator_) {} + PostfixExpr(const Node *Child_, std::string_view Operator_, Prec Prec_) + : Node(KPostfixExpr, Prec_), Child(Child_), Operator(Operator_) {} - template<typename Fn> void match(Fn F) const { F(Child, Operator); } + template <typename Fn> void match(Fn F) const { + F(Child, Operator, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += "("; - Child->print(S); - S += ")"; - S += Operator; + void printLeft(OutputBuffer &OB) const override { + Child->printAsOperand(OB, getPrecedence(), true); + OB += Operator; } }; @@ -1663,78 +1778,128 @@ class ConditionalExpr : public Node { const Node *Else; public: - ConditionalExpr(const Node *Cond_, const Node *Then_, const Node *Else_) - : Node(KConditionalExpr), Cond(Cond_), Then(Then_), Else(Else_) {} + ConditionalExpr(const Node *Cond_, const Node *Then_, const Node *Else_, + Prec Prec_) + : Node(KConditionalExpr, Prec_), Cond(Cond_), Then(Then_), Else(Else_) {} - template<typename Fn> void match(Fn F) const { F(Cond, Then, Else); } + template <typename Fn> void match(Fn F) const { + F(Cond, Then, Else, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += "("; - Cond->print(S); - S += ") ? ("; - Then->print(S); - S += ") : ("; - Else->print(S); - S += ")"; + void printLeft(OutputBuffer &OB) const override { + Cond->printAsOperand(OB, getPrecedence()); + OB += " ? "; + Then->printAsOperand(OB); + OB += " : "; + Else->printAsOperand(OB, Prec::Assign, true); } }; class MemberExpr : public Node { const Node *LHS; - const StringView Kind; + const std::string_view Kind; const Node *RHS; public: - MemberExpr(const Node *LHS_, StringView Kind_, const Node *RHS_) - : Node(KMemberExpr), LHS(LHS_), Kind(Kind_), RHS(RHS_) {} + MemberExpr(const Node *LHS_, std::string_view Kind_, const Node *RHS_, + Prec Prec_) + : Node(KMemberExpr, Prec_), LHS(LHS_), Kind(Kind_), RHS(RHS_) {} + + template <typename Fn> void match(Fn F) const { + F(LHS, Kind, RHS, getPrecedence()); + } + + void printLeft(OutputBuffer &OB) const override { + LHS->printAsOperand(OB, getPrecedence(), true); + OB += Kind; + RHS->printAsOperand(OB, getPrecedence(), false); + } +}; + +class SubobjectExpr : public Node { + const Node *Type; + const Node *SubExpr; + std::string_view Offset; + NodeArray UnionSelectors; + bool OnePastTheEnd; - template<typename Fn> void match(Fn F) const { F(LHS, Kind, RHS); } +public: + SubobjectExpr(const Node *Type_, const Node *SubExpr_, + std::string_view Offset_, NodeArray UnionSelectors_, + bool OnePastTheEnd_) + : Node(KSubobjectExpr), Type(Type_), SubExpr(SubExpr_), Offset(Offset_), + UnionSelectors(UnionSelectors_), OnePastTheEnd(OnePastTheEnd_) {} - void printLeft(OutputStream &S) const override { - LHS->print(S); - S += Kind; - RHS->print(S); + template<typename Fn> void match(Fn F) const { + F(Type, SubExpr, Offset, UnionSelectors, OnePastTheEnd); + } + + void printLeft(OutputBuffer &OB) const override { + SubExpr->print(OB); + OB += ".<"; + Type->print(OB); + OB += " at offset "; + if (Offset.empty()) { + OB += "0"; + } else if (Offset[0] == 'n') { + OB += "-"; + OB += std::string_view(Offset.data() + 1, Offset.size() - 1); + } else { + OB += Offset; + } + OB += ">"; } }; class EnclosingExpr : public Node { - const StringView Prefix; + const std::string_view Prefix; const Node *Infix; - const StringView Postfix; + const std::string_view Postfix; public: - EnclosingExpr(StringView Prefix_, Node *Infix_, StringView Postfix_) - : Node(KEnclosingExpr), Prefix(Prefix_), Infix(Infix_), - Postfix(Postfix_) {} + EnclosingExpr(std::string_view Prefix_, const Node *Infix_, + Prec Prec_ = Prec::Primary) + : Node(KEnclosingExpr, Prec_), Prefix(Prefix_), Infix(Infix_) {} - template<typename Fn> void match(Fn F) const { F(Prefix, Infix, Postfix); } + template <typename Fn> void match(Fn F) const { + F(Prefix, Infix, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += Prefix; - Infix->print(S); - S += Postfix; + void printLeft(OutputBuffer &OB) const override { + OB += Prefix; + OB.printOpen(); + Infix->print(OB); + OB.printClose(); + OB += Postfix; } }; class CastExpr : public Node { // cast_kind<to>(from) - const StringView CastKind; + const std::string_view CastKind; const Node *To; const Node *From; public: - CastExpr(StringView CastKind_, const Node *To_, const Node *From_) - : Node(KCastExpr), CastKind(CastKind_), To(To_), From(From_) {} + CastExpr(std::string_view CastKind_, const Node *To_, const Node *From_, + Prec Prec_) + : Node(KCastExpr, Prec_), CastKind(CastKind_), To(To_), From(From_) {} - template<typename Fn> void match(Fn F) const { F(CastKind, To, From); } + template <typename Fn> void match(Fn F) const { + F(CastKind, To, From, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += CastKind; - S += "<"; - To->printLeft(S); - S += ">("; - From->printLeft(S); - S += ")"; + void printLeft(OutputBuffer &OB) const override { + OB += CastKind; + { + ScopedOverride<unsigned> LT(OB.GtIsGt, 0); + OB += "<"; + To->printLeft(OB); + OB += ">"; + } + OB.printOpen(); + From->printAsOperand(OB); + OB.printClose(); } }; @@ -1747,11 +1912,12 @@ public: template<typename Fn> void match(Fn F) const { F(Pack); } - void printLeft(OutputStream &S) const override { - S += "sizeof...("; + void printLeft(OutputBuffer &OB) const override { + OB += "sizeof..."; + OB.printOpen(); ParameterPackExpansion PPE(Pack); - PPE.printLeft(S); - S += ")"; + PPE.printLeft(OB); + OB.printClose(); } }; @@ -1760,16 +1926,18 @@ class CallExpr : public Node { NodeArray Args; public: - CallExpr(const Node *Callee_, NodeArray Args_) - : Node(KCallExpr), Callee(Callee_), Args(Args_) {} + CallExpr(const Node *Callee_, NodeArray Args_, Prec Prec_) + : Node(KCallExpr, Prec_), Callee(Callee_), Args(Args_) {} - template<typename Fn> void match(Fn F) const { F(Callee, Args); } + template <typename Fn> void match(Fn F) const { + F(Callee, Args, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - Callee->print(S); - S += "("; - Args.printWithComma(S); - S += ")"; + void printLeft(OutputBuffer &OB) const override { + Callee->print(OB); + OB.printOpen(); + Args.printWithComma(OB); + OB.printClose(); } }; @@ -1782,33 +1950,32 @@ class NewExpr : public Node { bool IsArray; // new[] ? public: NewExpr(NodeArray ExprList_, Node *Type_, NodeArray InitList_, bool IsGlobal_, - bool IsArray_) - : Node(KNewExpr), ExprList(ExprList_), Type(Type_), InitList(InitList_), - IsGlobal(IsGlobal_), IsArray(IsArray_) {} + bool IsArray_, Prec Prec_) + : Node(KNewExpr, Prec_), ExprList(ExprList_), Type(Type_), + InitList(InitList_), IsGlobal(IsGlobal_), IsArray(IsArray_) {} template<typename Fn> void match(Fn F) const { - F(ExprList, Type, InitList, IsGlobal, IsArray); + F(ExprList, Type, InitList, IsGlobal, IsArray, getPrecedence()); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (IsGlobal) - S += "::operator "; - S += "new"; + OB += "::"; + OB += "new"; if (IsArray) - S += "[]"; - S += ' '; + OB += "[]"; if (!ExprList.empty()) { - S += "("; - ExprList.printWithComma(S); - S += ")"; + OB.printOpen(); + ExprList.printWithComma(OB); + OB.printClose(); } - Type->print(S); + OB += " "; + Type->print(OB); if (!InitList.empty()) { - S += "("; - InitList.printWithComma(S); - S += ")"; + OB.printOpen(); + InitList.printWithComma(OB); + OB.printClose(); } - } }; @@ -1818,50 +1985,55 @@ class DeleteExpr : public Node { bool IsArray; public: - DeleteExpr(Node *Op_, bool IsGlobal_, bool IsArray_) - : Node(KDeleteExpr), Op(Op_), IsGlobal(IsGlobal_), IsArray(IsArray_) {} + DeleteExpr(Node *Op_, bool IsGlobal_, bool IsArray_, Prec Prec_) + : Node(KDeleteExpr, Prec_), Op(Op_), IsGlobal(IsGlobal_), + IsArray(IsArray_) {} - template<typename Fn> void match(Fn F) const { F(Op, IsGlobal, IsArray); } + template <typename Fn> void match(Fn F) const { + F(Op, IsGlobal, IsArray, getPrecedence()); + } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (IsGlobal) - S += "::"; - S += "delete"; + OB += "::"; + OB += "delete"; if (IsArray) - S += "[] "; - Op->print(S); + OB += "[]"; + OB += ' '; + Op->print(OB); } }; class PrefixExpr : public Node { - StringView Prefix; + std::string_view Prefix; Node *Child; public: - PrefixExpr(StringView Prefix_, Node *Child_) - : Node(KPrefixExpr), Prefix(Prefix_), Child(Child_) {} + PrefixExpr(std::string_view Prefix_, Node *Child_, Prec Prec_) + : Node(KPrefixExpr, Prec_), Prefix(Prefix_), Child(Child_) {} - template<typename Fn> void match(Fn F) const { F(Prefix, Child); } + template <typename Fn> void match(Fn F) const { + F(Prefix, Child, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += Prefix; - S += "("; - Child->print(S); - S += ")"; + void printLeft(OutputBuffer &OB) const override { + OB += Prefix; + Child->printAsOperand(OB, getPrecedence()); } }; class FunctionParam : public Node { - StringView Number; + std::string_view Number; public: - FunctionParam(StringView Number_) : Node(KFunctionParam), Number(Number_) {} + FunctionParam(std::string_view Number_) + : Node(KFunctionParam), Number(Number_) {} template<typename Fn> void match(Fn F) const { F(Number); } - void printLeft(OutputStream &S) const override { - S += "fp"; - S += Number; + void printLeft(OutputBuffer &OB) const override { + OB += "fp"; + OB += Number; } }; @@ -1870,17 +2042,45 @@ class ConversionExpr : public Node { NodeArray Expressions; public: - ConversionExpr(const Node *Type_, NodeArray Expressions_) - : Node(KConversionExpr), Type(Type_), Expressions(Expressions_) {} + ConversionExpr(const Node *Type_, NodeArray Expressions_, Prec Prec_) + : Node(KConversionExpr, Prec_), Type(Type_), Expressions(Expressions_) {} + + template <typename Fn> void match(Fn F) const { + F(Type, Expressions, getPrecedence()); + } + + void printLeft(OutputBuffer &OB) const override { + OB.printOpen(); + Type->print(OB); + OB.printClose(); + OB.printOpen(); + Expressions.printWithComma(OB); + OB.printClose(); + } +}; + +class PointerToMemberConversionExpr : public Node { + const Node *Type; + const Node *SubExpr; + std::string_view Offset; + +public: + PointerToMemberConversionExpr(const Node *Type_, const Node *SubExpr_, + std::string_view Offset_, Prec Prec_) + : Node(KPointerToMemberConversionExpr, Prec_), Type(Type_), + SubExpr(SubExpr_), Offset(Offset_) {} - template<typename Fn> void match(Fn F) const { F(Type, Expressions); } + template <typename Fn> void match(Fn F) const { + F(Type, SubExpr, Offset, getPrecedence()); + } - void printLeft(OutputStream &S) const override { - S += "("; - Type->print(S); - S += ")("; - Expressions.printWithComma(S); - S += ")"; + void printLeft(OutputBuffer &OB) const override { + OB.printOpen(); + Type->print(OB); + OB.printClose(); + OB.printOpen(); + SubExpr->print(OB); + OB.printClose(); } }; @@ -1893,12 +2093,12 @@ public: template<typename Fn> void match(Fn F) const { F(Ty, Inits); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (Ty) - Ty->print(S); - S += '{'; - Inits.printWithComma(S); - S += '}'; + Ty->print(OB); + OB += '{'; + Inits.printWithComma(OB); + OB += '}'; } }; @@ -1912,18 +2112,18 @@ public: template<typename Fn> void match(Fn F) const { F(Elem, Init, IsArray); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (IsArray) { - S += '['; - Elem->print(S); - S += ']'; + OB += '['; + Elem->print(OB); + OB += ']'; } else { - S += '.'; - Elem->print(S); + OB += '.'; + Elem->print(OB); } if (Init->getKind() != KBracedExpr && Init->getKind() != KBracedRangeExpr) - S += " = "; - Init->print(S); + OB += " = "; + Init->print(OB); } }; @@ -1937,25 +2137,25 @@ public: template<typename Fn> void match(Fn F) const { F(First, Last, Init); } - void printLeft(OutputStream &S) const override { - S += '['; - First->print(S); - S += " ... "; - Last->print(S); - S += ']'; + void printLeft(OutputBuffer &OB) const override { + OB += '['; + First->print(OB); + OB += " ... "; + Last->print(OB); + OB += ']'; if (Init->getKind() != KBracedExpr && Init->getKind() != KBracedRangeExpr) - S += " = "; - Init->print(S); + OB += " = "; + Init->print(OB); } }; class FoldExpr : public Node { const Node *Pack, *Init; - StringView OperatorName; + std::string_view OperatorName; bool IsLeftFold; public: - FoldExpr(bool IsLeftFold_, StringView OperatorName_, const Node *Pack_, + FoldExpr(bool IsLeftFold_, std::string_view OperatorName_, const Node *Pack_, const Node *Init_) : Node(KFoldExpr), Pack(Pack_), Init(Init_), OperatorName(OperatorName_), IsLeftFold(IsLeftFold_) {} @@ -1964,43 +2164,35 @@ public: F(IsLeftFold, OperatorName, Pack, Init); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { auto PrintPack = [&] { - S += '('; - ParameterPackExpansion(Pack).print(S); - S += ')'; + OB.printOpen(); + ParameterPackExpansion(Pack).print(OB); + OB.printClose(); }; - S += '('; - - if (IsLeftFold) { - // init op ... op pack - if (Init != nullptr) { - Init->print(S); - S += ' '; - S += OperatorName; - S += ' '; - } - // ... op pack - S += "... "; - S += OperatorName; - S += ' '; - PrintPack(); - } else { // !IsLeftFold - // pack op ... - PrintPack(); - S += ' '; - S += OperatorName; - S += " ..."; - // pack op ... op init - if (Init != nullptr) { - S += ' '; - S += OperatorName; - S += ' '; - Init->print(S); - } + OB.printOpen(); + // Either '[init op ]... op pack' or 'pack op ...[ op init]' + // Refactored to '[(init|pack) op ]...[ op (pack|init)]' + // Fold expr operands are cast-expressions + if (!IsLeftFold || Init != nullptr) { + // '(init|pack) op ' + if (IsLeftFold) + Init->printAsOperand(OB, Prec::Cast, true); + else + PrintPack(); + OB << " " << OperatorName << " "; + } + OB << "..."; + if (IsLeftFold || Init != nullptr) { + // ' op (init|pack)' + OB << " " << OperatorName << " "; + if (IsLeftFold) + PrintPack(); + else + Init->printAsOperand(OB, Prec::Cast, true); } - S += ')'; + OB.printClose(); } }; @@ -2012,24 +2204,9 @@ public: template<typename Fn> void match(Fn F) const { F(Op); } - void printLeft(OutputStream &S) const override { - S += "throw "; - Op->print(S); - } -}; - -// MSVC __uuidof extension, generated by clang in -fms-extensions mode. -class UUIDOfExpr : public Node { - Node *Operand; -public: - UUIDOfExpr(Node *Operand_) : Node(KUUIDOfExpr), Operand(Operand_) {} - - template<typename Fn> void match(Fn F) const { F(Operand); } - - void printLeft(OutputStream &S) const override { - S << "__uuidof("; - Operand->print(S); - S << ")"; + void printLeft(OutputBuffer &OB) const override { + OB += "throw "; + Op->print(OB); } }; @@ -2041,8 +2218,8 @@ public: template<typename Fn> void match(Fn F) const { F(Value); } - void printLeft(OutputStream &S) const override { - S += Value ? StringView("true") : StringView("false"); + void printLeft(OutputBuffer &OB) const override { + OB += Value ? std::string_view("true") : std::string_view("false"); } }; @@ -2054,10 +2231,10 @@ public: template<typename Fn> void match(Fn F) const { F(Type); } - void printLeft(OutputStream &S) const override { - S += "\"<"; - Type->print(S); - S += ">\""; + void printLeft(OutputBuffer &OB) const override { + OB += "\"<"; + Type->print(OB); + OB += ">\""; } }; @@ -2069,58 +2246,61 @@ public: template<typename Fn> void match(Fn F) const { F(Type); } - void printLeft(OutputStream &S) const override { - S += "[]"; + void printLeft(OutputBuffer &OB) const override { + OB += "[]"; if (Type->getKind() == KClosureTypeName) - static_cast<const ClosureTypeName *>(Type)->printDeclarator(S); - S += "{...}"; + static_cast<const ClosureTypeName *>(Type)->printDeclarator(OB); + OB += "{...}"; } }; -class IntegerCastExpr : public Node { +class EnumLiteral : public Node { // ty(integer) const Node *Ty; - StringView Integer; + std::string_view Integer; public: - IntegerCastExpr(const Node *Ty_, StringView Integer_) - : Node(KIntegerCastExpr), Ty(Ty_), Integer(Integer_) {} + EnumLiteral(const Node *Ty_, std::string_view Integer_) + : Node(KEnumLiteral), Ty(Ty_), Integer(Integer_) {} template<typename Fn> void match(Fn F) const { F(Ty, Integer); } - void printLeft(OutputStream &S) const override { - S += "("; - Ty->print(S); - S += ")"; - S += Integer; + void printLeft(OutputBuffer &OB) const override { + OB.printOpen(); + Ty->print(OB); + OB.printClose(); + + if (Integer[0] == 'n') + OB << '-' << std::string_view(Integer.data() + 1, Integer.size() - 1); + else + OB << Integer; } }; class IntegerLiteral : public Node { - StringView Type; - StringView Value; + std::string_view Type; + std::string_view Value; public: - IntegerLiteral(StringView Type_, StringView Value_) + IntegerLiteral(std::string_view Type_, std::string_view Value_) : Node(KIntegerLiteral), Type(Type_), Value(Value_) {} template<typename Fn> void match(Fn F) const { F(Type, Value); } - void printLeft(OutputStream &S) const override { + void printLeft(OutputBuffer &OB) const override { if (Type.size() > 3) { - S += "("; - S += Type; - S += ")"; + OB.printOpen(); + OB += Type; + OB.printClose(); } - if (Value[0] == 'n') { - S += "-"; - S += Value.dropFront(1); - } else - S += Value; + if (Value[0] == 'n') + OB << '-' << std::string_view(Value.data() + 1, Value.size() - 1); + else + OB += Value; if (Type.size() <= 3) - S += Type; + OB += Type; } }; @@ -2139,29 +2319,26 @@ constexpr Node::Kind getFloatLiteralKind(long double *) { } template <class Float> class FloatLiteralImpl : public Node { - const StringView Contents; + const std::string_view Contents; static constexpr Kind KindForClass = float_literal_impl::getFloatLiteralKind((Float *)nullptr); public: - FloatLiteralImpl(StringView Contents_) + FloatLiteralImpl(std::string_view Contents_) : Node(KindForClass), Contents(Contents_) {} template<typename Fn> void match(Fn F) const { F(Contents); } - void printLeft(OutputStream &s) const override { - const char *first = Contents.begin(); - const char *last = Contents.end() + 1; - + void printLeft(OutputBuffer &OB) const override { const size_t N = FloatData<Float>::mangled_size; - if (static_cast<std::size_t>(last - first) > N) { - last = first + N; + if (Contents.size() >= N) { union { Float value; char buf[sizeof(Float)]; }; - const char *t = first; + const char *t = Contents.data(); + const char *last = t + N; char *e = buf; for (; t != last; ++t, ++e) { unsigned d1 = isdigit(*t) ? static_cast<unsigned>(*t - '0') @@ -2176,7 +2353,7 @@ public: #endif char num[FloatData<Float>::max_demangled_size] = {0}; int n = snprintf(num, sizeof(num), FloatData<Float>::spec, value); - s += StringView(num, num + n); + OB += std::string_view(num, n); } } }; @@ -2190,143 +2367,22 @@ using LongDoubleLiteral = FloatLiteralImpl<long double>; template<typename Fn> void Node::visit(Fn F) const { switch (K) { -#define CASE(X) case K ## X: return F(static_cast<const X*>(this)); - FOR_EACH_NODE_KIND(CASE) -#undef CASE +#define NODE(X) \ + case K##X: \ + return F(static_cast<const X *>(this)); +#include "ItaniumNodes.def" } assert(0 && "unknown mangling node kind"); } /// Determine the kind of a node from its type. template<typename NodeT> struct NodeKind; -#define SPECIALIZATION(X) \ - template<> struct NodeKind<X> { \ - static constexpr Node::Kind Kind = Node::K##X; \ - static constexpr const char *name() { return #X; } \ +#define NODE(X) \ + template <> struct NodeKind<X> { \ + static constexpr Node::Kind Kind = Node::K##X; \ + static constexpr const char *name() { return #X; } \ }; -FOR_EACH_NODE_KIND(SPECIALIZATION) -#undef SPECIALIZATION - -#undef FOR_EACH_NODE_KIND - -template <class T, size_t N> -class PODSmallVector { - static_assert(std::is_pod<T>::value, - "T is required to be a plain old data type"); - - T* First; - T* Last; - T* Cap; - T Inline[N]; - - bool isInline() const { return First == Inline; } - - void clearInline() { - First = Inline; - Last = Inline; - Cap = Inline + N; - } - - void reserve(size_t NewCap) { - size_t S = size(); - if (isInline()) { - auto* Tmp = static_cast<T*>(std::malloc(NewCap * sizeof(T))); - if (Tmp == nullptr) - std::terminate(); - std::copy(First, Last, Tmp); - First = Tmp; - } else { - First = static_cast<T*>(std::realloc(First, NewCap * sizeof(T))); - if (First == nullptr) - std::terminate(); - } - Last = First + S; - Cap = First + NewCap; - } - -public: - PODSmallVector() : First(Inline), Last(First), Cap(Inline + N) {} - - PODSmallVector(const PODSmallVector&) = delete; - PODSmallVector& operator=(const PODSmallVector&) = delete; - - PODSmallVector(PODSmallVector&& Other) : PODSmallVector() { - if (Other.isInline()) { - std::copy(Other.begin(), Other.end(), First); - Last = First + Other.size(); - Other.clear(); - return; - } - - First = Other.First; - Last = Other.Last; - Cap = Other.Cap; - Other.clearInline(); - } - - PODSmallVector& operator=(PODSmallVector&& Other) { - if (Other.isInline()) { - if (!isInline()) { - std::free(First); - clearInline(); - } - std::copy(Other.begin(), Other.end(), First); - Last = First + Other.size(); - Other.clear(); - return *this; - } - - if (isInline()) { - First = Other.First; - Last = Other.Last; - Cap = Other.Cap; - Other.clearInline(); - return *this; - } - - std::swap(First, Other.First); - std::swap(Last, Other.Last); - std::swap(Cap, Other.Cap); - Other.clear(); - return *this; - } - - void push_back(const T& Elem) { - if (Last == Cap) - reserve(size() * 2); - *Last++ = Elem; - } - - void pop_back() { - assert(Last != First && "Popping empty vector!"); - --Last; - } - - void dropBack(size_t Index) { - assert(Index <= size() && "dropBack() can't expand!"); - Last = First + Index; - } - - T* begin() { return First; } - T* end() { return Last; } - - bool empty() const { return First == Last; } - size_t size() const { return static_cast<size_t>(Last - First); } - T& back() { - assert(Last != First && "Calling back() on empty vector!"); - return *(Last - 1); - } - T& operator[](size_t Index) { - assert(Index < size() && "Invalid access!"); - return *(begin() + Index); - } - void clear() { Last = First; } - - ~PODSmallVector() { - if (!isInline()) - std::free(First); - } -}; +#include "ItaniumNodes.def" template <typename Derived, typename Alloc> struct AbstractManglingParser { const char *First; @@ -2350,9 +2406,9 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser { TemplateParamList Params; public: - ScopedTemplateParamList(AbstractManglingParser *Parser) - : Parser(Parser), - OldNumTemplateParamLists(Parser->TemplateParams.size()) { + ScopedTemplateParamList(AbstractManglingParser *TheParser) + : Parser(TheParser), + OldNumTemplateParamLists(TheParser->TemplateParams.size()) { Parser->TemplateParams.push_back(&Params); } ~ScopedTemplateParamList() { @@ -2424,8 +2480,9 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser { return res; } - bool consumeIf(StringView S) { - if (StringView(First, Last).startsWith(S)) { + bool consumeIf(std::string_view S) { + if (llvm::itanium_demangle::starts_with( + std::string_view(First, Last - First), S)) { First += S.size(); return true; } @@ -2442,7 +2499,7 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser { char consume() { return First != Last ? *First++ : '\0'; } - char look(unsigned Lookahead = 0) { + char look(unsigned Lookahead = 0) const { if (static_cast<size_t>(Last - First) <= Lookahead) return '\0'; return First[Lookahead]; @@ -2450,10 +2507,10 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser { size_t numLeft() const { return static_cast<size_t>(Last - First); } - StringView parseNumber(bool AllowNegative = false); + std::string_view parseNumber(bool AllowNegative = false); Qualifiers parseCVQualifiers(); bool parsePositiveInteger(size_t *Out); - StringView parseBareSourceName(); + std::string_view parseBareSourceName(); bool parseSeqId(size_t *Out); Node *parseSubstitution(); @@ -2464,16 +2521,17 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser { /// Parse the <expr> production. Node *parseExpr(); - Node *parsePrefixExpr(StringView Kind); - Node *parseBinaryExpr(StringView Kind); - Node *parseIntegerLiteral(StringView Lit); + Node *parsePrefixExpr(std::string_view Kind, Node::Prec Prec); + Node *parseBinaryExpr(std::string_view Kind, Node::Prec Prec); + Node *parseIntegerLiteral(std::string_view Lit); Node *parseExprPrimary(); template <class Float> Node *parseFloatingLiteral(); Node *parseFunctionParam(); - Node *parseNewExpr(); Node *parseConversionExpr(); Node *parseBracedExpr(); Node *parseFoldExpr(); + Node *parsePointerToMemberConversionExpr(Node::Prec Prec); + Node *parseSubobjectExpr(); /// Parse the <type> production. Node *parseType(); @@ -2520,17 +2578,81 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser { Node *parseName(NameState *State = nullptr); Node *parseLocalName(NameState *State); Node *parseOperatorName(NameState *State); - Node *parseUnqualifiedName(NameState *State); + bool parseModuleNameOpt(ModuleName *&Module); + Node *parseUnqualifiedName(NameState *State, Node *Scope, ModuleName *Module); Node *parseUnnamedTypeName(NameState *State); Node *parseSourceName(NameState *State); - Node *parseUnscopedName(NameState *State); + Node *parseUnscopedName(NameState *State, bool *isSubstName); Node *parseNestedName(NameState *State); Node *parseCtorDtorName(Node *&SoFar, NameState *State); Node *parseAbiTags(Node *N); + struct OperatorInfo { + enum OIKind : unsigned char { + Prefix, // Prefix unary: @ expr + Postfix, // Postfix unary: expr @ + Binary, // Binary: lhs @ rhs + Array, // Array index: lhs [ rhs ] + Member, // Member access: lhs @ rhs + New, // New + Del, // Delete + Call, // Function call: expr (expr*) + CCast, // C cast: (type)expr + Conditional, // Conditional: expr ? expr : expr + NameOnly, // Overload only, not allowed in expression. + // Below do not have operator names + NamedCast, // Named cast, @<type>(expr) + OfIdOp, // alignof, sizeof, typeid + + Unnameable = NamedCast, + }; + char Enc[2]; // Encoding + OIKind Kind; // Kind of operator + bool Flag : 1; // Entry-specific flag + Node::Prec Prec : 7; // Precedence + const char *Name; // Spelling + + public: + constexpr OperatorInfo(const char (&E)[3], OIKind K, bool F, Node::Prec P, + const char *N) + : Enc{E[0], E[1]}, Kind{K}, Flag{F}, Prec{P}, Name{N} {} + + public: + bool operator<(const OperatorInfo &Other) const { + return *this < Other.Enc; + } + bool operator<(const char *Peek) const { + return Enc[0] < Peek[0] || (Enc[0] == Peek[0] && Enc[1] < Peek[1]); + } + bool operator==(const char *Peek) const { + return Enc[0] == Peek[0] && Enc[1] == Peek[1]; + } + bool operator!=(const char *Peek) const { return !this->operator==(Peek); } + + public: + std::string_view getSymbol() const { + std::string_view Res = Name; + if (Kind < Unnameable) { + assert(llvm::itanium_demangle::starts_with(Res, "operator") && + "operator name does not start with 'operator'"); + Res.remove_prefix(sizeof("operator") - 1); + if (llvm::itanium_demangle::starts_with(Res, ' ')) + Res.remove_prefix(1); + } + return Res; + } + std::string_view getName() const { return Name; } + OIKind getKind() const { return Kind; } + bool getFlag() const { return Flag; } + Node::Prec getPrecedence() const { return Prec; } + }; + static const OperatorInfo Ops[]; + static const size_t NumOps; + const OperatorInfo *parseOperatorEncoding(); + /// Parse the <unresolved-name> production. - Node *parseUnresolvedName(); + Node *parseUnresolvedName(bool Global); Node *parseSimpleId(); Node *parseBaseUnresolvedName(); Node *parseUnresolvedType(); @@ -2551,41 +2673,35 @@ const char* parse_discriminator(const char* first, const char* last); // ::= <substitution> template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseName(NameState *State) { - consumeIf('L'); // extension - if (look() == 'N') return getDerived().parseNestedName(State); if (look() == 'Z') return getDerived().parseLocalName(State); - // ::= <unscoped-template-name> <template-args> - if (look() == 'S' && look(1) != 't') { - Node *S = getDerived().parseSubstitution(); - if (S == nullptr) - return nullptr; - if (look() != 'I') - return nullptr; - Node *TA = getDerived().parseTemplateArgs(State != nullptr); - if (TA == nullptr) - return nullptr; - if (State) State->EndsWithTemplateArgs = true; - return make<NameWithTemplateArgs>(S, TA); - } + Node *Result = nullptr; + bool IsSubst = false; - Node *N = getDerived().parseUnscopedName(State); - if (N == nullptr) + Result = getDerived().parseUnscopedName(State, &IsSubst); + if (!Result) return nullptr; - // ::= <unscoped-template-name> <template-args> + if (look() == 'I') { - Subs.push_back(N); + // ::= <unscoped-template-name> <template-args> + if (!IsSubst) + // An unscoped-template-name is substitutable. + Subs.push_back(Result); Node *TA = getDerived().parseTemplateArgs(State != nullptr); if (TA == nullptr) return nullptr; - if (State) State->EndsWithTemplateArgs = true; - return make<NameWithTemplateArgs>(N, TA); + if (State) + State->EndsWithTemplateArgs = true; + Result = make<NameWithTemplateArgs>(Result, TA); + } else if (IsSubst) { + // The substitution case must be followed by <template-args>. + return nullptr; } - // ::= <unscoped-name> - return N; + + return Result; } // <local-name> := Z <function encoding> E <entity name> [<discriminator>] @@ -2626,34 +2742,63 @@ Node *AbstractManglingParser<Derived, Alloc>::parseLocalName(NameState *State) { // <unscoped-name> ::= <unqualified-name> // ::= St <unqualified-name> # ::std:: -// extension ::= StL<unqualified-name> +// [*] extension template <typename Derived, typename Alloc> Node * -AbstractManglingParser<Derived, Alloc>::parseUnscopedName(NameState *State) { - if (consumeIf("StL") || consumeIf("St")) { - Node *R = getDerived().parseUnqualifiedName(State); - if (R == nullptr) +AbstractManglingParser<Derived, Alloc>::parseUnscopedName(NameState *State, + bool *IsSubst) { + + Node *Std = nullptr; + if (consumeIf("St")) { + Std = make<NameType>("std"); + if (Std == nullptr) return nullptr; - return make<StdQualifiedName>(R); } - return getDerived().parseUnqualifiedName(State); + + Node *Res = nullptr; + ModuleName *Module = nullptr; + if (look() == 'S') { + Node *S = getDerived().parseSubstitution(); + if (!S) + return nullptr; + if (S->getKind() == Node::KModuleName) + Module = static_cast<ModuleName *>(S); + else if (IsSubst && Std == nullptr) { + Res = S; + *IsSubst = true; + } else { + return nullptr; + } + } + + if (Res == nullptr || Std != nullptr) { + Res = getDerived().parseUnqualifiedName(State, Std, Module); + } + + return Res; } -// <unqualified-name> ::= <operator-name> [abi-tags] -// ::= <ctor-dtor-name> -// ::= <source-name> -// ::= <unnamed-type-name> -// ::= DC <source-name>+ E # structured binding declaration +// <unqualified-name> ::= [<module-name>] L? <operator-name> [<abi-tags>] +// ::= [<module-name>] <ctor-dtor-name> [<abi-tags>] +// ::= [<module-name>] L? <source-name> [<abi-tags>] +// ::= [<module-name>] L? <unnamed-type-name> [<abi-tags>] +// # structured binding declaration +// ::= [<module-name>] L? DC <source-name>+ E template <typename Derived, typename Alloc> -Node * -AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(NameState *State) { - // <ctor-dtor-name>s are special-cased in parseNestedName(). +Node *AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName( + NameState *State, Node *Scope, ModuleName *Module) { + if (getDerived().parseModuleNameOpt(Module)) + return nullptr; + + consumeIf('L'); + Node *Result; - if (look() == 'U') - Result = getDerived().parseUnnamedTypeName(State); - else if (look() >= '1' && look() <= '9') + if (look() >= '1' && look() <= '9') { Result = getDerived().parseSourceName(State); - else if (consumeIf("DC")) { + } else if (look() == 'U') { + Result = getDerived().parseUnnamedTypeName(State); + } else if (consumeIf("DC")) { + // Structured binding size_t BindingsBegin = Names.size(); do { Node *Binding = getDerived().parseSourceName(State); @@ -2662,13 +2807,46 @@ AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(NameState *State) { Names.push_back(Binding); } while (!consumeIf('E')); Result = make<StructuredBindingName>(popTrailingNodeArray(BindingsBegin)); - } else + } else if (look() == 'C' || look() == 'D') { + // A <ctor-dtor-name>. + if (Scope == nullptr || Module != nullptr) + return nullptr; + Result = getDerived().parseCtorDtorName(Scope, State); + } else { Result = getDerived().parseOperatorName(State); + } + + if (Result != nullptr && Module != nullptr) + Result = make<ModuleEntity>(Module, Result); if (Result != nullptr) Result = getDerived().parseAbiTags(Result); + if (Result != nullptr && Scope != nullptr) + Result = make<NestedName>(Scope, Result); + return Result; } +// <module-name> ::= <module-subname> +// ::= <module-name> <module-subname> +// ::= <substitution> # passed in by caller +// <module-subname> ::= W <source-name> +// ::= W P <source-name> +template <typename Derived, typename Alloc> +bool AbstractManglingParser<Derived, Alloc>::parseModuleNameOpt( + ModuleName *&Module) { + while (consumeIf('W')) { + bool IsPartition = consumeIf('P'); + Node *Sub = getDerived().parseSourceName(nullptr); + if (!Sub) + return true; + Module = + static_cast<ModuleName *>(make<ModuleName>(Module, Sub, IsPartition)); + Subs.push_back(Module); + } + + return false; +} + // <unnamed-type-name> ::= Ut [<nonnegative number>] _ // ::= <closure-type-name> // @@ -2684,19 +2862,19 @@ AbstractManglingParser<Derived, Alloc>::parseUnnamedTypeName(NameState *State) { TemplateParams.clear(); if (consumeIf("Ut")) { - StringView Count = parseNumber(); + std::string_view Count = parseNumber(); if (!consumeIf('_')) return nullptr; return make<UnnamedTypeName>(Count); } if (consumeIf("Ul")) { - SwapAndRestore<size_t> SwapParams(ParsingLambdaParamsAtLevel, + ScopedOverride<size_t> SwapParams(ParsingLambdaParamsAtLevel, TemplateParams.size()); ScopedTemplateParamList LambdaTemplateParams(this); size_t ParamsBegin = Names.size(); while (look() == 'T' && - StringView("yptn").find(look(1)) != StringView::npos) { + std::string_view("yptn").find(look(1)) != std::string_view::npos) { Node *T = parseTemplateParamDecl(); if (!T) return nullptr; @@ -2739,7 +2917,7 @@ AbstractManglingParser<Derived, Alloc>::parseUnnamedTypeName(NameState *State) { } NodeArray Params = popTrailingNodeArray(ParamsBegin); - StringView Count = parseNumber(); + std::string_view Count = parseNumber(); if (!consumeIf('_')) return nullptr; return make<ClosureTypeName>(TempParams, Params, Count); @@ -2761,104 +2939,138 @@ Node *AbstractManglingParser<Derived, Alloc>::parseSourceName(NameState *) { return nullptr; if (numLeft() < Length || Length == 0) return nullptr; - StringView Name(First, First + Length); + std::string_view Name(First, Length); First += Length; - if (Name.startsWith("_GLOBAL__N")) + if (llvm::itanium_demangle::starts_with(Name, "_GLOBAL__N")) return make<NameType>("(anonymous namespace)"); return make<NameType>(Name); } -// <operator-name> ::= aa # && -// ::= ad # & (unary) -// ::= an # & -// ::= aN # &= -// ::= aS # = -// ::= cl # () -// ::= cm # , -// ::= co # ~ -// ::= cv <type> # (cast) -// ::= da # delete[] -// ::= de # * (unary) -// ::= dl # delete -// ::= dv # / -// ::= dV # /= -// ::= eo # ^ -// ::= eO # ^= -// ::= eq # == -// ::= ge # >= -// ::= gt # > -// ::= ix # [] -// ::= le # <= +// Operator encodings +template <typename Derived, typename Alloc> +const typename AbstractManglingParser< + Derived, Alloc>::OperatorInfo AbstractManglingParser<Derived, + Alloc>::Ops[] = { + // Keep ordered by encoding + {"aN", OperatorInfo::Binary, false, Node::Prec::Assign, "operator&="}, + {"aS", OperatorInfo::Binary, false, Node::Prec::Assign, "operator="}, + {"aa", OperatorInfo::Binary, false, Node::Prec::AndIf, "operator&&"}, + {"ad", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator&"}, + {"an", OperatorInfo::Binary, false, Node::Prec::And, "operator&"}, + {"at", OperatorInfo::OfIdOp, /*Type*/ true, Node::Prec::Unary, "alignof "}, + {"aw", OperatorInfo::NameOnly, false, Node::Prec::Primary, + "operator co_await"}, + {"az", OperatorInfo::OfIdOp, /*Type*/ false, Node::Prec::Unary, "alignof "}, + {"cc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, "const_cast"}, + {"cl", OperatorInfo::Call, false, Node::Prec::Postfix, "operator()"}, + {"cm", OperatorInfo::Binary, false, Node::Prec::Comma, "operator,"}, + {"co", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator~"}, + {"cv", OperatorInfo::CCast, false, Node::Prec::Cast, "operator"}, // C Cast + {"dV", OperatorInfo::Binary, false, Node::Prec::Assign, "operator/="}, + {"da", OperatorInfo::Del, /*Ary*/ true, Node::Prec::Unary, + "operator delete[]"}, + {"dc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, "dynamic_cast"}, + {"de", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator*"}, + {"dl", OperatorInfo::Del, /*Ary*/ false, Node::Prec::Unary, + "operator delete"}, + {"ds", OperatorInfo::Member, /*Named*/ false, Node::Prec::PtrMem, + "operator.*"}, + {"dt", OperatorInfo::Member, /*Named*/ false, Node::Prec::Postfix, + "operator."}, + {"dv", OperatorInfo::Binary, false, Node::Prec::Assign, "operator/"}, + {"eO", OperatorInfo::Binary, false, Node::Prec::Assign, "operator^="}, + {"eo", OperatorInfo::Binary, false, Node::Prec::Xor, "operator^"}, + {"eq", OperatorInfo::Binary, false, Node::Prec::Equality, "operator=="}, + {"ge", OperatorInfo::Binary, false, Node::Prec::Relational, "operator>="}, + {"gt", OperatorInfo::Binary, false, Node::Prec::Relational, "operator>"}, + {"ix", OperatorInfo::Array, false, Node::Prec::Postfix, "operator[]"}, + {"lS", OperatorInfo::Binary, false, Node::Prec::Assign, "operator<<="}, + {"le", OperatorInfo::Binary, false, Node::Prec::Relational, "operator<="}, + {"ls", OperatorInfo::Binary, false, Node::Prec::Shift, "operator<<"}, + {"lt", OperatorInfo::Binary, false, Node::Prec::Relational, "operator<"}, + {"mI", OperatorInfo::Binary, false, Node::Prec::Assign, "operator-="}, + {"mL", OperatorInfo::Binary, false, Node::Prec::Assign, "operator*="}, + {"mi", OperatorInfo::Binary, false, Node::Prec::Additive, "operator-"}, + {"ml", OperatorInfo::Binary, false, Node::Prec::Multiplicative, + "operator*"}, + {"mm", OperatorInfo::Postfix, false, Node::Prec::Postfix, "operator--"}, + {"na", OperatorInfo::New, /*Ary*/ true, Node::Prec::Unary, + "operator new[]"}, + {"ne", OperatorInfo::Binary, false, Node::Prec::Equality, "operator!="}, + {"ng", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator-"}, + {"nt", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator!"}, + {"nw", OperatorInfo::New, /*Ary*/ false, Node::Prec::Unary, "operator new"}, + {"oR", OperatorInfo::Binary, false, Node::Prec::Assign, "operator|="}, + {"oo", OperatorInfo::Binary, false, Node::Prec::OrIf, "operator||"}, + {"or", OperatorInfo::Binary, false, Node::Prec::Ior, "operator|"}, + {"pL", OperatorInfo::Binary, false, Node::Prec::Assign, "operator+="}, + {"pl", OperatorInfo::Binary, false, Node::Prec::Additive, "operator+"}, + {"pm", OperatorInfo::Member, /*Named*/ false, Node::Prec::PtrMem, + "operator->*"}, + {"pp", OperatorInfo::Postfix, false, Node::Prec::Postfix, "operator++"}, + {"ps", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator+"}, + {"pt", OperatorInfo::Member, /*Named*/ true, Node::Prec::Postfix, + "operator->"}, + {"qu", OperatorInfo::Conditional, false, Node::Prec::Conditional, + "operator?"}, + {"rM", OperatorInfo::Binary, false, Node::Prec::Assign, "operator%="}, + {"rS", OperatorInfo::Binary, false, Node::Prec::Assign, "operator>>="}, + {"rc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, + "reinterpret_cast"}, + {"rm", OperatorInfo::Binary, false, Node::Prec::Multiplicative, + "operator%"}, + {"rs", OperatorInfo::Binary, false, Node::Prec::Shift, "operator>>"}, + {"sc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, "static_cast"}, + {"ss", OperatorInfo::Binary, false, Node::Prec::Spaceship, "operator<=>"}, + {"st", OperatorInfo::OfIdOp, /*Type*/ true, Node::Prec::Unary, "sizeof "}, + {"sz", OperatorInfo::OfIdOp, /*Type*/ false, Node::Prec::Unary, "sizeof "}, + {"te", OperatorInfo::OfIdOp, /*Type*/ false, Node::Prec::Postfix, + "typeid "}, + {"ti", OperatorInfo::OfIdOp, /*Type*/ true, Node::Prec::Postfix, "typeid "}, +}; +template <typename Derived, typename Alloc> +const size_t AbstractManglingParser<Derived, Alloc>::NumOps = sizeof(Ops) / + sizeof(Ops[0]); + +// If the next 2 chars are an operator encoding, consume them and return their +// OperatorInfo. Otherwise return nullptr. +template <typename Derived, typename Alloc> +const typename AbstractManglingParser<Derived, Alloc>::OperatorInfo * +AbstractManglingParser<Derived, Alloc>::parseOperatorEncoding() { + if (numLeft() < 2) + return nullptr; + + // We can't use lower_bound as that can link to symbols in the C++ library, + // and this must remain independant of that. + size_t lower = 0u, upper = NumOps - 1; // Inclusive bounds. + while (upper != lower) { + size_t middle = (upper + lower) / 2; + if (Ops[middle] < First) + lower = middle + 1; + else + upper = middle; + } + if (Ops[lower] != First) + return nullptr; + + First += 2; + return &Ops[lower]; +} + +// <operator-name> ::= See parseOperatorEncoding() // ::= li <source-name> # operator "" -// ::= ls # << -// ::= lS # <<= -// ::= lt # < -// ::= mi # - -// ::= mI # -= -// ::= ml # * -// ::= mL # *= -// ::= mm # -- (postfix in <expression> context) -// ::= na # new[] -// ::= ne # != -// ::= ng # - (unary) -// ::= nt # ! -// ::= nw # new -// ::= oo # || -// ::= or # | -// ::= oR # |= -// ::= pm # ->* -// ::= pl # + -// ::= pL # += -// ::= pp # ++ (postfix in <expression> context) -// ::= ps # + (unary) -// ::= pt # -> -// ::= qu # ? -// ::= rm # % -// ::= rM # %= -// ::= rs # >> -// ::= rS # >>= -// ::= ss # <=> C++2a -// ::= v <digit> <source-name> # vendor extended operator +// ::= v <digit> <source-name> # vendor extended operator template <typename Derived, typename Alloc> Node * AbstractManglingParser<Derived, Alloc>::parseOperatorName(NameState *State) { - switch (look()) { - case 'a': - switch (look(1)) { - case 'a': - First += 2; - return make<NameType>("operator&&"); - case 'd': - case 'n': - First += 2; - return make<NameType>("operator&"); - case 'N': - First += 2; - return make<NameType>("operator&="); - case 'S': - First += 2; - return make<NameType>("operator="); - } - return nullptr; - case 'c': - switch (look(1)) { - case 'l': - First += 2; - return make<NameType>("operator()"); - case 'm': - First += 2; - return make<NameType>("operator,"); - case 'o': - First += 2; - return make<NameType>("operator~"); - // ::= cv <type> # (cast) - case 'v': { - First += 2; - SwapAndRestore<bool> SaveTemplate(TryToParseTemplateArgs, false); + if (const auto *Op = parseOperatorEncoding()) { + if (Op->getKind() == OperatorInfo::CCast) { + // ::= cv <type> # (cast) + ScopedOverride<bool> SaveTemplate(TryToParseTemplateArgs, false); // If we're parsing an encoding, State != nullptr and the conversion // operators' <type> could have a <template-param> that refers to some // <template-arg>s further ahead in the mangled name. - SwapAndRestore<bool> SavePermit(PermitForwardTemplateReferences, + ScopedOverride<bool> SavePermit(PermitForwardTemplateReferences, PermitForwardTemplateReferences || State != nullptr); Node *Ty = getDerived().parseType(); @@ -2867,185 +3079,29 @@ AbstractManglingParser<Derived, Alloc>::parseOperatorName(NameState *State) { if (State) State->CtorDtorConversion = true; return make<ConversionOperatorType>(Ty); } - } - return nullptr; - case 'd': - switch (look(1)) { - case 'a': - First += 2; - return make<NameType>("operator delete[]"); - case 'e': - First += 2; - return make<NameType>("operator*"); - case 'l': - First += 2; - return make<NameType>("operator delete"); - case 'v': - First += 2; - return make<NameType>("operator/"); - case 'V': - First += 2; - return make<NameType>("operator/="); - } - return nullptr; - case 'e': - switch (look(1)) { - case 'o': - First += 2; - return make<NameType>("operator^"); - case 'O': - First += 2; - return make<NameType>("operator^="); - case 'q': - First += 2; - return make<NameType>("operator=="); - } - return nullptr; - case 'g': - switch (look(1)) { - case 'e': - First += 2; - return make<NameType>("operator>="); - case 't': - First += 2; - return make<NameType>("operator>"); - } - return nullptr; - case 'i': - if (look(1) == 'x') { - First += 2; - return make<NameType>("operator[]"); - } - return nullptr; - case 'l': - switch (look(1)) { - case 'e': - First += 2; - return make<NameType>("operator<="); + + if (Op->getKind() >= OperatorInfo::Unnameable) + /* Not a nameable operator. */ + return nullptr; + if (Op->getKind() == OperatorInfo::Member && !Op->getFlag()) + /* Not a nameable MemberExpr */ + return nullptr; + + return make<NameType>(Op->getName()); + } + + if (consumeIf("li")) { // ::= li <source-name> # operator "" - case 'i': { - First += 2; - Node *SN = getDerived().parseSourceName(State); - if (SN == nullptr) - return nullptr; - return make<LiteralOperator>(SN); - } - case 's': - First += 2; - return make<NameType>("operator<<"); - case 'S': - First += 2; - return make<NameType>("operator<<="); - case 't': - First += 2; - return make<NameType>("operator<"); - } - return nullptr; - case 'm': - switch (look(1)) { - case 'i': - First += 2; - return make<NameType>("operator-"); - case 'I': - First += 2; - return make<NameType>("operator-="); - case 'l': - First += 2; - return make<NameType>("operator*"); - case 'L': - First += 2; - return make<NameType>("operator*="); - case 'm': - First += 2; - return make<NameType>("operator--"); - } - return nullptr; - case 'n': - switch (look(1)) { - case 'a': - First += 2; - return make<NameType>("operator new[]"); - case 'e': - First += 2; - return make<NameType>("operator!="); - case 'g': - First += 2; - return make<NameType>("operator-"); - case 't': - First += 2; - return make<NameType>("operator!"); - case 'w': - First += 2; - return make<NameType>("operator new"); - } - return nullptr; - case 'o': - switch (look(1)) { - case 'o': - First += 2; - return make<NameType>("operator||"); - case 'r': - First += 2; - return make<NameType>("operator|"); - case 'R': - First += 2; - return make<NameType>("operator|="); - } - return nullptr; - case 'p': - switch (look(1)) { - case 'm': - First += 2; - return make<NameType>("operator->*"); - case 'l': - First += 2; - return make<NameType>("operator+"); - case 'L': - First += 2; - return make<NameType>("operator+="); - case 'p': - First += 2; - return make<NameType>("operator++"); - case 's': - First += 2; - return make<NameType>("operator+"); - case 't': - First += 2; - return make<NameType>("operator->"); - } - return nullptr; - case 'q': - if (look(1) == 'u') { - First += 2; - return make<NameType>("operator?"); - } - return nullptr; - case 'r': - switch (look(1)) { - case 'm': - First += 2; - return make<NameType>("operator%"); - case 'M': - First += 2; - return make<NameType>("operator%="); - case 's': - First += 2; - return make<NameType>("operator>>"); - case 'S': - First += 2; - return make<NameType>("operator>>="); - } - return nullptr; - case 's': - if (look(1) == 's') { - First += 2; - return make<NameType>("operator<=>"); - } - return nullptr; - // ::= v <digit> <source-name> # vendor extended operator - case 'v': - if (std::isdigit(look(1))) { - First += 2; + Node *SN = getDerived().parseSourceName(State); + if (SN == nullptr) + return nullptr; + return make<LiteralOperator>(SN); + } + + if (consumeIf('v')) { + // ::= v <digit> <source-name> # vendor extended operator + if (look() >= '0' && look() <= '9') { + First++; Node *SN = getDerived().parseSourceName(State); if (SN == nullptr) return nullptr; @@ -3053,6 +3109,7 @@ AbstractManglingParser<Derived, Alloc>::parseOperatorName(NameState *State) { } return nullptr; } + return nullptr; } @@ -3071,19 +3128,11 @@ Node * AbstractManglingParser<Derived, Alloc>::parseCtorDtorName(Node *&SoFar, NameState *State) { if (SoFar->getKind() == Node::KSpecialSubstitution) { - auto SSK = static_cast<SpecialSubstitution *>(SoFar)->SSK; - switch (SSK) { - case SpecialSubKind::string: - case SpecialSubKind::istream: - case SpecialSubKind::ostream: - case SpecialSubKind::iostream: - SoFar = make<ExpandedSpecialSubstitution>(SSK); - if (!SoFar) - return nullptr; - break; - default: - break; - } + // Expand the special substitution. + SoFar = make<ExpandedSpecialSubstitution>( + static_cast<SpecialSubstitution *>(SoFar)); + if (!SoFar) + return nullptr; } if (consumeIf('C')) { @@ -3112,8 +3161,10 @@ AbstractManglingParser<Derived, Alloc>::parseCtorDtorName(Node *&SoFar, return nullptr; } -// <nested-name> ::= N [<CV-Qualifiers>] [<ref-qualifier>] <prefix> <unqualified-name> E -// ::= N [<CV-Qualifiers>] [<ref-qualifier>] <template-prefix> <template-args> E +// <nested-name> ::= N [<CV-Qualifiers>] [<ref-qualifier>] <prefix> +// <unqualified-name> E +// ::= N [<CV-Qualifiers>] [<ref-qualifier>] <template-prefix> +// <template-args> E // // <prefix> ::= <prefix> <unqualified-name> // ::= <template-prefix> <template-args> @@ -3122,7 +3173,7 @@ AbstractManglingParser<Derived, Alloc>::parseCtorDtorName(Node *&SoFar, // ::= # empty // ::= <substitution> // ::= <prefix> <data-member-prefix> -// extension ::= L +// [*] extension // // <data-member-prefix> := <member source-name> [<template-args>] M // @@ -3142,90 +3193,76 @@ AbstractManglingParser<Derived, Alloc>::parseNestedName(NameState *State) { if (State) State->ReferenceQualifier = FrefQualRValue; } else if (consumeIf('R')) { if (State) State->ReferenceQualifier = FrefQualLValue; - } else + } else { if (State) State->ReferenceQualifier = FrefQualNone; - - Node *SoFar = nullptr; - auto PushComponent = [&](Node *Comp) { - if (!Comp) return false; - if (SoFar) SoFar = make<NestedName>(SoFar, Comp); - else SoFar = Comp; - if (State) State->EndsWithTemplateArgs = false; - return SoFar != nullptr; - }; - - if (consumeIf("St")) { - SoFar = make<NameType>("std"); - if (!SoFar) - return nullptr; } + Node *SoFar = nullptr; while (!consumeIf('E')) { - consumeIf('L'); // extension + if (State) + // Only set end-with-template on the case that does that. + State->EndsWithTemplateArgs = false; - // <data-member-prefix> := <member source-name> [<template-args>] M - if (consumeIf('M')) { - if (SoFar == nullptr) - return nullptr; - continue; - } - - // ::= <template-param> if (look() == 'T') { - if (!PushComponent(getDerived().parseTemplateParam())) - return nullptr; - Subs.push_back(SoFar); - continue; - } - - // ::= <template-prefix> <template-args> - if (look() == 'I') { + // ::= <template-param> + if (SoFar != nullptr) + return nullptr; // Cannot have a prefix. + SoFar = getDerived().parseTemplateParam(); + } else if (look() == 'I') { + // ::= <template-prefix> <template-args> + if (SoFar == nullptr) + return nullptr; // Must have a prefix. Node *TA = getDerived().parseTemplateArgs(State != nullptr); - if (TA == nullptr || SoFar == nullptr) - return nullptr; - SoFar = make<NameWithTemplateArgs>(SoFar, TA); - if (!SoFar) - return nullptr; - if (State) State->EndsWithTemplateArgs = true; - Subs.push_back(SoFar); - continue; - } - - // ::= <decltype> - if (look() == 'D' && (look(1) == 't' || look(1) == 'T')) { - if (!PushComponent(getDerived().parseDecltype())) + if (TA == nullptr) return nullptr; - Subs.push_back(SoFar); - continue; - } - - // ::= <substitution> - if (look() == 'S' && look(1) != 't') { - Node *S = getDerived().parseSubstitution(); - if (!PushComponent(S)) + if (SoFar->getKind() == Node::KNameWithTemplateArgs) + // Semantically <template-args> <template-args> cannot be generated by a + // C++ entity. There will always be [something like] a name between + // them. return nullptr; - if (SoFar != S) - Subs.push_back(S); - continue; - } + if (State) + State->EndsWithTemplateArgs = true; + SoFar = make<NameWithTemplateArgs>(SoFar, TA); + } else if (look() == 'D' && (look(1) == 't' || look(1) == 'T')) { + // ::= <decltype> + if (SoFar != nullptr) + return nullptr; // Cannot have a prefix. + SoFar = getDerived().parseDecltype(); + } else { + ModuleName *Module = nullptr; + + if (look() == 'S') { + // ::= <substitution> + Node *S = nullptr; + if (look(1) == 't') { + First += 2; + S = make<NameType>("std"); + } else { + S = getDerived().parseSubstitution(); + } + if (!S) + return nullptr; + if (S->getKind() == Node::KModuleName) { + Module = static_cast<ModuleName *>(S); + } else if (SoFar != nullptr) { + return nullptr; // Cannot have a prefix. + } else { + SoFar = S; + continue; // Do not push a new substitution. + } + } - // Parse an <unqualified-name> thats actually a <ctor-dtor-name>. - if (look() == 'C' || (look() == 'D' && look(1) != 'C')) { - if (SoFar == nullptr) - return nullptr; - if (!PushComponent(getDerived().parseCtorDtorName(SoFar, State))) - return nullptr; - SoFar = getDerived().parseAbiTags(SoFar); - if (SoFar == nullptr) - return nullptr; - Subs.push_back(SoFar); - continue; + // ::= [<prefix>] <unqualified-name> + SoFar = getDerived().parseUnqualifiedName(State, SoFar, Module); } - // ::= <prefix> <unqualified-name> - if (!PushComponent(getDerived().parseUnqualifiedName(State))) + if (SoFar == nullptr) return nullptr; Subs.push_back(SoFar); + + // No longer used. + // <data-member-prefix> := <member source-name> [<template-args>] M + consumeIf('M'); } if (SoFar == nullptr || Subs.empty()) @@ -3320,6 +3357,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseBaseUnresolvedName() { // ::= [gs] <base-unresolved-name> # x or (with "gs") ::x // ::= [gs] sr <unresolved-qualifier-level>+ E <base-unresolved-name> // # A::x, N::y, A<T>::z; "gs" means leading "::" +// [gs] has been parsed by caller. // ::= sr <unresolved-type> <base-unresolved-name> # T::x / decltype(p)::x // extension ::= sr <unresolved-type> <template-args> <base-unresolved-name> // # T::N::x /decltype(p)::N::x @@ -3327,7 +3365,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseBaseUnresolvedName() { // // <unresolved-qualifier-level> ::= <simple-id> template <typename Derived, typename Alloc> -Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName() { +Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName(bool Global) { Node *SoFar = nullptr; // srN <unresolved-type> [<template-args>] <unresolved-qualifier-level>* E <base-unresolved-name> @@ -3361,8 +3399,6 @@ Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName() { return make<QualifiedName>(SoFar, Base); } - bool Global = consumeIf("gs"); - // [gs] <base-unresolved-name> # x or (with "gs") ::x if (!consumeIf("sr")) { SoFar = getDerived().parseBaseUnresolvedName(); @@ -3419,7 +3455,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName() { template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseAbiTags(Node *N) { while (consumeIf('B')) { - StringView SN = parseBareSourceName(); + std::string_view SN = parseBareSourceName(); if (SN.empty()) return nullptr; N = make<AbiTagAttr>(N, SN); @@ -3431,16 +3467,16 @@ Node *AbstractManglingParser<Derived, Alloc>::parseAbiTags(Node *N) { // <number> ::= [n] <non-negative decimal integer> template <typename Alloc, typename Derived> -StringView +std::string_view AbstractManglingParser<Alloc, Derived>::parseNumber(bool AllowNegative) { const char *Tmp = First; if (AllowNegative) consumeIf('n'); if (numLeft() == 0 || !std::isdigit(*First)) - return StringView(); + return std::string_view(); while (numLeft() != 0 && std::isdigit(*First)) ++First; - return StringView(Tmp, First); + return std::string_view(Tmp, First - Tmp); } // <positive length number> ::= [0-9]* @@ -3457,11 +3493,11 @@ bool AbstractManglingParser<Alloc, Derived>::parsePositiveInteger(size_t *Out) { } template <typename Alloc, typename Derived> -StringView AbstractManglingParser<Alloc, Derived>::parseBareSourceName() { +std::string_view AbstractManglingParser<Alloc, Derived>::parseBareSourceName() { size_t Int = 0; if (parsePositiveInteger(&Int) || numLeft() < Int) - return StringView(); - StringView R(First, First + Int); + return {}; + std::string_view R(First, Int); First += Int; return R; } @@ -3549,7 +3585,9 @@ Node *AbstractManglingParser<Derived, Alloc>::parseVectorType() { if (!consumeIf("Dv")) return nullptr; if (look() >= '1' && look() <= '9') { - StringView DimensionNumber = parseNumber(); + Node *DimensionNumber = make<NameType>(parseNumber()); + if (!DimensionNumber) + return nullptr; if (!consumeIf('_')) return nullptr; if (consumeIf('p')) @@ -3574,7 +3612,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseVectorType() { Node *ElemType = getDerived().parseType(); if (!ElemType) return nullptr; - return make<VectorType>(ElemType, StringView()); + return make<VectorType>(ElemType, /*Dimension=*/nullptr); } // <decltype> ::= Dt <expression> E # decltype of an id-expression or class member access (C++0x) @@ -3590,7 +3628,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseDecltype() { return nullptr; if (!consumeIf('E')) return nullptr; - return make<EnclosingExpr>("decltype(", E, ")"); + return make<EnclosingExpr>("decltype", E); } // <array-type> ::= A <positive dimension number> _ <element type> @@ -3600,10 +3638,12 @@ Node *AbstractManglingParser<Derived, Alloc>::parseArrayType() { if (!consumeIf('A')) return nullptr; - NodeOrString Dimension; + Node *Dimension = nullptr; if (std::isdigit(look())) { - Dimension = parseNumber(); + Dimension = make<NameType>(parseNumber()); + if (!Dimension) + return nullptr; if (!consumeIf('_')) return nullptr; } else if (!consumeIf('_')) { @@ -3641,7 +3681,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parsePointerToMemberType() { // ::= Te <name> # dependent elaborated type specifier using 'enum' template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseClassEnumType() { - StringView ElabSpef; + std::string_view ElabSpef; if (consumeIf("Ts")) ElabSpef = "struct"; else if (consumeIf("Tu")) @@ -3665,19 +3705,18 @@ Node *AbstractManglingParser<Derived, Alloc>::parseClassEnumType() { template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseQualifiedType() { if (consumeIf('U')) { - StringView Qual = parseBareSourceName(); + std::string_view Qual = parseBareSourceName(); if (Qual.empty()) return nullptr; - // FIXME parse the optional <template-args> here! - // extension ::= U <objc-name> <objc-type> # objc-type<identifier> - if (Qual.startsWith("objcproto")) { - StringView ProtoSourceName = Qual.dropFront(std::strlen("objcproto")); - StringView Proto; + if (llvm::itanium_demangle::starts_with(Qual, "objcproto")) { + constexpr size_t Len = sizeof("objcproto") - 1; + std::string_view ProtoSourceName(Qual.data() + Len, Qual.size() - Len); + std::string_view Proto; { - SwapAndRestore<const char *> SaveFirst(First, ProtoSourceName.begin()), - SaveLast(Last, ProtoSourceName.end()); + ScopedOverride<const char *> SaveFirst(First, ProtoSourceName.data()), + SaveLast(Last, &*ProtoSourceName.rbegin() + 1); Proto = parseBareSourceName(); } if (Proto.empty()) @@ -3688,10 +3727,17 @@ Node *AbstractManglingParser<Derived, Alloc>::parseQualifiedType() { return make<ObjCProtoName>(Child, Proto); } + Node *TA = nullptr; + if (look() == 'I') { + TA = getDerived().parseTemplateArgs(); + if (TA == nullptr) + return nullptr; + } + Node *Child = getDerived().parseQualifiedType(); if (Child == nullptr) return nullptr; - return make<VendorExtQualType>(Child, Qual); + return make<VendorExtQualType>(Child, Qual, TA); } Qualifiers Quals = parseCVQualifiers(); @@ -3838,7 +3884,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() { // <builtin-type> ::= u <source-name> # vendor extended type case 'u': { ++First; - StringView Res = parseBareSourceName(); + std::string_view Res = parseBareSourceName(); if (Res.empty()) return nullptr; // Typically, <builtin-type>s are not considered substitution candidates, @@ -3864,7 +3910,33 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() { // ::= Dh # IEEE 754r half-precision floating point (16 bits) case 'h': First += 2; - return make<NameType>("decimal16"); + return make<NameType>("half"); + // ::= DF <number> _ # ISO/IEC TS 18661 binary floating point (N bits) + case 'F': { + First += 2; + Node *DimensionNumber = make<NameType>(parseNumber()); + if (!DimensionNumber) + return nullptr; + if (!consumeIf('_')) + return nullptr; + return make<BinaryFPType>(DimensionNumber); + } + // ::= DB <number> _ # C23 signed _BitInt(N) + // ::= DB <instantiation-dependent expression> _ # C23 signed _BitInt(N) + // ::= DU <number> _ # C23 unsigned _BitInt(N) + // ::= DU <instantiation-dependent expression> _ # C23 unsigned _BitInt(N) + case 'B': + case 'U': { + bool Signed = look(1) == 'B'; + First += 2; + Node *Size = std::isdigit(look()) ? make<NameType>(parseNumber()) + : getDerived().parseExpr(); + if (!Size) + return nullptr; + if (!consumeIf('_')) + return nullptr; + return make<BitIntType>(Size, Signed); + } // ::= Di # char32_t case 'i': First += 2; @@ -4012,9 +4084,10 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() { } // ::= <substitution> # See Compression below case 'S': { - if (look(1) && look(1) != 't') { - Node *Sub = getDerived().parseSubstitution(); - if (Sub == nullptr) + if (look(1) != 't') { + bool IsSubst = false; + Result = getDerived().parseUnscopedName(nullptr, &IsSubst); + if (!Result) return nullptr; // Sub could be either of: @@ -4027,17 +4100,19 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() { // If this is followed by some <template-args>, and we're permitted to // parse them, take the second production. - if (TryToParseTemplateArgs && look() == 'I') { + if (look() == 'I' && (!IsSubst || TryToParseTemplateArgs)) { + if (!IsSubst) + Subs.push_back(Result); Node *TA = getDerived().parseTemplateArgs(); if (TA == nullptr) return nullptr; - Result = make<NameWithTemplateArgs>(Sub, TA); - break; + Result = make<NameWithTemplateArgs>(Result, TA); + } else if (IsSubst) { + // If all we parsed was a substitution, don't re-insert into the + // substitution table. + return Result; } - - // If all we parsed was a substitution, don't re-insert into the - // substitution table. - return Sub; + break; } DEMANGLE_FALLTHROUGH; } @@ -4057,28 +4132,32 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() { } template <typename Derived, typename Alloc> -Node *AbstractManglingParser<Derived, Alloc>::parsePrefixExpr(StringView Kind) { +Node * +AbstractManglingParser<Derived, Alloc>::parsePrefixExpr(std::string_view Kind, + Node::Prec Prec) { Node *E = getDerived().parseExpr(); if (E == nullptr) return nullptr; - return make<PrefixExpr>(Kind, E); + return make<PrefixExpr>(Kind, E, Prec); } template <typename Derived, typename Alloc> -Node *AbstractManglingParser<Derived, Alloc>::parseBinaryExpr(StringView Kind) { +Node * +AbstractManglingParser<Derived, Alloc>::parseBinaryExpr(std::string_view Kind, + Node::Prec Prec) { Node *LHS = getDerived().parseExpr(); if (LHS == nullptr) return nullptr; Node *RHS = getDerived().parseExpr(); if (RHS == nullptr) return nullptr; - return make<BinaryExpr>(LHS, Kind, RHS); + return make<BinaryExpr>(LHS, Kind, RHS, Prec); } template <typename Derived, typename Alloc> -Node * -AbstractManglingParser<Derived, Alloc>::parseIntegerLiteral(StringView Lit) { - StringView Tmp = parseNumber(true); +Node *AbstractManglingParser<Derived, Alloc>::parseIntegerLiteral( + std::string_view Lit) { + std::string_view Tmp = parseNumber(true); if (!Tmp.empty() && consumeIf('E')) return make<IntegerLiteral>(Lit, Tmp); return nullptr; @@ -4101,11 +4180,14 @@ Qualifiers AbstractManglingParser<Alloc, Derived>::parseCVQualifiers() { // ::= fp <top-level CV-Qualifiers> <parameter-2 non-negative number> _ # L == 0, second and later parameters // ::= fL <L-1 non-negative number> p <top-level CV-Qualifiers> _ # L > 0, first parameter // ::= fL <L-1 non-negative number> p <top-level CV-Qualifiers> <parameter-2 non-negative number> _ # L > 0, second and later parameters +// ::= fpT # 'this' expression (not part of standard?) template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseFunctionParam() { + if (consumeIf("fpT")) + return make<NameType>("this"); if (consumeIf("fp")) { parseCVQualifiers(); - StringView Num = parseNumber(); + std::string_view Num = parseNumber(); if (!consumeIf('_')) return nullptr; return make<FunctionParam>(Num); @@ -4116,7 +4198,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFunctionParam() { if (!consumeIf('p')) return nullptr; parseCVQualifiers(); - StringView Num = parseNumber(); + std::string_view Num = parseNumber(); if (!consumeIf('_')) return nullptr; return make<FunctionParam>(Num); @@ -4124,43 +4206,6 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFunctionParam() { return nullptr; } -// [gs] nw <expression>* _ <type> E # new (expr-list) type -// [gs] nw <expression>* _ <type> <initializer> # new (expr-list) type (init) -// [gs] na <expression>* _ <type> E # new[] (expr-list) type -// [gs] na <expression>* _ <type> <initializer> # new[] (expr-list) type (init) -// <initializer> ::= pi <expression>* E # parenthesized initialization -template <typename Derived, typename Alloc> -Node *AbstractManglingParser<Derived, Alloc>::parseNewExpr() { - bool Global = consumeIf("gs"); - bool IsArray = look(1) == 'a'; - if (!consumeIf("nw") && !consumeIf("na")) - return nullptr; - size_t Exprs = Names.size(); - while (!consumeIf('_')) { - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return nullptr; - Names.push_back(Ex); - } - NodeArray ExprList = popTrailingNodeArray(Exprs); - Node *Ty = getDerived().parseType(); - if (Ty == nullptr) - return Ty; - if (consumeIf("pi")) { - size_t InitsBegin = Names.size(); - while (!consumeIf('E')) { - Node *Init = getDerived().parseExpr(); - if (Init == nullptr) - return Init; - Names.push_back(Init); - } - NodeArray Inits = popTrailingNodeArray(InitsBegin); - return make<NewExpr>(ExprList, Ty, Inits, Global, IsArray); - } else if (!consumeIf('E')) - return nullptr; - return make<NewExpr>(ExprList, Ty, NodeArray(), Global, IsArray); -} - // cv <type> <expression> # conversion with one argument // cv <type> _ <expression>* E # conversion with a different number of arguments template <typename Derived, typename Alloc> @@ -4169,7 +4214,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseConversionExpr() { return nullptr; Node *Ty; { - SwapAndRestore<bool> SaveTemp(TryToParseTemplateArgs, false); + ScopedOverride<bool> SaveTemp(TryToParseTemplateArgs, false); Ty = getDerived().parseType(); } @@ -4262,7 +4307,13 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExprPrimary() { return getDerived().template parseFloatingLiteral<double>(); case 'e': ++First; +#if defined(__powerpc__) || defined(__s390__) + // Handle cases where long doubles encoded with e have the same size + // and representation as doubles. + return getDerived().template parseFloatingLiteral<double>(); +#else return getDerived().template parseFloatingLiteral<long double>(); +#endif case '_': if (consumeIf("_Z")) { Node *R = getDerived().parseEncoding(); @@ -4280,7 +4331,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExprPrimary() { return nullptr; } case 'D': - if (consumeIf("DnE")) + if (consumeIf("Dn") && (consumeIf('0'), consumeIf('E'))) return make<NameType>("nullptr"); return nullptr; case 'T': @@ -4301,12 +4352,12 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExprPrimary() { Node *T = getDerived().parseType(); if (T == nullptr) return nullptr; - StringView N = parseNumber(); + std::string_view N = parseNumber(/*AllowNegative=*/true); if (N.empty()) return nullptr; if (!consumeIf('E')) return nullptr; - return make<IntegerCastExpr>(T, N); + return make<EnumLiteral>(T, N); } } } @@ -4367,55 +4418,38 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFoldExpr() { if (!consumeIf('f')) return nullptr; - char FoldKind = look(); - bool IsLeftFold, HasInitializer; - HasInitializer = FoldKind == 'L' || FoldKind == 'R'; - if (FoldKind == 'l' || FoldKind == 'L') - IsLeftFold = true; - else if (FoldKind == 'r' || FoldKind == 'R') - IsLeftFold = false; - else + bool IsLeftFold = false, HasInitializer = false; + switch (look()) { + default: return nullptr; + case 'L': + IsLeftFold = true; + HasInitializer = true; + break; + case 'R': + HasInitializer = true; + break; + case 'l': + IsLeftFold = true; + break; + case 'r': + break; + } ++First; - // FIXME: This map is duplicated in parseOperatorName and parseExpr. - StringView OperatorName; - if (consumeIf("aa")) OperatorName = "&&"; - else if (consumeIf("an")) OperatorName = "&"; - else if (consumeIf("aN")) OperatorName = "&="; - else if (consumeIf("aS")) OperatorName = "="; - else if (consumeIf("cm")) OperatorName = ","; - else if (consumeIf("ds")) OperatorName = ".*"; - else if (consumeIf("dv")) OperatorName = "/"; - else if (consumeIf("dV")) OperatorName = "/="; - else if (consumeIf("eo")) OperatorName = "^"; - else if (consumeIf("eO")) OperatorName = "^="; - else if (consumeIf("eq")) OperatorName = "=="; - else if (consumeIf("ge")) OperatorName = ">="; - else if (consumeIf("gt")) OperatorName = ">"; - else if (consumeIf("le")) OperatorName = "<="; - else if (consumeIf("ls")) OperatorName = "<<"; - else if (consumeIf("lS")) OperatorName = "<<="; - else if (consumeIf("lt")) OperatorName = "<"; - else if (consumeIf("mi")) OperatorName = "-"; - else if (consumeIf("mI")) OperatorName = "-="; - else if (consumeIf("ml")) OperatorName = "*"; - else if (consumeIf("mL")) OperatorName = "*="; - else if (consumeIf("ne")) OperatorName = "!="; - else if (consumeIf("oo")) OperatorName = "||"; - else if (consumeIf("or")) OperatorName = "|"; - else if (consumeIf("oR")) OperatorName = "|="; - else if (consumeIf("pl")) OperatorName = "+"; - else if (consumeIf("pL")) OperatorName = "+="; - else if (consumeIf("rm")) OperatorName = "%"; - else if (consumeIf("rM")) OperatorName = "%="; - else if (consumeIf("rs")) OperatorName = ">>"; - else if (consumeIf("rS")) OperatorName = ">>="; - else return nullptr; - - Node *Pack = getDerived().parseExpr(), *Init = nullptr; + const auto *Op = parseOperatorEncoding(); + if (!Op) + return nullptr; + if (!(Op->getKind() == OperatorInfo::Binary + || (Op->getKind() == OperatorInfo::Member + && Op->getName().back() == '*'))) + return nullptr; + + Node *Pack = getDerived().parseExpr(); if (Pack == nullptr) return nullptr; + + Node *Init = nullptr; if (HasInitializer) { Init = getDerived().parseExpr(); if (Init == nullptr) @@ -4425,7 +4459,53 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFoldExpr() { if (IsLeftFold && Init) std::swap(Pack, Init); - return make<FoldExpr>(IsLeftFold, OperatorName, Pack, Init); + return make<FoldExpr>(IsLeftFold, Op->getSymbol(), Pack, Init); +} + +// <expression> ::= mc <parameter type> <expr> [<offset number>] E +// +// Not yet in the spec: https://github.com/itanium-cxx-abi/cxx-abi/issues/47 +template <typename Derived, typename Alloc> +Node * +AbstractManglingParser<Derived, Alloc>::parsePointerToMemberConversionExpr( + Node::Prec Prec) { + Node *Ty = getDerived().parseType(); + if (!Ty) + return nullptr; + Node *Expr = getDerived().parseExpr(); + if (!Expr) + return nullptr; + std::string_view Offset = getDerived().parseNumber(true); + if (!consumeIf('E')) + return nullptr; + return make<PointerToMemberConversionExpr>(Ty, Expr, Offset, Prec); +} + +// <expression> ::= so <referent type> <expr> [<offset number>] <union-selector>* [p] E +// <union-selector> ::= _ [<number>] +// +// Not yet in the spec: https://github.com/itanium-cxx-abi/cxx-abi/issues/47 +template <typename Derived, typename Alloc> +Node *AbstractManglingParser<Derived, Alloc>::parseSubobjectExpr() { + Node *Ty = getDerived().parseType(); + if (!Ty) + return nullptr; + Node *Expr = getDerived().parseExpr(); + if (!Expr) + return nullptr; + std::string_view Offset = getDerived().parseNumber(true); + size_t SelectorsBegin = Names.size(); + while (consumeIf('_')) { + Node *Selector = make<NameType>(parseNumber()); + if (!Selector) + return nullptr; + Names.push_back(Selector); + } + bool OnePastTheEnd = consumeIf('p'); + if (!consumeIf('E')) + return nullptr; + return make<SubobjectExpr>( + Ty, Expr, Offset, popTrailingNodeArray(SelectorsBegin), OnePastTheEnd); } // <expression> ::= <unary operator-name> <expression> @@ -4475,313 +4555,127 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFoldExpr() { template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseExpr() { bool Global = consumeIf("gs"); - if (numLeft() < 2) - return nullptr; - switch (*First) { - case 'L': - return getDerived().parseExprPrimary(); - case 'T': - return getDerived().parseTemplateParam(); - case 'f': { - // Disambiguate a fold expression from a <function-param>. - if (look(1) == 'p' || (look(1) == 'L' && std::isdigit(look(2)))) - return getDerived().parseFunctionParam(); - return getDerived().parseFoldExpr(); - } - case 'a': - switch (First[1]) { - case 'a': - First += 2; - return getDerived().parseBinaryExpr("&&"); - case 'd': - First += 2; - return getDerived().parsePrefixExpr("&"); - case 'n': - First += 2; - return getDerived().parseBinaryExpr("&"); - case 'N': - First += 2; - return getDerived().parseBinaryExpr("&="); - case 'S': - First += 2; - return getDerived().parseBinaryExpr("="); - case 't': { - First += 2; - Node *Ty = getDerived().parseType(); - if (Ty == nullptr) - return nullptr; - return make<EnclosingExpr>("alignof (", Ty, ")"); - } - case 'z': { - First += 2; - Node *Ty = getDerived().parseExpr(); - if (Ty == nullptr) - return nullptr; - return make<EnclosingExpr>("alignof (", Ty, ")"); - } - } - return nullptr; - case 'c': - switch (First[1]) { - // cc <type> <expression> # const_cast<type>(expression) - case 'c': { - First += 2; - Node *Ty = getDerived().parseType(); - if (Ty == nullptr) - return Ty; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<CastExpr>("const_cast", Ty, Ex); - } - // cl <expression>+ E # call - case 'l': { - First += 2; - Node *Callee = getDerived().parseExpr(); - if (Callee == nullptr) - return Callee; - size_t ExprsBegin = Names.size(); - while (!consumeIf('E')) { - Node *E = getDerived().parseExpr(); - if (E == nullptr) - return E; - Names.push_back(E); - } - return make<CallExpr>(Callee, popTrailingNodeArray(ExprsBegin)); - } - case 'm': - First += 2; - return getDerived().parseBinaryExpr(","); - case 'o': - First += 2; - return getDerived().parsePrefixExpr("~"); - case 'v': - return getDerived().parseConversionExpr(); - } - return nullptr; - case 'd': - switch (First[1]) { - case 'a': { - First += 2; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<DeleteExpr>(Ex, Global, /*is_array=*/true); - } - case 'c': { - First += 2; - Node *T = getDerived().parseType(); - if (T == nullptr) - return T; + const auto *Op = parseOperatorEncoding(); + if (Op) { + auto Sym = Op->getSymbol(); + switch (Op->getKind()) { + case OperatorInfo::Binary: + // Binary operator: lhs @ rhs + return getDerived().parseBinaryExpr(Sym, Op->getPrecedence()); + case OperatorInfo::Prefix: + // Prefix unary operator: @ expr + return getDerived().parsePrefixExpr(Sym, Op->getPrecedence()); + case OperatorInfo::Postfix: { + // Postfix unary operator: expr @ + if (consumeIf('_')) + return getDerived().parsePrefixExpr(Sym, Op->getPrecedence()); Node *Ex = getDerived().parseExpr(); if (Ex == nullptr) - return Ex; - return make<CastExpr>("dynamic_cast", T, Ex); - } - case 'e': - First += 2; - return getDerived().parsePrefixExpr("*"); - case 'l': { - First += 2; - Node *E = getDerived().parseExpr(); - if (E == nullptr) - return E; - return make<DeleteExpr>(E, Global, /*is_array=*/false); + return nullptr; + return make<PostfixExpr>(Ex, Sym, Op->getPrecedence()); } - case 'n': - return getDerived().parseUnresolvedName(); - case 's': { - First += 2; - Node *LHS = getDerived().parseExpr(); - if (LHS == nullptr) + case OperatorInfo::Array: { + // Array Index: lhs [ rhs ] + Node *Base = getDerived().parseExpr(); + if (Base == nullptr) return nullptr; - Node *RHS = getDerived().parseExpr(); - if (RHS == nullptr) + Node *Index = getDerived().parseExpr(); + if (Index == nullptr) return nullptr; - return make<MemberExpr>(LHS, ".*", RHS); + return make<ArraySubscriptExpr>(Base, Index, Op->getPrecedence()); } - case 't': { - First += 2; + case OperatorInfo::Member: { + // Member access lhs @ rhs Node *LHS = getDerived().parseExpr(); if (LHS == nullptr) - return LHS; + return nullptr; Node *RHS = getDerived().parseExpr(); if (RHS == nullptr) return nullptr; - return make<MemberExpr>(LHS, ".", RHS); - } - case 'v': - First += 2; - return getDerived().parseBinaryExpr("/"); - case 'V': - First += 2; - return getDerived().parseBinaryExpr("/="); - } - return nullptr; - case 'e': - switch (First[1]) { - case 'o': - First += 2; - return getDerived().parseBinaryExpr("^"); - case 'O': - First += 2; - return getDerived().parseBinaryExpr("^="); - case 'q': - First += 2; - return getDerived().parseBinaryExpr("=="); - } - return nullptr; - case 'g': - switch (First[1]) { - case 'e': - First += 2; - return getDerived().parseBinaryExpr(">="); - case 't': - First += 2; - return getDerived().parseBinaryExpr(">"); - } - return nullptr; - case 'i': - switch (First[1]) { - case 'x': { - First += 2; - Node *Base = getDerived().parseExpr(); - if (Base == nullptr) + return make<MemberExpr>(LHS, Sym, RHS, Op->getPrecedence()); + } + case OperatorInfo::New: { + // New + // # new (expr-list) type [(init)] + // [gs] nw <expression>* _ <type> [pi <expression>*] E + // # new[] (expr-list) type [(init)] + // [gs] na <expression>* _ <type> [pi <expression>*] E + size_t Exprs = Names.size(); + while (!consumeIf('_')) { + Node *Ex = getDerived().parseExpr(); + if (Ex == nullptr) + return nullptr; + Names.push_back(Ex); + } + NodeArray ExprList = popTrailingNodeArray(Exprs); + Node *Ty = getDerived().parseType(); + if (Ty == nullptr) return nullptr; - Node *Index = getDerived().parseExpr(); - if (Index == nullptr) - return Index; - return make<ArraySubscriptExpr>(Base, Index); - } - case 'l': { - First += 2; + bool HaveInits = consumeIf("pi"); size_t InitsBegin = Names.size(); while (!consumeIf('E')) { - Node *E = getDerived().parseBracedExpr(); - if (E == nullptr) + if (!HaveInits) return nullptr; - Names.push_back(E); + Node *Init = getDerived().parseExpr(); + if (Init == nullptr) + return Init; + Names.push_back(Init); } - return make<InitListExpr>(nullptr, popTrailingNodeArray(InitsBegin)); - } + NodeArray Inits = popTrailingNodeArray(InitsBegin); + return make<NewExpr>(ExprList, Ty, Inits, Global, + /*IsArray=*/Op->getFlag(), Op->getPrecedence()); } - return nullptr; - case 'l': - switch (First[1]) { - case 'e': - First += 2; - return getDerived().parseBinaryExpr("<="); - case 's': - First += 2; - return getDerived().parseBinaryExpr("<<"); - case 'S': - First += 2; - return getDerived().parseBinaryExpr("<<="); - case 't': - First += 2; - return getDerived().parseBinaryExpr("<"); - } - return nullptr; - case 'm': - switch (First[1]) { - case 'i': - First += 2; - return getDerived().parseBinaryExpr("-"); - case 'I': - First += 2; - return getDerived().parseBinaryExpr("-="); - case 'l': - First += 2; - return getDerived().parseBinaryExpr("*"); - case 'L': - First += 2; - return getDerived().parseBinaryExpr("*="); - case 'm': - First += 2; - if (consumeIf('_')) - return getDerived().parsePrefixExpr("--"); + case OperatorInfo::Del: { + // Delete Node *Ex = getDerived().parseExpr(); if (Ex == nullptr) return nullptr; - return make<PostfixExpr>(Ex, "--"); - } - return nullptr; - case 'n': - switch (First[1]) { - case 'a': - case 'w': - return getDerived().parseNewExpr(); - case 'e': - First += 2; - return getDerived().parseBinaryExpr("!="); - case 'g': - First += 2; - return getDerived().parsePrefixExpr("-"); - case 't': - First += 2; - return getDerived().parsePrefixExpr("!"); - case 'x': - First += 2; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<EnclosingExpr>("noexcept (", Ex, ")"); - } - return nullptr; - case 'o': - switch (First[1]) { - case 'n': - return getDerived().parseUnresolvedName(); - case 'o': - First += 2; - return getDerived().parseBinaryExpr("||"); - case 'r': - First += 2; - return getDerived().parseBinaryExpr("|"); - case 'R': - First += 2; - return getDerived().parseBinaryExpr("|="); + return make<DeleteExpr>(Ex, Global, /*IsArray=*/Op->getFlag(), + Op->getPrecedence()); } - return nullptr; - case 'p': - switch (First[1]) { - case 'm': - First += 2; - return getDerived().parseBinaryExpr("->*"); - case 'l': - First += 2; - return getDerived().parseBinaryExpr("+"); - case 'L': - First += 2; - return getDerived().parseBinaryExpr("+="); - case 'p': { - First += 2; - if (consumeIf('_')) - return getDerived().parsePrefixExpr("++"); - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<PostfixExpr>(Ex, "++"); + case OperatorInfo::Call: { + // Function Call + Node *Callee = getDerived().parseExpr(); + if (Callee == nullptr) + return nullptr; + size_t ExprsBegin = Names.size(); + while (!consumeIf('E')) { + Node *E = getDerived().parseExpr(); + if (E == nullptr) + return nullptr; + Names.push_back(E); + } + return make<CallExpr>(Callee, popTrailingNodeArray(ExprsBegin), + Op->getPrecedence()); } - case 's': - First += 2; - return getDerived().parsePrefixExpr("+"); - case 't': { - First += 2; - Node *L = getDerived().parseExpr(); - if (L == nullptr) + case OperatorInfo::CCast: { + // C Cast: (type)expr + Node *Ty; + { + ScopedOverride<bool> SaveTemp(TryToParseTemplateArgs, false); + Ty = getDerived().parseType(); + } + if (Ty == nullptr) return nullptr; - Node *R = getDerived().parseExpr(); - if (R == nullptr) + + size_t ExprsBegin = Names.size(); + bool IsMany = consumeIf('_'); + while (!consumeIf('E')) { + Node *E = getDerived().parseExpr(); + if (E == nullptr) + return E; + Names.push_back(E); + if (!IsMany) + break; + } + NodeArray Exprs = popTrailingNodeArray(ExprsBegin); + if (!IsMany && Exprs.size() != 1) return nullptr; - return make<MemberExpr>(L, "->", R); + return make<ConversionExpr>(Ty, Exprs, Op->getPrecedence()); } - } - return nullptr; - case 'q': - if (First[1] == 'u') { - First += 2; + case OperatorInfo::Conditional: { + // Conditional operator: expr ? expr : expr Node *Cond = getDerived().parseExpr(); if (Cond == nullptr) return nullptr; @@ -4791,169 +4685,158 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExpr() { Node *RHS = getDerived().parseExpr(); if (RHS == nullptr) return nullptr; - return make<ConditionalExpr>(Cond, LHS, RHS); - } - return nullptr; - case 'r': - switch (First[1]) { - case 'c': { - First += 2; - Node *T = getDerived().parseType(); - if (T == nullptr) - return T; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<CastExpr>("reinterpret_cast", T, Ex); + return make<ConditionalExpr>(Cond, LHS, RHS, Op->getPrecedence()); } - case 'm': - First += 2; - return getDerived().parseBinaryExpr("%"); - case 'M': - First += 2; - return getDerived().parseBinaryExpr("%="); - case 's': - First += 2; - return getDerived().parseBinaryExpr(">>"); - case 'S': - First += 2; - return getDerived().parseBinaryExpr(">>="); - } - return nullptr; - case 's': - switch (First[1]) { - case 'c': { - First += 2; - Node *T = getDerived().parseType(); - if (T == nullptr) - return T; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<CastExpr>("static_cast", T, Ex); - } - case 'p': { - First += 2; - Node *Child = getDerived().parseExpr(); - if (Child == nullptr) - return nullptr; - return make<ParameterPackExpansion>(Child); - } - case 'r': - return getDerived().parseUnresolvedName(); - case 't': { - First += 2; + case OperatorInfo::NamedCast: { + // Named cast operation, @<type>(expr) Node *Ty = getDerived().parseType(); if (Ty == nullptr) - return Ty; - return make<EnclosingExpr>("sizeof (", Ty, ")"); - } - case 'z': { - First += 2; + return nullptr; Node *Ex = getDerived().parseExpr(); if (Ex == nullptr) - return Ex; - return make<EnclosingExpr>("sizeof (", Ex, ")"); - } - case 'Z': - First += 2; - if (look() == 'T') { - Node *R = getDerived().parseTemplateParam(); - if (R == nullptr) - return nullptr; - return make<SizeofParamPackExpr>(R); - } else if (look() == 'f') { - Node *FP = getDerived().parseFunctionParam(); - if (FP == nullptr) - return nullptr; - return make<EnclosingExpr>("sizeof... (", FP, ")"); - } - return nullptr; - case 'P': { - First += 2; - size_t ArgsBegin = Names.size(); - while (!consumeIf('E')) { - Node *Arg = getDerived().parseTemplateArg(); - if (Arg == nullptr) - return nullptr; - Names.push_back(Arg); - } - auto *Pack = make<NodeArrayNode>(popTrailingNodeArray(ArgsBegin)); - if (!Pack) return nullptr; - return make<EnclosingExpr>("sizeof... (", Pack, ")"); + return make<CastExpr>(Sym, Ty, Ex, Op->getPrecedence()); } + case OperatorInfo::OfIdOp: { + // [sizeof/alignof/typeid] ( <type>|<expr> ) + Node *Arg = + Op->getFlag() ? getDerived().parseType() : getDerived().parseExpr(); + if (!Arg) + return nullptr; + return make<EnclosingExpr>(Sym, Arg, Op->getPrecedence()); } - return nullptr; - case 't': - switch (First[1]) { - case 'e': { - First += 2; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) - return Ex; - return make<EnclosingExpr>("typeid (", Ex, ")"); + case OperatorInfo::NameOnly: { + // Not valid as an expression operand. + return nullptr; } - case 'i': { - First += 2; - Node *Ty = getDerived().parseType(); - if (Ty == nullptr) - return Ty; - return make<EnclosingExpr>("typeid (", Ty, ")"); } - case 'l': { - First += 2; - Node *Ty = getDerived().parseType(); - if (Ty == nullptr) + DEMANGLE_UNREACHABLE; + } + + if (numLeft() < 2) + return nullptr; + + if (look() == 'L') + return getDerived().parseExprPrimary(); + if (look() == 'T') + return getDerived().parseTemplateParam(); + if (look() == 'f') { + // Disambiguate a fold expression from a <function-param>. + if (look(1) == 'p' || (look(1) == 'L' && std::isdigit(look(2)))) + return getDerived().parseFunctionParam(); + return getDerived().parseFoldExpr(); + } + if (consumeIf("il")) { + size_t InitsBegin = Names.size(); + while (!consumeIf('E')) { + Node *E = getDerived().parseBracedExpr(); + if (E == nullptr) return nullptr; - size_t InitsBegin = Names.size(); - while (!consumeIf('E')) { - Node *E = getDerived().parseBracedExpr(); - if (E == nullptr) - return nullptr; - Names.push_back(E); - } - return make<InitListExpr>(Ty, popTrailingNodeArray(InitsBegin)); + Names.push_back(E); } - case 'r': - First += 2; - return make<NameType>("throw"); - case 'w': { - First += 2; - Node *Ex = getDerived().parseExpr(); - if (Ex == nullptr) + return make<InitListExpr>(nullptr, popTrailingNodeArray(InitsBegin)); + } + if (consumeIf("mc")) + return parsePointerToMemberConversionExpr(Node::Prec::Unary); + if (consumeIf("nx")) { + Node *Ex = getDerived().parseExpr(); + if (Ex == nullptr) + return Ex; + return make<EnclosingExpr>("noexcept ", Ex, Node::Prec::Unary); + } + if (consumeIf("so")) + return parseSubobjectExpr(); + if (consumeIf("sp")) { + Node *Child = getDerived().parseExpr(); + if (Child == nullptr) + return nullptr; + return make<ParameterPackExpansion>(Child); + } + if (consumeIf("sZ")) { + if (look() == 'T') { + Node *R = getDerived().parseTemplateParam(); + if (R == nullptr) return nullptr; - return make<ThrowExpr>(Ex); + return make<SizeofParamPackExpr>(R); } + Node *FP = getDerived().parseFunctionParam(); + if (FP == nullptr) + return nullptr; + return make<EnclosingExpr>("sizeof... ", FP); + } + if (consumeIf("sP")) { + size_t ArgsBegin = Names.size(); + while (!consumeIf('E')) { + Node *Arg = getDerived().parseTemplateArg(); + if (Arg == nullptr) + return nullptr; + Names.push_back(Arg); } - return nullptr; - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return getDerived().parseUnresolvedName(); - } - - if (consumeIf("u8__uuidoft")) { + auto *Pack = make<NodeArrayNode>(popTrailingNodeArray(ArgsBegin)); + if (!Pack) + return nullptr; + return make<EnclosingExpr>("sizeof... ", Pack); + } + if (consumeIf("tl")) { Node *Ty = getDerived().parseType(); - if (!Ty) + if (Ty == nullptr) return nullptr; - return make<UUIDOfExpr>(Ty); + size_t InitsBegin = Names.size(); + while (!consumeIf('E')) { + Node *E = getDerived().parseBracedExpr(); + if (E == nullptr) + return nullptr; + Names.push_back(E); + } + return make<InitListExpr>(Ty, popTrailingNodeArray(InitsBegin)); } - - if (consumeIf("u8__uuidofz")) { + if (consumeIf("tr")) + return make<NameType>("throw"); + if (consumeIf("tw")) { Node *Ex = getDerived().parseExpr(); - if (!Ex) + if (Ex == nullptr) return nullptr; - return make<UUIDOfExpr>(Ex); + return make<ThrowExpr>(Ex); + } + if (consumeIf('u')) { + Node *Name = getDerived().parseSourceName(/*NameState=*/nullptr); + if (!Name) + return nullptr; + // Special case legacy __uuidof mangling. The 't' and 'z' appear where the + // standard encoding expects a <template-arg>, and would be otherwise be + // interpreted as <type> node 'short' or 'ellipsis'. However, neither + // __uuidof(short) nor __uuidof(...) can actually appear, so there is no + // actual conflict here. + bool IsUUID = false; + Node *UUID = nullptr; + if (Name->getBaseName() == "__uuidof") { + if (consumeIf('t')) { + UUID = getDerived().parseType(); + IsUUID = true; + } else if (consumeIf('z')) { + UUID = getDerived().parseExpr(); + IsUUID = true; + } + } + size_t ExprsBegin = Names.size(); + if (IsUUID) { + if (UUID == nullptr) + return nullptr; + Names.push_back(UUID); + } else { + while (!consumeIf('E')) { + Node *E = getDerived().parseTemplateArg(); + if (E == nullptr) + return E; + Names.push_back(E); + } + } + return make<CallExpr>(Name, popTrailingNodeArray(ExprsBegin), + Node::Prec::Postfix); } - return nullptr; + // Only unresolved names remain. + return getDerived().parseUnresolvedName(Global); } // <call-offset> ::= h <nv-offset> _ @@ -4986,19 +4869,32 @@ bool AbstractManglingParser<Alloc, Derived>::parseCallOffset() { // # second call-offset is result adjustment // ::= T <call-offset> <base encoding> // # base is the nominal target function of thunk -// ::= GV <object name> # Guard variable for one-time initialization +// # Guard variable for one-time initialization +// ::= GV <object name> // # No <type> // ::= TW <object name> # Thread-local wrapper // ::= TH <object name> # Thread-local initialization // ::= GR <object name> _ # First temporary // ::= GR <object name> <seq-id> _ # Subsequent temporaries -// extension ::= TC <first type> <number> _ <second type> # construction vtable for second-in-first +// # construction vtable for second-in-first +// extension ::= TC <first type> <number> _ <second type> // extension ::= GR <object name> # reference temporary for object +// extension ::= GI <module name> # module global initializer template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() { switch (look()) { case 'T': switch (look(1)) { + // TA <template-arg> # template parameter object + // + // Not yet in the spec: https://github.com/itanium-cxx-abi/cxx-abi/issues/63 + case 'A': { + First += 2; + Node *Arg = getDerived().parseTemplateArg(); + if (Arg == nullptr) + return nullptr; + return make<SpecialName>("template parameter object for ", Arg); + } // TV <type> # virtual table case 'V': { First += 2; @@ -5110,6 +5006,16 @@ Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() { return nullptr; return make<SpecialName>("reference temporary for ", Name); } + // GI <module-name> v + case 'I': { + First += 2; + ModuleName *Module = nullptr; + if (getDerived().parseModuleNameOpt(Module)) + return nullptr; + if (Module == nullptr) + return nullptr; + return make<SpecialName>("initializer for module ", Module); + } } } return nullptr; @@ -5120,6 +5026,26 @@ Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() { // ::= <special-name> template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseEncoding() { + // The template parameters of an encoding are unrelated to those of the + // enclosing context. + class SaveTemplateParams { + AbstractManglingParser *Parser; + decltype(TemplateParams) OldParams; + decltype(OuterTemplateParams) OldOuterParams; + + public: + SaveTemplateParams(AbstractManglingParser *TheParser) : Parser(TheParser) { + OldParams = std::move(Parser->TemplateParams); + OldOuterParams = std::move(Parser->OuterTemplateParams); + Parser->TemplateParams.clear(); + Parser->OuterTemplateParams.clear(); + } + ~SaveTemplateParams() { + Parser->TemplateParams = std::move(OldParams); + Parser->OuterTemplateParams = std::move(OldOuterParams); + } + } SaveTemplateParams(this); + if (look() == 'G' || look() == 'T') return getDerived().parseSpecialName(); @@ -5204,14 +5130,19 @@ template <> struct FloatData<long double> { #if defined(__mips__) && defined(__mips_n64) || defined(__aarch64__) || \ - defined(__wasm__) + defined(__wasm__) || defined(__riscv) || defined(__loongarch__) static const size_t mangled_size = 32; #elif defined(__arm__) || defined(__mips__) || defined(__hexagon__) static const size_t mangled_size = 16; #else static const size_t mangled_size = 20; // May need to be adjusted to 16 or 24 on other platforms #endif - static const size_t max_demangled_size = 40; + // `-0x1.ffffffffffffffffffffffffffffp+16383` + 'L' + '\0' == 42 bytes. + // 28 'f's * 4 bits == 112 bits, which is the number of mantissa bits. + // Negatives are one character longer than positives. + // `0x1.` and `p` are constant, and exponents `+16383` and `-16382` are the + // same length. 1 sign bit, 112 mantissa bits, and 15 exponent bits == 128. + static const size_t max_demangled_size = 42; static constexpr const char *spec = "%LaL"; }; @@ -5221,7 +5152,7 @@ Node *AbstractManglingParser<Alloc, Derived>::parseFloatingLiteral() { const size_t N = FloatData<Float>::mangled_size; if (numLeft() <= N) return nullptr; - StringView Data(First, First + N); + std::string_view Data(First, N); for (char C : Data) if (!std::isxdigit(C)) return nullptr; @@ -5264,43 +5195,41 @@ bool AbstractManglingParser<Alloc, Derived>::parseSeqId(size_t *Out) { // <substitution> ::= Si # ::std::basic_istream<char, std::char_traits<char> > // <substitution> ::= So # ::std::basic_ostream<char, std::char_traits<char> > // <substitution> ::= Sd # ::std::basic_iostream<char, std::char_traits<char> > +// The St case is handled specially in parseNestedName. template <typename Derived, typename Alloc> Node *AbstractManglingParser<Derived, Alloc>::parseSubstitution() { if (!consumeIf('S')) return nullptr; - if (std::islower(look())) { - Node *SpecialSub; + if (look() >= 'a' && look() <= 'z') { + SpecialSubKind Kind; switch (look()) { case 'a': - ++First; - SpecialSub = make<SpecialSubstitution>(SpecialSubKind::allocator); + Kind = SpecialSubKind::allocator; break; case 'b': - ++First; - SpecialSub = make<SpecialSubstitution>(SpecialSubKind::basic_string); + Kind = SpecialSubKind::basic_string; break; - case 's': - ++First; - SpecialSub = make<SpecialSubstitution>(SpecialSubKind::string); + case 'd': + Kind = SpecialSubKind::iostream; break; case 'i': - ++First; - SpecialSub = make<SpecialSubstitution>(SpecialSubKind::istream); + Kind = SpecialSubKind::istream; break; case 'o': - ++First; - SpecialSub = make<SpecialSubstitution>(SpecialSubKind::ostream); + Kind = SpecialSubKind::ostream; break; - case 'd': - ++First; - SpecialSub = make<SpecialSubstitution>(SpecialSubKind::iostream); + case 's': + Kind = SpecialSubKind::string; break; default: return nullptr; } + ++First; + auto *SpecialSub = make<SpecialSubstitution>(Kind); if (!SpecialSub) return nullptr; + // Itanium C++ ABI 5.1.2: If a name that would use a built-in <substitution> // has ABI tags, the tags are appended to the substitution; the result is a // substitutable component. @@ -5543,7 +5472,8 @@ Node *AbstractManglingParser<Derived, Alloc>::parse() { if (Encoding == nullptr) return nullptr; if (look() == '.') { - Encoding = make<DotSuffix>(Encoding, StringView(First, Last)); + Encoding = + make<DotSuffix>(Encoding, std::string_view(First, Last - First)); First = Last; } if (numLeft() != 0) diff --git a/externals/demangle/llvm/Demangle/ItaniumNodes.def b/externals/demangle/llvm/Demangle/ItaniumNodes.def new file mode 100644 index 000000000..5985769ef --- /dev/null +++ b/externals/demangle/llvm/Demangle/ItaniumNodes.def @@ -0,0 +1,96 @@ +//===--- ItaniumNodes.def ------------*- mode:c++;eval:(read-only-mode) -*-===// +// Do not edit! See README.txt. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-FileCopyrightText: Part of the LLVM Project +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Define the demangler's node names + +#ifndef NODE +#error Define NODE to handle nodes +#endif + +NODE(NodeArrayNode) +NODE(DotSuffix) +NODE(VendorExtQualType) +NODE(QualType) +NODE(ConversionOperatorType) +NODE(PostfixQualifiedType) +NODE(ElaboratedTypeSpefType) +NODE(NameType) +NODE(AbiTagAttr) +NODE(EnableIfAttr) +NODE(ObjCProtoName) +NODE(PointerType) +NODE(ReferenceType) +NODE(PointerToMemberType) +NODE(ArrayType) +NODE(FunctionType) +NODE(NoexceptSpec) +NODE(DynamicExceptionSpec) +NODE(FunctionEncoding) +NODE(LiteralOperator) +NODE(SpecialName) +NODE(CtorVtableSpecialName) +NODE(QualifiedName) +NODE(NestedName) +NODE(LocalName) +NODE(ModuleName) +NODE(ModuleEntity) +NODE(VectorType) +NODE(PixelVectorType) +NODE(BinaryFPType) +NODE(BitIntType) +NODE(SyntheticTemplateParamName) +NODE(TypeTemplateParamDecl) +NODE(NonTypeTemplateParamDecl) +NODE(TemplateTemplateParamDecl) +NODE(TemplateParamPackDecl) +NODE(ParameterPack) +NODE(TemplateArgumentPack) +NODE(ParameterPackExpansion) +NODE(TemplateArgs) +NODE(ForwardTemplateReference) +NODE(NameWithTemplateArgs) +NODE(GlobalQualifiedName) +NODE(ExpandedSpecialSubstitution) +NODE(SpecialSubstitution) +NODE(CtorDtorName) +NODE(DtorName) +NODE(UnnamedTypeName) +NODE(ClosureTypeName) +NODE(StructuredBindingName) +NODE(BinaryExpr) +NODE(ArraySubscriptExpr) +NODE(PostfixExpr) +NODE(ConditionalExpr) +NODE(MemberExpr) +NODE(SubobjectExpr) +NODE(EnclosingExpr) +NODE(CastExpr) +NODE(SizeofParamPackExpr) +NODE(CallExpr) +NODE(NewExpr) +NODE(DeleteExpr) +NODE(PrefixExpr) +NODE(FunctionParam) +NODE(ConversionExpr) +NODE(PointerToMemberConversionExpr) +NODE(InitListExpr) +NODE(FoldExpr) +NODE(ThrowExpr) +NODE(BoolExpr) +NODE(StringLiteral) +NODE(LambdaExpr) +NODE(EnumLiteral) +NODE(IntegerLiteral) +NODE(FloatLiteral) +NODE(DoubleLiteral) +NODE(LongDoubleLiteral) +NODE(BracedExpr) +NODE(BracedRangeExpr) + +#undef NODE diff --git a/externals/demangle/llvm/Demangle/StringView.h b/externals/demangle/llvm/Demangle/StringView.h index 44d2b18a3..76b215252 100644 --- a/externals/demangle/llvm/Demangle/StringView.h +++ b/externals/demangle/llvm/Demangle/StringView.h @@ -1,5 +1,5 @@ -//===--- StringView.h -------------------------------------------*- C++ -*-===// -// +//===--- StringView.h ----------------*- mode:c++;eval:(read-only-mode) -*-===// +// Do not edit! See README.txt. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-FileCopyrightText: Part of the LLVM Project @@ -8,6 +8,9 @@ //===----------------------------------------------------------------------===// // // FIXME: Use std::string_view instead when we support C++17. +// There are two copies of this file in the source tree. The one under +// libcxxabi is the original and the one under llvm is the copy. Use +// cp-to-llvm.sh to update the copy. See README.txt for more details. // //===----------------------------------------------------------------------===// @@ -15,7 +18,6 @@ #define DEMANGLE_STRINGVIEW_H #include "DemangleConfig.h" -#include <algorithm> #include <cassert> #include <cstring> @@ -37,29 +39,23 @@ public: StringView(const char *Str) : First(Str), Last(Str + std::strlen(Str)) {} StringView() : First(nullptr), Last(nullptr) {} - StringView substr(size_t From) const { - return StringView(begin() + From, size() - From); + StringView substr(size_t Pos, size_t Len = npos) const { + assert(Pos <= size()); + if (Len > size() - Pos) + Len = size() - Pos; + return StringView(begin() + Pos, Len); } size_t find(char C, size_t From = 0) const { - size_t FindBegin = std::min(From, size()); // Avoid calling memchr with nullptr. - if (FindBegin < size()) { + if (From < size()) { // Just forward to memchr, which is faster than a hand-rolled loop. - if (const void *P = ::memchr(First + FindBegin, C, size() - FindBegin)) + if (const void *P = ::memchr(First + From, C, size() - From)) return size_t(static_cast<const char *>(P) - First); } return npos; } - StringView substr(size_t From, size_t To) const { - if (To >= size()) - To = size() - 1; - if (From >= size()) - From = size() - 1; - return StringView(First + From, First + To); - } - StringView dropFront(size_t N = 1) const { if (N >= size()) N = size(); @@ -106,7 +102,7 @@ public: bool startsWith(StringView Str) const { if (Str.size() > size()) return false; - return std::equal(Str.begin(), Str.end(), begin()); + return std::strncmp(Str.begin(), begin(), Str.size()) == 0; } const char &operator[](size_t Idx) const { return *(begin() + Idx); } @@ -119,7 +115,7 @@ public: inline bool operator==(const StringView &LHS, const StringView &RHS) { return LHS.size() == RHS.size() && - std::equal(LHS.begin(), LHS.end(), RHS.begin()); + std::strncmp(LHS.begin(), RHS.begin(), LHS.size()) == 0; } DEMANGLE_NAMESPACE_END diff --git a/externals/demangle/llvm/Demangle/StringViewExtras.h b/externals/demangle/llvm/Demangle/StringViewExtras.h new file mode 100644 index 000000000..83685c304 --- /dev/null +++ b/externals/demangle/llvm/Demangle/StringViewExtras.h @@ -0,0 +1,39 @@ +//===--- StringViewExtras.h ----------*- mode:c++;eval:(read-only-mode) -*-===// +// Do not edit! See README.txt. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-FileCopyrightText: Part of the LLVM Project +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// There are two copies of this file in the source tree. The one under +// libcxxabi is the original and the one under llvm is the copy. Use +// cp-to-llvm.sh to update the copy. See README.txt for more details. +// +//===----------------------------------------------------------------------===// + +#ifndef DEMANGLE_STRINGVIEW_H +#define DEMANGLE_STRINGVIEW_H + +#include "DemangleConfig.h" + +#include <string_view> + +DEMANGLE_NAMESPACE_BEGIN + +inline bool starts_with(std::string_view self, char C) noexcept { + return !self.empty() && *self.begin() == C; +} + +inline bool starts_with(std::string_view haystack, + std::string_view needle) noexcept { + if (needle.size() > haystack.size()) + return false; + haystack.remove_suffix(haystack.size() - needle.size()); + return haystack == needle; +} + +DEMANGLE_NAMESPACE_END + +#endif diff --git a/externals/demangle/llvm/Demangle/Utility.h b/externals/demangle/llvm/Demangle/Utility.h index 50d05c6b1..30dfbfc8d 100644 --- a/externals/demangle/llvm/Demangle/Utility.h +++ b/externals/demangle/llvm/Demangle/Utility.h @@ -1,5 +1,5 @@ -//===--- Utility.h ----------------------------------------------*- C++ -*-===// -// +//===--- Utility.h -------------------*- mode:c++;eval:(read-only-mode) -*-===// +// Do not edit! See README.txt. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-FileCopyrightText: Part of the LLVM Project @@ -7,70 +7,83 @@ // //===----------------------------------------------------------------------===// // -// Provide some utility classes for use in the demangler(s). +// Provide some utility classes for use in the demangler. +// There are two copies of this file in the source tree. The one in libcxxabi +// is the original and the one in llvm is the copy. Use cp-to-llvm.sh to update +// the copy. See README.txt for more details. // //===----------------------------------------------------------------------===// #ifndef DEMANGLE_UTILITY_H #define DEMANGLE_UTILITY_H -#include "StringView.h" +#include "DemangleConfig.h" + +#include <array> +#include <cassert> #include <cstdint> #include <cstdlib> #include <cstring> -#include <iterator> +#include <exception> #include <limits> +#include <string_view> DEMANGLE_NAMESPACE_BEGIN // Stream that AST nodes write their string representation into after the AST // has been parsed. -class OutputStream { - char *Buffer; - size_t CurrentPosition; - size_t BufferCapacity; +class OutputBuffer { + char *Buffer = nullptr; + size_t CurrentPosition = 0; + size_t BufferCapacity = 0; - // Ensure there is at least n more positions in buffer. + // Ensure there are at least N more positions in the buffer. void grow(size_t N) { - if (N + CurrentPosition >= BufferCapacity) { + size_t Need = N + CurrentPosition; + if (Need > BufferCapacity) { + // Reduce the number of reallocations, with a bit of hysteresis. The + // number here is chosen so the first allocation will more-than-likely not + // allocate more than 1K. + Need += 1024 - 32; BufferCapacity *= 2; - if (BufferCapacity < N + CurrentPosition) - BufferCapacity = N + CurrentPosition; + if (BufferCapacity < Need) + BufferCapacity = Need; Buffer = static_cast<char *>(std::realloc(Buffer, BufferCapacity)); if (Buffer == nullptr) std::terminate(); } } - void writeUnsigned(uint64_t N, bool isNeg = false) { - // Handle special case... - if (N == 0) { - *this << '0'; - return; - } - - char Temp[21]; - char *TempPtr = std::end(Temp); + OutputBuffer &writeUnsigned(uint64_t N, bool isNeg = false) { + std::array<char, 21> Temp; + char *TempPtr = Temp.data() + Temp.size(); - while (N) { - *--TempPtr = '0' + char(N % 10); + // Output at least one character. + do { + *--TempPtr = char('0' + N % 10); N /= 10; - } + } while (N); - // Add negative sign... + // Add negative sign. if (isNeg) *--TempPtr = '-'; - this->operator<<(StringView(TempPtr, std::end(Temp))); + + return operator+=( + std::string_view(TempPtr, Temp.data() + Temp.size() - TempPtr)); } public: - OutputStream(char *StartBuf, size_t Size) - : Buffer(StartBuf), CurrentPosition(0), BufferCapacity(Size) {} - OutputStream() = default; - void reset(char *Buffer_, size_t BufferCapacity_) { - CurrentPosition = 0; - Buffer = Buffer_; - BufferCapacity = BufferCapacity_; + OutputBuffer(char *StartBuf, size_t Size) + : Buffer(StartBuf), BufferCapacity(Size) {} + OutputBuffer(char *StartBuf, size_t *SizePtr) + : OutputBuffer(StartBuf, StartBuf ? *SizePtr : 0) {} + OutputBuffer() = default; + // Non-copyable + OutputBuffer(const OutputBuffer &) = delete; + OutputBuffer &operator=(const OutputBuffer &) = delete; + + operator std::string_view() const { + return std::string_view(Buffer, CurrentPosition); } /// If a ParameterPackExpansion (or similar type) is encountered, the offset @@ -78,115 +91,116 @@ public: unsigned CurrentPackIndex = std::numeric_limits<unsigned>::max(); unsigned CurrentPackMax = std::numeric_limits<unsigned>::max(); - OutputStream &operator+=(StringView R) { - size_t Size = R.size(); - if (Size == 0) - return *this; - grow(Size); - std::memmove(Buffer + CurrentPosition, R.begin(), Size); - CurrentPosition += Size; + /// When zero, we're printing template args and '>' needs to be parenthesized. + /// Use a counter so we can simply increment inside parentheses. + unsigned GtIsGt = 1; + + bool isGtInsideTemplateArgs() const { return GtIsGt == 0; } + + void printOpen(char Open = '(') { + GtIsGt++; + *this += Open; + } + void printClose(char Close = ')') { + GtIsGt--; + *this += Close; + } + + OutputBuffer &operator+=(std::string_view R) { + if (size_t Size = R.size()) { + grow(Size); + std::memcpy(Buffer + CurrentPosition, &*R.begin(), Size); + CurrentPosition += Size; + } return *this; } - OutputStream &operator+=(char C) { + OutputBuffer &operator+=(char C) { grow(1); Buffer[CurrentPosition++] = C; return *this; } - OutputStream &operator<<(StringView R) { return (*this += R); } + OutputBuffer &prepend(std::string_view R) { + size_t Size = R.size(); - OutputStream &operator<<(char C) { return (*this += C); } + grow(Size); + std::memmove(Buffer + Size, Buffer, CurrentPosition); + std::memcpy(Buffer, &*R.begin(), Size); + CurrentPosition += Size; - OutputStream &operator<<(long long N) { - if (N < 0) - writeUnsigned(static_cast<unsigned long long>(-N), true); - else - writeUnsigned(static_cast<unsigned long long>(N)); return *this; } - OutputStream &operator<<(unsigned long long N) { - writeUnsigned(N, false); - return *this; + OutputBuffer &operator<<(std::string_view R) { return (*this += R); } + + OutputBuffer &operator<<(char C) { return (*this += C); } + + OutputBuffer &operator<<(long long N) { + return writeUnsigned(static_cast<unsigned long long>(std::abs(N)), N < 0); } - OutputStream &operator<<(long N) { + OutputBuffer &operator<<(unsigned long long N) { + return writeUnsigned(N, false); + } + + OutputBuffer &operator<<(long N) { return this->operator<<(static_cast<long long>(N)); } - OutputStream &operator<<(unsigned long N) { + OutputBuffer &operator<<(unsigned long N) { return this->operator<<(static_cast<unsigned long long>(N)); } - OutputStream &operator<<(int N) { + OutputBuffer &operator<<(int N) { return this->operator<<(static_cast<long long>(N)); } - OutputStream &operator<<(unsigned int N) { + OutputBuffer &operator<<(unsigned int N) { return this->operator<<(static_cast<unsigned long long>(N)); } + void insert(size_t Pos, const char *S, size_t N) { + assert(Pos <= CurrentPosition); + if (N == 0) + return; + grow(N); + std::memmove(Buffer + Pos + N, Buffer + Pos, CurrentPosition - Pos); + std::memcpy(Buffer + Pos, S, N); + CurrentPosition += N; + } + size_t getCurrentPosition() const { return CurrentPosition; } void setCurrentPosition(size_t NewPos) { CurrentPosition = NewPos; } char back() const { - return CurrentPosition ? Buffer[CurrentPosition - 1] : '\0'; + assert(CurrentPosition); + return Buffer[CurrentPosition - 1]; } bool empty() const { return CurrentPosition == 0; } char *getBuffer() { return Buffer; } char *getBufferEnd() { return Buffer + CurrentPosition - 1; } - size_t getBufferCapacity() { return BufferCapacity; } + size_t getBufferCapacity() const { return BufferCapacity; } }; -template <class T> class SwapAndRestore { - T &Restore; - T OriginalValue; - bool ShouldRestore = true; +template <class T> class ScopedOverride { + T &Loc; + T Original; public: - SwapAndRestore(T &Restore_) : SwapAndRestore(Restore_, Restore_) {} - - SwapAndRestore(T &Restore_, T NewVal) - : Restore(Restore_), OriginalValue(Restore) { - Restore = std::move(NewVal); - } - ~SwapAndRestore() { - if (ShouldRestore) - Restore = std::move(OriginalValue); - } + ScopedOverride(T &Loc_) : ScopedOverride(Loc_, Loc_) {} - void shouldRestore(bool ShouldRestore_) { ShouldRestore = ShouldRestore_; } - - void restoreNow(bool Force) { - if (!Force && !ShouldRestore) - return; - - Restore = std::move(OriginalValue); - ShouldRestore = false; + ScopedOverride(T &Loc_, T NewVal) : Loc(Loc_), Original(Loc_) { + Loc_ = std::move(NewVal); } + ~ScopedOverride() { Loc = std::move(Original); } - SwapAndRestore(const SwapAndRestore &) = delete; - SwapAndRestore &operator=(const SwapAndRestore &) = delete; + ScopedOverride(const ScopedOverride &) = delete; + ScopedOverride &operator=(const ScopedOverride &) = delete; }; -inline bool initializeOutputStream(char *Buf, size_t *N, OutputStream &S, - size_t InitSize) { - size_t BufferSize; - if (Buf == nullptr) { - Buf = static_cast<char *>(std::malloc(InitSize)); - if (Buf == nullptr) - return false; - BufferSize = InitSize; - } else - BufferSize = *N; - - S.reset(Buf, BufferSize); - return true; -} - DEMANGLE_NAMESPACE_END #endif diff --git a/externals/vma/VulkanMemoryAllocator b/externals/vma/VulkanMemoryAllocator deleted file mode 160000 -Subproject 0aa3989b8f382f185fdf646cc83a1d16fa31d6a diff --git a/src/android/app/src/main/res/values/strings.xml b/src/android/app/src/main/res/values/strings.xml index b963f0119..bfdebd35b 100644 --- a/src/android/app/src/main/res/values/strings.xml +++ b/src/android/app/src/main/res/values/strings.xml @@ -232,7 +232,7 @@ <!-- ROM loading errors --> <string name="loader_error_encrypted">Your ROM is encrypted</string> - <string name="loader_error_encrypted_roms_description"><![CDATA[Please follow the guides to redump your <a href="https://yuzu-emu.org/help/quickstart/#dumping-cartridge-games">game cartidges</a> or <a href="https://yuzu-emu.org/help/quickstart/#dumping-installed-titles-eshop">installed titles</a>.]]></string> + <string name="loader_error_encrypted_roms_description"><![CDATA[Please follow the guides to redump your <a href="https://yuzu-emu.org/help/quickstart/#dumping-physical-titles-game-cards">game cartidges</a> or <a href="https://yuzu-emu.org/help/quickstart/#dumping-digital-titles-eshop">installed titles</a>.]]></string> <string name="loader_error_encrypted_keys_description"><![CDATA[Please ensure your <a href="https://yuzu-emu.org/help/quickstart/#dumping-prodkeys-and-titlekeys">prod.keys</a> file is installed so that games can be decrypted.]]></string> <string name="loader_error_video_core">An error occurred initializing the video core</string> <string name="loader_error_video_core_description">This is usually caused by an incompatible GPU driver. Installing a custom GPU driver may resolve this problem.</string> diff --git a/src/common/demangle.cpp b/src/common/demangle.cpp index 3310faf86..6e117cb41 100644 --- a/src/common/demangle.cpp +++ b/src/common/demangle.cpp @@ -23,7 +23,7 @@ std::string DemangleSymbol(const std::string& mangled) { SCOPE_EXIT({ std::free(demangled); }); if (is_itanium(mangled)) { - demangled = llvm::itaniumDemangle(mangled.c_str(), nullptr, nullptr, nullptr); + demangled = llvm::itaniumDemangle(mangled.c_str()); } if (!demangled) { diff --git a/src/common/detached_tasks.cpp b/src/common/detached_tasks.cpp index da64848da..f2ed795cc 100644 --- a/src/common/detached_tasks.cpp +++ b/src/common/detached_tasks.cpp @@ -30,8 +30,8 @@ DetachedTasks::~DetachedTasks() { void DetachedTasks::AddTask(std::function<void()> task) { std::unique_lock lock{instance->mutex}; ++instance->count; - std::thread([task{std::move(task)}]() { - task(); + std::thread([task_{std::move(task)}]() { + task_(); std::unique_lock thread_lock{instance->mutex}; --instance->count; std::notify_all_at_thread_exit(instance->cv, std::move(thread_lock)); diff --git a/src/common/socket_types.h b/src/common/socket_types.h index 0a801a443..63824a5c4 100644 --- a/src/common/socket_types.h +++ b/src/common/socket_types.h @@ -3,17 +3,22 @@ #pragma once +#include <optional> +#include <string> + #include "common/common_types.h" namespace Network { /// Address families enum class Domain : u8 { - INET, ///< Address family for IPv4 + Unspecified, ///< Represents 0, used in getaddrinfo hints + INET, ///< Address family for IPv4 }; /// Socket types enum class Type { + Unspecified, ///< Represents 0, used in getaddrinfo hints STREAM, DGRAM, RAW, @@ -22,6 +27,7 @@ enum class Type { /// Protocol values for sockets enum class Protocol : u8 { + Unspecified, ///< Represents 0, usable in various places ICMP, TCP, UDP, @@ -48,4 +54,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2; constexpr u32 FLAG_MSG_DONTWAIT = 0x80; constexpr u32 FLAG_O_NONBLOCK = 0x800; +/// Cross-platform addrinfo structure +struct AddrInfo { + Domain family; + Type socket_type; + Protocol protocol; + SockAddrIn addr; + std::optional<std::string> canon_name; +}; + } // namespace Network diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 28cb6f86f..4b7395be8 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -723,6 +723,7 @@ add_library(core STATIC hle/service/spl/spl_types.h hle/service/ssl/ssl.cpp hle/service/ssl/ssl.h + hle/service/ssl/ssl_backend.h hle/service/time/clock_types.h hle/service/time/ephemeral_network_system_clock_context_writer.h hle/service/time/ephemeral_network_system_clock_core.h @@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64) target_link_libraries(core PRIVATE dynarmic::dynarmic) endif() +if(ENABLE_OPENSSL) + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_openssl.cpp) + target_link_libraries(core PRIVATE OpenSSL::SSL) +elseif (APPLE) + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_securetransport.cpp) + target_link_libraries(core PRIVATE "-framework Security") +elseif (WIN32) + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_schannel.cpp) + target_link_libraries(core PRIVATE crypt32 secur32) +else() + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_none.cpp) +endif() + if (YUZU_USE_PRECOMPILED_HEADERS) target_precompile_headers(core PRIVATE precompiled_headers.h) endif() diff --git a/src/core/hle/kernel/k_thread.cpp b/src/core/hle/kernel/k_thread.cpp index adb6ec581..d88909889 100644 --- a/src/core/hle/kernel/k_thread.cpp +++ b/src/core/hle/kernel/k_thread.cpp @@ -302,12 +302,12 @@ Result KThread::InitializeServiceThread(Core::System& system, KThread* thread, std::function<void()>&& func, s32 prio, s32 virt_core, KProcess* owner) { system.Kernel().GlobalSchedulerContext().AddThread(thread); - std::function<void()> func2{[&system, func{std::move(func)}] { + std::function<void()> func2{[&system, func_{std::move(func)}] { // Similar to UserModeThreadStarter. system.Kernel().CurrentScheduler()->OnThreadStart(); // Run the guest function. - func(); + func_(); // Exit. Svc::ExitThread(system); diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index f33600ca5..ebe7582c6 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -1089,15 +1089,15 @@ static std::jthread RunHostThreadFunc(KernelCore& kernel, KProcess* process, KThread::Register(kernel, thread); return std::jthread( - [&kernel, thread, thread_name{std::move(thread_name)}, func{std::move(func)}] { + [&kernel, thread, thread_name_{std::move(thread_name)}, func_{std::move(func)}] { // Set the thread name. - Common::SetCurrentThreadName(thread_name.c_str()); + Common::SetCurrentThreadName(thread_name_.c_str()); // Set the thread as current. kernel.RegisterHostThread(thread); // Run the callback. - func(); + func_(); // Close the thread. // This will free the process if it is the last reference. diff --git a/src/core/hle/service/am/am.cpp b/src/core/hle/service/am/am.cpp index a2375508a..4f400d341 100644 --- a/src/core/hle/service/am/am.cpp +++ b/src/core/hle/service/am/am.cpp @@ -506,8 +506,8 @@ void ISelfController::SetHandlesRequestToDisplay(HLERequestContext& ctx) { void ISelfController::SetIdleTimeDetectionExtension(HLERequestContext& ctx) { IPC::RequestParser rp{ctx}; idle_time_detection_extension = rp.Pop<u32>(); - LOG_WARNING(Service_AM, "(STUBBED) called idle_time_detection_extension={}", - idle_time_detection_extension); + LOG_DEBUG(Service_AM, "(STUBBED) called idle_time_detection_extension={}", + idle_time_detection_extension); IPC::ResponseBuilder rb{ctx, 2}; rb.Push(ResultSuccess); diff --git a/src/core/hle/service/nifm/nifm.cpp b/src/core/hle/service/nifm/nifm.cpp index 91d42853e..21b06d10b 100644 --- a/src/core/hle/service/nifm/nifm.cpp +++ b/src/core/hle/service/nifm/nifm.cpp @@ -7,6 +7,7 @@ #include "core/hle/service/kernel_helpers.h" #include "core/hle/service/nifm/nifm.h" #include "core/hle/service/server_manager.h" +#include "network/network.h" namespace { diff --git a/src/core/hle/service/nifm/nifm.h b/src/core/hle/service/nifm/nifm.h index 9b20e6823..ae99c4695 100644 --- a/src/core/hle/service/nifm/nifm.h +++ b/src/core/hle/service/nifm/nifm.h @@ -4,14 +4,15 @@ #pragma once #include "core/hle/service/service.h" -#include "network/network.h" -#include "network/room.h" -#include "network/room_member.h" namespace Core { class System; } +namespace Network { +class RoomNetwork; +} + namespace Service::NIFM { void LoopProcess(Core::System& system); diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp index bce45d321..e63b0a357 100644 --- a/src/core/hle/service/sockets/bsd.cpp +++ b/src/core/hle/service/sockets/bsd.cpp @@ -20,6 +20,9 @@ #include "core/internal_network/sockets.h" #include "network/network.h" +using Common::Expected; +using Common::Unexpected; + namespace Service::Sockets { namespace { @@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) { const u32 level = rp.Pop<u32>(); const auto optname = static_cast<OptName>(rp.Pop<u32>()); - LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname); - std::vector<u8> optval(ctx.GetWriteBufferSize()); + LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname, + optval.size()); + + const Errno err = GetSockOptImpl(fd, level, optname, optval); + ctx.WriteBuffer(optval); IPC::ResponseBuilder rb{ctx, 5}; rb.Push(ResultSuccess); - rb.Push<s32>(-1); - rb.PushEnum(Errno::NOTCONN); + rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1); + rb.PushEnum(err); rb.Push<u32>(static_cast<u32>(optval.size())); } @@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) { BuildErrnoResponse(ctx, CloseImpl(fd)); } +void BSD::DuplicateSocket(HLERequestContext& ctx) { + struct InputParameters { + s32 fd; + u64 reserved; + }; + static_assert(sizeof(InputParameters) == 0x10); + + struct OutputParameters { + s32 ret; + Errno bsd_errno; + }; + static_assert(sizeof(OutputParameters) == 0x8); + + IPC::RequestParser rp{ctx}; + auto input = rp.PopRaw<InputParameters>(); + + Expected<s32, Errno> res = DuplicateSocketImpl(input.fd); + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .ret = res.value_or(0), + .bsd_errno = res ? Errno::SUCCESS : res.error(), + }); +} + void BSD::EventFd(HLERequestContext& ctx) { IPC::RequestParser rp{ctx}; const u64 initval = rp.Pop<u64>(); @@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco auto room_member = room_network.GetRoomMember().lock(); if (room_member && room_member->IsConnected()) { - descriptor.socket = std::make_unique<Network::ProxySocket>(room_network); + descriptor.socket = std::make_shared<Network::ProxySocket>(room_network); } else { - descriptor.socket = std::make_unique<Network::Socket>(); + descriptor.socket = std::make_shared<Network::Socket>(); } - descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol)); + descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol)); descriptor.is_connection_based = IsConnectionBased(type); return {fd, Errno::SUCCESS}; @@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) { Network::PollFD result; result.socket = file_descriptors[pollfd.fd]->socket.get(); - result.events = TranslatePollEventsToHost(pollfd.events); + result.events = Translate(pollfd.events); result.revents = Network::PollEvents{}; return result; }); @@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con const size_t num = host_pollfds.size(); for (size_t i = 0; i < num; ++i) { - fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents); + fds[i].revents = Translate(host_pollfds[i].revents); } std::memcpy(write_buffer.data(), fds.data(), length); @@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) { } const SockAddrIn guest_addrin = Translate(addr_in); - ASSERT(write_buffer.size() == sizeof(guest_addrin)); + ASSERT(write_buffer.size() >= sizeof(guest_addrin)); + write_buffer.resize(sizeof(guest_addrin)); std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); return Translate(bsd_errno); } @@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) { } const SockAddrIn guest_addrin = Translate(addr_in); - ASSERT(write_buffer.size() == sizeof(guest_addrin)); + ASSERT(write_buffer.size() >= sizeof(guest_addrin)); + write_buffer.resize(sizeof(guest_addrin)); std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); return Translate(bsd_errno); } @@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) { } } -Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { - UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET +Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) { + if (!IsFileDescriptorValid(fd)) { + return Errno::BADF; + } + + if (level != static_cast<u32>(SocketLevel::SOCKET)) { + UNIMPLEMENTED_MSG("Unknown getsockopt level"); + return Errno::SUCCESS; + } + + Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); + + switch (optname) { + case OptName::ERROR_: { + auto [pending_err, getsockopt_err] = socket->GetPendingError(); + if (getsockopt_err == Network::Errno::SUCCESS) { + Errno translated_pending_err = Translate(pending_err); + ASSERT_OR_EXECUTE_MSG( + optval.size() == sizeof(Errno), { return Errno::INVAL; }, + "Incorrect getsockopt option size"); + optval.resize(sizeof(Errno)); + memcpy(optval.data(), &translated_pending_err, sizeof(Errno)); + } + return Translate(getsockopt_err); + } + default: + UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); + return Errno::SUCCESS; + } +} +Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { if (!IsFileDescriptorValid(fd)) { return Errno::BADF; } + if (level != static_cast<u32>(SocketLevel::SOCKET)) { + UNIMPLEMENTED_MSG("Unknown setsockopt level"); + return Errno::SUCCESS; + } + Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); if (optname == OptName::LINGER) { @@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con return Translate(socket->SetSndTimeo(value)); case OptName::RCVTIMEO: return Translate(socket->SetRcvTimeo(value)); + case OptName::NOSIGPIPE: + LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value); + return Errno::SUCCESS; default: UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); return Errno::SUCCESS; @@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) { return bsd_errno; } +Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) { + if (!IsFileDescriptorValid(fd)) { + return Unexpected(Errno::BADF); + } + + const s32 new_fd = FindFreeFileDescriptorHandle(); + if (new_fd < 0) { + LOG_ERROR(Service, "No more file descriptors available"); + return Unexpected(Errno::MFILE); + } + + file_descriptors[new_fd] = file_descriptors[fd]; + return new_fd; +} + +std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) { + if (!IsFileDescriptorValid(fd)) { + return std::nullopt; + } + return file_descriptors[fd]->socket; +} + s32 BSD::FindFreeFileDescriptorHandle() noexcept { for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) { if (!file_descriptors[fd]) { @@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name) {24, &BSD::Write, "Write"}, {25, &BSD::Read, "Read"}, {26, &BSD::Close, "Close"}, - {27, nullptr, "DuplicateSocket"}, + {27, &BSD::DuplicateSocket, "DuplicateSocket"}, {28, nullptr, "GetResourceStatistics"}, {29, nullptr, "RecvMMsg"}, {30, nullptr, "SendMMsg"}, diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h index 30ae9c140..430edb97c 100644 --- a/src/core/hle/service/sockets/bsd.h +++ b/src/core/hle/service/sockets/bsd.h @@ -8,6 +8,7 @@ #include <vector> #include "common/common_types.h" +#include "common/expected.h" #include "common/socket_types.h" #include "core/hle/service/service.h" #include "core/hle/service/sockets/sockets.h" @@ -29,12 +30,19 @@ public: explicit BSD(Core::System& system_, const char* name); ~BSD() override; + // These methods are called from SSL; the first two are also called from + // this class for the corresponding IPC methods. + // On the real device, the SSL service makes IPC calls to this service. + Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd); + Errno CloseImpl(s32 fd); + std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd); + private: /// Maximum number of file descriptors static constexpr size_t MAX_FD = 128; struct FileDescriptor { - std::unique_ptr<Network::SocketBase> socket; + std::shared_ptr<Network::SocketBase> socket; s32 flags = 0; bool is_connection_based = false; }; @@ -138,6 +146,7 @@ private: void Write(HLERequestContext& ctx); void Read(HLERequestContext& ctx); void Close(HLERequestContext& ctx); + void DuplicateSocket(HLERequestContext& ctx); void EventFd(HLERequestContext& ctx); template <typename Work> @@ -153,6 +162,7 @@ private: Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer); Errno ListenImpl(s32 fd, s32 backlog); std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); + Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval); Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); Errno ShutdownImpl(s32 fd, s32 how); std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message); @@ -161,7 +171,6 @@ private: std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message); std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message, std::span<const u8> addr); - Errno CloseImpl(s32 fd); s32 FindFreeFileDescriptorHandle() noexcept; bool IsFileDescriptorValid(s32 fd) const noexcept; diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp index 6491a73be..0dfb0f166 100644 --- a/src/core/hle/service/sockets/nsd.cpp +++ b/src/core/hle/service/sockets/nsd.cpp @@ -1,10 +1,15 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "core/hle/service/ipc_helpers.h" #include "core/hle/service/sockets/nsd.h" +#include "common/string_util.h" + namespace Service::Sockets { +constexpr Result ResultOverflow{ErrorModule::NSD, 6}; + NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { // clang-format off static const FunctionInfo functions[] = { @@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na {13, nullptr, "DeleteSettings"}, {14, nullptr, "ImportSettings"}, {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, - {20, nullptr, "Resolve"}, - {21, nullptr, "ResolveEx"}, + {20, &NSD::Resolve, "Resolve"}, + {21, &NSD::ResolveEx, "ResolveEx"}, {30, nullptr, "GetNasServiceSetting"}, {31, nullptr, "GetNasServiceSettingEx"}, {40, nullptr, "GetNasRequestFqdn"}, @@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na RegisterHandlers(functions); } +static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) { + // The real implementation makes various substitutions. + // For now we just return the string as-is, which is good enough when not + // connecting to real Nintendo servers. + LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in); + return fqdn_in; +} + +static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) { + const auto res = ResolveImpl(fqdn_in); + if (res.Failed()) { + return res.Code(); + } + if (res->size() >= fqdn_out.size()) { + return ResultOverflow; + } + std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1); + return ResultSuccess; +} + +void NSD::Resolve(HLERequestContext& ctx) { + const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); + + std::array<char, 0x100> fqdn_out{}; + const Result res = ResolveCommon(fqdn_in, fqdn_out); + + ctx.WriteBuffer(fqdn_out); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); +} + +void NSD::ResolveEx(HLERequestContext& ctx) { + const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); + + std::array<char, 0x100> fqdn_out; + const Result res = ResolveCommon(fqdn_in, fqdn_out); + + if (res.IsError()) { + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + return; + } + + ctx.WriteBuffer(fqdn_out); + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(ResultSuccess); + rb.Push(ResultSuccess); +} + NSD::~NSD() = default; } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h index 5cc12b855..a7379a8a9 100644 --- a/src/core/hle/service/sockets/nsd.h +++ b/src/core/hle/service/sockets/nsd.h @@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> { public: explicit NSD(Core::System& system_, const char* name); ~NSD() override; + +private: + void Resolve(HLERequestContext& ctx); + void ResolveEx(HLERequestContext& ctx); }; } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp index 132dd5797..84cc79de8 100644 --- a/src/core/hle/service/sockets/sfdnsres.cpp +++ b/src/core/hle/service/sockets/sfdnsres.cpp @@ -10,27 +10,18 @@ #include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/sockets/sfdnsres.h" +#include "core/hle/service/sockets/sockets.h" +#include "core/hle/service/sockets/sockets_translate.h" +#include "core/internal_network/network.h" #include "core/memory.h" -#ifdef _WIN32 -#include <ws2tcpip.h> -#elif YUZU_UNIX -#include <arpa/inet.h> -#include <netdb.h> -#include <netinet/in.h> -#include <sys/socket.h> -#ifndef EAI_NODATA -#define EAI_NODATA EAI_NONAME -#endif -#endif - namespace Service::Sockets { SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { static const FunctionInfo functions[] = { {0, nullptr, "SetDnsAddressesPrivateRequest"}, {1, nullptr, "GetDnsAddressPrivateRequest"}, - {2, nullptr, "GetHostByNameRequest"}, + {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"}, {3, nullptr, "GetHostByAddrRequest"}, {4, nullptr, "GetHostStringErrorRequest"}, {5, nullptr, "GetGaiStringErrorRequest"}, @@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres" {7, nullptr, "GetNameInfoRequest"}, {8, nullptr, "RequestCancelHandleRequest"}, {9, nullptr, "CancelRequest"}, - {10, nullptr, "GetHostByNameRequestWithOptions"}, + {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"}, {11, nullptr, "GetHostByAddrRequestWithOptions"}, {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, {13, nullptr, "GetNameInfoRequestWithOptions"}, - {14, nullptr, "ResolverSetOptionRequest"}, + {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"}, {15, nullptr, "ResolverGetOptionRequest"}, }; RegisterHandlers(functions); @@ -59,188 +50,285 @@ enum class NetDbError : s32 { NoData = 4, }; -static NetDbError AddrInfoErrorToNetDbError(s32 result) { - // Best effort guess to map errors +static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) { + // These combinations have been verified on console (but are not + // exhaustive). switch (result) { - case 0: + case GetAddrInfoError::SUCCESS: return NetDbError::Success; - case EAI_AGAIN: + case GetAddrInfoError::AGAIN: return NetDbError::TryAgain; - case EAI_NODATA: - return NetDbError::NoData; + case GetAddrInfoError::NODATA: + return NetDbError::HostNotFound; + case GetAddrInfoError::SERVICE: + return NetDbError::Success; default: return NetDbError::HostNotFound; } } -static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code, +static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) { + // These combinations have been verified on console (but are not + // exhaustive). + switch (result) { + case GetAddrInfoError::SUCCESS: + // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for + // some reason, but that doesn't seem useful to implement. + return Errno::SUCCESS; + case GetAddrInfoError::AGAIN: + return Errno::SUCCESS; + case GetAddrInfoError::NODATA: + return Errno::SUCCESS; + case GetAddrInfoError::SERVICE: + return Errno::INVAL; + default: + return Errno::SUCCESS; + } +} + +template <typename T> +static void Append(std::vector<u8>& vec, T t) { + const size_t offset = vec.size(); + vec.resize(offset + sizeof(T)); + std::memcpy(vec.data() + offset, &t, sizeof(T)); +} + +static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) { + const size_t offset = vec.size(); + vec.resize(offset + str.size() + 1); + std::memmove(vec.data() + offset, str.data(), str.size()); +} + +// We implement gethostbyname using the host's getaddrinfo rather than the +// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo +// behaves the same on Unix and Windows, unlike gethostbyname where Windows +// doesn't implement h_errno. +static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec, + std::string_view host) { + + std::vector<u8> data; + // h_name: use the input hostname (append nul-terminated) + AppendNulTerminated(data, host); + // h_aliases: leave empty + + Append<u32_be>(data, 0); // count of h_aliases + // (If the count were nonzero, the aliases would be appended as nul-terminated here.) + Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype + Append<u16_be>(data, sizeof(Network::IPv4Address)); // h_length + // h_addr_list: + size_t count = vec.size(); + ASSERT(count <= UINT32_MAX); + Append<u32_be>(data, static_cast<uint32_t>(count)); + for (const Network::AddrInfo& addrinfo : vec) { + // On the Switch, this is passed through htonl despite already being + // big-endian, so it ends up as little-endian. + Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); + + LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, + Network::IPv4AddressToString(addrinfo.addr.ip)); + } + return data; +} + +static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) { + struct InputParameters { + u8 use_nsd_resolve; + u32 cancel_handle; + u64 process_id; + }; + static_assert(sizeof(InputParameters) == 0x10); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw<InputParameters>(); + + LOG_WARNING( + Service, + "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", + parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); + + const auto host_buffer = ctx.ReadBuffer(0); + const std::string host = Common::StringFromBuffer(host_buffer); + // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions. + + auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt); + if (!res.has_value()) { + return {0, Translate(res.error())}; + } + + const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host); + const u32 data_size = static_cast<u32>(data.size()); + ctx.WriteBuffer(data, 0); + + return {data_size, GetAddrInfoError::SUCCESS}; +} + +void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) { + auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); + + struct OutputParameters { + NetDbError netdb_error; + Errno bsd_errno; + u32 data_size; + }; + static_assert(sizeof(OutputParameters) == 0xc); + + IPC::ResponseBuilder rb{ctx, 5}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + .data_size = data_size, + }); +} + +void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) { + auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); + + struct OutputParameters { + u32 data_size; + NetDbError netdb_error; + Errno bsd_errno; + }; + static_assert(sizeof(OutputParameters) == 0xc); + + IPC::ResponseBuilder rb{ctx, 5}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .data_size = data_size, + .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + }); +} + +static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec, std::string_view host) { // Adapted from // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 std::vector<u8> data; - auto* current = addrinfo; - while (current != nullptr) { - struct SerializedResponseHeader { - u32 magic; - s32 flags; - s32 family; - s32 socket_type; - s32 protocol; - u32 address_length; - }; - static_assert(sizeof(SerializedResponseHeader) == 0x18, - "Response header size must be 0x18 bytes"); - - constexpr auto header_size = sizeof(SerializedResponseHeader); - const auto addr_size = - current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4; - const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1; - - const auto last_size = data.size(); - data.resize(last_size + header_size + addr_size + canonname_size); - - // Header in network byte order - SerializedResponseHeader header{}; - - constexpr auto HEADER_MAGIC = 0xBEEFCAFE; - header.magic = htonl(HEADER_MAGIC); - header.family = htonl(current->ai_family); - header.flags = htonl(current->ai_flags); - header.socket_type = htonl(current->ai_socktype); - header.protocol = htonl(current->ai_protocol); - header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0; - - auto* header_ptr = data.data() + last_size; - std::memcpy(header_ptr, &header, header_size); - - if (header.address_length == 0) { - std::memset(header_ptr + header_size, 0, 4); - } else { - switch (current->ai_family) { - case AF_INET: { - struct SockAddrIn { - s16 sin_family; - u16 sin_port; - u32 sin_addr; - u8 sin_zero[8]; - }; - - SockAddrIn serialized_addr{}; - const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr); - serialized_addr.sin_port = htons(addr.sin_port); - serialized_addr.sin_family = htons(addr.sin_family); - serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr); - std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn)); - - char addr_string_buf[64]{}; - inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf)); - LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf); - break; - } - case AF_INET6: { - struct SockAddrIn6 { - s16 sin6_family; - u16 sin6_port; - u32 sin6_flowinfo; - u8 sin6_addr[16]; - u32 sin6_scope_id; - }; - - SockAddrIn6 serialized_addr{}; - const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr); - serialized_addr.sin6_family = htons(addr.sin6_family); - serialized_addr.sin6_port = htons(addr.sin6_port); - serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo); - serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id); - std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr, - sizeof(SockAddrIn6::sin6_addr)); - std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6)); - - char addr_string_buf[64]{}; - inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf)); - LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf); - break; - } - default: - std::memcpy(header_ptr + header_size, current->ai_addr, addr_size); - break; - } - } - if (current->ai_canonname) { - std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size); + for (const Network::AddrInfo& addrinfo : vec) { + // serialized addrinfo: + Append<u32_be>(data, 0xBEEFCAFE); // magic + Append<u32_be>(data, 0); // ai_flags + Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family))); // ai_family + Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype + Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol))); // ai_protocol + Append<u32_be>(data, sizeof(SockAddrIn)); // ai_addrlen + // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size + + // ai_addr: + Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family + // On the Switch, the following fields are passed through htonl despite + // already being big-endian, so they end up as little-endian. + Append<u16_le>(data, addrinfo.addr.portno); // sin_port + Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr + data.resize(data.size() + 8, 0); // sin_zero + + if (addrinfo.canon_name.has_value()) { + AppendNulTerminated(data, *addrinfo.canon_name); } else { - *(header_ptr + header_size + addr_size) = 0; + data.push_back(0); } - current = current->ai_next; + LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, + Network::IPv4AddressToString(addrinfo.addr.ip)); } - // 4-byte sentinel value - data.push_back(0); - data.push_back(0); - data.push_back(0); - data.push_back(0); + data.resize(data.size() + 4, 0); // 4-byte sentinel value return data; } -static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) { - struct Parameters { +static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) { + struct InputParameters { u8 use_nsd_resolve; - u32 unknown; + u32 cancel_handle; u64 process_id; }; + static_assert(sizeof(InputParameters) == 0x10); IPC::RequestParser rp{ctx}; - const auto parameters = rp.PopRaw<Parameters>(); + const auto parameters = rp.PopRaw<InputParameters>(); + + LOG_WARNING( + Service, + "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", + parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); - LOG_WARNING(Service, - "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}", - parameters.use_nsd_resolve, parameters.unknown, parameters.process_id); + // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve + // before looking up. const auto host_buffer = ctx.ReadBuffer(0); const std::string host = Common::StringFromBuffer(host_buffer); - const auto service_buffer = ctx.ReadBuffer(1); - const std::string service = Common::StringFromBuffer(service_buffer); - - addrinfo* addrinfo; - // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now - s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo); + std::optional<std::string> service = std::nullopt; + if (ctx.CanReadBuffer(1)) { + const std::span<const u8> service_buffer = ctx.ReadBuffer(1); + service = Common::StringFromBuffer(service_buffer); + } - u32 data_size = 0; - if (result_code == 0 && addrinfo != nullptr) { - const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host); - data_size = static_cast<u32>(data.size()); - freeaddrinfo(addrinfo); + // Serialized hints are also passed in a buffer, but are ignored for now. - ctx.WriteBuffer(data, 0); + auto res = Network::GetAddressInfo(host, service); + if (!res.has_value()) { + return {0, Translate(res.error())}; } - return std::make_pair(data_size, result_code); + const std::vector<u8> data = SerializeAddrInfo(res.value(), host); + const u32 data_size = static_cast<u32>(data.size()); + ctx.WriteBuffer(data, 0); + + return {data_size, GetAddrInfoError::SUCCESS}; } void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { - auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); + auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); + + struct OutputParameters { + Errno bsd_errno; + GetAddrInfoError gai_error; + u32 data_size; + }; + static_assert(sizeof(OutputParameters) == 0xc); - IPC::ResponseBuilder rb{ctx, 4}; + IPC::ResponseBuilder rb{ctx, 5}; rb.Push(ResultSuccess); - rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode - rb.Push(result_code); // errno - rb.Push(data_size); // serialized size + rb.PushRaw(OutputParameters{ + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + .gai_error = emu_gai_err, + .data_size = data_size, + }); } void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { // Additional options are ignored - auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); + auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); + + struct OutputParameters { + u32 data_size; + GetAddrInfoError gai_error; + NetDbError netdb_error; + Errno bsd_errno; + }; + static_assert(sizeof(OutputParameters) == 0x10); + + IPC::ResponseBuilder rb{ctx, 6}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .data_size = data_size, + .gai_error = emu_gai_err, + .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + }); +} + +void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) { + LOG_WARNING(Service, "(STUBBED) called"); + + IPC::ResponseBuilder rb{ctx, 3}; - IPC::ResponseBuilder rb{ctx, 5}; rb.Push(ResultSuccess); - rb.Push(data_size); // serialized size - rb.Push(result_code); // errno - rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode - rb.Push(0); + rb.Push<s32>(0); // bsd errno } } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h index 18e3cd60c..d99a9d560 100644 --- a/src/core/hle/service/sockets/sfdnsres.h +++ b/src/core/hle/service/sockets/sfdnsres.h @@ -17,8 +17,11 @@ public: ~SFDNSRES() override; private: + void GetHostByNameRequest(HLERequestContext& ctx); + void GetHostByNameRequestWithOptions(HLERequestContext& ctx); void GetAddrInfoRequest(HLERequestContext& ctx); void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); + void ResolverSetOptionRequest(HLERequestContext& ctx); }; } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h index acd2dae7b..77426c46e 100644 --- a/src/core/hle/service/sockets/sockets.h +++ b/src/core/hle/service/sockets/sockets.h @@ -22,13 +22,35 @@ enum class Errno : u32 { CONNRESET = 104, NOTCONN = 107, TIMEDOUT = 110, + INPROGRESS = 115, +}; + +enum class GetAddrInfoError : s32 { + SUCCESS = 0, + ADDRFAMILY = 1, + AGAIN = 2, + BADFLAGS = 3, + FAIL = 4, + FAMILY = 5, + MEMORY = 6, + NODATA = 7, + NONAME = 8, + SERVICE = 9, + SOCKTYPE = 10, + SYSTEM = 11, + BADHINTS = 12, + PROTOCOL = 13, + OVERFLOW_ = 14, // avoid name collision with Windows macro + OTHER = 15, }; enum class Domain : u32 { + Unspecified = 0, INET = 2, }; enum class Type : u32 { + Unspecified = 0, STREAM = 1, DGRAM = 2, RAW = 3, @@ -36,12 +58,16 @@ enum class Type : u32 { }; enum class Protocol : u32 { - UNSPECIFIED = 0, + Unspecified = 0, ICMP = 1, TCP = 6, UDP = 17, }; +enum class SocketLevel : u32 { + SOCKET = 0xffff, // i.e. SOL_SOCKET +}; + enum class OptName : u32 { REUSEADDR = 0x4, KEEPALIVE = 0x8, @@ -51,6 +77,8 @@ enum class OptName : u32 { RCVBUF = 0x1002, SNDTIMEO = 0x1005, RCVTIMEO = 0x1006, + ERROR_ = 0x1007, // avoid name collision with Windows macro + NOSIGPIPE = 0x800, // at least according to libnx }; enum class ShutdownHow : s32 { @@ -80,6 +108,9 @@ enum class PollEvents : u16 { Err = 1 << 3, Hup = 1 << 4, Nval = 1 << 5, + RdNorm = 1 << 6, + RdBand = 1 << 7, + WrBand = 1 << 8, }; DECLARE_ENUM_FLAG_OPERATORS(PollEvents); diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp index 594e58f90..2f9a0e39c 100644 --- a/src/core/hle/service/sockets/sockets_translate.cpp +++ b/src/core/hle/service/sockets/sockets_translate.cpp @@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) { return Errno::TIMEDOUT; case Network::Errno::CONNRESET: return Errno::CONNRESET; + case Network::Errno::INPROGRESS: + return Errno::INPROGRESS; default: UNIMPLEMENTED_MSG("Unimplemented errno={}", value); return Errno::SUCCESS; @@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) { return {value.first, Translate(value.second)}; } +GetAddrInfoError Translate(Network::GetAddrInfoError error) { + switch (error) { + case Network::GetAddrInfoError::SUCCESS: + return GetAddrInfoError::SUCCESS; + case Network::GetAddrInfoError::ADDRFAMILY: + return GetAddrInfoError::ADDRFAMILY; + case Network::GetAddrInfoError::AGAIN: + return GetAddrInfoError::AGAIN; + case Network::GetAddrInfoError::BADFLAGS: + return GetAddrInfoError::BADFLAGS; + case Network::GetAddrInfoError::FAIL: + return GetAddrInfoError::FAIL; + case Network::GetAddrInfoError::FAMILY: + return GetAddrInfoError::FAMILY; + case Network::GetAddrInfoError::MEMORY: + return GetAddrInfoError::MEMORY; + case Network::GetAddrInfoError::NODATA: + return GetAddrInfoError::NODATA; + case Network::GetAddrInfoError::NONAME: + return GetAddrInfoError::NONAME; + case Network::GetAddrInfoError::SERVICE: + return GetAddrInfoError::SERVICE; + case Network::GetAddrInfoError::SOCKTYPE: + return GetAddrInfoError::SOCKTYPE; + case Network::GetAddrInfoError::SYSTEM: + return GetAddrInfoError::SYSTEM; + case Network::GetAddrInfoError::BADHINTS: + return GetAddrInfoError::BADHINTS; + case Network::GetAddrInfoError::PROTOCOL: + return GetAddrInfoError::PROTOCOL; + case Network::GetAddrInfoError::OVERFLOW_: + return GetAddrInfoError::OVERFLOW_; + case Network::GetAddrInfoError::OTHER: + return GetAddrInfoError::OTHER; + default: + UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error); + return GetAddrInfoError::OTHER; + } +} + Network::Domain Translate(Domain domain) { switch (domain) { + case Domain::Unspecified: + return Network::Domain::Unspecified; case Domain::INET: return Network::Domain::INET; default: @@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) { Domain Translate(Network::Domain domain) { switch (domain) { + case Network::Domain::Unspecified: + return Domain::Unspecified; case Network::Domain::INET: return Domain::INET; default: @@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) { Network::Type Translate(Type type) { switch (type) { + case Type::Unspecified: + return Network::Type::Unspecified; case Type::STREAM: return Network::Type::STREAM; case Type::DGRAM: return Network::Type::DGRAM; + case Type::RAW: + return Network::Type::RAW; + case Type::SEQPACKET: + return Network::Type::SEQPACKET; default: UNIMPLEMENTED_MSG("Unimplemented type={}", type); return Network::Type{}; } } -Network::Protocol Translate(Type type, Protocol protocol) { +Type Translate(Network::Type type) { + switch (type) { + case Network::Type::Unspecified: + return Type::Unspecified; + case Network::Type::STREAM: + return Type::STREAM; + case Network::Type::DGRAM: + return Type::DGRAM; + case Network::Type::RAW: + return Type::RAW; + case Network::Type::SEQPACKET: + return Type::SEQPACKET; + default: + UNIMPLEMENTED_MSG("Unimplemented type={}", type); + return Type{}; + } +} + +Network::Protocol Translate(Protocol protocol) { switch (protocol) { - case Protocol::UNSPECIFIED: - LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type"); - switch (type) { - case Type::DGRAM: - return Network::Protocol::UDP; - case Type::STREAM: - return Network::Protocol::TCP; - default: - return Network::Protocol::TCP; - } + case Protocol::Unspecified: + return Network::Protocol::Unspecified; case Protocol::TCP: return Network::Protocol::TCP; case Protocol::UDP: return Network::Protocol::UDP; default: UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); - return Network::Protocol::TCP; + return Network::Protocol::Unspecified; + } +} + +Protocol Translate(Network::Protocol protocol) { + switch (protocol) { + case Network::Protocol::Unspecified: + return Protocol::Unspecified; + case Network::Protocol::TCP: + return Protocol::TCP; + case Network::Protocol::UDP: + return Protocol::UDP; + default: + UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); + return Protocol::Unspecified; } } -Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { +Network::PollEvents Translate(PollEvents flags) { Network::PollEvents result{}; const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { if (True(flags & from)) { @@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { translate(PollEvents::Err, Network::PollEvents::Err); translate(PollEvents::Hup, Network::PollEvents::Hup); translate(PollEvents::Nval, Network::PollEvents::Nval); + translate(PollEvents::RdNorm, Network::PollEvents::RdNorm); + translate(PollEvents::RdBand, Network::PollEvents::RdBand); + translate(PollEvents::WrBand, Network::PollEvents::WrBand); UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); return result; } -PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { +PollEvents Translate(Network::PollEvents flags) { PollEvents result{}; const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { if (True(flags & from)) { @@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { translate(Network::PollEvents::Err, PollEvents::Err); translate(Network::PollEvents::Hup, PollEvents::Hup); translate(Network::PollEvents::Nval, PollEvents::Nval); + translate(Network::PollEvents::RdNorm, PollEvents::RdNorm); + translate(Network::PollEvents::RdBand, PollEvents::RdBand); + translate(Network::PollEvents::WrBand, PollEvents::WrBand); UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); return result; } Network::SockAddrIn Translate(SockAddrIn value) { - ASSERT(value.len == 0 || value.len == sizeof(value)); + // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets + // sin_len to 6 when deserializing getaddrinfo results). + ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6); return { .family = Translate(static_cast<Domain>(value.family)), diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h index c93291d3e..694868b37 100644 --- a/src/core/hle/service/sockets/sockets_translate.h +++ b/src/core/hle/service/sockets/sockets_translate.h @@ -17,6 +17,9 @@ Errno Translate(Network::Errno value); /// Translate abstract return value errno pair to guest return value errno pair std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value); +/// Translate abstract getaddrinfo error to guest getaddrinfo error +GetAddrInfoError Translate(Network::GetAddrInfoError value); + /// Translate guest domain to abstract domain Network::Domain Translate(Domain domain); @@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain); /// Translate guest type to abstract type Network::Type Translate(Type type); +/// Translate abstract type to guest type +Type Translate(Network::Type type); + /// Translate guest protocol to abstract protocol -Network::Protocol Translate(Type type, Protocol protocol); +Network::Protocol Translate(Protocol protocol); -/// Translate abstract poll event flags to guest poll event flags -Network::PollEvents TranslatePollEventsToHost(PollEvents flags); +/// Translate abstract protocol to guest protocol +Protocol Translate(Network::Protocol protocol); /// Translate guest poll event flags to abstract poll event flags -PollEvents TranslatePollEventsToGuest(Network::PollEvents flags); +Network::PollEvents Translate(PollEvents flags); + +/// Translate abstract poll event flags to guest poll event flags +PollEvents Translate(Network::PollEvents flags); /// Translate guest socket address structure to abstract socket address structure Network::SockAddrIn Translate(SockAddrIn value); diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 2b99dd7ac..9c96f9763 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -1,10 +1,18 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "common/string_util.h" + +#include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/server_manager.h" #include "core/hle/service/service.h" +#include "core/hle/service/sm/sm.h" +#include "core/hle/service/sockets/bsd.h" #include "core/hle/service/ssl/ssl.h" +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" namespace Service::SSL { @@ -20,6 +28,18 @@ enum class ContextOption : u32 { CrlImportDateCheckEnable = 1, }; +// This is nn::ssl::Connection::IoMode +enum class IoMode : u32 { + Blocking = 1, + NonBlocking = 2, +}; + +// This is nn::ssl::sf::OptionType +enum class OptionType : u32 { + DoNotCloseSocket = 0, + GetServerCertChain = 1, +}; + // This is nn::ssl::sf::SslVersion struct SslVersion { union { @@ -34,35 +54,42 @@ struct SslVersion { }; }; +struct SslContextSharedData { + u32 connection_count = 0; +}; + class ISslConnection final : public ServiceFramework<ISslConnection> { public: - explicit ISslConnection(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { + explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in, + std::shared_ptr<SslContextSharedData>& shared_data_in, + std::unique_ptr<SSLConnectionBackend>&& backend_in) + : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in}, + shared_data{shared_data_in}, backend{std::move(backend_in)} { // clang-format off static const FunctionInfo functions[] = { - {0, nullptr, "SetSocketDescriptor"}, - {1, nullptr, "SetHostName"}, - {2, nullptr, "SetVerifyOption"}, - {3, nullptr, "SetIoMode"}, + {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, + {1, &ISslConnection::SetHostName, "SetHostName"}, + {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, + {3, &ISslConnection::SetIoMode, "SetIoMode"}, {4, nullptr, "GetSocketDescriptor"}, {5, nullptr, "GetHostName"}, {6, nullptr, "GetVerifyOption"}, {7, nullptr, "GetIoMode"}, - {8, nullptr, "DoHandshake"}, - {9, nullptr, "DoHandshakeGetServerCert"}, - {10, nullptr, "Read"}, - {11, nullptr, "Write"}, - {12, nullptr, "Pending"}, + {8, &ISslConnection::DoHandshake, "DoHandshake"}, + {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, + {10, &ISslConnection::Read, "Read"}, + {11, &ISslConnection::Write, "Write"}, + {12, &ISslConnection::Pending, "Pending"}, {13, nullptr, "Peek"}, {14, nullptr, "Poll"}, {15, nullptr, "GetVerifyCertError"}, {16, nullptr, "GetNeededServerCertBufferSize"}, - {17, nullptr, "SetSessionCacheMode"}, + {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, {18, nullptr, "GetSessionCacheMode"}, {19, nullptr, "FlushSessionCache"}, {20, nullptr, "SetRenegotiationMode"}, {21, nullptr, "GetRenegotiationMode"}, - {22, nullptr, "SetOption"}, + {22, &ISslConnection::SetOption, "SetOption"}, {23, nullptr, "GetOption"}, {24, nullptr, "GetVerifyCertErrors"}, {25, nullptr, "GetCipherInfo"}, @@ -80,21 +107,299 @@ public: // clang-format on RegisterHandlers(functions); + + shared_data->connection_count++; + } + + ~ISslConnection() { + shared_data->connection_count--; + if (fd_to_close.has_value()) { + const s32 fd = *fd_to_close; + if (!do_not_close_socket) { + LOG_ERROR(Service_SSL, + "do_not_close_socket was changed after setting socket; is this right?"); + } else { + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + if (bsd) { + auto err = bsd->CloseImpl(fd); + if (err != Service::Sockets::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err); + } + } + } + } } private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data; + std::unique_ptr<SSLConnectionBackend> backend; + std::optional<int> fd_to_close; + bool do_not_close_socket = false; + bool get_server_cert_chain = false; + std::shared_ptr<Network::SocketBase> socket; + bool did_set_host_name = false; + bool did_handshake = false; + + ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { + LOG_DEBUG(Service_SSL, "called, fd={}", fd); + ASSERT(!did_handshake); + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); + s32 ret_fd; + // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor + if (do_not_close_socket) { + auto res = bsd->DuplicateSocketImpl(fd); + if (!res.has_value()) { + LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); + return ResultInvalidSocket; + } + fd = *res; + fd_to_close = fd; + ret_fd = fd; + } else { + ret_fd = -1; + } + std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); + if (!sock.has_value()) { + LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); + return ResultInvalidSocket; + } + socket = std::move(*sock); + backend->SetSocket(socket); + return ret_fd; + } + + Result SetHostNameImpl(const std::string& hostname) { + LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); + ASSERT(!did_handshake); + Result res = backend->SetHostName(hostname); + if (res == ResultSuccess) { + did_set_host_name = true; + } + return res; + } + + Result SetVerifyOptionImpl(u32 option) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); + return ResultSuccess; + } + + Result SetIoModeImpl(u32 input_mode) { + auto mode = static_cast<IoMode>(input_mode); + ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); + ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); + + const bool non_block = mode == IoMode::NonBlocking; + const Network::Errno error = socket->SetNonBlock(non_block); + if (error != Network::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); + } + return ResultSuccess; + } + + Result SetSessionCacheModeImpl(u32 mode) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); + return ResultSuccess; + } + + Result DoHandshakeImpl() { + ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE_MSG( + did_set_host_name, { return ResultInternalError; }, + "Expected SetHostName before DoHandshake"); + Result res = backend->DoHandshake(); + did_handshake = res.IsSuccess(); + return res; + } + + std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { + struct Header { + u64 magic; + u32 count; + u32 pad; + }; + struct EntryHeader { + u32 size; + u32 offset; + }; + if (!get_server_cert_chain) { + // Just return the first one, unencoded. + ASSERT_OR_EXECUTE_MSG( + !certs.empty(), { return {}; }, "Should be at least one server cert"); + return certs[0]; + } + std::vector<u8> ret; + Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; + ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); + size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); + for (auto& cert : certs) { + EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; + data_offset += cert.size(); + ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), + reinterpret_cast<u8*>(&entry_header + 1)); + } + for (auto& cert : certs) { + ret.insert(ret.end(), cert.begin(), cert.end()); + } + return ret; + } + + ResultVal<std::vector<u8>> ReadImpl(size_t size) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + std::vector<u8> res(size); + ResultVal<size_t> actual = backend->Read(res); + if (actual.Failed()) { + return actual.Code(); + } + res.resize(*actual); + return res; + } + + ResultVal<size_t> WriteImpl(std::span<const u8> data) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + return backend->Write(data); + } + + ResultVal<s32> PendingImpl() { + LOG_WARNING(Service_SSL, "(STUBBED) called."); + return 0; + } + + void SetSocketDescriptor(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const s32 fd = rp.Pop<s32>(); + const ResultVal<s32> res = SetSocketDescriptorImpl(fd); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(-1)); + } + + void SetHostName(HLERequestContext& ctx) { + const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); + const Result res = SetHostNameImpl(hostname); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetVerifyOption(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 option = rp.Pop<u32>(); + const Result res = SetVerifyOptionImpl(option); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetIoMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetIoModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshake(HLERequestContext& ctx) { + const Result res = DoHandshakeImpl(); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshakeGetServerCert(HLERequestContext& ctx) { + struct OutputParameters { + u32 certs_size; + u32 certs_count; + }; + static_assert(sizeof(OutputParameters) == 0x8); + + const Result res = DoHandshakeImpl(); + OutputParameters out{}; + if (res == ResultSuccess) { + auto certs = backend->GetServerCerts(); + if (certs.Succeeded()) { + const std::vector<u8> certs_buf = SerializeServerCerts(*certs); + ctx.WriteBuffer(certs_buf); + out.certs_count = static_cast<u32>(certs->size()); + out.certs_size = static_cast<u32>(certs_buf.size()); + } + } + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(res); + rb.PushRaw(out); + } + + void Read(HLERequestContext& ctx) { + const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + if (res.Succeeded()) { + rb.Push(static_cast<u32>(res->size())); + ctx.WriteBuffer(*res); + } else { + rb.Push(static_cast<u32>(0)); + } + } + + void Write(HLERequestContext& ctx) { + const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(static_cast<u32>(res.ValueOr(0))); + } + + void Pending(HLERequestContext& ctx) { + const ResultVal<s32> res = PendingImpl(); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(0)); + } + + void SetSessionCacheMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetSessionCacheModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetOption(HLERequestContext& ctx) { + struct Parameters { + OptionType option; + s32 value; + }; + static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw<Parameters>(); + + switch (parameters.option) { + case OptionType::DoNotCloseSocket: + do_not_close_socket = static_cast<bool>(parameters.value); + break; + case OptionType::GetServerCertChain: + get_server_cert_chain = static_cast<bool>(parameters.value); + break; + default: + LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, + parameters.value); + } + + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(ResultSuccess); + } }; class ISslContext final : public ServiceFramework<ISslContext> { public: explicit ISslContext(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { + : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, + shared_data{std::make_shared<SslContextSharedData>()} { static const FunctionInfo functions[] = { {0, &ISslContext::SetOption, "SetOption"}, {1, nullptr, "GetOption"}, {2, &ISslContext::CreateConnection, "CreateConnection"}, - {3, nullptr, "GetConnectionCount"}, + {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, {4, &ISslContext::ImportServerPki, "ImportServerPki"}, {5, &ISslContext::ImportClientPki, "ImportClientPki"}, {6, nullptr, "RemoveServerPki"}, @@ -111,6 +416,7 @@ public: private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data; void SetOption(HLERequestContext& ctx) { struct Parameters { @@ -130,11 +436,24 @@ private: } void CreateConnection(HLERequestContext& ctx) { - LOG_WARNING(Service_SSL, "(STUBBED) called"); + LOG_WARNING(Service_SSL, "called"); + + auto backend_res = CreateSSLConnectionBackend(); IPC::ResponseBuilder rb{ctx, 2, 0, 1}; + rb.Push(backend_res.Code()); + if (backend_res.Succeeded()) { + rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, + std::move(*backend_res)); + } + } + + void GetConnectionCount(HLERequestContext& ctx) { + LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); + + IPC::ResponseBuilder rb{ctx, 3}; rb.Push(ResultSuccess); - rb.PushIpcInterface<ISslConnection>(system, ssl_version); + rb.Push(shared_data->connection_count); } void ImportServerPki(HLERequestContext& ctx) { diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h new file mode 100644 index 000000000..409f4367c --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend.h @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include <memory> +#include <span> +#include <string> +#include <vector> + +#include "common/common_types.h" + +#include "core/hle/result.h" + +namespace Network { +class SocketBase; +} + +namespace Service::SSL { + +constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; +constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; +constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; +constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up + +// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, +// with no way in the latter case to distinguish whether the client should poll +// for read or write. The one official client I've seen handles this by always +// polling for read (with a timeout). +constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; + +class SSLConnectionBackend { +public: + virtual ~SSLConnectionBackend() {} + virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; + virtual Result SetHostName(const std::string& hostname) = 0; + virtual Result DoHandshake() = 0; + virtual ResultVal<size_t> Read(std::span<u8> data) = 0; + virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; + virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp new file mode 100644 index 000000000..2f4f23c42 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_none.cpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "common/logging/log.h" + +#include "core/hle/service/ssl/ssl_backend.h" + +namespace Service::SSL { + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because no SSL backend is available on this platform"); + return ResultInternalError; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp new file mode 100644 index 000000000..6ca869dbf --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +#include <openssl/bio.h> +#include <openssl/err.h> +#include <openssl/ssl.h> +#include <openssl/x509.h> + +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +using namespace Common::FS; + +namespace Service::SSL { + +// Import OpenSSL's `SSL` type into the namespace. This is needed because the +// namespace is also named `SSL`. +using ::SSL; + +namespace { + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SSL_CTX* ssl_ctx; +IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment +BIO_METHOD* bio_meth; + +Result CheckOpenSSLErrors(); +void OneTimeInit(); +void OneTimeInitLogFile(); +bool OneTimeInitBIO(); + +} // namespace + +class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because OpenSSL one-time initialization failed"); + return ResultInternalError; + } + + ssl = SSL_new(ssl_ctx); + if (!ssl) { + LOG_ERROR(Service_SSL, "SSL_new failed"); + return CheckOpenSSLErrors(); + } + + SSL_set_connect_state(ssl); + + bio = BIO_new(bio_meth); + if (!bio) { + LOG_ERROR(Service_SSL, "BIO_new failed"); + return CheckOpenSSLErrors(); + } + + BIO_set_data(bio, this); + BIO_set_init(bio, 1); + SSL_set_bio(ssl, bio, bio); + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname) override { + if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification + LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); + return CheckOpenSSLErrors(); + } + if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI + LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); + return CheckOpenSSLErrors(); + } + return ResultSuccess; + } + + Result DoHandshake() override { + SSL_set_verify_result(ssl, X509_V_OK); + const int ret = SSL_do_handshake(ssl); + const long verify_result = SSL_get_verify_result(ssl); + if (verify_result != X509_V_OK) { + LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", + X509_verify_cert_error_string(verify_result)); + return CheckOpenSSLErrors(); + } + if (ret <= 0) { + const int ssl_err = SSL_get_error(ssl, ret); + if (ssl_err == SSL_ERROR_ZERO_RETURN || + (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + } + return HandleReturn("SSL_do_handshake", 0, ret).Code(); + } + + ResultVal<size_t> Read(std::span<u8> data) override { + size_t actual; + const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); + return HandleReturn("SSL_read_ex", actual, ret); + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + size_t actual; + const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); + return HandleReturn("SSL_write_ex", actual, ret); + } + + ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { + const int ssl_err = SSL_get_error(ssl, ret); + CheckOpenSSLErrors(); + switch (ssl_err) { + case SSL_ERROR_NONE: + return actual; + case SSL_ERROR_ZERO_RETURN: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); + // DoHandshake special-cases this, but for Read and Write: + return size_t(0); + case SSL_ERROR_WANT_READ: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); + return ResultWouldBlock; + case SSL_ERROR_WANT_WRITE: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); + return ResultWouldBlock; + default: + if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); + return size_t(0); + } + LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); + return ResultInternalError; + } + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); + if (!chain) { + LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); + return ResultInternalError; + } + std::vector<std::vector<u8>> ret; + int count = sk_X509_num(chain); + ASSERT(count >= 0); + for (int i = 0; i < count; i++) { + X509* x509 = sk_X509_value(chain, i); + ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); + unsigned char* buf = nullptr; + int len = i2d_X509(x509, &buf); + ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); + ret.emplace_back(buf, buf + len); + OPENSSL_free(buf); + } + return ret; + } + + ~SSLConnectionBackendOpenSSL() { + // these are null-tolerant: + SSL_free(ssl); + BIO_free(bio); + } + + static void KeyLogCallback(const SSL* ssl, const char* line) { + std::string str(line); + str.push_back('\n'); + // Do this in a single WriteString for atomicity if multiple instances + // are running on different threads (though that can't currently + // happen). + if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { + LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); + } + LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); + } + + static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { + auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + return 1; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return -1; + } + } + + static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { + auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + if (actual == 0) { + self->got_read_eof = true; + } + return actual ? 1 : 0; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return -1; + } + } + + static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) { + switch (cmd) { + case BIO_CTRL_FLUSH: + // Nothing to flush. + return 1; + case BIO_CTRL_PUSH: + case BIO_CTRL_POP: +#ifdef BIO_CTRL_GET_KTLS_SEND + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: +#endif + // We don't support these operations, but don't bother logging them + // as they're nothing unusual. + return 0; + default: + LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg); + return 0; + } + } + + SSL* ssl = nullptr; + BIO* bio = nullptr; + bool got_read_eof = false; + + std::shared_ptr<Network::SocketBase> socket; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +namespace { + +Result CheckOpenSSLErrors() { + unsigned long rc; + const char* file; + int line; + const char* func; + const char* data; + int flags; +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) +#else + // Can't get function names from OpenSSL on this version, so use mine: + func = __func__; + while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags))) +#endif + { + std::string msg; + msg.resize(1024, '\0'); + ERR_error_string_n(rc, msg.data(), msg.size()); + msg.resize(strlen(msg.data()), '\0'); + if (flags & ERR_TXT_STRING) { + msg.append(" | "); + msg.append(data); + } + Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, + Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", + msg); + } + return ResultInternalError; +} + +void OneTimeInit() { + ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!ssl_ctx) { + LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); + CheckOpenSSLErrors(); + return; + } + + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + + if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { + LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); + CheckOpenSSLErrors(); + return; + } + + OneTimeInitLogFile(); + + if (!OneTimeInitBIO()) { + return; + } + + one_time_init_success = true; +} + +void OneTimeInitLogFile() { + const char* logfile = getenv("SSLKEYLOGFILE"); + if (logfile) { + key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, + FileShareFlag::ShareWriteOnly); + if (key_log_file.IsOpen()) { + SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); + } else { + LOG_CRITICAL(Service_SSL, + "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); + } + } +} + +bool OneTimeInitBIO() { + bio_meth = + BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); + if (!bio_meth || + !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || + !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || + !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { + LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); + return false; + } + return true; +} + +} // namespace + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100644 index 000000000..d8074339a --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp @@ -0,0 +1,544 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +#include "common/error.h" +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +namespace { + +// These includes are inside the namespace to avoid a conflict on MinGW where +// the headers define an enum containing Network and Service as enumerators +// (which clash with the correspondingly named namespaces). +#define SECURITY_WIN32 +#include <schnlsp.h> +#include <security.h> +#include <wincrypt.h> + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SCHANNEL_CRED schannel_cred{}; +CredHandle cred_handle; + +static void OneTimeInit() { + schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; + schannel_cred.dwFlags = + SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols + SCH_CRED_AUTO_CRED_VALIDATION | // validate certs + SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate + // ^ I'm assuming that nobody would want to connect Yuzu to a + // service that requires some OS-provided corporate client + // certificate, and presenting one to some arbitrary server + // might be a privacy concern? Who knows, though. + + const SECURITY_STATUS ret = + AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, + nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); + if (ret != SEC_E_OK) { + // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. + LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", + Common::NativeErrorToString(ret)); + return; + } + + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " + "keys; not logging keys!"); + // Not fatal. + } + + one_time_init_success = true; +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSchannel final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR( + Service_SSL, + "Can't create SSL connection because Schannel one-time initialization failed"); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname_in) override { + hostname = hostname_in; + return ResultSuccess; + } + + Result DoHandshake() override { + while (1) { + Result r; + switch (handshake_state) { + case HandshakeState::Initial: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::ContinueNeeded: + case HandshakeState::IncompleteMessage: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = FillCiphertextReadBuf()) != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + if ((r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::DoneAfterFlush: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { + return r; + } + handshake_state = HandshakeState::Connected; + return ResultSuccess; + case HandshakeState::Connected: + LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); + return ResultInternalError; + case HandshakeState::Error: + return ResultInternalError; + } + } + } + + Result FillCiphertextReadBuf() { + const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; + read_buf_fill_size = 0; + // This unnecessarily zeroes the buffer; oh well. + const size_t offset = ciphertext_read_buf.size(); + ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); + ciphertext_read_buf.resize(offset + fill_size, 0); + const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); + const auto [actual, err] = socket->Recv(0, read_span); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= fill_size); + ciphertext_read_buf.resize(offset + actual); + return ResultSuccess; + case Network::Errno::AGAIN: + ciphertext_read_buf.resize(offset); + return ResultWouldBlock; + default: + ciphertext_read_buf.resize(offset); + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return ResultInternalError; + } + } + + // Returns success if the write buffer has been completely emptied. + Result FlushCiphertextWriteBuf() { + while (!ciphertext_write_buf.empty()) { + const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); + ciphertext_write_buf.erase(ciphertext_write_buf.begin(), + ciphertext_write_buf.begin() + actual); + break; + case Network::Errno::AGAIN: + return ResultWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return ResultInternalError; + } + } + return ResultSuccess; + } + + Result CallInitializeSecurityContext() { + const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | + ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | + ISC_REQ_USE_SUPPLIED_CREDS; + unsigned long attr; + // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel + std::array<SecBuffer, 2> input_buffers{{ + // only used if `initial_call_done` + { + // [0] + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = ciphertext_read_buf.data(), + }, + { + // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is + // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the + // whole buffer wasn't used) + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }, + }}; + std::array<SecBuffer, 2> output_buffers{{ + { + .cbBuffer = 0, + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = nullptr, + }, // [0] + { + .cbBuffer = 0, + .BufferType = SECBUFFER_ALERT, + .pvBuffer = nullptr, + }, // [1] + }}; + SecBufferDesc input_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(input_buffers.size()), + .pBuffers = input_buffers.data(), + }; + SecBufferDesc output_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(output_buffers.size()), + .pBuffers = output_buffers.data(), + }; + ASSERT_OR_EXECUTE_MSG( + input_buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + + bool initial_call_done = handshake_state != HandshakeState::Initial; + if (initial_call_done) { + LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", + ciphertext_read_buf.size()); + } + + const SECURITY_STATUS ret = + InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, + // Caller ensured we have set a hostname: + const_cast<char*>(hostname.value().c_str()), req, + 0, // Reserved1 + 0, // TargetDataRep not used with Schannel + initial_call_done ? &input_desc : nullptr, + 0, // Reserved2 + initial_call_done ? nullptr : &ctxt, &output_desc, &attr, + nullptr); // ptsExpiry + + if (output_buffers[0].pvBuffer) { + const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), + output_buffers[0].cbBuffer); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); + FreeContextBuffer(output_buffers[0].pvBuffer); + } + + if (output_buffers[1].pvBuffer) { + const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), + output_buffers[1].cbBuffer); + // The documentation doesn't explain what format this data is in. + LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), + Common::HexToString(span)); + } + + switch (ret) { + case SEC_I_CONTINUE_NEEDED: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); + if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { + LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); + ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - input_buffers[1].cbBuffer); + } else { + ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + handshake_state = HandshakeState::ContinueNeeded; + return ResultSuccess; + case SEC_E_INCOMPLETE_MESSAGE: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); + ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); + read_buf_fill_size = input_buffers[1].cbBuffer; + handshake_state = HandshakeState::IncompleteMessage; + return ResultSuccess; + case SEC_E_OK: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); + ciphertext_read_buf.clear(); + handshake_state = HandshakeState::DoneAfterFlush; + return GrabStreamSizes(); + default: + LOG_ERROR(Service_SSL, + "InitializeSecurityContext failed (probably certificate/protocol issue): {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + } + + Result GrabStreamSizes() { + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + return ResultSuccess; + } + + ResultVal<size_t> Read(std::span<u8> data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0 || got_read_eof) { + return size_t(0); + } + while (1) { + if (!cleartext_read_buf.empty()) { + const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); + std::memcpy(data.data(), cleartext_read_buf.data(), read_size); + cleartext_read_buf.erase(cleartext_read_buf.begin(), + cleartext_read_buf.begin() + read_size); + return read_size; + } + if (!ciphertext_read_buf.empty()) { + SecBuffer empty{ + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }; + std::array<SecBuffer, 5> buffers{{ + { + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = ciphertext_read_buf.data(), + }, + empty, + empty, + empty, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + SECURITY_STATUS ret = + DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); + switch (ret) { + case SEC_E_OK: + ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, + { return ResultInternalError; }); + cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), + static_cast<u8*>(buffers[1].pvBuffer) + + buffers[1].cbBuffer); + if (buffers[3].BufferType == SECBUFFER_EXTRA) { + ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - buffers[3].cbBuffer); + } else { + ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + continue; + case SEC_E_INCOMPLETE_MESSAGE: + break; + case SEC_I_CONTEXT_EXPIRED: + // Server hung up by sending close_notify. + got_read_eof = true; + return size_t(0); + default: + LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + } + const Result r = FillCiphertextReadBuf(); + if (r != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + got_read_eof = true; + return size_t(0); + } + } + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0) { + return size_t(0); + } + data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); + if (!cleartext_write_buf.empty()) { + // Already in the middle of a write. It wouldn't make sense to not + // finish sending the entire buffer since TLS has + // header/MAC/padding/etc. + if (data.size() != cleartext_write_buf.size() || + std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { + LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); + return ResultInternalError; + } + return WriteAlreadyEncryptedData(); + } else { + cleartext_write_buf.assign(data.begin(), data.end()); + } + + std::vector<u8> header_buf(stream_sizes.cbHeader, 0); + std::vector<u8> tmp_data_buf = cleartext_write_buf; + std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); + + std::array<SecBuffer, 3> buffers{{ + { + .cbBuffer = stream_sizes.cbHeader, + .BufferType = SECBUFFER_STREAM_HEADER, + .pvBuffer = header_buf.data(), + }, + { + .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = tmp_data_buf.data(), + }, + { + .cbBuffer = stream_sizes.cbTrailer, + .BufferType = SECBUFFER_STREAM_TRAILER, + .pvBuffer = trailer_buf.data(), + }, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, + "temp buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + + const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); + return ResultInternalError; + } + ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), + header_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), + tmp_data_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), + trailer_buf.end()); + return WriteAlreadyEncryptedData(); + } + + ResultVal<size_t> WriteAlreadyEncryptedData() { + const Result r = FlushCiphertextWriteBuf(); + if (r != ResultSuccess) { + return r; + } + // write buf is empty + const size_t cleartext_bytes_written = cleartext_write_buf.size(); + cleartext_write_buf.clear(); + return cleartext_bytes_written; + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + PCCERT_CONTEXT returned_cert = nullptr; + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, + "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + PCCERT_CONTEXT some_cert = nullptr; + std::vector<std::vector<u8>> certs; + while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { + certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), + static_cast<u8*>(some_cert->pbCertEncoded) + + some_cert->cbCertEncoded); + } + std::reverse(certs.begin(), + certs.end()); // Windows returns certs in reverse order from what we want + CertFreeCertificateContext(returned_cert); + return certs; + } + + ~SSLConnectionBackendSchannel() { + if (handshake_state != HandshakeState::Initial) { + DeleteSecurityContext(&ctxt); + } + } + + enum class HandshakeState { + // Haven't called anything yet. + Initial, + // `SEC_I_CONTINUE_NEEDED` was returned by + // `InitializeSecurityContext`; must finish sending data (if any) in + // the write buffer, then read at least one byte before calling + // `InitializeSecurityContext` again. + ContinueNeeded, + // `SEC_E_INCOMPLETE_MESSAGE` was returned by + // `InitializeSecurityContext`; hopefully the write buffer is empty; + // must read at least one byte before calling + // `InitializeSecurityContext` again. + IncompleteMessage, + // `SEC_E_OK` was returned by `InitializeSecurityContext`; must + // finish sending data in the write buffer before having `DoHandshake` + // report success. + DoneAfterFlush, + // We finished the above and are now connected. At this point, writing + // and reading are separate 'state machines' represented by the + // nonemptiness of the ciphertext and cleartext read and write buffers. + Connected, + // Another error was returned and we shouldn't allow initialization + // to continue. + Error, + } handshake_state = HandshakeState::Initial; + + CtxtHandle ctxt; + SecPkgContext_StreamSizes stream_sizes; + + std::shared_ptr<Network::SocketBase> socket; + std::optional<std::string> hostname; + + std::vector<u8> ciphertext_read_buf; + std::vector<u8> ciphertext_write_buf; + std::vector<u8> cleartext_read_buf; + std::vector<u8> cleartext_write_buf; + + bool got_read_eof = false; + size_t read_buf_fill_size = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSchannel>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp new file mode 100644 index 000000000..b3083cbad --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +// SecureTransport has been deprecated in its entirety in favor of +// Network.framework, but that does not allow layering TLS on top of an +// arbitrary socket. +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include <Security/SecureTransport.h> +#pragma GCC diagnostic pop +#endif + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +namespace { + +template <typename T> +struct CFReleaser { + T ptr; + + YUZU_NON_COPYABLE(CFReleaser); + constexpr CFReleaser() : ptr(nullptr) {} + constexpr CFReleaser(T ptr) : ptr(ptr) {} + constexpr operator T() { + return ptr; + } + ~CFReleaser() { + if (ptr) { + CFRelease(ptr); + } + } +}; + +std::string CFStringToString(CFStringRef cfstr) { + CFReleaser<CFDataRef> cfdata( + CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); + ASSERT_OR_EXECUTE(cfdata, { return "???"; }); + return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), + CFDataGetLength(cfdata)); +} + +std::string OSStatusToString(OSStatus status) { + CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); + if (!cfstr) { + return "[unknown error]"; + } + return CFStringToString(cfstr); +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { +public: + Result Init() { + static std::once_flag once_flag; + std::call_once(once_flag, []() { + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " + "support exporting keys; not logging keys!"); + // Not fatal. + } + }); + + context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); + if (!context) { + LOG_ERROR(Service_SSL, "SSLCreateContext failed"); + return ResultInternalError; + } + + OSStatus status; + if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || + (status = SSLSetConnection(context, this))) { + LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", + OSStatusToString(status)); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { + socket = std::move(in_socket); + } + + Result SetHostName(const std::string& hostname) override { + OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); + if (status) { + LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + return ResultSuccess; + } + + Result DoHandshake() override { + OSStatus status = SSLHandshake(context); + return HandleReturn("SSLHandshake", 0, status).Code(); + } + + ResultVal<size_t> Read(std::span<u8> data) override { + size_t actual; + OSStatus status = SSLRead(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLRead", actual, status); + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + size_t actual; + OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLWrite", actual, status); + } + + ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { + switch (status) { + case 0: + return actual; + case errSSLWouldBlock: + return ResultWouldBlock; + default: { + std::string reason; + if (got_read_eof) { + reason = "server hung up"; + } else { + reason = OSStatusToString(status); + } + LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); + return ResultInternalError; + } + } + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + CFReleaser<SecTrustRef> trust; + OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); + if (status) { + LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + std::vector<std::vector<u8>> ret; + for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { + SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); + CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); + ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); + const u8* ptr = CFDataGetBytePtr(data); + ret.emplace_back(ptr, ptr + CFDataGetLength(data)); + } + return ret; + } + + static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { + return ReadOrWriteCallback(connection, data, dataLength, true); + } + + static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, + size_t* dataLength) { + return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); + } + + static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, + bool is_read) { + auto self = + static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", + is_read ? "read" : "write"); + + // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are + // expected to read/write the full requested dataLength or return an + // error, so we have to add a loop ourselves. + size_t requested_len = *dataLength; + size_t offset = 0; + while (offset < requested_len) { + std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); + auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); + LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, + actual, cur.size(), static_cast<s32>(err)); + switch (err) { + case Network::Errno::SUCCESS: + offset += actual; + if (actual == 0) { + ASSERT(is_read); + self->got_read_eof = true; + return errSecEndOfData; + } + break; + case Network::Errno::AGAIN: + *dataLength = offset; + return errSSLWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", + is_read ? "recv" : "send", err); + return errSecIO; + } + } + ASSERT(offset == requested_len); + return 0; + } + +private: + CFReleaser<SSLContextRef> context = nullptr; + bool got_read_eof = false; + + std::shared_ptr<Network::SocketBase> socket; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp index 75ac10a9c..28f89c599 100644 --- a/src/core/internal_network/network.cpp +++ b/src/core/internal_network/network.cpp @@ -27,6 +27,7 @@ #include "common/assert.h" #include "common/common_types.h" +#include "common/expected.h" #include "common/logging/log.h" #include "common/settings.h" #include "core/internal_network/network.h" @@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) { Errno TranslateNativeError(int e) { switch (e) { + case 0: + return Errno::SUCCESS; case WSAEBADF: return Errno::BADF; case WSAEINVAL: @@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) { return Errno::MSGSIZE; case WSAETIMEDOUT: return Errno::TIMEDOUT; + case WSAEINPROGRESS: + return Errno::INPROGRESS; default: UNIMPLEMENTED_MSG("Unimplemented errno={}", e); return Errno::OTHER; @@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) { Errno TranslateNativeError(int e) { switch (e) { + case 0: + return Errno::SUCCESS; case EBADF: return Errno::BADF; case EINVAL: @@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) { return Errno::MSGSIZE; case ETIMEDOUT: return Errno::TIMEDOUT; + case EINPROGRESS: + return Errno::INPROGRESS; default: - UNIMPLEMENTED_MSG("Unimplemented errno={}", e); + UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e)); return Errno::OTHER; } } @@ -234,15 +243,84 @@ Errno GetAndLogLastError() { int e = errno; #endif const Errno err = TranslateNativeError(e); - if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { + if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) { + // These happen during normal operation, so only log them at debug level. + LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); return err; } LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); return err; } -int TranslateDomain(Domain domain) { +GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) { + switch (gai_err) { + case 0: + return GetAddrInfoError::SUCCESS; +#ifdef EAI_ADDRFAMILY + case EAI_ADDRFAMILY: + return GetAddrInfoError::ADDRFAMILY; +#endif + case EAI_AGAIN: + return GetAddrInfoError::AGAIN; + case EAI_BADFLAGS: + return GetAddrInfoError::BADFLAGS; + case EAI_FAIL: + return GetAddrInfoError::FAIL; + case EAI_FAMILY: + return GetAddrInfoError::FAMILY; + case EAI_MEMORY: + return GetAddrInfoError::MEMORY; + case EAI_NONAME: + return GetAddrInfoError::NONAME; + case EAI_SERVICE: + return GetAddrInfoError::SERVICE; + case EAI_SOCKTYPE: + return GetAddrInfoError::SOCKTYPE; + // These codes may not be defined on all systems: +#ifdef EAI_SYSTEM + case EAI_SYSTEM: + return GetAddrInfoError::SYSTEM; +#endif +#ifdef EAI_BADHINTS + case EAI_BADHINTS: + return GetAddrInfoError::BADHINTS; +#endif +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: + return GetAddrInfoError::PROTOCOL; +#endif +#ifdef EAI_OVERFLOW + case EAI_OVERFLOW: + return GetAddrInfoError::OVERFLOW_; +#endif + default: +#ifdef EAI_NODATA + // This can't be a case statement because it would create a duplicate + // case on Windows where EAI_NODATA is an alias for EAI_NONAME. + if (gai_err == EAI_NODATA) { + return GetAddrInfoError::NODATA; + } +#endif + return GetAddrInfoError::OTHER; + } +} + +Domain TranslateDomainFromNative(int domain) { + switch (domain) { + case 0: + return Domain::Unspecified; + case AF_INET: + return Domain::INET; + default: + UNIMPLEMENTED_MSG("Unhandled domain={}", domain); + return Domain::INET; + } +} + +int TranslateDomainToNative(Domain domain) { switch (domain) { + case Domain::Unspecified: + return 0; case Domain::INET: return AF_INET; default: @@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) { } } -int TranslateType(Type type) { +Type TranslateTypeFromNative(int type) { + switch (type) { + case 0: + return Type::Unspecified; + case SOCK_STREAM: + return Type::STREAM; + case SOCK_DGRAM: + return Type::DGRAM; + case SOCK_RAW: + return Type::RAW; + case SOCK_SEQPACKET: + return Type::SEQPACKET; + default: + UNIMPLEMENTED_MSG("Unimplemented type={}", type); + return Type::STREAM; + } +} + +int TranslateTypeToNative(Type type) { switch (type) { + case Type::Unspecified: + return 0; case Type::STREAM: return SOCK_STREAM; case Type::DGRAM: return SOCK_DGRAM; + case Type::RAW: + return SOCK_RAW; default: UNIMPLEMENTED_MSG("Unimplemented type={}", type); return 0; } } -int TranslateProtocol(Protocol protocol) { +Protocol TranslateProtocolFromNative(int protocol) { + switch (protocol) { + case 0: + return Protocol::Unspecified; + case IPPROTO_TCP: + return Protocol::TCP; + case IPPROTO_UDP: + return Protocol::UDP; + default: + UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); + return Protocol::Unspecified; + } +} + +int TranslateProtocolToNative(Protocol protocol) { switch (protocol) { + case Protocol::Unspecified: + return 0; case Protocol::TCP: return IPPROTO_TCP; case Protocol::UDP: @@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) { } } -SockAddrIn TranslateToSockAddrIn(sockaddr input_) { - sockaddr_in input; - std::memcpy(&input, &input_, sizeof(input)); - +SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) { SockAddrIn result; - switch (input.sin_family) { - case AF_INET: - result.family = Domain::INET; - break; - default: - UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family); - result.family = Domain::INET; - break; - } + result.family = TranslateDomainFromNative(input.sin_family); result.portno = ntohs(input.sin_port); @@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) { short TranslatePollEvents(PollEvents events) { short result = 0; - if (True(events & PollEvents::In)) { - events &= ~PollEvents::In; - result |= POLLIN; - } - if (True(events & PollEvents::Pri)) { - events &= ~PollEvents::Pri; + const auto translate = [&result, &events](PollEvents guest, short host) { + if (True(events & guest)) { + events &= ~guest; + result |= host; + } + }; + + translate(PollEvents::In, POLLIN); + translate(PollEvents::Pri, POLLPRI); + translate(PollEvents::Out, POLLOUT); + translate(PollEvents::Err, POLLERR); + translate(PollEvents::Hup, POLLHUP); + translate(PollEvents::Nval, POLLNVAL); + translate(PollEvents::RdNorm, POLLRDNORM); + translate(PollEvents::RdBand, POLLRDBAND); + translate(PollEvents::WrBand, POLLWRBAND); + #ifdef _WIN32 - LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); -#else - result |= POLLPRI; + short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM; + // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input. + if (result & ~allowed_events) { + LOG_DEBUG(Network, + "Removing WSAPoll input events 0x{:x} because Windows doesn't support them", + result & ~allowed_events); + } + result &= allowed_events; #endif - } - if (True(events & PollEvents::Out)) { - events &= ~PollEvents::Out; - result |= POLLOUT; - } UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); @@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) { translate(POLLOUT, PollEvents::Out); translate(POLLERR, PollEvents::Err); translate(POLLHUP, PollEvents::Hup); + translate(POLLNVAL, PollEvents::Nval); + translate(POLLRDNORM, PollEvents::RdNorm); + translate(POLLRDBAND, PollEvents::RdBand); + translate(POLLWRBAND, PollEvents::WrBand); UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); @@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() { return {}; } - std::array<char, 16> ip_addr = {}; - ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) != - nullptr); return TranslateIPv4(network_interface->ip_address); } +std::string IPv4AddressToString(IPv4Address ip_addr) { + std::array<char, INET_ADDRSTRLEN> buf = {}; + ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data()); + return std::string(buf.data()); +} + +u32 IPv4AddressToInteger(IPv4Address ip_addr) { + return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 | + static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]); +} + +Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo( + const std::string& host, const std::optional<std::string>& service) { + addrinfo hints{}; + hints.ai_family = AF_INET; // Switch only supports IPv4. + addrinfo* addrinfo; + s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr, + &hints, &addrinfo); + if (gai_err != 0) { + return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err)); + } + std::vector<AddrInfo> ret; + for (auto* current = addrinfo; current; current = current->ai_next) { + // We should only get AF_INET results due to the hints value. + ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET && + addrinfo->ai_addrlen == sizeof(sockaddr_in), + continue;); + + AddrInfo& out = ret.emplace_back(); + out.family = TranslateDomainFromNative(current->ai_family); + out.socket_type = TranslateTypeFromNative(current->ai_socktype); + out.protocol = TranslateProtocolFromNative(current->ai_protocol); + out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr), + current->ai_addrlen); + if (current->ai_canonname != nullptr) { + out.canon_name = current->ai_canonname; + } + } + freeaddrinfo(addrinfo); + return ret; +} + std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { const size_t num = pollfds.size(); @@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept { } template <typename T> -Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { +std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) { + T value{}; + socklen_t len = sizeof(value); + const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len); + if (result != SOCKET_ERROR) { + ASSERT(len == sizeof(value)); + return {value, Errno::SUCCESS}; + } + return {value, GetAndLogLastError()}; +} + +template <typename T> +Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) { const int result = - setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); + setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); if (result != SOCKET_ERROR) { return Errno::SUCCESS; } @@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { } Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { - fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); + fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type), + TranslateProtocolToNative(protocol)); if (fd != INVALID_SOCKET) { return Errno::SUCCESS; } @@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { } std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - const SOCKET new_socket = accept(fd, &addr, &addrlen); + const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); if (new_socket == INVALID_SOCKET) { return {AcceptResult{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - AcceptResult result{ .socket = std::make_unique<Socket>(new_socket), - .sockaddr_in = TranslateToSockAddrIn(addr), + .sockaddr_in = TranslateToSockAddrIn(addr, addrlen), }; return {std::move(result), Errno::SUCCESS}; @@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) { } std::pair<SockAddrIn, Errno> Socket::GetPeerName() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { + if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { return {SockAddrIn{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; + return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; } std::pair<SockAddrIn, Errno> Socket::GetSockName() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { + if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) { return {SockAddrIn{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; + return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; } Errno Socket::Bind(SockAddrIn addr) { @@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) { return GetAndLogLastError(); } -std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { +std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) { ASSERT(flags == 0); ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); @@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { return {-1, GetAndLogLastError()}; } -std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { +std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { ASSERT(flags == 0); ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); - sockaddr addr_in{}; + sockaddr_in addr_in{}; socklen_t addrlen = sizeof(addr_in); socklen_t* const p_addrlen = addr ? &addrlen : nullptr; - sockaddr* const p_addr_in = addr ? &addr_in : nullptr; + sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr; const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); if (result != SOCKET_ERROR) { if (addr) { - ASSERT(addrlen == sizeof(addr_in)); - *addr = TranslateToSockAddrIn(addr_in); + *addr = TranslateToSockAddrIn(addr_in, addrlen); } return {static_cast<s32>(result), Errno::SUCCESS}; } @@ -597,6 +764,11 @@ Errno Socket::Close() { return Errno::SUCCESS; } +std::pair<Errno, Errno> Socket::GetPendingError() { + auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR); + return {TranslateNativeError(pending_err), getsockopt_err}; +} + Errno Socket::SetLinger(bool enable, u32 linger) { return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); } diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h index 1e09a007a..badcb8369 100644 --- a/src/core/internal_network/network.h +++ b/src/core/internal_network/network.h @@ -5,6 +5,7 @@ #include <array> #include <optional> +#include <vector> #include "common/common_funcs.h" #include "common/common_types.h" @@ -16,6 +17,11 @@ #include <netinet/in.h> #endif +namespace Common { +template <typename T, typename E> +class Expected; +} + namespace Network { class SocketBase; @@ -36,6 +42,26 @@ enum class Errno { NETUNREACH, TIMEDOUT, MSGSIZE, + INPROGRESS, + OTHER, +}; + +enum class GetAddrInfoError { + SUCCESS, + ADDRFAMILY, + AGAIN, + BADFLAGS, + FAIL, + FAMILY, + MEMORY, + NODATA, + NONAME, + SERVICE, + SOCKTYPE, + SYSTEM, + BADHINTS, + PROTOCOL, + OVERFLOW_, OTHER, }; @@ -49,6 +75,9 @@ enum class PollEvents : u16 { Err = 1 << 3, Hup = 1 << 4, Nval = 1 << 5, + RdNorm = 1 << 6, + RdBand = 1 << 7, + WrBand = 1 << 8, }; DECLARE_ENUM_FLAG_OPERATORS(PollEvents); @@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) { /// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array std::optional<IPv4Address> GetHostIPv4Address(); +std::string IPv4AddressToString(IPv4Address ip_addr); +u32 IPv4AddressToInteger(IPv4Address ip_addr); + +// named to avoid name collision with Windows macro +Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo( + const std::string& host, const std::optional<std::string>& service); + } // namespace Network diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp index 7a77171c2..ce0dee970 100644 --- a/src/core/internal_network/socket_proxy.cpp +++ b/src/core/internal_network/socket_proxy.cpp @@ -10,6 +10,7 @@ #include "core/internal_network/network.h" #include "core/internal_network/network_interface.h" #include "core/internal_network/socket_proxy.h" +#include "network/network.h" #if YUZU_UNIX #include <sys/socket.h> @@ -98,7 +99,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) { return Errno::SUCCESS; } -std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { +std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) { LOG_WARNING(Network, "(STUBBED) called"); ASSERT(flags == 0); ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); @@ -106,7 +107,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { return {static_cast<s32>(0), Errno::SUCCESS}; } -std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { +std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) { ASSERT(flags == 0); ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); @@ -140,8 +141,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, } } -std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message, - SockAddrIn* addr, std::size_t max_length) { +std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr, + std::size_t max_length) { ProxyPacket& packet = received_packets.front(); if (addr) { addr->family = Domain::INET; @@ -153,10 +154,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes std::size_t read_bytes; if (packet.data.size() > max_length) { read_bytes = max_length; - message.clear(); - std::copy(packet.data.begin(), packet.data.begin() + read_bytes, - std::back_inserter(message)); - message.resize(max_length); + memcpy(message.data(), packet.data.data(), max_length); if (protocol == Protocol::UDP) { if (!peek) { @@ -171,9 +169,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes } } else { read_bytes = packet.data.size(); - message.clear(); - std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message)); - message.resize(max_length); + memcpy(message.data(), packet.data.data(), read_bytes); if (!peek) { received_packets.pop(); } @@ -293,6 +289,11 @@ Errno ProxySocket::SetNonBlock(bool enable) { return Errno::SUCCESS; } +std::pair<Errno, Errno> ProxySocket::GetPendingError() { + LOG_DEBUG(Network, "(STUBBED) called"); + return {Errno::SUCCESS, Errno::SUCCESS}; +} + bool ProxySocket::IsOpened() const { return fd != INVALID_SOCKET; } diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h index 6e991fa38..70500cf4a 100644 --- a/src/core/internal_network/socket_proxy.h +++ b/src/core/internal_network/socket_proxy.h @@ -10,10 +10,12 @@ #include "common/common_funcs.h" #include "core/internal_network/sockets.h" -#include "network/network.h" +#include "network/room_member.h" namespace Network { +class RoomNetwork; + class ProxySocket : public SocketBase { public: explicit ProxySocket(RoomNetwork& room_network_) noexcept; @@ -39,11 +41,11 @@ public: Errno Shutdown(ShutdownHow how) override; - std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; + std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override; - std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; + std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override; - std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr, + std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr, std::size_t max_length); std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; @@ -74,6 +76,8 @@ public: template <typename T> Errno SetSockOpt(SOCKET fd, int option, T value); + std::pair<Errno, Errno> GetPendingError() override; + bool IsOpened() const override; private: diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h index 11e479e50..4ba51f62c 100644 --- a/src/core/internal_network/sockets.h +++ b/src/core/internal_network/sockets.h @@ -15,12 +15,13 @@ #include "common/common_types.h" #include "core/internal_network/network.h" -#include "network/network.h" // TODO: C++20 Replace std::vector usages with std::span namespace Network { +struct ProxyPacket; + class SocketBase { public: #ifdef YUZU_UNIX @@ -59,10 +60,9 @@ public: virtual Errno Shutdown(ShutdownHow how) = 0; - virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0; + virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0; - virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, - SockAddrIn* addr) = 0; + virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0; virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; @@ -87,6 +87,8 @@ public: virtual Errno SetNonBlock(bool enable) = 0; + virtual std::pair<Errno, Errno> GetPendingError() = 0; + virtual bool IsOpened() const = 0; virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; @@ -126,9 +128,9 @@ public: Errno Shutdown(ShutdownHow how) override; - std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; + std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override; - std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; + std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override; std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; @@ -156,6 +158,11 @@ public: template <typename T> Errno SetSockOpt(SOCKET fd, int option, T value); + std::pair<Errno, Errno> GetPendingError() override; + + template <typename T> + std::pair<T, Errno> GetSockOpt(SOCKET fd, int option); + bool IsOpened() const override; void HandleProxyPacket(const ProxyPacket& packet) override; diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt index 3b2fe01da..7f79111e0 100644 --- a/src/video_core/CMakeLists.txt +++ b/src/video_core/CMakeLists.txt @@ -274,6 +274,7 @@ add_library(video_core STATIC vulkan_common/vulkan_wrapper.h vulkan_common/nsight_aftermath_tracker.cpp vulkan_common/nsight_aftermath_tracker.h + vulkan_common/vma.cpp ) create_target_directory_groups(video_core) @@ -291,7 +292,7 @@ target_link_options(video_core PRIVATE ${FFmpeg_LDFLAGS}) add_dependencies(video_core host_shaders) target_include_directories(video_core PRIVATE ${HOST_SHADERS_INCLUDE}) -target_link_libraries(video_core PRIVATE sirit Vulkan::Headers vma) +target_link_libraries(video_core PRIVATE sirit Vulkan::Headers GPUOpen::VulkanMemoryAllocator) if (ENABLE_NSIGHT_AFTERMATH) if (NOT DEFINED ENV{NSIGHT_AFTERMATH_SDK}) @@ -324,6 +325,9 @@ else() # xbyak set_source_files_properties(macro/macro_jit_x64.cpp PROPERTIES COMPILE_OPTIONS "-Wno-conversion;-Wno-shadow") + + # VMA + set_source_files_properties(vulkan_common/vma.cpp PROPERTIES COMPILE_OPTIONS "-Wno-conversion;-Wno-unused-variable;-Wno-unused-parameter;-Wno-missing-field-initializers") endif() if (ARCHITECTURE_x86_64) diff --git a/src/video_core/renderer_base.cpp b/src/video_core/renderer_base.cpp index 2d3f58201..4002fa72b 100644 --- a/src/video_core/renderer_base.cpp +++ b/src/video_core/renderer_base.cpp @@ -38,8 +38,8 @@ void RendererBase::RequestScreenshot(void* data, std::function<void(bool)> callb LOG_ERROR(Render, "A screenshot is already requested or in progress, ignoring the request"); return; } - auto async_callback{[callback = std::move(callback)](bool invert_y) { - std::thread t{callback, invert_y}; + auto async_callback{[callback_ = std::move(callback)](bool invert_y) { + std::thread t{callback_, invert_y}; t.detach(); }}; renderer_settings.screenshot_bits = data; diff --git a/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp b/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp index 23a48c6fe..71f720c63 100644 --- a/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp +++ b/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp @@ -231,24 +231,25 @@ GraphicsPipeline::GraphicsPipeline(const Device& device, TextureCache& texture_c } const bool in_parallel = thread_worker != nullptr; const auto backend = device.GetShaderBackend(); - auto func{[this, sources = std::move(sources), sources_spirv = std::move(sources_spirv), + auto func{[this, sources_ = std::move(sources), sources_spirv_ = std::move(sources_spirv), shader_notify, backend, in_parallel, force_context_flush](ShaderContext::Context*) mutable { for (size_t stage = 0; stage < 5; ++stage) { switch (backend) { case Settings::ShaderBackend::GLSL: - if (!sources[stage].empty()) { - source_programs[stage] = CreateProgram(sources[stage], Stage(stage)); + if (!sources_[stage].empty()) { + source_programs[stage] = CreateProgram(sources_[stage], Stage(stage)); } break; case Settings::ShaderBackend::GLASM: - if (!sources[stage].empty()) { - assembly_programs[stage] = CompileProgram(sources[stage], AssemblyStage(stage)); + if (!sources_[stage].empty()) { + assembly_programs[stage] = + CompileProgram(sources_[stage], AssemblyStage(stage)); } break; case Settings::ShaderBackend::SPIRV: - if (!sources_spirv[stage].empty()) { - source_programs[stage] = CreateProgram(sources_spirv[stage], Stage(stage)); + if (!sources_spirv_[stage].empty()) { + source_programs[stage] = CreateProgram(sources_spirv_[stage], Stage(stage)); } break; } diff --git a/src/video_core/renderer_opengl/gl_shader_cache.cpp b/src/video_core/renderer_opengl/gl_shader_cache.cpp index 0329ed820..7e1d7f92e 100644 --- a/src/video_core/renderer_opengl/gl_shader_cache.cpp +++ b/src/video_core/renderer_opengl/gl_shader_cache.cpp @@ -288,9 +288,9 @@ void ShaderCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading, const auto load_compute{[&](std::ifstream& file, FileEnvironment env) { ComputePipelineKey key; file.read(reinterpret_cast<char*>(&key), sizeof(key)); - queue_work([this, key, env = std::move(env), &state, &callback](Context* ctx) mutable { + queue_work([this, key, env_ = std::move(env), &state, &callback](Context* ctx) mutable { ctx->pools.ReleaseContents(); - auto pipeline{CreateComputePipeline(ctx->pools, key, env, true)}; + auto pipeline{CreateComputePipeline(ctx->pools, key, env_, true)}; std::scoped_lock lock{state.mutex}; if (pipeline) { compute_cache.emplace(key, std::move(pipeline)); @@ -305,9 +305,9 @@ void ShaderCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading, const auto load_graphics{[&](std::ifstream& file, std::vector<FileEnvironment> envs) { GraphicsPipelineKey key; file.read(reinterpret_cast<char*>(&key), sizeof(key)); - queue_work([this, key, envs = std::move(envs), &state, &callback](Context* ctx) mutable { + queue_work([this, key, envs_ = std::move(envs), &state, &callback](Context* ctx) mutable { boost::container::static_vector<Shader::Environment*, 5> env_ptrs; - for (auto& env : envs) { + for (auto& env : envs_) { env_ptrs.push_back(&env); } ctx->pools.ReleaseContents(); diff --git a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp index 51df18ec3..f8cd2a5d8 100644 --- a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp +++ b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp @@ -206,8 +206,8 @@ public: const size_t sub_first_offset = static_cast<size_t>(first % 4) * GetQuadsNum(num_indices); const size_t offset = (sub_first_offset + GetQuadsNum(first)) * 6ULL * BytesPerIndex(index_type); - scheduler.Record([buffer = *buffer, index_type_, offset](vk::CommandBuffer cmdbuf) { - cmdbuf.BindIndexBuffer(buffer, offset, index_type_); + scheduler.Record([buffer_ = *buffer, index_type_, offset](vk::CommandBuffer cmdbuf) { + cmdbuf.BindIndexBuffer(buffer_, offset, index_type_); }); } @@ -528,17 +528,18 @@ void BufferCacheRuntime::BindVertexBuffers(VideoCommon::HostBindings<Buffer>& bi buffer_handles.push_back(handle); } if (device.IsExtExtendedDynamicStateSupported()) { - scheduler.Record([bindings = std::move(bindings), - buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) { - cmdbuf.BindVertexBuffers2EXT( - bindings.min_index, bindings.max_index - bindings.min_index, buffer_handles.data(), - bindings.offsets.data(), bindings.sizes.data(), bindings.strides.data()); + scheduler.Record([bindings_ = std::move(bindings), + buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) { + cmdbuf.BindVertexBuffers2EXT(bindings_.min_index, + bindings_.max_index - bindings_.min_index, + buffer_handles_.data(), bindings_.offsets.data(), + bindings_.sizes.data(), bindings_.strides.data()); }); } else { - scheduler.Record([bindings = std::move(bindings), - buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) { - cmdbuf.BindVertexBuffers(bindings.min_index, bindings.max_index - bindings.min_index, - buffer_handles.data(), bindings.offsets.data()); + scheduler.Record([bindings_ = std::move(bindings), + buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) { + cmdbuf.BindVertexBuffers(bindings_.min_index, bindings_.max_index - bindings_.min_index, + buffer_handles_.data(), bindings_.offsets.data()); }); } } @@ -573,11 +574,11 @@ void BufferCacheRuntime::BindTransformFeedbackBuffers(VideoCommon::HostBindings< for (u32 index = 0; index < bindings.buffers.size(); ++index) { buffer_handles.push_back(bindings.buffers[index]->Handle()); } - scheduler.Record([bindings = std::move(bindings), - buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) { - cmdbuf.BindTransformFeedbackBuffersEXT(0, static_cast<u32>(buffer_handles.size()), - buffer_handles.data(), bindings.offsets.data(), - bindings.sizes.data()); + scheduler.Record([bindings_ = std::move(bindings), + buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) { + cmdbuf.BindTransformFeedbackBuffersEXT(0, static_cast<u32>(buffer_handles_.size()), + buffer_handles_.data(), bindings_.offsets.data(), + bindings_.sizes.data()); }); } diff --git a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp index d600c4e61..4f84d8497 100644 --- a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp +++ b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp @@ -469,9 +469,9 @@ void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading ComputePipelineCacheKey key; file.read(reinterpret_cast<char*>(&key), sizeof(key)); - workers.QueueWork([this, key, env = std::move(env), &state, &callback]() mutable { + workers.QueueWork([this, key, env_ = std::move(env), &state, &callback]() mutable { ShaderPools pools; - auto pipeline{CreateComputePipeline(pools, key, env, state.statistics.get(), false)}; + auto pipeline{CreateComputePipeline(pools, key, env_, state.statistics.get(), false)}; std::scoped_lock lock{state.mutex}; if (pipeline) { compute_cache.emplace(key, std::move(pipeline)); @@ -500,10 +500,10 @@ void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading (key.state.dynamic_vertex_input != 0) != dynamic_features.has_dynamic_vertex_input) { return; } - workers.QueueWork([this, key, envs = std::move(envs), &state, &callback]() mutable { + workers.QueueWork([this, key, envs_ = std::move(envs), &state, &callback]() mutable { ShaderPools pools; boost::container::static_vector<Shader::Environment*, 5> env_ptrs; - for (auto& env : envs) { + for (auto& env : envs_) { env_ptrs.push_back(&env); } auto pipeline{CreateGraphicsPipeline(pools, key, MakeSpan(env_ptrs), @@ -702,8 +702,8 @@ std::unique_ptr<ComputePipeline> PipelineCache::CreateComputePipeline( if (!pipeline || pipeline_cache_filename.empty()) { return pipeline; } - serialization_thread.QueueWork([this, key, env = std::move(env)] { - SerializePipeline(key, std::array<const GenericEnvironment*, 1>{&env}, + serialization_thread.QueueWork([this, key, env_ = std::move(env)] { + SerializePipeline(key, std::array<const GenericEnvironment*, 1>{&env_}, pipeline_cache_filename, CACHE_VERSION); }); return pipeline; diff --git a/src/video_core/renderer_vulkan/vk_query_cache.cpp b/src/video_core/renderer_vulkan/vk_query_cache.cpp index d67490449..29e0b797b 100644 --- a/src/video_core/renderer_vulkan/vk_query_cache.cpp +++ b/src/video_core/renderer_vulkan/vk_query_cache.cpp @@ -98,10 +98,10 @@ HostCounter::HostCounter(QueryCache& cache_, std::shared_ptr<HostCounter> depend : HostCounterBase{std::move(dependency_)}, cache{cache_}, type{type_}, query{cache_.AllocateQuery(type_)}, tick{cache_.GetScheduler().CurrentTick()} { const vk::Device* logical = &cache.GetDevice().GetLogical(); - cache.GetScheduler().Record([logical, query = query](vk::CommandBuffer cmdbuf) { + cache.GetScheduler().Record([logical, query_ = query](vk::CommandBuffer cmdbuf) { const bool use_precise = Settings::IsGPULevelHigh(); - logical->ResetQueryPool(query.first, query.second, 1); - cmdbuf.BeginQuery(query.first, query.second, + logical->ResetQueryPool(query_.first, query_.second, 1); + cmdbuf.BeginQuery(query_.first, query_.second, use_precise ? VK_QUERY_CONTROL_PRECISE_BIT : 0); }); } @@ -111,8 +111,9 @@ HostCounter::~HostCounter() { } void HostCounter::EndQuery() { - cache.GetScheduler().Record( - [query = query](vk::CommandBuffer cmdbuf) { cmdbuf.EndQuery(query.first, query.second); }); + cache.GetScheduler().Record([query_ = query](vk::CommandBuffer cmdbuf) { + cmdbuf.EndQuery(query_.first, query_.second); + }); } u64 HostCounter::BlockingQuery(bool async) const { diff --git a/src/video_core/renderer_vulkan/vk_texture_cache.cpp b/src/video_core/renderer_vulkan/vk_texture_cache.cpp index 3aac3cfab..bf6ad6c79 100644 --- a/src/video_core/renderer_vulkan/vk_texture_cache.cpp +++ b/src/video_core/renderer_vulkan/vk_texture_cache.cpp @@ -1412,7 +1412,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS } scheduler->RequestOutsideRenderPassOperationContext(); scheduler->Record([buffers = std::move(buffers_vector), image = *original_image, - aspect_mask = aspect_mask, vk_copies](vk::CommandBuffer cmdbuf) { + aspect_mask_ = aspect_mask, vk_copies](vk::CommandBuffer cmdbuf) { const VkImageMemoryBarrier read_barrier{ .sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, .pNext = nullptr, @@ -1424,7 +1424,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, .image = image, .subresourceRange{ - .aspectMask = aspect_mask, + .aspectMask = aspect_mask_, .baseMipLevel = 0, .levelCount = VK_REMAINING_MIP_LEVELS, .baseArrayLayer = 0, @@ -1456,7 +1456,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, .image = image, .subresourceRange{ - .aspectMask = aspect_mask, + .aspectMask = aspect_mask_, .baseMipLevel = 0, .levelCount = VK_REMAINING_MIP_LEVELS, .baseArrayLayer = 0, diff --git a/externals/vma/vma.cpp b/src/video_core/vulkan_common/vma.cpp index 1fe2cf52b..1fe2cf52b 100644 --- a/externals/vma/vma.cpp +++ b/src/video_core/vulkan_common/vma.cpp diff --git a/src/web_service/announce_room_json.cpp b/src/web_service/announce_room_json.cpp index 4c3195efd..f1020a5b8 100644 --- a/src/web_service/announce_room_json.cpp +++ b/src/web_service/announce_room_json.cpp @@ -135,11 +135,11 @@ void RoomJson::Delete() { LOG_ERROR(WebService, "Room must be registered to be deleted"); return; } - Common::DetachedTasks::AddTask( - [host{this->host}, username{this->username}, token{this->token}, room_id{this->room_id}]() { - // create a new client here because the this->client might be destroyed. - Client{host, username, token}.DeleteJson(fmt::format("/lobby/{}", room_id), "", false); - }); + Common::DetachedTasks::AddTask([host_{this->host}, username_{this->username}, + token_{this->token}, room_id_{this->room_id}]() { + // create a new client here because the this->client might be destroyed. + Client{host_, username_, token_}.DeleteJson(fmt::format("/lobby/{}", room_id_), "", false); + }); } } // namespace WebService |