Branch data Line data Source code
1 : : // Copyright 2026 HPActor Contributors
2 : : //
3 : : // Licensed under the Apache License, Version 2.0 (the "License");
4 : : // you may not use this file except in compliance with the License.
5 : : // You may obtain a copy of the License at
6 : : //
7 : : // http://www.apache.org/licenses/LICENSE-2.0
8 : : //
9 : : // Unless required by applicable law or agreed to in writing, software
10 : : // distributed under the License is distributed on an "AS IS" BASIS,
11 : : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : // See the License for the specific language governing permissions and
13 : : // limitations under the License.
14 : :
15 : : #include <hpactor/net/tcp_transport.hpp>
16 : :
17 : : #include <hpactor/log/logger.hpp>
18 : :
19 : : #include <cstring>
20 : : #include <fcntl.h>
21 : : #include <netinet/in.h>
22 : : #include <netinet/tcp.h>
23 : : #include <sys/socket.h>
24 : : #include <sys/stat.h>
25 : : #include <sys/un.h>
26 : : #include <unistd.h>
27 : :
28 : : namespace hpactor {
29 : :
30 : : namespace net {
31 : :
32 : : // -----------------------------------------------------------------------------
33 : : // TcpTransport implementation
34 : : // -----------------------------------------------------------------------------
35 : :
36 : 13 : TcpTransport::TcpTransport(EndPoint endpoint, const TlsConfig& tls_config,
37 : 13 : const PoolConfig& pool_config, NodeRegistry* registry)
38 : 13 : : endpoint_(endpoint), acceptor_(&loop_),
39 : 13 : tls_context_(TlsContext::from_config(tls_config)),
40 : 26 : pool_config_(pool_config), registry_(registry) {
41 : : // Ensure UDS directory exists
42 : 13 : std::string uds_dir = "/tmp/hpactor";
43 : 13 : ::mkdir(uds_dir.c_str(), 0755); // Ignore error if exists
44 : :
45 : : // Set up completion callback to route send completions to TlsConnection
46 : 26 : completion_callback_ = [this](OpCompletion c) {
47 : 0 : if (c.type == OpType::Send) {
48 : 0 : auto it = connections_.find(c.fd);
49 : 0 : if (it != connections_.end()) {
50 : 0 : it->second->handle_send_completion(c.result);
51 : : }
52 : : }
53 : 13 : };
54 : 13 : loop_.set_completion_callback(completion_callback_);
55 : 13 : }
56 : :
57 : 13 : TcpTransport::~TcpTransport() {
58 : 13 : stop_listening();
59 : : // Abort all connection pools
60 : 17 : for (auto& [ep, pool] : pools_) {
61 : 4 : pool->abort();
62 : : }
63 : 13 : }
64 : :
65 : : std::shared_ptr<ConnectionPool>
66 : 18 : TcpTransport::get_or_create_pool(EndPoint remote_endpoint) {
67 : 18 : auto it = pools_.find(remote_endpoint);
68 : 18 : if (it != pools_.end()) {
69 : 13 : return it->second;
70 : : }
71 : : auto pool =
72 : 5 : std::make_shared<ConnectionPool>(remote_endpoint, pool_config_, &loop_);
73 : 5 : pools_[remote_endpoint] = pool;
74 : : // Set RPC handler if one has been registered
75 : 5 : if (rpc_handler_) {
76 : 0 : pool->set_rpc_handler(rpc_handler_);
77 : : }
78 : : // Set actor message handler if one has been registered
79 : 5 : if (actor_msg_handler_) {
80 : 0 : pool->set_actor_message_handler(actor_msg_handler_);
81 : : }
82 : 5 : return pool;
83 : 5 : }
84 : :
85 : 18 : ConnectionPtr TcpTransport::connect(EndPoint remote_endpoint,
86 : : const std::string& /*host*/, uint16_t port) {
87 : : // Build sockaddr from remote_endpoint — supports both IPv4 and IPv6
88 : 18 : struct sockaddr_storage addr_storage{};
89 : 18 : socklen_t addr_len = 0;
90 : 18 : int family = 0;
91 : :
92 : 18 : if (auto* ipv4 = std::get_if<Ipv4Endpoint>(&remote_endpoint)) {
93 : 18 : family = AF_INET;
94 : 18 : auto* sa = reinterpret_cast<struct sockaddr_in*>(&addr_storage);
95 : 18 : sa->sin_family = AF_INET;
96 : 18 : sa->sin_port = htons(port);
97 : 18 : sa->sin_addr.s_addr = ipv4->addr;
98 : 18 : addr_len = sizeof(struct sockaddr_in);
99 : 0 : } else if (auto* ipv6 = std::get_if<Ipv6Endpoint>(&remote_endpoint)) {
100 : 0 : family = AF_INET6;
101 : 0 : auto* sa = reinterpret_cast<struct sockaddr_in6*>(&addr_storage);
102 : 0 : sa->sin6_family = AF_INET6;
103 : 0 : sa->sin6_port = htons(port);
104 : 0 : std::memcpy(sa->sin6_addr.s6_addr, ipv6->addr.data(), 16);
105 : 0 : sa->sin6_flowinfo = 0;
106 : 0 : sa->sin6_scope_id = 0;
107 : 0 : addr_len = sizeof(struct sockaddr_in6);
108 : : } else {
109 : 0 : return nullptr;
110 : : }
111 : :
112 : 18 : int fd = ::socket(family, SOCK_STREAM, 0);
113 : 18 : if (fd < 0) {
114 : 0 : return nullptr;
115 : : }
116 : :
117 : : // Set TCP_NODELAY
118 : 18 : int nodelay = 1;
119 : 18 : setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay));
120 : :
121 : : // Set non-blocking
122 : 18 : int flags = fcntl(fd, F_GETFL, 0);
123 : 18 : fcntl(fd, F_SETFL, flags | O_NONBLOCK);
124 : :
125 : 18 : int result = ::connect(fd, reinterpret_cast<struct sockaddr*>(&addr_storage),
126 : : addr_len);
127 : 18 : if (result < 0 && errno != EINPROGRESS) {
128 : 0 : ::close(fd);
129 : 0 : return nullptr;
130 : : }
131 : :
132 : 18 : auto pool = get_or_create_pool(remote_endpoint);
133 : 18 : bool use_tls = pool_config_.use_tls;
134 : 18 : ConnectionPtr conn;
135 : :
136 : : // Create connection in Connecting state — event loop registration is
137 : : // deferred until the non-blocking connect completes.
138 : 18 : if (use_tls) {
139 : : auto tls_conn = TlsConnection::create_client(endpoint_, remote_endpoint,
140 : 0 : &tls_context_, &loop_);
141 : 0 : tls_conn->set_fd(fd);
142 : 0 : tls_conn->set_ready_handler(
143 : 0 : [pool](ConnectionPtr c) { pool->on_connection_ready(c); });
144 : 0 : tls_conn->set_error_handler([pool](ConnectionPtr c, const error& e) {
145 : 0 : pool->on_connection_error(c, e);
146 : 0 : });
147 : 0 : tls_conn->set_frame_handler([pool](StreamBuffer data) {
148 : 0 : pool->on_frame_received(std::move(data));
149 : 0 : });
150 : 0 : conn = tls_conn;
151 : 0 : } else {
152 : : auto plain_conn = WireFrameConnection::create_connecting_client(
153 : 18 : fd, endpoint_, remote_endpoint, &loop_);
154 : 36 : plain_conn->set_ready_handler(
155 : 36 : [pool](ConnectionPtr c) { pool->on_connection_ready(c); });
156 : 18 : plain_conn->set_error_handler([pool](ConnectionPtr c, const error& e) {
157 : 0 : pool->on_connection_error(c, e);
158 : 0 : });
159 : 18 : plain_conn->set_frame_handler([pool](StreamBuffer data) {
160 : 0 : pool->on_frame_received(std::move(data));
161 : 0 : });
162 : 18 : conn = plain_conn;
163 : 18 : }
164 : :
165 : : // Add to pool and track early so the write_handler can find the connection
166 : 18 : pool->add_connection(conn);
167 : 18 : register_connection(conn, fd);
168 : :
169 : : // Handle non-blocking connect completion
170 : 18 : if (result < 0 && errno == EINPROGRESS) {
171 : 18 : if (!complete_connect(fd, use_tls)) {
172 : 0 : unregister_connection(fd);
173 : 0 : return nullptr;
174 : : }
175 : : } else {
176 : : // Connected immediately (e.g. localhost) — set up read handler directly
177 : 0 : if (use_tls) {
178 : 0 : TlsConnection::setup_after_connect(
179 : 0 : std::static_pointer_cast<TlsConnection>(conn));
180 : 0 : static_cast<TlsConnection*>(conn.get())->start_client_handshake();
181 : : } else {
182 : 0 : WireFrameConnection::setup_after_connect(
183 : 0 : std::static_pointer_cast<WireFrameConnection>(conn));
184 : : }
185 : : }
186 : :
187 : 18 : return conn;
188 : 18 : }
189 : :
190 : 3 : ConnectionPtr TcpTransport::connect(EndPoint remote_endpoint) {
191 : 3 : if (registry_ == nullptr) {
192 : 1 : return nullptr; // No registry configured
193 : : }
194 : :
195 : 2 : NodeEndpoint* ep = registry_->get(remote_endpoint);
196 : 2 : if (ep == nullptr) {
197 : 1 : return nullptr; // Unknown node
198 : : }
199 : :
200 : : // Check if UDS path is available for this endpoint
201 : 1 : if (!ep->identity.uds_path.empty()) {
202 : 1 : return connect_unix_domain(remote_endpoint, ep->identity.uds_path);
203 : : }
204 : :
205 : : // Resolve hostname to IP if needed
206 : 0 : std::string ip = host_resolver_.resolve(ep->identity.host);
207 : :
208 : : // Connect to resolved IP:port
209 : 0 : return connect(remote_endpoint, ip, ep->tcp_port);
210 : 0 : }
211 : :
212 : 2 : ConnectionPtr TcpTransport::connect_unix_domain(EndPoint remote_endpoint,
213 : : const std::string& socket_path) {
214 : 2 : int fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
215 : 2 : if (fd < 0) {
216 : 0 : return nullptr;
217 : : }
218 : :
219 : : // Set non-blocking
220 : 2 : int flags = fcntl(fd, F_GETFL, 0);
221 : 2 : fcntl(fd, F_SETFL, flags | O_NONBLOCK);
222 : :
223 : 2 : struct sockaddr_un addr{};
224 : 2 : addr.sun_family = AF_UNIX;
225 : 2 : std::strncpy(addr.sun_path, socket_path.c_str(), sizeof(addr.sun_path) - 1);
226 : :
227 : : int result =
228 : 2 : ::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr));
229 : 2 : if (result < 0 && errno != EINPROGRESS) {
230 : 2 : ::close(fd);
231 : 2 : return nullptr;
232 : : }
233 : :
234 : : // Note: TLS over UDS is not supported per design decision.
235 : : // UDS connections always use WireFrameConnection.
236 : 0 : auto pool = get_or_create_pool(remote_endpoint);
237 : : auto plain_conn = WireFrameConnection::create_connecting_client(
238 : 0 : fd, endpoint_, remote_endpoint, &loop_);
239 : 0 : plain_conn->set_ready_handler(
240 : 0 : [pool](ConnectionPtr c) { pool->on_connection_ready(c); });
241 : 0 : plain_conn->set_error_handler([pool](ConnectionPtr c, const error& e) {
242 : 0 : pool->on_connection_error(c, e);
243 : 0 : });
244 : 0 : plain_conn->set_frame_handler([pool](StreamBuffer data) {
245 : 0 : pool->on_frame_received(std::move(data));
246 : 0 : });
247 : :
248 : 0 : pool->add_connection(plain_conn);
249 : 0 : register_connection(plain_conn, fd);
250 : :
251 : : // Handle non-blocking connect completion
252 : 0 : if (result < 0 && errno == EINPROGRESS) {
253 : 0 : if (!complete_connect(fd, /*use_tls=*/false)) {
254 : 0 : unregister_connection(fd);
255 : 0 : return nullptr;
256 : : }
257 : : } else {
258 : 0 : WireFrameConnection::setup_after_connect(plain_conn);
259 : : }
260 : :
261 : 0 : return plain_conn;
262 : 0 : }
263 : :
264 : 18 : bool TcpTransport::complete_connect(int fd, bool use_tls) {
265 : : // Ensure event loop backend is started
266 : 18 : if (!loop_.is_running()) {
267 : 5 : loop_.run();
268 : : }
269 : :
270 : : // Shared flag to coordinate between write_handler and timeout —
271 : : // whichever fires first wins; the other becomes a no-op.
272 : 18 : auto done = std::make_shared<bool>(false);
273 : :
274 : : // Register for Write events — the fd becomes writable when the
275 : : // non-blocking TCP handshake completes (success or error).
276 : 18 : loop_.add_fd(fd, EventLoop::Event::Write);
277 : :
278 : 18 : loop_.set_write_handler(fd, [this, fd, use_tls, done](int event_fd) {
279 : : (void)event_fd;
280 : 0 : if (*done)
281 : 0 : return;
282 : 0 : *done = true;
283 : :
284 : 0 : loop_.clear_write_handler(fd);
285 : :
286 : 0 : int so_error = 0;
287 : 0 : socklen_t len = sizeof(so_error);
288 : 0 : ::getsockopt(fd, SOL_SOCKET, SO_ERROR, &so_error, &len);
289 : :
290 : 0 : auto it = connections_.find(fd);
291 : 0 : if (it == connections_.end())
292 : 0 : return;
293 : :
294 : 0 : if (so_error != 0) {
295 : 0 : it->second->set_state(ConnectionState::Error);
296 : 0 : HPACTOR_LOG_ERROR(log::LogCategory::kNetwork, ActorId{0}, 0,
297 : : "connection error");
298 : 0 : return;
299 : : }
300 : :
301 : : // Connect succeeded — complete the per-type setup
302 : 0 : if (use_tls) {
303 : 0 : TlsConnection::setup_after_connect(
304 : 0 : std::static_pointer_cast<TlsConnection>(it->second));
305 : 0 : static_cast<TlsConnection*>(it->second.get())->start_client_handshake();
306 : : } else {
307 : 0 : WireFrameConnection::setup_after_connect(
308 : 0 : std::static_pointer_cast<WireFrameConnection>(it->second));
309 : : }
310 : : });
311 : :
312 : : // Timeout via event loop timer — fires after 5s if the connect
313 : : // hasn't completed, preventing a stale write_handler leak.
314 : 18 : loop_.run_after(
315 : 36 : [this, fd, done]() {
316 : 0 : if (*done)
317 : 0 : return;
318 : 0 : *done = true;
319 : :
320 : 0 : loop_.clear_write_handler(fd);
321 : :
322 : 0 : auto it = connections_.find(fd);
323 : 0 : if (it != connections_.end() &&
324 : 0 : it->second->state() == ConnectionState::Connecting) {
325 : 0 : it->second->set_state(ConnectionState::Error);
326 : : }
327 : : },
328 : : 5000);
329 : :
330 : 36 : return true;
331 : 18 : }
332 : :
333 : 3 : void TcpTransport::listen(uint16_t port) {
334 : 3 : acceptor_.set_accept_handler(
335 : 3 : [this](int client_fd, EndPoint ep) { handle_accept(client_fd, ep); });
336 : 6 : acceptor_.listen(port);
337 : 3 : }
338 : :
339 : 23 : void TcpTransport::stop_listening() {
340 : 23 : acceptor_.close();
341 : 23 : }
342 : :
343 : 0 : bool TcpTransport::try_send(const ActorAddress& target, const StreamBuffer& encoded) {
344 : 0 : auto pool = get_or_create_pool(target.endpoint);
345 : 0 : return pool->try_send(target, encoded);
346 : 0 : }
347 : :
348 : 5 : bool TcpTransport::is_connected(EndPoint remote_endpoint) const {
349 : 5 : auto it = pools_.find(remote_endpoint);
350 : 5 : if (it != pools_.end()) {
351 : 2 : return it->second->is_connected();
352 : : }
353 : 3 : return false;
354 : : }
355 : :
356 : 1 : void TcpTransport::close_connection(EndPoint remote_endpoint) {
357 : 1 : auto it = pools_.find(remote_endpoint);
358 : 1 : if (it != pools_.end()) {
359 : 1 : it->second->abort();
360 : 1 : pools_.erase(it);
361 : : }
362 : 1 : }
363 : :
364 : 0 : void TcpTransport::set_rpc_handler(rpc_response_handler handler) {
365 : : // Store handler and apply to all existing pools
366 : 0 : rpc_handler_ = std::move(handler);
367 : 0 : for (auto& [ep, pool] : pools_) {
368 : 0 : pool->set_rpc_handler(rpc_handler_);
369 : : }
370 : 0 : }
371 : :
372 : 18 : void TcpTransport::register_connection(ConnectionPtr conn, int fd) {
373 : 18 : connections_[fd] = conn;
374 : 18 : }
375 : :
376 : 0 : void TcpTransport::unregister_connection(int fd) {
377 : 0 : connections_.erase(fd);
378 : 0 : }
379 : :
380 : 0 : void TcpTransport::handle_accept(int client_fd, EndPoint remote_endpoint) {
381 : : // Get or create pool for the connecting endpoint
382 : 0 : auto pool = get_or_create_pool(remote_endpoint);
383 : :
384 : 0 : ConnectionPtr conn;
385 : 0 : if (pool_config_.use_tls) {
386 : : auto tls_conn = TlsConnection::create_server(
387 : 0 : client_fd, endpoint_, remote_endpoint, &tls_context_, &loop_);
388 : 0 : tls_conn->set_frame_handler([pool](StreamBuffer data) {
389 : 0 : pool->on_frame_received(std::move(data));
390 : 0 : });
391 : 0 : tls_conn->set_error_handler([pool](ConnectionPtr c, const error& e) {
392 : 0 : pool->on_connection_error(c, e);
393 : 0 : });
394 : 0 : conn = tls_conn;
395 : 0 : } else {
396 : : auto plain_conn = WireFrameConnection::create_as_server(
397 : 0 : client_fd, endpoint_, remote_endpoint, &loop_);
398 : 0 : plain_conn->set_frame_handler([pool](StreamBuffer data) {
399 : 0 : pool->on_frame_received(std::move(data));
400 : 0 : });
401 : 0 : plain_conn->set_error_handler([pool](ConnectionPtr c, const error& e) {
402 : 0 : pool->on_connection_error(c, e);
403 : 0 : });
404 : 0 : conn = plain_conn;
405 : 0 : }
406 : :
407 : 0 : pool->add_connection(conn);
408 : 0 : register_connection(conn, client_fd);
409 : 0 : }
410 : :
411 : 0 : std::string TcpTransport::derive_uds_path(const std::string& node_id) const {
412 : : // Sanitize node_id: replace colons with underscores
413 : 0 : std::string sanitized = node_id;
414 : 0 : for (char& c : sanitized) {
415 : 0 : if (c == ':')
416 : 0 : c = '_';
417 : : }
418 : 0 : return "/tmp/hpactor/" + sanitized + ".sock";
419 : 0 : }
420 : :
421 : : } // namespace net
422 : : } // namespace hpactor
|