Crow  1.1
A C++ microframework for the web
 
Loading...
Searching...
No Matches
session.h
1#pragma once
2
3#include "crow/http_request.h"
4#include "crow/http_response.h"
5#include "crow/json.h"
6#include "crow/utility.h"
7#include "crow/middlewares/cookie_parser.h"
8
9#include <unordered_map>
10#include <unordered_set>
11#include <set>
12#include <queue>
13
14#include <memory>
15#include <string>
16#include <cstdio>
17#include <mutex>
18
19#include <fstream>
20#include <sstream>
21
22#include <type_traits>
23#include <functional>
24#include <chrono>
25
26#ifdef CROW_CAN_USE_CPP17
27#include <variant>
28#endif
29
30namespace
31{
32 // convert all integer values to int64_t
33 template<typename T>
34 using wrap_integral_t = typename std::conditional<
35 std::is_integral<T>::value && !std::is_same<bool, T>::value
36 // except for uint64_t because that could lead to overflow on conversion
37 && !std::is_same<uint64_t, T>::value,
38 int64_t, T>::type;
39
40 // convert char[]/char* to std::string
41 template<typename T>
42 using wrap_char_t = typename std::conditional<
43 std::is_same<typename std::decay<T>::type, char*>::value,
44 std::string, T>::type;
45
46 // Upgrade to correct type for multi_variant use
47 template<typename T>
48 using wrap_mv_t = wrap_char_t<wrap_integral_t<T>>;
49} // namespace
50
51namespace crow
52{
53 namespace session
54 {
55
56#ifdef CROW_CAN_USE_CPP17
57 using multi_value_types = black_magic::S<bool, int64_t, double, std::string>;
58
59 /// A multi_value is a safe variant wrapper with json conversion support
60 struct multi_value
61 {
62 json::wvalue json() const
63 {
64 // clang-format off
65 return std::visit([](auto arg) {
66 return json::wvalue(arg);
67 }, v_);
68 // clang-format on
69 }
70
71 static multi_value from_json(const json::rvalue&);
72
73 std::string string() const
74 {
75 // clang-format off
76 return std::visit([](auto arg) {
77 if constexpr (std::is_same_v<decltype(arg), std::string>)
78 return arg;
79 else
80 return std::to_string(arg);
81 }, v_);
82 // clang-format on
83 }
84
85 template<typename T, typename RT = wrap_mv_t<T>>
86 RT get(const T& fallback)
87 {
88 if (const RT* val = std::get_if<RT>(&v_)) return *val;
89 return fallback;
90 }
91
92 template<typename T, typename RT = wrap_mv_t<T>>
93 void set(T val)
94 {
95 v_ = RT(std::move(val));
96 }
97
98 typename multi_value_types::rebind<std::variant> v_;
99 };
100
101 inline multi_value multi_value::from_json(const json::rvalue& rv)
102 {
103 using namespace json;
104 switch (rv.t())
105 {
106 case type::Number:
107 {
108 if (rv.nt() == num_type::Floating_point || rv.nt() == num_type::Double_precision_floating_point)
109 return multi_value{rv.d()};
110 else if (rv.nt() == num_type::Unsigned_integer)
111 return multi_value{int64_t(rv.u())};
112 else
113 return multi_value{rv.i()};
114 }
115 case type::False: return multi_value{false};
116 case type::True: return multi_value{true};
117 case type::String: return multi_value{std::string(rv)};
118 default: return multi_value{false};
119 }
120 }
121#else
122 // Fallback for C++11/14 that uses a raw json::wvalue internally.
123 // This implementation consumes significantly more memory
124 // than the variant-based version
126 {
127 json::wvalue json() const { return v_; }
128
129 static multi_value from_json(const json::rvalue&);
130
131 std::string string() const { return v_.dump(); }
132
133 template<typename T, typename RT = wrap_mv_t<T>>
134 RT get(const T& fallback)
135 {
136 return json::wvalue_reader{v_}.get((const RT&)(fallback));
137 }
138
139 template<typename T, typename RT = wrap_mv_t<T>>
140 void set(T val)
141 {
142 v_ = RT(std::move(val));
143 }
144
145 json::wvalue v_;
146 };
147
148 inline multi_value multi_value::from_json(const json::rvalue& rv)
149 {
150 return {rv};
151 }
152#endif
153
154 /// Expiration tracker keeps track of soonest-to-expire keys
156 {
157 using DataPair = std::pair<uint64_t /*time*/, std::string /*key*/>;
158
159 /// Add key with time to tracker.
160 /// If the key is already present, it will be updated
161 void add(std::string key, uint64_t time)
162 {
163 auto it = times_.find(key);
164 if (it != times_.end()) remove(key);
165 times_[key] = time;
166 queue_.insert({time, std::move(key)});
167 }
168
169 void remove(const std::string& key)
170 {
171 auto it = times_.find(key);
172 if (it != times_.end())
173 {
174 queue_.erase({it->second, key});
175 times_.erase(it);
176 }
177 }
178
179 /// Get expiration time of soonest-to-expire entry
180 uint64_t peek_first() const
181 {
182 if (queue_.empty()) return std::numeric_limits<uint64_t>::max();
183 return queue_.begin()->first;
184 }
185
186 std::string pop_first()
187 {
188 auto it = times_.find(queue_.begin()->second);
189 auto key = it->first;
190 times_.erase(it);
191 queue_.erase(queue_.begin());
192 return key;
193 }
194
195 using iterator = typename std::set<DataPair>::const_iterator;
196
197 iterator begin() const { return queue_.cbegin(); }
198
199 iterator end() const { return queue_.cend(); }
200
201 private:
202 std::set<DataPair> queue_;
203 std::unordered_map<std::string, uint64_t> times_;
204 };
205
206 /// CachedSessions are shared across requests
208 {
209 std::string session_id;
210 std::string requested_session_id; // session hasn't been created yet, but a key was requested
211
212 std::unordered_map<std::string, multi_value> entries;
213 std::unordered_set<std::string> dirty; // values that were changed after last load
214
215 void* store_data;
216 bool requested_refresh;
217
218 // number of references held - used for correctly destroying the cache.
219 // No need to be atomic, all SessionMiddleware accesses are synchronized
220 int referrers;
221 std::recursive_mutex mutex;
222 };
223 } // namespace session
224
225 // SessionMiddleware allows storing securely and easily small snippets of user information
226 template<typename Store>
228 {
229#ifdef CROW_CAN_USE_CPP17
230 using lock = std::scoped_lock<std::mutex>;
231 using rc_lock = std::scoped_lock<std::recursive_mutex>;
232#else
233 using lock = std::lock_guard<std::mutex>;
234 using rc_lock = std::lock_guard<std::recursive_mutex>;
235#endif
236
237 struct context
238 {
239 // Get a mutex for locking this session
240 std::recursive_mutex& mutex()
241 {
242 check_node();
243 return node->mutex;
244 }
245
246 // Check whether this session is already present
247 bool exists() { return bool(node); }
248
249 // Get a value by key or fallback if it doesn't exist or is of another type
250 template<typename F>
251 auto get(const std::string& key, const F& fallback = F())
252 // This trick lets the multi_value deduce the return type from the fallback
253 // which allows both:
254 // context.get<std::string>("key")
255 // context.get("key", "") -> char[] is transformed into string by multivalue
256 // to return a string
257 -> decltype(std::declval<session::multi_value>().get<F>(std::declval<F>()))
258 {
259 if (!node) return fallback;
260 rc_lock l(node->mutex);
261
262 auto it = node->entries.find(key);
263 if (it != node->entries.end()) return it->second.get<F>(fallback);
264 return fallback;
265 }
266
267 // Set a value by key
268 template<typename T>
269 void set(const std::string& key, T value)
270 {
271 check_node();
272 rc_lock l(node->mutex);
273
274 node->dirty.insert(key);
275 node->entries[key].set(std::move(value));
276 }
277
278 bool contains(const std::string& key)
279 {
280 if (!node) return false;
281 return node->entries.find(key) != node->entries.end();
282 }
283
284 // Atomically mutate a value with a function
285 template<typename Func>
286 void apply(const std::string& key, const Func& f)
287 {
288 using traits = utility::function_traits<Func>;
289 using arg = typename std::decay<typename traits::template arg<0>>::type;
290 using retv = typename std::decay<typename traits::result_type>::type;
291 check_node();
292 rc_lock l(node->mutex);
293 node->dirty.insert(key);
294 node->entries[key].set<retv>(f(node->entries[key].get(arg{})));
295 }
296
297 // Remove a value from the session
298 void remove(const std::string& key)
299 {
300 if (!node) return;
301 rc_lock l(node->mutex);
302 node->dirty.insert(key);
303 node->entries.erase(key);
304 }
305
306 // Format value by key as a string
307 std::string string(const std::string& key)
308 {
309 if (!node) return "";
310 rc_lock l(node->mutex);
311
312 auto it = node->entries.find(key);
313 if (it != node->entries.end()) return it->second.string();
314 return "";
315 }
316
317 // Get a list of keys present in session
318 std::vector<std::string> keys()
319 {
320 if (!node) return {};
321 rc_lock l(node->mutex);
322
323 std::vector<std::string> out;
324 for (const auto& p : node->entries)
325 out.push_back(p.first);
326 return out;
327 }
328
329 // Delay expiration by issuing another cookie with an updated expiration time
330 // and notifying the store
331 void refresh_expiration()
332 {
333 if (!node) return;
334 node->requested_refresh = true;
335 }
336
337 private:
338 friend struct SessionMiddleware;
339
340 void check_node()
341 {
342 if (!node) node = std::make_shared<session::CachedSession>();
343 }
344
345 std::shared_ptr<session::CachedSession> node;
346 };
347
348 template<typename... Ts>
351 int id_length,
352 Ts... ts):
353 id_length_(id_length),
354 cookie_(cookie),
355 store_(std::forward<Ts>(ts)...), mutex_(new std::mutex{})
356 {}
357
358 template<typename... Ts>
359 SessionMiddleware(Ts... ts):
361 CookieParser::Cookie("session").path("/").max_age(/*month*/ 30 * 24 * 60 * 60),
362 /*id_length */ 20, // around 10^34 possible combinations, but small enough to fit into SSO
363 std::forward<Ts>(ts)...)
364 {}
365
366 template<typename AllContext>
367 void before_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx)
368 {
369 lock l(*mutex_);
370
371 auto& cookies = all_ctx.template get<CookieParser>();
372 auto session_id = load_id(cookies);
373 if (session_id == "") return;
374
375 // search entry in cache
376 auto it = cache_.find(session_id);
377 if (it != cache_.end())
378 {
379 it->second->referrers++;
380 ctx.node = it->second;
381 return;
382 }
383
384 // check this is a valid entry before loading
385 if (!store_.contains(session_id)) return;
386
387 auto node = std::make_shared<session::CachedSession>();
388 node->session_id = session_id;
389 node->referrers = 1;
390
391 try
392 {
393 store_.load(*node);
394 }
395 catch (...)
396 {
397 CROW_LOG_ERROR << "Exception occurred during session load";
398 return;
399 }
400
401 ctx.node = node;
402 cache_[session_id] = node;
403 }
404
405 template<typename AllContext>
406 void after_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx)
407 {
408 lock l(*mutex_);
409 if (!ctx.node || --ctx.node->referrers > 0) return;
410 ctx.node->requested_refresh |= ctx.node->session_id == "";
411
412 // generate new id
413 if (ctx.node->session_id == "")
414 {
415 // check for requested id
416 ctx.node->session_id = std::move(ctx.node->requested_session_id);
417 if (ctx.node->session_id == "")
418 {
419 ctx.node->session_id = utility::random_alphanum(id_length_);
420 }
421 }
422 else
423 {
424 cache_.erase(ctx.node->session_id);
425 }
426
427 if (ctx.node->requested_refresh)
428 {
429 auto& cookies = all_ctx.template get<CookieParser>();
430 store_id(cookies, ctx.node->session_id);
431 }
432
433 try
434 {
435 store_.save(*ctx.node);
436 }
437 catch (...)
438 {
439 CROW_LOG_ERROR << "Exception occurred during session save";
440 return;
441 }
442 }
443
444 private:
445 std::string next_id()
446 {
447 std::string id;
448 do
449 {
450 id = utility::random_alphanum(id_length_);
451 } while (store_.contains(id));
452 return id;
453 }
454
455 std::string load_id(const CookieParser::context& cookies)
456 {
457 return cookies.get_cookie(cookie_.name());
458 }
459
460 void store_id(CookieParser::context& cookies, const std::string& session_id)
461 {
462 cookie_.value(session_id);
463 cookies.set_cookie(cookie_);
464 }
465
466 private:
467 int id_length_;
468
469 // prototype for cookie
470 CookieParser::Cookie cookie_;
471
472 Store store_;
473
474 // mutexes are immovable
475 std::unique_ptr<std::mutex> mutex_;
476 std::unordered_map<std::string, std::shared_ptr<session::CachedSession>> cache_;
477 };
478
479 /// InMemoryStore stores all entries in memory
481 {
482 // Load a value into the session cache.
483 // A load is always followed by a save, no loads happen consecutively
484 void load(session::CachedSession& cn)
485 {
486 // load & stores happen sequentially, so moving is safe
487 cn.entries = std::move(entries[cn.session_id]);
488 }
489
490 // Persist session data
491 void save(session::CachedSession& cn)
492 {
493 entries[cn.session_id] = std::move(cn.entries);
494 // cn.dirty is a list of changed keys since the last load
495 }
496
497 bool contains(const std::string& key)
498 {
499 return entries.count(key) > 0;
500 }
501
502 std::unordered_map<std::string, std::unordered_map<std::string, session::multi_value>> entries;
503 };
504
505 // FileStore stores all data as json files in a folder.
506 // Files are deleted after expiration. Expiration refreshes are automatically picked up.
508 {
509 FileStore(const std::string& folder, uint64_t expiration_seconds = /*month*/ 30 * 24 * 60 * 60):
510 path_(folder), expiration_seconds_(expiration_seconds)
511 {
512 std::ifstream ifs(get_filename(".expirations", false));
513
514 auto current_ts = chrono_time();
515 std::string key;
516 uint64_t time;
517 while (ifs >> key >> time)
518 {
519 if (current_ts > time)
520 {
521 evict(key);
522 }
523 else if (contains(key))
524 {
525 expirations_.add(key, time);
526 }
527 }
528 }
529
530 ~FileStore()
531 {
532 std::ofstream ofs(get_filename(".expirations", false), std::ios::trunc);
533 for (const auto& p : expirations_)
534 ofs << p.second << " " << p.first << "\n";
535 }
536
537 // Delete expired entries
538 // At most 3 to prevent freezes
539 void handle_expired()
540 {
541 int deleted = 0;
542 auto current_ts = chrono_time();
543 while (current_ts > expirations_.peek_first() && deleted < 3)
544 {
545 evict(expirations_.pop_first());
546 deleted++;
547 }
548 }
549
550 void load(session::CachedSession& cn)
551 {
552 handle_expired();
553
554 std::ifstream file(get_filename(cn.session_id));
555
556 std::stringstream buffer;
557 buffer << file.rdbuf() << std::endl;
558
559 for (const auto& p : json::load(buffer.str()))
560 cn.entries[p.key()] = session::multi_value::from_json(p);
561 }
562
563 void save(session::CachedSession& cn)
564 {
565 if (cn.requested_refresh)
566 expirations_.add(cn.session_id, chrono_time() + expiration_seconds_);
567 if (cn.dirty.empty()) return;
568
569 std::ofstream file(get_filename(cn.session_id));
570 json::wvalue jw;
571 for (const auto& p : cn.entries)
572 jw[p.first] = p.second.json();
573 file << jw.dump() << std::flush;
574 }
575
576 std::string get_filename(const std::string& key, bool suffix = true)
577 {
578 return utility::join_path(path_, key + (suffix ? ".json" : ""));
579 }
580
581 bool contains(const std::string& key)
582 {
583 std::ifstream file(get_filename(key));
584 return file.good();
585 }
586
587 void evict(const std::string& key)
588 {
589 std::remove(get_filename(key).c_str());
590 }
591
592 uint64_t chrono_time() const
593 {
594 return std::chrono::duration_cast<std::chrono::seconds>(
595 std::chrono::system_clock::now().time_since_epoch())
596 .count();
597 }
598
599 std::string path_;
600 uint64_t expiration_seconds_;
601 session::ExpirationTracker expirations_;
602 };
603
604} // namespace crow
JSON read value.
Definition json.h:276
JSON write value.
Definition json.h:1289
The main namespace of the library. In this namespace is defined the most important classes and functi...
Definition session.h:508
InMemoryStore stores all entries in memory.
Definition session.h:481
Definition session.h:238
Definition session.h:228
Definition json.h:2046
CachedSessions are shared across requests.
Definition session.h:208
Expiration tracker keeps track of soonest-to-expire keys.
Definition session.h:156
uint64_t peek_first() const
Get expiration time of soonest-to-expire entry.
Definition session.h:180
void add(std::string key, uint64_t time)
Definition session.h:161
Definition session.h:126