X-Git-Url: https://git.r.bdr.sh/rbdr/olden-mail/blobdiff_plain/dc3d68214e2cd4b5f86279a7fa7b7d64c341b446..66bcf7c1640b6d5c5a04928b31062c71c5a26791:/src/proxy.rs?ds=sidebyside diff --git a/src/proxy.rs b/src/proxy.rs index ce40e6e..33de440 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,86 +1,316 @@ +//! # Proxy Module +//! +//! This module has the actual proxy functionality, exposed through +//! `Server`. The proxy consists of a local unencrypted TCP stream +//! and a remote TLS stream. Messages are passed between them via two +//! threads. +//! +//! Each new client connection spawns two threads: +//! - **Client to Server Thread**: Forwards data from client -> TLS server +//! - **Serveer to Client Thread**: Forwards data from TLS server -> client +//! +//! Finally, the `Server` may be shutdown by calling `.shutdown()`, +//! this will stop new connections and wait for it to finish. +//! +//! # Example +//! +//! ``` +//! use std::sync::Arc; +//! use crate::configuration::Proxy; +//! use crate::proxy::Server; +//! +//! let config = Arc::new(Proxy { +//! protocol: "IMAP".to_string(), +//! local_port: 143, +//! remote_domain: "imap.example.com".to_string(), +//! remote_port: 993, +//! }); +//! +//! let mut server = Server::new(config); +//! // The server runs in a background thread. To shut down gracefully: +//! server.shutdown(); +//! ``` +use log::{debug, error, info}; use native_tls::TlsConnector; -use std::io::{Read, Write}; +use std::io::{ErrorKind, Read, Write}; use std::net::{TcpListener, TcpStream}; -use std::sync::Arc; -use std::thread::{spawn, JoinHandle}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, +}; +use std::thread::{sleep, spawn, JoinHandle}; +use std::time::Duration; -use crate::configuration::ProxyConfiguration; +use crate::configuration::Proxy; +use crate::middleware::{CLIENT_MIDDLEWARE, SERVER_MIDDLEWARE}; -pub fn create_proxy(configuration: Arc) -> JoinHandle<()> { - let cloned_configuration = Arc::clone(&configuration); - spawn(move || { - run_proxy(cloned_configuration); - }) +/// A proxy server that listens for plaintext connections and forwards them +/// via TLS. +/// +/// Creating a new `Server` spawns a dedicated thread that: +/// - Binds to a local port (non-blocking mode). +/// - Spawns additional threads for each incoming client connection. +/// - Manages connection-lifetime until it receives a shutdown signal. +pub struct Server { + running: Arc, + thread_handle: Option>, } -fn run_proxy(configuration: Arc) { - let listener = TcpListener::bind(format!("0.0.0.0:{}", configuration.local_port)).unwrap(); +impl Server { + /// Creates a new `Server` for the given `Proxy` configuration and + /// immediately starts it. + /// + /// # Arguments + /// + /// * `configuration` - Shared (Arc) `Proxy` + /// + /// # Returns + /// + /// A `Server` instance that will keep running until its `.shutdown()` + /// method is called, or an error occurs. + pub fn new(configuration: Arc) -> Self { + let running = Arc::new(AtomicBool::new(true)); + let running_clone = Arc::clone(&running); - println!("Proxy listening on port {}", configuration.local_port); + let thread_handle = spawn(move || { + run_proxy(&configuration, &running_clone); + }); - for stream in listener.incoming() { - match stream { - Ok(stream) => { - let cloned_configuration = Arc::clone(&configuration); - spawn(move || { - handle_client(stream, cloned_configuration); + Server { + running, + thread_handle: Some(thread_handle), + } + } + + /// Signals this proxy to stop accepting new connections and waits + /// for all active connection threads to complete. + pub fn shutdown(&mut self) { + self.running.store(false, Ordering::SeqCst); + if let Some(handle) = self.thread_handle.take() { + let _ = handle.join(); + } + } +} + +/// The main loop that listens for incoming (plaintext) connections on +/// `configuration.bind_address:configuration.local_port`. +fn run_proxy(configuration: &Arc, running: &Arc) { + let listener = match TcpListener::bind(format!( + "{}:{}", + configuration.bind_address, configuration.local_port + )) { + Ok(l) => l, + Err(e) => { + error!("Failed to bind to port {}: {}", configuration.local_port, e); + return; + } + }; + listener.set_nonblocking(true).unwrap(); + + info!( + "{} proxy listening on port {}:{}", + configuration.protocol, configuration.bind_address, configuration.local_port + ); + + // Keep track of active connections so we can join them on shutdown + let mut active_threads = vec![]; + + while running.load(Ordering::SeqCst) { + match listener.accept() { + Ok((stream, addr)) => { + info!("New {} connection from {}", configuration.protocol, addr); + + let configuration_clone = Arc::clone(configuration); + let handle = spawn(move || { + handle_client(stream, &configuration_clone); }); + active_threads.push(handle); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // No pending connection; sleep briefly then loop again + sleep(Duration::from_millis(100)); + continue; } Err(e) => { - eprintln!("Failed to accept connection: {}", e); + error!("Error accepting connection: {}", e); + break; } } + + // Clean up any finished threads + active_threads.retain(|thread| !thread.is_finished()); + + // Potential Improvement: Configure thread limit. + if active_threads.len() >= 50 { + sleep(Duration::from_millis(100)); + } + } + + // On shutdown, wait for all threads to finish + for thread in active_threads { + let _ = thread.join(); } } -fn handle_client(mut client_stream: TcpStream, configuration: Arc) { - let connector = TlsConnector::new().unwrap(); - let remote_stream = TcpStream::connect(format!( +/// Handles a single client connection by bridging it (plaintext) to a TLS connection. +fn handle_client(client_stream: TcpStream, configuration: &Arc) { + if let Err(e) = client_stream.set_nonblocking(true) { + error!("Failed to set client stream to nonblocking: {}", e); + return; + } + + let connector = match TlsConnector::new() { + Ok(c) => c, + Err(e) => { + error!("Failed to create TLS connector: {}", e); + return; + } + }; + + let remote_addr = format!( "{}:{}", - configuration.remote_domain, configuration.remote_port - )) - .unwrap(); - let mut remote_stream = connector - .connect(&configuration.remote_domain, remote_stream) - .unwrap(); - - let mut client_stream_clone = client_stream.try_clone().unwrap(); - let mut remote_stream_clone = remote_stream.get_ref().try_clone().unwrap(); - - let cloned_configuration = Arc::clone(&configuration); - spawn(move || { - forward_stream(&mut client_stream, &mut remote_stream, cloned_configuration); - }); - forward_stream( - &mut remote_stream_clone, - &mut client_stream_clone, - configuration, + configuration.remote_host, configuration.remote_port ); -} + let tcp_stream = match TcpStream::connect(&remote_addr) { + Ok(stream) => stream, + Err(e) => { + error!("Failed to connect to {}: {}", remote_addr, e); + return; + } + }; + + let tls_stream = match connector.connect(&configuration.remote_host, tcp_stream) { + Ok(tls_stream) => tls_stream, + Err(e) => { + error!( + "TLS handshake to {} failed: {}", + configuration.remote_host, e + ); + return; + } + }; + + // The nonblocking needs to be set AFTER the TLS handshake is completed. + // Otherwise the TLS handshake is interrupted. + if let Err(e) = tls_stream.get_ref().set_nonblocking(true) { + error!("Failed to set remote stream to nonblocking: {}", e); + return; + } + + let tls_stream = Arc::new(Mutex::new(tls_stream)); -fn forward_stream( - from: &mut R, - to: &mut W, - configuration: Arc, -) { - let mut buffer = [0; 4096]; - loop { - match from.read(&mut buffer) { - Ok(0) => break, // EOF - Ok(n) => { - if let Err(e) = to.write_all(&buffer[..n]) { - eprintln!("{} proxy write error: {}", configuration.protocol, e); + let client_stream_clone = match client_stream.try_clone() { + Ok(s) => s, + Err(e) => { + error!("Failed to clone client stream: {}", e); + return; + } + }; + + // Client to Server Thread + let tls_stream_clone = Arc::clone(&tls_stream); + let client_to_server = spawn(move || { + debug!(">>> BEGIN"); + let mut buffer = [0u8; 8192]; + let mut client_reader = client_stream; + loop { + let bytes_read = match client_reader.read(&mut buffer) { + Ok(0) => break, + Ok(n) => n, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + sleep(Duration::from_millis(10)); + continue; + } + Err(error) => { + debug!(">>> Error reading buffer {error}"); break; } - if let Err(e) = to.flush() { - eprintln!("{} proxy flush error: {}", configuration.protocol, e); + }; + + let mut command = buffer[..bytes_read].to_vec(); + + for middleware in CLIENT_MIDDLEWARE { + command = middleware(&command); + } + + let debug_str = String::from_utf8_lossy(&command) + .replace('\n', "\\n") + .replace('\r', "\\r") + .replace('\t', "\\t"); + debug!(">>> {}", debug_str); + + // Lock the TLS stream and write the data to server + match tls_stream_clone.lock() { + Ok(mut tls_guard) => { + if let Err(error) = tls_guard.write_all(&command) { + debug!(">>> Error writing to server: {error}"); + break; + } + if let Err(error) = tls_guard.flush() { + debug!(">>> Error flushing server connection: {error}"); + break; + } + } + Err(error) => { + debug!(">>> Error acquiring TLS stream lock: {error}"); break; } } - Err(e) => { - eprintln!("{} proxy read error: {}", configuration.protocol, e); + } + }); + + // Server to Client Thread + let tls_stream_clone = Arc::clone(&tls_stream); + let server_to_client = spawn(move || { + debug!("<<< BEGIN"); + let mut buffer = [0u8; 8192]; + let mut client_writer = client_stream_clone; + loop { + // Lock the TLS stream and read from the server + let bytes_read = match tls_stream_clone.lock() { + Ok(mut tls_guard) => match tls_guard.read(&mut buffer) { + Ok(0) => break, // TLS server closed + Ok(n) => n, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + sleep(Duration::from_millis(10)); + continue; + } + Err(error) => { + debug!("<<< Error reading buffer {error}"); + break; + } + }, + Err(error) => { + debug!("<<< Error Cloning TLS {error}"); + break; + } + }; + + let mut command = buffer[..bytes_read].to_vec(); + + for middleware in SERVER_MIDDLEWARE { + command = middleware(&command); + } + + let debug_str = String::from_utf8_lossy(&command) + .replace('\n', "\\n") + .replace('\r', "\\r") + .replace('\t', "\\t"); + debug!("<<< {}", debug_str); + + // Write decrypted data to client + if client_writer.write_all(&command).is_err() { + debug!("<<< ERR"); + break; + } + if client_writer.flush().is_err() { + debug!("<<< ERR"); break; } } - } + }); + + // Wait for both directions to finish + let _ = client_to_server.join(); + let _ = server_to_client.join(); }