libraries: kernel: alternative sema implementation

This commit is contained in:
psucien 2024-06-22 20:41:00 +02:00
parent 6bbf8aa1c2
commit 8376813a67
1 changed files with 34 additions and 102 deletions

View File

@ -13,128 +13,60 @@
namespace Libraries::Kernel {
using ListBaseHook =
boost::intrusive::list_base_hook<boost::intrusive::link_mode<boost::intrusive::normal_link>>;
class Semaphore {
public:
Semaphore(s32 init_count, s32 max_count, const char* name, bool is_fifo)
: name{name}, token_count{init_count}, max_count{max_count}, is_fifo{is_fifo} {}
bool Wait(bool can_block, s32 need_count, u64* timeout) {
if (HasAvailableTokens(need_count)) {
return true;
}
if (!can_block) {
if (need_count < 1 || need_count > max_count) {
return false;
}
// Create waiting thread object and add it into the list of waiters.
WaitingThread waiter{need_count, is_fifo};
AddWaiter(waiter);
SCOPE_EXIT {
PopWaiter(waiter);
};
std::unique_lock<std::mutex> lock{mutex};
auto pred = [this, need_count] { return token_count >= need_count; };
if (!timeout) {
cond.wait(lock, pred);
token_count -= need_count;
} else {
const auto time_ms = std::chrono::microseconds(*timeout);
const auto start = std::chrono::high_resolution_clock::now();
// Perform the wait.
return waiter.Wait(timeout);
bool result = cond.wait_for(lock, time_ms, pred);
auto end = std::chrono::high_resolution_clock::now();
if (result) {
const auto delta =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
const auto time_remain = time_ms - delta;
*timeout = time_remain.count();
token_count -= need_count;
} else {
*timeout = 0;
return false;
}
}
return true;
}
bool Signal(s32 signal_count) {
std::scoped_lock lk{mutex};
std::unique_lock<std::mutex> lock{mutex};
if (token_count + signal_count > max_count) {
return false;
}
token_count += signal_count;
// Wake up threads in order of priority.
for (auto& waiter : wait_list) {
if (waiter.need_count > token_count) {
continue;
}
token_count -= waiter.need_count;
waiter.cv.notify_one();
}
token_count += signal_count;
cond.notify_all();
return true;
}
private:
struct WaitingThread : public ListBaseHook {
std::mutex mutex;
std::condition_variable cv;
u32 priority;
s32 need_count;
explicit WaitingThread(s32 need_count, bool is_fifo) : need_count{need_count} {
if (is_fifo) {
return;
}
// Retrieve calling thread priority for sorting into waiting threads list.
s32 policy;
sched_param param;
pthread_getschedparam(pthread_self(), &policy, &param);
priority = param.sched_priority;
}
bool Wait(u64* timeout) {
std::unique_lock lk{mutex};
if (!timeout) {
// Wait indefinitely until we are woken up.
cv.wait(lk);
return true;
}
// Wait until timeout runs out, recording how much remaining time there was.
const auto start = std::chrono::high_resolution_clock::now();
const auto status = cv.wait_for(lk, std::chrono::microseconds(*timeout));
const auto end = std::chrono::high_resolution_clock::now();
const auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
*timeout -= time;
return status != std::cv_status::timeout;
}
bool operator<(const WaitingThread& other) const {
return priority < other.priority;
}
};
void AddWaiter(WaitingThread& waiter) {
std::scoped_lock lk{mutex};
// Insert at the end of the list for FIFO order.
if (is_fifo) {
wait_list.push_back(waiter);
return;
}
// Find the first with priority less then us and insert right before it.
auto it = wait_list.begin();
while (it != wait_list.end() && it->priority > waiter.priority) {
it++;
}
wait_list.insert(it, waiter);
}
void PopWaiter(WaitingThread& waiter) {
std::scoped_lock lk{mutex};
wait_list.erase(WaitingThreads::s_iterator_to(waiter));
}
bool HasAvailableTokens(s32 need_count) {
std::scoped_lock lk{mutex};
if (token_count >= need_count) {
token_count -= need_count;
return true;
}
return false;
}
using WaitingThreads =
boost::intrusive::list<WaitingThread, boost::intrusive::base_hook<ListBaseHook>,
boost::intrusive::constant_time_size<false>>;
WaitingThreads wait_list;
std::string name;
std::atomic<s32> token_count;
s32 token_count;
std::mutex mutex;
std::condition_variable cond;
s32 max_count;
bool is_fifo;
};