]>
Commit | Line | Data |
---|---|---|
1 | //! # Proxy Module | |
2 | //! | |
3 | //! This module has the actual proxy functionality, exposed through | |
4 | //! `Server`. The proxy consists of a local unencrypted TCP stream | |
5 | //! and a remote TLS stream. Messages are passed between them via two | |
6 | //! threads. | |
7 | //! | |
8 | //! Each new client connection spawns two threads: | |
9 | //! - **Client to Server Thread**: Forwards data from client -> TLS server | |
10 | //! - **Serveer to Client Thread**: Forwards data from TLS server -> client | |
11 | //! | |
12 | //! Finally, the `Server` may be shutdown by calling `.shutdown()`, | |
13 | //! this will stop new connections and wait for it to finish. | |
14 | //! | |
15 | //! # Example | |
16 | //! | |
17 | //! ``` | |
18 | //! use std::sync::Arc; | |
19 | //! use crate::configuration::Proxy; | |
20 | //! use crate::proxy::Server; | |
21 | //! | |
22 | //! let config = Arc::new(Proxy { | |
23 | //! protocol: "IMAP".to_string(), | |
24 | //! local_port: 143, | |
25 | //! remote_domain: "imap.example.com".to_string(), | |
26 | //! remote_port: 993, | |
27 | //! }); | |
28 | //! | |
29 | //! let mut server = Server::new(config); | |
30 | //! // The server runs in a background thread. To shut down gracefully: | |
31 | //! server.shutdown(); | |
32 | //! ``` | |
33 | use log::{debug, error, info}; | |
34 | use native_tls::TlsConnector; | |
35 | use std::io::{ErrorKind, Read, Write}; | |
36 | use std::net::{TcpListener, TcpStream}; | |
37 | use std::sync::{ | |
38 | atomic::{AtomicBool, Ordering}, | |
39 | Arc, Mutex, | |
40 | }; | |
41 | use std::thread::{sleep, spawn, JoinHandle}; | |
42 | use std::time::Duration; | |
43 | ||
44 | use crate::configuration::Proxy; | |
45 | ||
46 | /// A proxy server that listens for plaintext connections and forwards them | |
47 | /// via TLS. | |
48 | /// | |
49 | /// Creating a new `Server` spawns a dedicated thread that: | |
50 | /// - Binds to a local port (non-blocking mode). | |
51 | /// - Spawns additional threads for each incoming client connection. | |
52 | /// - Manages connection-lifetime until it receives a shutdown signal. | |
53 | pub struct Server { | |
54 | running: Arc<AtomicBool>, | |
55 | thread_handle: Option<JoinHandle<()>>, | |
56 | } | |
57 | ||
58 | impl Server { | |
59 | /// Creates a new `Server` for the given `Proxy` configuration and | |
60 | /// immediately starts it. | |
61 | /// | |
62 | /// # Arguments | |
63 | /// | |
64 | /// * `configuration` - Shared (Arc) `Proxy` | |
65 | /// | |
66 | /// # Returns | |
67 | /// | |
68 | /// A `Server` instance that will keep running until its `.shutdown()` | |
69 | /// method is called, or an error occurs. | |
70 | pub fn new(configuration: Arc<Proxy>) -> Self { | |
71 | let running = Arc::new(AtomicBool::new(true)); | |
72 | let running_clone = Arc::clone(&running); | |
73 | ||
74 | let thread_handle = spawn(move || { | |
75 | run_proxy(&configuration, &running_clone); | |
76 | }); | |
77 | ||
78 | Server { | |
79 | running, | |
80 | thread_handle: Some(thread_handle), | |
81 | } | |
82 | } | |
83 | ||
84 | /// Signals this proxy to stop accepting new connections and waits | |
85 | /// for all active connection threads to complete. | |
86 | pub fn shutdown(&mut self) { | |
87 | self.running.store(false, Ordering::SeqCst); | |
88 | if let Some(handle) = self.thread_handle.take() { | |
89 | let _ = handle.join(); | |
90 | } | |
91 | } | |
92 | } | |
93 | ||
94 | /// The main loop that listens for incoming (plaintext) connections on | |
95 | /// `configuration.bind_address:configuration.local_port`. | |
96 | fn run_proxy(configuration: &Arc<Proxy>, running: &Arc<AtomicBool>) { | |
97 | let listener = match TcpListener::bind(format!( | |
98 | "{}:{}", | |
99 | configuration.bind_address, configuration.local_port | |
100 | )) { | |
101 | Ok(l) => l, | |
102 | Err(e) => { | |
103 | error!("Failed to bind to port {}: {}", configuration.local_port, e); | |
104 | return; | |
105 | } | |
106 | }; | |
107 | listener.set_nonblocking(true).unwrap(); | |
108 | ||
109 | info!( | |
110 | "{} proxy listening on port {}:{}", | |
111 | configuration.protocol, configuration.bind_address, configuration.local_port | |
112 | ); | |
113 | ||
114 | // Keep track of active connections so we can join them on shutdown | |
115 | let mut active_threads = vec![]; | |
116 | ||
117 | while running.load(Ordering::SeqCst) { | |
118 | match listener.accept() { | |
119 | Ok((stream, addr)) => { | |
120 | info!("New {} connection from {}", configuration.protocol, addr); | |
121 | ||
122 | let configuration_clone = Arc::clone(configuration); | |
123 | let handle = spawn(move || { | |
124 | handle_client(stream, &configuration_clone); | |
125 | }); | |
126 | active_threads.push(handle); | |
127 | } | |
128 | Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { | |
129 | // No pending connection; sleep briefly then loop again | |
130 | sleep(Duration::from_millis(100)); | |
131 | continue; | |
132 | } | |
133 | Err(e) => { | |
134 | error!("Error accepting connection: {}", e); | |
135 | break; | |
136 | } | |
137 | } | |
138 | ||
139 | // Clean up any finished threads | |
140 | active_threads.retain(|thread| !thread.is_finished()); | |
141 | ||
142 | // Potential Improvement: Configure thread limit. | |
143 | if active_threads.len() >= 50 { | |
144 | sleep(Duration::from_millis(100)); | |
145 | } | |
146 | } | |
147 | ||
148 | // On shutdown, wait for all threads to finish | |
149 | for thread in active_threads { | |
150 | let _ = thread.join(); | |
151 | } | |
152 | } | |
153 | ||
154 | /// Handles a single client connection by bridging it (plaintext) to a TLS connection. | |
155 | fn handle_client(client_stream: TcpStream, configuration: &Arc<Proxy>) { | |
156 | if let Err(e) = client_stream.set_nonblocking(true) { | |
157 | error!("Failed to set client stream to nonblocking: {}", e); | |
158 | return; | |
159 | } | |
160 | ||
161 | let connector = match TlsConnector::new() { | |
162 | Ok(c) => c, | |
163 | Err(e) => { | |
164 | error!("Failed to create TLS connector: {}", e); | |
165 | return; | |
166 | } | |
167 | }; | |
168 | ||
169 | let remote_addr = format!( | |
170 | "{}:{}", | |
171 | configuration.remote_host, configuration.remote_port | |
172 | ); | |
173 | let tcp_stream = match TcpStream::connect(&remote_addr) { | |
174 | Ok(stream) => stream, | |
175 | Err(e) => { | |
176 | error!("Failed to connect to {}: {}", remote_addr, e); | |
177 | return; | |
178 | } | |
179 | }; | |
180 | ||
181 | let tls_stream = match connector.connect(&configuration.remote_host, tcp_stream) { | |
182 | Ok(tls_stream) => tls_stream, | |
183 | Err(e) => { | |
184 | error!( | |
185 | "TLS handshake to {} failed: {}", | |
186 | configuration.remote_host, e | |
187 | ); | |
188 | return; | |
189 | } | |
190 | }; | |
191 | ||
192 | // The nonblocking needs to be set AFTER the TLS handshake is completed. | |
193 | // Otherwise the TLS handshake is interrupted. | |
194 | if let Err(e) = tls_stream.get_ref().set_nonblocking(true) { | |
195 | error!("Failed to set remote stream to nonblocking: {}", e); | |
196 | return; | |
197 | } | |
198 | ||
199 | let tls_stream = Arc::new(Mutex::new(tls_stream)); | |
200 | ||
201 | let client_stream_clone = match client_stream.try_clone() { | |
202 | Ok(s) => s, | |
203 | Err(e) => { | |
204 | error!("Failed to clone client stream: {}", e); | |
205 | return; | |
206 | } | |
207 | }; | |
208 | ||
209 | // Client to Server Thread | |
210 | let tls_stream_clone = Arc::clone(&tls_stream); | |
211 | let client_to_server = spawn(move || { | |
212 | debug!(">>> BEGIN"); | |
213 | let mut buffer = [0u8; 8192]; | |
214 | let mut client_reader = client_stream; | |
215 | loop { | |
216 | debug!(">"); | |
217 | let bytes_read = match client_reader.read(&mut buffer) { | |
218 | Ok(0) => break, | |
219 | Ok(n) => n, | |
220 | Err(ref e) if e.kind() == ErrorKind::WouldBlock => { | |
221 | sleep(Duration::from_millis(10)); | |
222 | continue; | |
223 | } | |
224 | Err(error) => { | |
225 | debug!(">>> Error reading buffer {error}"); | |
226 | break; | |
227 | } | |
228 | }; | |
229 | ||
230 | let debug_str = String::from_utf8_lossy(&buffer[..bytes_read]) | |
231 | .replace('\n', "\\n") | |
232 | .replace('\r', "\\r") | |
233 | .replace('\t', "\\t"); | |
234 | debug!(">>> {}", debug_str); | |
235 | ||
236 | // Lock the TLS stream and write the data to server | |
237 | match tls_stream_clone.lock() { | |
238 | Ok(mut tls_guard) => { | |
239 | if let Err(error) = tls_guard.write_all(&buffer[..bytes_read]) { | |
240 | debug!(">>> Error writing to server: {error}"); | |
241 | break; | |
242 | } | |
243 | if let Err(error) = tls_guard.flush() { | |
244 | debug!(">>> Error flushing server connection: {error}"); | |
245 | break; | |
246 | } | |
247 | } | |
248 | Err(error) => { | |
249 | debug!(">>> Error acquiring TLS stream lock: {error}"); | |
250 | break; | |
251 | } | |
252 | } | |
253 | } | |
254 | }); | |
255 | ||
256 | // Server to Client Thread | |
257 | let tls_stream_clone = Arc::clone(&tls_stream); | |
258 | let server_to_client = spawn(move || { | |
259 | debug!("<<< BEGIN"); | |
260 | let mut buffer = [0u8; 8192]; | |
261 | let mut client_writer = client_stream_clone; | |
262 | loop { | |
263 | debug!("<"); | |
264 | // Lock the TLS stream and read from the server | |
265 | let bytes_read = match tls_stream_clone.lock() { | |
266 | Ok(mut tls_guard) => match tls_guard.read(&mut buffer) { | |
267 | Ok(0) => break, // TLS server closed | |
268 | Ok(n) => n, | |
269 | Err(ref e) if e.kind() == ErrorKind::WouldBlock => { | |
270 | sleep(Duration::from_millis(10)); | |
271 | continue; | |
272 | } | |
273 | Err(error) => { | |
274 | debug!("<<< Error reading buffer {error}"); | |
275 | break; | |
276 | } | |
277 | }, | |
278 | Err(error) => { | |
279 | debug!("<<< Error Cloning TLS {error}"); | |
280 | break; | |
281 | } | |
282 | }; | |
283 | ||
284 | let debug_str = String::from_utf8_lossy(&buffer[..bytes_read]) | |
285 | .replace('\n', "\\n") | |
286 | .replace('\r', "\\r") | |
287 | .replace('\t', "\\t"); | |
288 | debug!("<<< {}", debug_str); | |
289 | ||
290 | // Write decrypted data to client | |
291 | if client_writer.write_all(&buffer[..bytes_read]).is_err() { | |
292 | debug!("<<< ERR"); | |
293 | break; | |
294 | } | |
295 | if client_writer.flush().is_err() { | |
296 | debug!("<<< ERR"); | |
297 | break; | |
298 | } | |
299 | } | |
300 | }); | |
301 | ||
302 | // Wait for both directions to finish | |
303 | let _ = client_to_server.join(); | |
304 | let _ = server_to_client.join(); | |
305 | } |