Crow  1.1
A C++ microframework for the web
 
Loading...
Searching...
No Matches
session.h
1#pragma once
2
3#include "crow/http_request.h"
4#include "crow/http_response.h"
5#include "crow/json.h"
6#include "crow/utility.h"
7#include "crow/middlewares/cookie_parser.h"
8
9#include <unordered_map>
10#include <unordered_set>
11#include <set>
12#include <queue>
13
14#include <memory>
15#include <string>
16#include <cstdio>
17#include <mutex>
18
19#include <fstream>
20#include <sstream>
21
22#include <type_traits>
23#include <functional>
24#include <chrono>
25
26#include <variant>
27
28namespace
29{
30 // convert all integer values to int64_t
31 template<typename T>
32 using wrap_integral_t = typename std::conditional<
33 std::is_integral<T>::value && !std::is_same<bool, T>::value
34 // except for uint64_t because that could lead to overflow on conversion
35 && !std::is_same<uint64_t, T>::value,
36 int64_t, T>::type;
37
38 // convert char[]/char* to std::string
39 template<typename T>
40 using wrap_char_t = typename std::conditional<
41 std::is_same<typename std::decay<T>::type, char*>::value,
42 std::string, T>::type;
43
44 // Upgrade to correct type for multi_variant use
45 template<typename T>
46 using wrap_mv_t = wrap_char_t<wrap_integral_t<T>>;
47} // namespace
48
49namespace crow
50{
51 namespace session
52 {
53
54 using multi_value_types = black_magic::S<bool, int64_t, double, std::string>;
55
56 /// A multi_value is a safe variant wrapper with json conversion support
58 {
59 json::wvalue json() const
60 {
61 // clang-format off
62 return std::visit([](auto arg) {
63 return json::wvalue(arg);
64 }, v_);
65 // clang-format on
66 }
67
68 static multi_value from_json(const json::rvalue&);
69
70 std::string string() const
71 {
72 // clang-format off
73 return std::visit([](auto arg) {
74 if constexpr (std::is_same_v<decltype(arg), std::string>)
75 return arg;
76 else
77 return std::to_string(arg);
78 }, v_);
79 // clang-format on
80 }
81
82 template<typename T, typename RT = wrap_mv_t<T>>
83 RT get(const T& fallback)
84 {
85 if (const RT* val = std::get_if<RT>(&v_)) return *val;
86 return fallback;
87 }
88
89 template<typename T, typename RT = wrap_mv_t<T>>
90 void set(T val)
91 {
92 v_ = RT(std::move(val));
93 }
94
95 typename multi_value_types::rebind<std::variant> v_;
96 };
97
98 inline multi_value multi_value::from_json(const json::rvalue& rv)
99 {
100 using namespace json;
101 switch (rv.t())
102 {
103 case type::Number:
104 {
105 if (rv.nt() == num_type::Floating_point || rv.nt() == num_type::Double_precision_floating_point)
106 return multi_value{rv.d()};
107 else if (rv.nt() == num_type::Unsigned_integer)
108 return multi_value{int64_t(rv.u())};
109 else
110 return multi_value{rv.i()};
111 }
112 case type::False: return multi_value{false};
113 case type::True: return multi_value{true};
114 case type::String: return multi_value{std::string(rv)};
115 default: return multi_value{false};
116 }
117 }
118
119 /// Expiration tracker keeps track of soonest-to-expire keys
121 {
122 using DataPair = std::pair<uint64_t /*time*/, std::string /*key*/>;
123
124 /// Add key with time to tracker.
125 /// If the key is already present, it will be updated
126 void add(std::string key, uint64_t time)
127 {
128 auto it = times_.find(key);
129 if (it != times_.end()) remove(key);
130 times_[key] = time;
131 queue_.insert({time, std::move(key)});
132 }
133
134 void remove(const std::string& key)
135 {
136 auto it = times_.find(key);
137 if (it != times_.end())
138 {
139 queue_.erase({it->second, key});
140 times_.erase(it);
141 }
142 }
143
144 /// Get expiration time of soonest-to-expire entry
145 uint64_t peek_first() const
146 {
147 if (queue_.empty()) return std::numeric_limits<uint64_t>::max();
148 return queue_.begin()->first;
149 }
150
151 std::string pop_first()
152 {
153 auto it = times_.find(queue_.begin()->second);
154 auto key = it->first;
155 times_.erase(it);
156 queue_.erase(queue_.begin());
157 return key;
158 }
159
160 using iterator = typename std::set<DataPair>::const_iterator;
161
162 iterator begin() const { return queue_.cbegin(); }
163
164 iterator end() const { return queue_.cend(); }
165
166 private:
167 std::set<DataPair> queue_;
168 std::unordered_map<std::string, uint64_t> times_;
169 };
170
171 /// CachedSessions are shared across requests
173 {
174 std::string session_id;
175 std::string requested_session_id; // session hasn't been created yet, but a key was requested
176
177 std::unordered_map<std::string, multi_value> entries;
178 std::unordered_set<std::string> dirty; // values that were changed after last load
179
180 void* store_data;
181 bool requested_refresh;
182
183 // number of references held - used for correctly destroying the cache.
184 // No need to be atomic, all SessionMiddleware accesses are synchronized
185 int referrers;
186 std::recursive_mutex mutex;
187 };
188 } // namespace session
189
190 // SessionMiddleware allows storing securely and easily small snippets of user information
191 template<typename Store>
193 {
194 using lock = std::scoped_lock<std::mutex>;
195 using rc_lock = std::scoped_lock<std::recursive_mutex>;
196
197 struct context
198 {
199 // Get a mutex for locking this session
200 std::recursive_mutex& mutex()
201 {
202 check_node();
203 return node->mutex;
204 }
205
206 // Check whether this session is already present
207 bool exists() { return bool(node); }
208
209 // Get a value by key or fallback if it doesn't exist or is of another type
210 template<typename F>
211 auto get(const std::string& key, const F& fallback = F())
212 // This trick lets the multi_value deduce the return type from the fallback
213 // which allows both:
214 // context.get<std::string>("key")
215 // context.get("key", "") -> char[] is transformed into string by multivalue
216 // to return a string
217 -> decltype(std::declval<session::multi_value>().get<F>(std::declval<F>()))
218 {
219 if (!node) return fallback;
220 rc_lock l(node->mutex);
221
222 auto it = node->entries.find(key);
223 if (it != node->entries.end()) return it->second.get<F>(fallback);
224 return fallback;
225 }
226
227 // Set a value by key
228 template<typename T>
229 void set(const std::string& key, T value)
230 {
231 check_node();
232 rc_lock l(node->mutex);
233
234 node->dirty.insert(key);
235 node->entries[key].set(std::move(value));
236 }
237
238 bool contains(const std::string& key)
239 {
240 if (!node) return false;
241 return node->entries.find(key) != node->entries.end();
242 }
243
244 // Atomically mutate a value with a function
245 template<typename Func>
246 void apply(const std::string& key, const Func& f)
247 {
248 using traits = utility::function_traits<Func>;
249 using arg = typename std::decay<typename traits::template arg<0>>::type;
250 using retv = typename std::decay<typename traits::result_type>::type;
251 check_node();
252 rc_lock l(node->mutex);
253 node->dirty.insert(key);
254 node->entries[key].set<retv>(f(node->entries[key].get(arg{})));
255 }
256
257 // Remove a value from the session
258 void remove(const std::string& key)
259 {
260 if (!node) return;
261 rc_lock l(node->mutex);
262 node->dirty.insert(key);
263 node->entries.erase(key);
264 }
265
266 // Format value by key as a string
267 std::string string(const std::string& key)
268 {
269 if (!node) return "";
270 rc_lock l(node->mutex);
271
272 auto it = node->entries.find(key);
273 if (it != node->entries.end()) return it->second.string();
274 return "";
275 }
276
277 // Get a list of keys present in session
278 std::vector<std::string> keys()
279 {
280 if (!node) return {};
281 rc_lock l(node->mutex);
282
283 std::vector<std::string> out;
284 for (const auto& p : node->entries)
285 out.push_back(p.first);
286 return out;
287 }
288
289 // Delay expiration by issuing another cookie with an updated expiration time
290 // and notifying the store
291 void refresh_expiration()
292 {
293 if (!node) return;
294 node->requested_refresh = true;
295 }
296
297 private:
298 friend struct SessionMiddleware;
299
300 void check_node()
301 {
302 if (!node) node = std::make_shared<session::CachedSession>();
303 }
304
305 std::shared_ptr<session::CachedSession> node;
306 };
307
308 template<typename... Ts>
311 int id_length,
312 Ts... ts):
313 id_length_(id_length),
314 cookie_(cookie),
315 store_(std::forward<Ts>(ts)...), mutex_(new std::mutex{})
316 {}
317
318 template<typename... Ts>
319 SessionMiddleware(Ts... ts):
321 CookieParser::Cookie("session").path("/").max_age(/*month*/ 30 * 24 * 60 * 60),
322 /*id_length */ 20, // around 10^34 possible combinations, but small enough to fit into SSO
323 std::forward<Ts>(ts)...)
324 {}
325
326 template<typename AllContext>
327 void before_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx)
328 {
329 lock l(*mutex_);
330
331 auto& cookies = all_ctx.template get<CookieParser>();
332 auto session_id = load_id(cookies);
333 if (session_id == "") return;
334
335 // search entry in cache
336 auto it = cache_.find(session_id);
337 if (it != cache_.end())
338 {
339 it->second->referrers++;
340 ctx.node = it->second;
341 return;
342 }
343
344 // check this is a valid entry before loading
345 if (!store_.contains(session_id)) return;
346
347 auto node = std::make_shared<session::CachedSession>();
348 node->session_id = session_id;
349 node->referrers = 1;
350
351 try
352 {
353 store_.load(*node);
354 }
355 catch (...)
356 {
357 CROW_LOG_ERROR << "Exception occurred during session load";
358 return;
359 }
360
361 ctx.node = node;
362 cache_[session_id] = node;
363 }
364
365 template<typename AllContext>
366 void after_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx)
367 {
368 lock l(*mutex_);
369 if (!ctx.node || --ctx.node->referrers > 0) return;
370 ctx.node->requested_refresh |= ctx.node->session_id == "";
371
372 // generate new id
373 if (ctx.node->session_id == "")
374 {
375 // check for requested id
376 ctx.node->session_id = std::move(ctx.node->requested_session_id);
377 if (ctx.node->session_id == "")
378 {
379 ctx.node->session_id = utility::random_alphanum(id_length_);
380 }
381 }
382 else
383 {
384 cache_.erase(ctx.node->session_id);
385 }
386
387 if (ctx.node->requested_refresh)
388 {
389 auto& cookies = all_ctx.template get<CookieParser>();
390 store_id(cookies, ctx.node->session_id);
391 }
392
393 try
394 {
395 store_.save(*ctx.node);
396 }
397 catch (...)
398 {
399 CROW_LOG_ERROR << "Exception occurred during session save";
400 return;
401 }
402 }
403
404 private:
405 std::string next_id()
406 {
407 std::string id;
408 do
409 {
410 id = utility::random_alphanum(id_length_);
411 } while (store_.contains(id));
412 return id;
413 }
414
415 std::string load_id(const CookieParser::context& cookies)
416 {
417 return cookies.get_cookie(cookie_.name());
418 }
419
420 void store_id(CookieParser::context& cookies, const std::string& session_id)
421 {
422 cookie_.value(session_id);
423 cookies.set_cookie(cookie_);
424 }
425
426 private:
427 int id_length_;
428
429 // prototype for cookie
430 CookieParser::Cookie cookie_;
431
432 Store store_;
433
434 // mutexes are immovable
435 std::unique_ptr<std::mutex> mutex_;
436 std::unordered_map<std::string, std::shared_ptr<session::CachedSession>> cache_;
437 };
438
439 /// InMemoryStore stores all entries in memory
441 {
442 // Load a value into the session cache.
443 // A load is always followed by a save, no loads happen consecutively
444 void load(session::CachedSession& cn)
445 {
446 // load & stores happen sequentially, so moving is safe
447 cn.entries = std::move(entries[cn.session_id]);
448 }
449
450 // Persist session data
451 void save(session::CachedSession& cn)
452 {
453 entries[cn.session_id] = std::move(cn.entries);
454 // cn.dirty is a list of changed keys since the last load
455 }
456
457 bool contains(const std::string& key)
458 {
459 return entries.count(key) > 0;
460 }
461
462 std::unordered_map<std::string, std::unordered_map<std::string, session::multi_value>> entries;
463 };
464
465 // FileStore stores all data as json files in a folder.
466 // Files are deleted after expiration. Expiration refreshes are automatically picked up.
468 {
469 FileStore(const std::string& folder, uint64_t expiration_seconds = /*month*/ 30 * 24 * 60 * 60):
470 path_(folder), expiration_seconds_(expiration_seconds)
471 {
472 std::ifstream ifs(get_filename(".expirations", false));
473
474 auto current_ts = chrono_time();
475 std::string key;
476 uint64_t time;
477 while (ifs >> key >> time)
478 {
479 if (current_ts > time)
480 {
481 evict(key);
482 }
483 else if (contains(key))
484 {
485 expirations_.add(key, time);
486 }
487 }
488 }
489
490 ~FileStore()
491 {
492 std::ofstream ofs(get_filename(".expirations", false), std::ios::trunc);
493 for (const auto& p : expirations_)
494 ofs << p.second << " " << p.first << "\n";
495 }
496
497 // Delete expired entries
498 // At most 3 to prevent freezes
499 void handle_expired()
500 {
501 int deleted = 0;
502 auto current_ts = chrono_time();
503 while (current_ts > expirations_.peek_first() && deleted < 3)
504 {
505 evict(expirations_.pop_first());
506 deleted++;
507 }
508 }
509
510 void load(session::CachedSession& cn)
511 {
512 handle_expired();
513
514 std::ifstream file(get_filename(cn.session_id));
515
516 std::stringstream buffer;
517 buffer << file.rdbuf() << std::endl;
518
519 for (const auto& p : json::load(buffer.str()))
520 cn.entries[p.key()] = session::multi_value::from_json(p);
521 }
522
523 void save(session::CachedSession& cn)
524 {
525 if (cn.requested_refresh)
526 expirations_.add(cn.session_id, chrono_time() + expiration_seconds_);
527 if (cn.dirty.empty()) return;
528
529 std::ofstream file(get_filename(cn.session_id));
530 json::wvalue jw;
531 for (const auto& p : cn.entries)
532 jw[p.first] = p.second.json();
533 file << jw.dump() << std::flush;
534 }
535
536 std::string get_filename(const std::string& key, bool suffix = true)
537 {
538 return utility::join_path(path_, key + (suffix ? ".json" : ""));
539 }
540
541 bool contains(const std::string& key)
542 {
543 std::ifstream file(get_filename(key));
544 return file.good();
545 }
546
547 void evict(const std::string& key)
548 {
549 std::remove(get_filename(key).c_str());
550 }
551
552 uint64_t chrono_time() const
553 {
554 return std::chrono::duration_cast<std::chrono::seconds>(
555 std::chrono::system_clock::now().time_since_epoch())
556 .count();
557 }
558
559 std::string path_;
560 uint64_t expiration_seconds_;
561 session::ExpirationTracker expirations_;
562 };
563
564} // namespace crow
JSON read value.
Definition json.h:278
int64_t i() const
The integer value.
Definition json.h:400
double d() const
The double precision floating-point number value.
Definition json.h:433
uint64_t u() const
The unsigned integer value.
Definition json.h:417
num_type nt() const
The number type of the JSON value.
Definition json.h:388
type t() const
The type of the JSON value.
Definition json.h:376
JSON write value.
Definition json.h:1291
The main namespace of the library. In this namespace is defined the most important classes and functi...
Definition session.h:468
InMemoryStore stores all entries in memory.
Definition session.h:441
Definition session.h:198
Definition session.h:193
CachedSessions are shared across requests.
Definition session.h:173
Expiration tracker keeps track of soonest-to-expire keys.
Definition session.h:121
uint64_t peek_first() const
Get expiration time of soonest-to-expire entry.
Definition session.h:145
void add(std::string key, uint64_t time)
Definition session.h:126
A multi_value is a safe variant wrapper with json conversion support.
Definition session.h:58