Crow  1.1
A C++ microframework for the web
session.h
1 #pragma once
2 
3 #include "crow/http_request.h"
4 #include "crow/http_response.h"
5 #include "crow/json.h"
6 #include "crow/utility.h"
7 #include "crow/middlewares/cookie_parser.h"
8 
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <set>
12 #include <queue>
13 
14 #include <memory>
15 #include <string>
16 #include <cstdio>
17 #include <mutex>
18 
19 #include <fstream>
20 #include <sstream>
21 
22 #include <type_traits>
23 #include <functional>
24 #include <chrono>
25 
26 #ifdef CROW_CAN_USE_CPP17
27 #include <variant>
28 #endif
29 
30 namespace
31 {
32  // convert all integer values to int64_t
33  template<typename T>
34  using wrap_integral_t = typename std::conditional<
35  std::is_integral<T>::value && !std::is_same<bool, T>::value
36  // except for uint64_t because that could lead to overflow on conversion
37  && !std::is_same<uint64_t, T>::value,
38  int64_t, T>::type;
39 
40  // convert char[]/char* to std::string
41  template<typename T>
42  using wrap_char_t = typename std::conditional<
43  std::is_same<typename std::decay<T>::type, char*>::value,
44  std::string, T>::type;
45 
46  // Upgrade to correct type for multi_variant use
47  template<typename T>
48  using wrap_mv_t = wrap_char_t<wrap_integral_t<T>>;
49 } // namespace
50 
51 namespace crow
52 {
53  namespace session
54  {
55 
56 #ifdef CROW_CAN_USE_CPP17
57  using multi_value_types = black_magic::S<bool, int64_t, double, std::string>;
58 
59  /// A multi_value is a safe variant wrapper with json conversion support
60  struct multi_value
61  {
62  json::wvalue json() const
63  {
64  // clang-format off
65  return std::visit([](auto arg) {
66  return json::wvalue(arg);
67  }, v_);
68  // clang-format on
69  }
70 
71  static multi_value from_json(const json::rvalue&);
72 
73  std::string string() const
74  {
75  // clang-format off
76  return std::visit([](auto arg) {
77  if constexpr (std::is_same_v<decltype(arg), std::string>)
78  return arg;
79  else
80  return std::to_string(arg);
81  }, v_);
82  // clang-format on
83  }
84 
85  template<typename T, typename RT = wrap_mv_t<T>>
86  RT get(const T& fallback)
87  {
88  if (const RT* val = std::get_if<RT>(&v_)) return *val;
89  return fallback;
90  }
91 
92  template<typename T, typename RT = wrap_mv_t<T>>
93  void set(T val)
94  {
95  v_ = RT(std::move(val));
96  }
97 
98  typename multi_value_types::rebind<std::variant> v_;
99  };
100 
101  inline multi_value multi_value::from_json(const json::rvalue& rv)
102  {
103  using namespace json;
104  switch (rv.t())
105  {
106  case type::Number:
107  {
108  if (rv.nt() == num_type::Floating_point || rv.nt() == num_type::Double_precision_floating_point)
109  return multi_value{rv.d()};
110  else if (rv.nt() == num_type::Unsigned_integer)
111  return multi_value{int64_t(rv.u())};
112  else
113  return multi_value{rv.i()};
114  }
115  case type::False: return multi_value{false};
116  case type::True: return multi_value{true};
117  case type::String: return multi_value{std::string(rv)};
118  default: return multi_value{false};
119  }
120  }
121 #else
122  // Fallback for C++11/14 that uses a raw json::wvalue internally.
123  // This implementation consumes significantly more memory
124  // than the variant-based version
125  struct multi_value
126  {
127  json::wvalue json() const { return v_; }
128 
129  static multi_value from_json(const json::rvalue&);
130 
131  std::string string() const { return v_.dump(); }
132 
133  template<typename T, typename RT = wrap_mv_t<T>>
134  RT get(const T& fallback)
135  {
136  return json::wvalue_reader{v_}.get((const RT&)(fallback));
137  }
138 
139  template<typename T, typename RT = wrap_mv_t<T>>
140  void set(T val)
141  {
142  v_ = RT(std::move(val));
143  }
144 
145  json::wvalue v_;
146  };
147 
148  inline multi_value multi_value::from_json(const json::rvalue& rv)
149  {
150  return {rv};
151  }
152 #endif
153 
154  /// Expiration tracker keeps track of soonest-to-expire keys
156  {
157  using DataPair = std::pair<uint64_t /*time*/, std::string /*key*/>;
158 
159  /// Add key with time to tracker.
160  /// If the key is already present, it will be updated
161  void add(std::string key, uint64_t time)
162  {
163  auto it = times_.find(key);
164  if (it != times_.end()) remove(key);
165  times_[key] = time;
166  queue_.insert({time, std::move(key)});
167  }
168 
169  void remove(const std::string& key)
170  {
171  auto it = times_.find(key);
172  if (it != times_.end())
173  {
174  queue_.erase({it->second, key});
175  times_.erase(it);
176  }
177  }
178 
179  /// Get expiration time of soonest-to-expire entry
180  uint64_t peek_first() const
181  {
182  if (queue_.empty()) return std::numeric_limits<uint64_t>::max();
183  return queue_.begin()->first;
184  }
185 
186  std::string pop_first()
187  {
188  auto it = times_.find(queue_.begin()->second);
189  auto key = it->first;
190  times_.erase(it);
191  queue_.erase(queue_.begin());
192  return key;
193  }
194 
195  using iterator = typename std::set<DataPair>::const_iterator;
196 
197  iterator begin() const { return queue_.cbegin(); }
198 
199  iterator end() const { return queue_.cend(); }
200 
201  private:
202  std::set<DataPair> queue_;
203  std::unordered_map<std::string, uint64_t> times_;
204  };
205 
206  /// CachedSessions are shared across requests
208  {
209  std::string session_id;
210  std::string requested_session_id; // session hasn't been created yet, but a key was requested
211 
212  std::unordered_map<std::string, multi_value> entries;
213  std::unordered_set<std::string> dirty; // values that were changed after last load
214 
215  void* store_data;
216  bool requested_refresh;
217 
218  // number of references held - used for correctly destroying the cache.
219  // No need to be atomic, all SessionMiddleware accesses are synchronized
220  int referrers;
221  std::recursive_mutex mutex;
222  };
223  } // namespace session
224 
225  // SessionMiddleware allows storing securely and easily small snippets of user information
226  template<typename Store>
228  {
229 #ifdef CROW_CAN_USE_CPP17
230  using lock = std::scoped_lock<std::mutex>;
231  using rc_lock = std::scoped_lock<std::recursive_mutex>;
232 #else
233  using lock = std::lock_guard<std::mutex>;
234  using rc_lock = std::lock_guard<std::recursive_mutex>;
235 #endif
236 
237  struct context
238  {
239  // Get a mutex for locking this session
240  std::recursive_mutex& mutex()
241  {
242  check_node();
243  return node->mutex;
244  }
245 
246  // Check whether this session is already present
247  bool exists() { return bool(node); }
248 
249  // Get a value by key or fallback if it doesn't exist or is of another type
250  template<typename F>
251  auto get(const std::string& key, const F& fallback = F())
252  // This trick lets the multi_value deduce the return type from the fallback
253  // which allows both:
254  // context.get<std::string>("key")
255  // context.get("key", "") -> char[] is transformed into string by multivalue
256  // to return a string
257  -> decltype(std::declval<session::multi_value>().get<F>(std::declval<F>()))
258  {
259  if (!node) return fallback;
260  rc_lock l(node->mutex);
261 
262  auto it = node->entries.find(key);
263  if (it != node->entries.end()) return it->second.get<F>(fallback);
264  return fallback;
265  }
266 
267  // Set a value by key
268  template<typename T>
269  void set(const std::string& key, T value)
270  {
271  check_node();
272  rc_lock l(node->mutex);
273 
274  node->dirty.insert(key);
275  node->entries[key].set(std::move(value));
276  }
277 
278  bool contains(const std::string& key)
279  {
280  if (!node) return false;
281  return node->entries.find(key) != node->entries.end();
282  }
283 
284  // Atomically mutate a value with a function
285  template<typename Func>
286  void apply(const std::string& key, const Func& f)
287  {
288  using traits = utility::function_traits<Func>;
289  using arg = typename std::decay<typename traits::template arg<0>>::type;
290  using retv = typename std::decay<typename traits::result_type>::type;
291  check_node();
292  rc_lock l(node->mutex);
293  node->dirty.insert(key);
294  node->entries[key].set<retv>(f(node->entries[key].get(arg{})));
295  }
296 
297  // Remove a value from the session
298  void remove(const std::string& key)
299  {
300  if (!node) return;
301  rc_lock l(node->mutex);
302  node->dirty.insert(key);
303  node->entries.erase(key);
304  }
305 
306  // Format value by key as a string
307  std::string string(const std::string& key)
308  {
309  if (!node) return "";
310  rc_lock l(node->mutex);
311 
312  auto it = node->entries.find(key);
313  if (it != node->entries.end()) return it->second.string();
314  return "";
315  }
316 
317  // Get a list of keys present in session
318  std::vector<std::string> keys()
319  {
320  if (!node) return {};
321  rc_lock l(node->mutex);
322 
323  std::vector<std::string> out;
324  for (const auto& p : node->entries)
325  out.push_back(p.first);
326  return out;
327  }
328 
329  // Delay expiration by issuing another cookie with an updated expiration time
330  // and notifying the store
331  void refresh_expiration()
332  {
333  if (!node) return;
334  node->requested_refresh = true;
335  }
336 
337  private:
338  friend struct SessionMiddleware;
339 
340  void check_node()
341  {
342  if (!node) node = std::make_shared<session::CachedSession>();
343  }
344 
345  std::shared_ptr<session::CachedSession> node;
346  };
347 
348  template<typename... Ts>
350  CookieParser::Cookie cookie,
351  int id_length,
352  Ts... ts):
353  id_length_(id_length),
354  cookie_(cookie),
355  store_(std::forward<Ts>(ts)...), mutex_(new std::mutex{})
356  {}
357 
358  template<typename... Ts>
359  SessionMiddleware(Ts... ts):
361  CookieParser::Cookie("session").path("/").max_age(/*month*/ 30 * 24 * 60 * 60),
362  /*id_length */ 20, // around 10^34 possible combinations, but small enough to fit into SSO
363  std::forward<Ts>(ts)...)
364  {}
365 
366  template<typename AllContext>
367  void before_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx)
368  {
369  lock l(*mutex_);
370 
371  auto& cookies = all_ctx.template get<CookieParser>();
372  auto session_id = load_id(cookies);
373  if (session_id == "") return;
374 
375  // search entry in cache
376  auto it = cache_.find(session_id);
377  if (it != cache_.end())
378  {
379  it->second->referrers++;
380  ctx.node = it->second;
381  return;
382  }
383 
384  // check this is a valid entry before loading
385  if (!store_.contains(session_id)) return;
386 
387  auto node = std::make_shared<session::CachedSession>();
388  node->session_id = session_id;
389  node->referrers = 1;
390 
391  try
392  {
393  store_.load(*node);
394  }
395  catch (...)
396  {
397  CROW_LOG_ERROR << "Exception occurred during session load";
398  return;
399  }
400 
401  ctx.node = node;
402  cache_[session_id] = node;
403  }
404 
405  template<typename AllContext>
406  void after_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx)
407  {
408  lock l(*mutex_);
409  if (!ctx.node || --ctx.node->referrers > 0) return;
410  ctx.node->requested_refresh |= ctx.node->session_id == "";
411 
412  // generate new id
413  if (ctx.node->session_id == "")
414  {
415  // check for requested id
416  ctx.node->session_id = std::move(ctx.node->requested_session_id);
417  if (ctx.node->session_id == "")
418  {
419  ctx.node->session_id = utility::random_alphanum(id_length_);
420  }
421  }
422  else
423  {
424  cache_.erase(ctx.node->session_id);
425  }
426 
427  if (ctx.node->requested_refresh)
428  {
429  auto& cookies = all_ctx.template get<CookieParser>();
430  store_id(cookies, ctx.node->session_id);
431  }
432 
433  try
434  {
435  store_.save(*ctx.node);
436  }
437  catch (...)
438  {
439  CROW_LOG_ERROR << "Exception occurred during session save";
440  return;
441  }
442  }
443 
444  private:
445  std::string next_id()
446  {
447  std::string id;
448  do
449  {
450  id = utility::random_alphanum(id_length_);
451  } while (store_.contains(id));
452  return id;
453  }
454 
455  std::string load_id(const CookieParser::context& cookies)
456  {
457  return cookies.get_cookie(cookie_.name());
458  }
459 
460  void store_id(CookieParser::context& cookies, const std::string& session_id)
461  {
462  cookie_.value(session_id);
463  cookies.set_cookie(cookie_);
464  }
465 
466  private:
467  int id_length_;
468 
469  // prototype for cookie
470  CookieParser::Cookie cookie_;
471 
472  Store store_;
473 
474  // mutexes are immovable
475  std::unique_ptr<std::mutex> mutex_;
476  std::unordered_map<std::string, std::shared_ptr<session::CachedSession>> cache_;
477  };
478 
479  /// InMemoryStore stores all entries in memory
481  {
482  // Load a value into the session cache.
483  // A load is always followed by a save, no loads happen consecutively
484  void load(session::CachedSession& cn)
485  {
486  // load & stores happen sequentially, so moving is safe
487  cn.entries = std::move(entries[cn.session_id]);
488  }
489 
490  // Persist session data
491  void save(session::CachedSession& cn)
492  {
493  entries[cn.session_id] = std::move(cn.entries);
494  // cn.dirty is a list of changed keys since the last load
495  }
496 
497  bool contains(const std::string& key)
498  {
499  return entries.count(key) > 0;
500  }
501 
502  std::unordered_map<std::string, std::unordered_map<std::string, session::multi_value>> entries;
503  };
504 
505  // FileStore stores all data as json files in a folder.
506  // Files are deleted after expiration. Expiration refreshes are automatically picked up.
507  struct FileStore
508  {
509  FileStore(const std::string& folder, uint64_t expiration_seconds = /*month*/ 30 * 24 * 60 * 60):
510  path_(folder), expiration_seconds_(expiration_seconds)
511  {
512  std::ifstream ifs(get_filename(".expirations", false));
513 
514  auto current_ts = chrono_time();
515  std::string key;
516  uint64_t time;
517  while (ifs >> key >> time)
518  {
519  if (current_ts > time)
520  {
521  evict(key);
522  }
523  else if (contains(key))
524  {
525  expirations_.add(key, time);
526  }
527  }
528  }
529 
530  ~FileStore()
531  {
532  std::ofstream ofs(get_filename(".expirations", false), std::ios::trunc);
533  for (const auto& p : expirations_)
534  ofs << p.second << " " << p.first << "\n";
535  }
536 
537  // Delete expired entries
538  // At most 3 to prevent freezes
539  void handle_expired()
540  {
541  int deleted = 0;
542  auto current_ts = chrono_time();
543  while (current_ts > expirations_.peek_first() && deleted < 3)
544  {
545  evict(expirations_.pop_first());
546  deleted++;
547  }
548  }
549 
550  void load(session::CachedSession& cn)
551  {
552  handle_expired();
553 
554  std::ifstream file(get_filename(cn.session_id));
555 
556  std::stringstream buffer;
557  buffer << file.rdbuf() << std::endl;
558 
559  for (const auto& p : json::load(buffer.str()))
560  cn.entries[p.key()] = session::multi_value::from_json(p);
561  }
562 
563  void save(session::CachedSession& cn)
564  {
565  if (cn.requested_refresh)
566  expirations_.add(cn.session_id, chrono_time() + expiration_seconds_);
567  if (cn.dirty.empty()) return;
568 
569  std::ofstream file(get_filename(cn.session_id));
570  json::wvalue jw;
571  for (const auto& p : cn.entries)
572  jw[p.first] = p.second.json();
573  file << jw.dump() << std::flush;
574  }
575 
576  std::string get_filename(const std::string& key, bool suffix = true)
577  {
578  return utility::join_path(path_, key + (suffix ? ".json" : ""));
579  }
580 
581  bool contains(const std::string& key)
582  {
583  std::ifstream file(get_filename(key));
584  return file.good();
585  }
586 
587  void evict(const std::string& key)
588  {
589  std::remove(get_filename(key).c_str());
590  }
591 
592  uint64_t chrono_time() const
593  {
594  return std::chrono::duration_cast<std::chrono::seconds>(
595  std::chrono::system_clock::now().time_since_epoch())
596  .count();
597  }
598 
599  std::string path_;
600  uint64_t expiration_seconds_;
601  session::ExpirationTracker expirations_;
602  };
603 
604 } // namespace crow
JSON read value.
Definition: json.h:276
JSON write value.
Definition: json.h:1289
The main namespace of the library. In this namespace is defined the most important classes and functi...
Definition: session.h:508
InMemoryStore stores all entries in memory.
Definition: session.h:481
Definition: session.h:238
Definition: session.h:228
Definition: json.h:2046
CachedSessions are shared across requests.
Definition: session.h:208
Expiration tracker keeps track of soonest-to-expire keys.
Definition: session.h:156
uint64_t peek_first() const
Get expiration time of soonest-to-expire entry.
Definition: session.h:180
void add(std::string key, uint64_t time)
Definition: session.h:161
Definition: session.h:126