]>
Commit | Line | Data |
---|---|---|
1 | //! # Proxy Module | |
2 | //! | |
3 | //! This module has the actual proxy functionality, exposed through | |
4 | //! `Server`. The proxy consists of a local unencrypted TCP stream | |
5 | //! and a remote TLS stream. Messages are passed between them via two | |
6 | //! threads. | |
7 | //! | |
8 | //! Each new client connection spawns two threads: | |
9 | //! - **Client to Server Thread**: Forwards data from client -> TLS server | |
10 | //! - **Serveer to Client Thread**: Forwards data from TLS server -> client | |
11 | //! | |
12 | //! Finally, the `Server` may be shutdown by calling `.shutdown()`, | |
13 | //! this will stop new connections and wait for it to finish. | |
14 | //! | |
15 | //! # Example | |
16 | //! | |
17 | //! ``` | |
18 | //! use std::sync::Arc; | |
19 | //! use crate::configuration::Proxy; | |
20 | //! use crate::proxy::Server; | |
21 | //! | |
22 | //! let config = Arc::new(Proxy { | |
23 | //! protocol: "IMAP".to_string(), | |
24 | //! local_port: 143, | |
25 | //! remote_domain: "imap.example.com".to_string(), | |
26 | //! remote_port: 993, | |
27 | //! }); | |
28 | //! | |
29 | //! let mut server = Server::new(config); | |
30 | //! // The server runs in a background thread. To shut down gracefully: | |
31 | //! server.shutdown(); | |
32 | //! ``` | |
33 | use log::{debug, error, info}; | |
34 | use native_tls::TlsConnector; | |
35 | use std::io::{ErrorKind, Read, Write}; | |
36 | use std::net::{TcpListener, TcpStream}; | |
37 | use std::sync::{ | |
38 | atomic::{AtomicBool, Ordering}, | |
39 | Arc, Mutex, | |
40 | }; | |
41 | use std::thread::{sleep, spawn, JoinHandle}; | |
42 | use std::time::Duration; | |
43 | ||
44 | use crate::configuration::Proxy; | |
45 | use crate::middleware::get_middleware; | |
46 | ||
47 | /// A proxy server that listens for plaintext connections and forwards them | |
48 | /// via TLS. | |
49 | /// | |
50 | /// Creating a new `Server` spawns a dedicated thread that: | |
51 | /// - Binds to a local port (non-blocking mode). | |
52 | /// - Spawns additional threads for each incoming client connection. | |
53 | /// - Manages connection-lifetime until it receives a shutdown signal. | |
54 | pub struct Server { | |
55 | running: Arc<AtomicBool>, | |
56 | thread_handle: Option<JoinHandle<()>>, | |
57 | } | |
58 | ||
59 | impl Server { | |
60 | /// Creates a new `Server` for the given `Proxy` configuration and | |
61 | /// immediately starts it. | |
62 | /// | |
63 | /// # Arguments | |
64 | /// | |
65 | /// * `configuration` - Shared (Arc) `Proxy` | |
66 | /// | |
67 | /// # Returns | |
68 | /// | |
69 | /// A `Server` instance that will keep running until its `.shutdown()` | |
70 | /// method is called, or an error occurs. | |
71 | pub fn new(configuration: Arc<Proxy>) -> Self { | |
72 | let running = Arc::new(AtomicBool::new(true)); | |
73 | let running_clone = Arc::clone(&running); | |
74 | ||
75 | let thread_handle = spawn(move || { | |
76 | run_proxy(&configuration, &running_clone); | |
77 | }); | |
78 | ||
79 | Server { | |
80 | running, | |
81 | thread_handle: Some(thread_handle), | |
82 | } | |
83 | } | |
84 | ||
85 | /// Signals this proxy to stop accepting new connections and waits | |
86 | /// for all active connection threads to complete. | |
87 | pub fn shutdown(&mut self) { | |
88 | self.running.store(false, Ordering::SeqCst); | |
89 | if let Some(handle) = self.thread_handle.take() { | |
90 | let _ = handle.join(); | |
91 | } | |
92 | } | |
93 | } | |
94 | ||
95 | /// The main loop that listens for incoming (plaintext) connections on | |
96 | /// `configuration.bind_address:configuration.local_port`. | |
97 | fn run_proxy(configuration: &Arc<Proxy>, running: &Arc<AtomicBool>) { | |
98 | let listener = match TcpListener::bind(format!( | |
99 | "{}:{}", | |
100 | configuration.bind_address, configuration.local_port | |
101 | )) { | |
102 | Ok(l) => l, | |
103 | Err(e) => { | |
104 | error!("Failed to bind to port {}: {}", configuration.local_port, e); | |
105 | return; | |
106 | } | |
107 | }; | |
108 | listener.set_nonblocking(true).unwrap(); | |
109 | ||
110 | info!( | |
111 | "{} proxy listening on port {}:{}", | |
112 | configuration.protocol, configuration.bind_address, configuration.local_port | |
113 | ); | |
114 | ||
115 | // Keep track of active connections so we can join them on shutdown | |
116 | let mut active_threads = vec![]; | |
117 | ||
118 | while running.load(Ordering::SeqCst) { | |
119 | match listener.accept() { | |
120 | Ok((stream, address)) => { | |
121 | info!("New {} connection from {}", configuration.protocol, address); | |
122 | ||
123 | let configuration_clone = Arc::clone(configuration); | |
124 | let handle = spawn(move || { | |
125 | handle_client(stream, &configuration_clone); | |
126 | }); | |
127 | active_threads.push(handle); | |
128 | } | |
129 | Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { | |
130 | // No pending connection; sleep briefly then loop again | |
131 | sleep(Duration::from_millis(100)); | |
132 | continue; | |
133 | } | |
134 | Err(e) => { | |
135 | error!("Error accepting connection: {}", e); | |
136 | break; | |
137 | } | |
138 | } | |
139 | ||
140 | // Clean up any finished threads | |
141 | active_threads.retain(|thread| !thread.is_finished()); | |
142 | ||
143 | // Potential Improvement: Configure thread limit. | |
144 | if active_threads.len() >= 50 { | |
145 | sleep(Duration::from_millis(100)); | |
146 | } | |
147 | } | |
148 | ||
149 | // On shutdown, wait for all threads to finish | |
150 | for thread in active_threads { | |
151 | let _ = thread.join(); | |
152 | } | |
153 | } | |
154 | ||
155 | /// Handles a single client connection by bridging it (plaintext) to a TLS connection. | |
156 | fn handle_client(client_stream: TcpStream, configuration: &Arc<Proxy>) { | |
157 | if let Err(e) = client_stream.set_nonblocking(true) { | |
158 | error!("Failed to set client stream to nonblocking: {}", e); | |
159 | return; | |
160 | } | |
161 | ||
162 | let available_middleware = get_middleware(); | |
163 | let available_middleware_clone = Arc::clone(&available_middleware); | |
164 | ||
165 | let connector = match TlsConnector::new() { | |
166 | Ok(c) => c, | |
167 | Err(e) => { | |
168 | error!("Failed to create TLS connector: {}", e); | |
169 | return; | |
170 | } | |
171 | }; | |
172 | ||
173 | let remote_address = format!( | |
174 | "{}:{}", | |
175 | configuration.remote_host, configuration.remote_port | |
176 | ); | |
177 | let tcp_stream = match TcpStream::connect(&remote_address) { | |
178 | Ok(stream) => stream, | |
179 | Err(e) => { | |
180 | error!("Failed to connect to {}: {}", remote_address, e); | |
181 | return; | |
182 | } | |
183 | }; | |
184 | ||
185 | let tls_stream = match connector.connect(&configuration.remote_host, tcp_stream) { | |
186 | Ok(tls_stream) => tls_stream, | |
187 | Err(e) => { | |
188 | error!( | |
189 | "TLS handshake to {} failed: {}", | |
190 | configuration.remote_host, e | |
191 | ); | |
192 | return; | |
193 | } | |
194 | }; | |
195 | ||
196 | // The nonblocking needs to be set AFTER the TLS handshake is completed. | |
197 | // Otherwise the TLS handshake is interrupted. | |
198 | if let Err(e) = tls_stream.get_ref().set_nonblocking(true) { | |
199 | error!("Failed to set remote stream to nonblocking: {}", e); | |
200 | return; | |
201 | } | |
202 | ||
203 | let tls_stream = Arc::new(Mutex::new(tls_stream)); | |
204 | ||
205 | let client_stream_clone = match client_stream.try_clone() { | |
206 | Ok(s) => s, | |
207 | Err(e) => { | |
208 | error!("Failed to clone client stream: {}", e); | |
209 | return; | |
210 | } | |
211 | }; | |
212 | ||
213 | // Client to Server Thread | |
214 | let tls_stream_clone = Arc::clone(&tls_stream); | |
215 | let client_to_server = spawn(move || { | |
216 | debug!(">>> BEGIN"); | |
217 | let mut buffer = [0u8; 8192]; | |
218 | let mut client_reader = client_stream; | |
219 | loop { | |
220 | let bytes_read = match client_reader.read(&mut buffer) { | |
221 | Ok(0) => break, | |
222 | Ok(n) => n, | |
223 | Err(ref e) if e.kind() == ErrorKind::WouldBlock => { | |
224 | sleep(Duration::from_millis(10)); | |
225 | continue; | |
226 | } | |
227 | Err(error) => { | |
228 | debug!(">>> Error reading buffer {error}"); | |
229 | break; | |
230 | } | |
231 | }; | |
232 | ||
233 | let mut command = buffer[..bytes_read].to_vec(); | |
234 | ||
235 | if let Ok(mut guard) = available_middleware.lock() { | |
236 | for middleware in guard.iter_mut() { | |
237 | command = middleware.client_message(&command); | |
238 | } | |
239 | } | |
240 | ||
241 | let debug_str = String::from_utf8_lossy(&buffer[..bytes_read]) | |
242 | .replace('\n', "\\n") | |
243 | .replace('\r', "\\r") | |
244 | .replace('\t', "\\t"); | |
245 | debug!(">>> {}", debug_str); | |
246 | ||
247 | // Lock the TLS stream and write the data to server | |
248 | match tls_stream_clone.lock() { | |
249 | Ok(mut tls_guard) => { | |
250 | if let Err(error) = tls_guard.write_all(&command) { | |
251 | debug!(">>> Error writing to server: {error}"); | |
252 | break; | |
253 | } | |
254 | if let Err(error) = tls_guard.flush() { | |
255 | debug!(">>> Error flushing server connection: {error}"); | |
256 | break; | |
257 | } | |
258 | } | |
259 | Err(error) => { | |
260 | debug!(">>> Error acquiring TLS stream lock: {error}"); | |
261 | break; | |
262 | } | |
263 | } | |
264 | } | |
265 | }); | |
266 | ||
267 | // Server to Client Thread | |
268 | let tls_stream_clone = Arc::clone(&tls_stream); | |
269 | let server_to_client = spawn(move || { | |
270 | debug!("<<< BEGIN"); | |
271 | let mut buffer = [0u8; 8192]; | |
272 | let mut client_writer = client_stream_clone; | |
273 | loop { | |
274 | // Lock the TLS stream and read from the server | |
275 | let bytes_read = match tls_stream_clone.lock() { | |
276 | Ok(mut tls_guard) => match tls_guard.read(&mut buffer) { | |
277 | Ok(0) => break, // TLS server closed | |
278 | Ok(n) => n, | |
279 | Err(ref e) if e.kind() == ErrorKind::WouldBlock => { | |
280 | sleep(Duration::from_millis(10)); | |
281 | continue; | |
282 | } | |
283 | Err(error) => { | |
284 | debug!("<<< Error reading buffer {error}"); | |
285 | break; | |
286 | } | |
287 | }, | |
288 | Err(error) => { | |
289 | debug!("<<< Error Cloning TLS {error}"); | |
290 | break; | |
291 | } | |
292 | }; | |
293 | ||
294 | let mut command = buffer[..bytes_read].to_vec(); | |
295 | ||
296 | if let Ok(mut guard) = available_middleware_clone.lock() { | |
297 | for middleware in guard.iter_mut() { | |
298 | command = middleware.server_message(&command); | |
299 | } | |
300 | } | |
301 | ||
302 | let debug_str = String::from_utf8_lossy(&buffer[..bytes_read]) | |
303 | .replace('\n', "\\n") | |
304 | .replace('\r', "\\r") | |
305 | .replace('\t', "\\t"); | |
306 | debug!("<<< {}", debug_str); | |
307 | ||
308 | // Write decrypted data to client | |
309 | if client_writer.write_all(&command).is_err() { | |
310 | debug!("<<< ERR"); | |
311 | break; | |
312 | } | |
313 | if client_writer.flush().is_err() { | |
314 | debug!("<<< ERR"); | |
315 | break; | |
316 | } | |
317 | } | |
318 | }); | |
319 | ||
320 | // Wait for both directions to finish | |
321 | let _ = client_to_server.join(); | |
322 | let _ = server_to_client.join(); | |
323 | } |