]> git.r.bdr.sh - rbdr/olden-mail/blobdiff - src/proxy.rs
Apply formatting
[rbdr/olden-mail] / src / proxy.rs
index ce40e6ed1ebb1bc5dec98fa693d35141589b7ead..33de44039615959fdc360b56767d601bbbba58d7 100644 (file)
+//! # Proxy Module
+//!
+//! This module has the actual proxy functionality, exposed through
+//! `Server`. The proxy consists of a local unencrypted TCP stream
+//! and a remote TLS stream. Messages are passed between them via two
+//! threads.
+//!
+//! Each new client connection spawns two threads:
+//! - **Client to Server Thread**: Forwards data from client -> TLS server
+//! - **Serveer to Client Thread**: Forwards data from TLS server -> client
+//!
+//! Finally, the `Server` may be shutdown by calling `.shutdown()`,
+//! this will stop new connections and wait for it to finish.
+//!
+//! # Example
+//!
+//! ```
+//! use std::sync::Arc;
+//! use crate::configuration::Proxy;
+//! use crate::proxy::Server;
+//!
+//! let config = Arc::new(Proxy {
+//!     protocol: "IMAP".to_string(),
+//!     local_port: 143,
+//!     remote_domain: "imap.example.com".to_string(),
+//!     remote_port: 993,
+//! });
+//!
+//! let mut server = Server::new(config);
+//! // The server runs in a background thread. To shut down gracefully:
+//! server.shutdown();
+//! ```
+use log::{debug, error, info};
 use native_tls::TlsConnector;
 use native_tls::TlsConnector;
-use std::io::{Read, Write};
+use std::io::{ErrorKind, Read, Write};
 use std::net::{TcpListener, TcpStream};
 use std::net::{TcpListener, TcpStream};
-use std::sync::Arc;
-use std::thread::{spawn, JoinHandle};
+use std::sync::{
+    atomic::{AtomicBool, Ordering},
+    Arc, Mutex,
+};
+use std::thread::{sleep, spawn, JoinHandle};
+use std::time::Duration;
 
 
-use crate::configuration::ProxyConfiguration;
+use crate::configuration::Proxy;
+use crate::middleware::{CLIENT_MIDDLEWARE, SERVER_MIDDLEWARE};
 
 
-pub fn create_proxy(configuration: Arc<ProxyConfiguration>) -> JoinHandle<()> {
-    let cloned_configuration = Arc::clone(&configuration);
-    spawn(move || {
-        run_proxy(cloned_configuration);
-    })
+/// A proxy server that listens for plaintext connections and forwards them
+/// via TLS.
+///
+/// Creating a new `Server` spawns a dedicated thread that:
+/// - Binds to a local port (non-blocking mode).
+/// - Spawns additional threads for each incoming client connection.
+/// - Manages connection-lifetime until it receives a shutdown signal.
+pub struct Server {
+    running: Arc<AtomicBool>,
+    thread_handle: Option<JoinHandle<()>>,
 }
 
 }
 
