diff options
-rw-r--r-- | src/core/core.cpp | 2 | ||||
-rw-r--r-- | src/core/internal_network/network.cpp | 89 | ||||
-rw-r--r-- | src/core/internal_network/network.h | 3 |
3 files changed, 91 insertions, 3 deletions
diff --git a/src/core/core.cpp b/src/core/core.cpp index f075ae7fa..2d6e61398 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -406,6 +406,7 @@ struct System::Impl { gpu_core->NotifyShutdown(); } + Network::CancelPendingSocketOperations(); kernel.SuspendApplication(true); if (services) { services->KillNVNFlinger(); @@ -427,6 +428,7 @@ struct System::Impl { debugger.reset(); kernel.Shutdown(); memory.Reset(); + Network::RestartSocketOperations(); if (auto room_member = room_network.GetRoomMember().lock()) { Network::GameInfo game_info{}; diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp index 5d28300e6..a983f23ea 100644 --- a/src/core/internal_network/network.cpp +++ b/src/core/internal_network/network.cpp @@ -48,15 +48,32 @@ enum class CallType { using socklen_t = int; +SOCKET interrupt_socket = static_cast<SOCKET>(-1); + +void InterruptSocketOperations() { + closesocket(interrupt_socket); +} + +void AcknowledgeInterrupt() { + interrupt_socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); +} + void Initialize() { WSADATA wsa_data; (void)WSAStartup(MAKEWORD(2, 2), &wsa_data); + + AcknowledgeInterrupt(); } void Finalize() { + InterruptSocketOperations(); WSACleanup(); } +SOCKET GetInterruptSocket() { + return interrupt_socket; +} + sockaddr TranslateFromSockAddrIn(SockAddrIn input) { sockaddr_in result; @@ -157,9 +174,42 @@ constexpr int SD_RECEIVE = SHUT_RD; constexpr int SD_SEND = SHUT_WR; constexpr int SD_BOTH = SHUT_RDWR; -void Initialize() {} +int interrupt_pipe_fd[2] = {-1, -1}; -void Finalize() {} +void Initialize() { + if (pipe(interrupt_pipe_fd) != 0) { + LOG_ERROR(Network, "Failed to create interrupt pipe!"); + } + int flags = fcntl(interrupt_pipe_fd[0], F_GETFL); + ASSERT_MSG(fcntl(interrupt_pipe_fd[0], F_SETFL, flags | O_NONBLOCK) == 0, + "Failed to set nonblocking state for interrupt pipe"); +} + +void Finalize() { + if (interrupt_pipe_fd[0] >= 0) { + close(interrupt_pipe_fd[0]); + } + if (interrupt_pipe_fd[1] >= 0) { + close(interrupt_pipe_fd[1]); + } +} + +void InterruptSocketOperations() { + u8 value = 0; + ASSERT(write(interrupt_pipe_fd[1], &value, sizeof(value)) == 1); +} + +void AcknowledgeInterrupt() { + u8 value = 0; + ssize_t ret = read(interrupt_pipe_fd[0], &value, sizeof(value)); + if (ret != 1 && errno != EAGAIN && errno != EWOULDBLOCK) { + LOG_ERROR(Network, "Failed to acknowledge interrupt on shutdown"); + } +} + +SOCKET GetInterruptSocket() { + return interrupt_pipe_fd[0]; +} sockaddr TranslateFromSockAddrIn(SockAddrIn input) { sockaddr_in result; @@ -490,6 +540,14 @@ NetworkInstance::~NetworkInstance() { Finalize(); } +void CancelPendingSocketOperations() { + InterruptSocketOperations(); +} + +void RestartSocketOperations() { + AcknowledgeInterrupt(); +} + std::optional<IPv4Address> GetHostIPv4Address() { const auto network_interface = Network::GetSelectedNetworkInterface(); if (!network_interface.has_value()) { @@ -560,7 +618,14 @@ std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { return result; }); - const int result = WSAPoll(host_pollfds.data(), static_cast<ULONG>(num), timeout); + host_pollfds.push_back(WSAPOLLFD{ + .fd = GetInterruptSocket(), + .events = POLLIN, + .revents = 0, + }); + + const int result = + WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), timeout); if (result == 0) { ASSERT(std::all_of(host_pollfds.begin(), host_pollfds.end(), [](WSAPOLLFD fd) { return fd.revents == 0; })); @@ -627,6 +692,24 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { sockaddr_in addr; socklen_t addrlen = sizeof(addr); + + std::vector<WSAPOLLFD> host_pollfds{ + WSAPOLLFD{fd, POLLIN, 0}, + WSAPOLLFD{GetInterruptSocket(), POLLIN, 0}, + }; + + while (true) { + const int pollres = + WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), -1); + if (host_pollfds[1].revents != 0) { + // Interrupt signaled before a client could be accepted, break + return {AcceptResult{}, Errno::AGAIN}; + } + if (pollres > 0) { + break; + } + } + const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); if (new_socket == INVALID_SOCKET) { diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h index c7e20ae34..b7b7d773a 100644 --- a/src/core/internal_network/network.h +++ b/src/core/internal_network/network.h @@ -96,6 +96,9 @@ public: ~NetworkInstance(); }; +void CancelPendingSocketOperations(); +void RestartSocketOperations(); + #ifdef _WIN32 constexpr IPv4Address TranslateIPv4(in_addr addr) { auto& bytes = addr.S_un.S_un_b; |