From 531572b411a4a311cb38bcf09a2c95559ef068aa Mon Sep 17 00:00:00 2001 From: Liam Date: Sat, 26 Aug 2023 18:18:13 -0400 Subject: internal_network: cancel pending socket operations on application process termination --- src/core/core.cpp | 2 + src/core/internal_network/network.cpp | 86 +++++++++++++++++++++++++++++++++-- src/core/internal_network/network.h | 3 ++ 3 files changed, 88 insertions(+), 3 deletions(-) (limited to 'src/core') diff --git a/src/core/core.cpp b/src/core/core.cpp index 2f67e60a9..9c5246a56 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -404,6 +404,7 @@ struct System::Impl { gpu_core->NotifyShutdown(); } + Network::CancelPendingSocketOperations(); kernel.SuspendApplication(true); if (services) { services->KillNVNFlinger(); @@ -425,6 +426,7 @@ struct System::Impl { debugger.reset(); kernel.Shutdown(); memory.Reset(); + Network::RestartSocketOperations(); if (auto room_member = room_network.GetRoomMember().lock()) { Network::GameInfo game_info{}; diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp index 5d28300e6..ef5e5d013 100644 --- a/src/core/internal_network/network.cpp +++ b/src/core/internal_network/network.cpp @@ -48,15 +48,32 @@ enum class CallType { using socklen_t = int; +SOCKET interrupt_socket = static_cast(-1); + +void InterruptSocketOperations() { + closesocket(interrupt_socket); +} + +void AcknowledgeInterrupt() { + interrupt_socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); +} + void Initialize() { WSADATA wsa_data; (void)WSAStartup(MAKEWORD(2, 2), &wsa_data); + + AcknowledgeInterrupt(); } void Finalize() { + InterruptSocketOperations(); WSACleanup(); } +SOCKET GetInterruptSocket() { + return interrupt_socket; +} + sockaddr TranslateFromSockAddrIn(SockAddrIn input) { sockaddr_in result; @@ -157,9 +174,39 @@ constexpr int SD_RECEIVE = SHUT_RD; constexpr int SD_SEND = SHUT_WR; constexpr int SD_BOTH = SHUT_RDWR; -void Initialize() {} +int interrupt_pipe_fd[2] = {-1, -1}; + +void Initialize() { + if (pipe(interrupt_pipe_fd) != 0) { + LOG_ERROR(Network, "Failed to create interrupt pipe!"); + } + int flags = fcntl(interrupt_pipe_fd[0], F_GETFL); + ASSERT_MSG(fcntl(interrupt_pipe_fd[0], F_SETFL, flags | O_NONBLOCK) == 0, + "Failed to set nonblocking state for interrupt pipe"); +} + +void Finalize() { + if (interrupt_pipe_fd[0] >= 0) { + close(interrupt_pipe_fd[0]); + } + if (interrupt_pipe_fd[1] >= 0) { + close(interrupt_pipe_fd[1]); + } +} + +void InterruptSocketOperations() { + u8 value = 0; + ASSERT(write(interrupt_pipe_fd[1], &value, sizeof(value)) == 1); +} -void Finalize() {} +void AcknowledgeInterrupt() { + u8 value = 0; + read(interrupt_pipe_fd[0], &value, sizeof(value)); +} + +SOCKET GetInterruptSocket() { + return interrupt_pipe_fd[0]; +} sockaddr TranslateFromSockAddrIn(SockAddrIn input) { sockaddr_in result; @@ -490,6 +537,14 @@ NetworkInstance::~NetworkInstance() { Finalize(); } +void CancelPendingSocketOperations() { + InterruptSocketOperations(); +} + +void RestartSocketOperations() { + AcknowledgeInterrupt(); +} + std::optional GetHostIPv4Address() { const auto network_interface = Network::GetSelectedNetworkInterface(); if (!network_interface.has_value()) { @@ -560,7 +615,14 @@ std::pair Poll(std::vector& pollfds, s32 timeout) { return result; }); - const int result = WSAPoll(host_pollfds.data(), static_cast(num), timeout); + host_pollfds.push_back(WSAPOLLFD{ + .fd = GetInterruptSocket(), + .events = POLLIN, + .revents = 0, + }); + + const int result = + WSAPoll(host_pollfds.data(), static_cast(host_pollfds.size()), timeout); if (result == 0) { ASSERT(std::all_of(host_pollfds.begin(), host_pollfds.end(), [](WSAPOLLFD fd) { return fd.revents == 0; })); @@ -627,6 +689,24 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { std::pair Socket::Accept() { sockaddr_in addr; socklen_t addrlen = sizeof(addr); + + std::vector host_pollfds{ + WSAPOLLFD{fd, POLLIN, 0}, + WSAPOLLFD{GetInterruptSocket(), POLLIN, 0}, + }; + + while (true) { + const int pollres = + WSAPoll(host_pollfds.data(), static_cast(host_pollfds.size()), -1); + if (host_pollfds[1].revents != 0) { + // Interrupt signaled before a client could be accepted, break + return {AcceptResult{}, Errno::AGAIN}; + } + if (pollres > 0) { + break; + } + } + const SOCKET new_socket = accept(fd, reinterpret_cast(&addr), &addrlen); if (new_socket == INVALID_SOCKET) { diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h index c7e20ae34..b7b7d773a 100644 --- a/src/core/internal_network/network.h +++ b/src/core/internal_network/network.h @@ -96,6 +96,9 @@ public: ~NetworkInstance(); }; +void CancelPendingSocketOperations(); +void RestartSocketOperations(); + #ifdef _WIN32 constexpr IPv4Address TranslateIPv4(in_addr addr) { auto& bytes = addr.S_un.S_un_b; -- cgit v1.2.3