summaryrefslogtreecommitdiffstats
path: root/src/core/hle/kernel/synchronization.cpp
blob: a7e3fbe92f2d03e1e1d80723f6e7657292d35402 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright 2020 yuzu Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.

#include "core/core.h"
#include "core/hle/kernel/errors.h"
#include "core/hle/kernel/handle_table.h"
#include "core/hle/kernel/kernel.h"
#include "core/hle/kernel/scheduler.h"
#include "core/hle/kernel/synchronization.h"
#include "core/hle/kernel/synchronization_object.h"
#include "core/hle/kernel/thread.h"
#include "core/hle/kernel/time_manager.h"

namespace Kernel {

Synchronization::Synchronization(Core::System& system) : system{system} {}

void Synchronization::SignalObject(SynchronizationObject& obj) const {
    auto& kernel = system.Kernel();
    SchedulerLock lock(kernel);
    auto& time_manager = kernel.TimeManager();
    if (obj.IsSignaled()) {
        for (auto thread : obj.GetWaitingThreads()) {
            if (thread->GetSchedulingStatus() == ThreadSchedStatus::Paused) {
                thread->SetSynchronizationResults(&obj, RESULT_SUCCESS);
                thread->ResumeFromWait();
                time_manager.CancelTimeEvent(thread.get());
            }
        }
        obj.ClearWaitingThreads();
    }
}

std::pair<ResultCode, Handle> Synchronization::WaitFor(
    std::vector<std::shared_ptr<SynchronizationObject>>& sync_objects, s64 nano_seconds) {
    auto& kernel = system.Kernel();
    auto* const thread = system.CurrentScheduler().GetCurrentThread();
    Handle event_handle = InvalidHandle;
    {
        SchedulerLockAndSleep lock(kernel, event_handle, thread, nano_seconds);
        const auto itr =
            std::find_if(sync_objects.begin(), sync_objects.end(),
                         [thread](const std::shared_ptr<SynchronizationObject>& object) {
                             return object->IsSignaled();
                         });

        if (itr != sync_objects.end()) {
            // We found a ready object, acquire it and set the result value
            SynchronizationObject* object = itr->get();
            object->Acquire(thread);
            const u32 index = static_cast<s32>(std::distance(sync_objects.begin(), itr));
            lock.CancelSleep();
            return {RESULT_SUCCESS, index};
        }

        if (nano_seconds == 0) {
            lock.CancelSleep();
            return {RESULT_TIMEOUT, InvalidHandle};
        }

        /// TODO(Blinkhawk): Check for termination pending

        if (thread->IsSyncCancelled()) {
            thread->SetSyncCancelled(false);
            lock.CancelSleep();
            return {ERR_SYNCHRONIZATION_CANCELED, InvalidHandle};
        }

        for (auto& object : sync_objects) {
            object->AddWaitingThread(SharedFrom(thread));
        }

        thread->SetSynchronizationObjects(&sync_objects);
        thread->SetSynchronizationResults(nullptr, RESULT_TIMEOUT);
        thread->SetStatus(ThreadStatus::WaitSynch);
        thread->SetWaitingSync(true);
    }
    thread->SetWaitingSync(false);

    if (event_handle != InvalidHandle) {
        auto& time_manager = kernel.TimeManager();
        time_manager.UnscheduleTimeEvent(event_handle);
    }

    {
        SchedulerLock lock(kernel);
        ResultCode signaling_result = thread->GetSignalingResult();
        SynchronizationObject* signaling_object = thread->GetSignalingObject();
        thread->SetSynchronizationObjects(nullptr);
        for (auto& obj : sync_objects) {
            obj->RemoveWaitingThread(SharedFrom(thread));
        }
        if (signaling_result == RESULT_SUCCESS) {
            const auto itr = std::find_if(
                sync_objects.begin(), sync_objects.end(),
                [signaling_object](const std::shared_ptr<SynchronizationObject>& object) {
                    return object.get() == signaling_object;
                });
            ASSERT(itr != sync_objects.end());
            signaling_object->Acquire(thread);
            const u32 index = static_cast<s32>(std::distance(sync_objects.begin(), itr));
            return {RESULT_SUCCESS, index};
        }
        return {signaling_result, -1};
    }
}

} // namespace Kernel