Branch data Line data Source code
1 : : // Copyright 2026 HPActor Contributors
2 : : //
3 : : // Licensed under the Apache License, Version 2.0 (the "License");
4 : : // you may not use this file except in compliance with the License.
5 : : // You may obtain a copy of the License at
6 : : //
7 : : // http://www.apache.org/licenses/LICENSE-2.0
8 : : //
9 : : // Unless required by applicable law or agreed to in writing, software
10 : : // distributed under the License is distributed on an "AS IS" BASIS,
11 : : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : // See the License for the specific language governing permissions and
13 : : // limitations under the License.
14 : :
15 : : #include <hpactor/net/frame.hpp>
16 : : #include <hpactor/rpc/rpc_channel.hpp>
17 : : #include <hpactor/sched/scheduler.hpp>
18 : :
19 : : namespace hpactor {
20 : :
21 : : // -----------------------------------------------------------------------------
22 : : // RpcFuture implementation
23 : : // -----------------------------------------------------------------------------
24 : : template <typename T>
25 : 14 : RpcFuture<T>::RpcFuture(std::future<result<T>> inner,
26 : : std::chrono::milliseconds timeout)
27 : 14 : : inner_(std::move(inner)), timeout_(timeout) {}
28 : :
29 : 3 : template <typename T> result<T> RpcFuture<T>::get() {
30 : 3 : if (!inner_.valid()) {
31 : 0 : return result<T>::make(error(errors::unknown, "future not valid"));
32 : : }
33 : :
34 : 3 : auto status = inner_.wait_for(timeout_);
35 : 3 : if (status == std::future_status::timeout) {
36 : 2 : return result<T>::make(error(errors::timeout, "RPC call timed out"));
37 : : }
38 : :
39 : 2 : return inner_.get();
40 : : }
41 : :
42 : : // Explicit instantiations
43 : : template class RpcFuture<StreamBuffer>;
44 : :
45 : : // -----------------------------------------------------------------------------
46 : : // RpcChannel implementation
47 : : // -----------------------------------------------------------------------------
48 : 4 : RpcChannel::RpcChannel(net::Transport* transport, sched::IScheduler* scheduler)
49 : 4 : : transport_(transport), scheduler_(scheduler) {}
50 : :
51 : 0 : void RpcChannel::abort() {
52 : 0 : std::lock_guard<std::mutex> lock(mutex_);
53 : 0 : for (auto& [id, call] : pending_) {
54 : 0 : if (!call->ready_.load(std::memory_order_acquire)) {
55 : 0 : call->promise.set_value(
56 : 0 : result<StreamBuffer>::make(error(errors::unknown, "RPC channel "
57 : : "aborted")));
58 : 0 : call->ready_.store(true, std::memory_order_release);
59 : : }
60 : : }
61 : 0 : pending_.clear();
62 : 0 : }
63 : :
64 : 1 : void RpcChannel::on_response(const RpcResponseFrame& response) {
65 : 1 : std::unique_ptr<PendingCall> call;
66 : 1 : uint64_t key = response.msg_id.value();
67 : : {
68 : 1 : std::lock_guard<std::mutex> lock(mutex_);
69 : 1 : auto it = pending_.find(key);
70 : 1 : if (it == pending_.end()) {
71 : 0 : return;
72 : : }
73 : 1 : call = std::move(it->second);
74 : 1 : pending_.erase(it);
75 : 1 : }
76 : :
77 : 1 : call->ready_.store(true, std::memory_order_release);
78 : 2 : call->promise.set_value(
79 : 2 : result<StreamBuffer>::make(StreamBuffer(response.payload)));
80 : 1 : }
81 : :
82 : 1 : void RpcChannel::on_timeout(MessageId msg_id) {
83 : 1 : PendingCall* call_ptr = nullptr;
84 : 1 : uint64_t key = msg_id.value();
85 : : {
86 : 1 : std::lock_guard<std::mutex> lock(mutex_);
87 : 1 : auto it = pending_.find(key);
88 : 1 : if (it == pending_.end()) {
89 : 0 : return;
90 : : }
91 : 1 : call_ptr = it->second.get();
92 : 1 : }
93 : :
94 : 1 : if (call_ptr->retry_count < call_ptr->max_retries) {
95 : 1 : call_ptr->retry_count++;
96 : 1 : schedule_retry(call_ptr);
97 : : } else {
98 : 0 : call_ptr->ready_.store(true, std::memory_order_release);
99 : 0 : call_ptr->promise.set_value(
100 : 0 : result<StreamBuffer>::make(error(errors::timeout, "RPC call timed "
101 : : "out")));
102 : 0 : std::lock_guard<std::mutex> lock(mutex_);
103 : 0 : pending_.erase(key);
104 : 0 : }
105 : : }
106 : :
107 : 1 : void RpcChannel::schedule_retry(PendingCall* call) {
108 : 1 : int64_t delay_ns = call->timeout.count() * 1000000;
109 : 1 : scheduler_->schedule_after(
110 : 1 : [this, msg_id = call->msg_id]() { on_timeout(msg_id); }, delay_ns);
111 : 1 : send_request(*call, true);
112 : 1 : }
113 : :
114 : 14 : void RpcChannel::send_request(PendingCall& call, bool is_retry) {
115 : 14 : net::WireFrame frame;
116 : 14 : net::to_proto(frame.pb_frame.mutable_sender(), ActorAddress{});
117 : 14 : net::to_proto(frame.pb_frame.mutable_receiver(), call.target);
118 : 14 : frame.pb_frame.set_payload(
119 : 14 : reinterpret_cast<const char*>(call.encoded_request.data()),
120 : : call.encoded_request.size());
121 : 14 : frame.pb_frame.set_message_id(call.msg_id.value());
122 : 14 : frame.pb_frame.set_flags(net::WireFrame::RpcRequest);
123 : 14 : if (is_retry) {
124 : 1 : frame.pb_frame.set_flags(frame.pb_frame.flags() |
125 : : net::WireFrame::RpcIdempotent);
126 : : }
127 : 14 : if (call.has_trace_context) {
128 : 0 : net::to_proto(frame.pb_frame.mutable_trace_context(), call.trace_context);
129 : : }
130 : :
131 : 14 : StreamBuffer encoded = frame.encode();
132 : 14 : transport_->send(call.target, encoded);
133 : 14 : }
134 : :
135 : 13 : RpcFuture<StreamBuffer> RpcChannel::call_raw(const ActorAddress& target,
136 : : const StreamBuffer& encoded_request,
137 : : std::chrono::milliseconds timeout_ms) {
138 : 13 : return call_raw(target, encoded_request, timeout_ms, nullptr);
139 : : }
140 : :
141 : 13 : RpcFuture<StreamBuffer> RpcChannel::call_raw(const ActorAddress& target,
142 : : const StreamBuffer& encoded_request,
143 : : std::chrono::milliseconds timeout_ms,
144 : : const TraceContext* parent_context) {
145 : 13 : MessageId msg_id = generate_message_id();
146 : :
147 : 13 : auto promise_ptr = std::make_shared<std::promise<result<StreamBuffer>>>();
148 : 13 : auto future = promise_ptr->get_future();
149 : :
150 : : auto* call_ptr =
151 : : new PendingCall{.msg_id = msg_id,
152 : : .target = target,
153 : : .encoded_request = encoded_request,
154 : : .timeout = timeout_ms,
155 : : .retry_count = 0,
156 : : .max_retries = 5,
157 : 13 : .promise = std::move(*promise_ptr),
158 : 13 : .enqueued_at = std::chrono::steady_clock::now(),
159 : 39 : .ready_ = false};
160 : 13 : if (parent_context != nullptr && parent_context->valid()) {
161 : 0 : call_ptr->has_trace_context = true;
162 : 0 : call_ptr->trace_context = *parent_context;
163 : : }
164 : :
165 : 13 : uint64_t key = msg_id.value();
166 : : {
167 : 13 : std::lock_guard<std::mutex> lock(mutex_);
168 : 13 : pending_.emplace(key, std::unique_ptr<PendingCall>(call_ptr));
169 : 13 : }
170 : :
171 : 13 : send_request(*call_ptr, false);
172 : :
173 : 13 : int64_t delay_ns = timeout_ms.count() * 1000000;
174 : 14 : scheduler_->schedule_after([this, msg_id]() { on_timeout(msg_id); }, delay_ns);
175 : :
176 : 13 : return RpcFuture<StreamBuffer>(std::move(future), timeout_ms);
177 : 13 : }
178 : :
179 : : } // namespace hpactor
|