-fn run_proxy(configuration: Arc<ProxyConfiguration>) {
-    let listener = TcpListener::bind(format!("0.0.0.0:{}", configuration.local_port)).unwrap();
+impl Server {
+    /// Creates a new `Server` for the given `Proxy` configuration and
+    /// immediately starts it.
+    ///
+    /// # Arguments
+    ///
+    /// * `configuration` - Shared (Arc) `Proxy`
+    ///
+    /// # Returns
+    ///
+    /// A `Server` instance that will keep running until its `.shutdown()`
+    /// method is called, or an error occurs.
+    pub fn new(configuration: Arc<Proxy>) -> Self {
+        let running = Arc::new(AtomicBool::new(true));
+        let running_clone = Arc::clone(&running);
 
 
-    println!("Proxy listening on port {}", configuration.local_port);
+        let thread_handle = spawn(move || {
+            run_proxy(&configuration, &running_clone);
+        });
 
 
-    for stream in listener.incoming() {
-        match stream {
-            Ok(stream) => {
-                let cloned_configuration = Arc::clone(&configuration);
-                spawn(move || {
-                    handle_client(stream, cloned_configuration);
+        Server {
+            running,
+            thread_handle: Some(thread_handle),
+        }
+    }
+
+    /// Signals this proxy to stop accepting new connections and waits
+    /// for all active connection threads to complete.
+    pub fn shutdown(&mut self) {
+        self.running.store(false, Ordering::SeqCst);
+        if let Some(handle) = self.thread_handle.take() {
+            let _ = handle.join();
+        }
+    }
+}
+
+/// The main loop that listens for incoming (plaintext) connections on
+/// `configuration.bind_address:configuration.local_port`.
+fn run_proxy(configuration: &Arc<Proxy>, running: &Arc<AtomicBool>) {
+    let listener = match TcpListener::bind(format!(
+        "{}:{}",
+        configuration.bind_address, configuration.local_port
+    )) {
+        Ok(l) => l,
+        Err(e) => {
+            error!("Failed to bind to port {}: {}", configuration.local_port, e);
+            return;
+        }
+    };
+    listener.set_nonblocking(true).unwrap();
+
+    info!(
+        "{} proxy listening on port {}:{}",
+        configuration.protocol, configuration.bind_address, configuration.local_port
+    );
+
+    // Keep track of active connections so we can join them on shutdown
+    let mut active_threads = vec![];
+
+    while running.load(Ordering::SeqCst) {
+        match listener.accept() {
+            Ok((stream, addr)) => {
+                info!("New {} connection from {}", configuration.protocol, addr);
+
+                let configuration_clone = Arc::clone(configuration);
+                let handle = spawn(move || {
+                    handle_client(stream, &configuration_clone);
                 });
                 });
+                active_threads.push(handle);
+            }
+            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+                // No pending connection; sleep briefly then loop again
+                sleep(Duration::from_millis(100));
+                continue;
             }
             Err(e) => {
             }
             Err(e) => {
-                eprintln!("Failed to accept connection: {}", e);
+                error!("Error accepting connection: {}", e);
+                break;
             }
         }
             }
         }
+
+        // Clean up any finished threads
+        active_threads.retain(|thread| !thread.is_finished());
+
+        // Potential Improvement: Configure thread limit.
+        if active_threads.len() >= 50 {
+            sleep(Duration::from_millis(100));
+        }
+    }
+
+    // On shutdown, wait for all threads to finish
+    for thread in active_threads {
+        let _ = thread.join();
     }
 }
 
     }
 }
 
