1 module scorpion.validation;
2 
3 import std.conv : to, ConvException;
4 import std.experimental.logger : traceImpl = trace;
5 import std.regex : PhobosRegex = Regex, regex, matchFirst;
6 import std..string : split, indexOf, replace;
7 import std.traits : isArray, Parameters;
8 import std.uri : decodeComponent, emailLength;
9 
10 import asdf : serializeToJson, deserialize, DeserializationException;
11 
12 import lighttp : ServerRequest, ServerResponse, StatusCodes;
13 
14 /**
15  * Creates a custom validator attribute.
16  * Note that due to a linker bug `func` cannot be a lambda.
17  * Example:
18  * ---
19  * // validates only if the value is equals to `test`, case insensitive
20  * bool validateExp(string value) {
21  *    import std.string : toLower;
22  *    return value.toLower() == "test";
23  * }
24  * alias Test = CustomValidation!validateExp;
25  * ...
26  * struct Form {
27  * 
28  *    @Test
29  *    string test;
30  * 
31  * }
32  * ---
33  */
34 struct CustomValidation(alias func) {
35 
36 	enum validator = true;
37 
38 	Parameters!func[1..$] args;
39 
40 	string message;
41 
42 	bool validate(T)(T value) {
43 		return func(value, args);
44 	}
45 
46 }
47 
48 /**
49  * Validates a string only if its length is higher
50  * than 0.
51  */
52 alias NotEmpty = CustomValidation!(notEmptyValidator);
53 
54 private bool notEmptyValidator(string value){ return value.length > 0; }
55 
56 /**
57  * Validates a string only if its length is shorter or equals
58  * to the given value.
59  * Example:
60  * ---
61  * @Min(10)
62  * ---
63  */
64 alias Min = CustomValidation!(minValidator);
65 
66 private bool minValidator(string value, size_t min){ return value.length >= min; }
67 
68 /**
69  * Validates a string only if its length is longer or equals
70  * to the given value.
71  * Example:
72  * ---
73  * @Max(100)
74  * ---
75  */
76 alias Max = CustomValidation!(maxValidator);
77 
78 private bool maxValidator(string value, size_t max){ return value.length <= max; }
79 
80 /**
81  * Validates a number only if it's not 0.
82  */
83 alias NotZero = CustomValidation!(notZeroValidator);
84 
85 private bool notZeroValidator(long value){ return value != 0; }
86 
87 /**
88  * Validates a string using a regular expression.
89  * Example:
90  * ---
91  * @Regex(`[a-zA-Z0-9]*`)
92  * ---
93  */
94 alias Regex = CustomValidation!(regexValidator);
95 
96 private PhobosRegex!char[string] _cached;
97 
98 private bool regexValidator(string value, string r){
99 	auto rg = {
100 		auto cached = r in _cached;
101 		if(cached) return *cached;
102 		auto ret = regex(r);
103 		_cached[r] = ret;
104 		return ret;
105 	}();
106 	auto result = matchFirst(value, rg);
107 	return !result.empty && result.pre.length == 0 && result.post.length == 0;
108 }
109 
110 /**
111  * Validates an email address using std.uri.emailLength.
112  */
113 alias Email = CustomValidation!(emailValidator);
114 
115 private bool emailValidator(string value) { return emailLength(value) == value.length; }
116 
117 /**
118  * Indicates that the field is optional. This means that, if set to
119  * its initial state or not set, no validation will be done.
120  * Example:
121  * ---
122  * @Min(5)
123  * @Optional
124  * string text;
125  * 
126  * // ok
127  * ""
128  * 
129  * // not ok (min is 5)
130  * "test"
131  * ---
132  */
133 enum Optional;
134 
135 /**
136  * Validation utility class that contains errors.
137  */
138 class Validation {
139 
140 	/**
141 	 * Representation of an error.
142 	 */
143 	static struct Error {
144 
145 		/**
146 		 * Field that the error is associated to.
147 		 * If the error is not associated to any particular field
148 		 * or it is associated to all of the them the value of this
149 		 * member id `*`.
150 		 */
151 		string field;
152 
153 		/**
154 		 * Error message. Can be an error code a sentence in a human-readable
155 		 * language.
156 		 * It's the last parameter is a custom validator.
157 		 */
158 		string message;
159 
160 	}
161 
162 	/**
163 	 * Errors generated during the validation process or by the
164 	 * controller.
165 	 */
166 	Error[] errors;
167 
168 	/**
169 	 * Indicates whether the validation was successful. A successful
170 	 * validation has no error.
171 	 */
172 	@property bool valid() pure nothrow @safe @nogc {
173 		return errors.length == 0;
174 	}
175 
176 	/**
177 	 * If the validation failed sets the response's status to
178 	 * `bad request` and prints the errors to its body.
179 	 */
180 	void apply(ServerResponse response) {
181 		if(!valid) {
182 			response.status = StatusCodes.badRequest;
183 			response.body_ = serializeToJson(this);
184 		}
185 	}
186 
187 }
188 
189 private string convertValue(string value) {
190 	return decodeComponent(replace(value, "+", " "));
191 }
192 
193 T validateParam(T)(string param, ServerRequest request, ServerResponse response) {
194 	auto values = request.url.queryParams[param];
195 	static if(isArray!T && !is(T : string)) {
196 		T ret;
197 		foreach(value ; values) {
198 			try {
199 				ret ~= to!(typeof(T.init[0]))(convertValue(value));
200 			} catch(ConvException) {
201 				response.status = StatusCodes.badRequest;
202 				return [];
203 			}
204 		}
205 		return ret;
206 	} else {
207 		if(values.length) {
208 			try {
209 				return to!T(convertValue(values[0]));
210 			} catch(ConvException) {}
211 		}
212 		response.status = StatusCodes.badRequest;
213 		return T.init;
214 	}
215 }
216 
217 enum ContentType {
218 
219 	formUrlencoded,
220 	json,
221 
222 }
223 
224 T validateBody(T)(ServerRequest request, ServerResponse response, Validation validation) {
225 	T ret;
226 	static if(is(T == class)) ret = new T();
227 	auto contentType = {
228 		auto ptr = "content-type" in request.headers;
229 		if(ptr is null) return ContentType.formUrlencoded;
230 		else switch(*ptr) with(ContentType) {
231 			case "application/json": return json;
232 			default: return formUrlencoded;
233 		}
234 	}();
235 	if(contentType == ContentType.formUrlencoded) {
236 		foreach(keyValue ; split(request.body_, "&")) {
237 			immutable eq = keyValue.indexOf("=");
238 			if(eq > 0) {
239 				immutable key = keyValue[0..eq];
240 				foreach(immutable member ; __traits(allMembers, T)) {
241 					static if(__traits(compiles, mixin("ret." ~ member ~ "=typeof(ret." ~ member ~ ").init"))) {
242 						if(key == member) {
243 							immutable value = convertValue(keyValue[eq+1..$]);
244 							try {
245 								static if(isArray!(typeof(__traits(getMember, T, member)))) enum op = "~=";
246 								else enum op = "=";
247 								mixin("ret." ~ member ~ op ~ "to!(typeof(__traits(getMember, T, member)))(value);");
248 								continue;
249 							} catch(ConvException) {
250 								trace(request, "field ", member, " with value '", value, "' cannot be converted to ", typeof(__traits(getMember, T, member)).stringof);
251 								validation.errors ~= Validation.Error(key, "conversionError");
252 								return T.init;
253 							}
254 						}
255 					}
256 				}
257 			} else {
258 				trace(request, "malformatted x-www-form-urlencoded");
259 				validation.errors ~= Validation.Error("*", "format");
260 				return T.init;
261 			}
262 		}
263 	} else if(contentType == ContentType.json) {
264 		try ret = deserialize!T(request.body_);
265 		catch(DeserializationException) {
266 			trace(request, "invalid JSON");
267 			validation.errors ~= Validation.Error("*", "format");
268 			return T.init;
269 		}
270 	}
271 	foreach(immutable member ; __traits(allMembers, T)) {
272 		static if(__traits(compiles, mixin("ret." ~ member ~ "=T.init." ~ member))) {{
273 			Validation.Error[] errors;
274 			bool optional = false;
275 			foreach(attr ; __traits(getAttributes, __traits(getMember, T, member))) {
276 				static if(is(attr == Optional)) {
277 					optional = true;
278 				} else static if(is(typeof(attr.validator))) {
279 					static if(__traits(compiles, attr())) {
280 						if(!attr().validate(mixin("ret." ~ member))) {
281 							errors ~= Validation.Error(member, "");
282 						}
283 					} else {
284 						if(!attr.validate(mixin("ret." ~ member))) {
285 							errors ~= Validation.Error(member, attr.message);
286 						}
287 					}
288 				}
289 			}
290 			if(!optional || mixin("ret." ~ member) != typeof(mixin("ret." ~ member)).init) validation.errors ~= errors;
291 		}}
292 	}
293 	return ret;
294 }
295 
296 private void trace(E...)(ServerRequest request, E args) {
297 	traceImpl(request.address, " ", request.method, " ", request.url.path, ": ", args);
298 }