diff options
Diffstat (limited to 'src/core/hle/service/ssl/ssl.cpp')
-rw-r--r-- | src/core/hle/service/ssl/ssl.cpp | 353 |
1 files changed, 336 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 2b99dd7ac..9c96f9763 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -1,10 +1,18 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "common/string_util.h" + +#include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/server_manager.h" #include "core/hle/service/service.h" +#include "core/hle/service/sm/sm.h" +#include "core/hle/service/sockets/bsd.h" #include "core/hle/service/ssl/ssl.h" +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" namespace Service::SSL { @@ -20,6 +28,18 @@ enum class ContextOption : u32 { CrlImportDateCheckEnable = 1, }; +// This is nn::ssl::Connection::IoMode +enum class IoMode : u32 { + Blocking = 1, + NonBlocking = 2, +}; + +// This is nn::ssl::sf::OptionType +enum class OptionType : u32 { + DoNotCloseSocket = 0, + GetServerCertChain = 1, +}; + // This is nn::ssl::sf::SslVersion struct SslVersion { union { @@ -34,35 +54,42 @@ struct SslVersion { }; }; +struct SslContextSharedData { + u32 connection_count = 0; +}; + class ISslConnection final : public ServiceFramework<ISslConnection> { public: - explicit ISslConnection(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { + explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in, + std::shared_ptr<SslContextSharedData>& shared_data_in, + std::unique_ptr<SSLConnectionBackend>&& backend_in) + : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in}, + shared_data{shared_data_in}, backend{std::move(backend_in)} { // clang-format off static const FunctionInfo functions[] = { - {0, nullptr, "SetSocketDescriptor"}, - {1, nullptr, "SetHostName"}, - {2, nullptr, "SetVerifyOption"}, - {3, nullptr, "SetIoMode"}, + {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, + {1, &ISslConnection::SetHostName, "SetHostName"}, + {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, + {3, &ISslConnection::SetIoMode, "SetIoMode"}, {4, nullptr, "GetSocketDescriptor"}, {5, nullptr, "GetHostName"}, {6, nullptr, "GetVerifyOption"}, {7, nullptr, "GetIoMode"}, - {8, nullptr, "DoHandshake"}, - {9, nullptr, "DoHandshakeGetServerCert"}, - {10, nullptr, "Read"}, - {11, nullptr, "Write"}, - {12, nullptr, "Pending"}, + {8, &ISslConnection::DoHandshake, "DoHandshake"}, + {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, + {10, &ISslConnection::Read, "Read"}, + {11, &ISslConnection::Write, "Write"}, + {12, &ISslConnection::Pending, "Pending"}, {13, nullptr, "Peek"}, {14, nullptr, "Poll"}, {15, nullptr, "GetVerifyCertError"}, {16, nullptr, "GetNeededServerCertBufferSize"}, - {17, nullptr, "SetSessionCacheMode"}, + {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, {18, nullptr, "GetSessionCacheMode"}, {19, nullptr, "FlushSessionCache"}, {20, nullptr, "SetRenegotiationMode"}, {21, nullptr, "GetRenegotiationMode"}, - {22, nullptr, "SetOption"}, + {22, &ISslConnection::SetOption, "SetOption"}, {23, nullptr, "GetOption"}, {24, nullptr, "GetVerifyCertErrors"}, {25, nullptr, "GetCipherInfo"}, @@ -80,21 +107,299 @@ public: // clang-format on RegisterHandlers(functions); + + shared_data->connection_count++; + } + + ~ISslConnection() { + shared_data->connection_count--; + if (fd_to_close.has_value()) { + const s32 fd = *fd_to_close; + if (!do_not_close_socket) { + LOG_ERROR(Service_SSL, + "do_not_close_socket was changed after setting socket; is this right?"); + } else { + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + if (bsd) { + auto err = bsd->CloseImpl(fd); + if (err != Service::Sockets::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err); + } + } + } + } } private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data; + std::unique_ptr<SSLConnectionBackend> backend; + std::optional<int> fd_to_close; + bool do_not_close_socket = false; + bool get_server_cert_chain = false; + std::shared_ptr<Network::SocketBase> socket; + bool did_set_host_name = false; + bool did_handshake = false; + + ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { + LOG_DEBUG(Service_SSL, "called, fd={}", fd); + ASSERT(!did_handshake); + auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); + ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); + s32 ret_fd; + // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor + if (do_not_close_socket) { + auto res = bsd->DuplicateSocketImpl(fd); + if (!res.has_value()) { + LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); + return ResultInvalidSocket; + } + fd = *res; + fd_to_close = fd; + ret_fd = fd; + } else { + ret_fd = -1; + } + std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); + if (!sock.has_value()) { + LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); + return ResultInvalidSocket; + } + socket = std::move(*sock); + backend->SetSocket(socket); + return ret_fd; + } + + Result SetHostNameImpl(const std::string& hostname) { + LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); + ASSERT(!did_handshake); + Result res = backend->SetHostName(hostname); + if (res == ResultSuccess) { + did_set_host_name = true; + } + return res; + } + + Result SetVerifyOptionImpl(u32 option) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); + return ResultSuccess; + } + + Result SetIoModeImpl(u32 input_mode) { + auto mode = static_cast<IoMode>(input_mode); + ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); + ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); + + const bool non_block = mode == IoMode::NonBlocking; + const Network::Errno error = socket->SetNonBlock(non_block); + if (error != Network::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); + } + return ResultSuccess; + } + + Result SetSessionCacheModeImpl(u32 mode) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); + return ResultSuccess; + } + + Result DoHandshakeImpl() { + ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE_MSG( + did_set_host_name, { return ResultInternalError; }, + "Expected SetHostName before DoHandshake"); + Result res = backend->DoHandshake(); + did_handshake = res.IsSuccess(); + return res; + } + + std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { + struct Header { + u64 magic; + u32 count; + u32 pad; + }; + struct EntryHeader { + u32 size; + u32 offset; + }; + if (!get_server_cert_chain) { + // Just return the first one, unencoded. + ASSERT_OR_EXECUTE_MSG( + !certs.empty(), { return {}; }, "Should be at least one server cert"); + return certs[0]; + } + std::vector<u8> ret; + Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; + ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); + size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); + for (auto& cert : certs) { + EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; + data_offset += cert.size(); + ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), + reinterpret_cast<u8*>(&entry_header + 1)); + } + for (auto& cert : certs) { + ret.insert(ret.end(), cert.begin(), cert.end()); + } + return ret; + } + + ResultVal<std::vector<u8>> ReadImpl(size_t size) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + std::vector<u8> res(size); + ResultVal<size_t> actual = backend->Read(res); + if (actual.Failed()) { + return actual.Code(); + } + res.resize(*actual); + return res; + } + + ResultVal<size_t> WriteImpl(std::span<const u8> data) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + return backend->Write(data); + } + + ResultVal<s32> PendingImpl() { + LOG_WARNING(Service_SSL, "(STUBBED) called."); + return 0; + } + + void SetSocketDescriptor(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const s32 fd = rp.Pop<s32>(); + const ResultVal<s32> res = SetSocketDescriptorImpl(fd); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(-1)); + } + + void SetHostName(HLERequestContext& ctx) { + const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); + const Result res = SetHostNameImpl(hostname); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetVerifyOption(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 option = rp.Pop<u32>(); + const Result res = SetVerifyOptionImpl(option); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetIoMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetIoModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshake(HLERequestContext& ctx) { + const Result res = DoHandshakeImpl(); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshakeGetServerCert(HLERequestContext& ctx) { + struct OutputParameters { + u32 certs_size; + u32 certs_count; + }; + static_assert(sizeof(OutputParameters) == 0x8); + + const Result res = DoHandshakeImpl(); + OutputParameters out{}; + if (res == ResultSuccess) { + auto certs = backend->GetServerCerts(); + if (certs.Succeeded()) { + const std::vector<u8> certs_buf = SerializeServerCerts(*certs); + ctx.WriteBuffer(certs_buf); + out.certs_count = static_cast<u32>(certs->size()); + out.certs_size = static_cast<u32>(certs_buf.size()); + } + } + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(res); + rb.PushRaw(out); + } + + void Read(HLERequestContext& ctx) { + const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + if (res.Succeeded()) { + rb.Push(static_cast<u32>(res->size())); + ctx.WriteBuffer(*res); + } else { + rb.Push(static_cast<u32>(0)); + } + } + + void Write(HLERequestContext& ctx) { + const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(static_cast<u32>(res.ValueOr(0))); + } + + void Pending(HLERequestContext& ctx) { + const ResultVal<s32> res = PendingImpl(); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push<s32>(res.ValueOr(0)); + } + + void SetSessionCacheMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop<u32>(); + const Result res = SetSessionCacheModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetOption(HLERequestContext& ctx) { + struct Parameters { + OptionType option; + s32 value; + }; + static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw<Parameters>(); + + switch (parameters.option) { + case OptionType::DoNotCloseSocket: + do_not_close_socket = static_cast<bool>(parameters.value); + break; + case OptionType::GetServerCertChain: + get_server_cert_chain = static_cast<bool>(parameters.value); + break; + default: + LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, + parameters.value); + } + + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(ResultSuccess); + } }; class ISslContext final : public ServiceFramework<ISslContext> { public: explicit ISslContext(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { + : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, + shared_data{std::make_shared<SslContextSharedData>()} { static const FunctionInfo functions[] = { {0, &ISslContext::SetOption, "SetOption"}, {1, nullptr, "GetOption"}, {2, &ISslContext::CreateConnection, "CreateConnection"}, - {3, nullptr, "GetConnectionCount"}, + {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, {4, &ISslContext::ImportServerPki, "ImportServerPki"}, {5, &ISslContext::ImportClientPki, "ImportClientPki"}, {6, nullptr, "RemoveServerPki"}, @@ -111,6 +416,7 @@ public: private: SslVersion ssl_version; + std::shared_ptr<SslContextSharedData> shared_data; void SetOption(HLERequestContext& ctx) { struct Parameters { @@ -130,11 +436,24 @@ private: } void CreateConnection(HLERequestContext& ctx) { - LOG_WARNING(Service_SSL, "(STUBBED) called"); + LOG_WARNING(Service_SSL, "called"); + + auto backend_res = CreateSSLConnectionBackend(); IPC::ResponseBuilder rb{ctx, 2, 0, 1}; + rb.Push(backend_res.Code()); + if (backend_res.Succeeded()) { + rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, + std::move(*backend_res)); + } + } + + void GetConnectionCount(HLERequestContext& ctx) { + LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); + + IPC::ResponseBuilder rb{ctx, 3}; rb.Push(ResultSuccess); - rb.PushIpcInterface<ISslConnection>(system, ssl_version); + rb.Push(shared_data->connection_count); } void ImportServerPki(HLERequestContext& ctx) { |