summaryrefslogtreecommitdiffstats
path: root/src/core/hle/service/ssl/ssl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/hle/service/ssl/ssl.cpp')
-rw-r--r--src/core/hle/service/ssl/ssl.cpp353
1 files changed, 336 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..9c96f9763 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
+#include "common/string_util.h"
+
+#include "core/core.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/server_manager.h"
#include "core/hle/service/service.h"
+#include "core/hle/service/sm/sm.h"
+#include "core/hle/service/sockets/bsd.h"
#include "core/hle/service/ssl/ssl.h"
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
namespace Service::SSL {
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
CrlImportDateCheckEnable = 1,
};
+// This is nn::ssl::Connection::IoMode
+enum class IoMode : u32 {
+ Blocking = 1,
+ NonBlocking = 2,
+};
+
+// This is nn::ssl::sf::OptionType
+enum class OptionType : u32 {
+ DoNotCloseSocket = 0,
+ GetServerCertChain = 1,
+};
+
// This is nn::ssl::sf::SslVersion
struct SslVersion {
union {
@@ -34,35 +54,42 @@ struct SslVersion {
};
};
+struct SslContextSharedData {
+ u32 connection_count = 0;
+};
+
class ISslConnection final : public ServiceFramework<ISslConnection> {
public:
- explicit ISslConnection(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} {
+ explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in,
+ std::shared_ptr<SslContextSharedData>& shared_data_in,
+ std::unique_ptr<SSLConnectionBackend>&& backend_in)
+ : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in},
+ shared_data{shared_data_in}, backend{std::move(backend_in)} {
// clang-format off
static const FunctionInfo functions[] = {
- {0, nullptr, "SetSocketDescriptor"},
- {1, nullptr, "SetHostName"},
- {2, nullptr, "SetVerifyOption"},
- {3, nullptr, "SetIoMode"},
+ {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
+ {1, &ISslConnection::SetHostName, "SetHostName"},
+ {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
+ {3, &ISslConnection::SetIoMode, "SetIoMode"},
{4, nullptr, "GetSocketDescriptor"},
{5, nullptr, "GetHostName"},
{6, nullptr, "GetVerifyOption"},
{7, nullptr, "GetIoMode"},
- {8, nullptr, "DoHandshake"},
- {9, nullptr, "DoHandshakeGetServerCert"},
- {10, nullptr, "Read"},
- {11, nullptr, "Write"},
- {12, nullptr, "Pending"},
+ {8, &ISslConnection::DoHandshake, "DoHandshake"},
+ {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
+ {10, &ISslConnection::Read, "Read"},
+ {11, &ISslConnection::Write, "Write"},
+ {12, &ISslConnection::Pending, "Pending"},
{13, nullptr, "Peek"},
{14, nullptr, "Poll"},
{15, nullptr, "GetVerifyCertError"},
{16, nullptr, "GetNeededServerCertBufferSize"},
- {17, nullptr, "SetSessionCacheMode"},
+ {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
{18, nullptr, "GetSessionCacheMode"},
{19, nullptr, "FlushSessionCache"},
{20, nullptr, "SetRenegotiationMode"},
{21, nullptr, "GetRenegotiationMode"},
- {22, nullptr, "SetOption"},
+ {22, &ISslConnection::SetOption, "SetOption"},
{23, nullptr, "GetOption"},
{24, nullptr, "GetVerifyCertErrors"},
{25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,299 @@ public:
// clang-format on
RegisterHandlers(functions);
+
+ shared_data->connection_count++;
+ }
+
+ ~ISslConnection() {
+ shared_data->connection_count--;
+ if (fd_to_close.has_value()) {
+ const s32 fd = *fd_to_close;
+ if (!do_not_close_socket) {
+ LOG_ERROR(Service_SSL,
+ "do_not_close_socket was changed after setting socket; is this right?");
+ } else {
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ if (bsd) {
+ auto err = bsd->CloseImpl(fd);
+ if (err != Service::Sockets::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err);
+ }
+ }
+ }
+ }
}
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
+ std::unique_ptr<SSLConnectionBackend> backend;
+ std::optional<int> fd_to_close;
+ bool do_not_close_socket = false;
+ bool get_server_cert_chain = false;
+ std::shared_ptr<Network::SocketBase> socket;
+ bool did_set_host_name = false;
+ bool did_handshake = false;
+
+ ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
+ LOG_DEBUG(Service_SSL, "called, fd={}", fd);
+ ASSERT(!did_handshake);
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
+ s32 ret_fd;
+ // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
+ if (do_not_close_socket) {
+ auto res = bsd->DuplicateSocketImpl(fd);
+ if (!res.has_value()) {
+ LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ fd = *res;
+ fd_to_close = fd;
+ ret_fd = fd;
+ } else {
+ ret_fd = -1;
+ }
+ std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
+ if (!sock.has_value()) {
+ LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ socket = std::move(*sock);
+ backend->SetSocket(socket);
+ return ret_fd;
+ }
+
+ Result SetHostNameImpl(const std::string& hostname) {
+ LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
+ ASSERT(!did_handshake);
+ Result res = backend->SetHostName(hostname);
+ if (res == ResultSuccess) {
+ did_set_host_name = true;
+ }
+ return res;
+ }
+
+ Result SetVerifyOptionImpl(u32 option) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
+ return ResultSuccess;
+ }
+
+ Result SetIoModeImpl(u32 input_mode) {
+ auto mode = static_cast<IoMode>(input_mode);
+ ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
+ ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
+
+ const bool non_block = mode == IoMode::NonBlocking;
+ const Network::Errno error = socket->SetNonBlock(non_block);
+ if (error != Network::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
+ }
+ return ResultSuccess;
+ }
+
+ Result SetSessionCacheModeImpl(u32 mode) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
+ return ResultSuccess;
+ }
+
+ Result DoHandshakeImpl() {
+ ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
+ ASSERT_OR_EXECUTE_MSG(
+ did_set_host_name, { return ResultInternalError; },
+ "Expected SetHostName before DoHandshake");
+ Result res = backend->DoHandshake();
+ did_handshake = res.IsSuccess();
+ return res;
+ }
+
+ std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
+ struct Header {
+ u64 magic;
+ u32 count;
+ u32 pad;
+ };
+ struct EntryHeader {
+ u32 size;
+ u32 offset;
+ };
+ if (!get_server_cert_chain) {
+ // Just return the first one, unencoded.
+ ASSERT_OR_EXECUTE_MSG(
+ !certs.empty(), { return {}; }, "Should be at least one server cert");
+ return certs[0];
+ }
+ std::vector<u8> ret;
+ Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
+ size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
+ for (auto& cert : certs) {
+ EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
+ data_offset += cert.size();
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
+ reinterpret_cast<u8*>(&entry_header + 1));
+ }
+ for (auto& cert : certs) {
+ ret.insert(ret.end(), cert.begin(), cert.end());
+ }
+ return ret;
+ }
+
+ ResultVal<std::vector<u8>> ReadImpl(size_t size) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ std::vector<u8> res(size);
+ ResultVal<size_t> actual = backend->Read(res);
+ if (actual.Failed()) {
+ return actual.Code();
+ }
+ res.resize(*actual);
+ return res;
+ }
+
+ ResultVal<size_t> WriteImpl(std::span<const u8> data) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ return backend->Write(data);
+ }
+
+ ResultVal<s32> PendingImpl() {
+ LOG_WARNING(Service_SSL, "(STUBBED) called.");
+ return 0;
+ }
+
+ void SetSocketDescriptor(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const s32 fd = rp.Pop<s32>();
+ const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(-1));
+ }
+
+ void SetHostName(HLERequestContext& ctx) {
+ const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
+ const Result res = SetHostNameImpl(hostname);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetVerifyOption(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 option = rp.Pop<u32>();
+ const Result res = SetVerifyOptionImpl(option);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetIoMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetIoModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshake(HLERequestContext& ctx) {
+ const Result res = DoHandshakeImpl();
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshakeGetServerCert(HLERequestContext& ctx) {
+ struct OutputParameters {
+ u32 certs_size;
+ u32 certs_count;
+ };
+ static_assert(sizeof(OutputParameters) == 0x8);
+
+ const Result res = DoHandshakeImpl();
+ OutputParameters out{};
+ if (res == ResultSuccess) {
+ auto certs = backend->GetServerCerts();
+ if (certs.Succeeded()) {
+ const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
+ ctx.WriteBuffer(certs_buf);
+ out.certs_count = static_cast<u32>(certs->size());
+ out.certs_size = static_cast<u32>(certs_buf.size());
+ }
+ }
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(res);
+ rb.PushRaw(out);
+ }
+
+ void Read(HLERequestContext& ctx) {
+ const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ if (res.Succeeded()) {
+ rb.Push(static_cast<u32>(res->size()));
+ ctx.WriteBuffer(*res);
+ } else {
+ rb.Push(static_cast<u32>(0));
+ }
+ }
+
+ void Write(HLERequestContext& ctx) {
+ const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push(static_cast<u32>(res.ValueOr(0)));
+ }
+
+ void Pending(HLERequestContext& ctx) {
+ const ResultVal<s32> res = PendingImpl();
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(0));
+ }
+
+ void SetSessionCacheMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetSessionCacheModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetOption(HLERequestContext& ctx) {
+ struct Parameters {
+ OptionType option;
+ s32 value;
+ };
+ static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
+
+ IPC::RequestParser rp{ctx};
+ const auto parameters = rp.PopRaw<Parameters>();
+
+ switch (parameters.option) {
+ case OptionType::DoNotCloseSocket:
+ do_not_close_socket = static_cast<bool>(parameters.value);
+ break;
+ case OptionType::GetServerCertChain:
+ get_server_cert_chain = static_cast<bool>(parameters.value);
+ break;
+ default:
+ LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
+ parameters.value);
+ }
+
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(ResultSuccess);
+ }
};
class ISslContext final : public ServiceFramework<ISslContext> {
public:
explicit ISslContext(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslContext"}, ssl_version{version} {
+ : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
+ shared_data{std::make_shared<SslContextSharedData>()} {
static const FunctionInfo functions[] = {
{0, &ISslContext::SetOption, "SetOption"},
{1, nullptr, "GetOption"},
{2, &ISslContext::CreateConnection, "CreateConnection"},
- {3, nullptr, "GetConnectionCount"},
+ {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
{4, &ISslContext::ImportServerPki, "ImportServerPki"},
{5, &ISslContext::ImportClientPki, "ImportClientPki"},
{6, nullptr, "RemoveServerPki"},
@@ -111,6 +416,7 @@ public:
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
void SetOption(HLERequestContext& ctx) {
struct Parameters {
@@ -130,11 +436,24 @@ private:
}
void CreateConnection(HLERequestContext& ctx) {
- LOG_WARNING(Service_SSL, "(STUBBED) called");
+ LOG_WARNING(Service_SSL, "called");
+
+ auto backend_res = CreateSSLConnectionBackend();
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+ rb.Push(backend_res.Code());
+ if (backend_res.Succeeded()) {
+ rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
+ std::move(*backend_res));
+ }
+ }
+
+ void GetConnectionCount(HLERequestContext& ctx) {
+ LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
+
+ IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess);
- rb.PushIpcInterface<ISslConnection>(system, ssl_version);
+ rb.Push(shared_data->connection_count);
}
void ImportServerPki(HLERequestContext& ctx) {