summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitmodules2
-rw-r--r--CMakeLists.txt19
-rw-r--r--externals/CMakeLists.txt6
m---------externals/VulkanMemoryAllocator0
-rw-r--r--externals/demangle/ItaniumDemangle.cpp171
-rw-r--r--externals/demangle/llvm/Demangle/Demangle.h37
-rw-r--r--externals/demangle/llvm/Demangle/DemangleConfig.h4
-rw-r--r--externals/demangle/llvm/Demangle/ItaniumDemangle.h3888
-rw-r--r--externals/demangle/llvm/Demangle/ItaniumNodes.def96
-rw-r--r--externals/demangle/llvm/Demangle/StringView.h32
-rw-r--r--externals/demangle/llvm/Demangle/StringViewExtras.h39
-rw-r--r--externals/demangle/llvm/Demangle/Utility.h208
m---------externals/vma/VulkanMemoryAllocator0
-rw-r--r--src/android/app/src/main/res/values/strings.xml2
-rw-r--r--src/common/demangle.cpp2
-rw-r--r--src/common/detached_tasks.cpp4
-rw-r--r--src/common/socket_types.h17
-rw-r--r--src/core/CMakeLists.txt18
-rw-r--r--src/core/hle/kernel/k_thread.cpp4
-rw-r--r--src/core/hle/kernel/kernel.cpp6
-rw-r--r--src/core/hle/service/am/am.cpp4
-rw-r--r--src/core/hle/service/nifm/nifm.cpp1
-rw-r--r--src/core/hle/service/nifm/nifm.h7
-rw-r--r--src/core/hle/service/sockets/bsd.cpp120
-rw-r--r--src/core/hle/service/sockets/bsd.h13
-rw-r--r--src/core/hle/service/sockets/nsd.cpp58
-rw-r--r--src/core/hle/service/sockets/nsd.h4
-rw-r--r--src/core/hle/service/sockets/sfdnsres.cpp388
-rw-r--r--src/core/hle/service/sockets/sfdnsres.h3
-rw-r--r--src/core/hle/service/sockets/sockets.h33
-rw-r--r--src/core/hle/service/sockets/sockets_translate.cpp114
-rw-r--r--src/core/hle/service/sockets/sockets_translate.h17
-rw-r--r--src/core/hle/service/ssl/ssl.cpp353
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h45
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp16
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp351
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp544
-rw-r--r--src/core/hle/service/ssl/ssl_backend_securetransport.cpp222
-rw-r--r--src/core/internal_network/network.cpp286
-rw-r--r--src/core/internal_network/network.h36
-rw-r--r--src/core/internal_network/socket_proxy.cpp23
-rw-r--r--src/core/internal_network/socket_proxy.h12
-rw-r--r--src/core/internal_network/sockets.h19
-rw-r--r--src/video_core/CMakeLists.txt6
-rw-r--r--src/video_core/renderer_base.cpp4
-rw-r--r--src/video_core/renderer_opengl/gl_graphics_pipeline.cpp15
-rw-r--r--src/video_core/renderer_opengl/gl_shader_cache.cpp8
-rw-r--r--src/video_core/renderer_vulkan/vk_buffer_cache.cpp33
-rw-r--r--src/video_core/renderer_vulkan/vk_pipeline_cache.cpp12
-rw-r--r--src/video_core/renderer_vulkan/vk_query_cache.cpp11
-rw-r--r--src/video_core/renderer_vulkan/vk_texture_cache.cpp6
-rw-r--r--src/video_core/vulkan_common/vma.cpp (renamed from externals/vma/vma.cpp)0
-rw-r--r--src/web_service/announce_room_json.cpp10
53 files changed, 4795 insertions, 2534 deletions
diff --git a/.gitmodules b/.gitmodules
index 9f96b70be..361f4845b 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -56,5 +56,5 @@
path = externals/nx_tzdb/tzdb_to_nx
url = https://github.com/lat9nq/tzdb_to_nx.git
[submodule "VulkanMemoryAllocator"]
- path = externals/vma/VulkanMemoryAllocator
+ path = externals/VulkanMemoryAllocator
url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7f8febb90..00d540f1f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF)
CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF)
+set(DEFAULT_ENABLE_OPENSSL ON)
+if (ANDROID OR WIN32 OR APPLE)
+ # - Windows defaults to the Schannel backend.
+ # - macOS defaults to the SecureTransport backend.
+ # - Android currently has no SSL backend as the NDK doesn't include any SSL
+ # library; a proper 'native' backend would have to go through Java.
+ # But you can force builds for those platforms to use OpenSSL if you have
+ # your own copy of it.
+ set(DEFAULT_ENABLE_OPENSSL OFF)
+endif()
+option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL})
+
# On Android, fetch and compile libcxx before doing anything else
if (ANDROID)
set(CMAKE_SKIP_INSTALL_RULES ON)
@@ -277,10 +289,11 @@ find_package(Boost 1.79.0 REQUIRED context)
find_package(enet 1.3 MODULE)
find_package(fmt 9 REQUIRED)
find_package(inih 52 MODULE COMPONENTS INIReader)
-find_package(LLVM MODULE COMPONENTS Demangle)
+find_package(LLVM 17 MODULE COMPONENTS Demangle)
find_package(lz4 REQUIRED)
find_package(nlohmann_json 3.8 REQUIRED)
find_package(Opus 1.3 MODULE)
+find_package(VulkanMemoryAllocator CONFIG)
find_package(ZLIB 1.2 REQUIRED)
find_package(zstd 1.5 REQUIRED)
@@ -322,6 +335,10 @@ if (MINGW)
find_library(MSWSOCK_LIBRARY mswsock REQUIRED)
endif()
+if(ENABLE_OPENSSL)
+ find_package(OpenSSL 1.1.1 REQUIRED)
+endif()
+
# Please consider this as a stub
if(ENABLE_QT6 AND Qt6_LOCATION)
list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}")
diff --git a/externals/CMakeLists.txt b/externals/CMakeLists.txt
index 4ff588851..1f7cd598e 100644
--- a/externals/CMakeLists.txt
+++ b/externals/CMakeLists.txt
@@ -144,9 +144,9 @@ endif()
add_subdirectory(nx_tzdb)
# VMA
-add_library(vma vma/vma.cpp)
-target_include_directories(vma PUBLIC ./vma/VulkanMemoryAllocator/include)
-target_link_libraries(vma PRIVATE Vulkan::Headers)
+if (NOT TARGET GPUOpen::VulkanMemoryAllocator)
+ add_subdirectory(VulkanMemoryAllocator)
+endif()
if (NOT TARGET LLVM::Demangle)
add_library(demangle demangle/ItaniumDemangle.cpp)
diff --git a/externals/VulkanMemoryAllocator b/externals/VulkanMemoryAllocator
new file mode 160000
+Subproject 9b0fc3e7b02afe97895eb3e945fe800c3a7485a
diff --git a/externals/demangle/ItaniumDemangle.cpp b/externals/demangle/ItaniumDemangle.cpp
index b055a2fd7..47dd5d301 100644
--- a/externals/demangle/ItaniumDemangle.cpp
+++ b/externals/demangle/ItaniumDemangle.cpp
@@ -20,9 +20,7 @@
#include <cstdlib>
#include <cstring>
#include <functional>
-#include <numeric>
#include <utility>
-#include <vector>
using namespace llvm;
using namespace llvm::itanium_demangle;
@@ -81,8 +79,8 @@ struct DumpVisitor {
}
void printStr(const char *S) { fprintf(stderr, "%s", S); }
- void print(StringView SV) {
- fprintf(stderr, "\"%.*s\"", (int)SV.size(), SV.begin());
+ void print(std::string_view SV) {
+ fprintf(stderr, "\"%.*s\"", (int)SV.size(), SV.data());
}
void print(const Node *N) {
if (N)
@@ -90,14 +88,6 @@ struct DumpVisitor {
else
printStr("<null>");
}
- void print(NodeOrString NS) {
- if (NS.isNode())
- print(NS.asNode());
- else if (NS.isString())
- print(NS.asString());
- else
- printStr("NodeOrString()");
- }
void print(NodeArray A) {
++Depth;
printStr("{");
@@ -116,13 +106,11 @@ struct DumpVisitor {
// Overload used when T is exactly 'bool', not merely convertible to 'bool'.
void print(bool B) { printStr(B ? "true" : "false"); }
- template <class T>
- typename std::enable_if<std::is_unsigned<T>::value>::type print(T N) {
+ template <class T> std::enable_if_t<std::is_unsigned<T>::value> print(T N) {
fprintf(stderr, "%llu", (unsigned long long)N);
}
- template <class T>
- typename std::enable_if<std::is_signed<T>::value>::type print(T N) {
+ template <class T> std::enable_if_t<std::is_signed<T>::value> print(T N) {
fprintf(stderr, "%lld", (long long)N);
}
@@ -185,6 +173,50 @@ struct DumpVisitor {
return printStr("TemplateParamKind::Template");
}
}
+ void print(Node::Prec P) {
+ switch (P) {
+ case Node::Prec::Primary:
+ return printStr("Node::Prec::Primary");
+ case Node::Prec::Postfix:
+ return printStr("Node::Prec::Postfix");
+ case Node::Prec::Unary:
+ return printStr("Node::Prec::Unary");
+ case Node::Prec::Cast:
+ return printStr("Node::Prec::Cast");
+ case Node::Prec::PtrMem:
+ return printStr("Node::Prec::PtrMem");
+ case Node::Prec::Multiplicative:
+ return printStr("Node::Prec::Multiplicative");
+ case Node::Prec::Additive:
+ return printStr("Node::Prec::Additive");
+ case Node::Prec::Shift:
+ return printStr("Node::Prec::Shift");
+ case Node::Prec::Spaceship:
+ return printStr("Node::Prec::Spaceship");
+ case Node::Prec::Relational:
+ return printStr("Node::Prec::Relational");
+ case Node::Prec::Equality:
+ return printStr("Node::Prec::Equality");
+ case Node::Prec::And:
+ return printStr("Node::Prec::And");
+ case Node::Prec::Xor:
+ return printStr("Node::Prec::Xor");
+ case Node::Prec::Ior:
+ return printStr("Node::Prec::Ior");
+ case Node::Prec::AndIf:
+ return printStr("Node::Prec::AndIf");
+ case Node::Prec::OrIf:
+ return printStr("Node::Prec::OrIf");
+ case Node::Prec::Conditional:
+ return printStr("Node::Prec::Conditional");
+ case Node::Prec::Assign:
+ return printStr("Node::Prec::Assign");
+ case Node::Prec::Comma:
+ return printStr("Node::Prec::Comma");
+ case Node::Prec::Default:
+ return printStr("Node::Prec::Default");
+ }
+ }
void newLine() {
printStr("\n");
@@ -334,36 +366,21 @@ public:
using Demangler = itanium_demangle::ManglingParser<DefaultAllocator>;
-char *llvm::itaniumDemangle(const char *MangledName, char *Buf,
- size_t *N, int *Status) {
- if (MangledName == nullptr || (Buf != nullptr && N == nullptr)) {
- if (Status)
- *Status = demangle_invalid_args;
+char *llvm::itaniumDemangle(std::string_view MangledName) {
+ if (MangledName.empty())
return nullptr;
- }
-
- int InternalStatus = demangle_success;
- Demangler Parser(MangledName, MangledName + std::strlen(MangledName));
- OutputStream S;
+ Demangler Parser(MangledName.data(),
+ MangledName.data() + MangledName.length());
Node *AST = Parser.parse();
+ if (!AST)
+ return nullptr;
- if (AST == nullptr)
- InternalStatus = demangle_invalid_mangled_name;
- else if (!initializeOutputStream(Buf, N, S, 1024))
- InternalStatus = demangle_memory_alloc_failure;
- else {
- assert(Parser.ForwardTemplateRefs.empty());
- AST->print(S);
- S += '\0';
- if (N != nullptr)
- *N = S.getCurrentPosition();
- Buf = S.getBuffer();
- }
-
- if (Status)
- *Status = InternalStatus;
- return InternalStatus == demangle_success ? Buf : nullptr;
+ OutputBuffer OB;
+ assert(Parser.ForwardTemplateRefs.empty());
+ AST->print(OB);
+ OB += '\0';
+ return OB.getBuffer();
}
ItaniumPartialDemangler::ItaniumPartialDemangler()
@@ -396,14 +413,12 @@ bool ItaniumPartialDemangler::partialDemangle(const char *MangledName) {
}
static char *printNode(const Node *RootNode, char *Buf, size_t *N) {
- OutputStream S;
- if (!initializeOutputStream(Buf, N, S, 128))
- return nullptr;
- RootNode->print(S);
- S += '\0';
+ OutputBuffer OB(Buf, N);
+ RootNode->print(OB);
+ OB += '\0';
if (N != nullptr)
- *N = S.getCurrentPosition();
- return S.getBuffer();
+ *N = OB.getCurrentPosition();
+ return OB.getBuffer();
}
char *ItaniumPartialDemangler::getFunctionBaseName(char *Buf, size_t *N) const {
@@ -417,8 +432,8 @@ char *ItaniumPartialDemangler::getFunctionBaseName(char *Buf, size_t *N) const {
case Node::KAbiTagAttr:
Name = static_cast<const AbiTagAttr *>(Name)->Base;
continue;
- case Node::KStdQualifiedName:
- Name = static_cast<const StdQualifiedName *>(Name)->Child;
+ case Node::KModuleEntity:
+ Name = static_cast<const ModuleEntity *>(Name)->Name;
continue;
case Node::KNestedName:
Name = static_cast<const NestedName *>(Name)->Name;
@@ -441,9 +456,7 @@ char *ItaniumPartialDemangler::getFunctionDeclContextName(char *Buf,
return nullptr;
const Node *Name = static_cast<const FunctionEncoding *>(RootNode)->getName();
- OutputStream S;
- if (!initializeOutputStream(Buf, N, S, 128))
- return nullptr;
+ OutputBuffer OB(Buf, N);
KeepGoingLocalFunction:
while (true) {
@@ -458,27 +471,27 @@ char *ItaniumPartialDemangler::getFunctionDeclContextName(char *Buf,
break;
}
+ if (Name->getKind() == Node::KModuleEntity)
+ Name = static_cast<const ModuleEntity *>(Name)->Name;
+
switch (Name->getKind()) {
- case Node::KStdQualifiedName:
- S += "std";
- break;
case Node::KNestedName:
- static_cast<const NestedName *>(Name)->Qual->print(S);
+ static_cast<const NestedName *>(Name)->Qual->print(OB);
break;
case Node::KLocalName: {
auto *LN = static_cast<const LocalName *>(Name);
- LN->Encoding->print(S);
- S += "::";
+ LN->Encoding->print(OB);
+ OB += "::";
Name = LN->Entity;
goto KeepGoingLocalFunction;
}
default:
break;
}
- S += '\0';
+ OB += '\0';
if (N != nullptr)
- *N = S.getCurrentPosition();
- return S.getBuffer();
+ *N = OB.getCurrentPosition();
+ return OB.getBuffer();
}
char *ItaniumPartialDemangler::getFunctionName(char *Buf, size_t *N) const {
@@ -494,17 +507,15 @@ char *ItaniumPartialDemangler::getFunctionParameters(char *Buf,
return nullptr;
NodeArray Params = static_cast<FunctionEncoding *>(RootNode)->getParams();
- OutputStream S;
- if (!initializeOutputStream(Buf, N, S, 128))
- return nullptr;
+ OutputBuffer OB(Buf, N);
- S += '(';
- Params.printWithComma(S);
- S += ')';
- S += '\0';
+ OB += '(';
+ Params.printWithComma(OB);
+ OB += ')';
+ OB += '\0';
if (N != nullptr)
- *N = S.getCurrentPosition();
- return S.getBuffer();
+ *N = OB.getCurrentPosition();
+ return OB.getBuffer();
}
char *ItaniumPartialDemangler::getFunctionReturnType(
@@ -512,18 +523,16 @@ char *ItaniumPartialDemangler::getFunctionReturnType(
if (!isFunction())
return nullptr;
- OutputStream S;
- if (!initializeOutputStream(Buf, N, S, 128))
- return nullptr;
+ OutputBuffer OB(Buf, N);
if (const Node *Ret =
static_cast<const FunctionEncoding *>(RootNode)->getReturnType())
- Ret->print(S);
+ Ret->print(OB);
- S += '\0';
+ OB += '\0';
if (N != nullptr)
- *N = S.getCurrentPosition();
- return S.getBuffer();
+ *N = OB.getCurrentPosition();
+ return OB.getBuffer();
}
char *ItaniumPartialDemangler::finishDemangle(char *Buf, size_t *N) const {
@@ -563,8 +572,8 @@ bool ItaniumPartialDemangler::isCtorOrDtor() const {
case Node::KNestedName:
N = static_cast<const NestedName *>(N)->Name;
break;
- case Node::KStdQualifiedName:
- N = static_cast<const StdQualifiedName *>(N)->Child;
+ case Node::KModuleEntity:
+ N = static_cast<const ModuleEntity *>(N)->Name;
break;
}
}
diff --git a/externals/demangle/llvm/Demangle/Demangle.h b/externals/demangle/llvm/Demangle/Demangle.h
index 5b673e4e1..1552a501a 100644
--- a/externals/demangle/llvm/Demangle/Demangle.h
+++ b/externals/demangle/llvm/Demangle/Demangle.h
@@ -12,6 +12,7 @@
#include <cstddef>
#include <string>
+#include <string_view>
namespace llvm {
/// This is a llvm local version of __cxa_demangle. Other than the name and
@@ -29,9 +30,10 @@ enum : int {
demangle_success = 0,
};
-char *itaniumDemangle(const char *mangled_name, char *buf, size_t *n,
- int *status);
-
+/// Returns a non-NULL pointer to a NUL-terminated C style string
+/// that should be explicitly freed, if successful. Otherwise, may return
+/// nullptr if mangled_name is not a valid mangling or is nullptr.
+char *itaniumDemangle(std::string_view mangled_name);
enum MSDemangleFlags {
MSDF_None = 0,
@@ -40,10 +42,34 @@ enum MSDemangleFlags {
MSDF_NoCallingConvention = 1 << 2,
MSDF_NoReturnType = 1 << 3,
MSDF_NoMemberType = 1 << 4,
+ MSDF_NoVariableType = 1 << 5,
};
-char *microsoftDemangle(const char *mangled_name, char *buf, size_t *n,
+
+/// Demangles the Microsoft symbol pointed at by mangled_name and returns it.
+/// Returns a pointer to the start of a null-terminated demangled string on
+/// success, or nullptr on error.
+/// If n_read is non-null and demangling was successful, it receives how many
+/// bytes of the input string were consumed.
+/// status receives one of the demangle_ enum entries above if it's not nullptr.
+/// Flags controls various details of the demangled representation.
+char *microsoftDemangle(std::string_view mangled_name, size_t *n_read,
int *status, MSDemangleFlags Flags = MSDF_None);
+// Demangles a Rust v0 mangled symbol.
+char *rustDemangle(std::string_view MangledName);
+
+// Demangles a D mangled symbol.
+char *dlangDemangle(std::string_view MangledName);
+
+/// Attempt to demangle a string using different demangling schemes.
+/// The function uses heuristics to determine which demangling scheme to use.
+/// \param MangledName - reference to string to demangle.
+/// \returns - the demangled string, or a copy of the input string if no
+/// demangling occurred.
+std::string demangle(std::string_view MangledName);
+
+bool nonMicrosoftDemangle(std::string_view MangledName, std::string &Result);
+
/// "Partial" demangler. This supports demangling a string into an AST
/// (typically an intermediate stage in itaniumDemangle) and querying certain
/// properties or partially printing the demangled name.
@@ -59,7 +85,7 @@ struct ItaniumPartialDemangler {
bool partialDemangle(const char *MangledName);
/// Just print the entire mangled name into Buf. Buf and N behave like the
- /// second and third parameters to itaniumDemangle.
+ /// second and third parameters to __cxa_demangle.
char *finishDemangle(char *Buf, size_t *N) const;
/// Get the base name of a function. This doesn't include trailing template
@@ -95,6 +121,7 @@ struct ItaniumPartialDemangler {
bool isSpecialName() const;
~ItaniumPartialDemangler();
+
private:
void *RootNode;
void *Context;
diff --git a/externals/demangle/llvm/Demangle/DemangleConfig.h b/externals/demangle/llvm/Demangle/DemangleConfig.h
index a8aef9df1..c7f86d766 100644
--- a/externals/demangle/llvm/Demangle/DemangleConfig.h
+++ b/externals/demangle/llvm/Demangle/DemangleConfig.h
@@ -13,8 +13,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef LLVM_DEMANGLE_COMPILER_H
-#define LLVM_DEMANGLE_COMPILER_H
+#ifndef LLVM_DEMANGLE_DEMANGLECONFIG_H
+#define LLVM_DEMANGLE_DEMANGLECONFIG_H
#ifndef __has_feature
#define __has_feature(x) 0
diff --git a/externals/demangle/llvm/Demangle/ItaniumDemangle.h b/externals/demangle/llvm/Demangle/ItaniumDemangle.h
index 64b35c142..0dc3d7337 100644
--- a/externals/demangle/llvm/Demangle/ItaniumDemangle.h
+++ b/externals/demangle/llvm/Demangle/ItaniumDemangle.h
@@ -1,5 +1,5 @@
-//===------------------------- ItaniumDemangle.h ----------------*- C++ -*-===//
-//
+//===--- ItaniumDemangle.h -----------*- mode:c++;eval:(read-only-mode) -*-===//
+// Do not edit! See README.txt.
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-FileCopyrightText: Part of the LLVM Project
@@ -7,144 +7,220 @@
//
//===----------------------------------------------------------------------===//
//
-// Generic itanium demangler library. This file has two byte-per-byte identical
-// copies in the source tree, one in libcxxabi, and the other in llvm.
+// Generic itanium demangler library.
+// There are two copies of this file in the source tree. The one under
+// libcxxabi is the original and the one under llvm is the copy. Use
+// cp-to-llvm.sh to update the copy. See README.txt for more details.
//
//===----------------------------------------------------------------------===//
#ifndef DEMANGLE_ITANIUMDEMANGLE_H
#define DEMANGLE_ITANIUMDEMANGLE_H
-// FIXME: (possibly) incomplete list of features that clang mangles that this
-// file does not yet support:
-// - C++ modules TS
-
#include "DemangleConfig.h"
-#include "StringView.h"
+#include "StringViewExtras.h"
#include "Utility.h"
+#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <cstring>
-#include <numeric>
+#include <limits>
+#include <new>
+#include <string_view>
+#include <type_traits>
#include <utility>
-#define FOR_EACH_NODE_KIND(X) \
- X(NodeArrayNode) \
- X(DotSuffix) \
- X(VendorExtQualType) \
- X(QualType) \
- X(ConversionOperatorType) \
- X(PostfixQualifiedType) \
- X(ElaboratedTypeSpefType) \
- X(NameType) \
- X(AbiTagAttr) \
- X(EnableIfAttr) \
- X(ObjCProtoName) \
- X(PointerType) \
- X(ReferenceType) \
- X(PointerToMemberType) \
- X(ArrayType) \
- X(FunctionType) \
- X(NoexceptSpec) \
- X(DynamicExceptionSpec) \
- X(FunctionEncoding) \
- X(LiteralOperator) \
- X(SpecialName) \
- X(CtorVtableSpecialName) \
- X(QualifiedName) \
- X(NestedName) \
- X(LocalName) \
- X(VectorType) \
- X(PixelVectorType) \
- X(SyntheticTemplateParamName) \
- X(TypeTemplateParamDecl) \
- X(NonTypeTemplateParamDecl) \
- X(TemplateTemplateParamDecl) \
- X(TemplateParamPackDecl) \
- X(ParameterPack) \
- X(TemplateArgumentPack) \
- X(ParameterPackExpansion) \
- X(TemplateArgs) \
- X(ForwardTemplateReference) \
- X(NameWithTemplateArgs) \
- X(GlobalQualifiedName) \
- X(StdQualifiedName) \
- X(ExpandedSpecialSubstitution) \
- X(SpecialSubstitution) \
- X(CtorDtorName) \
- X(DtorName) \
- X(UnnamedTypeName) \
- X(ClosureTypeName) \
- X(StructuredBindingName) \
- X(BinaryExpr) \
- X(ArraySubscriptExpr) \
- X(PostfixExpr) \
- X(ConditionalExpr) \
- X(MemberExpr) \
- X(EnclosingExpr) \
- X(CastExpr) \
- X(SizeofParamPackExpr) \
- X(CallExpr) \
- X(NewExpr) \
- X(DeleteExpr) \
- X(PrefixExpr) \
- X(FunctionParam) \
- X(ConversionExpr) \
- X(InitListExpr) \
- X(FoldExpr) \
- X(ThrowExpr) \
- X(UUIDOfExpr) \
- X(BoolExpr) \
- X(StringLiteral) \
- X(LambdaExpr) \
- X(IntegerCastExpr) \
- X(IntegerLiteral) \
- X(FloatLiteral) \
- X(DoubleLiteral) \
- X(LongDoubleLiteral) \
- X(BracedExpr) \
- X(BracedRangeExpr)
-
DEMANGLE_NAMESPACE_BEGIN
+template <class T, size_t N> class PODSmallVector {
+ static_assert(std::is_pod<T>::value,
+ "T is required to be a plain old data type");
+
+ T *First = nullptr;
+ T *Last = nullptr;
+ T *Cap = nullptr;
+ T Inline[N] = {0};
+
+ bool isInline() const { return First == Inline; }
+
+ void clearInline() {
+ First = Inline;
+ Last = Inline;
+ Cap = Inline + N;
+ }
+
+ void reserve(size_t NewCap) {
+ size_t S = size();
+ if (isInline()) {
+ auto *Tmp = static_cast<T *>(std::malloc(NewCap * sizeof(T)));
+ if (Tmp == nullptr)
+ std::terminate();
+ std::copy(First, Last, Tmp);
+ First = Tmp;
+ } else {
+ First = static_cast<T *>(std::realloc(First, NewCap * sizeof(T)));
+ if (First == nullptr)
+ std::terminate();
+ }
+ Last = First + S;
+ Cap = First + NewCap;
+ }
+
+public:
+ PODSmallVector() : First(Inline), Last(First), Cap(Inline + N) {}
+
+ PODSmallVector(const PODSmallVector &) = delete;
+ PODSmallVector &operator=(const PODSmallVector &) = delete;
+
+ PODSmallVector(PODSmallVector &&Other) : PODSmallVector() {
+ if (Other.isInline()) {
+ std::copy(Other.begin(), Other.end(), First);
+ Last = First + Other.size();
+ Other.clear();
+ return;
+ }
+
+ First = Other.First;
+ Last = Other.Last;
+ Cap = Other.Cap;
+ Other.clearInline();
+ }
+
+ PODSmallVector &operator=(PODSmallVector &&Other) {
+ if (Other.isInline()) {
+ if (!isInline()) {
+ std::free(First);
+ clearInline();
+ }
+ std::copy(Other.begin(), Other.end(), First);
+ Last = First + Other.size();
+ Other.clear();
+ return *this;
+ }
+
+ if (isInline()) {
+ First = Other.First;
+ Last = Other.Last;
+ Cap = Other.Cap;
+ Other.clearInline();
+ return *this;
+ }
+
+ std::swap(First, Other.First);
+ std::swap(Last, Other.Last);
+ std::swap(Cap, Other.Cap);
+ Other.clear();
+ return *this;
+ }
+
+ // NOLINTNEXTLINE(readability-identifier-naming)
+ void push_back(const T &Elem) {
+ if (Last == Cap)
+ reserve(size() * 2);
+ *Last++ = Elem;
+ }
+
+ // NOLINTNEXTLINE(readability-identifier-naming)
+ void pop_back() {
+ assert(Last != First && "Popping empty vector!");
+ --Last;
+ }
+
+ void dropBack(size_t Index) {
+ assert(Index <= size() && "dropBack() can't expand!");
+ Last = First + Index;
+ }
+
+ T *begin() { return First; }
+ T *end() { return Last; }
+
+ bool empty() const { return First == Last; }
+ size_t size() const { return static_cast<size_t>(Last - First); }
+ T &back() {
+ assert(Last != First && "Calling back() on empty vector!");
+ return *(Last - 1);
+ }
+ T &operator[](size_t Index) {
+ assert(Index < size() && "Invalid access!");
+ return *(begin() + Index);
+ }
+ void clear() { Last = First; }
+
+ ~PODSmallVector() {
+ if (!isInline())
+ std::free(First);
+ }
+};
+
// Base class of all AST nodes. The AST is built by the parser, then is
// traversed by the printLeft/Right functions to produce a demangled string.
class Node {
public:
enum Kind : unsigned char {
-#define ENUMERATOR(NodeKind) K ## NodeKind,
- FOR_EACH_NODE_KIND(ENUMERATOR)
-#undef ENUMERATOR
+#define NODE(NodeKind) K##NodeKind,
+#include "ItaniumNodes.def"
};
/// Three-way bool to track a cached value. Unknown is possible if this node
/// has an unexpanded parameter pack below it that may affect this cache.
enum class Cache : unsigned char { Yes, No, Unknown, };
+ /// Operator precedence for expression nodes. Used to determine required
+ /// parens in expression emission.
+ enum class Prec {
+ Primary,
+ Postfix,
+ Unary,
+ Cast,
+ PtrMem,
+ Multiplicative,
+ Additive,
+ Shift,
+ Spaceship,
+ Relational,
+ Equality,
+ And,
+ Xor,
+ Ior,
+ AndIf,
+ OrIf,
+ Conditional,
+ Assign,
+ Comma,
+ Default,
+ };
+
private:
Kind K;
+ Prec Precedence : 6;
+
// FIXME: Make these protected.
public:
/// Tracks if this node has a component on its right side, in which case we
/// need to call printRight.
- Cache RHSComponentCache;
+ Cache RHSComponentCache : 2;
/// Track if this node is a (possibly qualified) array type. This can affect
/// how we format the output string.
- Cache ArrayCache;
+ Cache ArrayCache : 2;
/// Track if this node is a (possibly qualified) function type. This can
/// affect how we format the output string.
- Cache FunctionCache;
+ Cache FunctionCache : 2;
public:
- Node(Kind K_, Cache RHSComponentCache_ = Cache::No,
- Cache ArrayCache_ = Cache::No, Cache FunctionCache_ = Cache::No)
- : K(K_), RHSComponentCache(RHSComponentCache_), ArrayCache(ArrayCache_),
- FunctionCache(FunctionCache_) {}
+ Node(Kind K_, Prec Precedence_ = Prec::Primary,
+ Cache RHSComponentCache_ = Cache::No, Cache ArrayCache_ = Cache::No,
+ Cache FunctionCache_ = Cache::No)
+ : K(K_), Precedence(Precedence_), RHSComponentCache(RHSComponentCache_),
+ ArrayCache(ArrayCache_), FunctionCache(FunctionCache_) {}
+ Node(Kind K_, Cache RHSComponentCache_, Cache ArrayCache_ = Cache::No,
+ Cache FunctionCache_ = Cache::No)
+ : Node(K_, Prec::Primary, RHSComponentCache_, ArrayCache_,
+ FunctionCache_) {}
/// Visit the most-derived object corresponding to this object.
template<typename Fn> void visit(Fn F) const;
@@ -155,52 +231,65 @@ public:
// would construct an equivalent node.
//template<typename Fn> void match(Fn F) const;
- bool hasRHSComponent(OutputStream &S) const {
+ bool hasRHSComponent(OutputBuffer &OB) const {
if (RHSComponentCache != Cache::Unknown)
return RHSComponentCache == Cache::Yes;
- return hasRHSComponentSlow(S);
+ return hasRHSComponentSlow(OB);
}
- bool hasArray(OutputStream &S) const {
+ bool hasArray(OutputBuffer &OB) const {
if (ArrayCache != Cache::Unknown)
return ArrayCache == Cache::Yes;
- return hasArraySlow(S);
+ return hasArraySlow(OB);
}
- bool hasFunction(OutputStream &S) const {
+ bool hasFunction(OutputBuffer &OB) const {
if (FunctionCache != Cache::Unknown)
return FunctionCache == Cache::Yes;
- return hasFunctionSlow(S);
+ return hasFunctionSlow(OB);
}
Kind getKind() const { return K; }
- virtual bool hasRHSComponentSlow(OutputStream &) const { return false; }
- virtual bool hasArraySlow(OutputStream &) const { return false; }
- virtual bool hasFunctionSlow(OutputStream &) const { return false; }
+ Prec getPrecedence() const { return Precedence; }
+
+ virtual bool hasRHSComponentSlow(OutputBuffer &) const { return false; }
+ virtual bool hasArraySlow(OutputBuffer &) const { return false; }
+ virtual bool hasFunctionSlow(OutputBuffer &) const { return false; }
// Dig through "glue" nodes like ParameterPack and ForwardTemplateReference to
// get at a node that actually represents some concrete syntax.
- virtual const Node *getSyntaxNode(OutputStream &) const {
- return this;
- }
-
- void print(OutputStream &S) const {
- printLeft(S);
+ virtual const Node *getSyntaxNode(OutputBuffer &) const { return this; }
+
+ // Print this node as an expression operand, surrounding it in parentheses if
+ // its precedence is [Strictly] weaker than P.
+ void printAsOperand(OutputBuffer &OB, Prec P = Prec::Default,
+ bool StrictlyWorse = false) const {
+ bool Paren =
+ unsigned(getPrecedence()) >= unsigned(P) + unsigned(StrictlyWorse);
+ if (Paren)
+ OB.printOpen();
+ print(OB);
+ if (Paren)
+ OB.printClose();
+ }
+
+ void print(OutputBuffer &OB) const {
+ printLeft(OB);
if (RHSComponentCache != Cache::No)
- printRight(S);
+ printRight(OB);
}
- // Print the "left" side of this Node into OutputStream.
- virtual void printLeft(OutputStream &) const = 0;
+ // Print the "left" side of this Node into OutputBuffer.
+ virtual void printLeft(OutputBuffer &) const = 0;
// Print the "right". This distinction is necessary to represent C++ types
// that appear on the RHS of their subtype, such as arrays or functions.
// Since most types don't have such a component, provide a default
// implementation.
- virtual void printRight(OutputStream &) const {}
+ virtual void printRight(OutputBuffer &) const {}
- virtual StringView getBaseName() const { return StringView(); }
+ virtual std::string_view getBaseName() const { return {}; }
// Silence compiler warnings, this dtor will never be called.
virtual ~Node() = default;
@@ -227,19 +316,19 @@ public:
Node *operator[](size_t Idx) const { return Elements[Idx]; }
- void printWithComma(OutputStream &S) const {
+ void printWithComma(OutputBuffer &OB) const {
bool FirstElement = true;
for (size_t Idx = 0; Idx != NumElements; ++Idx) {
- size_t BeforeComma = S.getCurrentPosition();
+ size_t BeforeComma = OB.getCurrentPosition();
if (!FirstElement)
- S += ", ";
- size_t AfterComma = S.getCurrentPosition();
- Elements[Idx]->print(S);
+ OB += ", ";
+ size_t AfterComma = OB.getCurrentPosition();
+ Elements[Idx]->printAsOperand(OB, Node::Prec::Comma);
// Elements[Idx] is an empty parameter pack expansion, we should erase the
// comma we just printed.
- if (AfterComma == S.getCurrentPosition()) {
- S.setCurrentPosition(BeforeComma);
+ if (AfterComma == OB.getCurrentPosition()) {
+ OB.setCurrentPosition(BeforeComma);
continue;
}
@@ -254,43 +343,48 @@ struct NodeArrayNode : Node {
template<typename Fn> void match(Fn F) const { F(Array); }
- void printLeft(OutputStream &S) const override {
- Array.printWithComma(S);
- }
+ void printLeft(OutputBuffer &OB) const override { Array.printWithComma(OB); }
};
class DotSuffix final : public Node {
const Node *Prefix;
- const StringView Suffix;
+ const std::string_view Suffix;
public:
- DotSuffix(const Node *Prefix_, StringView Suffix_)
+ DotSuffix(const Node *Prefix_, std::string_view Suffix_)
: Node(KDotSuffix), Prefix(Prefix_), Suffix(Suffix_) {}
template<typename Fn> void match(Fn F) const { F(Prefix, Suffix); }
- void printLeft(OutputStream &s) const override {
- Prefix->print(s);
- s += " (";
- s += Suffix;
- s += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ Prefix->print(OB);
+ OB += " (";
+ OB += Suffix;
+ OB += ")";
}
};
class VendorExtQualType final : public Node {
const Node *Ty;
- StringView Ext;
+ std::string_view Ext;
+ const Node *TA;
public:
- VendorExtQualType(const Node *Ty_, StringView Ext_)
- : Node(KVendorExtQualType), Ty(Ty_), Ext(Ext_) {}
+ VendorExtQualType(const Node *Ty_, std::string_view Ext_, const Node *TA_)
+ : Node(KVendorExtQualType), Ty(Ty_), Ext(Ext_), TA(TA_) {}
+
+ const Node *getTy() const { return Ty; }
+ std::string_view getExt() const { return Ext; }
+ const Node *getTA() const { return TA; }
- template<typename Fn> void match(Fn F) const { F(Ty, Ext); }
+ template <typename Fn> void match(Fn F) const { F(Ty, Ext, TA); }
- void printLeft(OutputStream &S) const override {
- Ty->print(S);
- S += " ";
- S += Ext;
+ void printLeft(OutputBuffer &OB) const override {
+ Ty->print(OB);
+ OB += " ";
+ OB += Ext;
+ if (TA != nullptr)
+ TA->print(OB);
}
};
@@ -316,13 +410,13 @@ protected:
const Qualifiers Quals;
const Node *Child;
- void printQuals(OutputStream &S) const {
+ void printQuals(OutputBuffer &OB) const {
if (Quals & QualConst)
- S += " const";
+ OB += " const";
if (Quals & QualVolatile)
- S += " volatile";
+ OB += " volatile";
if (Quals & QualRestrict)
- S += " restrict";
+ OB += " restrict";
}
public:
@@ -331,24 +425,27 @@ public:
Child_->ArrayCache, Child_->FunctionCache),
Quals(Quals_), Child(Child_) {}
+ Qualifiers getQuals() const { return Quals; }
+ const Node *getChild() const { return Child; }
+
template<typename Fn> void match(Fn F) const { F(Child, Quals); }
- bool hasRHSComponentSlow(OutputStream &S) const override {
- return Child->hasRHSComponent(S);
+ bool hasRHSComponentSlow(OutputBuffer &OB) const override {
+ return Child->hasRHSComponent(OB);
}
- bool hasArraySlow(OutputStream &S) const override {
- return Child->hasArray(S);
+ bool hasArraySlow(OutputBuffer &OB) const override {
+ return Child->hasArray(OB);
}
- bool hasFunctionSlow(OutputStream &S) const override {
- return Child->hasFunction(S);
+ bool hasFunctionSlow(OutputBuffer &OB) const override {
+ return Child->hasFunction(OB);
}
- void printLeft(OutputStream &S) const override {
- Child->printLeft(S);
- printQuals(S);
+ void printLeft(OutputBuffer &OB) const override {
+ Child->printLeft(OB);
+ printQuals(OB);
}
- void printRight(OutputStream &S) const override { Child->printRight(S); }
+ void printRight(OutputBuffer &OB) const override { Child->printRight(OB); }
};
class ConversionOperatorType final : public Node {
@@ -360,74 +457,96 @@ public:
template<typename Fn> void match(Fn F) const { F(Ty); }
- void printLeft(OutputStream &S) const override {
- S += "operator ";
- Ty->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "operator ";
+ Ty->print(OB);
}
};
class PostfixQualifiedType final : public Node {
const Node *Ty;
- const StringView Postfix;
+ const std::string_view Postfix;
public:
- PostfixQualifiedType(Node *Ty_, StringView Postfix_)
+ PostfixQualifiedType(const Node *Ty_, std::string_view Postfix_)
: Node(KPostfixQualifiedType), Ty(Ty_), Postfix(Postfix_) {}
template<typename Fn> void match(Fn F) const { F(Ty, Postfix); }
- void printLeft(OutputStream &s) const override {
- Ty->printLeft(s);
- s += Postfix;
+ void printLeft(OutputBuffer &OB) const override {
+ Ty->printLeft(OB);
+ OB += Postfix;
}
};
class NameType final : public Node {
- const StringView Name;
+ const std::string_view Name;
public:
- NameType(StringView Name_) : Node(KNameType), Name(Name_) {}
+ NameType(std::string_view Name_) : Node(KNameType), Name(Name_) {}
template<typename Fn> void match(Fn F) const { F(Name); }
- StringView getName() const { return Name; }
- StringView getBaseName() const override { return Name; }
+ std::string_view getName() const { return Name; }
+ std::string_view getBaseName() const override { return Name; }
+
+ void printLeft(OutputBuffer &OB) const override { OB += Name; }
+};
+
+class BitIntType final : public Node {
+ const Node *Size;
+ bool Signed;
+
+public:
+ BitIntType(const Node *Size_, bool Signed_)
+ : Node(KBitIntType), Size(Size_), Signed(Signed_) {}
+
+ template <typename Fn> void match(Fn F) const { F(Size, Signed); }
- void printLeft(OutputStream &s) const override { s += Name; }
+ void printLeft(OutputBuffer &OB) const override {
+ if (!Signed)
+ OB += "unsigned ";
+ OB += "_BitInt";
+ OB.printOpen();
+ Size->printAsOperand(OB);
+ OB.printClose();
+ }
};
class ElaboratedTypeSpefType : public Node {
- StringView Kind;
+ std::string_view Kind;
Node *Child;
public:
- ElaboratedTypeSpefType(StringView Kind_, Node *Child_)
+ ElaboratedTypeSpefType(std::string_view Kind_, Node *Child_)
: Node(KElaboratedTypeSpefType), Kind(Kind_), Child(Child_) {}
template<typename Fn> void match(Fn F) const { F(Kind, Child); }
- void printLeft(OutputStream &S) const override {
- S += Kind;
- S += ' ';
- Child->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += Kind;
+ OB += ' ';
+ Child->print(OB);
}
};
struct AbiTagAttr : Node {
Node *Base;
- StringView Tag;
+ std::string_view Tag;
- AbiTagAttr(Node* Base_, StringView Tag_)
- : Node(KAbiTagAttr, Base_->RHSComponentCache,
- Base_->ArrayCache, Base_->FunctionCache),
+ AbiTagAttr(Node *Base_, std::string_view Tag_)
+ : Node(KAbiTagAttr, Base_->RHSComponentCache, Base_->ArrayCache,
+ Base_->FunctionCache),
Base(Base_), Tag(Tag_) {}
template<typename Fn> void match(Fn F) const { F(Base, Tag); }
- void printLeft(OutputStream &S) const override {
- Base->printLeft(S);
- S += "[abi:";
- S += Tag;
- S += "]";
+ std::string_view getBaseName() const override { return Base->getBaseName(); }
+
+ void printLeft(OutputBuffer &OB) const override {
+ Base->printLeft(OB);
+ OB += "[abi:";
+ OB += Tag;
+ OB += "]";
}
};
@@ -439,21 +558,21 @@ public:
template<typename Fn> void match(Fn F) const { F(Conditions); }
- void printLeft(OutputStream &S) const override {
- S += " [enable_if:";
- Conditions.printWithComma(S);
- S += ']';
+ void printLeft(OutputBuffer &OB) const override {
+ OB += " [enable_if:";
+ Conditions.printWithComma(OB);
+ OB += ']';
}
};
class ObjCProtoName : public Node {
const Node *Ty;
- StringView Protocol;
+ std::string_view Protocol;
friend class PointerType;
public:
- ObjCProtoName(const Node *Ty_, StringView Protocol_)
+ ObjCProtoName(const Node *Ty_, std::string_view Protocol_)
: Node(KObjCProtoName), Ty(Ty_), Protocol(Protocol_) {}
template<typename Fn> void match(Fn F) const { F(Ty, Protocol); }
@@ -463,11 +582,11 @@ public:
static_cast<const NameType *>(Ty)->getName() == "objc_object";
}
- void printLeft(OutputStream &S) const override {
- Ty->print(S);
- S += "<";
- S += Protocol;
- S += ">";
+ void printLeft(OutputBuffer &OB) const override {
+ Ty->print(OB);
+ OB += "<";
+ OB += Protocol;
+ OB += ">";
}
};
@@ -479,36 +598,38 @@ public:
: Node(KPointerType, Pointee_->RHSComponentCache),
Pointee(Pointee_) {}
+ const Node *getPointee() const { return Pointee; }
+
template<typename Fn> void match(Fn F) const { F(Pointee); }
- bool hasRHSComponentSlow(OutputStream &S) const override {
- return Pointee->hasRHSComponent(S);
+ bool hasRHSComponentSlow(OutputBuffer &OB) const override {
+ return Pointee->hasRHSComponent(OB);
}
- void printLeft(OutputStream &s) const override {
+ void printLeft(OutputBuffer &OB) const override {
// We rewrite objc_object<SomeProtocol>* into id<SomeProtocol>.
if (Pointee->getKind() != KObjCProtoName ||
!static_cast<const ObjCProtoName *>(Pointee)->isObjCObject()) {
- Pointee->printLeft(s);
- if (Pointee->hasArray(s))
- s += " ";
- if (Pointee->hasArray(s) || Pointee->hasFunction(s))
- s += "(";
- s += "*";
+ Pointee->printLeft(OB);
+ if (Pointee->hasArray(OB))
+ OB += " ";
+ if (Pointee->hasArray(OB) || Pointee->hasFunction(OB))
+ OB += "(";
+ OB += "*";
} else {
const auto *objcProto = static_cast<const ObjCProtoName *>(Pointee);
- s += "id<";
- s += objcProto->Protocol;
- s += ">";
+ OB += "id<";
+ OB += objcProto->Protocol;
+ OB += ">";
}
}
- void printRight(OutputStream &s) const override {
+ void printRight(OutputBuffer &OB) const override {
if (Pointee->getKind() != KObjCProtoName ||
!static_cast<const ObjCProtoName *>(Pointee)->isObjCObject()) {
- if (Pointee->hasArray(s) || Pointee->hasFunction(s))
- s += ")";
- Pointee->printRight(s);
+ if (Pointee->hasArray(OB) || Pointee->hasFunction(OB))
+ OB += ")";
+ Pointee->printRight(OB);
}
}
};
@@ -528,15 +649,30 @@ class ReferenceType : public Node {
// Dig through any refs to refs, collapsing the ReferenceTypes as we go. The
// rule here is rvalue ref to rvalue ref collapses to a rvalue ref, and any
// other combination collapses to a lvalue ref.
- std::pair<ReferenceKind, const Node *> collapse(OutputStream &S) const {
+ //
+ // A combination of a TemplateForwardReference and a back-ref Substitution
+ // from an ill-formed string may have created a cycle; use cycle detection to
+ // avoid looping forever.
+ std::pair<ReferenceKind, const Node *> collapse(OutputBuffer &OB) const {
auto SoFar = std::make_pair(RK, Pointee);
+ // Track the chain of nodes for the Floyd's 'tortoise and hare'
+ // cycle-detection algorithm, since getSyntaxNode(S) is impure
+ PODSmallVector<const Node *, 8> Prev;
for (;;) {
- const Node *SN = SoFar.second->getSyntaxNode(S);
+ const Node *SN = SoFar.second->getSyntaxNode(OB);
if (SN->getKind() != KReferenceType)
break;
auto *RT = static_cast<const ReferenceType *>(SN);
SoFar.second = RT->Pointee;
SoFar.first = std::min(SoFar.first, RT->RK);
+
+ // The middle of Prev is the 'slow' pointer moving at half speed
+ Prev.push_back(SoFar.second);
+ if (Prev.size() > 1 && SoFar.second == Prev[(Prev.size() - 1) / 2]) {
+ // Cycle detected
+ SoFar.second = nullptr;
+ break;
+ }
}
return SoFar;
}
@@ -548,31 +684,35 @@ public:
template<typename Fn> void match(Fn F) const { F(Pointee, RK); }
- bool hasRHSComponentSlow(OutputStream &S) const override {
- return Pointee->hasRHSComponent(S);
+ bool hasRHSComponentSlow(OutputBuffer &OB) const override {
+ return Pointee->hasRHSComponent(OB);
}
- void printLeft(OutputStream &s) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (Printing)
return;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- std::pair<ReferenceKind, const Node *> Collapsed = collapse(s);
- Collapsed.second->printLeft(s);
- if (Collapsed.second->hasArray(s))
- s += " ";
- if (Collapsed.second->hasArray(s) || Collapsed.second->hasFunction(s))
- s += "(";
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ std::pair<ReferenceKind, const Node *> Collapsed = collapse(OB);
+ if (!Collapsed.second)
+ return;
+ Collapsed.second->printLeft(OB);
+ if (Collapsed.second->hasArray(OB))
+ OB += " ";
+ if (Collapsed.second->hasArray(OB) || Collapsed.second->hasFunction(OB))
+ OB += "(";
- s += (Collapsed.first == ReferenceKind::LValue ? "&" : "&&");
+ OB += (Collapsed.first == ReferenceKind::LValue ? "&" : "&&");
}
- void printRight(OutputStream &s) const override {
+ void printRight(OutputBuffer &OB) const override {
if (Printing)
return;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- std::pair<ReferenceKind, const Node *> Collapsed = collapse(s);
- if (Collapsed.second->hasArray(s) || Collapsed.second->hasFunction(s))
- s += ")";
- Collapsed.second->printRight(s);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ std::pair<ReferenceKind, const Node *> Collapsed = collapse(OB);
+ if (!Collapsed.second)
+ return;
+ if (Collapsed.second->hasArray(OB) || Collapsed.second->hasFunction(OB))
+ OB += ")";
+ Collapsed.second->printRight(OB);
}
};
@@ -587,69 +727,33 @@ public:
template<typename Fn> void match(Fn F) const { F(ClassType, MemberType); }
- bool hasRHSComponentSlow(OutputStream &S) const override {
- return MemberType->hasRHSComponent(S);
+ bool hasRHSComponentSlow(OutputBuffer &OB) const override {
+ return MemberType->hasRHSComponent(OB);
}
- void printLeft(OutputStream &s) const override {
- MemberType->printLeft(s);
- if (MemberType->hasArray(s) || MemberType->hasFunction(s))
- s += "(";
+ void printLeft(OutputBuffer &OB) const override {
+ MemberType->printLeft(OB);
+ if (MemberType->hasArray(OB) || MemberType->hasFunction(OB))
+ OB += "(";
else
- s += " ";
- ClassType->print(s);
- s += "::*";
- }
-
- void printRight(OutputStream &s) const override {
- if (MemberType->hasArray(s) || MemberType->hasFunction(s))
- s += ")";
- MemberType->printRight(s);
- }
-};
-
-class NodeOrString {
- const void *First;
- const void *Second;
-
-public:
- /* implicit */ NodeOrString(StringView Str) {
- const char *FirstChar = Str.begin();
- const char *SecondChar = Str.end();
- if (SecondChar == nullptr) {
- assert(FirstChar == SecondChar);
- ++FirstChar, ++SecondChar;
- }
- First = static_cast<const void *>(FirstChar);
- Second = static_cast<const void *>(SecondChar);
+ OB += " ";
+ ClassType->print(OB);
+ OB += "::*";
}
- /* implicit */ NodeOrString(Node *N)
- : First(static_cast<const void *>(N)), Second(nullptr) {}
- NodeOrString() : First(nullptr), Second(nullptr) {}
-
- bool isString() const { return Second && First; }
- bool isNode() const { return First && !Second; }
- bool isEmpty() const { return !First && !Second; }
-
- StringView asString() const {
- assert(isString());
- return StringView(static_cast<const char *>(First),
- static_cast<const char *>(Second));
- }
-
- const Node *asNode() const {
- assert(isNode());
- return static_cast<const Node *>(First);
+ void printRight(OutputBuffer &OB) const override {
+ if (MemberType->hasArray(OB) || MemberType->hasFunction(OB))
+ OB += ")";
+ MemberType->printRight(OB);
}
};
class ArrayType final : public Node {
const Node *Base;
- NodeOrString Dimension;
+ Node *Dimension;
public:
- ArrayType(const Node *Base_, NodeOrString Dimension_)
+ ArrayType(const Node *Base_, Node *Dimension_)
: Node(KArrayType,
/*RHSComponentCache=*/Cache::Yes,
/*ArrayCache=*/Cache::Yes),
@@ -657,21 +761,19 @@ public:
template<typename Fn> void match(Fn F) const { F(Base, Dimension); }
- bool hasRHSComponentSlow(OutputStream &) const override { return true; }
- bool hasArraySlow(OutputStream &) const override { return true; }
+ bool hasRHSComponentSlow(OutputBuffer &) const override { return true; }
+ bool hasArraySlow(OutputBuffer &) const override { return true; }
- void printLeft(OutputStream &S) const override { Base->printLeft(S); }
+ void printLeft(OutputBuffer &OB) const override { Base->printLeft(OB); }
- void printRight(OutputStream &S) const override {
- if (S.back() != ']')
- S += " ";
- S += "[";
- if (Dimension.isString())
- S += Dimension.asString();
- else if (Dimension.isNode())
- Dimension.asNode()->print(S);
- S += "]";
- Base->printRight(S);
+ void printRight(OutputBuffer &OB) const override {
+ if (OB.back() != ']')
+ OB += " ";
+ OB += "[";
+ if (Dimension)
+ Dimension->print(OB);
+ OB += "]";
+ Base->printRight(OB);
}
};
@@ -695,8 +797,8 @@ public:
F(Ret, Params, CVQuals, RefQual, ExceptionSpec);
}
- bool hasRHSComponentSlow(OutputStream &) const override { return true; }
- bool hasFunctionSlow(OutputStream &) const override { return true; }
+ bool hasRHSComponentSlow(OutputBuffer &) const override { return true; }
+ bool hasFunctionSlow(OutputBuffer &) const override { return true; }
// Handle C++'s ... quirky decl grammar by using the left & right
// distinction. Consider:
@@ -705,32 +807,32 @@ public:
// that takes a char and returns an int. If we're trying to print f, start
// by printing out the return types's left, then print our parameters, then
// finally print right of the return type.
- void printLeft(OutputStream &S) const override {
- Ret->printLeft(S);
- S += " ";
+ void printLeft(OutputBuffer &OB) const override {
+ Ret->printLeft(OB);
+ OB += " ";
}
- void printRight(OutputStream &S) const override {
- S += "(";
- Params.printWithComma(S);
- S += ")";
- Ret->printRight(S);
+ void printRight(OutputBuffer &OB) const override {
+ OB.printOpen();
+ Params.printWithComma(OB);
+ OB.printClose();
+ Ret->printRight(OB);
if (CVQuals & QualConst)
- S += " const";
+ OB += " const";
if (CVQuals & QualVolatile)
- S += " volatile";
+ OB += " volatile";
if (CVQuals & QualRestrict)
- S += " restrict";
+ OB += " restrict";
if (RefQual == FrefQualLValue)
- S += " &";
+ OB += " &";
else if (RefQual == FrefQualRValue)
- S += " &&";
+ OB += " &&";
if (ExceptionSpec != nullptr) {
- S += ' ';
- ExceptionSpec->print(S);
+ OB += ' ';
+ ExceptionSpec->print(OB);
}
}
};
@@ -742,10 +844,11 @@ public:
template<typename Fn> void match(Fn F) const { F(E); }
- void printLeft(OutputStream &S) const override {
- S += "noexcept(";
- E->print(S);
- S += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "noexcept";
+ OB.printOpen();
+ E->printAsOperand(OB);
+ OB.printClose();
}
};
@@ -757,10 +860,11 @@ public:
template<typename Fn> void match(Fn F) const { F(Types); }
- void printLeft(OutputStream &S) const override {
- S += "throw(";
- Types.printWithComma(S);
- S += ')';
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "throw";
+ OB.printOpen();
+ Types.printWithComma(OB);
+ OB.printClose();
}
};
@@ -791,41 +895,41 @@ public:
NodeArray getParams() const { return Params; }
const Node *getReturnType() const { return Ret; }
- bool hasRHSComponentSlow(OutputStream &) const override { return true; }
- bool hasFunctionSlow(OutputStream &) const override { return true; }
+ bool hasRHSComponentSlow(OutputBuffer &) const override { return true; }
+ bool hasFunctionSlow(OutputBuffer &) const override { return true; }
const Node *getName() const { return Name; }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (Ret) {
- Ret->printLeft(S);
- if (!Ret->hasRHSComponent(S))
- S += " ";
+ Ret->printLeft(OB);
+ if (!Ret->hasRHSComponent(OB))
+ OB += " ";
}
- Name->print(S);
+ Name->print(OB);
}
- void printRight(OutputStream &S) const override {
- S += "(";
- Params.printWithComma(S);
- S += ")";
+ void printRight(OutputBuffer &OB) const override {
+ OB.printOpen();
+ Params.printWithComma(OB);
+ OB.printClose();
if (Ret)
- Ret->printRight(S);
+ Ret->printRight(OB);
if (CVQuals & QualConst)
- S += " const";
+ OB += " const";
if (CVQuals & QualVolatile)
- S += " volatile";
+ OB += " volatile";
if (CVQuals & QualRestrict)
- S += " restrict";
+ OB += " restrict";
if (RefQual == FrefQualLValue)
- S += " &";
+ OB += " &";
else if (RefQual == FrefQualRValue)
- S += " &&";
+ OB += " &&";
if (Attrs != nullptr)
- Attrs->print(S);
+ Attrs->print(OB);
}
};
@@ -838,25 +942,25 @@ public:
template<typename Fn> void match(Fn F) const { F(OpName); }
- void printLeft(OutputStream &S) const override {
- S += "operator\"\" ";
- OpName->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "operator\"\" ";
+ OpName->print(OB);
}
};
class SpecialName final : public Node {
- const StringView Special;
+ const std::string_view Special;
const Node *Child;
public:
- SpecialName(StringView Special_, const Node *Child_)
+ SpecialName(std::string_view Special_, const Node *Child_)
: Node(KSpecialName), Special(Special_), Child(Child_) {}
template<typename Fn> void match(Fn F) const { F(Special, Child); }
- void printLeft(OutputStream &S) const override {
- S += Special;
- Child->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += Special;
+ Child->print(OB);
}
};
@@ -871,11 +975,11 @@ public:
template<typename Fn> void match(Fn F) const { F(FirstType, SecondType); }
- void printLeft(OutputStream &S) const override {
- S += "construction vtable for ";
- FirstType->print(S);
- S += "-in-";
- SecondType->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "construction vtable for ";
+ FirstType->print(OB);
+ OB += "-in-";
+ SecondType->print(OB);
}
};
@@ -888,12 +992,52 @@ struct NestedName : Node {
template<typename Fn> void match(Fn F) const { F(Qual, Name); }
- StringView getBaseName() const override { return Name->getBaseName(); }
+ std::string_view getBaseName() const override { return Name->getBaseName(); }
+
+ void printLeft(OutputBuffer &OB) const override {
+ Qual->print(OB);
+ OB += "::";
+ Name->print(OB);
+ }
+};
+
+struct ModuleName : Node {
+ ModuleName *Parent;
+ Node *Name;
+ bool IsPartition;
+
+ ModuleName(ModuleName *Parent_, Node *Name_, bool IsPartition_ = false)
+ : Node(KModuleName), Parent(Parent_), Name(Name_),
+ IsPartition(IsPartition_) {}
+
+ template <typename Fn> void match(Fn F) const {
+ F(Parent, Name, IsPartition);
+ }
+
+ void printLeft(OutputBuffer &OB) const override {
+ if (Parent)
+ Parent->print(OB);
+ if (Parent || IsPartition)
+ OB += IsPartition ? ':' : '.';
+ Name->print(OB);
+ }
+};
+
+struct ModuleEntity : Node {
+ ModuleName *Module;
+ Node *Name;
+
+ ModuleEntity(ModuleName *Module_, Node *Name_)
+ : Node(KModuleEntity), Module(Module_), Name(Name_) {}
+
+ template <typename Fn> void match(Fn F) const { F(Module, Name); }
+
+ std::string_view getBaseName() const override { return Name->getBaseName(); }
- void printLeft(OutputStream &S) const override {
- Qual->print(S);
- S += "::";
- Name->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ Name->print(OB);
+ OB += '@';
+ Module->print(OB);
}
};
@@ -906,10 +1050,10 @@ struct LocalName : Node {
template<typename Fn> void match(Fn F) const { F(Encoding, Entity); }
- void printLeft(OutputStream &S) const override {
- Encoding->print(S);
- S += "::";
- Entity->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ Encoding->print(OB);
+ OB += "::";
+ Entity->print(OB);
}
};
@@ -924,51 +1068,66 @@ public:
template<typename Fn> void match(Fn F) const { F(Qualifier, Name); }
- StringView getBaseName() const override { return Name->getBaseName(); }
+ std::string_view getBaseName() const override { return Name->getBaseName(); }
- void printLeft(OutputStream &S) const override {
- Qualifier->print(S);
- S += "::";
- Name->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ Qualifier->print(OB);
+ OB += "::";
+ Name->print(OB);
}
};
class VectorType final : public Node {
const Node *BaseType;
- const NodeOrString Dimension;
+ const Node *Dimension;
public:
- VectorType(const Node *BaseType_, NodeOrString Dimension_)
- : Node(KVectorType), BaseType(BaseType_),
- Dimension(Dimension_) {}
+ VectorType(const Node *BaseType_, const Node *Dimension_)
+ : Node(KVectorType), BaseType(BaseType_), Dimension(Dimension_) {}
+
+ const Node *getBaseType() const { return BaseType; }
+ const Node *getDimension() const { return Dimension; }
template<typename Fn> void match(Fn F) const { F(BaseType, Dimension); }
- void printLeft(OutputStream &S) const override {
- BaseType->print(S);
- S += " vector[";
- if (Dimension.isNode())
- Dimension.asNode()->print(S);
- else if (Dimension.isString())
- S += Dimension.asString();
- S += "]";
+ void printLeft(OutputBuffer &OB) const override {
+ BaseType->print(OB);
+ OB += " vector[";
+ if (Dimension)
+ Dimension->print(OB);
+ OB += "]";
}
};
class PixelVectorType final : public Node {
- const NodeOrString Dimension;
+ const Node *Dimension;
public:
- PixelVectorType(NodeOrString Dimension_)
+ PixelVectorType(const Node *Dimension_)
: Node(KPixelVectorType), Dimension(Dimension_) {}
template<typename Fn> void match(Fn F) const { F(Dimension); }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
// FIXME: This should demangle as "vector pixel".
- S += "pixel vector[";
- S += Dimension.asString();
- S += "]";
+ OB += "pixel vector[";
+ Dimension->print(OB);
+ OB += "]";
+ }
+};
+
+class BinaryFPType final : public Node {
+ const Node *Dimension;
+
+public:
+ BinaryFPType(const Node *Dimension_)
+ : Node(KBinaryFPType), Dimension(Dimension_) {}
+
+ template<typename Fn> void match(Fn F) const { F(Dimension); }
+
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "_Float";
+ Dimension->print(OB);
}
};
@@ -990,20 +1149,20 @@ public:
template<typename Fn> void match(Fn F) const { F(Kind, Index); }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
switch (Kind) {
case TemplateParamKind::Type:
- S += "$T";
+ OB += "$T";
break;
case TemplateParamKind::NonType:
- S += "$N";
+ OB += "$N";
break;
case TemplateParamKind::Template:
- S += "$TT";
+ OB += "$TT";
break;
}
if (Index > 0)
- S << Index - 1;
+ OB << Index - 1;
}
};
@@ -1017,13 +1176,9 @@ public:
template<typename Fn> void match(Fn F) const { F(Name); }
- void printLeft(OutputStream &S) const override {
- S += "typename ";
- }
+ void printLeft(OutputBuffer &OB) const override { OB += "typename "; }
- void printRight(OutputStream &S) const override {
- Name->print(S);
- }
+ void printRight(OutputBuffer &OB) const override { Name->print(OB); }
};
/// A non-type template parameter declaration, 'int N'.
@@ -1037,15 +1192,15 @@ public:
template<typename Fn> void match(Fn F) const { F(Name, Type); }
- void printLeft(OutputStream &S) const override {
- Type->printLeft(S);
- if (!Type->hasRHSComponent(S))
- S += " ";
+ void printLeft(OutputBuffer &OB) const override {
+ Type->printLeft(OB);
+ if (!Type->hasRHSComponent(OB))
+ OB += " ";
}
- void printRight(OutputStream &S) const override {
- Name->print(S);
- Type->printRight(S);
+ void printRight(OutputBuffer &OB) const override {
+ Name->print(OB);
+ Type->printRight(OB);
}
};
@@ -1062,15 +1217,14 @@ public:
template<typename Fn> void match(Fn F) const { F(Name, Params); }
- void printLeft(OutputStream &S) const override {
- S += "template<";
- Params.printWithComma(S);
- S += "> typename ";
+ void printLeft(OutputBuffer &OB) const override {
+ ScopedOverride<unsigned> LT(OB.GtIsGt, 0);
+ OB += "template<";
+ Params.printWithComma(OB);
+ OB += "> typename ";
}
- void printRight(OutputStream &S) const override {
- Name->print(S);
- }
+ void printRight(OutputBuffer &OB) const override { Name->print(OB); }
};
/// A template parameter pack declaration, 'typename ...T'.
@@ -1083,14 +1237,12 @@ public:
template<typename Fn> void match(Fn F) const { F(Param); }
- void printLeft(OutputStream &S) const override {
- Param->printLeft(S);
- S += "...";
+ void printLeft(OutputBuffer &OB) const override {
+ Param->printLeft(OB);
+ OB += "...";
}
- void printRight(OutputStream &S) const override {
- Param->printRight(S);
- }
+ void printRight(OutputBuffer &OB) const override { Param->printRight(OB); }
};
/// An unexpanded parameter pack (either in the expression or type context). If
@@ -1104,11 +1256,12 @@ public:
class ParameterPack final : public Node {
NodeArray Data;
- // Setup OutputStream for a pack expansion unless we're already expanding one.
- void initializePackExpansion(OutputStream &S) const {
- if (S.CurrentPackMax == std::numeric_limits<unsigned>::max()) {
- S.CurrentPackMax = static_cast<unsigned>(Data.size());
- S.CurrentPackIndex = 0;
+ // Setup OutputBuffer for a pack expansion, unless we're already expanding
+ // one.
+ void initializePackExpansion(OutputBuffer &OB) const {
+ if (OB.CurrentPackMax == std::numeric_limits<unsigned>::max()) {
+ OB.CurrentPackMax = static_cast<unsigned>(Data.size());
+ OB.CurrentPackIndex = 0;
}
}
@@ -1131,38 +1284,38 @@ public:
template<typename Fn> void match(Fn F) const { F(Data); }
- bool hasRHSComponentSlow(OutputStream &S) const override {
- initializePackExpansion(S);
- size_t Idx = S.CurrentPackIndex;
- return Idx < Data.size() && Data[Idx]->hasRHSComponent(S);
+ bool hasRHSComponentSlow(OutputBuffer &OB) const override {
+ initializePackExpansion(OB);
+ size_t Idx = OB.CurrentPackIndex;
+ return Idx < Data.size() && Data[Idx]->hasRHSComponent(OB);
}
- bool hasArraySlow(OutputStream &S) const override {
- initializePackExpansion(S);
- size_t Idx = S.CurrentPackIndex;
- return Idx < Data.size() && Data[Idx]->hasArray(S);
+ bool hasArraySlow(OutputBuffer &OB) const override {
+ initializePackExpansion(OB);
+ size_t Idx = OB.CurrentPackIndex;
+ return Idx < Data.size() && Data[Idx]->hasArray(OB);
}
- bool hasFunctionSlow(OutputStream &S) const override {
- initializePackExpansion(S);
- size_t Idx = S.CurrentPackIndex;
- return Idx < Data.size() && Data[Idx]->hasFunction(S);
+ bool hasFunctionSlow(OutputBuffer &OB) const override {
+ initializePackExpansion(OB);
+ size_t Idx = OB.CurrentPackIndex;
+ return Idx < Data.size() && Data[Idx]->hasFunction(OB);
}
- const Node *getSyntaxNode(OutputStream &S) const override {
- initializePackExpansion(S);
- size_t Idx = S.CurrentPackIndex;
- return Idx < Data.size() ? Data[Idx]->getSyntaxNode(S) : this;
+ const Node *getSyntaxNode(OutputBuffer &OB) const override {
+ initializePackExpansion(OB);
+ size_t Idx = OB.CurrentPackIndex;
+ return Idx < Data.size() ? Data[Idx]->getSyntaxNode(OB) : this;
}
- void printLeft(OutputStream &S) const override {
- initializePackExpansion(S);
- size_t Idx = S.CurrentPackIndex;
+ void printLeft(OutputBuffer &OB) const override {
+ initializePackExpansion(OB);
+ size_t Idx = OB.CurrentPackIndex;
if (Idx < Data.size())
- Data[Idx]->printLeft(S);
+ Data[Idx]->printLeft(OB);
}
- void printRight(OutputStream &S) const override {
- initializePackExpansion(S);
- size_t Idx = S.CurrentPackIndex;
+ void printRight(OutputBuffer &OB) const override {
+ initializePackExpansion(OB);
+ size_t Idx = OB.CurrentPackIndex;
if (Idx < Data.size())
- Data[Idx]->printRight(S);
+ Data[Idx]->printRight(OB);
}
};
@@ -1181,8 +1334,8 @@ public:
NodeArray getElements() const { return Elements; }
- void printLeft(OutputStream &S) const override {
- Elements.printWithComma(S);
+ void printLeft(OutputBuffer &OB) const override {
+ Elements.printWithComma(OB);
}
};
@@ -1199,35 +1352,35 @@ public:
const Node *getChild() const { return Child; }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
constexpr unsigned Max = std::numeric_limits<unsigned>::max();
- SwapAndRestore<unsigned> SavePackIdx(S.CurrentPackIndex, Max);
- SwapAndRestore<unsigned> SavePackMax(S.CurrentPackMax, Max);
- size_t StreamPos = S.getCurrentPosition();
+ ScopedOverride<unsigned> SavePackIdx(OB.CurrentPackIndex, Max);
+ ScopedOverride<unsigned> SavePackMax(OB.CurrentPackMax, Max);
+ size_t StreamPos = OB.getCurrentPosition();
// Print the first element in the pack. If Child contains a ParameterPack,
// it will set up S.CurrentPackMax and print the first element.
- Child->print(S);
+ Child->print(OB);
// No ParameterPack was found in Child. This can occur if we've found a pack
// expansion on a <function-param>.
- if (S.CurrentPackMax == Max) {
- S += "...";
+ if (OB.CurrentPackMax == Max) {
+ OB += "...";
return;
}
// We found a ParameterPack, but it has no elements. Erase whatever we may
// of printed.
- if (S.CurrentPackMax == 0) {
- S.setCurrentPosition(StreamPos);
+ if (OB.CurrentPackMax == 0) {
+ OB.setCurrentPosition(StreamPos);
return;
}
// Else, iterate through the rest of the elements in the pack.
- for (unsigned I = 1, E = S.CurrentPackMax; I < E; ++I) {
- S += ", ";
- S.CurrentPackIndex = I;
- Child->print(S);
+ for (unsigned I = 1, E = OB.CurrentPackMax; I < E; ++I) {
+ OB += ", ";
+ OB.CurrentPackIndex = I;
+ Child->print(OB);
}
}
};
@@ -1242,12 +1395,11 @@ public:
NodeArray getParams() { return Params; }
- void printLeft(OutputStream &S) const override {
- S += "<";
- Params.printWithComma(S);
- if (S.back() == '>')
- S += " ";
- S += ">";
+ void printLeft(OutputBuffer &OB) const override {
+ ScopedOverride<unsigned> LT(OB.GtIsGt, 0);
+ OB += "<";
+ Params.printWithComma(OB);
+ OB += ">";
}
};
@@ -1289,42 +1441,42 @@ struct ForwardTemplateReference : Node {
// special handling.
template<typename Fn> void match(Fn F) const = delete;
- bool hasRHSComponentSlow(OutputStream &S) const override {
+ bool hasRHSComponentSlow(OutputBuffer &OB) const override {
if (Printing)
return false;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- return Ref->hasRHSComponent(S);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ return Ref->hasRHSComponent(OB);
}
- bool hasArraySlow(OutputStream &S) const override {
+ bool hasArraySlow(OutputBuffer &OB) const override {
if (Printing)
return false;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- return Ref->hasArray(S);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ return Ref->hasArray(OB);
}
- bool hasFunctionSlow(OutputStream &S) const override {
+ bool hasFunctionSlow(OutputBuffer &OB) const override {
if (Printing)
return false;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- return Ref->hasFunction(S);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ return Ref->hasFunction(OB);
}
- const Node *getSyntaxNode(OutputStream &S) const override {
+ const Node *getSyntaxNode(OutputBuffer &OB) const override {
if (Printing)
return this;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- return Ref->getSyntaxNode(S);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ return Ref->getSyntaxNode(OB);
}
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (Printing)
return;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- Ref->printLeft(S);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ Ref->printLeft(OB);
}
- void printRight(OutputStream &S) const override {
+ void printRight(OutputBuffer &OB) const override {
if (Printing)
return;
- SwapAndRestore<bool> SavePrinting(Printing, true);
- Ref->printRight(S);
+ ScopedOverride<bool> SavePrinting(Printing, true);
+ Ref->printRight(OB);
}
};
@@ -1338,11 +1490,11 @@ struct NameWithTemplateArgs : Node {
template<typename Fn> void match(Fn F) const { F(Name, TemplateArgs); }
- StringView getBaseName() const override { return Name->getBaseName(); }
+ std::string_view getBaseName() const override { return Name->getBaseName(); }
- void printLeft(OutputStream &S) const override {
- Name->print(S);
- TemplateArgs->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ Name->print(OB);
+ TemplateArgs->print(OB);
}
};
@@ -1355,26 +1507,11 @@ public:
template<typename Fn> void match(Fn F) const { F(Child); }
- StringView getBaseName() const override { return Child->getBaseName(); }
+ std::string_view getBaseName() const override { return Child->getBaseName(); }
- void printLeft(OutputStream &S) const override {
- S += "::";
- Child->print(S);
- }
-};
-
-struct StdQualifiedName : Node {
- Node *Child;
-
- StdQualifiedName(Node *Child_) : Node(KStdQualifiedName), Child(Child_) {}
-
- template<typename Fn> void match(Fn F) const { F(Child); }
-
- StringView getBaseName() const override { return Child->getBaseName(); }
-
- void printLeft(OutputStream &S) const override {
- S += "std::";
- Child->print(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "::";
+ Child->print(OB);
}
};
@@ -1387,109 +1524,81 @@ enum class SpecialSubKind {
iostream,
};
-class ExpandedSpecialSubstitution final : public Node {
+class SpecialSubstitution;
+class ExpandedSpecialSubstitution : public Node {
+protected:
SpecialSubKind SSK;
+ ExpandedSpecialSubstitution(SpecialSubKind SSK_, Kind K_)
+ : Node(K_), SSK(SSK_) {}
public:
ExpandedSpecialSubstitution(SpecialSubKind SSK_)
- : Node(KExpandedSpecialSubstitution), SSK(SSK_) {}
+ : ExpandedSpecialSubstitution(SSK_, KExpandedSpecialSubstitution) {}
+ inline ExpandedSpecialSubstitution(SpecialSubstitution const *);
template<typename Fn> void match(Fn F) const { F(SSK); }
- StringView getBaseName() const override {
+protected:
+ bool isInstantiation() const {
+ return unsigned(SSK) >= unsigned(SpecialSubKind::string);
+ }
+
+ std::string_view getBaseName() const override {
switch (SSK) {
case SpecialSubKind::allocator:
- return StringView("allocator");
+ return {"allocator"};
case SpecialSubKind::basic_string:
- return StringView("basic_string");
+ return {"basic_string"};
case SpecialSubKind::string:
- return StringView("basic_string");
+ return {"basic_string"};
case SpecialSubKind::istream:
- return StringView("basic_istream");
+ return {"basic_istream"};
case SpecialSubKind::ostream:
- return StringView("basic_ostream");
+ return {"basic_ostream"};
case SpecialSubKind::iostream:
- return StringView("basic_iostream");
+ return {"basic_iostream"};
}
DEMANGLE_UNREACHABLE;
}
- void printLeft(OutputStream &S) const override {
- switch (SSK) {
- case SpecialSubKind::allocator:
- S += "std::allocator";
- break;
- case SpecialSubKind::basic_string:
- S += "std::basic_string";
- break;
- case SpecialSubKind::string:
- S += "std::basic_string<char, std::char_traits<char>, "
- "std::allocator<char> >";
- break;
- case SpecialSubKind::istream:
- S += "std::basic_istream<char, std::char_traits<char> >";
- break;
- case SpecialSubKind::ostream:
- S += "std::basic_ostream<char, std::char_traits<char> >";
- break;
- case SpecialSubKind::iostream:
- S += "std::basic_iostream<char, std::char_traits<char> >";
- break;
+private:
+ void printLeft(OutputBuffer &OB) const override {
+ OB << "std::" << getBaseName();
+ if (isInstantiation()) {
+ OB << "<char, std::char_traits<char>";
+ if (SSK == SpecialSubKind::string)
+ OB << ", std::allocator<char>";
+ OB << ">";
}
}
};
-class SpecialSubstitution final : public Node {
+class SpecialSubstitution final : public ExpandedSpecialSubstitution {
public:
- SpecialSubKind SSK;
-
SpecialSubstitution(SpecialSubKind SSK_)
- : Node(KSpecialSubstitution), SSK(SSK_) {}
+ : ExpandedSpecialSubstitution(SSK_, KSpecialSubstitution) {}
template<typename Fn> void match(Fn F) const { F(SSK); }
- StringView getBaseName() const override {
- switch (SSK) {
- case SpecialSubKind::allocator:
- return StringView("allocator");
- case SpecialSubKind::basic_string:
- return StringView("basic_string");
- case SpecialSubKind::string:
- return StringView("string");
- case SpecialSubKind::istream:
- return StringView("istream");
- case SpecialSubKind::ostream:
- return StringView("ostream");
- case SpecialSubKind::iostream:
- return StringView("iostream");
+ std::string_view getBaseName() const override {
+ std::string_view SV = ExpandedSpecialSubstitution::getBaseName();
+ if (isInstantiation()) {
+ // The instantiations are typedefs that drop the "basic_" prefix.
+ assert(llvm::itanium_demangle::starts_with(SV, "basic_"));
+ SV.remove_prefix(sizeof("basic_") - 1);
}
- DEMANGLE_UNREACHABLE;
+ return SV;
}
- void printLeft(OutputStream &S) const override {
- switch (SSK) {
- case SpecialSubKind::allocator:
- S += "std::allocator";
- break;
- case SpecialSubKind::basic_string:
- S += "std::basic_string";
- break;
- case SpecialSubKind::string:
- S += "std::string";
- break;
- case SpecialSubKind::istream:
- S += "std::istream";
- break;
- case SpecialSubKind::ostream:
- S += "std::ostream";
- break;
- case SpecialSubKind::iostream:
- S += "std::iostream";
- break;
- }
+ void printLeft(OutputBuffer &OB) const override {
+ OB << "std::" << getBaseName();
}
};
+inline ExpandedSpecialSubstitution::ExpandedSpecialSubstitution(
+ SpecialSubstitution const *SS)
+ : ExpandedSpecialSubstitution(SS->SSK) {}
+
class CtorDtorName final : public Node {
const Node *Basename;
const bool IsDtor;
@@ -1502,10 +1611,10 @@ public:
template<typename Fn> void match(Fn F) const { F(Basename, IsDtor, Variant); }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (IsDtor)
- S += "~";
- S += Basename->getBaseName();
+ OB += "~";
+ OB += Basename->getBaseName();
}
};
@@ -1517,35 +1626,36 @@ public:
template<typename Fn> void match(Fn F) const { F(Base); }
- void printLeft(OutputStream &S) const override {
- S += "~";
- Base->printLeft(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "~";
+ Base->printLeft(OB);
}
};
class UnnamedTypeName : public Node {
- const StringView Count;
+ const std::string_view Count;
public:
- UnnamedTypeName(StringView Count_) : Node(KUnnamedTypeName), Count(Count_) {}
+ UnnamedTypeName(std::string_view Count_)
+ : Node(KUnnamedTypeName), Count(Count_) {}
template<typename Fn> void match(Fn F) const { F(Count); }
- void printLeft(OutputStream &S) const override {
- S += "'unnamed";
- S += Count;
- S += "\'";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "'unnamed";
+ OB += Count;
+ OB += "\'";
}
};
class ClosureTypeName : public Node {
NodeArray TemplateParams;
NodeArray Params;
- StringView Count;
+ std::string_view Count;
public:
ClosureTypeName(NodeArray TemplateParams_, NodeArray Params_,
- StringView Count_)
+ std::string_view Count_)
: Node(KClosureTypeName), TemplateParams(TemplateParams_),
Params(Params_), Count(Count_) {}
@@ -1553,22 +1663,23 @@ public:
F(TemplateParams, Params, Count);
}
- void printDeclarator(OutputStream &S) const {
+ void printDeclarator(OutputBuffer &OB) const {
if (!TemplateParams.empty()) {
- S += "<";
- TemplateParams.printWithComma(S);
- S += ">";
+ ScopedOverride<unsigned> LT(OB.GtIsGt, 0);
+ OB += "<";
+ TemplateParams.printWithComma(OB);
+ OB += ">";
}
- S += "(";
- Params.printWithComma(S);
- S += ")";
+ OB.printOpen();
+ Params.printWithComma(OB);
+ OB.printClose();
}
- void printLeft(OutputStream &S) const override {
- S += "\'lambda";
- S += Count;
- S += "\'";
- printDeclarator(S);
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "\'lambda";
+ OB += Count;
+ OB += "\'";
+ printDeclarator(OB);
}
};
@@ -1580,10 +1691,10 @@ public:
template<typename Fn> void match(Fn F) const { F(Bindings); }
- void printLeft(OutputStream &S) const override {
- S += '[';
- Bindings.printWithComma(S);
- S += ']';
+ void printLeft(OutputBuffer &OB) const override {
+ OB.printOpen('[');
+ Bindings.printWithComma(OB);
+ OB.printClose(']');
}
};
@@ -1591,32 +1702,35 @@ public:
class BinaryExpr : public Node {
const Node *LHS;
- const StringView InfixOperator;
+ const std::string_view InfixOperator;
const Node *RHS;
public:
- BinaryExpr(const Node *LHS_, StringView InfixOperator_, const Node *RHS_)
- : Node(KBinaryExpr), LHS(LHS_), InfixOperator(InfixOperator_), RHS(RHS_) {
- }
-
- template<typename Fn> void match(Fn F) const { F(LHS, InfixOperator, RHS); }
-
- void printLeft(OutputStream &S) const override {
- // might be a template argument expression, then we need to disambiguate
- // with parens.
- if (InfixOperator == ">")
- S += "(";
-
- S += "(";
- LHS->print(S);
- S += ") ";
- S += InfixOperator;
- S += " (";
- RHS->print(S);
- S += ")";
-
- if (InfixOperator == ">")
- S += ")";
+ BinaryExpr(const Node *LHS_, std::string_view InfixOperator_,
+ const Node *RHS_, Prec Prec_)
+ : Node(KBinaryExpr, Prec_), LHS(LHS_), InfixOperator(InfixOperator_),
+ RHS(RHS_) {}
+
+ template <typename Fn> void match(Fn F) const {
+ F(LHS, InfixOperator, RHS, getPrecedence());
+ }
+
+ void printLeft(OutputBuffer &OB) const override {
+ bool ParenAll = OB.isGtInsideTemplateArgs() &&
+ (InfixOperator == ">" || InfixOperator == ">>");
+ if (ParenAll)
+ OB.printOpen();
+ // Assignment is right associative, with special LHS precedence.
+ bool IsAssign = getPrecedence() == Prec::Assign;
+ LHS->printAsOperand(OB, IsAssign ? Prec::OrIf : getPrecedence(), !IsAssign);
+ // No space before comma operator
+ if (!(InfixOperator == ","))
+ OB += " ";
+ OB += InfixOperator;
+ OB += " ";
+ RHS->printAsOperand(OB, getPrecedence(), IsAssign);
+ if (ParenAll)
+ OB.printClose();
}
};
@@ -1625,35 +1739,36 @@ class ArraySubscriptExpr : public Node {
const Node *Op2;
public:
- ArraySubscriptExpr(const Node *Op1_, const Node *Op2_)
- : Node(KArraySubscriptExpr), Op1(Op1_), Op2(Op2_) {}
+ ArraySubscriptExpr(const Node *Op1_, const Node *Op2_, Prec Prec_)
+ : Node(KArraySubscriptExpr, Prec_), Op1(Op1_), Op2(Op2_) {}
- template<typename Fn> void match(Fn F) const { F(Op1, Op2); }
+ template <typename Fn> void match(Fn F) const {
+ F(Op1, Op2, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += "(";
- Op1->print(S);
- S += ")[";
- Op2->print(S);
- S += "]";
+ void printLeft(OutputBuffer &OB) const override {
+ Op1->printAsOperand(OB, getPrecedence());
+ OB.printOpen('[');
+ Op2->printAsOperand(OB);
+ OB.printClose(']');
}
};
class PostfixExpr : public Node {
const Node *Child;
- const StringView Operator;
+ const std::string_view Operator;
public:
- PostfixExpr(const Node *Child_, StringView Operator_)
- : Node(KPostfixExpr), Child(Child_), Operator(Operator_) {}
+ PostfixExpr(const Node *Child_, std::string_view Operator_, Prec Prec_)
+ : Node(KPostfixExpr, Prec_), Child(Child_), Operator(Operator_) {}
- template<typename Fn> void match(Fn F) const { F(Child, Operator); }
+ template <typename Fn> void match(Fn F) const {
+ F(Child, Operator, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += "(";
- Child->print(S);
- S += ")";
- S += Operator;
+ void printLeft(OutputBuffer &OB) const override {
+ Child->printAsOperand(OB, getPrecedence(), true);
+ OB += Operator;
}
};
@@ -1663,78 +1778,128 @@ class ConditionalExpr : public Node {
const Node *Else;
public:
- ConditionalExpr(const Node *Cond_, const Node *Then_, const Node *Else_)
- : Node(KConditionalExpr), Cond(Cond_), Then(Then_), Else(Else_) {}
+ ConditionalExpr(const Node *Cond_, const Node *Then_, const Node *Else_,
+ Prec Prec_)
+ : Node(KConditionalExpr, Prec_), Cond(Cond_), Then(Then_), Else(Else_) {}
- template<typename Fn> void match(Fn F) const { F(Cond, Then, Else); }
+ template <typename Fn> void match(Fn F) const {
+ F(Cond, Then, Else, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += "(";
- Cond->print(S);
- S += ") ? (";
- Then->print(S);
- S += ") : (";
- Else->print(S);
- S += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ Cond->printAsOperand(OB, getPrecedence());
+ OB += " ? ";
+ Then->printAsOperand(OB);
+ OB += " : ";
+ Else->printAsOperand(OB, Prec::Assign, true);
}
};
class MemberExpr : public Node {
const Node *LHS;
- const StringView Kind;
+ const std::string_view Kind;
const Node *RHS;
public:
- MemberExpr(const Node *LHS_, StringView Kind_, const Node *RHS_)
- : Node(KMemberExpr), LHS(LHS_), Kind(Kind_), RHS(RHS_) {}
+ MemberExpr(const Node *LHS_, std::string_view Kind_, const Node *RHS_,
+ Prec Prec_)
+ : Node(KMemberExpr, Prec_), LHS(LHS_), Kind(Kind_), RHS(RHS_) {}
+
+ template <typename Fn> void match(Fn F) const {
+ F(LHS, Kind, RHS, getPrecedence());
+ }
+
+ void printLeft(OutputBuffer &OB) const override {
+ LHS->printAsOperand(OB, getPrecedence(), true);
+ OB += Kind;
+ RHS->printAsOperand(OB, getPrecedence(), false);
+ }
+};
+
+class SubobjectExpr : public Node {
+ const Node *Type;
+ const Node *SubExpr;
+ std::string_view Offset;
+ NodeArray UnionSelectors;
+ bool OnePastTheEnd;
- template<typename Fn> void match(Fn F) const { F(LHS, Kind, RHS); }
+public:
+ SubobjectExpr(const Node *Type_, const Node *SubExpr_,
+ std::string_view Offset_, NodeArray UnionSelectors_,
+ bool OnePastTheEnd_)
+ : Node(KSubobjectExpr), Type(Type_), SubExpr(SubExpr_), Offset(Offset_),
+ UnionSelectors(UnionSelectors_), OnePastTheEnd(OnePastTheEnd_) {}
- void printLeft(OutputStream &S) const override {
- LHS->print(S);
- S += Kind;
- RHS->print(S);
+ template<typename Fn> void match(Fn F) const {
+ F(Type, SubExpr, Offset, UnionSelectors, OnePastTheEnd);
+ }
+
+ void printLeft(OutputBuffer &OB) const override {
+ SubExpr->print(OB);
+ OB += ".<";
+ Type->print(OB);
+ OB += " at offset ";
+ if (Offset.empty()) {
+ OB += "0";
+ } else if (Offset[0] == 'n') {
+ OB += "-";
+ OB += std::string_view(Offset.data() + 1, Offset.size() - 1);
+ } else {
+ OB += Offset;
+ }
+ OB += ">";
}
};
class EnclosingExpr : public Node {
- const StringView Prefix;
+ const std::string_view Prefix;
const Node *Infix;
- const StringView Postfix;
+ const std::string_view Postfix;
public:
- EnclosingExpr(StringView Prefix_, Node *Infix_, StringView Postfix_)
- : Node(KEnclosingExpr), Prefix(Prefix_), Infix(Infix_),
- Postfix(Postfix_) {}
+ EnclosingExpr(std::string_view Prefix_, const Node *Infix_,
+ Prec Prec_ = Prec::Primary)
+ : Node(KEnclosingExpr, Prec_), Prefix(Prefix_), Infix(Infix_) {}
- template<typename Fn> void match(Fn F) const { F(Prefix, Infix, Postfix); }
+ template <typename Fn> void match(Fn F) const {
+ F(Prefix, Infix, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += Prefix;
- Infix->print(S);
- S += Postfix;
+ void printLeft(OutputBuffer &OB) const override {
+ OB += Prefix;
+ OB.printOpen();
+ Infix->print(OB);
+ OB.printClose();
+ OB += Postfix;
}
};
class CastExpr : public Node {
// cast_kind<to>(from)
- const StringView CastKind;
+ const std::string_view CastKind;
const Node *To;
const Node *From;
public:
- CastExpr(StringView CastKind_, const Node *To_, const Node *From_)
- : Node(KCastExpr), CastKind(CastKind_), To(To_), From(From_) {}
+ CastExpr(std::string_view CastKind_, const Node *To_, const Node *From_,
+ Prec Prec_)
+ : Node(KCastExpr, Prec_), CastKind(CastKind_), To(To_), From(From_) {}
- template<typename Fn> void match(Fn F) const { F(CastKind, To, From); }
+ template <typename Fn> void match(Fn F) const {
+ F(CastKind, To, From, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += CastKind;
- S += "<";
- To->printLeft(S);
- S += ">(";
- From->printLeft(S);
- S += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += CastKind;
+ {
+ ScopedOverride<unsigned> LT(OB.GtIsGt, 0);
+ OB += "<";
+ To->printLeft(OB);
+ OB += ">";
+ }
+ OB.printOpen();
+ From->printAsOperand(OB);
+ OB.printClose();
}
};
@@ -1747,11 +1912,12 @@ public:
template<typename Fn> void match(Fn F) const { F(Pack); }
- void printLeft(OutputStream &S) const override {
- S += "sizeof...(";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "sizeof...";
+ OB.printOpen();
ParameterPackExpansion PPE(Pack);
- PPE.printLeft(S);
- S += ")";
+ PPE.printLeft(OB);
+ OB.printClose();
}
};
@@ -1760,16 +1926,18 @@ class CallExpr : public Node {
NodeArray Args;
public:
- CallExpr(const Node *Callee_, NodeArray Args_)
- : Node(KCallExpr), Callee(Callee_), Args(Args_) {}
+ CallExpr(const Node *Callee_, NodeArray Args_, Prec Prec_)
+ : Node(KCallExpr, Prec_), Callee(Callee_), Args(Args_) {}
- template<typename Fn> void match(Fn F) const { F(Callee, Args); }
+ template <typename Fn> void match(Fn F) const {
+ F(Callee, Args, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- Callee->print(S);
- S += "(";
- Args.printWithComma(S);
- S += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ Callee->print(OB);
+ OB.printOpen();
+ Args.printWithComma(OB);
+ OB.printClose();
}
};
@@ -1782,33 +1950,32 @@ class NewExpr : public Node {
bool IsArray; // new[] ?
public:
NewExpr(NodeArray ExprList_, Node *Type_, NodeArray InitList_, bool IsGlobal_,
- bool IsArray_)
- : Node(KNewExpr), ExprList(ExprList_), Type(Type_), InitList(InitList_),
- IsGlobal(IsGlobal_), IsArray(IsArray_) {}
+ bool IsArray_, Prec Prec_)
+ : Node(KNewExpr, Prec_), ExprList(ExprList_), Type(Type_),
+ InitList(InitList_), IsGlobal(IsGlobal_), IsArray(IsArray_) {}
template<typename Fn> void match(Fn F) const {
- F(ExprList, Type, InitList, IsGlobal, IsArray);
+ F(ExprList, Type, InitList, IsGlobal, IsArray, getPrecedence());
}
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (IsGlobal)
- S += "::operator ";
- S += "new";
+ OB += "::";
+ OB += "new";
if (IsArray)
- S += "[]";
- S += ' ';
+ OB += "[]";
if (!ExprList.empty()) {
- S += "(";
- ExprList.printWithComma(S);
- S += ")";
+ OB.printOpen();
+ ExprList.printWithComma(OB);
+ OB.printClose();
}
- Type->print(S);
+ OB += " ";
+ Type->print(OB);
if (!InitList.empty()) {
- S += "(";
- InitList.printWithComma(S);
- S += ")";
+ OB.printOpen();
+ InitList.printWithComma(OB);
+ OB.printClose();
}
-
}
};
@@ -1818,50 +1985,55 @@ class DeleteExpr : public Node {
bool IsArray;
public:
- DeleteExpr(Node *Op_, bool IsGlobal_, bool IsArray_)
- : Node(KDeleteExpr), Op(Op_), IsGlobal(IsGlobal_), IsArray(IsArray_) {}
+ DeleteExpr(Node *Op_, bool IsGlobal_, bool IsArray_, Prec Prec_)
+ : Node(KDeleteExpr, Prec_), Op(Op_), IsGlobal(IsGlobal_),
+ IsArray(IsArray_) {}
- template<typename Fn> void match(Fn F) const { F(Op, IsGlobal, IsArray); }
+ template <typename Fn> void match(Fn F) const {
+ F(Op, IsGlobal, IsArray, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (IsGlobal)
- S += "::";
- S += "delete";
+ OB += "::";
+ OB += "delete";
if (IsArray)
- S += "[] ";
- Op->print(S);
+ OB += "[]";
+ OB += ' ';
+ Op->print(OB);
}
};
class PrefixExpr : public Node {
- StringView Prefix;
+ std::string_view Prefix;
Node *Child;
public:
- PrefixExpr(StringView Prefix_, Node *Child_)
- : Node(KPrefixExpr), Prefix(Prefix_), Child(Child_) {}
+ PrefixExpr(std::string_view Prefix_, Node *Child_, Prec Prec_)
+ : Node(KPrefixExpr, Prec_), Prefix(Prefix_), Child(Child_) {}
- template<typename Fn> void match(Fn F) const { F(Prefix, Child); }
+ template <typename Fn> void match(Fn F) const {
+ F(Prefix, Child, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += Prefix;
- S += "(";
- Child->print(S);
- S += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += Prefix;
+ Child->printAsOperand(OB, getPrecedence());
}
};
class FunctionParam : public Node {
- StringView Number;
+ std::string_view Number;
public:
- FunctionParam(StringView Number_) : Node(KFunctionParam), Number(Number_) {}
+ FunctionParam(std::string_view Number_)
+ : Node(KFunctionParam), Number(Number_) {}
template<typename Fn> void match(Fn F) const { F(Number); }
- void printLeft(OutputStream &S) const override {
- S += "fp";
- S += Number;
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "fp";
+ OB += Number;
}
};
@@ -1870,17 +2042,45 @@ class ConversionExpr : public Node {
NodeArray Expressions;
public:
- ConversionExpr(const Node *Type_, NodeArray Expressions_)
- : Node(KConversionExpr), Type(Type_), Expressions(Expressions_) {}
+ ConversionExpr(const Node *Type_, NodeArray Expressions_, Prec Prec_)
+ : Node(KConversionExpr, Prec_), Type(Type_), Expressions(Expressions_) {}
+
+ template <typename Fn> void match(Fn F) const {
+ F(Type, Expressions, getPrecedence());
+ }
+
+ void printLeft(OutputBuffer &OB) const override {
+ OB.printOpen();
+ Type->print(OB);
+ OB.printClose();
+ OB.printOpen();
+ Expressions.printWithComma(OB);
+ OB.printClose();
+ }
+};
+
+class PointerToMemberConversionExpr : public Node {
+ const Node *Type;
+ const Node *SubExpr;
+ std::string_view Offset;
+
+public:
+ PointerToMemberConversionExpr(const Node *Type_, const Node *SubExpr_,
+ std::string_view Offset_, Prec Prec_)
+ : Node(KPointerToMemberConversionExpr, Prec_), Type(Type_),
+ SubExpr(SubExpr_), Offset(Offset_) {}
- template<typename Fn> void match(Fn F) const { F(Type, Expressions); }
+ template <typename Fn> void match(Fn F) const {
+ F(Type, SubExpr, Offset, getPrecedence());
+ }
- void printLeft(OutputStream &S) const override {
- S += "(";
- Type->print(S);
- S += ")(";
- Expressions.printWithComma(S);
- S += ")";
+ void printLeft(OutputBuffer &OB) const override {
+ OB.printOpen();
+ Type->print(OB);
+ OB.printClose();
+ OB.printOpen();
+ SubExpr->print(OB);
+ OB.printClose();
}
};
@@ -1893,12 +2093,12 @@ public:
template<typename Fn> void match(Fn F) const { F(Ty, Inits); }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (Ty)
- Ty->print(S);
- S += '{';
- Inits.printWithComma(S);
- S += '}';
+ Ty->print(OB);
+ OB += '{';
+ Inits.printWithComma(OB);
+ OB += '}';
}
};
@@ -1912,18 +2112,18 @@ public:
template<typename Fn> void match(Fn F) const { F(Elem, Init, IsArray); }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (IsArray) {
- S += '[';
- Elem->print(S);
- S += ']';
+ OB += '[';
+ Elem->print(OB);
+ OB += ']';
} else {
- S += '.';
- Elem->print(S);
+ OB += '.';
+ Elem->print(OB);
}
if (Init->getKind() != KBracedExpr && Init->getKind() != KBracedRangeExpr)
- S += " = ";
- Init->print(S);
+ OB += " = ";
+ Init->print(OB);
}
};
@@ -1937,25 +2137,25 @@ public:
template<typename Fn> void match(Fn F) const { F(First, Last, Init); }
- void printLeft(OutputStream &S) const override {
- S += '[';
- First->print(S);
- S += " ... ";
- Last->print(S);
- S += ']';
+ void printLeft(OutputBuffer &OB) const override {
+ OB += '[';
+ First->print(OB);
+ OB += " ... ";
+ Last->print(OB);
+ OB += ']';
if (Init->getKind() != KBracedExpr && Init->getKind() != KBracedRangeExpr)
- S += " = ";
- Init->print(S);
+ OB += " = ";
+ Init->print(OB);
}
};
class FoldExpr : public Node {
const Node *Pack, *Init;
- StringView OperatorName;
+ std::string_view OperatorName;
bool IsLeftFold;
public:
- FoldExpr(bool IsLeftFold_, StringView OperatorName_, const Node *Pack_,
+ FoldExpr(bool IsLeftFold_, std::string_view OperatorName_, const Node *Pack_,
const Node *Init_)
: Node(KFoldExpr), Pack(Pack_), Init(Init_), OperatorName(OperatorName_),
IsLeftFold(IsLeftFold_) {}
@@ -1964,43 +2164,35 @@ public:
F(IsLeftFold, OperatorName, Pack, Init);
}
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
auto PrintPack = [&] {
- S += '(';
- ParameterPackExpansion(Pack).print(S);
- S += ')';
+ OB.printOpen();
+ ParameterPackExpansion(Pack).print(OB);
+ OB.printClose();
};
- S += '(';
-
- if (IsLeftFold) {
- // init op ... op pack
- if (Init != nullptr) {
- Init->print(S);
- S += ' ';
- S += OperatorName;
- S += ' ';
- }
- // ... op pack
- S += "... ";
- S += OperatorName;
- S += ' ';
- PrintPack();
- } else { // !IsLeftFold
- // pack op ...
- PrintPack();
- S += ' ';
- S += OperatorName;
- S += " ...";
- // pack op ... op init
- if (Init != nullptr) {
- S += ' ';
- S += OperatorName;
- S += ' ';
- Init->print(S);
- }
+ OB.printOpen();
+ // Either '[init op ]... op pack' or 'pack op ...[ op init]'
+ // Refactored to '[(init|pack) op ]...[ op (pack|init)]'
+ // Fold expr operands are cast-expressions
+ if (!IsLeftFold || Init != nullptr) {
+ // '(init|pack) op '
+ if (IsLeftFold)
+ Init->printAsOperand(OB, Prec::Cast, true);
+ else
+ PrintPack();
+ OB << " " << OperatorName << " ";
+ }
+ OB << "...";
+ if (IsLeftFold || Init != nullptr) {
+ // ' op (init|pack)'
+ OB << " " << OperatorName << " ";
+ if (IsLeftFold)
+ PrintPack();
+ else
+ Init->printAsOperand(OB, Prec::Cast, true);
}
- S += ')';
+ OB.printClose();
}
};
@@ -2012,24 +2204,9 @@ public:
template<typename Fn> void match(Fn F) const { F(Op); }
- void printLeft(OutputStream &S) const override {
- S += "throw ";
- Op->print(S);
- }
-};
-
-// MSVC __uuidof extension, generated by clang in -fms-extensions mode.
-class UUIDOfExpr : public Node {
- Node *Operand;
-public:
- UUIDOfExpr(Node *Operand_) : Node(KUUIDOfExpr), Operand(Operand_) {}
-
- template<typename Fn> void match(Fn F) const { F(Operand); }
-
- void printLeft(OutputStream &S) const override {
- S << "__uuidof(";
- Operand->print(S);
- S << ")";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "throw ";
+ Op->print(OB);
}
};
@@ -2041,8 +2218,8 @@ public:
template<typename Fn> void match(Fn F) const { F(Value); }
- void printLeft(OutputStream &S) const override {
- S += Value ? StringView("true") : StringView("false");
+ void printLeft(OutputBuffer &OB) const override {
+ OB += Value ? std::string_view("true") : std::string_view("false");
}
};
@@ -2054,10 +2231,10 @@ public:
template<typename Fn> void match(Fn F) const { F(Type); }
- void printLeft(OutputStream &S) const override {
- S += "\"<";
- Type->print(S);
- S += ">\"";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "\"<";
+ Type->print(OB);
+ OB += ">\"";
}
};
@@ -2069,58 +2246,61 @@ public:
template<typename Fn> void match(Fn F) const { F(Type); }
- void printLeft(OutputStream &S) const override {
- S += "[]";
+ void printLeft(OutputBuffer &OB) const override {
+ OB += "[]";
if (Type->getKind() == KClosureTypeName)
- static_cast<const ClosureTypeName *>(Type)->printDeclarator(S);
- S += "{...}";
+ static_cast<const ClosureTypeName *>(Type)->printDeclarator(OB);
+ OB += "{...}";
}
};
-class IntegerCastExpr : public Node {
+class EnumLiteral : public Node {
// ty(integer)
const Node *Ty;
- StringView Integer;
+ std::string_view Integer;
public:
- IntegerCastExpr(const Node *Ty_, StringView Integer_)
- : Node(KIntegerCastExpr), Ty(Ty_), Integer(Integer_) {}
+ EnumLiteral(const Node *Ty_, std::string_view Integer_)
+ : Node(KEnumLiteral), Ty(Ty_), Integer(Integer_) {}
template<typename Fn> void match(Fn F) const { F(Ty, Integer); }
- void printLeft(OutputStream &S) const override {
- S += "(";
- Ty->print(S);
- S += ")";
- S += Integer;
+ void printLeft(OutputBuffer &OB) const override {
+ OB.printOpen();
+ Ty->print(OB);
+ OB.printClose();
+
+ if (Integer[0] == 'n')
+ OB << '-' << std::string_view(Integer.data() + 1, Integer.size() - 1);
+ else
+ OB << Integer;
}
};
class IntegerLiteral : public Node {
- StringView Type;
- StringView Value;
+ std::string_view Type;
+ std::string_view Value;
public:
- IntegerLiteral(StringView Type_, StringView Value_)
+ IntegerLiteral(std::string_view Type_, std::string_view Value_)
: Node(KIntegerLiteral), Type(Type_), Value(Value_) {}
template<typename Fn> void match(Fn F) const { F(Type, Value); }
- void printLeft(OutputStream &S) const override {
+ void printLeft(OutputBuffer &OB) const override {
if (Type.size() > 3) {
- S += "(";
- S += Type;
- S += ")";
+ OB.printOpen();
+ OB += Type;
+ OB.printClose();
}
- if (Value[0] == 'n') {
- S += "-";
- S += Value.dropFront(1);
- } else
- S += Value;
+ if (Value[0] == 'n')
+ OB << '-' << std::string_view(Value.data() + 1, Value.size() - 1);
+ else
+ OB += Value;
if (Type.size() <= 3)
- S += Type;
+ OB += Type;
}
};
@@ -2139,29 +2319,26 @@ constexpr Node::Kind getFloatLiteralKind(long double *) {
}
template <class Float> class FloatLiteralImpl : public Node {
- const StringView Contents;
+ const std::string_view Contents;
static constexpr Kind KindForClass =
float_literal_impl::getFloatLiteralKind((Float *)nullptr);
public:
- FloatLiteralImpl(StringView Contents_)
+ FloatLiteralImpl(std::string_view Contents_)
: Node(KindForClass), Contents(Contents_) {}
template<typename Fn> void match(Fn F) const { F(Contents); }
- void printLeft(OutputStream &s) const override {
- const char *first = Contents.begin();
- const char *last = Contents.end() + 1;
-
+ void printLeft(OutputBuffer &OB) const override {
const size_t N = FloatData<Float>::mangled_size;
- if (static_cast<std::size_t>(last - first) > N) {
- last = first + N;
+ if (Contents.size() >= N) {
union {
Float value;
char buf[sizeof(Float)];
};
- const char *t = first;
+ const char *t = Contents.data();
+ const char *last = t + N;
char *e = buf;
for (; t != last; ++t, ++e) {
unsigned d1 = isdigit(*t) ? static_cast<unsigned>(*t - '0')
@@ -2176,7 +2353,7 @@ public:
#endif
char num[FloatData<Float>::max_demangled_size] = {0};
int n = snprintf(num, sizeof(num), FloatData<Float>::spec, value);
- s += StringView(num, num + n);
+ OB += std::string_view(num, n);
}
}
};
@@ -2190,143 +2367,22 @@ using LongDoubleLiteral = FloatLiteralImpl<long double>;
template<typename Fn>
void Node::visit(Fn F) const {
switch (K) {
-#define CASE(X) case K ## X: return F(static_cast<const X*>(this));
- FOR_EACH_NODE_KIND(CASE)
-#undef CASE
+#define NODE(X) \
+ case K##X: \
+ return F(static_cast<const X *>(this));
+#include "ItaniumNodes.def"
}
assert(0 && "unknown mangling node kind");
}
/// Determine the kind of a node from its type.
template<typename NodeT> struct NodeKind;
-#define SPECIALIZATION(X) \
- template<> struct NodeKind<X> { \
- static constexpr Node::Kind Kind = Node::K##X; \
- static constexpr const char *name() { return #X; } \
+#define NODE(X) \
+ template <> struct NodeKind<X> { \
+ static constexpr Node::Kind Kind = Node::K##X; \
+ static constexpr const char *name() { return #X; } \
};
-FOR_EACH_NODE_KIND(SPECIALIZATION)
-#undef SPECIALIZATION
-
-#undef FOR_EACH_NODE_KIND
-
-template <class T, size_t N>
-class PODSmallVector {
- static_assert(std::is_pod<T>::value,
- "T is required to be a plain old data type");
-
- T* First;
- T* Last;
- T* Cap;
- T Inline[N];
-
- bool isInline() const { return First == Inline; }
-
- void clearInline() {
- First = Inline;
- Last = Inline;
- Cap = Inline + N;
- }
-
- void reserve(size_t NewCap) {
- size_t S = size();
- if (isInline()) {
- auto* Tmp = static_cast<T*>(std::malloc(NewCap * sizeof(T)));
- if (Tmp == nullptr)
- std::terminate();
- std::copy(First, Last, Tmp);
- First = Tmp;
- } else {
- First = static_cast<T*>(std::realloc(First, NewCap * sizeof(T)));
- if (First == nullptr)
- std::terminate();
- }
- Last = First + S;
- Cap = First + NewCap;
- }
-
-public:
- PODSmallVector() : First(Inline), Last(First), Cap(Inline + N) {}
-
- PODSmallVector(const PODSmallVector&) = delete;
- PODSmallVector& operator=(const PODSmallVector&) = delete;
-
- PODSmallVector(PODSmallVector&& Other) : PODSmallVector() {
- if (Other.isInline()) {
- std::copy(Other.begin(), Other.end(), First);
- Last = First + Other.size();
- Other.clear();
- return;
- }
-
- First = Other.First;
- Last = Other.Last;
- Cap = Other.Cap;
- Other.clearInline();
- }
-
- PODSmallVector& operator=(PODSmallVector&& Other) {
- if (Other.isInline()) {
- if (!isInline()) {
- std::free(First);
- clearInline();
- }
- std::copy(Other.begin(), Other.end(), First);
- Last = First + Other.size();
- Other.clear();
- return *this;
- }
-
- if (isInline()) {
- First = Other.First;
- Last = Other.Last;
- Cap = Other.Cap;
- Other.clearInline();
- return *this;
- }
-
- std::swap(First, Other.First);
- std::swap(Last, Other.Last);
- std::swap(Cap, Other.Cap);
- Other.clear();
- return *this;
- }
-
- void push_back(const T& Elem) {
- if (Last == Cap)
- reserve(size() * 2);
- *Last++ = Elem;
- }
-
- void pop_back() {
- assert(Last != First && "Popping empty vector!");
- --Last;
- }
-
- void dropBack(size_t Index) {
- assert(Index <= size() && "dropBack() can't expand!");
- Last = First + Index;
- }
-
- T* begin() { return First; }
- T* end() { return Last; }
-
- bool empty() const { return First == Last; }
- size_t size() const { return static_cast<size_t>(Last - First); }
- T& back() {
- assert(Last != First && "Calling back() on empty vector!");
- return *(Last - 1);
- }
- T& operator[](size_t Index) {
- assert(Index < size() && "Invalid access!");
- return *(begin() + Index);
- }
- void clear() { Last = First; }
-
- ~PODSmallVector() {
- if (!isInline())
- std::free(First);
- }
-};
+#include "ItaniumNodes.def"
template <typename Derived, typename Alloc> struct AbstractManglingParser {
const char *First;
@@ -2350,9 +2406,9 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser {
TemplateParamList Params;
public:
- ScopedTemplateParamList(AbstractManglingParser *Parser)
- : Parser(Parser),
- OldNumTemplateParamLists(Parser->TemplateParams.size()) {
+ ScopedTemplateParamList(AbstractManglingParser *TheParser)
+ : Parser(TheParser),
+ OldNumTemplateParamLists(TheParser->TemplateParams.size()) {
Parser->TemplateParams.push_back(&Params);
}
~ScopedTemplateParamList() {
@@ -2424,8 +2480,9 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser {
return res;
}
- bool consumeIf(StringView S) {
- if (StringView(First, Last).startsWith(S)) {
+ bool consumeIf(std::string_view S) {
+ if (llvm::itanium_demangle::starts_with(
+ std::string_view(First, Last - First), S)) {
First += S.size();
return true;
}
@@ -2442,7 +2499,7 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser {
char consume() { return First != Last ? *First++ : '\0'; }
- char look(unsigned Lookahead = 0) {
+ char look(unsigned Lookahead = 0) const {
if (static_cast<size_t>(Last - First) <= Lookahead)
return '\0';
return First[Lookahead];
@@ -2450,10 +2507,10 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser {
size_t numLeft() const { return static_cast<size_t>(Last - First); }
- StringView parseNumber(bool AllowNegative = false);
+ std::string_view parseNumber(bool AllowNegative = false);
Qualifiers parseCVQualifiers();
bool parsePositiveInteger(size_t *Out);
- StringView parseBareSourceName();
+ std::string_view parseBareSourceName();
bool parseSeqId(size_t *Out);
Node *parseSubstitution();
@@ -2464,16 +2521,17 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser {
/// Parse the <expr> production.
Node *parseExpr();
- Node *parsePrefixExpr(StringView Kind);
- Node *parseBinaryExpr(StringView Kind);
- Node *parseIntegerLiteral(StringView Lit);
+ Node *parsePrefixExpr(std::string_view Kind, Node::Prec Prec);
+ Node *parseBinaryExpr(std::string_view Kind, Node::Prec Prec);
+ Node *parseIntegerLiteral(std::string_view Lit);
Node *parseExprPrimary();
template <class Float> Node *parseFloatingLiteral();
Node *parseFunctionParam();
- Node *parseNewExpr();
Node *parseConversionExpr();
Node *parseBracedExpr();
Node *parseFoldExpr();
+ Node *parsePointerToMemberConversionExpr(Node::Prec Prec);
+ Node *parseSubobjectExpr();
/// Parse the <type> production.
Node *parseType();
@@ -2520,17 +2578,81 @@ template <typename Derived, typename Alloc> struct AbstractManglingParser {
Node *parseName(NameState *State = nullptr);
Node *parseLocalName(NameState *State);
Node *parseOperatorName(NameState *State);
- Node *parseUnqualifiedName(NameState *State);
+ bool parseModuleNameOpt(ModuleName *&Module);
+ Node *parseUnqualifiedName(NameState *State, Node *Scope, ModuleName *Module);
Node *parseUnnamedTypeName(NameState *State);
Node *parseSourceName(NameState *State);
- Node *parseUnscopedName(NameState *State);
+ Node *parseUnscopedName(NameState *State, bool *isSubstName);
Node *parseNestedName(NameState *State);
Node *parseCtorDtorName(Node *&SoFar, NameState *State);
Node *parseAbiTags(Node *N);
+ struct OperatorInfo {
+ enum OIKind : unsigned char {
+ Prefix, // Prefix unary: @ expr
+ Postfix, // Postfix unary: expr @
+ Binary, // Binary: lhs @ rhs
+ Array, // Array index: lhs [ rhs ]
+ Member, // Member access: lhs @ rhs
+ New, // New
+ Del, // Delete
+ Call, // Function call: expr (expr*)
+ CCast, // C cast: (type)expr
+ Conditional, // Conditional: expr ? expr : expr
+ NameOnly, // Overload only, not allowed in expression.
+ // Below do not have operator names
+ NamedCast, // Named cast, @<type>(expr)
+ OfIdOp, // alignof, sizeof, typeid
+
+ Unnameable = NamedCast,
+ };
+ char Enc[2]; // Encoding
+ OIKind Kind; // Kind of operator
+ bool Flag : 1; // Entry-specific flag
+ Node::Prec Prec : 7; // Precedence
+ const char *Name; // Spelling
+
+ public:
+ constexpr OperatorInfo(const char (&E)[3], OIKind K, bool F, Node::Prec P,
+ const char *N)
+ : Enc{E[0], E[1]}, Kind{K}, Flag{F}, Prec{P}, Name{N} {}
+
+ public:
+ bool operator<(const OperatorInfo &Other) const {
+ return *this < Other.Enc;
+ }
+ bool operator<(const char *Peek) const {
+ return Enc[0] < Peek[0] || (Enc[0] == Peek[0] && Enc[1] < Peek[1]);
+ }
+ bool operator==(const char *Peek) const {
+ return Enc[0] == Peek[0] && Enc[1] == Peek[1];
+ }
+ bool operator!=(const char *Peek) const { return !this->operator==(Peek); }
+
+ public:
+ std::string_view getSymbol() const {
+ std::string_view Res = Name;
+ if (Kind < Unnameable) {
+ assert(llvm::itanium_demangle::starts_with(Res, "operator") &&
+ "operator name does not start with 'operator'");
+ Res.remove_prefix(sizeof("operator") - 1);
+ if (llvm::itanium_demangle::starts_with(Res, ' '))
+ Res.remove_prefix(1);
+ }
+ return Res;
+ }
+ std::string_view getName() const { return Name; }
+ OIKind getKind() const { return Kind; }
+ bool getFlag() const { return Flag; }
+ Node::Prec getPrecedence() const { return Prec; }
+ };
+ static const OperatorInfo Ops[];
+ static const size_t NumOps;
+ const OperatorInfo *parseOperatorEncoding();
+
/// Parse the <unresolved-name> production.
- Node *parseUnresolvedName();
+ Node *parseUnresolvedName(bool Global);
Node *parseSimpleId();
Node *parseBaseUnresolvedName();
Node *parseUnresolvedType();
@@ -2551,41 +2673,35 @@ const char* parse_discriminator(const char* first, const char* last);
// ::= <substitution>
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseName(NameState *State) {
- consumeIf('L'); // extension
-
if (look() == 'N')
return getDerived().parseNestedName(State);
if (look() == 'Z')
return getDerived().parseLocalName(State);
- // ::= <unscoped-template-name> <template-args>
- if (look() == 'S' && look(1) != 't') {
- Node *S = getDerived().parseSubstitution();
- if (S == nullptr)
- return nullptr;
- if (look() != 'I')
- return nullptr;
- Node *TA = getDerived().parseTemplateArgs(State != nullptr);
- if (TA == nullptr)
- return nullptr;
- if (State) State->EndsWithTemplateArgs = true;
- return make<NameWithTemplateArgs>(S, TA);
- }
+ Node *Result = nullptr;
+ bool IsSubst = false;
- Node *N = getDerived().parseUnscopedName(State);
- if (N == nullptr)
+ Result = getDerived().parseUnscopedName(State, &IsSubst);
+ if (!Result)
return nullptr;
- // ::= <unscoped-template-name> <template-args>
+
if (look() == 'I') {
- Subs.push_back(N);
+ // ::= <unscoped-template-name> <template-args>
+ if (!IsSubst)
+ // An unscoped-template-name is substitutable.
+ Subs.push_back(Result);
Node *TA = getDerived().parseTemplateArgs(State != nullptr);
if (TA == nullptr)
return nullptr;
- if (State) State->EndsWithTemplateArgs = true;
- return make<NameWithTemplateArgs>(N, TA);
+ if (State)
+ State->EndsWithTemplateArgs = true;
+ Result = make<NameWithTemplateArgs>(Result, TA);
+ } else if (IsSubst) {
+ // The substitution case must be followed by <template-args>.
+ return nullptr;
}
- // ::= <unscoped-name>
- return N;
+
+ return Result;
}
// <local-name> := Z <function encoding> E <entity name> [<discriminator>]
@@ -2626,34 +2742,63 @@ Node *AbstractManglingParser<Derived, Alloc>::parseLocalName(NameState *State) {
// <unscoped-name> ::= <unqualified-name>
// ::= St <unqualified-name> # ::std::
-// extension ::= StL<unqualified-name>
+// [*] extension
template <typename Derived, typename Alloc>
Node *
-AbstractManglingParser<Derived, Alloc>::parseUnscopedName(NameState *State) {
- if (consumeIf("StL") || consumeIf("St")) {
- Node *R = getDerived().parseUnqualifiedName(State);
- if (R == nullptr)
+AbstractManglingParser<Derived, Alloc>::parseUnscopedName(NameState *State,
+ bool *IsSubst) {
+
+ Node *Std = nullptr;
+ if (consumeIf("St")) {
+ Std = make<NameType>("std");
+ if (Std == nullptr)
return nullptr;
- return make<StdQualifiedName>(R);
}
- return getDerived().parseUnqualifiedName(State);
+
+ Node *Res = nullptr;
+ ModuleName *Module = nullptr;
+ if (look() == 'S') {
+ Node *S = getDerived().parseSubstitution();
+ if (!S)
+ return nullptr;
+ if (S->getKind() == Node::KModuleName)
+ Module = static_cast<ModuleName *>(S);
+ else if (IsSubst && Std == nullptr) {
+ Res = S;
+ *IsSubst = true;
+ } else {
+ return nullptr;
+ }
+ }
+
+ if (Res == nullptr || Std != nullptr) {
+ Res = getDerived().parseUnqualifiedName(State, Std, Module);
+ }
+
+ return Res;
}
-// <unqualified-name> ::= <operator-name> [abi-tags]
-// ::= <ctor-dtor-name>
-// ::= <source-name>
-// ::= <unnamed-type-name>
-// ::= DC <source-name>+ E # structured binding declaration
+// <unqualified-name> ::= [<module-name>] L? <operator-name> [<abi-tags>]
+// ::= [<module-name>] <ctor-dtor-name> [<abi-tags>]
+// ::= [<module-name>] L? <source-name> [<abi-tags>]
+// ::= [<module-name>] L? <unnamed-type-name> [<abi-tags>]
+// # structured binding declaration
+// ::= [<module-name>] L? DC <source-name>+ E
template <typename Derived, typename Alloc>
-Node *
-AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(NameState *State) {
- // <ctor-dtor-name>s are special-cased in parseNestedName().
+Node *AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(
+ NameState *State, Node *Scope, ModuleName *Module) {
+ if (getDerived().parseModuleNameOpt(Module))
+ return nullptr;
+
+ consumeIf('L');
+
Node *Result;
- if (look() == 'U')
- Result = getDerived().parseUnnamedTypeName(State);
- else if (look() >= '1' && look() <= '9')
+ if (look() >= '1' && look() <= '9') {
Result = getDerived().parseSourceName(State);
- else if (consumeIf("DC")) {
+ } else if (look() == 'U') {
+ Result = getDerived().parseUnnamedTypeName(State);
+ } else if (consumeIf("DC")) {
+ // Structured binding
size_t BindingsBegin = Names.size();
do {
Node *Binding = getDerived().parseSourceName(State);
@@ -2662,13 +2807,46 @@ AbstractManglingParser<Derived, Alloc>::parseUnqualifiedName(NameState *State) {
Names.push_back(Binding);
} while (!consumeIf('E'));
Result = make<StructuredBindingName>(popTrailingNodeArray(BindingsBegin));
- } else
+ } else if (look() == 'C' || look() == 'D') {
+ // A <ctor-dtor-name>.
+ if (Scope == nullptr || Module != nullptr)
+ return nullptr;
+ Result = getDerived().parseCtorDtorName(Scope, State);
+ } else {
Result = getDerived().parseOperatorName(State);
+ }
+
+ if (Result != nullptr && Module != nullptr)
+ Result = make<ModuleEntity>(Module, Result);
if (Result != nullptr)
Result = getDerived().parseAbiTags(Result);
+ if (Result != nullptr && Scope != nullptr)
+ Result = make<NestedName>(Scope, Result);
+
return Result;
}
+// <module-name> ::= <module-subname>
+// ::= <module-name> <module-subname>
+// ::= <substitution> # passed in by caller
+// <module-subname> ::= W <source-name>
+// ::= W P <source-name>
+template <typename Derived, typename Alloc>
+bool AbstractManglingParser<Derived, Alloc>::parseModuleNameOpt(
+ ModuleName *&Module) {
+ while (consumeIf('W')) {
+ bool IsPartition = consumeIf('P');
+ Node *Sub = getDerived().parseSourceName(nullptr);
+ if (!Sub)
+ return true;
+ Module =
+ static_cast<ModuleName *>(make<ModuleName>(Module, Sub, IsPartition));
+ Subs.push_back(Module);
+ }
+
+ return false;
+}
+
// <unnamed-type-name> ::= Ut [<nonnegative number>] _
// ::= <closure-type-name>
//
@@ -2684,19 +2862,19 @@ AbstractManglingParser<Derived, Alloc>::parseUnnamedTypeName(NameState *State) {
TemplateParams.clear();
if (consumeIf("Ut")) {
- StringView Count = parseNumber();
+ std::string_view Count = parseNumber();
if (!consumeIf('_'))
return nullptr;
return make<UnnamedTypeName>(Count);
}
if (consumeIf("Ul")) {
- SwapAndRestore<size_t> SwapParams(ParsingLambdaParamsAtLevel,
+ ScopedOverride<size_t> SwapParams(ParsingLambdaParamsAtLevel,
TemplateParams.size());
ScopedTemplateParamList LambdaTemplateParams(this);
size_t ParamsBegin = Names.size();
while (look() == 'T' &&
- StringView("yptn").find(look(1)) != StringView::npos) {
+ std::string_view("yptn").find(look(1)) != std::string_view::npos) {
Node *T = parseTemplateParamDecl();
if (!T)
return nullptr;
@@ -2739,7 +2917,7 @@ AbstractManglingParser<Derived, Alloc>::parseUnnamedTypeName(NameState *State) {
}
NodeArray Params = popTrailingNodeArray(ParamsBegin);
- StringView Count = parseNumber();
+ std::string_view Count = parseNumber();
if (!consumeIf('_'))
return nullptr;
return make<ClosureTypeName>(TempParams, Params, Count);
@@ -2761,104 +2939,138 @@ Node *AbstractManglingParser<Derived, Alloc>::parseSourceName(NameState *) {
return nullptr;
if (numLeft() < Length || Length == 0)
return nullptr;
- StringView Name(First, First + Length);
+ std::string_view Name(First, Length);
First += Length;
- if (Name.startsWith("_GLOBAL__N"))
+ if (llvm::itanium_demangle::starts_with(Name, "_GLOBAL__N"))
return make<NameType>("(anonymous namespace)");
return make<NameType>(Name);
}
-// <operator-name> ::= aa # &&
-// ::= ad # & (unary)
-// ::= an # &
-// ::= aN # &=
-// ::= aS # =
-// ::= cl # ()
-// ::= cm # ,
-// ::= co # ~
-// ::= cv <type> # (cast)
-// ::= da # delete[]
-// ::= de # * (unary)
-// ::= dl # delete
-// ::= dv # /
-// ::= dV # /=
-// ::= eo # ^
-// ::= eO # ^=
-// ::= eq # ==
-// ::= ge # >=
-// ::= gt # >
-// ::= ix # []
-// ::= le # <=
+// Operator encodings
+template <typename Derived, typename Alloc>
+const typename AbstractManglingParser<
+ Derived, Alloc>::OperatorInfo AbstractManglingParser<Derived,
+ Alloc>::Ops[] = {
+ // Keep ordered by encoding
+ {"aN", OperatorInfo::Binary, false, Node::Prec::Assign, "operator&="},
+ {"aS", OperatorInfo::Binary, false, Node::Prec::Assign, "operator="},
+ {"aa", OperatorInfo::Binary, false, Node::Prec::AndIf, "operator&&"},
+ {"ad", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator&"},
+ {"an", OperatorInfo::Binary, false, Node::Prec::And, "operator&"},
+ {"at", OperatorInfo::OfIdOp, /*Type*/ true, Node::Prec::Unary, "alignof "},
+ {"aw", OperatorInfo::NameOnly, false, Node::Prec::Primary,
+ "operator co_await"},
+ {"az", OperatorInfo::OfIdOp, /*Type*/ false, Node::Prec::Unary, "alignof "},
+ {"cc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, "const_cast"},
+ {"cl", OperatorInfo::Call, false, Node::Prec::Postfix, "operator()"},
+ {"cm", OperatorInfo::Binary, false, Node::Prec::Comma, "operator,"},
+ {"co", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator~"},
+ {"cv", OperatorInfo::CCast, false, Node::Prec::Cast, "operator"}, // C Cast
+ {"dV", OperatorInfo::Binary, false, Node::Prec::Assign, "operator/="},
+ {"da", OperatorInfo::Del, /*Ary*/ true, Node::Prec::Unary,
+ "operator delete[]"},
+ {"dc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, "dynamic_cast"},
+ {"de", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator*"},
+ {"dl", OperatorInfo::Del, /*Ary*/ false, Node::Prec::Unary,
+ "operator delete"},
+ {"ds", OperatorInfo::Member, /*Named*/ false, Node::Prec::PtrMem,
+ "operator.*"},
+ {"dt", OperatorInfo::Member, /*Named*/ false, Node::Prec::Postfix,
+ "operator."},
+ {"dv", OperatorInfo::Binary, false, Node::Prec::Assign, "operator/"},
+ {"eO", OperatorInfo::Binary, false, Node::Prec::Assign, "operator^="},
+ {"eo", OperatorInfo::Binary, false, Node::Prec::Xor, "operator^"},
+ {"eq", OperatorInfo::Binary, false, Node::Prec::Equality, "operator=="},
+ {"ge", OperatorInfo::Binary, false, Node::Prec::Relational, "operator>="},
+ {"gt", OperatorInfo::Binary, false, Node::Prec::Relational, "operator>"},
+ {"ix", OperatorInfo::Array, false, Node::Prec::Postfix, "operator[]"},
+ {"lS", OperatorInfo::Binary, false, Node::Prec::Assign, "operator<<="},
+ {"le", OperatorInfo::Binary, false, Node::Prec::Relational, "operator<="},
+ {"ls", OperatorInfo::Binary, false, Node::Prec::Shift, "operator<<"},
+ {"lt", OperatorInfo::Binary, false, Node::Prec::Relational, "operator<"},
+ {"mI", OperatorInfo::Binary, false, Node::Prec::Assign, "operator-="},
+ {"mL", OperatorInfo::Binary, false, Node::Prec::Assign, "operator*="},
+ {"mi", OperatorInfo::Binary, false, Node::Prec::Additive, "operator-"},
+ {"ml", OperatorInfo::Binary, false, Node::Prec::Multiplicative,
+ "operator*"},
+ {"mm", OperatorInfo::Postfix, false, Node::Prec::Postfix, "operator--"},
+ {"na", OperatorInfo::New, /*Ary*/ true, Node::Prec::Unary,
+ "operator new[]"},
+ {"ne", OperatorInfo::Binary, false, Node::Prec::Equality, "operator!="},
+ {"ng", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator-"},
+ {"nt", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator!"},
+ {"nw", OperatorInfo::New, /*Ary*/ false, Node::Prec::Unary, "operator new"},
+ {"oR", OperatorInfo::Binary, false, Node::Prec::Assign, "operator|="},
+ {"oo", OperatorInfo::Binary, false, Node::Prec::OrIf, "operator||"},
+ {"or", OperatorInfo::Binary, false, Node::Prec::Ior, "operator|"},
+ {"pL", OperatorInfo::Binary, false, Node::Prec::Assign, "operator+="},
+ {"pl", OperatorInfo::Binary, false, Node::Prec::Additive, "operator+"},
+ {"pm", OperatorInfo::Member, /*Named*/ false, Node::Prec::PtrMem,
+ "operator->*"},
+ {"pp", OperatorInfo::Postfix, false, Node::Prec::Postfix, "operator++"},
+ {"ps", OperatorInfo::Prefix, false, Node::Prec::Unary, "operator+"},
+ {"pt", OperatorInfo::Member, /*Named*/ true, Node::Prec::Postfix,
+ "operator->"},
+ {"qu", OperatorInfo::Conditional, false, Node::Prec::Conditional,
+ "operator?"},
+ {"rM", OperatorInfo::Binary, false, Node::Prec::Assign, "operator%="},
+ {"rS", OperatorInfo::Binary, false, Node::Prec::Assign, "operator>>="},
+ {"rc", OperatorInfo::NamedCast, false, Node::Prec::Postfix,
+ "reinterpret_cast"},
+ {"rm", OperatorInfo::Binary, false, Node::Prec::Multiplicative,
+ "operator%"},
+ {"rs", OperatorInfo::Binary, false, Node::Prec::Shift, "operator>>"},
+ {"sc", OperatorInfo::NamedCast, false, Node::Prec::Postfix, "static_cast"},
+ {"ss", OperatorInfo::Binary, false, Node::Prec::Spaceship, "operator<=>"},
+ {"st", OperatorInfo::OfIdOp, /*Type*/ true, Node::Prec::Unary, "sizeof "},
+ {"sz", OperatorInfo::OfIdOp, /*Type*/ false, Node::Prec::Unary, "sizeof "},
+ {"te", OperatorInfo::OfIdOp, /*Type*/ false, Node::Prec::Postfix,
+ "typeid "},
+ {"ti", OperatorInfo::OfIdOp, /*Type*/ true, Node::Prec::Postfix, "typeid "},
+};
+template <typename Derived, typename Alloc>
+const size_t AbstractManglingParser<Derived, Alloc>::NumOps = sizeof(Ops) /
+ sizeof(Ops[0]);
+
+// If the next 2 chars are an operator encoding, consume them and return their
+// OperatorInfo. Otherwise return nullptr.
+template <typename Derived, typename Alloc>
+const typename AbstractManglingParser<Derived, Alloc>::OperatorInfo *
+AbstractManglingParser<Derived, Alloc>::parseOperatorEncoding() {
+ if (numLeft() < 2)
+ return nullptr;
+
+ // We can't use lower_bound as that can link to symbols in the C++ library,
+ // and this must remain independant of that.
+ size_t lower = 0u, upper = NumOps - 1; // Inclusive bounds.
+ while (upper != lower) {
+ size_t middle = (upper + lower) / 2;
+ if (Ops[middle] < First)
+ lower = middle + 1;
+ else
+ upper = middle;
+ }
+ if (Ops[lower] != First)
+ return nullptr;
+
+ First += 2;
+ return &Ops[lower];
+}
+
+// <operator-name> ::= See parseOperatorEncoding()
// ::= li <source-name> # operator ""
-// ::= ls # <<
-// ::= lS # <<=
-// ::= lt # <
-// ::= mi # -
-// ::= mI # -=
-// ::= ml # *
-// ::= mL # *=
-// ::= mm # -- (postfix in <expression> context)
-// ::= na # new[]
-// ::= ne # !=
-// ::= ng # - (unary)
-// ::= nt # !
-// ::= nw # new
-// ::= oo # ||
-// ::= or # |
-// ::= oR # |=
-// ::= pm # ->*
-// ::= pl # +
-// ::= pL # +=
-// ::= pp # ++ (postfix in <expression> context)
-// ::= ps # + (unary)
-// ::= pt # ->
-// ::= qu # ?
-// ::= rm # %
-// ::= rM # %=
-// ::= rs # >>
-// ::= rS # >>=
-// ::= ss # <=> C++2a
-// ::= v <digit> <source-name> # vendor extended operator
+// ::= v <digit> <source-name> # vendor extended operator
template <typename Derived, typename Alloc>
Node *
AbstractManglingParser<Derived, Alloc>::parseOperatorName(NameState *State) {
- switch (look()) {
- case 'a':
- switch (look(1)) {
- case 'a':
- First += 2;
- return make<NameType>("operator&&");
- case 'd':
- case 'n':
- First += 2;
- return make<NameType>("operator&");
- case 'N':
- First += 2;
- return make<NameType>("operator&=");
- case 'S':
- First += 2;
- return make<NameType>("operator=");
- }
- return nullptr;
- case 'c':
- switch (look(1)) {
- case 'l':
- First += 2;
- return make<NameType>("operator()");
- case 'm':
- First += 2;
- return make<NameType>("operator,");
- case 'o':
- First += 2;
- return make<NameType>("operator~");
- // ::= cv <type> # (cast)
- case 'v': {
- First += 2;
- SwapAndRestore<bool> SaveTemplate(TryToParseTemplateArgs, false);
+ if (const auto *Op = parseOperatorEncoding()) {
+ if (Op->getKind() == OperatorInfo::CCast) {
+ // ::= cv <type> # (cast)
+ ScopedOverride<bool> SaveTemplate(TryToParseTemplateArgs, false);
// If we're parsing an encoding, State != nullptr and the conversion
// operators' <type> could have a <template-param> that refers to some
// <template-arg>s further ahead in the mangled name.
- SwapAndRestore<bool> SavePermit(PermitForwardTemplateReferences,
+ ScopedOverride<bool> SavePermit(PermitForwardTemplateReferences,
PermitForwardTemplateReferences ||
State != nullptr);
Node *Ty = getDerived().parseType();
@@ -2867,185 +3079,29 @@ AbstractManglingParser<Derived, Alloc>::parseOperatorName(NameState *State) {
if (State) State->CtorDtorConversion = true;
return make<ConversionOperatorType>(Ty);
}
- }
- return nullptr;
- case 'd':
- switch (look(1)) {
- case 'a':
- First += 2;
- return make<NameType>("operator delete[]");
- case 'e':
- First += 2;
- return make<NameType>("operator*");
- case 'l':
- First += 2;
- return make<NameType>("operator delete");
- case 'v':
- First += 2;
- return make<NameType>("operator/");
- case 'V':
- First += 2;
- return make<NameType>("operator/=");
- }
- return nullptr;
- case 'e':
- switch (look(1)) {
- case 'o':
- First += 2;
- return make<NameType>("operator^");
- case 'O':
- First += 2;
- return make<NameType>("operator^=");
- case 'q':
- First += 2;
- return make<NameType>("operator==");
- }
- return nullptr;
- case 'g':
- switch (look(1)) {
- case 'e':
- First += 2;
- return make<NameType>("operator>=");
- case 't':
- First += 2;
- return make<NameType>("operator>");
- }
- return nullptr;
- case 'i':
- if (look(1) == 'x') {
- First += 2;
- return make<NameType>("operator[]");
- }
- return nullptr;
- case 'l':
- switch (look(1)) {
- case 'e':
- First += 2;
- return make<NameType>("operator<=");
+
+ if (Op->getKind() >= OperatorInfo::Unnameable)
+ /* Not a nameable operator. */
+ return nullptr;
+ if (Op->getKind() == OperatorInfo::Member && !Op->getFlag())
+ /* Not a nameable MemberExpr */
+ return nullptr;
+
+ return make<NameType>(Op->getName());
+ }
+
+ if (consumeIf("li")) {
// ::= li <source-name> # operator ""
- case 'i': {
- First += 2;
- Node *SN = getDerived().parseSourceName(State);
- if (SN == nullptr)
- return nullptr;
- return make<LiteralOperator>(SN);
- }
- case 's':
- First += 2;
- return make<NameType>("operator<<");
- case 'S':
- First += 2;
- return make<NameType>("operator<<=");
- case 't':
- First += 2;
- return make<NameType>("operator<");
- }
- return nullptr;
- case 'm':
- switch (look(1)) {
- case 'i':
- First += 2;
- return make<NameType>("operator-");
- case 'I':
- First += 2;
- return make<NameType>("operator-=");
- case 'l':
- First += 2;
- return make<NameType>("operator*");
- case 'L':
- First += 2;
- return make<NameType>("operator*=");
- case 'm':
- First += 2;
- return make<NameType>("operator--");
- }
- return nullptr;
- case 'n':
- switch (look(1)) {
- case 'a':
- First += 2;
- return make<NameType>("operator new[]");
- case 'e':
- First += 2;
- return make<NameType>("operator!=");
- case 'g':
- First += 2;
- return make<NameType>("operator-");
- case 't':
- First += 2;
- return make<NameType>("operator!");
- case 'w':
- First += 2;
- return make<NameType>("operator new");
- }
- return nullptr;
- case 'o':
- switch (look(1)) {
- case 'o':
- First += 2;
- return make<NameType>("operator||");
- case 'r':
- First += 2;
- return make<NameType>("operator|");
- case 'R':
- First += 2;
- return make<NameType>("operator|=");
- }
- return nullptr;
- case 'p':
- switch (look(1)) {
- case 'm':
- First += 2;
- return make<NameType>("operator->*");
- case 'l':
- First += 2;
- return make<NameType>("operator+");
- case 'L':
- First += 2;
- return make<NameType>("operator+=");
- case 'p':
- First += 2;
- return make<NameType>("operator++");
- case 's':
- First += 2;
- return make<NameType>("operator+");
- case 't':
- First += 2;
- return make<NameType>("operator->");
- }
- return nullptr;
- case 'q':
- if (look(1) == 'u') {
- First += 2;
- return make<NameType>("operator?");
- }
- return nullptr;
- case 'r':
- switch (look(1)) {
- case 'm':
- First += 2;
- return make<NameType>("operator%");
- case 'M':
- First += 2;
- return make<NameType>("operator%=");
- case 's':
- First += 2;
- return make<NameType>("operator>>");
- case 'S':
- First += 2;
- return make<NameType>("operator>>=");
- }
- return nullptr;
- case 's':
- if (look(1) == 's') {
- First += 2;
- return make<NameType>("operator<=>");
- }
- return nullptr;
- // ::= v <digit> <source-name> # vendor extended operator
- case 'v':
- if (std::isdigit(look(1))) {
- First += 2;
+ Node *SN = getDerived().parseSourceName(State);
+ if (SN == nullptr)
+ return nullptr;
+ return make<LiteralOperator>(SN);
+ }
+
+ if (consumeIf('v')) {
+ // ::= v <digit> <source-name> # vendor extended operator
+ if (look() >= '0' && look() <= '9') {
+ First++;
Node *SN = getDerived().parseSourceName(State);
if (SN == nullptr)
return nullptr;
@@ -3053,6 +3109,7 @@ AbstractManglingParser<Derived, Alloc>::parseOperatorName(NameState *State) {
}
return nullptr;
}
+
return nullptr;
}
@@ -3071,19 +3128,11 @@ Node *
AbstractManglingParser<Derived, Alloc>::parseCtorDtorName(Node *&SoFar,
NameState *State) {
if (SoFar->getKind() == Node::KSpecialSubstitution) {
- auto SSK = static_cast<SpecialSubstitution *>(SoFar)->SSK;
- switch (SSK) {
- case SpecialSubKind::string:
- case SpecialSubKind::istream:
- case SpecialSubKind::ostream:
- case SpecialSubKind::iostream:
- SoFar = make<ExpandedSpecialSubstitution>(SSK);
- if (!SoFar)
- return nullptr;
- break;
- default:
- break;
- }
+ // Expand the special substitution.
+ SoFar = make<ExpandedSpecialSubstitution>(
+ static_cast<SpecialSubstitution *>(SoFar));
+ if (!SoFar)
+ return nullptr;
}
if (consumeIf('C')) {
@@ -3112,8 +3161,10 @@ AbstractManglingParser<Derived, Alloc>::parseCtorDtorName(Node *&SoFar,
return nullptr;
}
-// <nested-name> ::= N [<CV-Qualifiers>] [<ref-qualifier>] <prefix> <unqualified-name> E
-// ::= N [<CV-Qualifiers>] [<ref-qualifier>] <template-prefix> <template-args> E
+// <nested-name> ::= N [<CV-Qualifiers>] [<ref-qualifier>] <prefix>
+// <unqualified-name> E
+// ::= N [<CV-Qualifiers>] [<ref-qualifier>] <template-prefix>
+// <template-args> E
//
// <prefix> ::= <prefix> <unqualified-name>
// ::= <template-prefix> <template-args>
@@ -3122,7 +3173,7 @@ AbstractManglingParser<Derived, Alloc>::parseCtorDtorName(Node *&SoFar,
// ::= # empty
// ::= <substitution>
// ::= <prefix> <data-member-prefix>
-// extension ::= L
+// [*] extension
//
// <data-member-prefix> := <member source-name> [<template-args>] M
//
@@ -3142,90 +3193,76 @@ AbstractManglingParser<Derived, Alloc>::parseNestedName(NameState *State) {
if (State) State->ReferenceQualifier = FrefQualRValue;
} else if (consumeIf('R')) {
if (State) State->ReferenceQualifier = FrefQualLValue;
- } else
+ } else {
if (State) State->ReferenceQualifier = FrefQualNone;
-
- Node *SoFar = nullptr;
- auto PushComponent = [&](Node *Comp) {
- if (!Comp) return false;
- if (SoFar) SoFar = make<NestedName>(SoFar, Comp);
- else SoFar = Comp;
- if (State) State->EndsWithTemplateArgs = false;
- return SoFar != nullptr;
- };
-
- if (consumeIf("St")) {
- SoFar = make<NameType>("std");
- if (!SoFar)
- return nullptr;
}
+ Node *SoFar = nullptr;
while (!consumeIf('E')) {
- consumeIf('L'); // extension
+ if (State)
+ // Only set end-with-template on the case that does that.
+ State->EndsWithTemplateArgs = false;
- // <data-member-prefix> := <member source-name> [<template-args>] M
- if (consumeIf('M')) {
- if (SoFar == nullptr)
- return nullptr;
- continue;
- }
-
- // ::= <template-param>
if (look() == 'T') {
- if (!PushComponent(getDerived().parseTemplateParam()))
- return nullptr;
- Subs.push_back(SoFar);
- continue;
- }
-
- // ::= <template-prefix> <template-args>
- if (look() == 'I') {
+ // ::= <template-param>
+ if (SoFar != nullptr)
+ return nullptr; // Cannot have a prefix.
+ SoFar = getDerived().parseTemplateParam();
+ } else if (look() == 'I') {
+ // ::= <template-prefix> <template-args>
+ if (SoFar == nullptr)
+ return nullptr; // Must have a prefix.
Node *TA = getDerived().parseTemplateArgs(State != nullptr);
- if (TA == nullptr || SoFar == nullptr)
- return nullptr;
- SoFar = make<NameWithTemplateArgs>(SoFar, TA);
- if (!SoFar)
- return nullptr;
- if (State) State->EndsWithTemplateArgs = true;
- Subs.push_back(SoFar);
- continue;
- }
-
- // ::= <decltype>
- if (look() == 'D' && (look(1) == 't' || look(1) == 'T')) {
- if (!PushComponent(getDerived().parseDecltype()))
+ if (TA == nullptr)
return nullptr;
- Subs.push_back(SoFar);
- continue;
- }
-
- // ::= <substitution>
- if (look() == 'S' && look(1) != 't') {
- Node *S = getDerived().parseSubstitution();
- if (!PushComponent(S))
+ if (SoFar->getKind() == Node::KNameWithTemplateArgs)
+ // Semantically <template-args> <template-args> cannot be generated by a
+ // C++ entity. There will always be [something like] a name between
+ // them.
return nullptr;
- if (SoFar != S)
- Subs.push_back(S);
- continue;
- }
+ if (State)
+ State->EndsWithTemplateArgs = true;
+ SoFar = make<NameWithTemplateArgs>(SoFar, TA);
+ } else if (look() == 'D' && (look(1) == 't' || look(1) == 'T')) {
+ // ::= <decltype>
+ if (SoFar != nullptr)
+ return nullptr; // Cannot have a prefix.
+ SoFar = getDerived().parseDecltype();
+ } else {
+ ModuleName *Module = nullptr;
+
+ if (look() == 'S') {
+ // ::= <substitution>
+ Node *S = nullptr;
+ if (look(1) == 't') {
+ First += 2;
+ S = make<NameType>("std");
+ } else {
+ S = getDerived().parseSubstitution();
+ }
+ if (!S)
+ return nullptr;
+ if (S->getKind() == Node::KModuleName) {
+ Module = static_cast<ModuleName *>(S);
+ } else if (SoFar != nullptr) {
+ return nullptr; // Cannot have a prefix.
+ } else {
+ SoFar = S;
+ continue; // Do not push a new substitution.
+ }
+ }
- // Parse an <unqualified-name> thats actually a <ctor-dtor-name>.
- if (look() == 'C' || (look() == 'D' && look(1) != 'C')) {
- if (SoFar == nullptr)
- return nullptr;
- if (!PushComponent(getDerived().parseCtorDtorName(SoFar, State)))
- return nullptr;
- SoFar = getDerived().parseAbiTags(SoFar);
- if (SoFar == nullptr)
- return nullptr;
- Subs.push_back(SoFar);
- continue;
+ // ::= [<prefix>] <unqualified-name>
+ SoFar = getDerived().parseUnqualifiedName(State, SoFar, Module);
}
- // ::= <prefix> <unqualified-name>
- if (!PushComponent(getDerived().parseUnqualifiedName(State)))
+ if (SoFar == nullptr)
return nullptr;
Subs.push_back(SoFar);
+
+ // No longer used.
+ // <data-member-prefix> := <member source-name> [<template-args>] M
+ consumeIf('M');
}
if (SoFar == nullptr || Subs.empty())
@@ -3320,6 +3357,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseBaseUnresolvedName() {
// ::= [gs] <base-unresolved-name> # x or (with "gs") ::x
// ::= [gs] sr <unresolved-qualifier-level>+ E <base-unresolved-name>
// # A::x, N::y, A<T>::z; "gs" means leading "::"
+// [gs] has been parsed by caller.
// ::= sr <unresolved-type> <base-unresolved-name> # T::x / decltype(p)::x
// extension ::= sr <unresolved-type> <template-args> <base-unresolved-name>
// # T::N::x /decltype(p)::N::x
@@ -3327,7 +3365,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseBaseUnresolvedName() {
//
// <unresolved-qualifier-level> ::= <simple-id>
template <typename Derived, typename Alloc>
-Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName() {
+Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName(bool Global) {
Node *SoFar = nullptr;
// srN <unresolved-type> [<template-args>] <unresolved-qualifier-level>* E <base-unresolved-name>
@@ -3361,8 +3399,6 @@ Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName() {
return make<QualifiedName>(SoFar, Base);
}
- bool Global = consumeIf("gs");
-
// [gs] <base-unresolved-name> # x or (with "gs") ::x
if (!consumeIf("sr")) {
SoFar = getDerived().parseBaseUnresolvedName();
@@ -3419,7 +3455,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseUnresolvedName() {
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseAbiTags(Node *N) {
while (consumeIf('B')) {
- StringView SN = parseBareSourceName();
+ std::string_view SN = parseBareSourceName();
if (SN.empty())
return nullptr;
N = make<AbiTagAttr>(N, SN);
@@ -3431,16 +3467,16 @@ Node *AbstractManglingParser<Derived, Alloc>::parseAbiTags(Node *N) {
// <number> ::= [n] <non-negative decimal integer>
template <typename Alloc, typename Derived>
-StringView
+std::string_view
AbstractManglingParser<Alloc, Derived>::parseNumber(bool AllowNegative) {
const char *Tmp = First;
if (AllowNegative)
consumeIf('n');
if (numLeft() == 0 || !std::isdigit(*First))
- return StringView();
+ return std::string_view();
while (numLeft() != 0 && std::isdigit(*First))
++First;
- return StringView(Tmp, First);
+ return std::string_view(Tmp, First - Tmp);
}
// <positive length number> ::= [0-9]*
@@ -3457,11 +3493,11 @@ bool AbstractManglingParser<Alloc, Derived>::parsePositiveInteger(size_t *Out) {
}
template <typename Alloc, typename Derived>
-StringView AbstractManglingParser<Alloc, Derived>::parseBareSourceName() {
+std::string_view AbstractManglingParser<Alloc, Derived>::parseBareSourceName() {
size_t Int = 0;
if (parsePositiveInteger(&Int) || numLeft() < Int)
- return StringView();
- StringView R(First, First + Int);
+ return {};
+ std::string_view R(First, Int);
First += Int;
return R;
}
@@ -3549,7 +3585,9 @@ Node *AbstractManglingParser<Derived, Alloc>::parseVectorType() {
if (!consumeIf("Dv"))
return nullptr;
if (look() >= '1' && look() <= '9') {
- StringView DimensionNumber = parseNumber();
+ Node *DimensionNumber = make<NameType>(parseNumber());
+ if (!DimensionNumber)
+ return nullptr;
if (!consumeIf('_'))
return nullptr;
if (consumeIf('p'))
@@ -3574,7 +3612,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseVectorType() {
Node *ElemType = getDerived().parseType();
if (!ElemType)
return nullptr;
- return make<VectorType>(ElemType, StringView());
+ return make<VectorType>(ElemType, /*Dimension=*/nullptr);
}
// <decltype> ::= Dt <expression> E # decltype of an id-expression or class member access (C++0x)
@@ -3590,7 +3628,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseDecltype() {
return nullptr;
if (!consumeIf('E'))
return nullptr;
- return make<EnclosingExpr>("decltype(", E, ")");
+ return make<EnclosingExpr>("decltype", E);
}
// <array-type> ::= A <positive dimension number> _ <element type>
@@ -3600,10 +3638,12 @@ Node *AbstractManglingParser<Derived, Alloc>::parseArrayType() {
if (!consumeIf('A'))
return nullptr;
- NodeOrString Dimension;
+ Node *Dimension = nullptr;
if (std::isdigit(look())) {
- Dimension = parseNumber();
+ Dimension = make<NameType>(parseNumber());
+ if (!Dimension)
+ return nullptr;
if (!consumeIf('_'))
return nullptr;
} else if (!consumeIf('_')) {
@@ -3641,7 +3681,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parsePointerToMemberType() {
// ::= Te <name> # dependent elaborated type specifier using 'enum'
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseClassEnumType() {
- StringView ElabSpef;
+ std::string_view ElabSpef;
if (consumeIf("Ts"))
ElabSpef = "struct";
else if (consumeIf("Tu"))
@@ -3665,19 +3705,18 @@ Node *AbstractManglingParser<Derived, Alloc>::parseClassEnumType() {
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseQualifiedType() {
if (consumeIf('U')) {
- StringView Qual = parseBareSourceName();
+ std::string_view Qual = parseBareSourceName();
if (Qual.empty())
return nullptr;
- // FIXME parse the optional <template-args> here!
-
// extension ::= U <objc-name> <objc-type> # objc-type<identifier>
- if (Qual.startsWith("objcproto")) {
- StringView ProtoSourceName = Qual.dropFront(std::strlen("objcproto"));
- StringView Proto;
+ if (llvm::itanium_demangle::starts_with(Qual, "objcproto")) {
+ constexpr size_t Len = sizeof("objcproto") - 1;
+ std::string_view ProtoSourceName(Qual.data() + Len, Qual.size() - Len);
+ std::string_view Proto;
{
- SwapAndRestore<const char *> SaveFirst(First, ProtoSourceName.begin()),
- SaveLast(Last, ProtoSourceName.end());
+ ScopedOverride<const char *> SaveFirst(First, ProtoSourceName.data()),
+ SaveLast(Last, &*ProtoSourceName.rbegin() + 1);
Proto = parseBareSourceName();
}
if (Proto.empty())
@@ -3688,10 +3727,17 @@ Node *AbstractManglingParser<Derived, Alloc>::parseQualifiedType() {
return make<ObjCProtoName>(Child, Proto);
}
+ Node *TA = nullptr;
+ if (look() == 'I') {
+ TA = getDerived().parseTemplateArgs();
+ if (TA == nullptr)
+ return nullptr;
+ }
+
Node *Child = getDerived().parseQualifiedType();
if (Child == nullptr)
return nullptr;
- return make<VendorExtQualType>(Child, Qual);
+ return make<VendorExtQualType>(Child, Qual, TA);
}
Qualifiers Quals = parseCVQualifiers();
@@ -3838,7 +3884,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() {
// <builtin-type> ::= u <source-name> # vendor extended type
case 'u': {
++First;
- StringView Res = parseBareSourceName();
+ std::string_view Res = parseBareSourceName();
if (Res.empty())
return nullptr;
// Typically, <builtin-type>s are not considered substitution candidates,
@@ -3864,7 +3910,33 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() {
// ::= Dh # IEEE 754r half-precision floating point (16 bits)
case 'h':
First += 2;
- return make<NameType>("decimal16");
+ return make<NameType>("half");
+ // ::= DF <number> _ # ISO/IEC TS 18661 binary floating point (N bits)
+ case 'F': {
+ First += 2;
+ Node *DimensionNumber = make<NameType>(parseNumber());
+ if (!DimensionNumber)
+ return nullptr;
+ if (!consumeIf('_'))
+ return nullptr;
+ return make<BinaryFPType>(DimensionNumber);
+ }
+ // ::= DB <number> _ # C23 signed _BitInt(N)
+ // ::= DB <instantiation-dependent expression> _ # C23 signed _BitInt(N)
+ // ::= DU <number> _ # C23 unsigned _BitInt(N)
+ // ::= DU <instantiation-dependent expression> _ # C23 unsigned _BitInt(N)
+ case 'B':
+ case 'U': {
+ bool Signed = look(1) == 'B';
+ First += 2;
+ Node *Size = std::isdigit(look()) ? make<NameType>(parseNumber())
+ : getDerived().parseExpr();
+ if (!Size)
+ return nullptr;
+ if (!consumeIf('_'))
+ return nullptr;
+ return make<BitIntType>(Size, Signed);
+ }
// ::= Di # char32_t
case 'i':
First += 2;
@@ -4012,9 +4084,10 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() {
}
// ::= <substitution> # See Compression below
case 'S': {
- if (look(1) && look(1) != 't') {
- Node *Sub = getDerived().parseSubstitution();
- if (Sub == nullptr)
+ if (look(1) != 't') {
+ bool IsSubst = false;
+ Result = getDerived().parseUnscopedName(nullptr, &IsSubst);
+ if (!Result)
return nullptr;
// Sub could be either of:
@@ -4027,17 +4100,19 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() {
// If this is followed by some <template-args>, and we're permitted to
// parse them, take the second production.
- if (TryToParseTemplateArgs && look() == 'I') {
+ if (look() == 'I' && (!IsSubst || TryToParseTemplateArgs)) {
+ if (!IsSubst)
+ Subs.push_back(Result);
Node *TA = getDerived().parseTemplateArgs();
if (TA == nullptr)
return nullptr;
- Result = make<NameWithTemplateArgs>(Sub, TA);
- break;
+ Result = make<NameWithTemplateArgs>(Result, TA);
+ } else if (IsSubst) {
+ // If all we parsed was a substitution, don't re-insert into the
+ // substitution table.
+ return Result;
}
-
- // If all we parsed was a substitution, don't re-insert into the
- // substitution table.
- return Sub;
+ break;
}
DEMANGLE_FALLTHROUGH;
}
@@ -4057,28 +4132,32 @@ Node *AbstractManglingParser<Derived, Alloc>::parseType() {
}
template <typename Derived, typename Alloc>
-Node *AbstractManglingParser<Derived, Alloc>::parsePrefixExpr(StringView Kind) {
+Node *
+AbstractManglingParser<Derived, Alloc>::parsePrefixExpr(std::string_view Kind,
+ Node::Prec Prec) {
Node *E = getDerived().parseExpr();
if (E == nullptr)
return nullptr;
- return make<PrefixExpr>(Kind, E);
+ return make<PrefixExpr>(Kind, E, Prec);
}
template <typename Derived, typename Alloc>
-Node *AbstractManglingParser<Derived, Alloc>::parseBinaryExpr(StringView Kind) {
+Node *
+AbstractManglingParser<Derived, Alloc>::parseBinaryExpr(std::string_view Kind,
+ Node::Prec Prec) {
Node *LHS = getDerived().parseExpr();
if (LHS == nullptr)
return nullptr;
Node *RHS = getDerived().parseExpr();
if (RHS == nullptr)
return nullptr;
- return make<BinaryExpr>(LHS, Kind, RHS);
+ return make<BinaryExpr>(LHS, Kind, RHS, Prec);
}
template <typename Derived, typename Alloc>
-Node *
-AbstractManglingParser<Derived, Alloc>::parseIntegerLiteral(StringView Lit) {
- StringView Tmp = parseNumber(true);
+Node *AbstractManglingParser<Derived, Alloc>::parseIntegerLiteral(
+ std::string_view Lit) {
+ std::string_view Tmp = parseNumber(true);
if (!Tmp.empty() && consumeIf('E'))
return make<IntegerLiteral>(Lit, Tmp);
return nullptr;
@@ -4101,11 +4180,14 @@ Qualifiers AbstractManglingParser<Alloc, Derived>::parseCVQualifiers() {
// ::= fp <top-level CV-Qualifiers> <parameter-2 non-negative number> _ # L == 0, second and later parameters
// ::= fL <L-1 non-negative number> p <top-level CV-Qualifiers> _ # L > 0, first parameter
// ::= fL <L-1 non-negative number> p <top-level CV-Qualifiers> <parameter-2 non-negative number> _ # L > 0, second and later parameters
+// ::= fpT # 'this' expression (not part of standard?)
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseFunctionParam() {
+ if (consumeIf("fpT"))
+ return make<NameType>("this");
if (consumeIf("fp")) {
parseCVQualifiers();
- StringView Num = parseNumber();
+ std::string_view Num = parseNumber();
if (!consumeIf('_'))
return nullptr;
return make<FunctionParam>(Num);
@@ -4116,7 +4198,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFunctionParam() {
if (!consumeIf('p'))
return nullptr;
parseCVQualifiers();
- StringView Num = parseNumber();
+ std::string_view Num = parseNumber();
if (!consumeIf('_'))
return nullptr;
return make<FunctionParam>(Num);
@@ -4124,43 +4206,6 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFunctionParam() {
return nullptr;
}
-// [gs] nw <expression>* _ <type> E # new (expr-list) type
-// [gs] nw <expression>* _ <type> <initializer> # new (expr-list) type (init)
-// [gs] na <expression>* _ <type> E # new[] (expr-list) type
-// [gs] na <expression>* _ <type> <initializer> # new[] (expr-list) type (init)
-// <initializer> ::= pi <expression>* E # parenthesized initialization
-template <typename Derived, typename Alloc>
-Node *AbstractManglingParser<Derived, Alloc>::parseNewExpr() {
- bool Global = consumeIf("gs");
- bool IsArray = look(1) == 'a';
- if (!consumeIf("nw") && !consumeIf("na"))
- return nullptr;
- size_t Exprs = Names.size();
- while (!consumeIf('_')) {
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return nullptr;
- Names.push_back(Ex);
- }
- NodeArray ExprList = popTrailingNodeArray(Exprs);
- Node *Ty = getDerived().parseType();
- if (Ty == nullptr)
- return Ty;
- if (consumeIf("pi")) {
- size_t InitsBegin = Names.size();
- while (!consumeIf('E')) {
- Node *Init = getDerived().parseExpr();
- if (Init == nullptr)
- return Init;
- Names.push_back(Init);
- }
- NodeArray Inits = popTrailingNodeArray(InitsBegin);
- return make<NewExpr>(ExprList, Ty, Inits, Global, IsArray);
- } else if (!consumeIf('E'))
- return nullptr;
- return make<NewExpr>(ExprList, Ty, NodeArray(), Global, IsArray);
-}
-
// cv <type> <expression> # conversion with one argument
// cv <type> _ <expression>* E # conversion with a different number of arguments
template <typename Derived, typename Alloc>
@@ -4169,7 +4214,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseConversionExpr() {
return nullptr;
Node *Ty;
{
- SwapAndRestore<bool> SaveTemp(TryToParseTemplateArgs, false);
+ ScopedOverride<bool> SaveTemp(TryToParseTemplateArgs, false);
Ty = getDerived().parseType();
}
@@ -4262,7 +4307,13 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExprPrimary() {
return getDerived().template parseFloatingLiteral<double>();
case 'e':
++First;
+#if defined(__powerpc__) || defined(__s390__)
+ // Handle cases where long doubles encoded with e have the same size
+ // and representation as doubles.
+ return getDerived().template parseFloatingLiteral<double>();
+#else
return getDerived().template parseFloatingLiteral<long double>();
+#endif
case '_':
if (consumeIf("_Z")) {
Node *R = getDerived().parseEncoding();
@@ -4280,7 +4331,7 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExprPrimary() {
return nullptr;
}
case 'D':
- if (consumeIf("DnE"))
+ if (consumeIf("Dn") && (consumeIf('0'), consumeIf('E')))
return make<NameType>("nullptr");
return nullptr;
case 'T':
@@ -4301,12 +4352,12 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExprPrimary() {
Node *T = getDerived().parseType();
if (T == nullptr)
return nullptr;
- StringView N = parseNumber();
+ std::string_view N = parseNumber(/*AllowNegative=*/true);
if (N.empty())
return nullptr;
if (!consumeIf('E'))
return nullptr;
- return make<IntegerCastExpr>(T, N);
+ return make<EnumLiteral>(T, N);
}
}
}
@@ -4367,55 +4418,38 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFoldExpr() {
if (!consumeIf('f'))
return nullptr;
- char FoldKind = look();
- bool IsLeftFold, HasInitializer;
- HasInitializer = FoldKind == 'L' || FoldKind == 'R';
- if (FoldKind == 'l' || FoldKind == 'L')
- IsLeftFold = true;
- else if (FoldKind == 'r' || FoldKind == 'R')
- IsLeftFold = false;
- else
+ bool IsLeftFold = false, HasInitializer = false;
+ switch (look()) {
+ default:
return nullptr;
+ case 'L':
+ IsLeftFold = true;
+ HasInitializer = true;
+ break;
+ case 'R':
+ HasInitializer = true;
+ break;
+ case 'l':
+ IsLeftFold = true;
+ break;
+ case 'r':
+ break;
+ }
++First;
- // FIXME: This map is duplicated in parseOperatorName and parseExpr.
- StringView OperatorName;
- if (consumeIf("aa")) OperatorName = "&&";
- else if (consumeIf("an")) OperatorName = "&";
- else if (consumeIf("aN")) OperatorName = "&=";
- else if (consumeIf("aS")) OperatorName = "=";
- else if (consumeIf("cm")) OperatorName = ",";
- else if (consumeIf("ds")) OperatorName = ".*";
- else if (consumeIf("dv")) OperatorName = "/";
- else if (consumeIf("dV")) OperatorName = "/=";
- else if (consumeIf("eo")) OperatorName = "^";
- else if (consumeIf("eO")) OperatorName = "^=";
- else if (consumeIf("eq")) OperatorName = "==";
- else if (consumeIf("ge")) OperatorName = ">=";
- else if (consumeIf("gt")) OperatorName = ">";
- else if (consumeIf("le")) OperatorName = "<=";
- else if (consumeIf("ls")) OperatorName = "<<";
- else if (consumeIf("lS")) OperatorName = "<<=";
- else if (consumeIf("lt")) OperatorName = "<";
- else if (consumeIf("mi")) OperatorName = "-";
- else if (consumeIf("mI")) OperatorName = "-=";
- else if (consumeIf("ml")) OperatorName = "*";
- else if (consumeIf("mL")) OperatorName = "*=";
- else if (consumeIf("ne")) OperatorName = "!=";
- else if (consumeIf("oo")) OperatorName = "||";
- else if (consumeIf("or")) OperatorName = "|";
- else if (consumeIf("oR")) OperatorName = "|=";
- else if (consumeIf("pl")) OperatorName = "+";
- else if (consumeIf("pL")) OperatorName = "+=";
- else if (consumeIf("rm")) OperatorName = "%";
- else if (consumeIf("rM")) OperatorName = "%=";
- else if (consumeIf("rs")) OperatorName = ">>";
- else if (consumeIf("rS")) OperatorName = ">>=";
- else return nullptr;
-
- Node *Pack = getDerived().parseExpr(), *Init = nullptr;
+ const auto *Op = parseOperatorEncoding();
+ if (!Op)
+ return nullptr;
+ if (!(Op->getKind() == OperatorInfo::Binary
+ || (Op->getKind() == OperatorInfo::Member
+ && Op->getName().back() == '*')))
+ return nullptr;
+
+ Node *Pack = getDerived().parseExpr();
if (Pack == nullptr)
return nullptr;
+
+ Node *Init = nullptr;
if (HasInitializer) {
Init = getDerived().parseExpr();
if (Init == nullptr)
@@ -4425,7 +4459,53 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFoldExpr() {
if (IsLeftFold && Init)
std::swap(Pack, Init);
- return make<FoldExpr>(IsLeftFold, OperatorName, Pack, Init);
+ return make<FoldExpr>(IsLeftFold, Op->getSymbol(), Pack, Init);
+}
+
+// <expression> ::= mc <parameter type> <expr> [<offset number>] E
+//
+// Not yet in the spec: https://github.com/itanium-cxx-abi/cxx-abi/issues/47
+template <typename Derived, typename Alloc>
+Node *
+AbstractManglingParser<Derived, Alloc>::parsePointerToMemberConversionExpr(
+ Node::Prec Prec) {
+ Node *Ty = getDerived().parseType();
+ if (!Ty)
+ return nullptr;
+ Node *Expr = getDerived().parseExpr();
+ if (!Expr)
+ return nullptr;
+ std::string_view Offset = getDerived().parseNumber(true);
+ if (!consumeIf('E'))
+ return nullptr;
+ return make<PointerToMemberConversionExpr>(Ty, Expr, Offset, Prec);
+}
+
+// <expression> ::= so <referent type> <expr> [<offset number>] <union-selector>* [p] E
+// <union-selector> ::= _ [<number>]
+//
+// Not yet in the spec: https://github.com/itanium-cxx-abi/cxx-abi/issues/47
+template <typename Derived, typename Alloc>
+Node *AbstractManglingParser<Derived, Alloc>::parseSubobjectExpr() {
+ Node *Ty = getDerived().parseType();
+ if (!Ty)
+ return nullptr;
+ Node *Expr = getDerived().parseExpr();
+ if (!Expr)
+ return nullptr;
+ std::string_view Offset = getDerived().parseNumber(true);
+ size_t SelectorsBegin = Names.size();
+ while (consumeIf('_')) {
+ Node *Selector = make<NameType>(parseNumber());
+ if (!Selector)
+ return nullptr;
+ Names.push_back(Selector);
+ }
+ bool OnePastTheEnd = consumeIf('p');
+ if (!consumeIf('E'))
+ return nullptr;
+ return make<SubobjectExpr>(
+ Ty, Expr, Offset, popTrailingNodeArray(SelectorsBegin), OnePastTheEnd);
}
// <expression> ::= <unary operator-name> <expression>
@@ -4475,313 +4555,127 @@ Node *AbstractManglingParser<Derived, Alloc>::parseFoldExpr() {
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseExpr() {
bool Global = consumeIf("gs");
- if (numLeft() < 2)
- return nullptr;
- switch (*First) {
- case 'L':
- return getDerived().parseExprPrimary();
- case 'T':
- return getDerived().parseTemplateParam();
- case 'f': {
- // Disambiguate a fold expression from a <function-param>.
- if (look(1) == 'p' || (look(1) == 'L' && std::isdigit(look(2))))
- return getDerived().parseFunctionParam();
- return getDerived().parseFoldExpr();
- }
- case 'a':
- switch (First[1]) {
- case 'a':
- First += 2;
- return getDerived().parseBinaryExpr("&&");
- case 'd':
- First += 2;
- return getDerived().parsePrefixExpr("&");
- case 'n':
- First += 2;
- return getDerived().parseBinaryExpr("&");
- case 'N':
- First += 2;
- return getDerived().parseBinaryExpr("&=");
- case 'S':
- First += 2;
- return getDerived().parseBinaryExpr("=");
- case 't': {
- First += 2;
- Node *Ty = getDerived().parseType();
- if (Ty == nullptr)
- return nullptr;
- return make<EnclosingExpr>("alignof (", Ty, ")");
- }
- case 'z': {
- First += 2;
- Node *Ty = getDerived().parseExpr();
- if (Ty == nullptr)
- return nullptr;
- return make<EnclosingExpr>("alignof (", Ty, ")");
- }
- }
- return nullptr;
- case 'c':
- switch (First[1]) {
- // cc <type> <expression> # const_cast<type>(expression)
- case 'c': {
- First += 2;
- Node *Ty = getDerived().parseType();
- if (Ty == nullptr)
- return Ty;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<CastExpr>("const_cast", Ty, Ex);
- }
- // cl <expression>+ E # call
- case 'l': {
- First += 2;
- Node *Callee = getDerived().parseExpr();
- if (Callee == nullptr)
- return Callee;
- size_t ExprsBegin = Names.size();
- while (!consumeIf('E')) {
- Node *E = getDerived().parseExpr();
- if (E == nullptr)
- return E;
- Names.push_back(E);
- }
- return make<CallExpr>(Callee, popTrailingNodeArray(ExprsBegin));
- }
- case 'm':
- First += 2;
- return getDerived().parseBinaryExpr(",");
- case 'o':
- First += 2;
- return getDerived().parsePrefixExpr("~");
- case 'v':
- return getDerived().parseConversionExpr();
- }
- return nullptr;
- case 'd':
- switch (First[1]) {
- case 'a': {
- First += 2;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<DeleteExpr>(Ex, Global, /*is_array=*/true);
- }
- case 'c': {
- First += 2;
- Node *T = getDerived().parseType();
- if (T == nullptr)
- return T;
+ const auto *Op = parseOperatorEncoding();
+ if (Op) {
+ auto Sym = Op->getSymbol();
+ switch (Op->getKind()) {
+ case OperatorInfo::Binary:
+ // Binary operator: lhs @ rhs
+ return getDerived().parseBinaryExpr(Sym, Op->getPrecedence());
+ case OperatorInfo::Prefix:
+ // Prefix unary operator: @ expr
+ return getDerived().parsePrefixExpr(Sym, Op->getPrecedence());
+ case OperatorInfo::Postfix: {
+ // Postfix unary operator: expr @
+ if (consumeIf('_'))
+ return getDerived().parsePrefixExpr(Sym, Op->getPrecedence());
Node *Ex = getDerived().parseExpr();
if (Ex == nullptr)
- return Ex;
- return make<CastExpr>("dynamic_cast", T, Ex);
- }
- case 'e':
- First += 2;
- return getDerived().parsePrefixExpr("*");
- case 'l': {
- First += 2;
- Node *E = getDerived().parseExpr();
- if (E == nullptr)
- return E;
- return make<DeleteExpr>(E, Global, /*is_array=*/false);
+ return nullptr;
+ return make<PostfixExpr>(Ex, Sym, Op->getPrecedence());
}
- case 'n':
- return getDerived().parseUnresolvedName();
- case 's': {
- First += 2;
- Node *LHS = getDerived().parseExpr();
- if (LHS == nullptr)
+ case OperatorInfo::Array: {
+ // Array Index: lhs [ rhs ]
+ Node *Base = getDerived().parseExpr();
+ if (Base == nullptr)
return nullptr;
- Node *RHS = getDerived().parseExpr();
- if (RHS == nullptr)
+ Node *Index = getDerived().parseExpr();
+ if (Index == nullptr)
return nullptr;
- return make<MemberExpr>(LHS, ".*", RHS);
+ return make<ArraySubscriptExpr>(Base, Index, Op->getPrecedence());
}
- case 't': {
- First += 2;
+ case OperatorInfo::Member: {
+ // Member access lhs @ rhs
Node *LHS = getDerived().parseExpr();
if (LHS == nullptr)
- return LHS;
+ return nullptr;
Node *RHS = getDerived().parseExpr();
if (RHS == nullptr)
return nullptr;
- return make<MemberExpr>(LHS, ".", RHS);
- }
- case 'v':
- First += 2;
- return getDerived().parseBinaryExpr("/");
- case 'V':
- First += 2;
- return getDerived().parseBinaryExpr("/=");
- }
- return nullptr;
- case 'e':
- switch (First[1]) {
- case 'o':
- First += 2;
- return getDerived().parseBinaryExpr("^");
- case 'O':
- First += 2;
- return getDerived().parseBinaryExpr("^=");
- case 'q':
- First += 2;
- return getDerived().parseBinaryExpr("==");
- }
- return nullptr;
- case 'g':
- switch (First[1]) {
- case 'e':
- First += 2;
- return getDerived().parseBinaryExpr(">=");
- case 't':
- First += 2;
- return getDerived().parseBinaryExpr(">");
- }
- return nullptr;
- case 'i':
- switch (First[1]) {
- case 'x': {
- First += 2;
- Node *Base = getDerived().parseExpr();
- if (Base == nullptr)
+ return make<MemberExpr>(LHS, Sym, RHS, Op->getPrecedence());
+ }
+ case OperatorInfo::New: {
+ // New
+ // # new (expr-list) type [(init)]
+ // [gs] nw <expression>* _ <type> [pi <expression>*] E
+ // # new[] (expr-list) type [(init)]
+ // [gs] na <expression>* _ <type> [pi <expression>*] E
+ size_t Exprs = Names.size();
+ while (!consumeIf('_')) {
+ Node *Ex = getDerived().parseExpr();
+ if (Ex == nullptr)
+ return nullptr;
+ Names.push_back(Ex);
+ }
+ NodeArray ExprList = popTrailingNodeArray(Exprs);
+ Node *Ty = getDerived().parseType();
+ if (Ty == nullptr)
return nullptr;
- Node *Index = getDerived().parseExpr();
- if (Index == nullptr)
- return Index;
- return make<ArraySubscriptExpr>(Base, Index);
- }
- case 'l': {
- First += 2;
+ bool HaveInits = consumeIf("pi");
size_t InitsBegin = Names.size();
while (!consumeIf('E')) {
- Node *E = getDerived().parseBracedExpr();
- if (E == nullptr)
+ if (!HaveInits)
return nullptr;
- Names.push_back(E);
+ Node *Init = getDerived().parseExpr();
+ if (Init == nullptr)
+ return Init;
+ Names.push_back(Init);
}
- return make<InitListExpr>(nullptr, popTrailingNodeArray(InitsBegin));
- }
+ NodeArray Inits = popTrailingNodeArray(InitsBegin);
+ return make<NewExpr>(ExprList, Ty, Inits, Global,
+ /*IsArray=*/Op->getFlag(), Op->getPrecedence());
}
- return nullptr;
- case 'l':
- switch (First[1]) {
- case 'e':
- First += 2;
- return getDerived().parseBinaryExpr("<=");
- case 's':
- First += 2;
- return getDerived().parseBinaryExpr("<<");
- case 'S':
- First += 2;
- return getDerived().parseBinaryExpr("<<=");
- case 't':
- First += 2;
- return getDerived().parseBinaryExpr("<");
- }
- return nullptr;
- case 'm':
- switch (First[1]) {
- case 'i':
- First += 2;
- return getDerived().parseBinaryExpr("-");
- case 'I':
- First += 2;
- return getDerived().parseBinaryExpr("-=");
- case 'l':
- First += 2;
- return getDerived().parseBinaryExpr("*");
- case 'L':
- First += 2;
- return getDerived().parseBinaryExpr("*=");
- case 'm':
- First += 2;
- if (consumeIf('_'))
- return getDerived().parsePrefixExpr("--");
+ case OperatorInfo::Del: {
+ // Delete
Node *Ex = getDerived().parseExpr();
if (Ex == nullptr)
return nullptr;
- return make<PostfixExpr>(Ex, "--");
- }
- return nullptr;
- case 'n':
- switch (First[1]) {
- case 'a':
- case 'w':
- return getDerived().parseNewExpr();
- case 'e':
- First += 2;
- return getDerived().parseBinaryExpr("!=");
- case 'g':
- First += 2;
- return getDerived().parsePrefixExpr("-");
- case 't':
- First += 2;
- return getDerived().parsePrefixExpr("!");
- case 'x':
- First += 2;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<EnclosingExpr>("noexcept (", Ex, ")");
- }
- return nullptr;
- case 'o':
- switch (First[1]) {
- case 'n':
- return getDerived().parseUnresolvedName();
- case 'o':
- First += 2;
- return getDerived().parseBinaryExpr("||");
- case 'r':
- First += 2;
- return getDerived().parseBinaryExpr("|");
- case 'R':
- First += 2;
- return getDerived().parseBinaryExpr("|=");
+ return make<DeleteExpr>(Ex, Global, /*IsArray=*/Op->getFlag(),
+ Op->getPrecedence());
}
- return nullptr;
- case 'p':
- switch (First[1]) {
- case 'm':
- First += 2;
- return getDerived().parseBinaryExpr("->*");
- case 'l':
- First += 2;
- return getDerived().parseBinaryExpr("+");
- case 'L':
- First += 2;
- return getDerived().parseBinaryExpr("+=");
- case 'p': {
- First += 2;
- if (consumeIf('_'))
- return getDerived().parsePrefixExpr("++");
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<PostfixExpr>(Ex, "++");
+ case OperatorInfo::Call: {
+ // Function Call
+ Node *Callee = getDerived().parseExpr();
+ if (Callee == nullptr)
+ return nullptr;
+ size_t ExprsBegin = Names.size();
+ while (!consumeIf('E')) {
+ Node *E = getDerived().parseExpr();
+ if (E == nullptr)
+ return nullptr;
+ Names.push_back(E);
+ }
+ return make<CallExpr>(Callee, popTrailingNodeArray(ExprsBegin),
+ Op->getPrecedence());
}
- case 's':
- First += 2;
- return getDerived().parsePrefixExpr("+");
- case 't': {
- First += 2;
- Node *L = getDerived().parseExpr();
- if (L == nullptr)
+ case OperatorInfo::CCast: {
+ // C Cast: (type)expr
+ Node *Ty;
+ {
+ ScopedOverride<bool> SaveTemp(TryToParseTemplateArgs, false);
+ Ty = getDerived().parseType();
+ }
+ if (Ty == nullptr)
return nullptr;
- Node *R = getDerived().parseExpr();
- if (R == nullptr)
+
+ size_t ExprsBegin = Names.size();
+ bool IsMany = consumeIf('_');
+ while (!consumeIf('E')) {
+ Node *E = getDerived().parseExpr();
+ if (E == nullptr)
+ return E;
+ Names.push_back(E);
+ if (!IsMany)
+ break;
+ }
+ NodeArray Exprs = popTrailingNodeArray(ExprsBegin);
+ if (!IsMany && Exprs.size() != 1)
return nullptr;
- return make<MemberExpr>(L, "->", R);
+ return make<ConversionExpr>(Ty, Exprs, Op->getPrecedence());
}
- }
- return nullptr;
- case 'q':
- if (First[1] == 'u') {
- First += 2;
+ case OperatorInfo::Conditional: {
+ // Conditional operator: expr ? expr : expr
Node *Cond = getDerived().parseExpr();
if (Cond == nullptr)
return nullptr;
@@ -4791,169 +4685,158 @@ Node *AbstractManglingParser<Derived, Alloc>::parseExpr() {
Node *RHS = getDerived().parseExpr();
if (RHS == nullptr)
return nullptr;
- return make<ConditionalExpr>(Cond, LHS, RHS);
- }
- return nullptr;
- case 'r':
- switch (First[1]) {
- case 'c': {
- First += 2;
- Node *T = getDerived().parseType();
- if (T == nullptr)
- return T;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<CastExpr>("reinterpret_cast", T, Ex);
+ return make<ConditionalExpr>(Cond, LHS, RHS, Op->getPrecedence());
}
- case 'm':
- First += 2;
- return getDerived().parseBinaryExpr("%");
- case 'M':
- First += 2;
- return getDerived().parseBinaryExpr("%=");
- case 's':
- First += 2;
- return getDerived().parseBinaryExpr(">>");
- case 'S':
- First += 2;
- return getDerived().parseBinaryExpr(">>=");
- }
- return nullptr;
- case 's':
- switch (First[1]) {
- case 'c': {
- First += 2;
- Node *T = getDerived().parseType();
- if (T == nullptr)
- return T;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<CastExpr>("static_cast", T, Ex);
- }
- case 'p': {
- First += 2;
- Node *Child = getDerived().parseExpr();
- if (Child == nullptr)
- return nullptr;
- return make<ParameterPackExpansion>(Child);
- }
- case 'r':
- return getDerived().parseUnresolvedName();
- case 't': {
- First += 2;
+ case OperatorInfo::NamedCast: {
+ // Named cast operation, @<type>(expr)
Node *Ty = getDerived().parseType();
if (Ty == nullptr)
- return Ty;
- return make<EnclosingExpr>("sizeof (", Ty, ")");
- }
- case 'z': {
- First += 2;
+ return nullptr;
Node *Ex = getDerived().parseExpr();
if (Ex == nullptr)
- return Ex;
- return make<EnclosingExpr>("sizeof (", Ex, ")");
- }
- case 'Z':
- First += 2;
- if (look() == 'T') {
- Node *R = getDerived().parseTemplateParam();
- if (R == nullptr)
- return nullptr;
- return make<SizeofParamPackExpr>(R);
- } else if (look() == 'f') {
- Node *FP = getDerived().parseFunctionParam();
- if (FP == nullptr)
- return nullptr;
- return make<EnclosingExpr>("sizeof... (", FP, ")");
- }
- return nullptr;
- case 'P': {
- First += 2;
- size_t ArgsBegin = Names.size();
- while (!consumeIf('E')) {
- Node *Arg = getDerived().parseTemplateArg();
- if (Arg == nullptr)
- return nullptr;
- Names.push_back(Arg);
- }
- auto *Pack = make<NodeArrayNode>(popTrailingNodeArray(ArgsBegin));
- if (!Pack)
return nullptr;
- return make<EnclosingExpr>("sizeof... (", Pack, ")");
+ return make<CastExpr>(Sym, Ty, Ex, Op->getPrecedence());
}
+ case OperatorInfo::OfIdOp: {
+ // [sizeof/alignof/typeid] ( <type>|<expr> )
+ Node *Arg =
+ Op->getFlag() ? getDerived().parseType() : getDerived().parseExpr();
+ if (!Arg)
+ return nullptr;
+ return make<EnclosingExpr>(Sym, Arg, Op->getPrecedence());
}
- return nullptr;
- case 't':
- switch (First[1]) {
- case 'e': {
- First += 2;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
- return Ex;
- return make<EnclosingExpr>("typeid (", Ex, ")");
+ case OperatorInfo::NameOnly: {
+ // Not valid as an expression operand.
+ return nullptr;
}
- case 'i': {
- First += 2;
- Node *Ty = getDerived().parseType();
- if (Ty == nullptr)
- return Ty;
- return make<EnclosingExpr>("typeid (", Ty, ")");
}
- case 'l': {
- First += 2;
- Node *Ty = getDerived().parseType();
- if (Ty == nullptr)
+ DEMANGLE_UNREACHABLE;
+ }
+
+ if (numLeft() < 2)
+ return nullptr;
+
+ if (look() == 'L')
+ return getDerived().parseExprPrimary();
+ if (look() == 'T')
+ return getDerived().parseTemplateParam();
+ if (look() == 'f') {
+ // Disambiguate a fold expression from a <function-param>.
+ if (look(1) == 'p' || (look(1) == 'L' && std::isdigit(look(2))))
+ return getDerived().parseFunctionParam();
+ return getDerived().parseFoldExpr();
+ }
+ if (consumeIf("il")) {
+ size_t InitsBegin = Names.size();
+ while (!consumeIf('E')) {
+ Node *E = getDerived().parseBracedExpr();
+ if (E == nullptr)
return nullptr;
- size_t InitsBegin = Names.size();
- while (!consumeIf('E')) {
- Node *E = getDerived().parseBracedExpr();
- if (E == nullptr)
- return nullptr;
- Names.push_back(E);
- }
- return make<InitListExpr>(Ty, popTrailingNodeArray(InitsBegin));
+ Names.push_back(E);
}
- case 'r':
- First += 2;
- return make<NameType>("throw");
- case 'w': {
- First += 2;
- Node *Ex = getDerived().parseExpr();
- if (Ex == nullptr)
+ return make<InitListExpr>(nullptr, popTrailingNodeArray(InitsBegin));
+ }
+ if (consumeIf("mc"))
+ return parsePointerToMemberConversionExpr(Node::Prec::Unary);
+ if (consumeIf("nx")) {
+ Node *Ex = getDerived().parseExpr();
+ if (Ex == nullptr)
+ return Ex;
+ return make<EnclosingExpr>("noexcept ", Ex, Node::Prec::Unary);
+ }
+ if (consumeIf("so"))
+ return parseSubobjectExpr();
+ if (consumeIf("sp")) {
+ Node *Child = getDerived().parseExpr();
+ if (Child == nullptr)
+ return nullptr;
+ return make<ParameterPackExpansion>(Child);
+ }
+ if (consumeIf("sZ")) {
+ if (look() == 'T') {
+ Node *R = getDerived().parseTemplateParam();
+ if (R == nullptr)
return nullptr;
- return make<ThrowExpr>(Ex);
+ return make<SizeofParamPackExpr>(R);
}
+ Node *FP = getDerived().parseFunctionParam();
+ if (FP == nullptr)
+ return nullptr;
+ return make<EnclosingExpr>("sizeof... ", FP);
+ }
+ if (consumeIf("sP")) {
+ size_t ArgsBegin = Names.size();
+ while (!consumeIf('E')) {
+ Node *Arg = getDerived().parseTemplateArg();
+ if (Arg == nullptr)
+ return nullptr;
+ Names.push_back(Arg);
}
- return nullptr;
- case '1':
- case '2':
- case '3':
- case '4':
- case '5':
- case '6':
- case '7':
- case '8':
- case '9':
- return getDerived().parseUnresolvedName();
- }
-
- if (consumeIf("u8__uuidoft")) {
+ auto *Pack = make<NodeArrayNode>(popTrailingNodeArray(ArgsBegin));
+ if (!Pack)
+ return nullptr;
+ return make<EnclosingExpr>("sizeof... ", Pack);
+ }
+ if (consumeIf("tl")) {
Node *Ty = getDerived().parseType();
- if (!Ty)
+ if (Ty == nullptr)
return nullptr;
- return make<UUIDOfExpr>(Ty);
+ size_t InitsBegin = Names.size();
+ while (!consumeIf('E')) {
+ Node *E = getDerived().parseBracedExpr();
+ if (E == nullptr)
+ return nullptr;
+ Names.push_back(E);
+ }
+ return make<InitListExpr>(Ty, popTrailingNodeArray(InitsBegin));
}
-
- if (consumeIf("u8__uuidofz")) {
+ if (consumeIf("tr"))
+ return make<NameType>("throw");
+ if (consumeIf("tw")) {
Node *Ex = getDerived().parseExpr();
- if (!Ex)
+ if (Ex == nullptr)
return nullptr;
- return make<UUIDOfExpr>(Ex);
+ return make<ThrowExpr>(Ex);
+ }
+ if (consumeIf('u')) {
+ Node *Name = getDerived().parseSourceName(/*NameState=*/nullptr);
+ if (!Name)
+ return nullptr;
+ // Special case legacy __uuidof mangling. The 't' and 'z' appear where the
+ // standard encoding expects a <template-arg>, and would be otherwise be
+ // interpreted as <type> node 'short' or 'ellipsis'. However, neither
+ // __uuidof(short) nor __uuidof(...) can actually appear, so there is no
+ // actual conflict here.
+ bool IsUUID = false;
+ Node *UUID = nullptr;
+ if (Name->getBaseName() == "__uuidof") {
+ if (consumeIf('t')) {
+ UUID = getDerived().parseType();
+ IsUUID = true;
+ } else if (consumeIf('z')) {
+ UUID = getDerived().parseExpr();
+ IsUUID = true;
+ }
+ }
+ size_t ExprsBegin = Names.size();
+ if (IsUUID) {
+ if (UUID == nullptr)
+ return nullptr;
+ Names.push_back(UUID);
+ } else {
+ while (!consumeIf('E')) {
+ Node *E = getDerived().parseTemplateArg();
+ if (E == nullptr)
+ return E;
+ Names.push_back(E);
+ }
+ }
+ return make<CallExpr>(Name, popTrailingNodeArray(ExprsBegin),
+ Node::Prec::Postfix);
}
- return nullptr;
+ // Only unresolved names remain.
+ return getDerived().parseUnresolvedName(Global);
}
// <call-offset> ::= h <nv-offset> _
@@ -4986,19 +4869,32 @@ bool AbstractManglingParser<Alloc, Derived>::parseCallOffset() {
// # second call-offset is result adjustment
// ::= T <call-offset> <base encoding>
// # base is the nominal target function of thunk
-// ::= GV <object name> # Guard variable for one-time initialization
+// # Guard variable for one-time initialization
+// ::= GV <object name>
// # No <type>
// ::= TW <object name> # Thread-local wrapper
// ::= TH <object name> # Thread-local initialization
// ::= GR <object name> _ # First temporary
// ::= GR <object name> <seq-id> _ # Subsequent temporaries
-// extension ::= TC <first type> <number> _ <second type> # construction vtable for second-in-first
+// # construction vtable for second-in-first
+// extension ::= TC <first type> <number> _ <second type>
// extension ::= GR <object name> # reference temporary for object
+// extension ::= GI <module name> # module global initializer
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() {
switch (look()) {
case 'T':
switch (look(1)) {
+ // TA <template-arg> # template parameter object
+ //
+ // Not yet in the spec: https://github.com/itanium-cxx-abi/cxx-abi/issues/63
+ case 'A': {
+ First += 2;
+ Node *Arg = getDerived().parseTemplateArg();
+ if (Arg == nullptr)
+ return nullptr;
+ return make<SpecialName>("template parameter object for ", Arg);
+ }
// TV <type> # virtual table
case 'V': {
First += 2;
@@ -5110,6 +5006,16 @@ Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() {
return nullptr;
return make<SpecialName>("reference temporary for ", Name);
}
+ // GI <module-name> v
+ case 'I': {
+ First += 2;
+ ModuleName *Module = nullptr;
+ if (getDerived().parseModuleNameOpt(Module))
+ return nullptr;
+ if (Module == nullptr)
+ return nullptr;
+ return make<SpecialName>("initializer for module ", Module);
+ }
}
}
return nullptr;
@@ -5120,6 +5026,26 @@ Node *AbstractManglingParser<Derived, Alloc>::parseSpecialName() {
// ::= <special-name>
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseEncoding() {
+ // The template parameters of an encoding are unrelated to those of the
+ // enclosing context.
+ class SaveTemplateParams {
+ AbstractManglingParser *Parser;
+ decltype(TemplateParams) OldParams;
+ decltype(OuterTemplateParams) OldOuterParams;
+
+ public:
+ SaveTemplateParams(AbstractManglingParser *TheParser) : Parser(TheParser) {
+ OldParams = std::move(Parser->TemplateParams);
+ OldOuterParams = std::move(Parser->OuterTemplateParams);
+ Parser->TemplateParams.clear();
+ Parser->OuterTemplateParams.clear();
+ }
+ ~SaveTemplateParams() {
+ Parser->TemplateParams = std::move(OldParams);
+ Parser->OuterTemplateParams = std::move(OldOuterParams);
+ }
+ } SaveTemplateParams(this);
+
if (look() == 'G' || look() == 'T')
return getDerived().parseSpecialName();
@@ -5204,14 +5130,19 @@ template <>
struct FloatData<long double>
{
#if defined(__mips__) && defined(__mips_n64) || defined(__aarch64__) || \
- defined(__wasm__)
+ defined(__wasm__) || defined(__riscv) || defined(__loongarch__)
static const size_t mangled_size = 32;
#elif defined(__arm__) || defined(__mips__) || defined(__hexagon__)
static const size_t mangled_size = 16;
#else
static const size_t mangled_size = 20; // May need to be adjusted to 16 or 24 on other platforms
#endif
- static const size_t max_demangled_size = 40;
+ // `-0x1.ffffffffffffffffffffffffffffp+16383` + 'L' + '\0' == 42 bytes.
+ // 28 'f's * 4 bits == 112 bits, which is the number of mantissa bits.
+ // Negatives are one character longer than positives.
+ // `0x1.` and `p` are constant, and exponents `+16383` and `-16382` are the
+ // same length. 1 sign bit, 112 mantissa bits, and 15 exponent bits == 128.
+ static const size_t max_demangled_size = 42;
static constexpr const char *spec = "%LaL";
};
@@ -5221,7 +5152,7 @@ Node *AbstractManglingParser<Alloc, Derived>::parseFloatingLiteral() {
const size_t N = FloatData<Float>::mangled_size;
if (numLeft() <= N)
return nullptr;
- StringView Data(First, First + N);
+ std::string_view Data(First, N);
for (char C : Data)
if (!std::isxdigit(C))
return nullptr;
@@ -5264,43 +5195,41 @@ bool AbstractManglingParser<Alloc, Derived>::parseSeqId(size_t *Out) {
// <substitution> ::= Si # ::std::basic_istream<char, std::char_traits<char> >
// <substitution> ::= So # ::std::basic_ostream<char, std::char_traits<char> >
// <substitution> ::= Sd # ::std::basic_iostream<char, std::char_traits<char> >
+// The St case is handled specially in parseNestedName.
template <typename Derived, typename Alloc>
Node *AbstractManglingParser<Derived, Alloc>::parseSubstitution() {
if (!consumeIf('S'))
return nullptr;
- if (std::islower(look())) {
- Node *SpecialSub;
+ if (look() >= 'a' && look() <= 'z') {
+ SpecialSubKind Kind;
switch (look()) {
case 'a':
- ++First;
- SpecialSub = make<SpecialSubstitution>(SpecialSubKind::allocator);
+ Kind = SpecialSubKind::allocator;
break;
case 'b':
- ++First;
- SpecialSub = make<SpecialSubstitution>(SpecialSubKind::basic_string);
+ Kind = SpecialSubKind::basic_string;
break;
- case 's':
- ++First;
- SpecialSub = make<SpecialSubstitution>(SpecialSubKind::string);
+ case 'd':
+ Kind = SpecialSubKind::iostream;
break;
case 'i':
- ++First;
- SpecialSub = make<SpecialSubstitution>(SpecialSubKind::istream);
+ Kind = SpecialSubKind::istream;
break;
case 'o':
- ++First;
- SpecialSub = make<SpecialSubstitution>(SpecialSubKind::ostream);
+ Kind = SpecialSubKind::ostream;
break;
- case 'd':
- ++First;
- SpecialSub = make<SpecialSubstitution>(SpecialSubKind::iostream);
+ case 's':
+ Kind = SpecialSubKind::string;
break;
default:
return nullptr;
}
+ ++First;
+ auto *SpecialSub = make<SpecialSubstitution>(Kind);
if (!SpecialSub)
return nullptr;
+
// Itanium C++ ABI 5.1.2: If a name that would use a built-in <substitution>
// has ABI tags, the tags are appended to the substitution; the result is a
// substitutable component.
@@ -5543,7 +5472,8 @@ Node *AbstractManglingParser<Derived, Alloc>::parse() {
if (Encoding == nullptr)
return nullptr;
if (look() == '.') {
- Encoding = make<DotSuffix>(Encoding, StringView(First, Last));
+ Encoding =
+ make<DotSuffix>(Encoding, std::string_view(First, Last - First));
First = Last;
}
if (numLeft() != 0)
diff --git a/externals/demangle/llvm/Demangle/ItaniumNodes.def b/externals/demangle/llvm/Demangle/ItaniumNodes.def
new file mode 100644
index 000000000..5985769ef
--- /dev/null
+++ b/externals/demangle/llvm/Demangle/ItaniumNodes.def
@@ -0,0 +1,96 @@
+//===--- ItaniumNodes.def ------------*- mode:c++;eval:(read-only-mode) -*-===//
+// Do not edit! See README.txt.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-FileCopyrightText: Part of the LLVM Project
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Define the demangler's node names
+
+#ifndef NODE
+#error Define NODE to handle nodes
+#endif
+
+NODE(NodeArrayNode)
+NODE(DotSuffix)
+NODE(VendorExtQualType)
+NODE(QualType)
+NODE(ConversionOperatorType)
+NODE(PostfixQualifiedType)
+NODE(ElaboratedTypeSpefType)
+NODE(NameType)
+NODE(AbiTagAttr)
+NODE(EnableIfAttr)
+NODE(ObjCProtoName)
+NODE(PointerType)
+NODE(ReferenceType)
+NODE(PointerToMemberType)
+NODE(ArrayType)
+NODE(FunctionType)
+NODE(NoexceptSpec)
+NODE(DynamicExceptionSpec)
+NODE(FunctionEncoding)
+NODE(LiteralOperator)
+NODE(SpecialName)
+NODE(CtorVtableSpecialName)
+NODE(QualifiedName)
+NODE(NestedName)
+NODE(LocalName)
+NODE(ModuleName)
+NODE(ModuleEntity)
+NODE(VectorType)
+NODE(PixelVectorType)
+NODE(BinaryFPType)
+NODE(BitIntType)
+NODE(SyntheticTemplateParamName)
+NODE(TypeTemplateParamDecl)
+NODE(NonTypeTemplateParamDecl)
+NODE(TemplateTemplateParamDecl)
+NODE(TemplateParamPackDecl)
+NODE(ParameterPack)
+NODE(TemplateArgumentPack)
+NODE(ParameterPackExpansion)
+NODE(TemplateArgs)
+NODE(ForwardTemplateReference)
+NODE(NameWithTemplateArgs)
+NODE(GlobalQualifiedName)
+NODE(ExpandedSpecialSubstitution)
+NODE(SpecialSubstitution)
+NODE(CtorDtorName)
+NODE(DtorName)
+NODE(UnnamedTypeName)
+NODE(ClosureTypeName)
+NODE(StructuredBindingName)
+NODE(BinaryExpr)
+NODE(ArraySubscriptExpr)
+NODE(PostfixExpr)
+NODE(ConditionalExpr)
+NODE(MemberExpr)
+NODE(SubobjectExpr)
+NODE(EnclosingExpr)
+NODE(CastExpr)
+NODE(SizeofParamPackExpr)
+NODE(CallExpr)
+NODE(NewExpr)
+NODE(DeleteExpr)
+NODE(PrefixExpr)
+NODE(FunctionParam)
+NODE(ConversionExpr)
+NODE(PointerToMemberConversionExpr)
+NODE(InitListExpr)
+NODE(FoldExpr)
+NODE(ThrowExpr)
+NODE(BoolExpr)
+NODE(StringLiteral)
+NODE(LambdaExpr)
+NODE(EnumLiteral)
+NODE(IntegerLiteral)
+NODE(FloatLiteral)
+NODE(DoubleLiteral)
+NODE(LongDoubleLiteral)
+NODE(BracedExpr)
+NODE(BracedRangeExpr)
+
+#undef NODE
diff --git a/externals/demangle/llvm/Demangle/StringView.h b/externals/demangle/llvm/Demangle/StringView.h
index 44d2b18a3..76b215252 100644
--- a/externals/demangle/llvm/Demangle/StringView.h
+++ b/externals/demangle/llvm/Demangle/StringView.h
@@ -1,5 +1,5 @@
-//===--- StringView.h -------------------------------------------*- C++ -*-===//
-//
+//===--- StringView.h ----------------*- mode:c++;eval:(read-only-mode) -*-===//
+// Do not edit! See README.txt.
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-FileCopyrightText: Part of the LLVM Project
@@ -8,6 +8,9 @@
//===----------------------------------------------------------------------===//
//
// FIXME: Use std::string_view instead when we support C++17.
+// There are two copies of this file in the source tree. The one under
+// libcxxabi is the original and the one under llvm is the copy. Use
+// cp-to-llvm.sh to update the copy. See README.txt for more details.
//
//===----------------------------------------------------------------------===//
@@ -15,7 +18,6 @@
#define DEMANGLE_STRINGVIEW_H
#include "DemangleConfig.h"
-#include <algorithm>
#include <cassert>
#include <cstring>
@@ -37,29 +39,23 @@ public:
StringView(const char *Str) : First(Str), Last(Str + std::strlen(Str)) {}
StringView() : First(nullptr), Last(nullptr) {}
- StringView substr(size_t From) const {
- return StringView(begin() + From, size() - From);
+ StringView substr(size_t Pos, size_t Len = npos) const {
+ assert(Pos <= size());
+ if (Len > size() - Pos)
+ Len = size() - Pos;
+ return StringView(begin() + Pos, Len);
}
size_t find(char C, size_t From = 0) const {
- size_t FindBegin = std::min(From, size());
// Avoid calling memchr with nullptr.
- if (FindBegin < size()) {
+ if (From < size()) {
// Just forward to memchr, which is faster than a hand-rolled loop.
- if (const void *P = ::memchr(First + FindBegin, C, size() - FindBegin))
+ if (const void *P = ::memchr(First + From, C, size() - From))
return size_t(static_cast<const char *>(P) - First);
}
return npos;
}
- StringView substr(size_t From, size_t To) const {
- if (To >= size())
- To = size() - 1;
- if (From >= size())
- From = size() - 1;
- return StringView(First + From, First + To);
- }
-
StringView dropFront(size_t N = 1) const {
if (N >= size())
N = size();
@@ -106,7 +102,7 @@ public:
bool startsWith(StringView Str) const {
if (Str.size() > size())
return false;
- return std::equal(Str.begin(), Str.end(), begin());
+ return std::strncmp(Str.begin(), begin(), Str.size()) == 0;
}
const char &operator[](size_t Idx) const { return *(begin() + Idx); }
@@ -119,7 +115,7 @@ public:
inline bool operator==(const StringView &LHS, const StringView &RHS) {
return LHS.size() == RHS.size() &&
- std::equal(LHS.begin(), LHS.end(), RHS.begin());
+ std::strncmp(LHS.begin(), RHS.begin(), LHS.size()) == 0;
}
DEMANGLE_NAMESPACE_END
diff --git a/externals/demangle/llvm/Demangle/StringViewExtras.h b/externals/demangle/llvm/Demangle/StringViewExtras.h
new file mode 100644
index 000000000..83685c304
--- /dev/null
+++ b/externals/demangle/llvm/Demangle/StringViewExtras.h
@@ -0,0 +1,39 @@
+//===--- StringViewExtras.h ----------*- mode:c++;eval:(read-only-mode) -*-===//
+// Do not edit! See README.txt.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-FileCopyrightText: Part of the LLVM Project
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// There are two copies of this file in the source tree. The one under
+// libcxxabi is the original and the one under llvm is the copy. Use
+// cp-to-llvm.sh to update the copy. See README.txt for more details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DEMANGLE_STRINGVIEW_H
+#define DEMANGLE_STRINGVIEW_H
+
+#include "DemangleConfig.h"
+
+#include <string_view>
+
+DEMANGLE_NAMESPACE_BEGIN
+
+inline bool starts_with(std::string_view self, char C) noexcept {
+ return !self.empty() && *self.begin() == C;
+}
+
+inline bool starts_with(std::string_view haystack,
+ std::string_view needle) noexcept {
+ if (needle.size() > haystack.size())
+ return false;
+ haystack.remove_suffix(haystack.size() - needle.size());
+ return haystack == needle;
+}
+
+DEMANGLE_NAMESPACE_END
+
+#endif
diff --git a/externals/demangle/llvm/Demangle/Utility.h b/externals/demangle/llvm/Demangle/Utility.h
index 50d05c6b1..30dfbfc8d 100644
--- a/externals/demangle/llvm/Demangle/Utility.h
+++ b/externals/demangle/llvm/Demangle/Utility.h
@@ -1,5 +1,5 @@
-//===--- Utility.h ----------------------------------------------*- C++ -*-===//
-//
+//===--- Utility.h -------------------*- mode:c++;eval:(read-only-mode) -*-===//
+// Do not edit! See README.txt.
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-FileCopyrightText: Part of the LLVM Project
@@ -7,70 +7,83 @@
//
//===----------------------------------------------------------------------===//
//
-// Provide some utility classes for use in the demangler(s).
+// Provide some utility classes for use in the demangler.
+// There are two copies of this file in the source tree. The one in libcxxabi
+// is the original and the one in llvm is the copy. Use cp-to-llvm.sh to update
+// the copy. See README.txt for more details.
//
//===----------------------------------------------------------------------===//
#ifndef DEMANGLE_UTILITY_H
#define DEMANGLE_UTILITY_H
-#include "StringView.h"
+#include "DemangleConfig.h"
+
+#include <array>
+#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <cstring>
-#include <iterator>
+#include <exception>
#include <limits>
+#include <string_view>
DEMANGLE_NAMESPACE_BEGIN
// Stream that AST nodes write their string representation into after the AST
// has been parsed.
-class OutputStream {
- char *Buffer;
- size_t CurrentPosition;
- size_t BufferCapacity;
+class OutputBuffer {
+ char *Buffer = nullptr;
+ size_t CurrentPosition = 0;
+ size_t BufferCapacity = 0;
- // Ensure there is at least n more positions in buffer.
+ // Ensure there are at least N more positions in the buffer.
void grow(size_t N) {
- if (N + CurrentPosition >= BufferCapacity) {
+ size_t Need = N + CurrentPosition;
+ if (Need > BufferCapacity) {
+ // Reduce the number of reallocations, with a bit of hysteresis. The
+ // number here is chosen so the first allocation will more-than-likely not
+ // allocate more than 1K.
+ Need += 1024 - 32;
BufferCapacity *= 2;
- if (BufferCapacity < N + CurrentPosition)
- BufferCapacity = N + CurrentPosition;
+ if (BufferCapacity < Need)
+ BufferCapacity = Need;
Buffer = static_cast<char *>(std::realloc(Buffer, BufferCapacity));
if (Buffer == nullptr)
std::terminate();
}
}
- void writeUnsigned(uint64_t N, bool isNeg = false) {
- // Handle special case...
- if (N == 0) {
- *this << '0';
- return;
- }
-
- char Temp[21];
- char *TempPtr = std::end(Temp);
+ OutputBuffer &writeUnsigned(uint64_t N, bool isNeg = false) {
+ std::array<char, 21> Temp;
+ char *TempPtr = Temp.data() + Temp.size();
- while (N) {
- *--TempPtr = '0' + char(N % 10);
+ // Output at least one character.
+ do {
+ *--TempPtr = char('0' + N % 10);
N /= 10;
- }
+ } while (N);
- // Add negative sign...
+ // Add negative sign.
if (isNeg)
*--TempPtr = '-';
- this->operator<<(StringView(TempPtr, std::end(Temp)));
+
+ return operator+=(
+ std::string_view(TempPtr, Temp.data() + Temp.size() - TempPtr));
}
public:
- OutputStream(char *StartBuf, size_t Size)
- : Buffer(StartBuf), CurrentPosition(0), BufferCapacity(Size) {}
- OutputStream() = default;
- void reset(char *Buffer_, size_t BufferCapacity_) {
- CurrentPosition = 0;
- Buffer = Buffer_;
- BufferCapacity = BufferCapacity_;
+ OutputBuffer(char *StartBuf, size_t Size)
+ : Buffer(StartBuf), BufferCapacity(Size) {}
+ OutputBuffer(char *StartBuf, size_t *SizePtr)
+ : OutputBuffer(StartBuf, StartBuf ? *SizePtr : 0) {}
+ OutputBuffer() = default;
+ // Non-copyable
+ OutputBuffer(const OutputBuffer &) = delete;
+ OutputBuffer &operator=(const OutputBuffer &) = delete;
+
+ operator std::string_view() const {
+ return std::string_view(Buffer, CurrentPosition);
}
/// If a ParameterPackExpansion (or similar type) is encountered, the offset
@@ -78,115 +91,116 @@ public:
unsigned CurrentPackIndex = std::numeric_limits<unsigned>::max();
unsigned CurrentPackMax = std::numeric_limits<unsigned>::max();
- OutputStream &operator+=(StringView R) {
- size_t Size = R.size();
- if (Size == 0)
- return *this;
- grow(Size);
- std::memmove(Buffer + CurrentPosition, R.begin(), Size);
- CurrentPosition += Size;
+ /// When zero, we're printing template args and '>' needs to be parenthesized.
+ /// Use a counter so we can simply increment inside parentheses.
+ unsigned GtIsGt = 1;
+
+ bool isGtInsideTemplateArgs() const { return GtIsGt == 0; }
+
+ void printOpen(char Open = '(') {
+ GtIsGt++;
+ *this += Open;
+ }
+ void printClose(char Close = ')') {
+ GtIsGt--;
+ *this += Close;
+ }
+
+ OutputBuffer &operator+=(std::string_view R) {
+ if (size_t Size = R.size()) {
+ grow(Size);
+ std::memcpy(Buffer + CurrentPosition, &*R.begin(), Size);
+ CurrentPosition += Size;
+ }
return *this;
}
- OutputStream &operator+=(char C) {
+ OutputBuffer &operator+=(char C) {
grow(1);
Buffer[CurrentPosition++] = C;
return *this;
}
- OutputStream &operator<<(StringView R) { return (*this += R); }
+ OutputBuffer &prepend(std::string_view R) {
+ size_t Size = R.size();
- OutputStream &operator<<(char C) { return (*this += C); }
+ grow(Size);
+ std::memmove(Buffer + Size, Buffer, CurrentPosition);
+ std::memcpy(Buffer, &*R.begin(), Size);
+ CurrentPosition += Size;
- OutputStream &operator<<(long long N) {
- if (N < 0)
- writeUnsigned(static_cast<unsigned long long>(-N), true);
- else
- writeUnsigned(static_cast<unsigned long long>(N));
return *this;
}
- OutputStream &operator<<(unsigned long long N) {
- writeUnsigned(N, false);
- return *this;
+ OutputBuffer &operator<<(std::string_view R) { return (*this += R); }
+
+ OutputBuffer &operator<<(char C) { return (*this += C); }
+
+ OutputBuffer &operator<<(long long N) {
+ return writeUnsigned(static_cast<unsigned long long>(std::abs(N)), N < 0);
}
- OutputStream &operator<<(long N) {
+ OutputBuffer &operator<<(unsigned long long N) {
+ return writeUnsigned(N, false);
+ }
+
+ OutputBuffer &operator<<(long N) {
return this->operator<<(static_cast<long long>(N));
}
- OutputStream &operator<<(unsigned long N) {
+ OutputBuffer &operator<<(unsigned long N) {
return this->operator<<(static_cast<unsigned long long>(N));
}
- OutputStream &operator<<(int N) {
+ OutputBuffer &operator<<(int N) {
return this->operator<<(static_cast<long long>(N));
}
- OutputStream &operator<<(unsigned int N) {
+ OutputBuffer &operator<<(unsigned int N) {
return this->operator<<(static_cast<unsigned long long>(N));
}
+ void insert(size_t Pos, const char *S, size_t N) {
+ assert(Pos <= CurrentPosition);
+ if (N == 0)
+ return;
+ grow(N);
+ std::memmove(Buffer + Pos + N, Buffer + Pos, CurrentPosition - Pos);
+ std::memcpy(Buffer + Pos, S, N);
+ CurrentPosition += N;
+ }
+
size_t getCurrentPosition() const { return CurrentPosition; }
void setCurrentPosition(size_t NewPos) { CurrentPosition = NewPos; }
char back() const {
- return CurrentPosition ? Buffer[CurrentPosition - 1] : '\0';
+ assert(CurrentPosition);
+ return Buffer[CurrentPosition - 1];
}
bool empty() const { return CurrentPosition == 0; }
char *getBuffer() { return Buffer; }
char *getBufferEnd() { return Buffer + CurrentPosition - 1; }
- size_t getBufferCapacity() { return BufferCapacity; }
+ size_t getBufferCapacity() const { return BufferCapacity; }
};
-template <class T> class SwapAndRestore {
- T &Restore;
- T OriginalValue;
- bool ShouldRestore = true;
+template <class T> class ScopedOverride {
+ T &Loc;
+ T Original;
public:
- SwapAndRestore(T &Restore_) : SwapAndRestore(Restore_, Restore_) {}
-
- SwapAndRestore(T &Restore_, T NewVal)
- : Restore(Restore_), OriginalValue(Restore) {
- Restore = std::move(NewVal);
- }
- ~SwapAndRestore() {
- if (ShouldRestore)
- Restore = std::move(OriginalValue);
- }
+ ScopedOverride(T &Loc_) : ScopedOverride(Loc_, Loc_) {}
- void shouldRestore(bool ShouldRestore_) { ShouldRestore = ShouldRestore_; }
-
- void restoreNow(bool Force) {
- if (!Force && !ShouldRestore)
- return;
-
- Restore = std::move(OriginalValue);
- ShouldRestore = false;
+ ScopedOverride(T &Loc_, T NewVal) : Loc(Loc_), Original(Loc_) {
+ Loc_ = std::move(NewVal);
}
+ ~ScopedOverride() { Loc = std::move(Original); }
- SwapAndRestore(const SwapAndRestore &) = delete;
- SwapAndRestore &operator=(const SwapAndRestore &) = delete;
+ ScopedOverride(const ScopedOverride &) = delete;
+ ScopedOverride &operator=(const ScopedOverride &) = delete;
};
-inline bool initializeOutputStream(char *Buf, size_t *N, OutputStream &S,
- size_t InitSize) {
- size_t BufferSize;
- if (Buf == nullptr) {
- Buf = static_cast<char *>(std::malloc(InitSize));
- if (Buf == nullptr)
- return false;
- BufferSize = InitSize;
- } else
- BufferSize = *N;
-
- S.reset(Buf, BufferSize);
- return true;
-}
-
DEMANGLE_NAMESPACE_END
#endif
diff --git a/externals/vma/VulkanMemoryAllocator b/externals/vma/VulkanMemoryAllocator
deleted file mode 160000
-Subproject 0aa3989b8f382f185fdf646cc83a1d16fa31d6a
diff --git a/src/android/app/src/main/res/values/strings.xml b/src/android/app/src/main/res/values/strings.xml
index b963f0119..bfdebd35b 100644
--- a/src/android/app/src/main/res/values/strings.xml
+++ b/src/android/app/src/main/res/values/strings.xml
@@ -232,7 +232,7 @@
<!-- ROM loading errors -->
<string name="loader_error_encrypted">Your ROM is encrypted</string>
- <string name="loader_error_encrypted_roms_description"><![CDATA[Please follow the guides to redump your <a href="https://yuzu-emu.org/help/quickstart/#dumping-cartridge-games">game cartidges</a> or <a href="https://yuzu-emu.org/help/quickstart/#dumping-installed-titles-eshop">installed titles</a>.]]></string>
+ <string name="loader_error_encrypted_roms_description"><![CDATA[Please follow the guides to redump your <a href="https://yuzu-emu.org/help/quickstart/#dumping-physical-titles-game-cards">game cartidges</a> or <a href="https://yuzu-emu.org/help/quickstart/#dumping-digital-titles-eshop">installed titles</a>.]]></string>
<string name="loader_error_encrypted_keys_description"><![CDATA[Please ensure your <a href="https://yuzu-emu.org/help/quickstart/#dumping-prodkeys-and-titlekeys">prod.keys</a> file is installed so that games can be decrypted.]]></string>
<string name="loader_error_video_core">An error occurred initializing the video core</string>
<string name="loader_error_video_core_description">This is usually caused by an incompatible GPU driver. Installing a custom GPU driver may resolve this problem.</string>
diff --git a/src/common/demangle.cpp b/src/common/demangle.cpp
index 3310faf86..6e117cb41 100644
--- a/src/common/demangle.cpp
+++ b/src/common/demangle.cpp
@@ -23,7 +23,7 @@ std::string DemangleSymbol(const std::string& mangled) {
SCOPE_EXIT({ std::free(demangled); });
if (is_itanium(mangled)) {
- demangled = llvm::itaniumDemangle(mangled.c_str(), nullptr, nullptr, nullptr);
+ demangled = llvm::itaniumDemangle(mangled.c_str());
}
if (!demangled) {
diff --git a/src/common/detached_tasks.cpp b/src/common/detached_tasks.cpp
index da64848da..f2ed795cc 100644
--- a/src/common/detached_tasks.cpp
+++ b/src/common/detached_tasks.cpp
@@ -30,8 +30,8 @@ DetachedTasks::~DetachedTasks() {
void DetachedTasks::AddTask(std::function<void()> task) {
std::unique_lock lock{instance->mutex};
++instance->count;
- std::thread([task{std::move(task)}]() {
- task();
+ std::thread([task_{std::move(task)}]() {
+ task_();
std::unique_lock thread_lock{instance->mutex};
--instance->count;
std::notify_all_at_thread_exit(instance->cv, std::move(thread_lock));
diff --git a/src/common/socket_types.h b/src/common/socket_types.h
index 0a801a443..63824a5c4 100644
--- a/src/common/socket_types.h
+++ b/src/common/socket_types.h
@@ -3,17 +3,22 @@
#pragma once
+#include <optional>
+#include <string>
+
#include "common/common_types.h"
namespace Network {
/// Address families
enum class Domain : u8 {
- INET, ///< Address family for IPv4
+ Unspecified, ///< Represents 0, used in getaddrinfo hints
+ INET, ///< Address family for IPv4
};
/// Socket types
enum class Type {
+ Unspecified, ///< Represents 0, used in getaddrinfo hints
STREAM,
DGRAM,
RAW,
@@ -22,6 +27,7 @@ enum class Type {
/// Protocol values for sockets
enum class Protocol : u8 {
+ Unspecified, ///< Represents 0, usable in various places
ICMP,
TCP,
UDP,
@@ -48,4 +54,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2;
constexpr u32 FLAG_MSG_DONTWAIT = 0x80;
constexpr u32 FLAG_O_NONBLOCK = 0x800;
+/// Cross-platform addrinfo structure
+struct AddrInfo {
+ Domain family;
+ Type socket_type;
+ Protocol protocol;
+ SockAddrIn addr;
+ std::optional<std::string> canon_name;
+};
+
} // namespace Network
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index 28cb6f86f..4b7395be8 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -723,6 +723,7 @@ add_library(core STATIC
hle/service/spl/spl_types.h
hle/service/ssl/ssl.cpp
hle/service/ssl/ssl.h
+ hle/service/ssl/ssl_backend.h
hle/service/time/clock_types.h
hle/service/time/ephemeral_network_system_clock_context_writer.h
hle/service/time/ephemeral_network_system_clock_core.h
@@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64)
target_link_libraries(core PRIVATE dynarmic::dynarmic)
endif()
+if(ENABLE_OPENSSL)
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_openssl.cpp)
+ target_link_libraries(core PRIVATE OpenSSL::SSL)
+elseif (APPLE)
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_securetransport.cpp)
+ target_link_libraries(core PRIVATE "-framework Security")
+elseif (WIN32)
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_schannel.cpp)
+ target_link_libraries(core PRIVATE crypt32 secur32)
+else()
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_none.cpp)
+endif()
+
if (YUZU_USE_PRECOMPILED_HEADERS)
target_precompile_headers(core PRIVATE precompiled_headers.h)
endif()
diff --git a/src/core/hle/kernel/k_thread.cpp b/src/core/hle/kernel/k_thread.cpp
index adb6ec581..d88909889 100644
--- a/src/core/hle/kernel/k_thread.cpp
+++ b/src/core/hle/kernel/k_thread.cpp
@@ -302,12 +302,12 @@ Result KThread::InitializeServiceThread(Core::System& system, KThread* thread,
std::function<void()>&& func, s32 prio, s32 virt_core,
KProcess* owner) {
system.Kernel().GlobalSchedulerContext().AddThread(thread);
- std::function<void()> func2{[&system, func{std::move(func)}] {
+ std::function<void()> func2{[&system, func_{std::move(func)}] {
// Similar to UserModeThreadStarter.
system.Kernel().CurrentScheduler()->OnThreadStart();
// Run the guest function.
- func();
+ func_();
// Exit.
Svc::ExitThread(system);
diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp
index f33600ca5..ebe7582c6 100644
--- a/src/core/hle/kernel/kernel.cpp
+++ b/src/core/hle/kernel/kernel.cpp
@@ -1089,15 +1089,15 @@ static std::jthread RunHostThreadFunc(KernelCore& kernel, KProcess* process,
KThread::Register(kernel, thread);
return std::jthread(
- [&kernel, thread, thread_name{std::move(thread_name)}, func{std::move(func)}] {
+ [&kernel, thread, thread_name_{std::move(thread_name)}, func_{std::move(func)}] {
// Set the thread name.
- Common::SetCurrentThreadName(thread_name.c_str());
+ Common::SetCurrentThreadName(thread_name_.c_str());
// Set the thread as current.
kernel.RegisterHostThread(thread);
// Run the callback.
- func();
+ func_();
// Close the thread.
// This will free the process if it is the last reference.
diff --git a/src/core/hle/service/am/am.cpp b/src/core/hle/service/am/am.cpp
index a2375508a..4f400d341 100644
--- a/src/core/hle/service/am/am.cpp
+++ b/src/core/hle/service/am/am.cpp
@@ -506,8 +506,8 @@ void ISelfController::SetHandlesRequestToDisplay(HLERequestContext& ctx) {
void ISelfController::SetIdleTimeDetectionExtension(HLERequestContext& ctx) {
IPC::RequestParser rp{ctx};
idle_time_detection_extension = rp.Pop<u32>();
- LOG_WARNING(Service_AM, "(STUBBED) called idle_time_detection_extension={}",
- idle_time_detection_extension);
+ LOG_DEBUG(Service_AM, "(STUBBED) called idle_time_detection_extension={}",
+ idle_time_detection_extension);
IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ResultSuccess);
diff --git a/src/core/hle/service/nifm/nifm.cpp b/src/core/hle/service/nifm/nifm.cpp
index 91d42853e..21b06d10b 100644
--- a/src/core/hle/service/nifm/nifm.cpp
+++ b/src/core/hle/service/nifm/nifm.cpp
@@ -7,6 +7,7 @@
#include "core/hle/service/kernel_helpers.h"
#include "core/hle/service/nifm/nifm.h"
#include "core/hle/service/server_manager.h"
+#include "network/network.h"
namespace {
diff --git a/src/core/hle/service/nifm/nifm.h b/src/core/hle/service/nifm/nifm.h
index 9b20e6823..ae99c4695 100644
--- a/src/core/hle/service/nifm/nifm.h
+++ b/src/core/hle/service/nifm/nifm.h
@@ -4,14 +4,15 @@
#pragma once
#include "core/hle/service/service.h"
-#include "network/network.h"
-#include "network/room.h"
-#include "network/room_member.h"
namespace Core {
class System;
}
+namespace Network {
+class RoomNetwork;
+}
+
namespace Service::NIFM {
void LoopProcess(Core::System& system);
diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp
index bce45d321..e63b0a357 100644
--- a/src/core/hle/service/sockets/bsd.cpp
+++ b/src/core/hle/service/sockets/bsd.cpp
@@ -20,6 +20,9 @@
#include "core/internal_network/sockets.h"
#include "network/network.h"
+using Common::Expected;
+using Common::Unexpected;
+
namespace Service::Sockets {
namespace {
@@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) {
const u32 level = rp.Pop<u32>();
const auto optname = static_cast<OptName>(rp.Pop<u32>());
- LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname);
-
std::vector<u8> optval(ctx.GetWriteBufferSize());
+ LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname,
+ optval.size());
+
+ const Errno err = GetSockOptImpl(fd, level, optname, optval);
+
ctx.WriteBuffer(optval);
IPC::ResponseBuilder rb{ctx, 5};
rb.Push(ResultSuccess);
- rb.Push<s32>(-1);
- rb.PushEnum(Errno::NOTCONN);
+ rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1);
+ rb.PushEnum(err);
rb.Push<u32>(static_cast<u32>(optval.size()));
}
@@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) {
BuildErrnoResponse(ctx, CloseImpl(fd));
}
+void BSD::DuplicateSocket(HLERequestContext& ctx) {
+ struct InputParameters {
+ s32 fd;
+ u64 reserved;
+ };
+ static_assert(sizeof(InputParameters) == 0x10);
+
+ struct OutputParameters {
+ s32 ret;
+ Errno bsd_errno;
+ };
+ static_assert(sizeof(OutputParameters) == 0x8);
+
+ IPC::RequestParser rp{ctx};
+ auto input = rp.PopRaw<InputParameters>();
+
+ Expected<s32, Errno> res = DuplicateSocketImpl(input.fd);
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .ret = res.value_or(0),
+ .bsd_errno = res ? Errno::SUCCESS : res.error(),
+ });
+}
+
void BSD::EventFd(HLERequestContext& ctx) {
IPC::RequestParser rp{ctx};
const u64 initval = rp.Pop<u64>();
@@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco
auto room_member = room_network.GetRoomMember().lock();
if (room_member && room_member->IsConnected()) {
- descriptor.socket = std::make_unique<Network::ProxySocket>(room_network);
+ descriptor.socket = std::make_shared<Network::ProxySocket>(room_network);
} else {
- descriptor.socket = std::make_unique<Network::Socket>();
+ descriptor.socket = std::make_shared<Network::Socket>();
}
- descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol));
+ descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol));
descriptor.is_connection_based = IsConnectionBased(type);
return {fd, Errno::SUCCESS};
@@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) {
Network::PollFD result;
result.socket = file_descriptors[pollfd.fd]->socket.get();
- result.events = TranslatePollEventsToHost(pollfd.events);
+ result.events = Translate(pollfd.events);
result.revents = Network::PollEvents{};
return result;
});
@@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
const size_t num = host_pollfds.size();
for (size_t i = 0; i < num; ++i) {
- fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents);
+ fds[i].revents = Translate(host_pollfds[i].revents);
}
std::memcpy(write_buffer.data(), fds.data(), length);
@@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) {
}
const SockAddrIn guest_addrin = Translate(addr_in);
- ASSERT(write_buffer.size() == sizeof(guest_addrin));
+ ASSERT(write_buffer.size() >= sizeof(guest_addrin));
+ write_buffer.resize(sizeof(guest_addrin));
std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
return Translate(bsd_errno);
}
@@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) {
}
const SockAddrIn guest_addrin = Translate(addr_in);
- ASSERT(write_buffer.size() == sizeof(guest_addrin));
+ ASSERT(write_buffer.size() >= sizeof(guest_addrin));
+ write_buffer.resize(sizeof(guest_addrin));
std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
return Translate(bsd_errno);
}
@@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) {
}
}
-Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
- UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET
+Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) {
+ if (!IsFileDescriptorValid(fd)) {
+ return Errno::BADF;
+ }
+
+ if (level != static_cast<u32>(SocketLevel::SOCKET)) {
+ UNIMPLEMENTED_MSG("Unknown getsockopt level");
+ return Errno::SUCCESS;
+ }
+
+ Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
+
+ switch (optname) {
+ case OptName::ERROR_: {
+ auto [pending_err, getsockopt_err] = socket->GetPendingError();
+ if (getsockopt_err == Network::Errno::SUCCESS) {
+ Errno translated_pending_err = Translate(pending_err);
+ ASSERT_OR_EXECUTE_MSG(
+ optval.size() == sizeof(Errno), { return Errno::INVAL; },
+ "Incorrect getsockopt option size");
+ optval.resize(sizeof(Errno));
+ memcpy(optval.data(), &translated_pending_err, sizeof(Errno));
+ }
+ return Translate(getsockopt_err);
+ }
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
+ return Errno::SUCCESS;
+ }
+}
+Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
if (!IsFileDescriptorValid(fd)) {
return Errno::BADF;
}
+ if (level != static_cast<u32>(SocketLevel::SOCKET)) {
+ UNIMPLEMENTED_MSG("Unknown setsockopt level");
+ return Errno::SUCCESS;
+ }
+
Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
if (optname == OptName::LINGER) {
@@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con
return Translate(socket->SetSndTimeo(value));
case OptName::RCVTIMEO:
return Translate(socket->SetRcvTimeo(value));
+ case OptName::NOSIGPIPE:
+ LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value);
+ return Errno::SUCCESS;
default:
UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
return Errno::SUCCESS;
@@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) {
return bsd_errno;
}
+Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) {
+ if (!IsFileDescriptorValid(fd)) {
+ return Unexpected(Errno::BADF);
+ }
+
+ const s32 new_fd = FindFreeFileDescriptorHandle();
+ if (new_fd < 0) {
+ LOG_ERROR(Service, "No more file descriptors available");
+ return Unexpected(Errno::MFILE);
+ }
+
+ file_descriptors[new_fd] = file_descriptors[fd];
+ return new_fd;
+}
+
+std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) {
+ if (!IsFileDescriptorValid(fd)) {
+ return std::nullopt;
+ }
+ return file_descriptors[fd]->socket;
+}
+
s32 BSD::FindFreeFileDescriptorHandle() noexcept {
for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) {
if (!file_descriptors[fd]) {
@@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name)
{24, &BSD::Write, "Write"},
{25, &BSD::Read, "Read"},
{26, &BSD::Close, "Close"},
- {27, nullptr, "DuplicateSocket"},
+ {27, &BSD::DuplicateSocket, "DuplicateSocket"},
{28, nullptr, "GetResourceStatistics"},
{29, nullptr, "RecvMMsg"},
{30, nullptr, "SendMMsg"},
diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h
index 30ae9c140..430edb97c 100644
--- a/src/core/hle/service/sockets/bsd.h
+++ b/src/core/hle/service/sockets/bsd.h
@@ -8,6 +8,7 @@
#include <vector>
#include "common/common_types.h"
+#include "common/expected.h"
#include "common/socket_types.h"
#include "core/hle/service/service.h"
#include "core/hle/service/sockets/sockets.h"
@@ -29,12 +30,19 @@ public:
explicit BSD(Core::System& system_, const char* name);
~BSD() override;
+ // These methods are called from SSL; the first two are also called from
+ // this class for the corresponding IPC methods.
+ // On the real device, the SSL service makes IPC calls to this service.
+ Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd);
+ Errno CloseImpl(s32 fd);
+ std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd);
+
private:
/// Maximum number of file descriptors
static constexpr size_t MAX_FD = 128;
struct FileDescriptor {
- std::unique_ptr<Network::SocketBase> socket;
+ std::shared_ptr<Network::SocketBase> socket;
s32 flags = 0;
bool is_connection_based = false;
};
@@ -138,6 +146,7 @@ private:
void Write(HLERequestContext& ctx);
void Read(HLERequestContext& ctx);
void Close(HLERequestContext& ctx);
+ void DuplicateSocket(HLERequestContext& ctx);
void EventFd(HLERequestContext& ctx);
template <typename Work>
@@ -153,6 +162,7 @@ private:
Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer);
Errno ListenImpl(s32 fd, s32 backlog);
std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg);
+ Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval);
Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval);
Errno ShutdownImpl(s32 fd, s32 how);
std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message);
@@ -161,7 +171,6 @@ private:
std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message);
std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message,
std::span<const u8> addr);
- Errno CloseImpl(s32 fd);
s32 FindFreeFileDescriptorHandle() noexcept;
bool IsFileDescriptorValid(s32 fd) const noexcept;
diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp
index 6491a73be..0dfb0f166 100644
--- a/src/core/hle/service/sockets/nsd.cpp
+++ b/src/core/hle/service/sockets/nsd.cpp
@@ -1,10 +1,15 @@
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
+#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/sockets/nsd.h"
+#include "common/string_util.h"
+
namespace Service::Sockets {
+constexpr Result ResultOverflow{ErrorModule::NSD, 6};
+
NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} {
// clang-format off
static const FunctionInfo functions[] = {
@@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
{13, nullptr, "DeleteSettings"},
{14, nullptr, "ImportSettings"},
{15, nullptr, "SetChangeEnvironmentIdentifierDisabled"},
- {20, nullptr, "Resolve"},
- {21, nullptr, "ResolveEx"},
+ {20, &NSD::Resolve, "Resolve"},
+ {21, &NSD::ResolveEx, "ResolveEx"},
{30, nullptr, "GetNasServiceSetting"},
{31, nullptr, "GetNasServiceSettingEx"},
{40, nullptr, "GetNasRequestFqdn"},
@@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
RegisterHandlers(functions);
}
+static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
+ // The real implementation makes various substitutions.
+ // For now we just return the string as-is, which is good enough when not
+ // connecting to real Nintendo servers.
+ LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in);
+ return fqdn_in;
+}
+
+static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
+ const auto res = ResolveImpl(fqdn_in);
+ if (res.Failed()) {
+ return res.Code();
+ }
+ if (res->size() >= fqdn_out.size()) {
+ return ResultOverflow;
+ }
+ std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1);
+ return ResultSuccess;
+}
+
+void NSD::Resolve(HLERequestContext& ctx) {
+ const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
+
+ std::array<char, 0x100> fqdn_out{};
+ const Result res = ResolveCommon(fqdn_in, fqdn_out);
+
+ ctx.WriteBuffer(fqdn_out);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+}
+
+void NSD::ResolveEx(HLERequestContext& ctx) {
+ const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
+
+ std::array<char, 0x100> fqdn_out;
+ const Result res = ResolveCommon(fqdn_in, fqdn_out);
+
+ if (res.IsError()) {
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ return;
+ }
+
+ ctx.WriteBuffer(fqdn_out);
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(ResultSuccess);
+ rb.Push(ResultSuccess);
+}
+
NSD::~NSD() = default;
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h
index 5cc12b855..a7379a8a9 100644
--- a/src/core/hle/service/sockets/nsd.h
+++ b/src/core/hle/service/sockets/nsd.h
@@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> {
public:
explicit NSD(Core::System& system_, const char* name);
~NSD() override;
+
+private:
+ void Resolve(HLERequestContext& ctx);
+ void ResolveEx(HLERequestContext& ctx);
};
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp
index 132dd5797..84cc79de8 100644
--- a/src/core/hle/service/sockets/sfdnsres.cpp
+++ b/src/core/hle/service/sockets/sfdnsres.cpp
@@ -10,27 +10,18 @@
#include "core/core.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/sockets/sfdnsres.h"
+#include "core/hle/service/sockets/sockets.h"
+#include "core/hle/service/sockets/sockets_translate.h"
+#include "core/internal_network/network.h"
#include "core/memory.h"
-#ifdef _WIN32
-#include <ws2tcpip.h>
-#elif YUZU_UNIX
-#include <arpa/inet.h>
-#include <netdb.h>
-#include <netinet/in.h>
-#include <sys/socket.h>
-#ifndef EAI_NODATA
-#define EAI_NODATA EAI_NONAME
-#endif
-#endif
-
namespace Service::Sockets {
SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} {
static const FunctionInfo functions[] = {
{0, nullptr, "SetDnsAddressesPrivateRequest"},
{1, nullptr, "GetDnsAddressPrivateRequest"},
- {2, nullptr, "GetHostByNameRequest"},
+ {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"},
{3, nullptr, "GetHostByAddrRequest"},
{4, nullptr, "GetHostStringErrorRequest"},
{5, nullptr, "GetGaiStringErrorRequest"},
@@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"
{7, nullptr, "GetNameInfoRequest"},
{8, nullptr, "RequestCancelHandleRequest"},
{9, nullptr, "CancelRequest"},
- {10, nullptr, "GetHostByNameRequestWithOptions"},
+ {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"},
{11, nullptr, "GetHostByAddrRequestWithOptions"},
{12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"},
{13, nullptr, "GetNameInfoRequestWithOptions"},
- {14, nullptr, "ResolverSetOptionRequest"},
+ {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"},
{15, nullptr, "ResolverGetOptionRequest"},
};
RegisterHandlers(functions);
@@ -59,188 +50,285 @@ enum class NetDbError : s32 {
NoData = 4,
};
-static NetDbError AddrInfoErrorToNetDbError(s32 result) {
- // Best effort guess to map errors
+static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) {
+ // These combinations have been verified on console (but are not
+ // exhaustive).
switch (result) {
- case 0:
+ case GetAddrInfoError::SUCCESS:
return NetDbError::Success;
- case EAI_AGAIN:
+ case GetAddrInfoError::AGAIN:
return NetDbError::TryAgain;
- case EAI_NODATA:
- return NetDbError::NoData;
+ case GetAddrInfoError::NODATA:
+ return NetDbError::HostNotFound;
+ case GetAddrInfoError::SERVICE:
+ return NetDbError::Success;
default:
return NetDbError::HostNotFound;
}
}
-static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code,
+static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) {
+ // These combinations have been verified on console (but are not
+ // exhaustive).
+ switch (result) {
+ case GetAddrInfoError::SUCCESS:
+ // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for
+ // some reason, but that doesn't seem useful to implement.
+ return Errno::SUCCESS;
+ case GetAddrInfoError::AGAIN:
+ return Errno::SUCCESS;
+ case GetAddrInfoError::NODATA:
+ return Errno::SUCCESS;
+ case GetAddrInfoError::SERVICE:
+ return Errno::INVAL;
+ default:
+ return Errno::SUCCESS;
+ }
+}
+
+template <typename T>
+static void Append(std::vector<u8>& vec, T t) {
+ const size_t offset = vec.size();
+ vec.resize(offset + sizeof(T));
+ std::memcpy(vec.data() + offset, &t, sizeof(T));
+}
+
+static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) {
+ const size_t offset = vec.size();
+ vec.resize(offset + str.size() + 1);
+ std::memmove(vec.data() + offset, str.data(), str.size());
+}
+
+// We implement gethostbyname using the host's getaddrinfo rather than the
+// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo
+// behaves the same on Unix and Windows, unlike gethostbyname where Windows
+// doesn't implement h_errno.
+static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec,
+ std::string_view host) {
+
+ std::vector<u8> data;
+ // h_name: use the input hostname (append nul-terminated)
+ AppendNulTerminated(data, host);
+ // h_aliases: leave empty
+
+ Append<u32_be>(data, 0); // count of h_aliases
+ // (If the count were nonzero, the aliases would be appended as nul-terminated here.)
+ Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype
+ Append<u16_be>(data, sizeof(Network::IPv4Address)); // h_length
+ // h_addr_list:
+ size_t count = vec.size();
+ ASSERT(count <= UINT32_MAX);
+ Append<u32_be>(data, static_cast<uint32_t>(count));
+ for (const Network::AddrInfo& addrinfo : vec) {
+ // On the Switch, this is passed through htonl despite already being
+ // big-endian, so it ends up as little-endian.
+ Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip));
+
+ LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
+ Network::IPv4AddressToString(addrinfo.addr.ip));
+ }
+ return data;
+}
+
+static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
+ struct InputParameters {
+ u8 use_nsd_resolve;
+ u32 cancel_handle;
+ u64 process_id;
+ };
+ static_assert(sizeof(InputParameters) == 0x10);
+
+ IPC::RequestParser rp{ctx};
+ const auto parameters = rp.PopRaw<InputParameters>();
+
+ LOG_WARNING(
+ Service,
+ "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
+ parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
+
+ const auto host_buffer = ctx.ReadBuffer(0);
+ const std::string host = Common::StringFromBuffer(host_buffer);
+ // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions.
+
+ auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt);
+ if (!res.has_value()) {
+ return {0, Translate(res.error())};
+ }
+
+ const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host);
+ const u32 data_size = static_cast<u32>(data.size());
+ ctx.WriteBuffer(data, 0);
+
+ return {data_size, GetAddrInfoError::SUCCESS};
+}
+
+void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
+ auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
+
+ struct OutputParameters {
+ NetDbError netdb_error;
+ Errno bsd_errno;
+ u32 data_size;
+ };
+ static_assert(sizeof(OutputParameters) == 0xc);
+
+ IPC::ResponseBuilder rb{ctx, 5};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ .data_size = data_size,
+ });
+}
+
+void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
+ auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
+
+ struct OutputParameters {
+ u32 data_size;
+ NetDbError netdb_error;
+ Errno bsd_errno;
+ };
+ static_assert(sizeof(OutputParameters) == 0xc);
+
+ IPC::ResponseBuilder rb{ctx, 5};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .data_size = data_size,
+ .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ });
+}
+
+static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
std::string_view host) {
// Adapted from
// https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190
std::vector<u8> data;
- auto* current = addrinfo;
- while (current != nullptr) {
- struct SerializedResponseHeader {
- u32 magic;
- s32 flags;
- s32 family;
- s32 socket_type;
- s32 protocol;
- u32 address_length;
- };
- static_assert(sizeof(SerializedResponseHeader) == 0x18,
- "Response header size must be 0x18 bytes");
-
- constexpr auto header_size = sizeof(SerializedResponseHeader);
- const auto addr_size =
- current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4;
- const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1;
-
- const auto last_size = data.size();
- data.resize(last_size + header_size + addr_size + canonname_size);
-
- // Header in network byte order
- SerializedResponseHeader header{};
-
- constexpr auto HEADER_MAGIC = 0xBEEFCAFE;
- header.magic = htonl(HEADER_MAGIC);
- header.family = htonl(current->ai_family);
- header.flags = htonl(current->ai_flags);
- header.socket_type = htonl(current->ai_socktype);
- header.protocol = htonl(current->ai_protocol);
- header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0;
-
- auto* header_ptr = data.data() + last_size;
- std::memcpy(header_ptr, &header, header_size);
-
- if (header.address_length == 0) {
- std::memset(header_ptr + header_size, 0, 4);
- } else {
- switch (current->ai_family) {
- case AF_INET: {
- struct SockAddrIn {
- s16 sin_family;
- u16 sin_port;
- u32 sin_addr;
- u8 sin_zero[8];
- };
-
- SockAddrIn serialized_addr{};
- const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr);
- serialized_addr.sin_port = htons(addr.sin_port);
- serialized_addr.sin_family = htons(addr.sin_family);
- serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr);
- std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn));
-
- char addr_string_buf[64]{};
- inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf));
- LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf);
- break;
- }
- case AF_INET6: {
- struct SockAddrIn6 {
- s16 sin6_family;
- u16 sin6_port;
- u32 sin6_flowinfo;
- u8 sin6_addr[16];
- u32 sin6_scope_id;
- };
-
- SockAddrIn6 serialized_addr{};
- const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr);
- serialized_addr.sin6_family = htons(addr.sin6_family);
- serialized_addr.sin6_port = htons(addr.sin6_port);
- serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo);
- serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id);
- std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr,
- sizeof(SockAddrIn6::sin6_addr));
- std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6));
-
- char addr_string_buf[64]{};
- inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf));
- LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf);
- break;
- }
- default:
- std::memcpy(header_ptr + header_size, current->ai_addr, addr_size);
- break;
- }
- }
- if (current->ai_canonname) {
- std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size);
+ for (const Network::AddrInfo& addrinfo : vec) {
+ // serialized addrinfo:
+ Append<u32_be>(data, 0xBEEFCAFE); // magic
+ Append<u32_be>(data, 0); // ai_flags
+ Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family))); // ai_family
+ Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype
+ Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol))); // ai_protocol
+ Append<u32_be>(data, sizeof(SockAddrIn)); // ai_addrlen
+ // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size
+
+ // ai_addr:
+ Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family
+ // On the Switch, the following fields are passed through htonl despite
+ // already being big-endian, so they end up as little-endian.
+ Append<u16_le>(data, addrinfo.addr.portno); // sin_port
+ Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr
+ data.resize(data.size() + 8, 0); // sin_zero
+
+ if (addrinfo.canon_name.has_value()) {
+ AppendNulTerminated(data, *addrinfo.canon_name);
} else {
- *(header_ptr + header_size + addr_size) = 0;
+ data.push_back(0);
}
- current = current->ai_next;
+ LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
+ Network::IPv4AddressToString(addrinfo.addr.ip));
}
- // 4-byte sentinel value
- data.push_back(0);
- data.push_back(0);
- data.push_back(0);
- data.push_back(0);
+ data.resize(data.size() + 4, 0); // 4-byte sentinel value
return data;
}
-static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
- struct Parameters {
+static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
+ struct InputParameters {
u8 use_nsd_resolve;
- u32 unknown;
+ u32 cancel_handle;
u64 process_id;
};
+ static_assert(sizeof(InputParameters) == 0x10);
IPC::RequestParser rp{ctx};
- const auto parameters = rp.PopRaw<Parameters>();
+ const auto parameters = rp.PopRaw<InputParameters>();
+
+ LOG_WARNING(
+ Service,
+ "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
+ parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
- LOG_WARNING(Service,
- "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}",
- parameters.use_nsd_resolve, parameters.unknown, parameters.process_id);
+ // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve
+ // before looking up.
const auto host_buffer = ctx.ReadBuffer(0);
const std::string host = Common::StringFromBuffer(host_buffer);
- const auto service_buffer = ctx.ReadBuffer(1);
- const std::string service = Common::StringFromBuffer(service_buffer);
-
- addrinfo* addrinfo;
- // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now
- s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo);
+ std::optional<std::string> service = std::nullopt;
+ if (ctx.CanReadBuffer(1)) {
+ const std::span<const u8> service_buffer = ctx.ReadBuffer(1);
+ service = Common::StringFromBuffer(service_buffer);
+ }
- u32 data_size = 0;
- if (result_code == 0 && addrinfo != nullptr) {
- const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host);
- data_size = static_cast<u32>(data.size());
- freeaddrinfo(addrinfo);
+ // Serialized hints are also passed in a buffer, but are ignored for now.
- ctx.WriteBuffer(data, 0);
+ auto res = Network::GetAddressInfo(host, service);
+ if (!res.has_value()) {
+ return {0, Translate(res.error())};
}
- return std::make_pair(data_size, result_code);
+ const std::vector<u8> data = SerializeAddrInfo(res.value(), host);
+ const u32 data_size = static_cast<u32>(data.size());
+ ctx.WriteBuffer(data, 0);
+
+ return {data_size, GetAddrInfoError::SUCCESS};
}
void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
- auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx);
+ auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
+
+ struct OutputParameters {
+ Errno bsd_errno;
+ GetAddrInfoError gai_error;
+ u32 data_size;
+ };
+ static_assert(sizeof(OutputParameters) == 0xc);
- IPC::ResponseBuilder rb{ctx, 4};
+ IPC::ResponseBuilder rb{ctx, 5};
rb.Push(ResultSuccess);
- rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
- rb.Push(result_code); // errno
- rb.Push(data_size); // serialized size
+ rb.PushRaw(OutputParameters{
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ .gai_error = emu_gai_err,
+ .data_size = data_size,
+ });
}
void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
// Additional options are ignored
- auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx);
+ auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
+
+ struct OutputParameters {
+ u32 data_size;
+ GetAddrInfoError gai_error;
+ NetDbError netdb_error;
+ Errno bsd_errno;
+ };
+ static_assert(sizeof(OutputParameters) == 0x10);
+
+ IPC::ResponseBuilder rb{ctx, 6};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .data_size = data_size,
+ .gai_error = emu_gai_err,
+ .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ });
+}
+
+void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
+ LOG_WARNING(Service, "(STUBBED) called");
+
+ IPC::ResponseBuilder rb{ctx, 3};
- IPC::ResponseBuilder rb{ctx, 5};
rb.Push(ResultSuccess);
- rb.Push(data_size); // serialized size
- rb.Push(result_code); // errno
- rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
- rb.Push(0);
+ rb.Push<s32>(0); // bsd errno
}
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h
index 18e3cd60c..d99a9d560 100644
--- a/src/core/hle/service/sockets/sfdnsres.h
+++ b/src/core/hle/service/sockets/sfdnsres.h
@@ -17,8 +17,11 @@ public:
~SFDNSRES() override;
private:
+ void GetHostByNameRequest(HLERequestContext& ctx);
+ void GetHostByNameRequestWithOptions(HLERequestContext& ctx);
void GetAddrInfoRequest(HLERequestContext& ctx);
void GetAddrInfoRequestWithOptions(HLERequestContext& ctx);
+ void ResolverSetOptionRequest(HLERequestContext& ctx);
};
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h
index acd2dae7b..77426c46e 100644
--- a/src/core/hle/service/sockets/sockets.h
+++ b/src/core/hle/service/sockets/sockets.h
@@ -22,13 +22,35 @@ enum class Errno : u32 {
CONNRESET = 104,
NOTCONN = 107,
TIMEDOUT = 110,
+ INPROGRESS = 115,
+};
+
+enum class GetAddrInfoError : s32 {
+ SUCCESS = 0,
+ ADDRFAMILY = 1,
+ AGAIN = 2,
+ BADFLAGS = 3,
+ FAIL = 4,
+ FAMILY = 5,
+ MEMORY = 6,
+ NODATA = 7,
+ NONAME = 8,
+ SERVICE = 9,
+ SOCKTYPE = 10,
+ SYSTEM = 11,
+ BADHINTS = 12,
+ PROTOCOL = 13,
+ OVERFLOW_ = 14, // avoid name collision with Windows macro
+ OTHER = 15,
};
enum class Domain : u32 {
+ Unspecified = 0,
INET = 2,
};
enum class Type : u32 {
+ Unspecified = 0,
STREAM = 1,
DGRAM = 2,
RAW = 3,
@@ -36,12 +58,16 @@ enum class Type : u32 {
};
enum class Protocol : u32 {
- UNSPECIFIED = 0,
+ Unspecified = 0,
ICMP = 1,
TCP = 6,
UDP = 17,
};
+enum class SocketLevel : u32 {
+ SOCKET = 0xffff, // i.e. SOL_SOCKET
+};
+
enum class OptName : u32 {
REUSEADDR = 0x4,
KEEPALIVE = 0x8,
@@ -51,6 +77,8 @@ enum class OptName : u32 {
RCVBUF = 0x1002,
SNDTIMEO = 0x1005,
RCVTIMEO = 0x1006,
+ ERROR_ = 0x1007, // avoid name collision with Windows macro
+ NOSIGPIPE = 0x800, // at least according to libnx
};
enum class ShutdownHow : s32 {
@@ -80,6 +108,9 @@ enum class PollEvents : u16 {
Err = 1 << 3,
Hup = 1 << 4,
Nval = 1 << 5,
+ RdNorm = 1 << 6,
+ RdBand = 1 << 7,
+ WrBand = 1 << 8,
};
DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp
index 594e58f90..2f9a0e39c 100644
--- a/src/core/hle/service/sockets/sockets_translate.cpp
+++ b/src/core/hle/service/sockets/sockets_translate.cpp
@@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) {
return Errno::TIMEDOUT;
case Network::Errno::CONNRESET:
return Errno::CONNRESET;
+ case Network::Errno::INPROGRESS:
+ return Errno::INPROGRESS;
default:
UNIMPLEMENTED_MSG("Unimplemented errno={}", value);
return Errno::SUCCESS;
@@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) {
return {value.first, Translate(value.second)};
}
+GetAddrInfoError Translate(Network::GetAddrInfoError error) {
+ switch (error) {
+ case Network::GetAddrInfoError::SUCCESS:
+ return GetAddrInfoError::SUCCESS;
+ case Network::GetAddrInfoError::ADDRFAMILY:
+ return GetAddrInfoError::ADDRFAMILY;
+ case Network::GetAddrInfoError::AGAIN:
+ return GetAddrInfoError::AGAIN;
+ case Network::GetAddrInfoError::BADFLAGS:
+ return GetAddrInfoError::BADFLAGS;
+ case Network::GetAddrInfoError::FAIL:
+ return GetAddrInfoError::FAIL;
+ case Network::GetAddrInfoError::FAMILY:
+ return GetAddrInfoError::FAMILY;
+ case Network::GetAddrInfoError::MEMORY:
+ return GetAddrInfoError::MEMORY;
+ case Network::GetAddrInfoError::NODATA:
+ return GetAddrInfoError::NODATA;
+ case Network::GetAddrInfoError::NONAME:
+ return GetAddrInfoError::NONAME;
+ case Network::GetAddrInfoError::SERVICE:
+ return GetAddrInfoError::SERVICE;
+ case Network::GetAddrInfoError::SOCKTYPE:
+ return GetAddrInfoError::SOCKTYPE;
+ case Network::GetAddrInfoError::SYSTEM:
+ return GetAddrInfoError::SYSTEM;
+ case Network::GetAddrInfoError::BADHINTS:
+ return GetAddrInfoError::BADHINTS;
+ case Network::GetAddrInfoError::PROTOCOL:
+ return GetAddrInfoError::PROTOCOL;
+ case Network::GetAddrInfoError::OVERFLOW_:
+ return GetAddrInfoError::OVERFLOW_;
+ case Network::GetAddrInfoError::OTHER:
+ return GetAddrInfoError::OTHER;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error);
+ return GetAddrInfoError::OTHER;
+ }
+}
+
Network::Domain Translate(Domain domain) {
switch (domain) {
+ case Domain::Unspecified:
+ return Network::Domain::Unspecified;
case Domain::INET:
return Network::Domain::INET;
default:
@@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) {
Domain Translate(Network::Domain domain) {
switch (domain) {
+ case Network::Domain::Unspecified:
+ return Domain::Unspecified;
case Network::Domain::INET:
return Domain::INET;
default:
@@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) {
Network::Type Translate(Type type) {
switch (type) {
+ case Type::Unspecified:
+ return Network::Type::Unspecified;
case Type::STREAM:
return Network::Type::STREAM;
case Type::DGRAM:
return Network::Type::DGRAM;
+ case Type::RAW:
+ return Network::Type::RAW;
+ case Type::SEQPACKET:
+ return Network::Type::SEQPACKET;
default:
UNIMPLEMENTED_MSG("Unimplemented type={}", type);
return Network::Type{};
}
}
-Network::Protocol Translate(Type type, Protocol protocol) {
+Type Translate(Network::Type type) {
+ switch (type) {
+ case Network::Type::Unspecified:
+ return Type::Unspecified;
+ case Network::Type::STREAM:
+ return Type::STREAM;
+ case Network::Type::DGRAM:
+ return Type::DGRAM;
+ case Network::Type::RAW:
+ return Type::RAW;
+ case Network::Type::SEQPACKET:
+ return Type::SEQPACKET;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented type={}", type);
+ return Type{};
+ }
+}
+
+Network::Protocol Translate(Protocol protocol) {
switch (protocol) {
- case Protocol::UNSPECIFIED:
- LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type");
- switch (type) {
- case Type::DGRAM:
- return Network::Protocol::UDP;
- case Type::STREAM:
- return Network::Protocol::TCP;
- default:
- return Network::Protocol::TCP;
- }
+ case Protocol::Unspecified:
+ return Network::Protocol::Unspecified;
case Protocol::TCP:
return Network::Protocol::TCP;
case Protocol::UDP:
return Network::Protocol::UDP;
default:
UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
- return Network::Protocol::TCP;
+ return Network::Protocol::Unspecified;
+ }
+}
+
+Protocol Translate(Network::Protocol protocol) {
+ switch (protocol) {
+ case Network::Protocol::Unspecified:
+ return Protocol::Unspecified;
+ case Network::Protocol::TCP:
+ return Protocol::TCP;
+ case Network::Protocol::UDP:
+ return Protocol::UDP;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
+ return Protocol::Unspecified;
}
}
-Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
+Network::PollEvents Translate(PollEvents flags) {
Network::PollEvents result{};
const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) {
if (True(flags & from)) {
@@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
translate(PollEvents::Err, Network::PollEvents::Err);
translate(PollEvents::Hup, Network::PollEvents::Hup);
translate(PollEvents::Nval, Network::PollEvents::Nval);
+ translate(PollEvents::RdNorm, Network::PollEvents::RdNorm);
+ translate(PollEvents::RdBand, Network::PollEvents::RdBand);
+ translate(PollEvents::WrBand, Network::PollEvents::WrBand);
UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
return result;
}
-PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
+PollEvents Translate(Network::PollEvents flags) {
PollEvents result{};
const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) {
if (True(flags & from)) {
@@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
translate(Network::PollEvents::Err, PollEvents::Err);
translate(Network::PollEvents::Hup, PollEvents::Hup);
translate(Network::PollEvents::Nval, PollEvents::Nval);
+ translate(Network::PollEvents::RdNorm, PollEvents::RdNorm);
+ translate(Network::PollEvents::RdBand, PollEvents::RdBand);
+ translate(Network::PollEvents::WrBand, PollEvents::WrBand);
UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
return result;
}
Network::SockAddrIn Translate(SockAddrIn value) {
- ASSERT(value.len == 0 || value.len == sizeof(value));
+ // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets
+ // sin_len to 6 when deserializing getaddrinfo results).
+ ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6);
return {
.family = Translate(static_cast<Domain>(value.family)),
diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h
index c93291d3e..694868b37 100644
--- a/src/core/hle/service/sockets/sockets_translate.h
+++ b/src/core/hle/service/sockets/sockets_translate.h
@@ -17,6 +17,9 @@ Errno Translate(Network::Errno value);
/// Translate abstract return value errno pair to guest return value errno pair
std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value);
+/// Translate abstract getaddrinfo error to guest getaddrinfo error
+GetAddrInfoError Translate(Network::GetAddrInfoError value);
+
/// Translate guest domain to abstract domain
Network::Domain Translate(Domain domain);
@@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain);
/// Translate guest type to abstract type
Network::Type Translate(Type type);
+/// Translate abstract type to guest type
+Type Translate(Network::Type type);
+
/// Translate guest protocol to abstract protocol
-Network::Protocol Translate(Type type, Protocol protocol);
+Network::Protocol Translate(Protocol protocol);
-/// Translate abstract poll event flags to guest poll event flags
-Network::PollEvents TranslatePollEventsToHost(PollEvents flags);
+/// Translate abstract protocol to guest protocol
+Protocol Translate(Network::Protocol protocol);
/// Translate guest poll event flags to abstract poll event flags
-PollEvents TranslatePollEventsToGuest(Network::PollEvents flags);
+Network::PollEvents Translate(PollEvents flags);
+
+/// Translate abstract poll event flags to guest poll event flags
+PollEvents Translate(Network::PollEvents flags);
/// Translate guest socket address structure to abstract socket address structure
Network::SockAddrIn Translate(SockAddrIn value);
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..9c96f9763 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
+#include "common/string_util.h"
+
+#include "core/core.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/server_manager.h"
#include "core/hle/service/service.h"
+#include "core/hle/service/sm/sm.h"
+#include "core/hle/service/sockets/bsd.h"
#include "core/hle/service/ssl/ssl.h"
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
namespace Service::SSL {
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
CrlImportDateCheckEnable = 1,
};
+// This is nn::ssl::Connection::IoMode
+enum class IoMode : u32 {
+ Blocking = 1,
+ NonBlocking = 2,
+};
+
+// This is nn::ssl::sf::OptionType
+enum class OptionType : u32 {
+ DoNotCloseSocket = 0,
+ GetServerCertChain = 1,
+};
+
// This is nn::ssl::sf::SslVersion
struct SslVersion {
union {
@@ -34,35 +54,42 @@ struct SslVersion {
};
};
+struct SslContextSharedData {
+ u32 connection_count = 0;
+};
+
class ISslConnection final : public ServiceFramework<ISslConnection> {
public:
- explicit ISslConnection(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} {
+ explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in,
+ std::shared_ptr<SslContextSharedData>& shared_data_in,
+ std::unique_ptr<SSLConnectionBackend>&& backend_in)
+ : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in},
+ shared_data{shared_data_in}, backend{std::move(backend_in)} {
// clang-format off
static const FunctionInfo functions[] = {
- {0, nullptr, "SetSocketDescriptor"},
- {1, nullptr, "SetHostName"},
- {2, nullptr, "SetVerifyOption"},
- {3, nullptr, "SetIoMode"},
+ {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
+ {1, &ISslConnection::SetHostName, "SetHostName"},
+ {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
+ {3, &ISslConnection::SetIoMode, "SetIoMode"},
{4, nullptr, "GetSocketDescriptor"},
{5, nullptr, "GetHostName"},
{6, nullptr, "GetVerifyOption"},
{7, nullptr, "GetIoMode"},
- {8, nullptr, "DoHandshake"},
- {9, nullptr, "DoHandshakeGetServerCert"},
- {10, nullptr, "Read"},
- {11, nullptr, "Write"},
- {12, nullptr, "Pending"},
+ {8, &ISslConnection::DoHandshake, "DoHandshake"},
+ {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
+ {10, &ISslConnection::Read, "Read"},
+ {11, &ISslConnection::Write, "Write"},
+ {12, &ISslConnection::Pending, "Pending"},
{13, nullptr, "Peek"},
{14, nullptr, "Poll"},
{15, nullptr, "GetVerifyCertError"},
{16, nullptr, "GetNeededServerCertBufferSize"},
- {17, nullptr, "SetSessionCacheMode"},
+ {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
{18, nullptr, "GetSessionCacheMode"},
{19, nullptr, "FlushSessionCache"},
{20, nullptr, "SetRenegotiationMode"},
{21, nullptr, "GetRenegotiationMode"},
- {22, nullptr, "SetOption"},
+ {22, &ISslConnection::SetOption, "SetOption"},
{23, nullptr, "GetOption"},
{24, nullptr, "GetVerifyCertErrors"},
{25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,299 @@ public:
// clang-format on
RegisterHandlers(functions);
+
+ shared_data->connection_count++;
+ }
+
+ ~ISslConnection() {
+ shared_data->connection_count--;
+ if (fd_to_close.has_value()) {
+ const s32 fd = *fd_to_close;
+ if (!do_not_close_socket) {
+ LOG_ERROR(Service_SSL,
+ "do_not_close_socket was changed after setting socket; is this right?");
+ } else {
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ if (bsd) {
+ auto err = bsd->CloseImpl(fd);
+ if (err != Service::Sockets::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err);
+ }
+ }
+ }
+ }
}
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
+ std::unique_ptr<SSLConnectionBackend> backend;
+ std::optional<int> fd_to_close;
+ bool do_not_close_socket = false;
+ bool get_server_cert_chain = false;
+ std::shared_ptr<Network::SocketBase> socket;
+ bool did_set_host_name = false;
+ bool did_handshake = false;
+
+ ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
+ LOG_DEBUG(Service_SSL, "called, fd={}", fd);
+ ASSERT(!did_handshake);
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
+ s32 ret_fd;
+ // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
+ if (do_not_close_socket) {
+ auto res = bsd->DuplicateSocketImpl(fd);
+ if (!res.has_value()) {
+ LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ fd = *res;
+ fd_to_close = fd;
+ ret_fd = fd;
+ } else {
+ ret_fd = -1;
+ }
+ std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
+ if (!sock.has_value()) {
+ LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ socket = std::move(*sock);
+ backend->SetSocket(socket);
+ return ret_fd;
+ }
+
+ Result SetHostNameImpl(const std::string& hostname) {
+ LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
+ ASSERT(!did_handshake);
+ Result res = backend->SetHostName(hostname);
+ if (res == ResultSuccess) {
+ did_set_host_name = true;
+ }
+ return res;
+ }
+
+ Result SetVerifyOptionImpl(u32 option) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
+ return ResultSuccess;
+ }
+
+ Result SetIoModeImpl(u32 input_mode) {
+ auto mode = static_cast<IoMode>(input_mode);
+ ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
+ ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
+
+ const bool non_block = mode == IoMode::NonBlocking;
+ const Network::Errno error = socket->SetNonBlock(non_block);
+ if (error != Network::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
+ }
+ return ResultSuccess;
+ }
+
+ Result SetSessionCacheModeImpl(u32 mode) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
+ return ResultSuccess;
+ }
+
+ Result DoHandshakeImpl() {
+ ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
+ ASSERT_OR_EXECUTE_MSG(
+ did_set_host_name, { return ResultInternalError; },
+ "Expected SetHostName before DoHandshake");
+ Result res = backend->DoHandshake();
+ did_handshake = res.IsSuccess();
+ return res;
+ }
+
+ std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
+ struct Header {
+ u64 magic;
+ u32 count;
+ u32 pad;
+ };
+ struct EntryHeader {
+ u32 size;
+ u32 offset;
+ };
+ if (!get_server_cert_chain) {
+ // Just return the first one, unencoded.
+ ASSERT_OR_EXECUTE_MSG(
+ !certs.empty(), { return {}; }, "Should be at least one server cert");
+ return certs[0];
+ }
+ std::vector<u8> ret;
+ Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
+ size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
+ for (auto& cert : certs) {
+ EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
+ data_offset += cert.size();
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
+ reinterpret_cast<u8*>(&entry_header + 1));
+ }
+ for (auto& cert : certs) {
+ ret.insert(ret.end(), cert.begin(), cert.end());
+ }
+ return ret;
+ }
+
+ ResultVal<std::vector<u8>> ReadImpl(size_t size) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ std::vector<u8> res(size);
+ ResultVal<size_t> actual = backend->Read(res);
+ if (actual.Failed()) {
+ return actual.Code();
+ }
+ res.resize(*actual);
+ return res;
+ }
+
+ ResultVal<size_t> WriteImpl(std::span<const u8> data) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ return backend->Write(data);
+ }
+
+ ResultVal<s32> PendingImpl() {
+ LOG_WARNING(Service_SSL, "(STUBBED) called.");
+ return 0;
+ }
+
+ void SetSocketDescriptor(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const s32 fd = rp.Pop<s32>();
+ const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(-1));
+ }
+
+ void SetHostName(HLERequestContext& ctx) {
+ const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
+ const Result res = SetHostNameImpl(hostname);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetVerifyOption(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 option = rp.Pop<u32>();
+ const Result res = SetVerifyOptionImpl(option);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetIoMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetIoModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshake(HLERequestContext& ctx) {
+ const Result res = DoHandshakeImpl();
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshakeGetServerCert(HLERequestContext& ctx) {
+ struct OutputParameters {
+ u32 certs_size;
+ u32 certs_count;
+ };
+ static_assert(sizeof(OutputParameters) == 0x8);
+
+ const Result res = DoHandshakeImpl();
+ OutputParameters out{};
+ if (res == ResultSuccess) {
+ auto certs = backend->GetServerCerts();
+ if (certs.Succeeded()) {
+ const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
+ ctx.WriteBuffer(certs_buf);
+ out.certs_count = static_cast<u32>(certs->size());
+ out.certs_size = static_cast<u32>(certs_buf.size());
+ }
+ }
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(res);
+ rb.PushRaw(out);
+ }
+
+ void Read(HLERequestContext& ctx) {
+ const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ if (res.Succeeded()) {
+ rb.Push(static_cast<u32>(res->size()));
+ ctx.WriteBuffer(*res);
+ } else {
+ rb.Push(static_cast<u32>(0));
+ }
+ }
+
+ void Write(HLERequestContext& ctx) {
+ const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push(static_cast<u32>(res.ValueOr(0)));
+ }
+
+ void Pending(HLERequestContext& ctx) {
+ const ResultVal<s32> res = PendingImpl();
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(0));
+ }
+
+ void SetSessionCacheMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetSessionCacheModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetOption(HLERequestContext& ctx) {
+ struct Parameters {
+ OptionType option;
+ s32 value;
+ };
+ static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
+
+ IPC::RequestParser rp{ctx};
+ const auto parameters = rp.PopRaw<Parameters>();
+
+ switch (parameters.option) {
+ case OptionType::DoNotCloseSocket:
+ do_not_close_socket = static_cast<bool>(parameters.value);
+ break;
+ case OptionType::GetServerCertChain:
+ get_server_cert_chain = static_cast<bool>(parameters.value);
+ break;
+ default:
+ LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
+ parameters.value);
+ }
+
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(ResultSuccess);
+ }
};
class ISslContext final : public ServiceFramework<ISslContext> {
public:
explicit ISslContext(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslContext"}, ssl_version{version} {
+ : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
+ shared_data{std::make_shared<SslContextSharedData>()} {
static const FunctionInfo functions[] = {
{0, &ISslContext::SetOption, "SetOption"},
{1, nullptr, "GetOption"},
{2, &ISslContext::CreateConnection, "CreateConnection"},
- {3, nullptr, "GetConnectionCount"},
+ {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
{4, &ISslContext::ImportServerPki, "ImportServerPki"},
{5, &ISslContext::ImportClientPki, "ImportClientPki"},
{6, nullptr, "RemoveServerPki"},
@@ -111,6 +416,7 @@ public:
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
void SetOption(HLERequestContext& ctx) {
struct Parameters {
@@ -130,11 +436,24 @@ private:
}
void CreateConnection(HLERequestContext& ctx) {
- LOG_WARNING(Service_SSL, "(STUBBED) called");
+ LOG_WARNING(Service_SSL, "called");
+
+ auto backend_res = CreateSSLConnectionBackend();
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+ rb.Push(backend_res.Code());
+ if (backend_res.Succeeded()) {
+ rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
+ std::move(*backend_res));
+ }
+ }
+
+ void GetConnectionCount(HLERequestContext& ctx) {
+ LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
+
+ IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess);
- rb.PushIpcInterface<ISslConnection>(system, ssl_version);
+ rb.Push(shared_data->connection_count);
}
void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..409f4367c
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,45 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#pragma once
+
+#include <memory>
+#include <span>
+#include <string>
+#include <vector>
+
+#include "common/common_types.h"
+
+#include "core/hle/result.h"
+
+namespace Network {
+class SocketBase;
+}
+
+namespace Service::SSL {
+
+constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
+constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
+constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
+constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
+
+// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
+// with no way in the latter case to distinguish whether the client should poll
+// for read or write. The one official client I've seen handles this by always
+// polling for read (with a timeout).
+constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
+
+class SSLConnectionBackend {
+public:
+ virtual ~SSLConnectionBackend() {}
+ virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
+ virtual Result SetHostName(const std::string& hostname) = 0;
+ virtual Result DoHandshake() = 0;
+ virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
+ virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
+ virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..2f4f23c42
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,16 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include "common/logging/log.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+
+namespace Service::SSL {
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ LOG_ERROR(Service_SSL,
+ "Can't create SSL connection because no SSL backend is available on this platform");
+ return ResultInternalError;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..6ca869dbf
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,351 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+#include <openssl/bio.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#include <openssl/x509.h>
+
+#include "common/fs/file.h"
+#include "common/hex_util.h"
+#include "common/string_util.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+using namespace Common::FS;
+
+namespace Service::SSL {
+
+// Import OpenSSL's `SSL` type into the namespace. This is needed because the
+// namespace is also named `SSL`.
+using ::SSL;
+
+namespace {
+
+std::once_flag one_time_init_flag;
+bool one_time_init_success = false;
+
+SSL_CTX* ssl_ctx;
+IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
+BIO_METHOD* bio_meth;
+
+Result CheckOpenSSLErrors();
+void OneTimeInit();
+void OneTimeInitLogFile();
+bool OneTimeInitBIO();
+
+} // namespace
+
+class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ std::call_once(one_time_init_flag, OneTimeInit);
+
+ if (!one_time_init_success) {
+ LOG_ERROR(Service_SSL,
+ "Can't create SSL connection because OpenSSL one-time initialization failed");
+ return ResultInternalError;
+ }
+
+ ssl = SSL_new(ssl_ctx);
+ if (!ssl) {
+ LOG_ERROR(Service_SSL, "SSL_new failed");
+ return CheckOpenSSLErrors();
+ }
+
+ SSL_set_connect_state(ssl);
+
+ bio = BIO_new(bio_meth);
+ if (!bio) {
+ LOG_ERROR(Service_SSL, "BIO_new failed");
+ return CheckOpenSSLErrors();
+ }
+
+ BIO_set_data(bio, this);
+ BIO_set_init(bio, 1);
+ SSL_set_bio(ssl, bio, bio);
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
+ socket = std::move(socket_in);
+ }
+
+ Result SetHostName(const std::string& hostname) override {
+ if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
+ LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
+ return CheckOpenSSLErrors();
+ }
+ if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
+ LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
+ return CheckOpenSSLErrors();
+ }
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ SSL_set_verify_result(ssl, X509_V_OK);
+ const int ret = SSL_do_handshake(ssl);
+ const long verify_result = SSL_get_verify_result(ssl);
+ if (verify_result != X509_V_OK) {
+ LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
+ X509_verify_cert_error_string(verify_result));
+ return CheckOpenSSLErrors();
+ }
+ if (ret <= 0) {
+ const int ssl_err = SSL_get_error(ssl, ret);
+ if (ssl_err == SSL_ERROR_ZERO_RETURN ||
+ (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
+ LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
+ return ResultInternalError;
+ }
+ }
+ return HandleReturn("SSL_do_handshake", 0, ret).Code();
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ size_t actual;
+ const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
+ return HandleReturn("SSL_read_ex", actual, ret);
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ size_t actual;
+ const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
+ return HandleReturn("SSL_write_ex", actual, ret);
+ }
+
+ ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
+ const int ssl_err = SSL_get_error(ssl, ret);
+ CheckOpenSSLErrors();
+ switch (ssl_err) {
+ case SSL_ERROR_NONE:
+ return actual;
+ case SSL_ERROR_ZERO_RETURN:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
+ // DoHandshake special-cases this, but for Read and Write:
+ return size_t(0);
+ case SSL_ERROR_WANT_READ:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
+ return ResultWouldBlock;
+ case SSL_ERROR_WANT_WRITE:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
+ return ResultWouldBlock;
+ default:
+ if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
+ return size_t(0);
+ }
+ LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
+ return ResultInternalError;
+ }
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
+ if (!chain) {
+ LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
+ return ResultInternalError;
+ }
+ std::vector<std::vector<u8>> ret;
+ int count = sk_X509_num(chain);
+ ASSERT(count >= 0);
+ for (int i = 0; i < count; i++) {
+ X509* x509 = sk_X509_value(chain, i);
+ ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
+ unsigned char* buf = nullptr;
+ int len = i2d_X509(x509, &buf);
+ ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
+ ret.emplace_back(buf, buf + len);
+ OPENSSL_free(buf);
+ }
+ return ret;
+ }
+
+ ~SSLConnectionBackendOpenSSL() {
+ // these are null-tolerant:
+ SSL_free(ssl);
+ BIO_free(bio);
+ }
+
+ static void KeyLogCallback(const SSL* ssl, const char* line) {
+ std::string str(line);
+ str.push_back('\n');
+ // Do this in a single WriteString for atomicity if multiple instances
+ // are running on different threads (though that can't currently
+ // happen).
+ if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
+ LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
+ }
+ LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
+ }
+
+ static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
+ auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
+ BIO_clear_retry_flags(bio);
+ auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ *actual_p = actual;
+ return 1;
+ case Network::Errno::AGAIN:
+ BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
+ return 0;
+ default:
+ LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
+ return -1;
+ }
+ }
+
+ static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
+ auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
+ BIO_clear_retry_flags(bio);
+ auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ *actual_p = actual;
+ if (actual == 0) {
+ self->got_read_eof = true;
+ }
+ return actual ? 1 : 0;
+ case Network::Errno::AGAIN:
+ BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
+ return 0;
+ default:
+ LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
+ return -1;
+ }
+ }
+
+ static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) {
+ switch (cmd) {
+ case BIO_CTRL_FLUSH:
+ // Nothing to flush.
+ return 1;
+ case BIO_CTRL_PUSH:
+ case BIO_CTRL_POP:
+#ifdef BIO_CTRL_GET_KTLS_SEND
+ case BIO_CTRL_GET_KTLS_SEND:
+ case BIO_CTRL_GET_KTLS_RECV:
+#endif
+ // We don't support these operations, but don't bother logging them
+ // as they're nothing unusual.
+ return 0;
+ default:
+ LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg);
+ return 0;
+ }
+ }
+
+ SSL* ssl = nullptr;
+ BIO* bio = nullptr;
+ bool got_read_eof = false;
+
+ std::shared_ptr<Network::SocketBase> socket;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+namespace {
+
+Result CheckOpenSSLErrors() {
+ unsigned long rc;
+ const char* file;
+ int line;
+ const char* func;
+ const char* data;
+ int flags;
+#if OPENSSL_VERSION_NUMBER >= 0x30000000L
+ while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags)))
+#else
+ // Can't get function names from OpenSSL on this version, so use mine:
+ func = __func__;
+ while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags)))
+#endif
+ {
+ std::string msg;
+ msg.resize(1024, '\0');
+ ERR_error_string_n(rc, msg.data(), msg.size());
+ msg.resize(strlen(msg.data()), '\0');
+ if (flags & ERR_TXT_STRING) {
+ msg.append(" | ");
+ msg.append(data);
+ }
+ Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
+ Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
+ msg);
+ }
+ return ResultInternalError;
+}
+
+void OneTimeInit() {
+ ssl_ctx = SSL_CTX_new(TLS_client_method());
+ if (!ssl_ctx) {
+ LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
+ CheckOpenSSLErrors();
+ return;
+ }
+
+ SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
+
+ if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
+ LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
+ CheckOpenSSLErrors();
+ return;
+ }
+
+ OneTimeInitLogFile();
+
+ if (!OneTimeInitBIO()) {
+ return;
+ }
+
+ one_time_init_success = true;
+}
+
+void OneTimeInitLogFile() {
+ const char* logfile = getenv("SSLKEYLOGFILE");
+ if (logfile) {
+ key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
+ FileShareFlag::ShareWriteOnly);
+ if (key_log_file.IsOpen()) {
+ SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
+ } else {
+ LOG_CRITICAL(Service_SSL,
+ "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
+ }
+ }
+}
+
+bool OneTimeInitBIO() {
+ bio_meth =
+ BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
+ if (!bio_meth ||
+ !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
+ !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
+ !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
+ LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..d8074339a
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,544 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+#include "common/error.h"
+#include "common/fs/file.h"
+#include "common/hex_util.h"
+#include "common/string_util.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+namespace {
+
+// These includes are inside the namespace to avoid a conflict on MinGW where
+// the headers define an enum containing Network and Service as enumerators
+// (which clash with the correspondingly named namespaces).
+#define SECURITY_WIN32
+#include <schnlsp.h>
+#include <security.h>
+#include <wincrypt.h>
+
+std::once_flag one_time_init_flag;
+bool one_time_init_success = false;
+
+SCHANNEL_CRED schannel_cred{};
+CredHandle cred_handle;
+
+static void OneTimeInit() {
+ schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
+ schannel_cred.dwFlags =
+ SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
+ SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
+ SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate
+ // ^ I'm assuming that nobody would want to connect Yuzu to a
+ // service that requires some OS-provided corporate client
+ // certificate, and presenting one to some arbitrary server
+ // might be a privacy concern? Who knows, though.
+
+ const SECURITY_STATUS ret =
+ AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
+ nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
+ if (ret != SEC_E_OK) {
+ // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
+ LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
+ Common::NativeErrorToString(ret));
+ return;
+ }
+
+ if (getenv("SSLKEYLOGFILE")) {
+ LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
+ "keys; not logging keys!");
+ // Not fatal.
+ }
+
+ one_time_init_success = true;
+}
+
+} // namespace
+
+namespace Service::SSL {
+
+class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ std::call_once(one_time_init_flag, OneTimeInit);
+
+ if (!one_time_init_success) {
+ LOG_ERROR(
+ Service_SSL,
+ "Can't create SSL connection because Schannel one-time initialization failed");
+ return ResultInternalError;
+ }
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
+ socket = std::move(socket_in);
+ }
+
+ Result SetHostName(const std::string& hostname_in) override {
+ hostname = hostname_in;
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ while (1) {
+ Result r;
+ switch (handshake_state) {
+ case HandshakeState::Initial:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
+ (r = CallInitializeSecurityContext()) != ResultSuccess) {
+ return r;
+ }
+ // CallInitializeSecurityContext updated `handshake_state`.
+ continue;
+ case HandshakeState::ContinueNeeded:
+ case HandshakeState::IncompleteMessage:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
+ (r = FillCiphertextReadBuf()) != ResultSuccess) {
+ return r;
+ }
+ if (ciphertext_read_buf.empty()) {
+ LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
+ return ResultInternalError;
+ }
+ if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
+ return r;
+ }
+ // CallInitializeSecurityContext updated `handshake_state`.
+ continue;
+ case HandshakeState::DoneAfterFlush:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
+ return r;
+ }
+ handshake_state = HandshakeState::Connected;
+ return ResultSuccess;
+ case HandshakeState::Connected:
+ LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
+ return ResultInternalError;
+ case HandshakeState::Error:
+ return ResultInternalError;
+ }
+ }
+ }
+
+ Result FillCiphertextReadBuf() {
+ const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
+ read_buf_fill_size = 0;
+ // This unnecessarily zeroes the buffer; oh well.
+ const size_t offset = ciphertext_read_buf.size();
+ ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
+ ciphertext_read_buf.resize(offset + fill_size, 0);
+ const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
+ const auto [actual, err] = socket->Recv(0, read_span);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ ASSERT(static_cast<size_t>(actual) <= fill_size);
+ ciphertext_read_buf.resize(offset + actual);
+ return ResultSuccess;
+ case Network::Errno::AGAIN:
+ ciphertext_read_buf.resize(offset);
+ return ResultWouldBlock;
+ default:
+ ciphertext_read_buf.resize(offset);
+ LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
+ return ResultInternalError;
+ }
+ }
+
+ // Returns success if the write buffer has been completely emptied.
+ Result FlushCiphertextWriteBuf() {
+ while (!ciphertext_write_buf.empty()) {
+ const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
+ ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
+ ciphertext_write_buf.begin() + actual);
+ break;
+ case Network::Errno::AGAIN:
+ return ResultWouldBlock;
+ default:
+ LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
+ return ResultInternalError;
+ }
+ }
+ return ResultSuccess;
+ }
+
+ Result CallInitializeSecurityContext() {
+ const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
+ ISC_REQ_USE_SUPPLIED_CREDS;
+ unsigned long attr;
+ // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
+ std::array<SecBuffer, 2> input_buffers{{
+ // only used if `initial_call_done`
+ {
+ // [0]
+ .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
+ .BufferType = SECBUFFER_TOKEN,
+ .pvBuffer = ciphertext_read_buf.data(),
+ },
+ {
+ // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
+ // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
+ // whole buffer wasn't used)
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_EMPTY,
+ .pvBuffer = nullptr,
+ },
+ }};
+ std::array<SecBuffer, 2> output_buffers{{
+ {
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_TOKEN,
+ .pvBuffer = nullptr,
+ }, // [0]
+ {
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_ALERT,
+ .pvBuffer = nullptr,
+ }, // [1]
+ }};
+ SecBufferDesc input_desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(input_buffers.size()),
+ .pBuffers = input_buffers.data(),
+ };
+ SecBufferDesc output_desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(output_buffers.size()),
+ .pBuffers = output_buffers.data(),
+ };
+ ASSERT_OR_EXECUTE_MSG(
+ input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
+ { return ResultInternalError; }, "read buffer too large");
+
+ bool initial_call_done = handshake_state != HandshakeState::Initial;
+ if (initial_call_done) {
+ LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
+ ciphertext_read_buf.size());
+ }
+
+ const SECURITY_STATUS ret =
+ InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
+ // Caller ensured we have set a hostname:
+ const_cast<char*>(hostname.value().c_str()), req,
+ 0, // Reserved1
+ 0, // TargetDataRep not used with Schannel
+ initial_call_done ? &input_desc : nullptr,
+ 0, // Reserved2
+ initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
+ nullptr); // ptsExpiry
+
+ if (output_buffers[0].pvBuffer) {
+ const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
+ output_buffers[0].cbBuffer);
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
+ FreeContextBuffer(output_buffers[0].pvBuffer);
+ }
+
+ if (output_buffers[1].pvBuffer) {
+ const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
+ output_buffers[1].cbBuffer);
+ // The documentation doesn't explain what format this data is in.
+ LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
+ Common::HexToString(span));
+ }
+
+ switch (ret) {
+ case SEC_I_CONTINUE_NEEDED:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
+ if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
+ LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
+ ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
+ ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
+ ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
+ } else {
+ ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
+ ciphertext_read_buf.clear();
+ }
+ handshake_state = HandshakeState::ContinueNeeded;
+ return ResultSuccess;
+ case SEC_E_INCOMPLETE_MESSAGE:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
+ ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
+ read_buf_fill_size = input_buffers[1].cbBuffer;
+ handshake_state = HandshakeState::IncompleteMessage;
+ return ResultSuccess;
+ case SEC_E_OK:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
+ ciphertext_read_buf.clear();
+ handshake_state = HandshakeState::DoneAfterFlush;
+ return GrabStreamSizes();
+ default:
+ LOG_ERROR(Service_SSL,
+ "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
+ Common::NativeErrorToString(ret));
+ handshake_state = HandshakeState::Error;
+ return ResultInternalError;
+ }
+ }
+
+ Result GrabStreamSizes() {
+ const SECURITY_STATUS ret =
+ QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
+ Common::NativeErrorToString(ret));
+ handshake_state = HandshakeState::Error;
+ return ResultInternalError;
+ }
+ return ResultSuccess;
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ if (handshake_state != HandshakeState::Connected) {
+ LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
+ return ResultInternalError;
+ }
+ if (data.size() == 0 || got_read_eof) {
+ return size_t(0);
+ }
+ while (1) {
+ if (!cleartext_read_buf.empty()) {
+ const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
+ std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
+ cleartext_read_buf.erase(cleartext_read_buf.begin(),
+ cleartext_read_buf.begin() + read_size);
+ return read_size;
+ }
+ if (!ciphertext_read_buf.empty()) {
+ SecBuffer empty{
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_EMPTY,
+ .pvBuffer = nullptr,
+ };
+ std::array<SecBuffer, 5> buffers{{
+ {
+ .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
+ .BufferType = SECBUFFER_DATA,
+ .pvBuffer = ciphertext_read_buf.data(),
+ },
+ empty,
+ empty,
+ empty,
+ }};
+ ASSERT_OR_EXECUTE_MSG(
+ buffers[0].cbBuffer == ciphertext_read_buf.size(),
+ { return ResultInternalError; }, "read buffer too large");
+ SecBufferDesc desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(buffers.size()),
+ .pBuffers = buffers.data(),
+ };
+ SECURITY_STATUS ret =
+ DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
+ switch (ret) {
+ case SEC_E_OK:
+ ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
+ { return ResultInternalError; });
+ ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
+ { return ResultInternalError; });
+ ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
+ { return ResultInternalError; });
+ cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
+ static_cast<u8*>(buffers[1].pvBuffer) +
+ buffers[1].cbBuffer);
+ if (buffers[3].BufferType == SECBUFFER_EXTRA) {
+ ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
+ ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
+ ciphertext_read_buf.end() - buffers[3].cbBuffer);
+ } else {
+ ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
+ ciphertext_read_buf.clear();
+ }
+ continue;
+ case SEC_E_INCOMPLETE_MESSAGE:
+ break;
+ case SEC_I_CONTEXT_EXPIRED:
+ // Server hung up by sending close_notify.
+ got_read_eof = true;
+ return size_t(0);
+ default:
+ LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
+ Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ }
+ const Result r = FillCiphertextReadBuf();
+ if (r != ResultSuccess) {
+ return r;
+ }
+ if (ciphertext_read_buf.empty()) {
+ got_read_eof = true;
+ return size_t(0);
+ }
+ }
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ if (handshake_state != HandshakeState::Connected) {
+ LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
+ return ResultInternalError;
+ }
+ if (data.size() == 0) {
+ return size_t(0);
+ }
+ data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
+ if (!cleartext_write_buf.empty()) {
+ // Already in the middle of a write. It wouldn't make sense to not
+ // finish sending the entire buffer since TLS has
+ // header/MAC/padding/etc.
+ if (data.size() != cleartext_write_buf.size() ||
+ std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
+ LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
+ return ResultInternalError;
+ }
+ return WriteAlreadyEncryptedData();
+ } else {
+ cleartext_write_buf.assign(data.begin(), data.end());
+ }
+
+ std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
+ std::vector<u8> tmp_data_buf = cleartext_write_buf;
+ std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
+
+ std::array<SecBuffer, 3> buffers{{
+ {
+ .cbBuffer = stream_sizes.cbHeader,
+ .BufferType = SECBUFFER_STREAM_HEADER,
+ .pvBuffer = header_buf.data(),
+ },
+ {
+ .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
+ .BufferType = SECBUFFER_DATA,
+ .pvBuffer = tmp_data_buf.data(),
+ },
+ {
+ .cbBuffer = stream_sizes.cbTrailer,
+ .BufferType = SECBUFFER_STREAM_TRAILER,
+ .pvBuffer = trailer_buf.data(),
+ },
+ }};
+ ASSERT_OR_EXECUTE_MSG(
+ buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
+ "temp buffer too large");
+ SecBufferDesc desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(buffers.size()),
+ .pBuffers = buffers.data(),
+ };
+
+ const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
+ header_buf.end());
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
+ tmp_data_buf.end());
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
+ trailer_buf.end());
+ return WriteAlreadyEncryptedData();
+ }
+
+ ResultVal<size_t> WriteAlreadyEncryptedData() {
+ const Result r = FlushCiphertextWriteBuf();
+ if (r != ResultSuccess) {
+ return r;
+ }
+ // write buf is empty
+ const size_t cleartext_bytes_written = cleartext_write_buf.size();
+ cleartext_write_buf.clear();
+ return cleartext_bytes_written;
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ PCCERT_CONTEXT returned_cert = nullptr;
+ const SECURITY_STATUS ret =
+ QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL,
+ "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
+ Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ PCCERT_CONTEXT some_cert = nullptr;
+ std::vector<std::vector<u8>> certs;
+ while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
+ certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
+ static_cast<u8*>(some_cert->pbCertEncoded) +
+ some_cert->cbCertEncoded);
+ }
+ std::reverse(certs.begin(),
+ certs.end()); // Windows returns certs in reverse order from what we want
+ CertFreeCertificateContext(returned_cert);
+ return certs;
+ }
+
+ ~SSLConnectionBackendSchannel() {
+ if (handshake_state != HandshakeState::Initial) {
+ DeleteSecurityContext(&ctxt);
+ }
+ }
+
+ enum class HandshakeState {
+ // Haven't called anything yet.
+ Initial,
+ // `SEC_I_CONTINUE_NEEDED` was returned by
+ // `InitializeSecurityContext`; must finish sending data (if any) in
+ // the write buffer, then read at least one byte before calling
+ // `InitializeSecurityContext` again.
+ ContinueNeeded,
+ // `SEC_E_INCOMPLETE_MESSAGE` was returned by
+ // `InitializeSecurityContext`; hopefully the write buffer is empty;
+ // must read at least one byte before calling
+ // `InitializeSecurityContext` again.
+ IncompleteMessage,
+ // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
+ // finish sending data in the write buffer before having `DoHandshake`
+ // report success.
+ DoneAfterFlush,
+ // We finished the above and are now connected. At this point, writing
+ // and reading are separate 'state machines' represented by the
+ // nonemptiness of the ciphertext and cleartext read and write buffers.
+ Connected,
+ // Another error was returned and we shouldn't allow initialization
+ // to continue.
+ Error,
+ } handshake_state = HandshakeState::Initial;
+
+ CtxtHandle ctxt;
+ SecPkgContext_StreamSizes stream_sizes;
+
+ std::shared_ptr<Network::SocketBase> socket;
+ std::optional<std::string> hostname;
+
+ std::vector<u8> ciphertext_read_buf;
+ std::vector<u8> ciphertext_write_buf;
+ std::vector<u8> cleartext_read_buf;
+ std::vector<u8> cleartext_write_buf;
+
+ bool got_read_eof = false;
+ size_t read_buf_fill_size = 0;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendSchannel>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
new file mode 100644
index 000000000..b3083cbad
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
@@ -0,0 +1,222 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+// SecureTransport has been deprecated in its entirety in favor of
+// Network.framework, but that does not allow layering TLS on top of an
+// arbitrary socket.
+#if defined(__GNUC__) || defined(__clang__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#include <Security/SecureTransport.h>
+#pragma GCC diagnostic pop
+#endif
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+namespace {
+
+template <typename T>
+struct CFReleaser {
+ T ptr;
+
+ YUZU_NON_COPYABLE(CFReleaser);
+ constexpr CFReleaser() : ptr(nullptr) {}
+ constexpr CFReleaser(T ptr) : ptr(ptr) {}
+ constexpr operator T() {
+ return ptr;
+ }
+ ~CFReleaser() {
+ if (ptr) {
+ CFRelease(ptr);
+ }
+ }
+};
+
+std::string CFStringToString(CFStringRef cfstr) {
+ CFReleaser<CFDataRef> cfdata(
+ CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0));
+ ASSERT_OR_EXECUTE(cfdata, { return "???"; });
+ return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)),
+ CFDataGetLength(cfdata));
+}
+
+std::string OSStatusToString(OSStatus status) {
+ CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr));
+ if (!cfstr) {
+ return "[unknown error]";
+ }
+ return CFStringToString(cfstr);
+}
+
+} // namespace
+
+namespace Service::SSL {
+
+class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ static std::once_flag once_flag;
+ std::call_once(once_flag, []() {
+ if (getenv("SSLKEYLOGFILE")) {
+ LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not "
+ "support exporting keys; not logging keys!");
+ // Not fatal.
+ }
+ });
+
+ context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType);
+ if (!context) {
+ LOG_ERROR(Service_SSL, "SSLCreateContext failed");
+ return ResultInternalError;
+ }
+
+ OSStatus status;
+ if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) ||
+ (status = SSLSetConnection(context, this))) {
+ LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}",
+ OSStatusToString(status));
+ return ResultInternalError;
+ }
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override {
+ socket = std::move(in_socket);
+ }
+
+ Result SetHostName(const std::string& hostname) override {
+ OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size());
+ if (status) {
+ LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status));
+ return ResultInternalError;
+ }
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ OSStatus status = SSLHandshake(context);
+ return HandleReturn("SSLHandshake", 0, status).Code();
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ size_t actual;
+ OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
+ ;
+ return HandleReturn("SSLRead", actual, status);
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ size_t actual;
+ OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
+ ;
+ return HandleReturn("SSLWrite", actual, status);
+ }
+
+ ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
+ switch (status) {
+ case 0:
+ return actual;
+ case errSSLWouldBlock:
+ return ResultWouldBlock;
+ default: {
+ std::string reason;
+ if (got_read_eof) {
+ reason = "server hung up";
+ } else {
+ reason = OSStatusToString(status);
+ }
+ LOG_ERROR(Service_SSL, "{} failed: {}", what, reason);
+ return ResultInternalError;
+ }
+ }
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ CFReleaser<SecTrustRef> trust;
+ OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
+ if (status) {
+ LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
+ return ResultInternalError;
+ }
+ std::vector<std::vector<u8>> ret;
+ for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
+ SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
+ CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
+ ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
+ const u8* ptr = CFDataGetBytePtr(data);
+ ret.emplace_back(ptr, ptr + CFDataGetLength(data));
+ }
+ return ret;
+ }
+
+ static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
+ return ReadOrWriteCallback(connection, data, dataLength, true);
+ }
+
+ static OSStatus WriteCallback(SSLConnectionRef connection, const void* data,
+ size_t* dataLength) {
+ return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false);
+ }
+
+ static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength,
+ bool is_read) {
+ auto self =
+ static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket",
+ is_read ? "read" : "write");
+
+ // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
+ // expected to read/write the full requested dataLength or return an
+ // error, so we have to add a loop ourselves.
+ size_t requested_len = *dataLength;
+ size_t offset = 0;
+ while (offset < requested_len) {
+ std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset);
+ auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0);
+ LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset,
+ actual, cur.size(), static_cast<s32>(err));
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ offset += actual;
+ if (actual == 0) {
+ ASSERT(is_read);
+ self->got_read_eof = true;
+ return errSecEndOfData;
+ }
+ break;
+ case Network::Errno::AGAIN:
+ *dataLength = offset;
+ return errSSLWouldBlock;
+ default:
+ LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}",
+ is_read ? "recv" : "send", err);
+ return errSecIO;
+ }
+ }
+ ASSERT(offset == requested_len);
+ return 0;
+ }
+
+private:
+ CFReleaser<SSLContextRef> context = nullptr;
+ bool got_read_eof = false;
+
+ std::shared_ptr<Network::SocketBase> socket;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp
index 75ac10a9c..28f89c599 100644
--- a/src/core/internal_network/network.cpp
+++ b/src/core/internal_network/network.cpp
@@ -27,6 +27,7 @@
#include "common/assert.h"
#include "common/common_types.h"
+#include "common/expected.h"
#include "common/logging/log.h"
#include "common/settings.h"
#include "core/internal_network/network.h"
@@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) {
Errno TranslateNativeError(int e) {
switch (e) {
+ case 0:
+ return Errno::SUCCESS;
case WSAEBADF:
return Errno::BADF;
case WSAEINVAL:
@@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) {
return Errno::MSGSIZE;
case WSAETIMEDOUT:
return Errno::TIMEDOUT;
+ case WSAEINPROGRESS:
+ return Errno::INPROGRESS;
default:
UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
return Errno::OTHER;
@@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) {
Errno TranslateNativeError(int e) {
switch (e) {
+ case 0:
+ return Errno::SUCCESS;
case EBADF:
return Errno::BADF;
case EINVAL:
@@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) {
return Errno::MSGSIZE;
case ETIMEDOUT:
return Errno::TIMEDOUT;
+ case EINPROGRESS:
+ return Errno::INPROGRESS;
default:
- UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
+ UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e));
return Errno::OTHER;
}
}
@@ -234,15 +243,84 @@ Errno GetAndLogLastError() {
int e = errno;
#endif
const Errno err = TranslateNativeError(e);
- if (err == Errno::AGAIN || err == Errno::TIMEDOUT) {
+ if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) {
+ // These happen during normal operation, so only log them at debug level.
+ LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
return err;
}
LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
return err;
}
-int TranslateDomain(Domain domain) {
+GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) {
+ switch (gai_err) {
+ case 0:
+ return GetAddrInfoError::SUCCESS;
+#ifdef EAI_ADDRFAMILY
+ case EAI_ADDRFAMILY:
+ return GetAddrInfoError::ADDRFAMILY;
+#endif
+ case EAI_AGAIN:
+ return GetAddrInfoError::AGAIN;
+ case EAI_BADFLAGS:
+ return GetAddrInfoError::BADFLAGS;
+ case EAI_FAIL:
+ return GetAddrInfoError::FAIL;
+ case EAI_FAMILY:
+ return GetAddrInfoError::FAMILY;
+ case EAI_MEMORY:
+ return GetAddrInfoError::MEMORY;
+ case EAI_NONAME:
+ return GetAddrInfoError::NONAME;
+ case EAI_SERVICE:
+ return GetAddrInfoError::SERVICE;
+ case EAI_SOCKTYPE:
+ return GetAddrInfoError::SOCKTYPE;
+ // These codes may not be defined on all systems:
+#ifdef EAI_SYSTEM
+ case EAI_SYSTEM:
+ return GetAddrInfoError::SYSTEM;
+#endif
+#ifdef EAI_BADHINTS
+ case EAI_BADHINTS:
+ return GetAddrInfoError::BADHINTS;
+#endif
+#ifdef EAI_PROTOCOL
+ case EAI_PROTOCOL:
+ return GetAddrInfoError::PROTOCOL;
+#endif
+#ifdef EAI_OVERFLOW
+ case EAI_OVERFLOW:
+ return GetAddrInfoError::OVERFLOW_;
+#endif
+ default:
+#ifdef EAI_NODATA
+ // This can't be a case statement because it would create a duplicate
+ // case on Windows where EAI_NODATA is an alias for EAI_NONAME.
+ if (gai_err == EAI_NODATA) {
+ return GetAddrInfoError::NODATA;
+ }
+#endif
+ return GetAddrInfoError::OTHER;
+ }
+}
+
+Domain TranslateDomainFromNative(int domain) {
+ switch (domain) {
+ case 0:
+ return Domain::Unspecified;
+ case AF_INET:
+ return Domain::INET;
+ default:
+ UNIMPLEMENTED_MSG("Unhandled domain={}", domain);
+ return Domain::INET;
+ }
+}
+
+int TranslateDomainToNative(Domain domain) {
switch (domain) {
+ case Domain::Unspecified:
+ return 0;
case Domain::INET:
return AF_INET;
default:
@@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) {
}
}
-int TranslateType(Type type) {
+Type TranslateTypeFromNative(int type) {
+ switch (type) {
+ case 0:
+ return Type::Unspecified;
+ case SOCK_STREAM:
+ return Type::STREAM;
+ case SOCK_DGRAM:
+ return Type::DGRAM;
+ case SOCK_RAW:
+ return Type::RAW;
+ case SOCK_SEQPACKET:
+ return Type::SEQPACKET;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented type={}", type);
+ return Type::STREAM;
+ }
+}
+
+int TranslateTypeToNative(Type type) {
switch (type) {
+ case Type::Unspecified:
+ return 0;
case Type::STREAM:
return SOCK_STREAM;
case Type::DGRAM:
return SOCK_DGRAM;
+ case Type::RAW:
+ return SOCK_RAW;
default:
UNIMPLEMENTED_MSG("Unimplemented type={}", type);
return 0;
}
}
-int TranslateProtocol(Protocol protocol) {
+Protocol TranslateProtocolFromNative(int protocol) {
+ switch (protocol) {
+ case 0:
+ return Protocol::Unspecified;
+ case IPPROTO_TCP:
+ return Protocol::TCP;
+ case IPPROTO_UDP:
+ return Protocol::UDP;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
+ return Protocol::Unspecified;
+ }
+}
+
+int TranslateProtocolToNative(Protocol protocol) {
switch (protocol) {
+ case Protocol::Unspecified:
+ return 0;
case Protocol::TCP:
return IPPROTO_TCP;
case Protocol::UDP:
@@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) {
}
}
-SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
- sockaddr_in input;
- std::memcpy(&input, &input_, sizeof(input));
-
+SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) {
SockAddrIn result;
- switch (input.sin_family) {
- case AF_INET:
- result.family = Domain::INET;
- break;
- default:
- UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family);
- result.family = Domain::INET;
- break;
- }
+ result.family = TranslateDomainFromNative(input.sin_family);
result.portno = ntohs(input.sin_port);
@@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
short TranslatePollEvents(PollEvents events) {
short result = 0;
- if (True(events & PollEvents::In)) {
- events &= ~PollEvents::In;
- result |= POLLIN;
- }
- if (True(events & PollEvents::Pri)) {
- events &= ~PollEvents::Pri;
+ const auto translate = [&result, &events](PollEvents guest, short host) {
+ if (True(events & guest)) {
+ events &= ~guest;
+ result |= host;
+ }
+ };
+
+ translate(PollEvents::In, POLLIN);
+ translate(PollEvents::Pri, POLLPRI);
+ translate(PollEvents::Out, POLLOUT);
+ translate(PollEvents::Err, POLLERR);
+ translate(PollEvents::Hup, POLLHUP);
+ translate(PollEvents::Nval, POLLNVAL);
+ translate(PollEvents::RdNorm, POLLRDNORM);
+ translate(PollEvents::RdBand, POLLRDBAND);
+ translate(PollEvents::WrBand, POLLWRBAND);
+
#ifdef _WIN32
- LOG_WARNING(Service, "Winsock doesn't support POLLPRI");
-#else
- result |= POLLPRI;
+ short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM;
+ // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input.
+ if (result & ~allowed_events) {
+ LOG_DEBUG(Network,
+ "Removing WSAPoll input events 0x{:x} because Windows doesn't support them",
+ result & ~allowed_events);
+ }
+ result &= allowed_events;
#endif
- }
- if (True(events & PollEvents::Out)) {
- events &= ~PollEvents::Out;
- result |= POLLOUT;
- }
UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events);
@@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) {
translate(POLLOUT, PollEvents::Out);
translate(POLLERR, PollEvents::Err);
translate(POLLHUP, PollEvents::Hup);
+ translate(POLLNVAL, PollEvents::Nval);
+ translate(POLLRDNORM, PollEvents::RdNorm);
+ translate(POLLRDBAND, PollEvents::RdBand);
+ translate(POLLWRBAND, PollEvents::WrBand);
UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents);
@@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() {
return {};
}
- std::array<char, 16> ip_addr = {};
- ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) !=
- nullptr);
return TranslateIPv4(network_interface->ip_address);
}
+std::string IPv4AddressToString(IPv4Address ip_addr) {
+ std::array<char, INET_ADDRSTRLEN> buf = {};
+ ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data());
+ return std::string(buf.data());
+}
+
+u32 IPv4AddressToInteger(IPv4Address ip_addr) {
+ return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 |
+ static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]);
+}
+
+Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
+ const std::string& host, const std::optional<std::string>& service) {
+ addrinfo hints{};
+ hints.ai_family = AF_INET; // Switch only supports IPv4.
+ addrinfo* addrinfo;
+ s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr,
+ &hints, &addrinfo);
+ if (gai_err != 0) {
+ return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err));
+ }
+ std::vector<AddrInfo> ret;
+ for (auto* current = addrinfo; current; current = current->ai_next) {
+ // We should only get AF_INET results due to the hints value.
+ ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET &&
+ addrinfo->ai_addrlen == sizeof(sockaddr_in),
+ continue;);
+
+ AddrInfo& out = ret.emplace_back();
+ out.family = TranslateDomainFromNative(current->ai_family);
+ out.socket_type = TranslateTypeFromNative(current->ai_socktype);
+ out.protocol = TranslateProtocolFromNative(current->ai_protocol);
+ out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr),
+ current->ai_addrlen);
+ if (current->ai_canonname != nullptr) {
+ out.canon_name = current->ai_canonname;
+ }
+ }
+ freeaddrinfo(addrinfo);
+ return ret;
+}
+
std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
const size_t num = pollfds.size();
@@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept {
}
template <typename T>
-Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
+std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) {
+ T value{};
+ socklen_t len = sizeof(value);
+ const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
+ if (result != SOCKET_ERROR) {
+ ASSERT(len == sizeof(value));
+ return {value, Errno::SUCCESS};
+ }
+ return {value, GetAndLogLastError()};
+}
+
+template <typename T>
+Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) {
const int result =
- setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
+ setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
if (result != SOCKET_ERROR) {
return Errno::SUCCESS;
}
@@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
}
Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
- fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol));
+ fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type),
+ TranslateProtocolToNative(protocol));
if (fd != INVALID_SOCKET) {
return Errno::SUCCESS;
}
@@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
}
std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
- sockaddr addr;
+ sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
- const SOCKET new_socket = accept(fd, &addr, &addrlen);
+ const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
if (new_socket == INVALID_SOCKET) {
return {AcceptResult{}, GetAndLogLastError()};
}
- ASSERT(addrlen == sizeof(sockaddr_in));
-
AcceptResult result{
.socket = std::make_unique<Socket>(new_socket),
- .sockaddr_in = TranslateToSockAddrIn(addr),
+ .sockaddr_in = TranslateToSockAddrIn(addr, addrlen),
};
return {std::move(result), Errno::SUCCESS};
@@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) {
}
std::pair<SockAddrIn, Errno> Socket::GetPeerName() {
- sockaddr addr;
+ sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
- if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) {
+ if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
return {SockAddrIn{}, GetAndLogLastError()};
}
- ASSERT(addrlen == sizeof(sockaddr_in));
- return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
+ return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
}
std::pair<SockAddrIn, Errno> Socket::GetSockName() {
- sockaddr addr;
+ sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
- if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) {
+ if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
return {SockAddrIn{}, GetAndLogLastError()};
}
- ASSERT(addrlen == sizeof(sockaddr_in));
- return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
+ return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
}
Errno Socket::Bind(SockAddrIn addr) {
@@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) {
return GetAndLogLastError();
}
-std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
+std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) {
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
return {-1, GetAndLogLastError()};
}
-std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) {
+std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
- sockaddr addr_in{};
+ sockaddr_in addr_in{};
socklen_t addrlen = sizeof(addr_in);
socklen_t* const p_addrlen = addr ? &addrlen : nullptr;
- sockaddr* const p_addr_in = addr ? &addr_in : nullptr;
+ sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr;
const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()),
static_cast<int>(message.size()), 0, p_addr_in, p_addrlen);
if (result != SOCKET_ERROR) {
if (addr) {
- ASSERT(addrlen == sizeof(addr_in));
- *addr = TranslateToSockAddrIn(addr_in);
+ *addr = TranslateToSockAddrIn(addr_in, addrlen);
}
return {static_cast<s32>(result), Errno::SUCCESS};
}
@@ -597,6 +764,11 @@ Errno Socket::Close() {
return Errno::SUCCESS;
}
+std::pair<Errno, Errno> Socket::GetPendingError() {
+ auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR);
+ return {TranslateNativeError(pending_err), getsockopt_err};
+}
+
Errno Socket::SetLinger(bool enable, u32 linger) {
return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger));
}
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h
index 1e09a007a..badcb8369 100644
--- a/src/core/internal_network/network.h
+++ b/src/core/internal_network/network.h
@@ -5,6 +5,7 @@
#include <array>
#include <optional>
+#include <vector>
#include "common/common_funcs.h"
#include "common/common_types.h"
@@ -16,6 +17,11 @@
#include <netinet/in.h>
#endif
+namespace Common {
+template <typename T, typename E>
+class Expected;
+}
+
namespace Network {
class SocketBase;
@@ -36,6 +42,26 @@ enum class Errno {
NETUNREACH,
TIMEDOUT,
MSGSIZE,
+ INPROGRESS,
+ OTHER,
+};
+
+enum class GetAddrInfoError {
+ SUCCESS,
+ ADDRFAMILY,
+ AGAIN,
+ BADFLAGS,
+ FAIL,
+ FAMILY,
+ MEMORY,
+ NODATA,
+ NONAME,
+ SERVICE,
+ SOCKTYPE,
+ SYSTEM,
+ BADHINTS,
+ PROTOCOL,
+ OVERFLOW_,
OTHER,
};
@@ -49,6 +75,9 @@ enum class PollEvents : u16 {
Err = 1 << 3,
Hup = 1 << 4,
Nval = 1 << 5,
+ RdNorm = 1 << 6,
+ RdBand = 1 << 7,
+ WrBand = 1 << 8,
};
DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
@@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) {
/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array
std::optional<IPv4Address> GetHostIPv4Address();
+std::string IPv4AddressToString(IPv4Address ip_addr);
+u32 IPv4AddressToInteger(IPv4Address ip_addr);
+
+// named to avoid name collision with Windows macro
+Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
+ const std::string& host, const std::optional<std::string>& service);
+
} // namespace Network
diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp
index 7a77171c2..ce0dee970 100644
--- a/src/core/internal_network/socket_proxy.cpp
+++ b/src/core/internal_network/socket_proxy.cpp
@@ -10,6 +10,7 @@
#include "core/internal_network/network.h"
#include "core/internal_network/network_interface.h"
#include "core/internal_network/socket_proxy.h"
+#include "network/network.h"
#if YUZU_UNIX
#include <sys/socket.h>
@@ -98,7 +99,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) {
return Errno::SUCCESS;
}
-std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
+std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) {
LOG_WARNING(Network, "(STUBBED) called");
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -106,7 +107,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
return {static_cast<s32>(0), Errno::SUCCESS};
}
-std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) {
+std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -140,8 +141,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message,
}
}
-std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message,
- SockAddrIn* addr, std::size_t max_length) {
+std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
+ std::size_t max_length) {
ProxyPacket& packet = received_packets.front();
if (addr) {
addr->family = Domain::INET;
@@ -153,10 +154,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
std::size_t read_bytes;
if (packet.data.size() > max_length) {
read_bytes = max_length;
- message.clear();
- std::copy(packet.data.begin(), packet.data.begin() + read_bytes,
- std::back_inserter(message));
- message.resize(max_length);
+ memcpy(message.data(), packet.data.data(), max_length);
if (protocol == Protocol::UDP) {
if (!peek) {
@@ -171,9 +169,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
}
} else {
read_bytes = packet.data.size();
- message.clear();
- std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message));
- message.resize(max_length);
+ memcpy(message.data(), packet.data.data(), read_bytes);
if (!peek) {
received_packets.pop();
}
@@ -293,6 +289,11 @@ Errno ProxySocket::SetNonBlock(bool enable) {
return Errno::SUCCESS;
}
+std::pair<Errno, Errno> ProxySocket::GetPendingError() {
+ LOG_DEBUG(Network, "(STUBBED) called");
+ return {Errno::SUCCESS, Errno::SUCCESS};
+}
+
bool ProxySocket::IsOpened() const {
return fd != INVALID_SOCKET;
}
diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h
index 6e991fa38..70500cf4a 100644
--- a/src/core/internal_network/socket_proxy.h
+++ b/src/core/internal_network/socket_proxy.h
@@ -10,10 +10,12 @@
#include "common/common_funcs.h"
#include "core/internal_network/sockets.h"
-#include "network/network.h"
+#include "network/room_member.h"
namespace Network {
+class RoomNetwork;
+
class ProxySocket : public SocketBase {
public:
explicit ProxySocket(RoomNetwork& room_network_) noexcept;
@@ -39,11 +41,11 @@ public:
Errno Shutdown(ShutdownHow how) override;
- std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override;
+ std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
- std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override;
+ std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
- std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr,
+ std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
std::size_t max_length);
std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -74,6 +76,8 @@ public:
template <typename T>
Errno SetSockOpt(SOCKET fd, int option, T value);
+ std::pair<Errno, Errno> GetPendingError() override;
+
bool IsOpened() const override;
private:
diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h
index 11e479e50..4ba51f62c 100644
--- a/src/core/internal_network/sockets.h
+++ b/src/core/internal_network/sockets.h
@@ -15,12 +15,13 @@
#include "common/common_types.h"
#include "core/internal_network/network.h"
-#include "network/network.h"
// TODO: C++20 Replace std::vector usages with std::span
namespace Network {
+struct ProxyPacket;
+
class SocketBase {
public:
#ifdef YUZU_UNIX
@@ -59,10 +60,9 @@ public:
virtual Errno Shutdown(ShutdownHow how) = 0;
- virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0;
+ virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0;
- virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message,
- SockAddrIn* addr) = 0;
+ virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0;
virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0;
@@ -87,6 +87,8 @@ public:
virtual Errno SetNonBlock(bool enable) = 0;
+ virtual std::pair<Errno, Errno> GetPendingError() = 0;
+
virtual bool IsOpened() const = 0;
virtual void HandleProxyPacket(const ProxyPacket& packet) = 0;
@@ -126,9 +128,9 @@ public:
Errno Shutdown(ShutdownHow how) override;
- std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override;
+ std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
- std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override;
+ std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -156,6 +158,11 @@ public:
template <typename T>
Errno SetSockOpt(SOCKET fd, int option, T value);
+ std::pair<Errno, Errno> GetPendingError() override;
+
+ template <typename T>
+ std::pair<T, Errno> GetSockOpt(SOCKET fd, int option);
+
bool IsOpened() const override;
void HandleProxyPacket(const ProxyPacket& packet) override;
diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt
index 3b2fe01da..7f79111e0 100644
--- a/src/video_core/CMakeLists.txt
+++ b/src/video_core/CMakeLists.txt
@@ -274,6 +274,7 @@ add_library(video_core STATIC
vulkan_common/vulkan_wrapper.h
vulkan_common/nsight_aftermath_tracker.cpp
vulkan_common/nsight_aftermath_tracker.h
+ vulkan_common/vma.cpp
)
create_target_directory_groups(video_core)
@@ -291,7 +292,7 @@ target_link_options(video_core PRIVATE ${FFmpeg_LDFLAGS})
add_dependencies(video_core host_shaders)
target_include_directories(video_core PRIVATE ${HOST_SHADERS_INCLUDE})
-target_link_libraries(video_core PRIVATE sirit Vulkan::Headers vma)
+target_link_libraries(video_core PRIVATE sirit Vulkan::Headers GPUOpen::VulkanMemoryAllocator)
if (ENABLE_NSIGHT_AFTERMATH)
if (NOT DEFINED ENV{NSIGHT_AFTERMATH_SDK})
@@ -324,6 +325,9 @@ else()
# xbyak
set_source_files_properties(macro/macro_jit_x64.cpp PROPERTIES COMPILE_OPTIONS "-Wno-conversion;-Wno-shadow")
+
+ # VMA
+ set_source_files_properties(vulkan_common/vma.cpp PROPERTIES COMPILE_OPTIONS "-Wno-conversion;-Wno-unused-variable;-Wno-unused-parameter;-Wno-missing-field-initializers")
endif()
if (ARCHITECTURE_x86_64)
diff --git a/src/video_core/renderer_base.cpp b/src/video_core/renderer_base.cpp
index 2d3f58201..4002fa72b 100644
--- a/src/video_core/renderer_base.cpp
+++ b/src/video_core/renderer_base.cpp
@@ -38,8 +38,8 @@ void RendererBase::RequestScreenshot(void* data, std::function<void(bool)> callb
LOG_ERROR(Render, "A screenshot is already requested or in progress, ignoring the request");
return;
}
- auto async_callback{[callback = std::move(callback)](bool invert_y) {
- std::thread t{callback, invert_y};
+ auto async_callback{[callback_ = std::move(callback)](bool invert_y) {
+ std::thread t{callback_, invert_y};
t.detach();
}};
renderer_settings.screenshot_bits = data;
diff --git a/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp b/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp
index 23a48c6fe..71f720c63 100644
--- a/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp
+++ b/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp
@@ -231,24 +231,25 @@ GraphicsPipeline::GraphicsPipeline(const Device& device, TextureCache& texture_c
}
const bool in_parallel = thread_worker != nullptr;
const auto backend = device.GetShaderBackend();
- auto func{[this, sources = std::move(sources), sources_spirv = std::move(sources_spirv),
+ auto func{[this, sources_ = std::move(sources), sources_spirv_ = std::move(sources_spirv),
shader_notify, backend, in_parallel,
force_context_flush](ShaderContext::Context*) mutable {
for (size_t stage = 0; stage < 5; ++stage) {
switch (backend) {
case Settings::ShaderBackend::GLSL:
- if (!sources[stage].empty()) {
- source_programs[stage] = CreateProgram(sources[stage], Stage(stage));
+ if (!sources_[stage].empty()) {
+ source_programs[stage] = CreateProgram(sources_[stage], Stage(stage));
}
break;
case Settings::ShaderBackend::GLASM:
- if (!sources[stage].empty()) {
- assembly_programs[stage] = CompileProgram(sources[stage], AssemblyStage(stage));
+ if (!sources_[stage].empty()) {
+ assembly_programs[stage] =
+ CompileProgram(sources_[stage], AssemblyStage(stage));
}
break;
case Settings::ShaderBackend::SPIRV:
- if (!sources_spirv[stage].empty()) {
- source_programs[stage] = CreateProgram(sources_spirv[stage], Stage(stage));
+ if (!sources_spirv_[stage].empty()) {
+ source_programs[stage] = CreateProgram(sources_spirv_[stage], Stage(stage));
}
break;
}
diff --git a/src/video_core/renderer_opengl/gl_shader_cache.cpp b/src/video_core/renderer_opengl/gl_shader_cache.cpp
index 0329ed820..7e1d7f92e 100644
--- a/src/video_core/renderer_opengl/gl_shader_cache.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_cache.cpp
@@ -288,9 +288,9 @@ void ShaderCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading,
const auto load_compute{[&](std::ifstream& file, FileEnvironment env) {
ComputePipelineKey key;
file.read(reinterpret_cast<char*>(&key), sizeof(key));
- queue_work([this, key, env = std::move(env), &state, &callback](Context* ctx) mutable {
+ queue_work([this, key, env_ = std::move(env), &state, &callback](Context* ctx) mutable {
ctx->pools.ReleaseContents();
- auto pipeline{CreateComputePipeline(ctx->pools, key, env, true)};
+ auto pipeline{CreateComputePipeline(ctx->pools, key, env_, true)};
std::scoped_lock lock{state.mutex};
if (pipeline) {
compute_cache.emplace(key, std::move(pipeline));
@@ -305,9 +305,9 @@ void ShaderCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading,
const auto load_graphics{[&](std::ifstream& file, std::vector<FileEnvironment> envs) {
GraphicsPipelineKey key;
file.read(reinterpret_cast<char*>(&key), sizeof(key));
- queue_work([this, key, envs = std::move(envs), &state, &callback](Context* ctx) mutable {
+ queue_work([this, key, envs_ = std::move(envs), &state, &callback](Context* ctx) mutable {
boost::container::static_vector<Shader::Environment*, 5> env_ptrs;
- for (auto& env : envs) {
+ for (auto& env : envs_) {
env_ptrs.push_back(&env);
}
ctx->pools.ReleaseContents();
diff --git a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp
index 51df18ec3..f8cd2a5d8 100644
--- a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp
@@ -206,8 +206,8 @@ public:
const size_t sub_first_offset = static_cast<size_t>(first % 4) * GetQuadsNum(num_indices);
const size_t offset =
(sub_first_offset + GetQuadsNum(first)) * 6ULL * BytesPerIndex(index_type);
- scheduler.Record([buffer = *buffer, index_type_, offset](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindIndexBuffer(buffer, offset, index_type_);
+ scheduler.Record([buffer_ = *buffer, index_type_, offset](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindIndexBuffer(buffer_, offset, index_type_);
});
}
@@ -528,17 +528,18 @@ void BufferCacheRuntime::BindVertexBuffers(VideoCommon::HostBindings<Buffer>& bi
buffer_handles.push_back(handle);
}
if (device.IsExtExtendedDynamicStateSupported()) {
- scheduler.Record([bindings = std::move(bindings),
- buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindVertexBuffers2EXT(
- bindings.min_index, bindings.max_index - bindings.min_index, buffer_handles.data(),
- bindings.offsets.data(), bindings.sizes.data(), bindings.strides.data());
+ scheduler.Record([bindings_ = std::move(bindings),
+ buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindVertexBuffers2EXT(bindings_.min_index,
+ bindings_.max_index - bindings_.min_index,
+ buffer_handles_.data(), bindings_.offsets.data(),
+ bindings_.sizes.data(), bindings_.strides.data());
});
} else {
- scheduler.Record([bindings = std::move(bindings),
- buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindVertexBuffers(bindings.min_index, bindings.max_index - bindings.min_index,
- buffer_handles.data(), bindings.offsets.data());
+ scheduler.Record([bindings_ = std::move(bindings),
+ buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindVertexBuffers(bindings_.min_index, bindings_.max_index - bindings_.min_index,
+ buffer_handles_.data(), bindings_.offsets.data());
});
}
}
@@ -573,11 +574,11 @@ void BufferCacheRuntime::BindTransformFeedbackBuffers(VideoCommon::HostBindings<
for (u32 index = 0; index < bindings.buffers.size(); ++index) {
buffer_handles.push_back(bindings.buffers[index]->Handle());
}
- scheduler.Record([bindings = std::move(bindings),
- buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindTransformFeedbackBuffersEXT(0, static_cast<u32>(buffer_handles.size()),
- buffer_handles.data(), bindings.offsets.data(),
- bindings.sizes.data());
+ scheduler.Record([bindings_ = std::move(bindings),
+ buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindTransformFeedbackBuffersEXT(0, static_cast<u32>(buffer_handles_.size()),
+ buffer_handles_.data(), bindings_.offsets.data(),
+ bindings_.sizes.data());
});
}
diff --git a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
index d600c4e61..4f84d8497 100644
--- a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
@@ -469,9 +469,9 @@ void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading
ComputePipelineCacheKey key;
file.read(reinterpret_cast<char*>(&key), sizeof(key));
- workers.QueueWork([this, key, env = std::move(env), &state, &callback]() mutable {
+ workers.QueueWork([this, key, env_ = std::move(env), &state, &callback]() mutable {
ShaderPools pools;
- auto pipeline{CreateComputePipeline(pools, key, env, state.statistics.get(), false)};
+ auto pipeline{CreateComputePipeline(pools, key, env_, state.statistics.get(), false)};
std::scoped_lock lock{state.mutex};
if (pipeline) {
compute_cache.emplace(key, std::move(pipeline));
@@ -500,10 +500,10 @@ void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading
(key.state.dynamic_vertex_input != 0) != dynamic_features.has_dynamic_vertex_input) {
return;
}
- workers.QueueWork([this, key, envs = std::move(envs), &state, &callback]() mutable {
+ workers.QueueWork([this, key, envs_ = std::move(envs), &state, &callback]() mutable {
ShaderPools pools;
boost::container::static_vector<Shader::Environment*, 5> env_ptrs;
- for (auto& env : envs) {
+ for (auto& env : envs_) {
env_ptrs.push_back(&env);
}
auto pipeline{CreateGraphicsPipeline(pools, key, MakeSpan(env_ptrs),
@@ -702,8 +702,8 @@ std::unique_ptr<ComputePipeline> PipelineCache::CreateComputePipeline(
if (!pipeline || pipeline_cache_filename.empty()) {
return pipeline;
}
- serialization_thread.QueueWork([this, key, env = std::move(env)] {
- SerializePipeline(key, std::array<const GenericEnvironment*, 1>{&env},
+ serialization_thread.QueueWork([this, key, env_ = std::move(env)] {
+ SerializePipeline(key, std::array<const GenericEnvironment*, 1>{&env_},
pipeline_cache_filename, CACHE_VERSION);
});
return pipeline;
diff --git a/src/video_core/renderer_vulkan/vk_query_cache.cpp b/src/video_core/renderer_vulkan/vk_query_cache.cpp
index d67490449..29e0b797b 100644
--- a/src/video_core/renderer_vulkan/vk_query_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_query_cache.cpp
@@ -98,10 +98,10 @@ HostCounter::HostCounter(QueryCache& cache_, std::shared_ptr<HostCounter> depend
: HostCounterBase{std::move(dependency_)}, cache{cache_}, type{type_},
query{cache_.AllocateQuery(type_)}, tick{cache_.GetScheduler().CurrentTick()} {
const vk::Device* logical = &cache.GetDevice().GetLogical();
- cache.GetScheduler().Record([logical, query = query](vk::CommandBuffer cmdbuf) {
+ cache.GetScheduler().Record([logical, query_ = query](vk::CommandBuffer cmdbuf) {
const bool use_precise = Settings::IsGPULevelHigh();
- logical->ResetQueryPool(query.first, query.second, 1);
- cmdbuf.BeginQuery(query.first, query.second,
+ logical->ResetQueryPool(query_.first, query_.second, 1);
+ cmdbuf.BeginQuery(query_.first, query_.second,
use_precise ? VK_QUERY_CONTROL_PRECISE_BIT : 0);
});
}
@@ -111,8 +111,9 @@ HostCounter::~HostCounter() {
}
void HostCounter::EndQuery() {
- cache.GetScheduler().Record(
- [query = query](vk::CommandBuffer cmdbuf) { cmdbuf.EndQuery(query.first, query.second); });
+ cache.GetScheduler().Record([query_ = query](vk::CommandBuffer cmdbuf) {
+ cmdbuf.EndQuery(query_.first, query_.second);
+ });
}
u64 HostCounter::BlockingQuery(bool async) const {
diff --git a/src/video_core/renderer_vulkan/vk_texture_cache.cpp b/src/video_core/renderer_vulkan/vk_texture_cache.cpp
index 3aac3cfab..bf6ad6c79 100644
--- a/src/video_core/renderer_vulkan/vk_texture_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_texture_cache.cpp
@@ -1412,7 +1412,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS
}
scheduler->RequestOutsideRenderPassOperationContext();
scheduler->Record([buffers = std::move(buffers_vector), image = *original_image,
- aspect_mask = aspect_mask, vk_copies](vk::CommandBuffer cmdbuf) {
+ aspect_mask_ = aspect_mask, vk_copies](vk::CommandBuffer cmdbuf) {
const VkImageMemoryBarrier read_barrier{
.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER,
.pNext = nullptr,
@@ -1424,7 +1424,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS
.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
.image = image,
.subresourceRange{
- .aspectMask = aspect_mask,
+ .aspectMask = aspect_mask_,
.baseMipLevel = 0,
.levelCount = VK_REMAINING_MIP_LEVELS,
.baseArrayLayer = 0,
@@ -1456,7 +1456,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS
.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
.image = image,
.subresourceRange{
- .aspectMask = aspect_mask,
+ .aspectMask = aspect_mask_,
.baseMipLevel = 0,
.levelCount = VK_REMAINING_MIP_LEVELS,
.baseArrayLayer = 0,
diff --git a/externals/vma/vma.cpp b/src/video_core/vulkan_common/vma.cpp
index 1fe2cf52b..1fe2cf52b 100644
--- a/externals/vma/vma.cpp
+++ b/src/video_core/vulkan_common/vma.cpp
diff --git a/src/web_service/announce_room_json.cpp b/src/web_service/announce_room_json.cpp
index 4c3195efd..f1020a5b8 100644
--- a/src/web_service/announce_room_json.cpp
+++ b/src/web_service/announce_room_json.cpp
@@ -135,11 +135,11 @@ void RoomJson::Delete() {
LOG_ERROR(WebService, "Room must be registered to be deleted");
return;
}
- Common::DetachedTasks::AddTask(
- [host{this->host}, username{this->username}, token{this->token}, room_id{this->room_id}]() {
- // create a new client here because the this->client might be destroyed.
- Client{host, username, token}.DeleteJson(fmt::format("/lobby/{}", room_id), "", false);
- });
+ Common::DetachedTasks::AddTask([host_{this->host}, username_{this->username},
+ token_{this->token}, room_id_{this->room_id}]() {
+ // create a new client here because the this->client might be destroyed.
+ Client{host_, username_, token_}.DeleteJson(fmt::format("/lobby/{}", room_id_), "", false);
+ });
}
} // namespace WebService