-fn handle_client(mut client_stream: TcpStream, configuration: Arc<ProxyConfiguration>) {
-    let connector = TlsConnector::new().unwrap();
-    let remote_stream = TcpStream::connect(format!(
+/// Handles a single client connection by bridging it (plaintext) to a TLS connection.
+fn handle_client(client_stream: TcpStream, configuration: &Arc<Proxy>) {
+    if let Err(e) = client_stream.set_nonblocking(true) {
+        error!("Failed to set client stream to nonblocking: {}", e);
+        return;
+    }
+
+    let connector = match TlsConnector::new() {
+        Ok(c) => c,
+        Err(e) => {
+            error!("Failed to create TLS connector: {}", e);
+            return;
+        }
+    };
+
+    let remote_addr = format!(
         "{}:{}",
         "{}:{}",
-        configuration.remote_domain, configuration.remote_port
-    ))
-    .unwrap();
-    let mut remote_stream = connector
-        .connect(&configuration.remote_domain, remote_stream)
-        .unwrap();
-
-    let mut client_stream_clone = client_stream.try_clone().unwrap();
-    let mut remote_stream_clone = remote_stream.get_ref().try_clone().unwrap();
-
-    let cloned_configuration = Arc::clone(&configuration);
-    spawn(move || {
-        forward_stream(&mut client_stream, &mut remote_stream, cloned_configuration);
-    });
-    forward_stream(
-        &mut remote_stream_clone,
-        &mut client_stream_clone,
-        configuration,
+        configuration.remote_host, configuration.remote_port
     );
     );
-}
+    let tcp_stream = match TcpStream::connect(&remote_addr) {
+        Ok(stream) => stream,
+        Err(e) => {
+            error!("Failed to connect to {}: {}", remote_addr, e);
+            return;
+        }
+    };
+
+    let tls_stream = match connector.connect(&configuration.remote_host, tcp_stream) {
+        Ok(tls_stream) => tls_stream,
+        Err(e) => {
+            error!(
+                "TLS handshake to {} failed: {}",
+                configuration.remote_host, e
+            );
+            return;
+        }
+    };
+
+    // The nonblocking needs to be set AFTER the TLS handshake is completed.
+    // Otherwise the TLS handshake is interrupted.
+    if let Err(e) = tls_stream.get_ref().set_nonblocking(true) {
+        error!("Failed to set remote stream to nonblocking: {}", e);
+        return;
+    }
+
+    let tls_stream = Arc::new(Mutex::new(tls_stream));
 
 
-fn forward_stream<R: Read, W: Write>(
-    from: &mut R,
-    to: &mut W,
-    configuration: Arc<ProxyConfiguration>,
-) {
-    let mut buffer = [0; 4096];
-    loop {
-        match from.read(&mut buffer) {
-            Ok(0) => break, // EOF
-            Ok(n) => {
-                if let Err(e) = to.write_all(&buffer[..n]) {
-                    eprintln!("{} proxy write error: {}", configuration.protocol, e);
+    let client_stream_clone = match client_stream.try_clone() {
+        Ok(s) => s,
+        Err(e) => {
+            error!("Failed to clone client stream: {}", e);
+            return;
+        }
+    };
+
+    // Client to Server Thread
+    let tls_stream_clone = Arc::clone(&tls_stream);
+    let client_to_server = spawn(move || {
+        debug!(">>> BEGIN");
+        let mut buffer = [0u8; 8192];
+        let mut client_reader = client_stream;
+        loop {
+            let bytes_read = match client_reader.read(&mut buffer) {
+                Ok(0) => break,
+                Ok(n) => n,
+                Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
+                    sleep(Duration::from_millis(10));
+                    continue;
+                }
+                Err(error) => {
+                    debug!(">>> Error reading buffer {error}");
                     break;
                 }
                     break;
                 }
-                if let Err(e) = to.flush() {
-                    eprintln!("{} proxy flush error: {}", configuration.protocol, e);
+            };
+
+            let mut command = buffer[..bytes_read].to_vec();
+
+            for middleware in CLIENT_MIDDLEWARE {
+                command = middleware(&command);
+            }
+
+            let debug_str = String::from_utf8_lossy(&command)
+                .replace('\n', "\\n")
+                .replace('\r', "\\r")
+                .replace('\t', "\\t");
+            debug!(">>> {}", debug_str);
+
+            // Lock the TLS stream and write the data to server
+            match tls_stream_clone.lock() {
+                Ok(mut tls_guard) => {
+                    if let Err(error) = tls_guard.write_all(&command) {
+                        debug!(">>> Error writing to server: {error}");
+                        break;
+                    }
+                    if let Err(error) = tls_guard.flush() {
+                        debug!(">>> Error flushing server connection: {error}");
+                        break;
+                    }
+                }
+                Err(error) => {
+                    debug!(">>> Error acquiring TLS stream lock: {error}");
                     break;
                 }
             }
                     break;
                 }
             }
-            Err(e) => {
-                eprintln!("{} proxy read error: {}", configuration.protocol, e);
+        }
+    });
+
+    // Server to Client Thread
+    let tls_stream_clone = Arc::clone(&tls_stream);
+    let server_to_client = spawn(move || {
+        debug!("<<< BEGIN");
+        let mut buffer = [0u8; 8192];
+        let mut client_writer = client_stream_clone;
+        loop {
+            // Lock the TLS stream and read from the server
+            let bytes_read = match tls_stream_clone.lock() {
+                Ok(mut tls_guard) => match tls_guard.read(&mut buffer) {
+                    Ok(0) => break, // TLS server closed
+                    Ok(n) => n,
+                    Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
+                        sleep(Duration::from_millis(10));
+                        continue;
+                    }
+                    Err(error) => {
+                        debug!("<<< Error reading buffer {error}");
+                        break;
+                    }
+                },
+                Err(error) => {
+                    debug!("<<< Error Cloning TLS {error}");
+                    break;
+                }
+            };
+
+            let mut command = buffer[..bytes_read].to_vec();
+
+            for middleware in SERVER_MIDDLEWARE {
+                command = middleware(&command);
+            }
+
+            let debug_str = String::from_utf8_lossy(&command)
+                .replace('\n', "\\n")
+                .replace('\r', "\\r")
+                .replace('\t', "\\t");
+            debug!("<<< {}", debug_str);
+
+            // Write decrypted data to client
+            if client_writer.write_all(&command).is_err() {
+                debug!("<<< ERR");
+                break;
+            }
+            if client_writer.flush().is_err() {
+                debug!("<<< ERR");
                 break;
             }
         }
                 break;
             }
         }
-    }
+    });
+
+    // Wait for both directions to finish
+    let _ = client_to_server.join();
+    let _ = server_to_client.join();
 }
 }