From 35501ba41cefb6f103a96f032c22c14f4fce1e96 Mon Sep 17 00:00:00 2001 From: Liam Date: Sun, 17 Dec 2023 19:54:29 -0500 Subject: k_server_session: process for guest servers --- src/core/hle/kernel/k_server_session.cpp | 1382 ++++++++++++++++++++++++------ src/core/hle/kernel/k_server_session.h | 15 +- src/core/hle/kernel/message_buffer.h | 20 +- src/core/hle/kernel/svc/svc_ipc.cpp | 6 +- src/core/hle/kernel/svc_results.h | 2 + src/core/hle/service/server_manager.cpp | 2 +- 6 files changed, 1156 insertions(+), 271 deletions(-) diff --git a/src/core/hle/kernel/k_server_session.cpp b/src/core/hle/kernel/k_server_session.cpp index e33a88e24..db3a07f91 100644 --- a/src/core/hle/kernel/k_server_session.cpp +++ b/src/core/hle/kernel/k_server_session.cpp @@ -8,6 +8,7 @@ #include "common/common_types.h" #include "common/logging/log.h" #include "common/scope_exit.h" +#include "common/scratch_buffer.h" #include "core/core.h" #include "core/core_timing.h" #include "core/hle/kernel/k_client_port.h" @@ -29,12 +30,140 @@ namespace Kernel { namespace { +constexpr inline size_t PointerTransferBufferAlignment = 0x10; +constexpr inline size_t ReceiveListDataSize = + MessageBuffer::MessageHeader::ReceiveListCountType_CountMax * + MessageBuffer::ReceiveListEntry::GetDataSize() / sizeof(u32); + +using ThreadQueueImplForKServerSessionRequest = KThreadQueue; + +static thread_local Common::ScratchBuffer temp_buffer; + +class ReceiveList { +public: + static constexpr int GetEntryCount(const MessageBuffer::MessageHeader& header) { + const auto count = header.GetReceiveListCount(); + switch (count) { + case MessageBuffer::MessageHeader::ReceiveListCountType_None: + return 0; + case MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer: + return 0; + case MessageBuffer::MessageHeader::ReceiveListCountType_ToSingleBuffer: + return 1; + default: + return count - MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset; + } + } + + explicit ReceiveList(const u32* dst_msg, uint64_t dst_address, + KProcessPageTable& dst_page_table, + const MessageBuffer::MessageHeader& dst_header, + const MessageBuffer::SpecialHeader& dst_special_header, size_t msg_size, + size_t out_offset, s32 dst_recv_list_idx, bool is_tls) { + m_recv_list_count = dst_header.GetReceiveListCount(); + m_msg_buffer_end = dst_address + sizeof(u32) * out_offset; + m_msg_buffer_space_end = dst_address + msg_size; + + // NOTE: Nintendo calculates the receive list index here using the special header. + // We pre-calculate it in the caller, and pass it as a parameter. + (void)dst_special_header; + + const u32* recv_list = dst_msg + dst_recv_list_idx; + const auto entry_count = GetEntryCount(dst_header); + + if (is_tls) { + // Messages from TLS to TLS are contained within one page. + std::memcpy(m_data.data(), recv_list, + entry_count * MessageBuffer::ReceiveListEntry::GetDataSize()); + } else { + // If any buffer is not from TLS, perform a normal read instead. + uint64_t cur_addr = dst_address + dst_recv_list_idx * sizeof(u32); + dst_page_table.GetMemory().ReadBlock( + cur_addr, m_data.data(), + entry_count * MessageBuffer::ReceiveListEntry::GetDataSize()); + } + } + + bool IsIndex() const { + return m_recv_list_count > + static_cast(MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset); + } + + bool IsToMessageBuffer() const { + return m_recv_list_count == + MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer; + } + + void GetBuffer(uint64_t& out, size_t size, int& key) const { + switch (m_recv_list_count) { + case MessageBuffer::MessageHeader::ReceiveListCountType_None: { + out = 0; + break; + } + case MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer: { + const uint64_t buf = + Common::AlignUp(m_msg_buffer_end + key, PointerTransferBufferAlignment); + + if ((buf < buf + size) && (buf + size <= m_msg_buffer_space_end)) { + out = buf; + key = static_cast(buf + size - m_msg_buffer_end); + } else { + out = 0; + } + break; + } + case MessageBuffer::MessageHeader::ReceiveListCountType_ToSingleBuffer: { + const MessageBuffer::ReceiveListEntry entry(m_data[0], m_data[1]); + const uint64_t buf = + Common::AlignUp(entry.GetAddress() + key, PointerTransferBufferAlignment); + + const uint64_t entry_addr = entry.GetAddress(); + const size_t entry_size = entry.GetSize(); + + if ((buf < buf + size) && (entry_addr < entry_addr + entry_size) && + (buf + size <= entry_addr + entry_size)) { + out = buf; + key = static_cast(buf + size - entry_addr); + } else { + out = 0; + } + break; + } + default: { + if (key < m_recv_list_count - + static_cast( + MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset)) { + const MessageBuffer::ReceiveListEntry entry(m_data[2 * key + 0], + m_data[2 * key + 1]); + + const uintptr_t entry_addr = entry.GetAddress(); + const size_t entry_size = entry.GetSize(); + + if ((entry_addr < entry_addr + entry_size) && (entry_size >= size)) { + out = entry_addr; + } + } else { + out = 0; + } + break; + } + } + } + +private: + std::array m_data; + s32 m_recv_list_count; + uint64_t m_msg_buffer_end; + uint64_t m_msg_buffer_space_end; +}; + template -Result ProcessMessageSpecialData(KProcess& dst_process, KProcess& src_process, KThread& src_thread, - MessageBuffer& dst_msg, const MessageBuffer& src_msg, - MessageBuffer::SpecialHeader& src_special_header) { +Result ProcessMessageSpecialData(s32& offset, KProcess& dst_process, KProcess& src_process, + KThread& src_thread, const MessageBuffer& dst_msg, + const MessageBuffer& src_msg, + const MessageBuffer::SpecialHeader& src_special_header) { // Copy the special header to the destination. - s32 offset = dst_msg.Set(src_special_header); + offset = dst_msg.Set(src_special_header); // Copy the process ID. if (src_special_header.GetHasProcessId()) { @@ -110,6 +239,93 @@ Result ProcessMessageSpecialData(KProcess& dst_process, KProcess& src_process, K R_RETURN(result); } +Result ProcessReceiveMessagePointerDescriptors(int& offset, int& pointer_key, + KProcessPageTable& dst_page_table, + KProcessPageTable& src_page_table, + const MessageBuffer& dst_msg, + const MessageBuffer& src_msg, + const ReceiveList& dst_recv_list, bool dst_user) { + // Get the offset at the start of processing. + const int cur_offset = offset; + + // Get the pointer desc. + MessageBuffer::PointerDescriptor src_desc(src_msg, cur_offset); + offset += static_cast(MessageBuffer::PointerDescriptor::GetDataSize() / sizeof(u32)); + + // Extract address/size. + const uint64_t src_pointer = src_desc.GetAddress(); + const size_t recv_size = src_desc.GetSize(); + uint64_t recv_pointer = 0; + + // Process the buffer, if it has a size. + if (recv_size > 0) { + // If using indexing, set index. + if (dst_recv_list.IsIndex()) { + pointer_key = src_desc.GetIndex(); + } + + // Get the buffer. + dst_recv_list.GetBuffer(recv_pointer, recv_size, pointer_key); + R_UNLESS(recv_pointer != 0, ResultOutOfResource); + + // Perform the pointer data copy. + // TODO: KProcessPageTable::CopyMemoryFromHeapToHeapWithoutCheckDestination + // TODO: KProcessPageTable::CopyMemoryFromLinearToUser + + temp_buffer.resize_destructive(recv_size); + src_page_table.GetMemory().ReadBlock(src_pointer, temp_buffer.data(), recv_size); + dst_page_table.GetMemory().WriteBlock(recv_pointer, temp_buffer.data(), recv_size); + } + + // Set the output descriptor. + dst_msg.Set(cur_offset, MessageBuffer::PointerDescriptor(reinterpret_cast(recv_pointer), + recv_size, src_desc.GetIndex())); + + R_SUCCEED(); +} + +constexpr Result GetMapAliasMemoryState(KMemoryState& out, + MessageBuffer::MapAliasDescriptor::Attribute attr) { + switch (attr) { + case MessageBuffer::MapAliasDescriptor::Attribute::Ipc: + out = KMemoryState::Ipc; + break; + case MessageBuffer::MapAliasDescriptor::Attribute::NonSecureIpc: + out = KMemoryState::NonSecureIpc; + break; + case MessageBuffer::MapAliasDescriptor::Attribute::NonDeviceIpc: + out = KMemoryState::NonDeviceIpc; + break; + default: + R_THROW(ResultInvalidCombination); + } + + R_SUCCEED(); +} + +constexpr Result GetMapAliasTestStateAndAttributeMask(u32& out_state, u32& out_attr_mask, + KMemoryState state) { + switch (state) { + case KMemoryState::Ipc: + out_state = static_cast(KMemoryState::FlagCanUseIpc); + out_attr_mask = static_cast(KMemoryAttribute::Uncached | + KMemoryAttribute::DeviceShared | KMemoryAttribute::Locked); + break; + case KMemoryState::NonSecureIpc: + out_state = static_cast(KMemoryState::FlagCanUseNonSecureIpc); + out_attr_mask = static_cast(KMemoryAttribute::Uncached | KMemoryAttribute::Locked); + break; + case KMemoryState::NonDeviceIpc: + out_state = static_cast(KMemoryState::FlagCanUseNonDeviceIpc); + out_attr_mask = static_cast(KMemoryAttribute::Uncached | KMemoryAttribute::Locked); + break; + default: + R_THROW(ResultInvalidCombination); + } + + R_SUCCEED(); +} + void CleanupSpecialData(KProcess& dst_process, u32* dst_msg_ptr, size_t dst_buffer_size) { // Parse the message. const MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size); @@ -144,273 +360,653 @@ void CleanupSpecialData(KProcess& dst_process, u32* dst_msg_ptr, size_t dst_buff } } -} // namespace +Result CleanupServerHandles(KernelCore& kernel, uint64_t message, size_t buffer_size, + KPhysicalAddress message_paddr) { + // Server is assumed to be current thread. + KThread& thread = GetCurrentThread(kernel); -using ThreadQueueImplForKServerSessionRequest = KThreadQueue; + // Get the linear message pointer. + u32* msg_ptr; + if (message) { + msg_ptr = kernel.System().DeviceMemory().GetPointer(message_paddr); + } else { + msg_ptr = GetCurrentMemory(kernel).GetPointer(thread.GetTlsAddress()); + buffer_size = MessageBufferSize; + message = GetInteger(thread.GetTlsAddress()); + } -KServerSession::KServerSession(KernelCore& kernel) - : KSynchronizationObject{kernel}, m_lock{m_kernel} {} + // Parse the message. + const MessageBuffer msg(msg_ptr, buffer_size); + const MessageBuffer::MessageHeader header(msg); + const MessageBuffer::SpecialHeader special_header(msg, header); -KServerSession::~KServerSession() = default; + // Check that the size is big enough. + R_UNLESS(MessageBuffer::GetMessageBufferSize(header, special_header) <= buffer_size, + ResultInvalidCombination); + + // If there's a special header, there may be move handles we need to close. + if (header.GetHasSpecialHeader()) { + // Determine the offset to the start of handles. + auto offset = msg.GetSpecialDataIndex(header, special_header); + if (special_header.GetHasProcessId()) { + offset += static_cast(sizeof(u64) / sizeof(u32)); + } + if (auto copy_count = special_header.GetCopyHandleCount(); copy_count > 0) { + offset += static_cast((sizeof(Svc::Handle) * copy_count) / sizeof(u32)); + } -void KServerSession::Destroy() { - m_parent->OnServerClosed(); + // Get the handle table. + auto& handle_table = thread.GetOwnerProcess()->GetHandleTable(); - this->CleanupRequests(); + // Close the handles. + for (auto i = 0; i < special_header.GetMoveHandleCount(); ++i) { + handle_table.Remove(msg.GetHandle(offset)); + offset += static_cast(sizeof(Svc::Handle) / sizeof(u32)); + } + } - m_parent->Close(); + R_SUCCEED(); } -void KServerSession::OnClientClosed() { - KScopedLightLock lk{m_lock}; - - // Handle any pending requests. - KSessionRequest* prev_request = nullptr; - while (true) { - // Declare variables for processing the request. - KSessionRequest* request = nullptr; - KEvent* event = nullptr; - KThread* thread = nullptr; - bool cur_request = false; - bool terminate = false; - - // Get the next request. - { - KScopedSchedulerLock sl{m_kernel}; +Result CleanupServerMap(KSessionRequest* request, KProcess* server_process) { + // If there's no server process, there's nothing to clean up. + R_SUCCEED_IF(server_process == nullptr); - if (m_current_request != nullptr && m_current_request != prev_request) { - // Set the request, open a reference as we process it. - request = m_current_request; - request->Open(); - cur_request = true; + // Get the page table. + auto& server_page_table = server_process->GetPageTable(); - // Get thread and event for the request. - thread = request->GetThread(); - event = request->GetEvent(); + // Cleanup Send mappings. + for (size_t i = 0; i < request->GetSendCount(); ++i) { + R_TRY(server_page_table.CleanupForIpcServer(request->GetSendServerAddress(i), + request->GetSendSize(i), + request->GetSendMemoryState(i))); + } - // If the thread is terminating, handle that. - if (thread->IsTerminationRequested()) { - request->ClearThread(); - request->ClearEvent(); - terminate = true; - } + // Cleanup Receive mappings. + for (size_t i = 0; i < request->GetReceiveCount(); ++i) { + R_TRY(server_page_table.CleanupForIpcServer(request->GetReceiveServerAddress(i), + request->GetReceiveSize(i), + request->GetReceiveMemoryState(i))); + } - prev_request = request; - } else if (!m_request_list.empty()) { - // Pop the request from the front of the list. - request = std::addressof(m_request_list.front()); - m_request_list.pop_front(); + // Cleanup Exchange mappings. + for (size_t i = 0; i < request->GetExchangeCount(); ++i) { + R_TRY(server_page_table.CleanupForIpcServer(request->GetExchangeServerAddress(i), + request->GetExchangeSize(i), + request->GetExchangeMemoryState(i))); + } - // Get thread and event for the request. - thread = request->GetThread(); - event = request->GetEvent(); - } - } + R_SUCCEED(); +} - // If there are no requests, we're done. - if (request == nullptr) { - break; - } +Result CleanupClientMap(KSessionRequest* request, KProcessPageTable* client_page_table) { + // If there's no client page table, there's nothing to clean up. + R_SUCCEED_IF(client_page_table == nullptr); - // All requests must have threads. - ASSERT(thread != nullptr); + // Cleanup Send mappings. + for (size_t i = 0; i < request->GetSendCount(); ++i) { + R_TRY(client_page_table->CleanupForIpcClient(request->GetSendClientAddress(i), + request->GetSendSize(i), + request->GetSendMemoryState(i))); + } - // Ensure that we close the request when done. - SCOPE_EXIT({ request->Close(); }); + // Cleanup Receive mappings. + for (size_t i = 0; i < request->GetReceiveCount(); ++i) { + R_TRY(client_page_table->CleanupForIpcClient(request->GetReceiveClientAddress(i), + request->GetReceiveSize(i), + request->GetReceiveMemoryState(i))); + } - // If we're terminating, close a reference to the thread and event. - if (terminate) { - thread->Close(); - if (event != nullptr) { - event->Close(); - } - } + // Cleanup Exchange mappings. + for (size_t i = 0; i < request->GetExchangeCount(); ++i) { + R_TRY(client_page_table->CleanupForIpcClient(request->GetExchangeClientAddress(i), + request->GetExchangeSize(i), + request->GetExchangeMemoryState(i))); + } - // If we need to, reply. - if (event != nullptr && !cur_request) { - // There must be no mappings. - ASSERT(request->GetSendCount() == 0); - ASSERT(request->GetReceiveCount() == 0); - ASSERT(request->GetExchangeCount() == 0); + R_SUCCEED(); +} - // // Get the process and page table. - // KProcess *client_process = thread->GetOwnerProcess(); - // auto& client_pt = client_process->GetPageTable(); +Result CleanupMap(KSessionRequest* request, KProcess* server_process, + KProcessPageTable* client_page_table) { + // Cleanup the server map. + R_TRY(CleanupServerMap(request, server_process)); - // // Reply to the request. - // ReplyAsyncError(client_process, request->GetAddress(), request->GetSize(), - // ResultSessionClosed); + // Cleanup the client map. + R_TRY(CleanupClientMap(request, client_page_table)); - // // Unlock the buffer. - // // NOTE: Nintendo does not check the result of this. - // client_pt.UnlockForIpcUserBuffer(request->GetAddress(), request->GetSize()); + R_SUCCEED(); +} - // Signal the event. - event->Signal(); +Result ProcessReceiveMessageMapAliasDescriptors(int& offset, KProcessPageTable& dst_page_table, + KProcessPageTable& src_page_table, + const MessageBuffer& dst_msg, + const MessageBuffer& src_msg, + KSessionRequest* request, KMemoryPermission perm, + bool send) { + // Get the offset at the start of processing. + const int cur_offset = offset; + + // Get the map alias descriptor. + MessageBuffer::MapAliasDescriptor src_desc(src_msg, cur_offset); + offset += static_cast(MessageBuffer::MapAliasDescriptor::GetDataSize() / sizeof(u32)); + + // Extract address/size. + const KProcessAddress src_address = src_desc.GetAddress(); + const size_t size = src_desc.GetSize(); + KProcessAddress dst_address = 0; + + // Determine the result memory state. + KMemoryState dst_state; + R_TRY(GetMapAliasMemoryState(dst_state, src_desc.GetAttribute())); + + // Process the buffer, if it has a size. + if (size > 0) { + // Set up the source pages for ipc. + R_TRY(dst_page_table.SetupForIpc(std::addressof(dst_address), size, src_address, + src_page_table, perm, dst_state, send)); + + // Ensure that we clean up on failure. + ON_RESULT_FAILURE { + dst_page_table.CleanupForIpcServer(dst_address, size, dst_state); + src_page_table.CleanupForIpcClient(src_address, size, dst_state); + }; + + // Push the appropriate mapping. + if (perm == KMemoryPermission::UserRead) { + R_TRY(request->PushSend(src_address, dst_address, size, dst_state)); + } else if (send) { + R_TRY(request->PushExchange(src_address, dst_address, size, dst_state)); + } else { + R_TRY(request->PushReceive(src_address, dst_address, size, dst_state)); } } - // Notify. - this->NotifyAvailable(ResultSessionClosed); -} - -bool KServerSession::IsSignaled() const { - ASSERT(KScheduler::IsSchedulerLockedByCurrentThread(m_kernel)); - - // If the client is closed, we're always signaled. - if (m_parent->IsClientClosed()) { - return true; - } + // Set the output descriptor. + dst_msg.Set(cur_offset, + MessageBuffer::MapAliasDescriptor(reinterpret_cast(GetInteger(dst_address)), + size, src_desc.GetAttribute())); - // Otherwise, we're signaled if we have a request and aren't handling one. - return !m_request_list.empty() && m_current_request == nullptr; + R_SUCCEED(); } -Result KServerSession::OnRequest(KSessionRequest* request) { - // Create the wait queue. - ThreadQueueImplForKServerSessionRequest wait_queue{m_kernel}; +Result ReceiveMessage(KernelCore& kernel, bool& recv_list_broken, uint64_t dst_message_buffer, + size_t dst_buffer_size, KPhysicalAddress dst_message_paddr, + KThread& src_thread, uint64_t src_message_buffer, size_t src_buffer_size, + KServerSession* session, KSessionRequest* request) { + // Prepare variables for receive. + KThread& dst_thread = GetCurrentThread(kernel); + KProcess& dst_process = *(dst_thread.GetOwnerProcess()); + KProcess& src_process = *(src_thread.GetOwnerProcess()); + auto& dst_page_table = dst_process.GetPageTable(); + auto& src_page_table = src_process.GetPageTable(); + + // NOTE: Session is used only for debugging, and so may go unused. + (void)session; + + // The receive list is initially not broken. + recv_list_broken = false; + + // Set the server process for the request. + request->SetServerProcess(std::addressof(dst_process)); + + // Determine the message buffers. + u32 *dst_msg_ptr, *src_msg_ptr; + bool dst_user, src_user; + + if (dst_message_buffer) { + dst_msg_ptr = kernel.System().DeviceMemory().GetPointer(dst_message_paddr); + dst_user = true; + } else { + dst_msg_ptr = dst_page_table.GetMemory().GetPointer(dst_thread.GetTlsAddress()); + dst_buffer_size = MessageBufferSize; + dst_message_buffer = GetInteger(dst_thread.GetTlsAddress()); + dst_user = false; + } - { - // Lock the scheduler. - KScopedSchedulerLock sl{m_kernel}; + if (src_message_buffer) { + // NOTE: Nintendo does not check the result of this GetPhysicalAddress call. + src_msg_ptr = src_page_table.GetMemory().GetPointer(src_message_buffer); + src_user = true; + } else { + src_msg_ptr = src_page_table.GetMemory().GetPointer(src_thread.GetTlsAddress()); + src_buffer_size = MessageBufferSize; + src_message_buffer = GetInteger(src_thread.GetTlsAddress()); + src_user = false; + } - // Ensure that we can handle new requests. - R_UNLESS(!m_parent->IsServerClosed(), ResultSessionClosed); + // Parse the headers. + const MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size); + const MessageBuffer src_msg(src_msg_ptr, src_buffer_size); + const MessageBuffer::MessageHeader dst_header(dst_msg); + const MessageBuffer::MessageHeader src_header(src_msg); + const MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header); + const MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + + // Get the end of the source message. + const size_t src_end_offset = + MessageBuffer::GetRawDataIndex(src_header, src_special_header) + src_header.GetRawCount(); + + // Ensure that the headers fit. + R_UNLESS(MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) <= dst_buffer_size, + ResultInvalidCombination); + R_UNLESS(MessageBuffer::GetMessageBufferSize(src_header, src_special_header) <= src_buffer_size, + ResultInvalidCombination); + + // Ensure the receive list offset is after the end of raw data. + if (dst_header.GetReceiveListOffset()) { + R_UNLESS(dst_header.GetReceiveListOffset() >= + MessageBuffer::GetRawDataIndex(dst_header, dst_special_header) + + dst_header.GetRawCount(), + ResultInvalidCombination); + } - // Check that we're not terminating. - R_UNLESS(!GetCurrentThread(m_kernel).IsTerminationRequested(), ResultTerminationRequested); + // Ensure that the destination buffer is big enough to receive the source. + R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), ResultMessageTooLarge); + + // Get the receive list. + const s32 dst_recv_list_idx = + MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header); + ReceiveList dst_recv_list(dst_msg_ptr, dst_message_buffer, dst_page_table, dst_header, + dst_special_header, dst_buffer_size, src_end_offset, + dst_recv_list_idx, !dst_user); + + // Ensure that the source special header isn't invalid. + const bool src_has_special_header = src_header.GetHasSpecialHeader(); + if (src_has_special_header) { + // Sending move handles from client -> server is not allowed. + R_UNLESS(src_special_header.GetMoveHandleCount() == 0, ResultInvalidCombination); + } - // Get whether we're empty. - const bool was_empty = m_request_list.empty(); + // Prepare for further processing. + int pointer_key = 0; + int offset = dst_msg.Set(src_header); - // Add the request to the list. - request->Open(); - m_request_list.push_back(*request); + // Set up a guard to make sure that we end up in a clean state on error. + ON_RESULT_FAILURE { + // Cleanup mappings. + CleanupMap(request, std::addressof(dst_process), std::addressof(src_page_table)); - // If we were empty, signal. - if (was_empty) { - this->NotifyAvailable(); + // Cleanup special data. + if (src_header.GetHasSpecialHeader()) { + CleanupSpecialData(dst_process, dst_msg_ptr, dst_buffer_size); } - // If we have a request event, this is asynchronous, and we don't need to wait. - R_SUCCEED_IF(request->GetEvent() != nullptr); + // Cleanup the header if the receive list isn't broken. + if (!recv_list_broken) { + dst_msg.Set(dst_header); + if (dst_header.GetHasSpecialHeader()) { + dst_msg.Set(dst_special_header); + } + } + }; + + // Process any special data. + if (src_header.GetHasSpecialHeader()) { + // After we process, make sure we track whether the receive list is broken. + SCOPE_EXIT({ + if (offset > dst_recv_list_idx) { + recv_list_broken = true; + } + }); - // This is a synchronous request, so we should wait for our request to complete. - GetCurrentThread(m_kernel).SetWaitReasonForDebugging(ThreadWaitReasonForDebugging::IPC); - GetCurrentThread(m_kernel).BeginWait(std::addressof(wait_queue)); + // Process special data. + R_TRY(ProcessMessageSpecialData(offset, dst_process, src_process, src_thread, + dst_msg, src_msg, src_special_header)); } - return GetCurrentThread(m_kernel).GetWaitResult(); -} + // Process any pointer buffers. + for (auto i = 0; i < src_header.GetPointerCount(); ++i) { + // After we process, make sure we track whether the receive list is broken. + SCOPE_EXIT({ + if (offset > dst_recv_list_idx) { + recv_list_broken = true; + } + }); -Result KServerSession::SendReply(bool is_hle) { - // Lock the session. - KScopedLightLock lk{m_lock}; + R_TRY(ProcessReceiveMessagePointerDescriptors( + offset, pointer_key, dst_page_table, src_page_table, dst_msg, src_msg, dst_recv_list, + dst_user && dst_header.GetReceiveListCount() == + MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer)); + } - // Get the request. - KSessionRequest* request; - { - KScopedSchedulerLock sl{m_kernel}; + // Process any map alias buffers. + for (auto i = 0; i < src_header.GetMapAliasCount(); ++i) { + // After we process, make sure we track whether the receive list is broken. + SCOPE_EXIT({ + if (offset > dst_recv_list_idx) { + recv_list_broken = true; + } + }); - // Get the current request. - request = m_current_request; - R_UNLESS(request != nullptr, ResultInvalidState); + // We process in order send, recv, exch. Buffers after send (recv/exch) are ReadWrite. + const KMemoryPermission perm = (i >= src_header.GetSendCount()) + ? KMemoryPermission::UserReadWrite + : KMemoryPermission::UserRead; - // Clear the current request, since we're processing it. - m_current_request = nullptr; - if (!m_request_list.empty()) { - this->NotifyAvailable(); - } - } + // Buffer is send if it is send or exch. + const bool send = (i < src_header.GetSendCount()) || + (i >= src_header.GetSendCount() + src_header.GetReceiveCount()); - // Close reference to the request once we're done processing it. - SCOPE_EXIT({ request->Close(); }); + R_TRY(ProcessReceiveMessageMapAliasDescriptors(offset, dst_page_table, src_page_table, + dst_msg, src_msg, request, perm, send)); + } - // Extract relevant information from the request. - const uintptr_t client_message = request->GetAddress(); - const size_t client_buffer_size = request->GetSize(); - KThread* client_thread = request->GetThread(); - KEvent* event = request->GetEvent(); + // Process any raw data. + if (const auto raw_count = src_header.GetRawCount(); raw_count != 0) { + // After we process, make sure we track whether the receive list is broken. + SCOPE_EXIT({ + if (offset + raw_count > dst_recv_list_idx) { + recv_list_broken = true; + } + }); - // Check whether we're closed. - const bool closed = (client_thread == nullptr || m_parent->IsClientClosed()); + // Get the offset and size. + const size_t offset_words = offset * sizeof(u32); + const size_t raw_size = raw_count * sizeof(u32); - Result result = ResultSuccess; - if (!closed) { - // If we're not closed, send the reply. - if (is_hle) { - // HLE servers write directly to a pointer to the thread command buffer. Therefore - // the reply has already been written in this case. + if (!dst_user && !src_user) { + // Fast case is TLS -> TLS, do raw memcpy if we can. + std::memcpy(dst_msg_ptr + offset, src_msg_ptr + offset, raw_size); } else { - Core::Memory::Memory& memory{client_thread->GetOwnerProcess()->GetMemory()}; - KThread* server_thread = GetCurrentThreadPointer(m_kernel); - KProcess& src_process = *client_thread->GetOwnerProcess(); - KProcess& dst_process = *server_thread->GetOwnerProcess(); - UNIMPLEMENTED_IF(server_thread->GetOwnerProcess() != client_thread->GetOwnerProcess()); - - auto* src_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); - auto* dst_msg_buffer = memory.GetPointer(client_message); - std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size); - - // Translate special header ad-hoc. - MessageBuffer src_msg(src_msg_buffer, client_buffer_size); - MessageBuffer::MessageHeader src_header(src_msg); - MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); - if (src_header.GetHasSpecialHeader()) { - MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size); - result = ProcessMessageSpecialData(dst_process, src_process, *server_thread, - dst_msg, src_msg, src_special_header); - if (R_FAILED(result)) { - CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size); - } - } + // Copy the memory. + temp_buffer.resize_destructive(raw_size); + src_page_table.GetMemory().ReadBlock(src_message_buffer + offset_words, + temp_buffer.data(), raw_size); + dst_page_table.GetMemory().WriteBlock(dst_message_buffer + offset_words, + temp_buffer.data(), raw_size); } - } else { - result = ResultSessionClosed; } - // Select a result for the client. - Result client_result = result; - if (closed && R_SUCCEEDED(result)) { - result = ResultSessionClosed; - client_result = ResultSessionClosed; - } else { - result = ResultSuccess; + // We succeeded! + R_SUCCEED(); +} + +Result ProcessSendMessageReceiveMapping(KProcessPageTable& src_page_table, + KProcessPageTable& dst_page_table, + KProcessAddress client_address, + KProcessAddress server_address, size_t size, + KMemoryState src_state) { + // If the size is zero, there's nothing to process. + R_SUCCEED_IF(size == 0); + + // Get the memory state and attribute mask to test. + u32 test_state; + u32 test_attr_mask; + R_TRY(GetMapAliasTestStateAndAttributeMask(test_state, test_attr_mask, src_state)); + + // Determine buffer extents. + KProcessAddress aligned_dst_start = Common::AlignDown(GetInteger(client_address), PageSize); + KProcessAddress aligned_dst_end = Common::AlignUp(GetInteger(client_address) + size, PageSize); + KProcessAddress mapping_dst_start = Common::AlignUp(GetInteger(client_address), PageSize); + KProcessAddress mapping_dst_end = + Common::AlignDown(GetInteger(client_address) + size, PageSize); + + KProcessAddress mapping_src_end = + Common::AlignDown(GetInteger(server_address) + size, PageSize); + + // If the start of the buffer is unaligned, handle that. + if (aligned_dst_start != mapping_dst_start) { + ASSERT(client_address < mapping_dst_start); + const size_t copy_size = std::min(size, mapping_dst_start - client_address); + temp_buffer.resize_destructive(copy_size); + src_page_table.GetMemory().ReadBlock(client_address, temp_buffer.data(), copy_size); + dst_page_table.GetMemory().WriteBlock(server_address, temp_buffer.data(), copy_size); } - // If there's a client thread, update it. - if (client_thread != nullptr) { - if (event != nullptr) { - // // Get the client process/page table. - // KProcess *client_process = client_thread->GetOwnerProcess(); - // KProcessPageTable *client_page_table = std::addressof(client_process->PageTable()); + // If the end of the buffer is unaligned, handle that. + if (mapping_dst_end < aligned_dst_end && + (aligned_dst_start == mapping_dst_start || aligned_dst_start < mapping_dst_end)) { + const size_t copy_size = client_address + size - mapping_dst_end; + temp_buffer.resize_destructive(copy_size); + src_page_table.GetMemory().ReadBlock(mapping_src_end, temp_buffer.data(), copy_size); + dst_page_table.GetMemory().WriteBlock(mapping_dst_end, temp_buffer.data(), copy_size); + } - // // If we need to, reply with an async error. - // if (R_FAILED(client_result)) { - // ReplyAsyncError(client_process, client_message, client_buffer_size, - // client_result); - // } + R_SUCCEED(); +} - // // Unlock the client buffer. - // // NOTE: Nintendo does not check the result of this. - // client_page_table->UnlockForIpcUserBuffer(client_message, client_buffer_size); +Result ProcessSendMessagePointerDescriptors(int& offset, int& pointer_key, + KProcessPageTable& src_page_table, + KProcessPageTable& dst_page_table, + const MessageBuffer& dst_msg, + const MessageBuffer& src_msg, + const ReceiveList& dst_recv_list, bool dst_user) { + // Get the offset at the start of processing. + const int cur_offset = offset; + + // Get the pointer desc. + MessageBuffer::PointerDescriptor src_desc(src_msg, cur_offset); + offset += static_cast(MessageBuffer::PointerDescriptor::GetDataSize() / sizeof(u32)); + + // Extract address/size. + const uint64_t src_pointer = src_desc.GetAddress(); + const size_t recv_size = src_desc.GetSize(); + uint64_t recv_pointer = 0; + + // Process the buffer, if it has a size. + if (recv_size > 0) { + // If using indexing, set index. + if (dst_recv_list.IsIndex()) { + pointer_key = src_desc.GetIndex(); + } - // Signal the event. - event->Signal(); - } else { - // End the client thread's wait. - KScopedSchedulerLock sl{m_kernel}; + // Get the buffer. + dst_recv_list.GetBuffer(recv_pointer, recv_size, pointer_key); + R_UNLESS(recv_pointer != 0, ResultOutOfResource); - if (!client_thread->IsTerminationRequested()) { - client_thread->EndWait(client_result); + // Perform the pointer data copy. + temp_buffer.resize_destructive(recv_size); + src_page_table.GetMemory().ReadBlock(src_pointer, temp_buffer.data(), recv_size); + dst_page_table.GetMemory().WriteBlock(recv_pointer, temp_buffer.data(), recv_size); + } + + // Set the output descriptor. + dst_msg.Set(cur_offset, MessageBuffer::PointerDescriptor(reinterpret_cast(recv_pointer), + recv_size, src_desc.GetIndex())); + + R_SUCCEED(); +} + +Result SendMessage(KernelCore& kernel, uint64_t src_message_buffer, size_t src_buffer_size, + KPhysicalAddress src_message_paddr, KThread& dst_thread, + uint64_t dst_message_buffer, size_t dst_buffer_size, KServerSession* session, + KSessionRequest* request) { + // Prepare variables for send. + KThread& src_thread = GetCurrentThread(kernel); + KProcess& dst_process = *(dst_thread.GetOwnerProcess()); + KProcess& src_process = *(src_thread.GetOwnerProcess()); + auto& dst_page_table = dst_process.GetPageTable(); + auto& src_page_table = src_process.GetPageTable(); + + // NOTE: Session is used only for debugging, and so may go unused. + (void)session; + + // Determine the message buffers. + u32 *dst_msg_ptr, *src_msg_ptr; + bool dst_user, src_user; + + if (dst_message_buffer) { + // NOTE: Nintendo does not check the result of this GetPhysicalAddress call. + dst_msg_ptr = dst_page_table.GetMemory().GetPointer(dst_message_buffer); + dst_user = true; + } else { + dst_msg_ptr = dst_page_table.GetMemory().GetPointer(dst_thread.GetTlsAddress()); + dst_buffer_size = MessageBufferSize; + dst_message_buffer = GetInteger(dst_thread.GetTlsAddress()); + dst_user = false; + } + + if (src_message_buffer) { + src_msg_ptr = src_page_table.GetMemory().GetPointer(src_message_buffer); + src_user = true; + } else { + src_msg_ptr = src_page_table.GetMemory().GetPointer(src_thread.GetTlsAddress()); + src_buffer_size = MessageBufferSize; + src_message_buffer = GetInteger(src_thread.GetTlsAddress()); + src_user = false; + } + + // Parse the headers. + const MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size); + const MessageBuffer src_msg(src_msg_ptr, src_buffer_size); + const MessageBuffer::MessageHeader dst_header(dst_msg); + const MessageBuffer::MessageHeader src_header(src_msg); + const MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header); + const MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + + // Get the end of the source message. + const size_t src_end_offset = + MessageBuffer::GetRawDataIndex(src_header, src_special_header) + src_header.GetRawCount(); + + // Declare variables for processing. + int offset = 0; + int pointer_key = 0; + bool processed_special_data = false; + + // Send the message. + { + // Make sure that we end up in a clean state on error. + ON_RESULT_FAILURE { + // Cleanup special data. + if (processed_special_data) { + if (src_header.GetHasSpecialHeader()) { + CleanupSpecialData(dst_process, dst_msg_ptr, dst_buffer_size); + } + } else { + CleanupServerHandles(kernel, src_user ? src_message_buffer : 0, src_buffer_size, + src_message_paddr); + } + + // Cleanup mappings. + CleanupMap(request, std::addressof(src_process), std::addressof(dst_page_table)); + }; + + // Ensure that the headers fit. + R_UNLESS(MessageBuffer::GetMessageBufferSize(src_header, src_special_header) <= + src_buffer_size, + ResultInvalidCombination); + R_UNLESS(MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) <= + dst_buffer_size, + ResultInvalidCombination); + + // Ensure the receive list offset is after the end of raw data. + if (dst_header.GetReceiveListOffset()) { + R_UNLESS(dst_header.GetReceiveListOffset() >= + MessageBuffer::GetRawDataIndex(dst_header, dst_special_header) + + dst_header.GetRawCount(), + ResultInvalidCombination); + } + + // Ensure that the destination buffer is big enough to receive the source. + R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), ResultMessageTooLarge); + + // Replies must have no buffers. + R_UNLESS(src_header.GetSendCount() == 0, ResultInvalidCombination); + R_UNLESS(src_header.GetReceiveCount() == 0, ResultInvalidCombination); + R_UNLESS(src_header.GetExchangeCount() == 0, ResultInvalidCombination); + + // Get the receive list. + const s32 dst_recv_list_idx = + MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header); + ReceiveList dst_recv_list(dst_msg_ptr, dst_message_buffer, dst_page_table, dst_header, + dst_special_header, dst_buffer_size, src_end_offset, + dst_recv_list_idx, !dst_user); + + // Handle any receive buffers. + for (size_t i = 0; i < request->GetReceiveCount(); ++i) { + R_TRY(ProcessSendMessageReceiveMapping( + src_page_table, dst_page_table, request->GetReceiveClientAddress(i), + request->GetReceiveServerAddress(i), request->GetReceiveSize(i), + request->GetReceiveMemoryState(i))); + } + + // Handle any exchange buffers. + for (size_t i = 0; i < request->GetExchangeCount(); ++i) { + R_TRY(ProcessSendMessageReceiveMapping( + src_page_table, dst_page_table, request->GetExchangeClientAddress(i), + request->GetExchangeServerAddress(i), request->GetExchangeSize(i), + request->GetExchangeMemoryState(i))); + } + + // Set the header. + offset = dst_msg.Set(src_header); + + // Process any special data. + ASSERT(GetCurrentThreadPointer(kernel) == std::addressof(src_thread)); + processed_special_data = true; + if (src_header.GetHasSpecialHeader()) { + R_TRY(ProcessMessageSpecialData(offset, dst_process, src_process, src_thread, + dst_msg, src_msg, src_special_header)); + } + + // Process any pointer buffers. + for (auto i = 0; i < src_header.GetPointerCount(); ++i) { + R_TRY(ProcessSendMessagePointerDescriptors( + offset, pointer_key, src_page_table, dst_page_table, dst_msg, src_msg, + dst_recv_list, + dst_user && + dst_header.GetReceiveListCount() == + MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer)); + } + + // Clear any map alias buffers. + for (auto i = 0; i < src_header.GetMapAliasCount(); ++i) { + offset = dst_msg.Set(offset, MessageBuffer::MapAliasDescriptor()); + } + + // Process any raw data. + if (const auto raw_count = src_header.GetRawCount(); raw_count != 0) { + // Get the offset and size. + const size_t offset_words = offset * sizeof(u32); + const size_t raw_size = raw_count * sizeof(u32); + + if (!dst_user && !src_user) { + // Fast case is TLS -> TLS, do raw memcpy if we can. + std::memcpy(dst_msg_ptr + offset, src_msg_ptr + offset, raw_size); + } else { + // Copy the memory. + temp_buffer.resize_destructive(raw_size); + src_page_table.GetMemory().ReadBlock(src_message_buffer + offset_words, + temp_buffer.data(), raw_size); + dst_page_table.GetMemory().WriteBlock(dst_message_buffer + offset_words, + temp_buffer.data(), raw_size); } } } - R_RETURN(result); + // Perform (and validate) any remaining cleanup. + R_RETURN(CleanupMap(request, std::addressof(src_process), std::addressof(dst_page_table))); } -Result KServerSession::ReceiveRequest(std::shared_ptr* out_context, +void ReplyAsyncError(KProcess* to_process, uint64_t to_msg_buf, size_t to_msg_buf_size, + Result result) { + // Convert the address to a linear pointer. + u32* to_msg = to_process->GetMemory().GetPointer(to_msg_buf); + + // Set the error. + MessageBuffer msg(to_msg, to_msg_buf_size); + msg.SetAsyncResult(result); +} + +} // namespace + +KServerSession::KServerSession(KernelCore& kernel) + : KSynchronizationObject{kernel}, m_lock{m_kernel} {} + +KServerSession::~KServerSession() = default; + +void KServerSession::Destroy() { + m_parent->OnServerClosed(); + + this->CleanupRequests(); + + m_parent->Close(); +} + +Result KServerSession::ReceiveRequest(uintptr_t server_message, uintptr_t server_buffer_size, + KPhysicalAddress server_message_paddr, + std::shared_ptr* out_context, std::weak_ptr manager) { // Lock the session. KScopedLightLock lk{m_lock}; @@ -449,52 +1045,242 @@ Result KServerSession::ReceiveRequest(std::shared_ptrGetAddress(); + uint64_t client_message = request->GetAddress(); size_t client_buffer_size = request->GetSize(); - // bool recv_list_broken = false; - - if (!client_message) { - client_message = GetInteger(client_thread->GetTlsAddress()); - client_buffer_size = MessageBufferSize; - } + bool recv_list_broken = false; // Receive the message. - Core::Memory::Memory& memory{client_thread->GetOwnerProcess()->GetMemory()}; + Result result = ResultSuccess; + if (out_context != nullptr) { // HLE request. + if (!client_message) { + client_message = GetInteger(client_thread->GetTlsAddress()); + } + Core::Memory::Memory& memory{client_thread->GetOwnerProcess()->GetMemory()}; u32* cmd_buf{reinterpret_cast(memory.GetPointer(client_message))}; *out_context = std::make_shared(m_kernel, memory, this, client_thread); (*out_context)->SetSessionRequestManager(manager); (*out_context) ->PopulateFromIncomingCommandBuffer(*client_thread->GetOwnerProcess(), cmd_buf); + // We succeeded. + R_SUCCEED(); } else { - KThread* server_thread = GetCurrentThreadPointer(m_kernel); - KProcess& src_process = *client_thread->GetOwnerProcess(); - KProcess& dst_process = *server_thread->GetOwnerProcess(); - UNIMPLEMENTED_IF(client_thread->GetOwnerProcess() != server_thread->GetOwnerProcess()); - - auto* src_msg_buffer = memory.GetPointer(client_message); - auto* dst_msg_buffer = memory.GetPointer(server_thread->GetTlsAddress()); - std::memcpy(dst_msg_buffer, src_msg_buffer, client_buffer_size); - - // Translate special header ad-hoc. - // TODO: fix this mess - MessageBuffer src_msg(src_msg_buffer, client_buffer_size); - MessageBuffer::MessageHeader src_header(src_msg); - MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); - if (src_header.GetHasSpecialHeader()) { - MessageBuffer dst_msg(dst_msg_buffer, client_buffer_size); - Result res = ProcessMessageSpecialData(dst_process, src_process, *client_thread, - dst_msg, src_msg, src_special_header); - if (R_FAILED(res)) { - CleanupSpecialData(dst_process, dst_msg_buffer, client_buffer_size); + result = ReceiveMessage(m_kernel, recv_list_broken, server_message, server_buffer_size, + server_message_paddr, *client_thread, client_message, + client_buffer_size, this, request); + } + + // Handle cleanup on receive failure. + if (R_FAILED(result)) { + // Cache the result to return it to the client. + const Result result_for_client = result; + + // Clear the current request. + { + KScopedSchedulerLock sl(m_kernel); + ASSERT(m_current_request == request); + m_current_request = nullptr; + if (!m_request_list.empty()) { + this->NotifyAvailable(); + } + } + + // Reply to the client. + { + // After we reply, close our reference to the request. + SCOPE_EXIT({ request->Close(); }); + + // Get the event to check whether the request is async. + if (KEvent* event = request->GetEvent(); event != nullptr) { + // The client sent an async request. + KProcess* client = client_thread->GetOwnerProcess(); + auto& client_pt = client->GetPageTable(); + + // Send the async result. + if (R_FAILED(result_for_client)) { + ReplyAsyncError(client, client_message, client_buffer_size, result_for_client); + } + + // Unlock the client buffer. + // NOTE: Nintendo does not check the result of this. + client_pt.UnlockForIpcUserBuffer(client_message, client_buffer_size); + + // Signal the event. + event->Signal(); + } else { + // End the client thread's wait. + KScopedSchedulerLock sl(m_kernel); + + if (!client_thread->IsTerminationRequested()) { + client_thread->EndWait(result_for_client); + } } } + + // Set the server result. + if (recv_list_broken) { + result = ResultReceiveListBroken; + } else { + result = ResultNotFound; + } } - // We succeeded. - R_SUCCEED(); + R_RETURN(result); +} + +Result KServerSession::SendReply(uintptr_t server_message, uintptr_t server_buffer_size, + KPhysicalAddress server_message_paddr, bool is_hle) { + // Lock the session. + KScopedLightLock lk{m_lock}; + + // Get the request. + KSessionRequest* request; + { + KScopedSchedulerLock sl{m_kernel}; + + // Get the current request. + request = m_current_request; + R_UNLESS(request != nullptr, ResultInvalidState); + + // Clear the current request, since we're processing it. + m_current_request = nullptr; + if (!m_request_list.empty()) { + this->NotifyAvailable(); + } + } + + // Close reference to the request once we're done processing it. + SCOPE_EXIT({ request->Close(); }); + + // Extract relevant information from the request. + const uint64_t client_message = request->GetAddress(); + const size_t client_buffer_size = request->GetSize(); + KThread* client_thread = request->GetThread(); + KEvent* event = request->GetEvent(); + + // Check whether we're closed. + const bool closed = (client_thread == nullptr || m_parent->IsClientClosed()); + + Result result = ResultSuccess; + if (!closed) { + // If we're not closed, send the reply. + if (is_hle) { + // HLE servers write directly to a pointer to the thread command buffer. Therefore + // the reply has already been written in this case. + } else { + result = SendMessage(m_kernel, server_message, server_buffer_size, server_message_paddr, + *client_thread, client_message, client_buffer_size, this, request); + } + } else if (!is_hle) { + // Otherwise, we'll need to do some cleanup. + KProcess* server_process = request->GetServerProcess(); + KProcess* client_process = + (client_thread != nullptr) ? client_thread->GetOwnerProcess() : nullptr; + KProcessPageTable* client_page_table = + (client_process != nullptr) ? std::addressof(client_process->GetPageTable()) : nullptr; + + // Cleanup server handles. + result = CleanupServerHandles(m_kernel, server_message, server_buffer_size, + server_message_paddr); + + // Cleanup mappings. + Result cleanup_map_result = CleanupMap(request, server_process, client_page_table); + + // If we successfully cleaned up handles, use the map cleanup result as our result. + if (R_SUCCEEDED(result)) { + result = cleanup_map_result; + } + } + + // Select a result for the client. + Result client_result = result; + if (closed && R_SUCCEEDED(result)) { + result = ResultSessionClosed; + client_result = ResultSessionClosed; + } else { + result = ResultSuccess; + } + + // If there's a client thread, update it. + if (client_thread != nullptr) { + if (event != nullptr) { + // Get the client process/page table. + KProcess* client_process = client_thread->GetOwnerProcess(); + KProcessPageTable* client_page_table = std::addressof(client_process->GetPageTable()); + + // If we need to, reply with an async error. + if (R_FAILED(client_result)) { + ReplyAsyncError(client_process, client_message, client_buffer_size, client_result); + } + + // Unlock the client buffer. + // NOTE: Nintendo does not check the result of this. + client_page_table->UnlockForIpcUserBuffer(client_message, client_buffer_size); + + // Signal the event. + event->Signal(); + } else { + // End the client thread's wait. + KScopedSchedulerLock sl{m_kernel}; + + if (!client_thread->IsTerminationRequested()) { + client_thread->EndWait(client_result); + } + } + } + + R_RETURN(result); +} + +Result KServerSession::OnRequest(KSessionRequest* request) { + // Create the wait queue. + ThreadQueueImplForKServerSessionRequest wait_queue{m_kernel}; + + { + // Lock the scheduler. + KScopedSchedulerLock sl{m_kernel}; + + // Ensure that we can handle new requests. + R_UNLESS(!m_parent->IsServerClosed(), ResultSessionClosed); + + // Check that we're not terminating. + R_UNLESS(!GetCurrentThread(m_kernel).IsTerminationRequested(), ResultTerminationRequested); + + // Get whether we're empty. + const bool was_empty = m_request_list.empty(); + + // Add the request to the list. + request->Open(); + m_request_list.push_back(*request); + + // If we were empty, signal. + if (was_empty) { + this->NotifyAvailable(); + } + + // If we have a request event, this is asynchronous, and we don't need to wait. + R_SUCCEED_IF(request->GetEvent() != nullptr); + + // This is a synchronous request, so we should wait for our request to complete. + GetCurrentThread(m_kernel).SetWaitReasonForDebugging(ThreadWaitReasonForDebugging::IPC); + GetCurrentThread(m_kernel).BeginWait(std::addressof(wait_queue)); + } + + return GetCurrentThread(m_kernel).GetWaitResult(); +} + +bool KServerSession::IsSignaled() const { + ASSERT(KScheduler::IsSchedulerLockedByCurrentThread(m_kernel)); + + // If the client is closed, we're always signaled. + if (m_parent->IsClientClosed()) { + return true; + } + + // Otherwise, we're signaled if we have a request and aren't handling one. + return !m_request_list.empty() && m_current_request == nullptr; } void KServerSession::CleanupRequests() { @@ -527,31 +1313,30 @@ void KServerSession::CleanupRequests() { SCOPE_EXIT({ request->Close(); }); // Extract relevant information from the request. - // const uintptr_t client_message = request->GetAddress(); - // const size_t client_buffer_size = request->GetSize(); + const uint64_t client_message = request->GetAddress(); + const size_t client_buffer_size = request->GetSize(); KThread* client_thread = request->GetThread(); KEvent* event = request->GetEvent(); - // KProcess *server_process = request->GetServerProcess(); - // KProcess *client_process = (client_thread != nullptr) ? - // client_thread->GetOwnerProcess() : nullptr; - // KProcessPageTable *client_page_table = (client_process != nullptr) ? - // std::addressof(client_process->GetPageTable()) - // : nullptr; + KProcess* server_process = request->GetServerProcess(); + KProcess* client_process = + (client_thread != nullptr) ? client_thread->GetOwnerProcess() : nullptr; + KProcessPageTable* client_page_table = + (client_process != nullptr) ? std::addressof(client_process->GetPageTable()) : nullptr; // Cleanup the mappings. - // Result result = CleanupMap(request, server_process, client_page_table); + Result result = CleanupMap(request, server_process, client_page_table); // If there's a client thread, update it. if (client_thread != nullptr) { if (event != nullptr) { - // // We need to reply async. - // ReplyAsyncError(client_process, client_message, client_buffer_size, - // (R_SUCCEEDED(result) ? ResultSessionClosed : result)); + // We need to reply async. + ReplyAsyncError(client_process, client_message, client_buffer_size, + (R_SUCCEEDED(result) ? ResultSessionClosed : result)); - // // Unlock the client buffer. + // Unlock the client buffer. // NOTE: Nintendo does not check the result of this. - // client_page_table->UnlockForIpcUserBuffer(client_message, client_buffer_size); + client_page_table->UnlockForIpcUserBuffer(client_message, client_buffer_size); // Signal the event. event->Signal(); @@ -567,4 +1352,97 @@ void KServerSession::CleanupRequests() { } } +void KServerSession::OnClientClosed() { + KScopedLightLock lk{m_lock}; + + // Handle any pending requests. + KSessionRequest* prev_request = nullptr; + while (true) { + // Declare variables for processing the request. + KSessionRequest* request = nullptr; + KEvent* event = nullptr; + KThread* thread = nullptr; + bool cur_request = false; + bool terminate = false; + + // Get the next request. + { + KScopedSchedulerLock sl{m_kernel}; + + if (m_current_request != nullptr && m_current_request != prev_request) { + // Set the request, open a reference as we process it. + request = m_current_request; + request->Open(); + cur_request = true; + + // Get thread and event for the request. + thread = request->GetThread(); + event = request->GetEvent(); + + // If the thread is terminating, handle that. + if (thread->IsTerminationRequested()) { + request->ClearThread(); + request->ClearEvent(); + terminate = true; + } + + prev_request = request; + } else if (!m_request_list.empty()) { + // Pop the request from the front of the list. + request = std::addressof(m_request_list.front()); + m_request_list.pop_front(); + + // Get thread and event for the request. + thread = request->GetThread(); + event = request->GetEvent(); + } + } + + // If there are no requests, we're done. + if (request == nullptr) { + break; + } + + // All requests must have threads. + ASSERT(thread != nullptr); + + // Ensure that we close the request when done. + SCOPE_EXIT({ request->Close(); }); + + // If we're terminating, close a reference to the thread and event. + if (terminate) { + thread->Close(); + if (event != nullptr) { + event->Close(); + } + } + + // If we need to, reply. + if (event != nullptr && !cur_request) { + // There must be no mappings. + ASSERT(request->GetSendCount() == 0); + ASSERT(request->GetReceiveCount() == 0); + ASSERT(request->GetExchangeCount() == 0); + + // Get the process and page table. + KProcess* client_process = thread->GetOwnerProcess(); + auto& client_pt = client_process->GetPageTable(); + + // Reply to the request. + ReplyAsyncError(client_process, request->GetAddress(), request->GetSize(), + ResultSessionClosed); + + // Unlock the buffer. + // NOTE: Nintendo does not check the result of this. + client_pt.UnlockForIpcUserBuffer(request->GetAddress(), request->GetSize()); + + // Signal the event. + event->Signal(); + } + } + + // Notify. + this->NotifyAvailable(ResultSessionClosed); +} + } // namespace Kernel diff --git a/src/core/hle/kernel/k_server_session.h b/src/core/hle/kernel/k_server_session.h index 403891919..2876c231b 100644 --- a/src/core/hle/kernel/k_server_session.h +++ b/src/core/hle/kernel/k_server_session.h @@ -49,14 +49,21 @@ public: bool IsSignaled() const override; void OnClientClosed(); - /// TODO: flesh these out to match the real kernel Result OnRequest(KSessionRequest* request); - Result SendReply(bool is_hle = false); - Result ReceiveRequest(std::shared_ptr* out_context = nullptr, + Result SendReply(uintptr_t server_message, uintptr_t server_buffer_size, + KPhysicalAddress server_message_paddr, bool is_hle = false); + Result ReceiveRequest(uintptr_t server_message, uintptr_t server_buffer_size, + KPhysicalAddress server_message_paddr, + std::shared_ptr* out_context = nullptr, std::weak_ptr manager = {}); Result SendReplyHLE() { - return SendReply(true); + R_RETURN(this->SendReply(0, 0, 0, true)); + } + + Result ReceiveRequestHLE(std::shared_ptr* out_context, + std::weak_ptr manager) { + R_RETURN(this->ReceiveRequest(0, 0, 0, out_context, manager)); } private: diff --git a/src/core/hle/kernel/message_buffer.h b/src/core/hle/kernel/message_buffer.h index 75b275310..d528a9bb3 100644 --- a/src/core/hle/kernel/message_buffer.h +++ b/src/core/hle/kernel/message_buffer.h @@ -18,13 +18,13 @@ public: static constexpr inline u64 NullTag = 0; public: - enum class ReceiveListCountType : u32 { - None = 0, - ToMessageBuffer = 1, - ToSingleBuffer = 2, + enum ReceiveListCountType : u32 { + ReceiveListCountType_None = 0, + ReceiveListCountType_ToMessageBuffer = 1, + ReceiveListCountType_ToSingleBuffer = 2, - CountOffset = 2, - CountMax = 13, + ReceiveListCountType_CountOffset = 2, + ReceiveListCountType_CountMax = 13, }; private: @@ -591,16 +591,16 @@ public: // Add the size of the receive list. const auto count = hdr.GetReceiveListCount(); switch (count) { - case MessageHeader::ReceiveListCountType::None: + case MessageHeader::ReceiveListCountType_None: break; - case MessageHeader::ReceiveListCountType::ToMessageBuffer: + case MessageHeader::ReceiveListCountType_ToMessageBuffer: break; - case MessageHeader::ReceiveListCountType::ToSingleBuffer: + case MessageHeader::ReceiveListCountType_ToSingleBuffer: msg_size += ReceiveListEntry::GetDataSize(); break; default: msg_size += (static_cast(count) - - static_cast(MessageHeader::ReceiveListCountType::CountOffset)) * + static_cast(MessageHeader::ReceiveListCountType_CountOffset)) * ReceiveListEntry::GetDataSize(); break; } diff --git a/src/core/hle/kernel/svc/svc_ipc.cpp b/src/core/hle/kernel/svc/svc_ipc.cpp index 47a3e7bb0..85cc4f561 100644 --- a/src/core/hle/kernel/svc/svc_ipc.cpp +++ b/src/core/hle/kernel/svc/svc_ipc.cpp @@ -48,8 +48,7 @@ Result ReplyAndReceiveImpl(KernelCore& kernel, int32_t* out_index, uintptr_t mes }; // Send the reply. - R_TRY(session->SendReply()); - // R_TRY(session->SendReply(message, buffer_size, message_paddr)); + R_TRY(session->SendReply(message, buffer_size, message_paddr)); } // Receive a message. @@ -85,8 +84,7 @@ Result ReplyAndReceiveImpl(KernelCore& kernel, int32_t* out_index, uintptr_t mes if (R_SUCCEEDED(result)) { KServerSession* session = objs[index]->DynamicCast(); if (session != nullptr) { - // result = session->ReceiveRequest(message, buffer_size, message_paddr); - result = session->ReceiveRequest(); + result = session->ReceiveRequest(message, buffer_size, message_paddr); if (ResultNotFound == result) { continue; } diff --git a/src/core/hle/kernel/svc_results.h b/src/core/hle/kernel/svc_results.h index e1ad78607..38e71d516 100644 --- a/src/core/hle/kernel/svc_results.h +++ b/src/core/hle/kernel/svc_results.h @@ -38,7 +38,9 @@ constexpr Result ResultInvalidState{ErrorModule::Kernel, 125}; constexpr Result ResultReservedUsed{ErrorModule::Kernel, 126}; constexpr Result ResultPortClosed{ErrorModule::Kernel, 131}; constexpr Result ResultLimitReached{ErrorModule::Kernel, 132}; +constexpr Result ResultReceiveListBroken{ErrorModule::Kernel, 258}; constexpr Result ResultOutOfAddressSpace{ErrorModule::Kernel, 259}; +constexpr Result ResultMessageTooLarge{ErrorModule::Kernel, 260}; constexpr Result ResultInvalidId{ErrorModule::Kernel, 519}; } // namespace Kernel diff --git a/src/core/hle/service/server_manager.cpp b/src/core/hle/service/server_manager.cpp index 6808247a9..ec599dac2 100644 --- a/src/core/hle/service/server_manager.cpp +++ b/src/core/hle/service/server_manager.cpp @@ -372,7 +372,7 @@ Result ServerManager::OnSessionEvent(Kernel::KServerSession* session, // Try to receive a message. std::shared_ptr context; - rc = session->ReceiveRequest(&context, manager); + rc = session->ReceiveRequestHLE(&context, manager); // If the session has been closed, we're done. if (rc == Kernel::ResultSessionClosed) { -- cgit v1.2.3