+//! # Proxy Module
+//!
+//! This module has the actual proxy functionality, exposed through
+//! `ProxyServer`. The proxy consists of a local unencrypted TCP stream
+//! and a remote TLS stream. Messages are passed between them via two
+//! threads.
+//!
+//! Each new client connection spawns two threads:
+//! - **Client to Server Thread**: Forwards data from client -> TLS server
+//! - **Serveer to Client Thread**: Forwards data from TLS server -> client
+//!
+//! Finally, the ProxyServer may be shutdown by calling `.shutdown()`,
+//! this will stop new connections and wait for it to finish.
+//!
+//! # Example
+//!
+//! ```
+//! use std::sync::Arc;
+//! use crate::configuration::ProxyConfiguration;
+//! use crate::proxy::ProxyServer;
+//!
+//! let config = Arc::new(ProxyConfiguration {
+//! protocol: "IMAP".to_string(),
+//! local_port: 143,
+//! remote_domain: "imap.example.com".to_string(),
+//! remote_port: 993,
+//! });
+//!
+//! let mut server = ProxyServer::new(config);
+//! // The server runs in a background thread. To shut down gracefully:
+//! server.shutdown();
+//! ```
+use log::{debug, error, info};
use native_tls::TlsConnector;
-use std::io::{Read, Write};
+use std::io::{ErrorKind, Read, Write};
use std::net::{TcpListener, TcpStream};
-use std::sync::Arc;
-use std::thread::{spawn, JoinHandle};
+use std::sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc, Mutex,
+};
+use std::thread::{sleep, spawn, JoinHandle};
+use std::time::Duration;
use crate::configuration::ProxyConfiguration;
-pub fn create_proxy(configuration: Arc<ProxyConfiguration>) -> JoinHandle<()> {
- let cloned_configuration = Arc::clone(&configuration);
- spawn(move || {
- run_proxy(cloned_configuration);
- })
+/// A proxy server that listens for plaintext connections and forwards them
+/// via TLS.
+///
+/// Creating a new `ProxyServer` spawns a dedicated thread that:
+/// - Binds to a local port (non-blocking mode).
+/// - Spawns additional threads for each incoming client connection.
+/// - Manages connection-lifetime until it receives a shutdown signal.
+pub struct ProxyServer {
+ running: Arc<AtomicBool>,
+ thread_handle: Option<JoinHandle<()>>,
}
-fn run_proxy(configuration: Arc<ProxyConfiguration>) {
- let listener = TcpListener::bind(format!("0.0.0.0:{}", configuration.local_port)).unwrap();
+impl ProxyServer {
+ /// Creates a new `ProxyServer` for the given `ProxyConfiguration` and
+ /// immediately starts it.
+ ///
+ /// # Arguments
+ ///
+ /// * `configuration` - Shared (Arc) `ProxyConfiguration`
+ ///
+ /// # Returns
+ ///
+ /// A `ProxyServer` instance that will keep running until its `.shutdown()`
+ /// method is called, or an error occurs.
+ pub fn new(configuration: Arc<ProxyConfiguration>) -> Self {
+ let running = Arc::new(AtomicBool::new(true));
+ let running_clone = Arc::clone(&running);
- println!("Proxy listening on port {}", configuration.local_port);
+ let thread_handle = spawn(move || {
+ run_proxy(configuration, running_clone);
+ });
- for stream in listener.incoming() {
- match stream {
- Ok(stream) => {
- let cloned_configuration = Arc::clone(&configuration);
- spawn(move || {
- handle_client(stream, cloned_configuration);
+ ProxyServer {
+ running,
+ thread_handle: Some(thread_handle),
+ }
+ }
+
+ /// Signals this proxy to stop accepting new connections and waits
+ /// for all active connection threads to complete.
+ pub fn shutdown(&mut self) {
+ self.running.store(false, Ordering::SeqCst);
+ if let Some(handle) = self.thread_handle.take() {
+ let _ = handle.join();
+ }
+ }
+}
+
+/// The main loop that listens for incoming (plaintext) connections on
+/// `configuration.bind_address:configuration.local_port`.
+fn run_proxy(configuration: Arc<ProxyConfiguration>, running: Arc<AtomicBool>) {
+ let listener = match TcpListener::bind(format!(
+ "{}:{}",
+ configuration.bind_address, configuration.local_port
+ )) {
+ Ok(l) => l,
+ Err(e) => {
+ error!("Failed to bind to port {}: {}", configuration.local_port, e);
+ return;
+ }
+ };
+ listener.set_nonblocking(true).unwrap();
+
+ info!(
+ "{} proxy listening on port {}:{}",
+ configuration.protocol, configuration.bind_address, configuration.local_port
+ );
+
+ // Keep track of active connections so we can join them on shutdown
+ let mut active_threads = vec![];
+
+ while running.load(Ordering::SeqCst) {
+ match listener.accept() {
+ Ok((stream, addr)) => {
+ info!("New {} connection from {}", configuration.protocol, addr);
+
+ let configuration_clone = Arc::clone(&configuration);
+ let handle = spawn(move || {
+ handle_client(stream, configuration_clone);
});
+ active_threads.push(handle);
+ }
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+ // No pending connection; sleep briefly then loop again
+ sleep(Duration::from_millis(100));
+ continue;
}
Err(e) => {
- eprintln!("Failed to accept connection: {}", e);
+ error!("Error accepting connection: {}", e);
+ break;
}
}
+
+ // Clean up any finished threads
+ active_threads.retain(|thread| !thread.is_finished());
+
+ // Potential Improvement: Configure thread limit.
+ if active_threads.len() >= 50 {
+ sleep(Duration::from_millis(100));
+ }
+ }
+
+ // On shutdown, wait for all threads to finish
+ for thread in active_threads {
+ let _ = thread.join();
}
}
-fn handle_client(mut client_stream: TcpStream, configuration: Arc<ProxyConfiguration>) {
- let connector = TlsConnector::new().unwrap();
- let remote_stream = TcpStream::connect(format!(
+/// Handles a single client connection by bridging it (plaintext) to a TLS connection.
+fn handle_client(client_stream: TcpStream, configuration: Arc<ProxyConfiguration>) {
+ if let Err(e) = client_stream.set_nonblocking(true) {
+ error!("Failed to set client stream to nonblocking: {}", e);
+ return;
+ }
+
+ let connector = match TlsConnector::new() {
+ Ok(c) => c,
+ Err(e) => {
+ error!("Failed to create TLS connector: {}", e);
+ return;
+ }
+ };
+
+ let remote_addr = format!(
"{}:{}",
- configuration.remote_domain, configuration.remote_port
- ))
- .unwrap();
- let mut remote_stream = connector
- .connect(&configuration.remote_domain, remote_stream)
- .unwrap();
-
- let mut client_stream_clone = client_stream.try_clone().unwrap();
- let mut remote_stream_clone = remote_stream.get_ref().try_clone().unwrap();
-
- let cloned_configuration = Arc::clone(&configuration);
- spawn(move || {
- forward_stream(&mut client_stream, &mut remote_stream, cloned_configuration);
- });
- forward_stream(
- &mut remote_stream_clone,
- &mut client_stream_clone,
- configuration,
+ configuration.remote_host, configuration.remote_port
);
-}
+ let tcp_stream = match TcpStream::connect(&remote_addr) {
+ Ok(stream) => stream,
+ Err(e) => {
+ error!("Failed to connect to {}: {}", remote_addr, e);
+ return;
+ }
+ };
+
+ let tls_stream = match connector.connect(&configuration.remote_host, tcp_stream) {
+ Ok(tls_stream) => tls_stream,
+ Err(e) => {
+ error!(
+ "TLS handshake to {} failed: {}",
+ configuration.remote_host, e
+ );
+ return;
+ }
+ };
+
+ // The nonblocking needs to be set AFTER the TLS handshake is completed.
+ // Otherwise the TLS handshake is interrupted.
+ if let Err(e) = tls_stream.get_ref().set_nonblocking(true) {
+ error!("Failed to set remote stream to nonblocking: {}", e);
+ return;
+ }
+
+ let tls_stream = Arc::new(Mutex::new(tls_stream));
+
+ let client_stream_clone = match client_stream.try_clone() {
+ Ok(s) => s,
+ Err(e) => {
+ error!("Failed to clone client stream: {}", e);
+ return;
+ }
+ };
+
+ // Client to Server Thread
+ let tls_stream_clone = Arc::clone(&tls_stream);
+ let client_to_server = spawn(move || {
+ debug!(">>> BEGIN");
+ let mut buffer = [0u8; 8192];
+ let mut client_reader = client_stream;
+ loop {
+ debug!(">");
+ let bytes_read = match client_reader.read(&mut buffer) {
+ Ok(0) => break,
+ Ok(n) => n,
+ Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
+ sleep(Duration::from_millis(10));
+ continue;
+ }
+ Err(error) => {
+ debug!(">>> Error reading buffer {error}");
+ break;
+ }
+ };
-fn forward_stream<R: Read, W: Write>(
- from: &mut R,
- to: &mut W,
- configuration: Arc<ProxyConfiguration>,
-) {
- let mut buffer = [0; 4096];
- loop {
- match from.read(&mut buffer) {
- Ok(0) => break, // EOF
- Ok(n) => {
- if let Err(e) = to.write_all(&buffer[..n]) {
- eprintln!("{} proxy write error: {}", configuration.protocol, e);
+ let debug_str = String::from_utf8_lossy(&buffer[..bytes_read])
+ .replace('\n', "\\n")
+ .replace('\r', "\\r")
+ .replace('\t', "\\t");
+ debug!(">>> {}", debug_str);
+
+ // Lock the TLS stream and write the data to server
+ match tls_stream_clone.lock() {
+ Ok(mut tls_guard) => {
+ if let Err(error) = tls_guard.write_all(&buffer[..bytes_read]) {
+ debug!(">>> Error writing to server: {error}");
+ break;
+ }
+ if let Err(error) = tls_guard.flush() {
+ debug!(">>> Error flushing server connection: {error}");
+ break;
+ }
+ }
+ Err(error) => {
+ debug!(">>> Error acquiring TLS stream lock: {error}");
break;
}
- if let Err(e) = to.flush() {
- eprintln!("{} proxy flush error: {}", configuration.protocol, e);
+ }
+ }
+ });
+
+ // Server to Client Thread
+ let tls_stream_clone = Arc::clone(&tls_stream);
+ let server_to_client = spawn(move || {
+ debug!("<<< BEGIN");
+ let mut buffer = [0u8; 8192];
+ let mut client_writer = client_stream_clone;
+ loop {
+ debug!("<");
+ // Lock the TLS stream and read from the server
+ let bytes_read = match tls_stream_clone.lock() {
+ Ok(mut tls_guard) => match tls_guard.read(&mut buffer) {
+ Ok(0) => break, // TLS server closed
+ Ok(n) => n,
+ Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
+ sleep(Duration::from_millis(10));
+ continue;
+ }
+ Err(error) => {
+ debug!("<<< Error reading buffer {error}");
+ break;
+ }
+ },
+ Err(error) => {
+ debug!("<<< Error Cloning TLS {error}");
break;
}
+ };
+
+ let debug_str = String::from_utf8_lossy(&buffer[..bytes_read])
+ .replace('\n', "\\n")
+ .replace('\r', "\\r")
+ .replace('\t', "\\t");
+ debug!("<<< {}", debug_str);
+
+ // Write decrypted data to client
+ if client_writer.write_all(&buffer[..bytes_read]).is_err() {
+ debug!("<<< ERR");
+ break;
}
- Err(e) => {
- eprintln!("{} proxy read error: {}", configuration.protocol, e);
+ if client_writer.flush().is_err() {
+ debug!("<<< ERR");
break;
}
}
- }
+ });
+
+ // Wait for both directions to finish
+ let _ = client_to_server.join();
+ let _ = server_to_client.join();
}