Branch data Line data Source code
1 : : // Copyright 2026 HPActor Contributors
2 : : //
3 : : // Licensed under the Apache License, Version 2.0 (the "License");
4 : : // you may not use this file except in compliance with the License.
5 : : // You may obtain a copy of the License at
6 : : //
7 : : // http://www.apache.org/licenses/LICENSE-2.0
8 : : //
9 : : // Unless required by applicable law or agreed to in writing, software
10 : : // distributed under the License is distributed on an "AS IS" BASIS,
11 : : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : // See the License for the specific language governing permissions and
13 : : // limitations under the License.
14 : :
15 : : #include <hpactor/sched/timing_wheel.hpp>
16 : :
17 : : #include <chrono>
18 : :
19 : : namespace hpactor::sched {
20 : :
21 : 212 : TimingWheel::TimingWheel(int64_t tick_ns, uint32_t num_levels)
22 : 424 : : tick_ns_(tick_ns), num_levels_(num_levels), levels_(num_levels),
23 : 212 : current_time_(std::chrono::steady_clock::now().time_since_epoch().count()) {
24 : : // Initialize each level
25 : : // Level 0: 256 buckets of tick_ns each
26 : : // Level 1: 256 buckets of 256 * tick_ns each
27 : : // etc.
28 : 1060 : for (uint32_t level = 0; level < num_levels; ++level) {
29 : 848 : uint32_t num_buckets = 256;
30 : :
31 : 848 : levels_[level].buckets.resize(num_buckets);
32 : 848 : levels_[level].num_buckets = num_buckets;
33 : 848 : levels_[level].mask = num_buckets - 1;
34 : : }
35 : 212 : }
36 : :
37 : 212 : TimingWheel::~TimingWheel() {
38 : : // Cancel all timers
39 : 212 : for (Timer* timer : all_timers_) {
40 : 0 : timer->cancelled = true;
41 : 0 : delete timer;
42 : : }
43 : 212 : }
44 : :
45 : 7 : uint64_t TimingWheel::schedule(int64_t delay_ns, TimerCallback callback) {
46 : 21 : return schedule_at(current_time_.load(std::memory_order_relaxed) + delay_ns,
47 : 14 : std::move(callback));
48 : : }
49 : :
50 : 7 : uint64_t TimingWheel::schedule_at(int64_t expire_ns, TimerCallback callback) {
51 : 7 : return add_timer_internal(expire_ns, std::move(callback));
52 : : }
53 : :
54 : : uint64_t
55 : 7 : TimingWheel::add_timer_internal(int64_t expire_ns, TimerCallback callback) {
56 : 7 : auto* timer = new Timer;
57 : 7 : timer->expire_ns = expire_ns;
58 : 7 : timer->id = next_timer_id_.fetch_add(1);
59 : 7 : timer->callback = std::move(callback);
60 : 7 : timer->cancelled = false;
61 : :
62 : 7 : insert_timer(timer);
63 : 7 : return timer->id;
64 : : }
65 : :
66 : 3 : bool TimingWheel::cancel(uint64_t timer_id) {
67 : 3 : Timer* timer = remove_timer(timer_id);
68 : 3 : if (timer) {
69 : 1 : timer->cancelled = true;
70 : 1 : return true;
71 : : }
72 : 2 : return false;
73 : : }
74 : :
75 : 7 : void TimingWheel::insert_timer(Timer* timer) {
76 : 7 : int64_t expire = timer->expire_ns;
77 : 7 : int64_t now = current_time_.load(std::memory_order_relaxed);
78 : :
79 : 7 : expire = std::max(expire, now);
80 : :
81 : : // Calculate which level and bucket
82 : 7 : int64_t diff = expire - now;
83 : 7 : uint32_t level = 0;
84 : :
85 : : // Find the appropriate level for this timer
86 : : // Level covers tick_ns * 256^(level+1) time range
87 : 13 : for (uint32_t l = 0; l < num_levels_; ++l) {
88 : 13 : int64_t level_range = tick_ns_;
89 : 32 : for (uint32_t k = 0; k <= l; ++k) {
90 : 19 : level_range *= 256;
91 : : }
92 : 13 : if (diff < level_range) {
93 : 7 : level = l;
94 : 7 : break;
95 : : }
96 : : }
97 : :
98 : : // Calculate bucket index for this level
99 : 7 : int64_t level_offset = now / tick_ns_;
100 : 13 : for (uint32_t l = 0; l < level; ++l) {
101 : 6 : level_offset /= 256;
102 : : }
103 : 7 : uint32_t bucket = static_cast<uint32_t>(level_offset) & levels_[level].mask;
104 : :
105 : 7 : timer->id |= (static_cast<uint64_t>(level) << 48); // Store level in high
106 : : // bits
107 : 7 : levels_[level].buckets[bucket].push_back(timer);
108 : 7 : }
109 : :
110 : 3 : TimingWheel::Timer* TimingWheel::remove_timer(uint64_t timer_id) {
111 : 3 : uint32_t level = static_cast<uint32_t>(timer_id >> 48);
112 : 3 : timer_id &= 0xFFFFFFFFFFFFULL; // Mask to get actual ID
113 : :
114 : 3 : if (level >= num_levels_) {
115 : 0 : return nullptr;
116 : : }
117 : :
118 : : // Search all buckets at this level for the timer
119 : : // This is O(buckets) but timers at higher levels are fewer
120 : 707 : for (auto& bucket : levels_[level].buckets) {
121 : 706 : for (auto it = bucket.begin(); it != bucket.end(); ++it) {
122 : 2 : if ((*it)->id == timer_id) {
123 : 1 : Timer* timer = *it;
124 : 1 : bucket.erase(it);
125 : 1 : return timer;
126 : : }
127 : : }
128 : : }
129 : 2 : return nullptr;
130 : : }
131 : :
132 : 19839 : uint32_t TimingWheel::advance(int64_t now_ns) {
133 : 19839 : int64_t old_time = current_time_.load(std::memory_order_relaxed);
134 : 19839 : if (now_ns <= old_time) {
135 : 0 : return 0;
136 : : }
137 : :
138 : 19839 : current_time_.store(now_ns, std::memory_order_relaxed);
139 : :
140 : 19839 : uint32_t fired = 0;
141 : :
142 : : // Process all levels
143 : 99195 : for (uint32_t level = 0; level < num_levels_; ++level) {
144 : 79356 : uint64_t level_offset = static_cast<uint64_t>(old_time / tick_ns_);
145 : 198390 : for (uint32_t l = 0; l < level; ++l) {
146 : 119034 : level_offset /= 256;
147 : : }
148 : : uint32_t start_bucket =
149 : 79356 : static_cast<uint32_t>(level_offset) & levels_[level].mask;
150 : :
151 : 79356 : uint64_t end_offset = static_cast<uint64_t>(now_ns / tick_ns_);
152 : 198390 : for (uint32_t l = 0; l < level; ++l) {
153 : 119034 : end_offset /= 256;
154 : : }
155 : : uint32_t end_bucket =
156 : 79356 : static_cast<uint32_t>(end_offset) & levels_[level].mask;
157 : :
158 : : // Process buckets from start to end (wrapping around)
159 : 79356 : uint32_t num_buckets = levels_[level].num_buckets;
160 : 79356 : uint32_t count =
161 : 79356 : ((end_bucket - start_bucket + num_buckets) % num_buckets) + 1;
162 : :
163 : 181104 : for (uint32_t i = 0; i < count; ++i) {
164 : 101748 : uint32_t bucket_idx = (start_bucket + i) % num_buckets;
165 : 101748 : auto& bucket = levels_[level].buckets[bucket_idx];
166 : :
167 : 101762 : for (auto it = bucket.begin(); it != bucket.end();) {
168 : 14 : Timer* timer = *it;
169 : 14 : if (timer->cancelled) {
170 : 0 : it = bucket.erase(it);
171 : 0 : delete timer;
172 : 6 : continue;
173 : : }
174 : :
175 : 14 : if (timer->expire_ns <= now_ns) {
176 : : // Timer expired, fire it
177 : 0 : it = bucket.erase(it);
178 : 0 : timer->callback();
179 : 0 : delete timer;
180 : 0 : ++fired;
181 : : } else {
182 : : // Timer not yet expired, might need to cascade to lower
183 : : // level
184 : 14 : if (level > 0) {
185 : : // Recalculate which bucket this timer should be in
186 : : // at this (lower) level
187 : 6 : uint32_t lower_level = level - 1;
188 : 6 : int64_t lower_offset = now_ns / tick_ns_;
189 : 12 : for (uint32_t l = 0; l <= lower_level; ++l) {
190 : 6 : lower_offset /= 256;
191 : : }
192 : : uint32_t lower_bucket =
193 : 6 : static_cast<uint32_t>(lower_offset) &
194 : 6 : levels_[lower_level].mask;
195 : :
196 : 6 : timer->id &= 0xFFFFFFFFFFFFULL;
197 : 6 : timer->id |= (static_cast<uint64_t>(lower_level) << 48);
198 : 6 : bucket.erase(it);
199 : 6 : levels_[lower_level].buckets[lower_bucket].push_back(timer);
200 : 6 : it = bucket.begin(); // Reset since we erased
201 : 6 : continue;
202 : 6 : }
203 : 8 : ++it;
204 : : }
205 : : }
206 : : }
207 : : }
208 : :
209 : 19839 : return fired;
210 : : }
211 : :
212 : 0 : bool TimingWheel::empty() const {
213 : 0 : for (const auto& level : levels_) {
214 : 0 : for (const auto& bucket : level.buckets) {
215 : 0 : if (!bucket.empty()) {
216 : 0 : return false;
217 : : }
218 : : }
219 : : }
220 : 0 : return true;
221 : : }
222 : :
223 : 0 : size_t TimingWheel::size() const {
224 : 0 : size_t count = 0;
225 : 0 : for (const auto& level : levels_) {
226 : 0 : for (const auto& bucket : level.buckets) {
227 : 0 : count += bucket.size();
228 : : }
229 : : }
230 : 0 : return count;
231 : : }
232 : :
233 : : } // namespace hpactor::sched
|