diff --git a/Cargo.lock b/Cargo.lock index 96a7984..9df2a7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "0.7.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4f55bd91a0978cbfd91c457a164bab8b4001c833b7f323132c0a4e1922dd44e" +dependencies = [ + "memchr", +] + [[package]] name = "async-attributes" version = "1.1.2" @@ -289,6 +298,12 @@ dependencies = [ "log", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.132" @@ -367,12 +382,31 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "regex" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" + [[package]] name = "simple-auth" version = "0.1.0" dependencies = [ "async-std", "json", + "lazy_static", + "regex", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 6a3e1a1..20e5bef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ edition = "2021" [dependencies] json = "0.12.4" +lazy_static = "1.4.0" +regex = "1" [dependencies.async-std] version = "1.6" diff --git a/src/handlers/request.rs b/src/handlers/request.rs index 7fa8af4..74a7b75 100644 --- a/src/handlers/request.rs +++ b/src/handlers/request.rs @@ -4,15 +4,23 @@ //! NOTE: only few parts of the specification has been implemented use json; +use lazy_static::lazy_static; +use regex::Regex; type RequestParts = (String, String, String); -// only two methods accepted -const HTTP_METHODS: [&'static str; 2] = ["POST", "GET"]; +// TODO: put this const in a conf file ? +const HTTP_METHODS: [&'static str; 1] = ["POST"]; +const HTTP_TARGETS: [&'static str; 3] = ["/validate/", "/get/", "/refresh/"]; + +lazy_static! { + static ref HTTP_VERSION_REGEX: Regex = Regex::new("^HTTP/(1.1|2)$").unwrap(); +} #[derive(Debug)] pub enum HTTPVersion { Http1, + Http2, Unknown, } @@ -20,11 +28,23 @@ impl Into for HTTPVersion { fn into(self) -> String { match self { Self::Http1 => "HTTP/1.1".to_string(), + Self::Http2 => "HTTP/2".to_string(), Self::Unknown => "UNKNOWN".to_string(), } } } +// TODO: not really satifying... could accept `String` too +impl From<&String> for HTTPVersion { + fn from(http_version: &String) -> Self { + match http_version.as_str() { + "HTTP/1.1" => Self::Http1, + "HTTP/2" => Self::Http2, + _ => Self::Unknown, + } + } +} + #[derive(Debug)] pub struct HTTPStartLine { pub method: String, @@ -50,28 +70,55 @@ impl HTTPStartLine { return Err("unable to parse the start correctly"); } - if !Self::check_method(parts[0].to_string()) { + let method = parts[0].to_string(); + let target = parts[1].to_string(); + let version = parts[2].to_string(); + + if !Self::check_method(&method) { return Err("method validation failed, bad method"); } + if !Self::check_target(&target) { + return Err("target validation failed, unvalid target"); + } + + if !Self::check_version(&version) { + return Err("http version validation failed, unknown version"); + } + // TODO: parse correctly the different parts (using regex ?) Ok(HTTPStartLine::new( - parts[0].to_string(), - parts[1].to_string(), - HTTPVersion::Http1, + method, + target, + HTTPVersion::from(&version), )) } /// check_method checks if the start_line method is in a predefined HTTP method list - fn check_method(method: String) -> bool { + fn check_method(method: &String) -> bool { for m in HTTP_METHODS.iter() { - if m.to_string() == method { + if m.to_string() == *method { return true; } } false } + /// check_target checks if the start_line target is in a predefined HTTP target whitelist + fn check_target(target: &String) -> bool { + for t in HTTP_TARGETS.iter() { + if t.to_string() == *target { + return true; + } + } + false + } + + fn check_version(version: &String) -> bool { + println!("version : {}", version); + HTTP_VERSION_REGEX.is_match(version) + } + fn is_valid(&self) -> bool { return self.method != "" && self.target != ""; } @@ -133,48 +180,22 @@ impl HTTPRequest { HTTPRequest { start_line, body } } - /// get mandatory request informations : - /// * start_line - /// * headers - fn get_request_mandats(request_parts: Vec<&str>) -> Result<(String, String), String> { - let headers_sline: Vec<&str> = request_parts[0].split("\r\n").collect(); - match headers_sline.len() { - 0 => return Err("request does not contain start_line or headers".to_string()), - 1 => return Ok((headers_sline[0].to_string(), "".to_string())), - // TODO: check if in the spec it must be 2 or 3 ! - 2 | 3 => return Ok((headers_sline[0].to_string(), headers_sline[1].to_string())), - _ => return Err("bad start_line headers parsing".to_string()), - } - } - /// split correctly the incoming request in order to get : /// * start_line /// * headers /// * data (if exists) fn get_request_parts(request: &str) -> Result { // separate the body part from the start_line and the headers - let request_parts: Vec<&str> = request.split("\r\n\r\n").collect(); - if request_parts.len() == 0 { + let request_parts: Vec<&str> = request.split("\r\n").collect(); + println!("request_parts : {:?}", request_parts); + if request_parts.len() < 3 { return Err("request has no enough informations to be correctly parsed".to_string()); } - - match request_parts.len() { - 0 => { - return Err("request has no enough informations to be correctly parsed".to_string()) - } - 1 => match HTTPRequest::get_request_mandats(request_parts) { - Ok(v) => return Ok((v.0, v.1, "".to_string())), - Err(e) => return Err(e), - }, - 2 => { - let body = request_parts[1]; - match HTTPRequest::get_request_mandats(request_parts) { - Ok(v) => return Ok((v.0, v.1, body.to_string())), - Err(e) => return Err(e), - } - } - _ => return Err("bad incoming request, impossible to parse".to_string()), - } + Ok(( + request_parts[0].to_string(), + request_parts[1].to_string(), + request_parts[2].to_string(), + )) } /// parse parses the request by spliting the incoming request with the separator `\r\n` @@ -244,18 +265,34 @@ fn test_handle_request() { is_valid: bool, } - let test_cases: [(&str, Expect); 7] = [ + let test_cases: [(String, Expect); 10] = [ ( - "GET / HTTP/1.1\r\n\r\n", + "POST /get/ HTTP/1.1\r\n\r\n".to_string(), Expect { - start_line: "GET / HTTP/1.1".to_string(), + start_line: "POST /get/ HTTP/1.1".to_string(), body: None, is_valid: true, }, ), + ( + "POST /refresh/ HTTP/2\r\n\r\n".to_string(), + Expect { + start_line: "POST /refresh/ HTTP/2".to_string(), + body: None, + is_valid: true, + }, + ), + ( + "GET / HTTP/1.1\r\n\r\n".to_string(), + Expect { + start_line: " UNKNOWN".to_string(), + body: None, + is_valid: false, + }, + ), // intentionally add HTTP with no version number ( - "OPTIONS /admin/2 HTTP/\r\nContent-Type: application/json\r\n", + "OPTIONS /admin/2 HTTP/\r\nContent-Type: application/json\r\n".to_string(), Expect { start_line: " UNKNOWN".to_string(), body: None, @@ -263,7 +300,7 @@ fn test_handle_request() { }, ), ( - "POST HTTP", + "POST HTTP".to_string(), Expect { start_line: " UNKNOWN".to_string(), body: None, @@ -271,7 +308,7 @@ fn test_handle_request() { } ), ( - "", + "".to_string(), Expect { start_line: " UNKNOWN".to_string(), body: None, @@ -279,7 +316,7 @@ fn test_handle_request() { } ), ( - "fjlqskjd /oks?id=65 HTTP/2\r\n\r\n", + "fjlqskjd /oks?id=65 HTTP/2\r\n\r\n".to_string(), Expect { start_line: " UNKNOWN".to_string(), body: None, @@ -287,7 +324,7 @@ fn test_handle_request() { } ), ( - " ", + " ".to_string(), Expect { start_line: " UNKNOWN".to_string(), body: None, @@ -295,17 +332,25 @@ fn test_handle_request() { } ), ( - r#"lm //// skkss\r\ndkldklkdl\r\n"{"access_token":"AAAAAAAAAAAA.BBBBBBBBBB.CCCCCCCCCC","refresh_token": "DDDDDDDDDDD.EEEEEEEEEEE.FFFFF"}""#, + r#"lm //// skkss\r\ndkldklkdl\r\n"{"access_token":"AAAAAAAAAAAA.BBBBBBBBBB.CCCCCCCCCC","refresh_token": "DDDDDDDDDDD.EEEEEEEEEEE.FFFFF"}""#.to_string(), Expect { start_line: " UNKNOWN".to_string(), body: Some(r#"{"access_token":"AAAAAAAAAAAA.BBBBBBBBBB.CCCCCCCCCC","refresh_token": "DDDDDDDDDDD.EEEEEEEEEEE.FFFFF"}"#.to_string()), is_valid: false, } ), + ( + format!("{}\r\nuselessheaders\r\n{}", "POST /refresh/ HTTP/1.1", r#"{"access_token": "toto", "refresh_token": "tutu"}"#), + Expect { + start_line: "POST /refresh/ HTTP/1.1".to_string(), + body: Some(r#"{"access_token":"toto","refresh_token":"tutu"}"#.to_string()), + is_valid: true, + } + ), ]; for (request, expect) in test_cases { - let http_request = HTTPRequest::from(request); + let http_request = HTTPRequest::from(request.as_str()); println!("{:?}", http_request); assert_eq!(expect.is_valid, http_request.is_valid()); @@ -314,7 +359,7 @@ fn test_handle_request() { match http_request.body { Some(v) => { - assert_eq!(expect.body.unwrap(), v.data) + assert_eq!(expect.body.unwrap(), v.data.dump()) } None => continue, } @@ -345,13 +390,13 @@ fn test_http_method() { let test_cases: Vec<(String, bool)> = vec![ ("POST".to_string(), true), ("POST ".to_string(), false), - ("GET".to_string(), true), + ("GET".to_string(), false), ("get".to_string(), false), ("qsdqsfqsf/".to_string(), false), ("OPTIONS".to_string(), false), ]; for (method, is_valid) in test_cases { - assert_eq!(is_valid, HTTPStartLine::check_method(method)); + assert_eq!(is_valid, HTTPStartLine::check_method(&method)); } }