LCOV - code coverage report
Current view: top level - src/net - tls_connection.cpp (source / functions) Coverage Total Hit
Test: HPActor Coverage Lines: 43.3 % 356 154
Test Date: 2026-05-20 02:24:49 Functions: 57.9 % 38 22
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: - 0 0

             Branch data     Line data    Source code
       1                 :             : // Copyright 2026 HPActor Contributors
       2                 :             : //
       3                 :             : // Licensed under the Apache License, Version 2.0 (the "License");
       4                 :             : // you may not use this file except in compliance with the License.
       5                 :             : // You may obtain a copy of the License at
       6                 :             : //
       7                 :             : //     http://www.apache.org/licenses/LICENSE-2.0
       8                 :             : //
       9                 :             : // Unless required by applicable law or agreed to in writing, software
      10                 :             : // distributed under the License is distributed on an "AS IS" BASIS,
      11                 :             : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12                 :             : // See the License for the specific language governing permissions and
      13                 :             : // limitations under the License.
      14                 :             : 
      15                 :             : #include <hpactor/net/tls_connection.hpp>
      16                 :             : 
      17                 :             : #include <hpactor/log/logger.hpp>
      18                 :             : #include <hpactor/net/event_loop.hpp>
      19                 :             : 
      20                 :             : #include <cstring>
      21                 :             : #include <openssl/aes.h>
      22                 :             : #include <openssl/evp.h>
      23                 :             : #include <openssl/rand.h>
      24                 :             : #include <sys/socket.h>
      25                 :             : #include <unistd.h>
      26                 :             : 
      27                 :             : namespace hpactor {
      28                 :             : 
      29                 :             : namespace net {
      30                 :             : 
      31                 :             : namespace {
      32                 :             : 
      33                 :             : // Format a TLS message with type and payload
      34                 :           6 : StreamBuffer format_tls_message(TlsMessageType type, const StreamBuffer& payload) {
      35                 :           6 :     StreamBuffer msg;
      36                 :           6 :     msg.push_back(static_cast<uint8_t>(type));
      37                 :             :     // Add length as 3 bytes (TLS style)
      38                 :           6 :     msg.push_back(static_cast<uint8_t>((payload.size() >> 16) & 0xFF));
      39                 :           6 :     msg.push_back(static_cast<uint8_t>((payload.size() >> 8) & 0xFF));
      40                 :           6 :     msg.push_back(static_cast<uint8_t>(payload.size() & 0xFF));
      41                 :           6 :     msg.insert(msg.end(), payload.begin(), payload.end());
      42                 :           6 :     return msg;
      43                 :             : }
      44                 :             : 
      45                 :             : // Parse a TLS message - extract payload from formatted message
      46                 :           3 : StreamBuffer parse_tls_payload(const uint8_t* data, size_t len, size_t& consumed) {
      47                 :           3 :     consumed = 0;
      48                 :           3 :     if (len < 4) {
      49                 :           0 :         return StreamBuffer{};
      50                 :             :     }
      51                 :           3 :     size_t payload_len = (static_cast<size_t>(data[1]) << 16) |
      52                 :           3 :                          (static_cast<size_t>(data[2]) << 8) |
      53                 :           3 :                          static_cast<size_t>(data[3]);
      54                 :           3 :     size_t total_len = 4 + payload_len;
      55                 :           3 :     if (len < total_len) {
      56                 :           0 :         return StreamBuffer{};
      57                 :             :     }
      58                 :           3 :     StreamBuffer payload(data + 4, data + total_len);
      59                 :           3 :     consumed = total_len;
      60                 :           3 :     return payload;
      61                 :           3 : }
      62                 :             : 
      63                 :             : } // anonymous namespace
      64                 :             : 
      65                 :          20 : TlsConnection::TlsConnection(int fd, EndPoint local_endpoint,
      66                 :             :                              EndPoint remote_endpoint, TlsContext* tls_context,
      67                 :          20 :                              EventLoop* loop)
      68                 :             :     : Connection(fd, local_endpoint, remote_endpoint, loop),
      69                 :          20 :       tls_context_(tls_context) {
      70                 :          20 :     read_buffer_.reserve(kReadChunkSize);
      71                 :          20 :     write_buffer_.reserve(kReadChunkSize);
      72                 :             :     // Generate random client nonce
      73                 :          40 :     RAND_bytes(client_nonce_.data(), static_cast<int>(kNonceSize));
      74                 :          20 : }
      75                 :             : 
      76                 :          40 : TlsConnection::~TlsConnection() {
      77                 :          20 :     close();
      78                 :          40 : }
      79                 :             : 
      80                 :             : TlsConnectionPtr
      81                 :          13 : TlsConnection::create_client(EndPoint local_endpoint, EndPoint remote_endpoint,
      82                 :             :                              TlsContext* tls_context, EventLoop* loop) {
      83                 :             :     auto conn = std::shared_ptr<TlsConnection>(new TlsConnection(
      84                 :          13 :         -1, local_endpoint, remote_endpoint, tls_context, loop));
      85                 :          13 :     conn->set_state(ConnectionState::Connecting);
      86                 :          13 :     conn->is_server_ = false;
      87                 :          13 :     return conn;
      88                 :             : }
      89                 :             : 
      90                 :             : TlsConnectionPtr
      91                 :           7 : TlsConnection::create_server(int socket_fd, EndPoint local_endpoint,
      92                 :             :                              EndPoint remote_endpoint, TlsContext* tls_context,
      93                 :             :                              EventLoop* loop) {
      94                 :             :     auto conn = std::shared_ptr<TlsConnection>(new TlsConnection(
      95                 :           7 :         socket_fd, local_endpoint, remote_endpoint, tls_context, loop));
      96                 :           7 :     conn->set_state(ConnectionState::Connected);
      97                 :           7 :     conn->is_server_ = true;
      98                 :             :     // Server waits for client hello
      99                 :           7 :     conn->set_handshake_state(TlsHandshakeState::WaitingForServerHello);
     100                 :           7 :     conn->set_session_state(TlsSessionState::Handshake);
     101                 :             : 
     102                 :             :     // Register fd with event loop for read events
     103                 :           7 :     if (loop && socket_fd >= 0) {
     104                 :           7 :         loop->add_fd(socket_fd, EventLoop::Event::Read);
     105                 :           7 :         if (loop->supports_read_handler()) {
     106                 :           7 :             std::weak_ptr<TlsConnection> weak_conn = conn;
     107                 :           7 :             loop->set_read_handler(socket_fd, [weak_conn](int /*event_fd*/) {
     108                 :           0 :                 if (auto self = weak_conn.lock()) {
     109                 :           0 :                     self->handle_read();
     110                 :           0 :                 }
     111                 :           0 :             });
     112                 :           7 :         }
     113                 :             :     }
     114                 :             : 
     115                 :           7 :     return conn;
     116                 :             : }
     117                 :             : 
     118                 :           1 : void TlsConnection::set_ready_handler(std::function<void(ConnectionPtr)> handler) {
     119                 :           1 :     ready_handler_ = std::move(handler);
     120                 :           1 : }
     121                 :             : 
     122                 :           1 : void TlsConnection::set_frame_handler(frame_handler handler) {
     123                 :           1 :     frame_handler_ = std::move(handler);
     124                 :           1 : }
     125                 :             : 
     126                 :           2 : void TlsConnection::set_error_handler(
     127                 :             :     std::function<void(ConnectionPtr, const error&)> handler) {
     128                 :           2 :     error_handler_ = std::move(handler);
     129                 :           2 : }
     130                 :             : 
     131                 :           3 : void TlsConnection::set_send_completion_handler(std::function<void(int result)> handler) {
     132                 :           3 :     send_completion_handler_ = std::move(handler);
     133                 :           3 : }
     134                 :             : 
     135                 :           0 : void TlsConnection::set_fd(int fd) {
     136                 :           0 :     fd_ = fd;
     137                 :             :     // Event loop registration is deferred — the caller must verify the
     138                 :             :     // non-blocking connect completed, then call setup_after_connect().
     139                 :           0 : }
     140                 :             : 
     141                 :           0 : void TlsConnection::setup_after_connect(TlsConnectionPtr conn) {
     142                 :           0 :     conn->set_state(ConnectionState::Connected);
     143                 :             : 
     144                 :           0 :     auto* loop = conn->event_loop();
     145                 :           0 :     int fd = conn->fd();
     146                 :           0 :     if (loop && fd >= 0) {
     147                 :           0 :         loop->add_fd(fd, EventLoop::Event::Read);
     148                 :           0 :         if (loop->supports_read_handler()) {
     149                 :           0 :             std::weak_ptr<TlsConnection> weak_conn = conn;
     150                 :           0 :             loop->set_read_handler(fd, [weak_conn](int /*event_fd*/) {
     151                 :           0 :                 if (auto self = weak_conn.lock()) {
     152                 :           0 :                     self->handle_read();
     153                 :           0 :                 }
     154                 :           0 :             });
     155                 :           0 :         }
     156                 :             :     }
     157                 :           0 : }
     158                 :             : 
     159                 :           6 : void TlsConnection::start_client_handshake() {
     160                 :           6 :     if (is_server_)
     161                 :           0 :         return;
     162                 :             : 
     163                 :             :     // Generate client nonce if not already done
     164                 :           6 :     bool all_zero = true;
     165                 :           6 :     for (auto b : client_nonce_) {
     166                 :           6 :         if (b != 0) {
     167                 :           6 :             all_zero = false;
     168                 :           6 :             break;
     169                 :             :         }
     170                 :             :     }
     171                 :           6 :     if (all_zero) {
     172                 :           0 :         RAND_bytes(client_nonce_.data(), static_cast<int>(kNonceSize));
     173                 :             :     }
     174                 :             : 
     175                 :           6 :     set_handshake_state(TlsHandshakeState::WaitingForServerHello);
     176                 :           6 :     set_state(ConnectionState::Handshake);
     177                 :             : 
     178                 :           6 :     StreamBuffer client_hello = build_client_hello();
     179                 :           6 :     send_raw(client_hello);
     180                 :           6 : }
     181                 :             : 
     182                 :           3 : void TlsConnection::handle_read() {
     183                 :             :     // Read from fd into accumulation buffer
     184                 :             :     while (true) {
     185                 :           6 :         uint8_t* area = read_buffer_.reserve_tail(kReadChunkSize);
     186                 :           6 :         ssize_t n = ::read(fd_, area, kReadChunkSize);
     187                 :           6 :         if (n > 0) {
     188                 :           3 :             read_buffer_.commit_tail(static_cast<size_t>(n));
     189                 :           3 :         } else if (n == 0) {
     190                 :           0 :             break;
     191                 :             :         } else {
     192                 :           3 :             if (errno == EAGAIN || errno == EWOULDBLOCK)
     193                 :             :                 break;
     194                 :           0 :             break;
     195                 :             :         }
     196                 :           3 :     }
     197                 :             : 
     198                 :           3 :     process_buffer();
     199                 :           3 : }
     200                 :             : 
     201                 :           3 : void TlsConnection::process_buffer() {
     202                 :             :     // Process complete messages in buffer
     203                 :           3 :     while (read_buffer_.size() >= 4) {
     204                 :           3 :         size_t consumed = 0;
     205                 :             :         StreamBuffer payload =
     206                 :           3 :             parse_tls_payload(read_buffer_.data(), read_buffer_.size(), consumed);
     207                 :           3 :         if (consumed == 0) {
     208                 :           0 :             break; // Wait for more data
     209                 :             :         }
     210                 :           3 :         read_buffer_.consume(consumed);
     211                 :             : 
     212                 :           3 :         if (session_state_ == TlsSessionState::Encrypted) {
     213                 :             :             // Decrypt and deliver to frame handler
     214                 :           0 :             StreamBuffer plaintext = decrypt_aes(payload);
     215                 :           0 :             if (!plaintext.empty() && frame_handler_) {
     216                 :           0 :                 frame_handler_(std::move(plaintext));
     217                 :             :             }
     218                 :           0 :         } else {
     219                 :             :             // Handle handshake messages
     220                 :           3 :             TlsMessageType msg_type = static_cast<TlsMessageType>(payload[0]);
     221                 :           3 :             StreamBuffer msg_payload(payload.begin() + 1, payload.end());
     222                 :             : 
     223                 :           3 :             switch (msg_type) {
     224                 :           0 :                 case TlsMessageType::ServerHello:
     225                 :           0 :                     handle_server_hello(msg_payload);
     226                 :           0 :                     break;
     227                 :           1 :                 case TlsMessageType::Certificate:
     228                 :           1 :                     handle_certificate(msg_payload);
     229                 :           1 :                     break;
     230                 :           0 :                 case TlsMessageType::CertificateVerify:
     231                 :           0 :                     handle_certificate_verify(msg_payload);
     232                 :           0 :                     break;
     233                 :           0 :                 case TlsMessageType::Finished:
     234                 :           0 :                     handle_finished(msg_payload);
     235                 :           0 :                     break;
     236                 :           2 :                 default:
     237                 :           2 :                     set_handshake_state(TlsHandshakeState::Error);
     238                 :           2 :                     break;
     239                 :             :             }
     240                 :           3 :         }
     241                 :             : 
     242                 :             :         // Check if we should stop processing
     243                 :           3 :         if (handshake_state_ == TlsHandshakeState::Error ||
     244                 :           0 :             session_state_ == TlsSessionState::Error) {
     245                 :             :             break;
     246                 :             :         }
     247                 :           3 :     }
     248                 :           3 : }
     249                 :             : 
     250                 :           1 : void TlsConnection::send(const StreamBuffer& frame_data) {
     251                 :           1 :     if (session_state_ == TlsSessionState::Encrypted) {
     252                 :           0 :         HPACTOR_LOG_TRACE(
     253                 :             :             log::LogCategory::kNetwork, ActorId{0}, 0, "network frame sent",
     254                 :             :             log::field("bytes", static_cast<uint64_t>(frame_data.size())));
     255                 :           0 :         StreamBuffer encrypted = encrypt_aes(frame_data);
     256                 :           0 :         send_raw(format_tls_message(TlsMessageType::Finished, encrypted));
     257                 :           0 :     }
     258                 :           1 : }
     259                 :             : 
     260                 :          22 : void TlsConnection::close() {
     261                 :          22 :     if (fd_ >= 0) {
     262                 :           7 :         if (loop_) {
     263                 :           7 :             loop_->clear_read_handler(fd_);
     264                 :           7 :             loop_->remove_fd(fd_);
     265                 :             :         }
     266                 :           7 :         ::close(fd_);
     267                 :           7 :         fd_ = -1;
     268                 :             :     }
     269                 :          22 :     set_state(ConnectionState::Disconnected);
     270                 :          22 :     HPACTOR_LOG_DEBUG(log::LogCategory::kNetwork, ActorId{0}, 0,
     271                 :             :                       "connection closed");
     272                 :          22 : }
     273                 :             : 
     274                 :           6 : StreamBuffer TlsConnection::build_client_hello() {
     275                 :           6 :     StreamBuffer payload;
     276                 :             :     // Message type
     277                 :           6 :     payload.push_back(static_cast<uint8_t>(TlsMessageType::ClientHello));
     278                 :             :     // Client nonce (32 bytes)
     279                 :           6 :     payload.insert(payload.end(), client_nonce_.begin(), client_nonce_.end());
     280                 :             :     // Public key from TLS context
     281                 :           6 :     if (!tls_context_) {
     282                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     283                 :           0 :         return StreamBuffer{};
     284                 :             :     }
     285                 :           6 :     const StreamBuffer& pub_key = tls_context_->public_key();
     286                 :           6 :     payload.insert(payload.end(), pub_key.begin(), pub_key.end());
     287                 :             : 
     288                 :           6 :     StreamBuffer msg = format_tls_message(TlsMessageType::ClientHello, payload);
     289                 :           6 :     handshake_messages_.insert(handshake_messages_.end(), msg.begin(), msg.end());
     290                 :           6 :     return msg;
     291                 :           6 : }
     292                 :             : 
     293                 :           0 : StreamBuffer TlsConnection::build_certificate() {
     294                 :           0 :     StreamBuffer payload;
     295                 :             :     // Certificate data from TLS context
     296                 :           0 :     if (!tls_context_) {
     297                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     298                 :           0 :         return StreamBuffer{};
     299                 :             :     }
     300                 :           0 :     const StreamBuffer& cert = tls_context_->certificate();
     301                 :           0 :     payload.insert(payload.end(), cert.begin(), cert.end());
     302                 :             : 
     303                 :           0 :     StreamBuffer msg = format_tls_message(TlsMessageType::Certificate, payload);
     304                 :           0 :     handshake_messages_.insert(handshake_messages_.end(), msg.begin(), msg.end());
     305                 :           0 :     return msg;
     306                 :           0 : }
     307                 :             : 
     308                 :           0 : StreamBuffer TlsConnection::build_certificate_verify(const Nonce& challenge) {
     309                 :           0 :     StreamBuffer payload;
     310                 :             :     // Sign the challenge nonce with our private key
     311                 :           0 :     StreamBuffer data_to_sign(challenge.begin(), challenge.end());
     312                 :           0 :     if (!tls_context_) {
     313                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     314                 :           0 :         return StreamBuffer{};
     315                 :             :     }
     316                 :           0 :     StreamBuffer signature = tls_context_->sign_data(data_to_sign);
     317                 :           0 :     payload.insert(payload.end(), signature.begin(), signature.end());
     318                 :             : 
     319                 :             :     StreamBuffer msg =
     320                 :           0 :         format_tls_message(TlsMessageType::CertificateVerify, payload);
     321                 :           0 :     handshake_messages_.insert(handshake_messages_.end(), msg.begin(), msg.end());
     322                 :           0 :     return msg;
     323                 :           0 : }
     324                 :             : 
     325                 :           0 : StreamBuffer TlsConnection::build_finished() {
     326                 :           0 :     StreamBuffer payload;
     327                 :             :     // Compute verify_data using PRF
     328                 :             :     StreamBuffer verify_data =
     329                 :           0 :         prf_sha256(master_secret_, "finished", handshake_messages_);
     330                 :           0 :     payload.insert(payload.end(), verify_data.begin(), verify_data.end());
     331                 :             : 
     332                 :           0 :     return format_tls_message(TlsMessageType::Finished, payload);
     333                 :           0 : }
     334                 :             : 
     335                 :           0 : void TlsConnection::handle_server_hello(const StreamBuffer& data) {
     336                 :           0 :     if (handshake_state_ != TlsHandshakeState::WaitingForServerHello) {
     337                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     338                 :           0 :         return;
     339                 :             :     }
     340                 :             : 
     341                 :           0 :     if (data.size() < kNonceSize) {
     342                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     343                 :           0 :         return;
     344                 :             :     }
     345                 :             : 
     346                 :             :     // Extract server nonce
     347                 :           0 :     std::memcpy(server_nonce_.data(), data.data(), kNonceSize);
     348                 :             : 
     349                 :             :     // Send our certificate
     350                 :           0 :     StreamBuffer cert_msg = build_certificate();
     351                 :           0 :     send_raw(cert_msg);
     352                 :             : 
     353                 :             :     // Generate pre_master_secret
     354                 :           0 :     pre_master_secret_.resize(48);
     355                 :           0 :     RAND_bytes(pre_master_secret_.data(), 48);
     356                 :             : 
     357                 :           0 :     set_handshake_state(TlsHandshakeState::WaitingForCertificate);
     358                 :           0 : }
     359                 :             : 
     360                 :           1 : void TlsConnection::handle_certificate(const StreamBuffer& data) {
     361                 :           1 :     if (handshake_state_ != TlsHandshakeState::WaitingForCertificate) {
     362                 :           1 :         set_handshake_state(TlsHandshakeState::Error);
     363                 :           1 :         return;
     364                 :             :     }
     365                 :             : 
     366                 :           0 :     if (!tls_context_) {
     367                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     368                 :           0 :         return;
     369                 :             :     }
     370                 :             : 
     371                 :             :     // Verify the certificate
     372                 :           0 :     auto result = tls_context_->verify_certificate(data);
     373                 :           0 :     if (result != TlsContext::CertVerifyResult::Ok) {
     374                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     375                 :           0 :         return;
     376                 :             :     }
     377                 :             : 
     378                 :             :     // Send certificate verify message
     379                 :           0 :     StreamBuffer verify_msg = build_certificate_verify(server_nonce_);
     380                 :           0 :     send_raw(verify_msg);
     381                 :             : 
     382                 :           0 :     set_handshake_state(TlsHandshakeState::WaitingForCertificateVerify);
     383                 :           0 : }
     384                 :             : 
     385                 :           0 : void TlsConnection::handle_certificate_verify(const StreamBuffer& data) {
     386                 :             :     (void)data;
     387                 :           0 :     if (handshake_state_ != TlsHandshakeState::WaitingForCertificateVerify) {
     388                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     389                 :           0 :         return;
     390                 :             :     }
     391                 :             : 
     392                 :             :     // Derive session keys
     393                 :           0 :     derive_session_keys(pre_master_secret_, client_nonce_, server_nonce_);
     394                 :             : 
     395                 :           0 :     set_handshake_state(TlsHandshakeState::WaitingForFinished);
     396                 :             : }
     397                 :             : 
     398                 :           0 : void TlsConnection::handle_finished(const StreamBuffer& data) {
     399                 :             :     (void)data;
     400                 :           0 :     if (handshake_state_ != TlsHandshakeState::WaitingForFinished) {
     401                 :           0 :         set_handshake_state(TlsHandshakeState::Error);
     402                 :           0 :         return;
     403                 :             :     }
     404                 :             : 
     405                 :             :     // Verify the finished message
     406                 :             :     // In a full implementation, we would verify the verify_data here
     407                 :             : 
     408                 :           0 :     set_handshake_state(TlsHandshakeState::HandshakeComplete);
     409                 :           0 :     set_session_state(TlsSessionState::Encrypted);
     410                 :           0 :     set_state(ConnectionState::Connected);
     411                 :             : 
     412                 :           0 :     HPACTOR_LOG_DEBUG(log::LogCategory::kNetwork, ActorId{0}, 0,
     413                 :             :                       "connection opened");
     414                 :             : 
     415                 :             :     // Notify ready handler
     416                 :           0 :     if (ready_handler_) {
     417                 :             :         TlsConnectionPtr self =
     418                 :           0 :             std::enable_shared_from_this<TlsConnection>::shared_from_this();
     419                 :           0 :         ready_handler_(self);
     420                 :           0 :     }
     421                 :             : }
     422                 :             : 
     423                 :           0 : void TlsConnection::derive_session_keys(const StreamBuffer& pre_master_secret,
     424                 :             :                                         const Nonce& client_nonce,
     425                 :             :                                         const Nonce& server_nonce) {
     426                 :           0 :     StreamBuffer random_data;
     427                 :           0 :     random_data.insert(random_data.end(), client_nonce.begin(), client_nonce.end());
     428                 :           0 :     random_data.insert(random_data.end(), server_nonce.begin(), server_nonce.end());
     429                 :           0 :     master_secret_ = prf_sha256(pre_master_secret, "master secret", random_data);
     430                 :             : 
     431                 :           0 :     random_data.clear();
     432                 :           0 :     random_data.insert(random_data.end(), server_nonce.begin(), server_nonce.end());
     433                 :           0 :     random_data.insert(random_data.end(), client_nonce.begin(), client_nonce.end());
     434                 :             :     StreamBuffer key_block =
     435                 :           0 :         prf_sha256(master_secret_, "key expansion", random_data);
     436                 :             : 
     437                 :           0 :     session_key_.assign(key_block.begin(), key_block.begin() + 32);
     438                 :           0 :     session_iv_.assign(key_block.begin() + 32, key_block.begin() + 48);
     439                 :           0 : }
     440                 :             : 
     441                 :             : namespace {
     442                 :             : 
     443                 :             : // HMAC-SHA256 using EVP_Q_mac (OpenSSL 3.0 compatible)
     444                 :           0 : StreamBuffer hmac_sha256(const StreamBuffer& key, const StreamBuffer& data) {
     445                 :           0 :     constexpr size_t hash_size = 32; // SHA256 output size
     446                 :           0 :     StreamBuffer out(hash_size, 0);
     447                 :           0 :     size_t out_len = hash_size;
     448                 :             : 
     449                 :             :     // Use EVP_Q_mac with HMAC algorithm and SHA256 digest
     450                 :           0 :     EVP_Q_mac(nullptr, "HMAC", nullptr, "SHA256", nullptr, key.data(), key.size(),
     451                 :             :               data.data(), data.size(), out.data(), out.size(), &out_len);
     452                 :             : 
     453                 :           0 :     out.resize(out_len);
     454                 :           0 :     return out;
     455                 :             : }
     456                 :             : 
     457                 :             : } // anonymous namespace
     458                 :             : 
     459                 :           0 : StreamBuffer TlsConnection::prf_sha256(const StreamBuffer& secret, const char* label,
     460                 :             :                                        const StreamBuffer& data) {
     461                 :           0 :     StreamBuffer result;
     462                 :           0 :     StreamBuffer label_seed;
     463                 :           0 :     label_seed.insert(label_seed.end(), label, label + std::strlen(label));
     464                 :           0 :     label_seed.insert(label_seed.end(), data.begin(), data.end());
     465                 :             : 
     466                 :           0 :     StreamBuffer a = label_seed;
     467                 :             : 
     468                 :           0 :     while (result.size() < 48) { // Generate enough for master_secret + key
     469                 :             :                                  // expansion
     470                 :             :         // A(i) = HMAC(secret, A(i-1))
     471                 :           0 :         a = hmac_sha256(secret, a);
     472                 :             : 
     473                 :             :         // HMAC(secret, A(i) + label_seed)
     474                 :           0 :         StreamBuffer a_label_seed = a;
     475                 :           0 :         a_label_seed.insert(a_label_seed.end(), label_seed.begin(),
     476                 :             :                             label_seed.end());
     477                 :           0 :         StreamBuffer h = hmac_sha256(secret, a_label_seed);
     478                 :             : 
     479                 :           0 :         result.insert(result.end(), h.begin(), h.end());
     480                 :           0 :     }
     481                 :           0 :     return result;
     482                 :           0 : }
     483                 :             : 
     484                 :           0 : StreamBuffer TlsConnection::encrypt_aes(const StreamBuffer& plaintext) {
     485                 :           0 :     StreamBuffer ciphertext;
     486                 :           0 :     EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new();
     487                 :           0 :     if (!ctx)
     488                 :           0 :         return ciphertext;
     489                 :             : 
     490                 :             :     unsigned char iv[16];
     491                 :           0 :     std::memcpy(iv, session_iv_.data(), 16);
     492                 :             : 
     493                 :           0 :     EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, session_key_.data(), iv);
     494                 :           0 :     int len = 0;
     495                 :           0 :     ciphertext.resize(plaintext.size() + AES_BLOCK_SIZE);
     496                 :           0 :     EVP_EncryptUpdate(ctx, ciphertext.data(), &len, plaintext.data(),
     497                 :           0 :                       static_cast<int>(plaintext.size()));
     498                 :           0 :     int ciphertext_len = len;
     499                 :           0 :     EVP_EncryptFinal_ex(ctx, ciphertext.data() + len, &len);
     500                 :           0 :     ciphertext_len += len;
     501                 :           0 :     ciphertext.resize(static_cast<size_t>(ciphertext_len));
     502                 :             : 
     503                 :           0 :     EVP_CIPHER_CTX_free(ctx);
     504                 :           0 :     return ciphertext;
     505                 :             : }
     506                 :             : 
     507                 :           0 : StreamBuffer TlsConnection::decrypt_aes(const StreamBuffer& ciphertext) {
     508                 :           0 :     StreamBuffer plaintext;
     509                 :           0 :     EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new();
     510                 :           0 :     if (!ctx)
     511                 :           0 :         return plaintext;
     512                 :             : 
     513                 :             :     unsigned char iv[16];
     514                 :           0 :     std::memcpy(iv, session_iv_.data(), 16);
     515                 :             : 
     516                 :           0 :     EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, session_key_.data(), iv);
     517                 :           0 :     int len = 0;
     518                 :           0 :     plaintext.resize(ciphertext.size());
     519                 :           0 :     EVP_DecryptUpdate(ctx, plaintext.data(), &len, ciphertext.data(),
     520                 :           0 :                       static_cast<int>(ciphertext.size()));
     521                 :           0 :     int plaintext_len = len;
     522                 :           0 :     EVP_DecryptFinal_ex(ctx, plaintext.data() + len, &len);
     523                 :           0 :     plaintext_len += len;
     524                 :           0 :     plaintext.resize(static_cast<size_t>(plaintext_len));
     525                 :             : 
     526                 :           0 :     EVP_CIPHER_CTX_free(ctx);
     527                 :           0 :     return plaintext;
     528                 :             : }
     529                 :             : 
     530                 :          19 : void TlsConnection::set_handshake_state(TlsHandshakeState new_state) {
     531                 :          19 :     handshake_state_ = new_state;
     532                 :          19 :     if (new_state == TlsHandshakeState::Error) {
     533                 :           6 :         session_state_ = TlsSessionState::Error;
     534                 :           6 :         set_state(ConnectionState::Error);
     535                 :           6 :         HPACTOR_LOG_ERROR(log::LogCategory::kNetwork, ActorId{0}, 0,
     536                 :             :                           "TLS handshake failure");
     537                 :             :     }
     538                 :          19 : }
     539                 :             : 
     540                 :           7 : void TlsConnection::set_session_state(TlsSessionState new_state) {
     541                 :           7 :     session_state_ = new_state;
     542                 :           7 : }
     543                 :             : 
     544                 :           6 : void TlsConnection::send_raw(const StreamBuffer& data) {
     545                 :           6 :     if (fd_ < 0 || !loop_)
     546                 :           6 :         return;
     547                 :             : 
     548                 :             :     // Append data to write buffer
     549                 :           0 :     write_buffer_.append(data.data(), data.size());
     550                 :             : 
     551                 :             :     // If already sending, wait for completion
     552                 :           0 :     if (is_sending_)
     553                 :           0 :         return;
     554                 :             : 
     555                 :           0 :     flush_write_buffer();
     556                 :             : }
     557                 :             : 
     558                 :           0 : void TlsConnection::flush_write_buffer() {
     559                 :           0 :     if (fd_ < 0 || loop_ == nullptr || write_buffer_.empty()) {
     560                 :           0 :         return;
     561                 :             :     }
     562                 :             : 
     563                 :           0 :     is_sending_ = true;
     564                 :             : 
     565                 :             :     struct iovec iov;
     566                 :           0 :     iov.iov_base = write_buffer_.data();
     567                 :           0 :     iov.iov_len = write_buffer_.size();
     568                 :             : 
     569                 :             :     // Use async_send - completion will be delivered via loop's completion
     570                 :             :     // callback
     571                 :           0 :     loop_->backend()->async_send(fd_, &iov, 1, ActorId(0),
     572                 :             :                                  static_cast<uint32_t>(OpType::Send));
     573                 :             : }
     574                 :             : 
     575                 :           8 : void TlsConnection::handle_send_completion(int result) {
     576                 :           8 :     if (send_completion_handler_) {
     577                 :           4 :         send_completion_handler_(result);
     578                 :             :     }
     579                 :           8 :     is_sending_ = false;
     580                 :             : 
     581                 :           8 :     if (result < 0) {
     582                 :             :         // Send error - close connection
     583                 :           3 :         set_handshake_state(TlsHandshakeState::Error);
     584                 :           3 :         return;
     585                 :             :     }
     586                 :             : 
     587                 :             :     // Remove sent StreamBuffer from write buffer
     588                 :           5 :     if (static_cast<size_t>(result) >= write_buffer_.size()) {
     589                 :           5 :         write_buffer_.clear();
     590                 :             :     } else {
     591                 :           0 :         write_buffer_.consume(static_cast<size_t>(result));
     592                 :             :     }
     593                 :             : 
     594                 :             :     // If more data to send, continue flushing
     595                 :           5 :     if (!write_buffer_.empty()) {
     596                 :           0 :         flush_write_buffer();
     597                 :             :     }
     598                 :             : }
     599                 :             : 
     600                 :             : } // namespace net
     601                 :             : } // namespace hpactor
        

Generated by: LCOV version 2.0-1