1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
13 UpstreamOAuthAuthorizationSession, User,
14};
15use mas_storage::{
16 Page, Pagination,
17 pagination::Node,
18 user::{BrowserSessionFilter, BrowserSessionRepository},
19};
20use rand::RngCore;
21use sea_query::{Expr, PostgresQueryBuilder, Query};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::StatementExt,
30 iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgBrowserSessionRepository<'c> {
38 conn: &'c mut PgConnection,
39}
40
41impl<'c> PgBrowserSessionRepository<'c> {
42 pub fn new(conn: &'c mut PgConnection) -> Self {
45 Self { conn }
46 }
47}
48
49#[allow(clippy::struct_field_names)]
50#[derive(sqlx::FromRow)]
51#[sea_query::enum_def]
52struct SessionLookup {
53 user_session_id: Uuid,
54 user_session_created_at: DateTime<Utc>,
55 user_session_finished_at: Option<DateTime<Utc>>,
56 user_session_user_agent: Option<String>,
57 user_session_last_active_at: Option<DateTime<Utc>>,
58 user_session_last_active_ip: Option<IpAddr>,
59 user_id: Uuid,
60 user_username: String,
61 user_created_at: DateTime<Utc>,
62 user_locked_at: Option<DateTime<Utc>>,
63 user_deactivated_at: Option<DateTime<Utc>>,
64 user_can_request_admin: bool,
65 user_is_guest: bool,
66}
67
68impl Node<Ulid> for SessionLookup {
69 fn cursor(&self) -> Ulid {
70 self.user_id.into()
71 }
72}
73
74impl TryFrom<SessionLookup> for BrowserSession {
75 type Error = DatabaseInconsistencyError;
76
77 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
78 let id = Ulid::from(value.user_id);
79 let user = User {
80 id,
81 username: value.user_username,
82 sub: id.to_string(),
83 created_at: value.user_created_at,
84 locked_at: value.user_locked_at,
85 deactivated_at: value.user_deactivated_at,
86 can_request_admin: value.user_can_request_admin,
87 is_guest: value.user_is_guest,
88 };
89
90 Ok(BrowserSession {
91 id: value.user_session_id.into(),
92 user,
93 created_at: value.user_session_created_at,
94 finished_at: value.user_session_finished_at,
95 user_agent: value.user_session_user_agent,
96 last_active_at: value.user_session_last_active_at,
97 last_active_ip: value.user_session_last_active_ip,
98 })
99 }
100}
101
102struct AuthenticationLookup {
103 user_session_authentication_id: Uuid,
104 created_at: DateTime<Utc>,
105 user_password_id: Option<Uuid>,
106 upstream_oauth_authorization_session_id: Option<Uuid>,
107}
108
109impl TryFrom<AuthenticationLookup> for Authentication {
110 type Error = DatabaseInconsistencyError;
111
112 fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
113 let id = Ulid::from(value.user_session_authentication_id);
114 let authentication_method = match (
115 value.user_password_id.map(Into::into),
116 value
117 .upstream_oauth_authorization_session_id
118 .map(Into::into),
119 ) {
120 (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
121 (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
122 upstream_oauth2_session_id,
123 },
124 (None, None) => AuthenticationMethod::Unknown,
125 _ => {
126 return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
127 }
128 };
129
130 Ok(Authentication {
131 id,
132 created_at: value.created_at,
133 authentication_method,
134 })
135 }
136}
137
138impl crate::filter::Filter for BrowserSessionFilter<'_> {
139 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
140 sea_query::Condition::all()
141 .add_option(self.user().map(|user| {
142 Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
143 }))
144 .add_option(self.state().map(|state| {
145 if state.is_active() {
146 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
147 } else {
148 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
149 }
150 }))
151 .add_option(self.last_active_after().map(|last_active_after| {
152 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
153 }))
154 .add_option(self.last_active_before().map(|last_active_before| {
155 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
156 }))
157 .add_option(self.authenticated_by_upstream_sessions().map(|filter| {
158 let join_expr = Expr::col((
161 UserSessionAuthentications::Table,
162 UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
163 ))
164 .eq(Expr::col((
165 UpstreamOAuthAuthorizationSessions::Table,
166 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
167 )));
168
169 Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
170 Query::select()
171 .expr(Expr::col((
172 UserSessionAuthentications::Table,
173 UserSessionAuthentications::UserSessionId,
174 )))
175 .from(UserSessionAuthentications::Table)
176 .inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
177 .apply_filter(filter)
178 .take(),
179 )
180 }))
181 }
182}
183
184#[async_trait]
185impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
186 type Error = DatabaseError;
187
188 #[tracing::instrument(
189 name = "db.browser_session.lookup",
190 skip_all,
191 fields(
192 db.query.text,
193 user_session.id = %id,
194 ),
195 err,
196 )]
197 async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
198 let res = sqlx::query_as!(
199 SessionLookup,
200 r#"
201 SELECT s.user_session_id
202 , s.created_at AS "user_session_created_at"
203 , s.finished_at AS "user_session_finished_at"
204 , s.user_agent AS "user_session_user_agent"
205 , s.last_active_at AS "user_session_last_active_at"
206 , s.last_active_ip AS "user_session_last_active_ip: IpAddr"
207 , u.user_id
208 , u.username AS "user_username"
209 , u.created_at AS "user_created_at"
210 , u.locked_at AS "user_locked_at"
211 , u.deactivated_at AS "user_deactivated_at"
212 , u.can_request_admin AS "user_can_request_admin"
213 , u.is_guest AS "user_is_guest"
214 FROM user_sessions s
215 INNER JOIN users u
216 USING (user_id)
217 WHERE s.user_session_id = $1
218 "#,
219 Uuid::from(id),
220 )
221 .traced()
222 .fetch_optional(&mut *self.conn)
223 .await?;
224
225 let Some(res) = res else { return Ok(None) };
226
227 Ok(Some(res.try_into()?))
228 }
229
230 #[tracing::instrument(
231 name = "db.browser_session.add",
232 skip_all,
233 fields(
234 db.query.text,
235 %user.id,
236 user_session.id,
237 ),
238 err,
239 )]
240 async fn add(
241 &mut self,
242 rng: &mut (dyn RngCore + Send),
243 clock: &dyn Clock,
244 user: &User,
245 user_agent: Option<String>,
246 ) -> Result<BrowserSession, Self::Error> {
247 let created_at = clock.now();
248 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
249 tracing::Span::current().record("user_session.id", tracing::field::display(id));
250
251 sqlx::query!(
252 r#"
253 INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
254 VALUES ($1, $2, $3, $4)
255 "#,
256 Uuid::from(id),
257 Uuid::from(user.id),
258 created_at,
259 user_agent.as_deref(),
260 )
261 .traced()
262 .execute(&mut *self.conn)
263 .await?;
264
265 let session = BrowserSession {
266 id,
267 user: user.clone(),
269 created_at,
270 finished_at: None,
271 user_agent,
272 last_active_at: None,
273 last_active_ip: None,
274 };
275
276 Ok(session)
277 }
278
279 #[tracing::instrument(
280 name = "db.browser_session.finish",
281 skip_all,
282 fields(
283 db.query.text,
284 %user_session.id,
285 ),
286 err,
287 )]
288 async fn finish(
289 &mut self,
290 clock: &dyn Clock,
291 mut user_session: BrowserSession,
292 ) -> Result<BrowserSession, Self::Error> {
293 let finished_at = clock.now();
294 let res = sqlx::query!(
295 r#"
296 UPDATE user_sessions
297 SET finished_at = $1
298 WHERE user_session_id = $2
299 "#,
300 finished_at,
301 Uuid::from(user_session.id),
302 )
303 .traced()
304 .execute(&mut *self.conn)
305 .await?;
306
307 user_session.finished_at = Some(finished_at);
308
309 DatabaseError::ensure_affected_rows(&res, 1)?;
310
311 Ok(user_session)
312 }
313
314 #[tracing::instrument(
315 name = "db.browser_session.finish_bulk",
316 skip_all,
317 fields(
318 db.query.text,
319 ),
320 err,
321 )]
322 async fn finish_bulk(
323 &mut self,
324 clock: &dyn Clock,
325 filter: BrowserSessionFilter<'_>,
326 ) -> Result<usize, Self::Error> {
327 let finished_at = clock.now();
328 let (sql, arguments) = sea_query::Query::update()
329 .table(UserSessions::Table)
330 .value(UserSessions::FinishedAt, finished_at)
331 .apply_filter(filter)
332 .build_sqlx(PostgresQueryBuilder);
333
334 let res = sqlx::query_with(&sql, arguments)
335 .traced()
336 .execute(&mut *self.conn)
337 .await?;
338
339 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
340 }
341
342 #[tracing::instrument(
343 name = "db.browser_session.list",
344 skip_all,
345 fields(
346 db.query.text,
347 ),
348 err,
349 )]
350 async fn list(
351 &mut self,
352 filter: BrowserSessionFilter<'_>,
353 pagination: Pagination,
354 ) -> Result<Page<BrowserSession>, Self::Error> {
355 let (sql, arguments) = sea_query::Query::select()
356 .expr_as(
357 Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
358 SessionLookupIden::UserSessionId,
359 )
360 .expr_as(
361 Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
362 SessionLookupIden::UserSessionCreatedAt,
363 )
364 .expr_as(
365 Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
366 SessionLookupIden::UserSessionFinishedAt,
367 )
368 .expr_as(
369 Expr::col((UserSessions::Table, UserSessions::UserAgent)),
370 SessionLookupIden::UserSessionUserAgent,
371 )
372 .expr_as(
373 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
374 SessionLookupIden::UserSessionLastActiveAt,
375 )
376 .expr_as(
377 Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
378 SessionLookupIden::UserSessionLastActiveIp,
379 )
380 .expr_as(
381 Expr::col((Users::Table, Users::UserId)),
382 SessionLookupIden::UserId,
383 )
384 .expr_as(
385 Expr::col((Users::Table, Users::Username)),
386 SessionLookupIden::UserUsername,
387 )
388 .expr_as(
389 Expr::col((Users::Table, Users::CreatedAt)),
390 SessionLookupIden::UserCreatedAt,
391 )
392 .expr_as(
393 Expr::col((Users::Table, Users::LockedAt)),
394 SessionLookupIden::UserLockedAt,
395 )
396 .expr_as(
397 Expr::col((Users::Table, Users::DeactivatedAt)),
398 SessionLookupIden::UserDeactivatedAt,
399 )
400 .expr_as(
401 Expr::col((Users::Table, Users::CanRequestAdmin)),
402 SessionLookupIden::UserCanRequestAdmin,
403 )
404 .expr_as(
405 Expr::col((Users::Table, Users::IsGuest)),
406 SessionLookupIden::UserIsGuest,
407 )
408 .from(UserSessions::Table)
409 .inner_join(
410 Users::Table,
411 Expr::col((UserSessions::Table, UserSessions::UserId))
412 .equals((Users::Table, Users::UserId)),
413 )
414 .apply_filter(filter)
415 .generate_pagination(
416 (UserSessions::Table, UserSessions::UserSessionId),
417 pagination,
418 )
419 .build_sqlx(PostgresQueryBuilder);
420
421 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
422 .traced()
423 .fetch_all(&mut *self.conn)
424 .await?;
425
426 let page = pagination
427 .process(edges)
428 .try_map(BrowserSession::try_from)?;
429
430 Ok(page)
431 }
432
433 #[tracing::instrument(
434 name = "db.browser_session.count",
435 skip_all,
436 fields(
437 db.query.text,
438 ),
439 err,
440 )]
441 async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
442 let (sql, arguments) = sea_query::Query::select()
443 .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
444 .from(UserSessions::Table)
445 .apply_filter(filter)
446 .build_sqlx(PostgresQueryBuilder);
447
448 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
449 .traced()
450 .fetch_one(&mut *self.conn)
451 .await?;
452
453 count
454 .try_into()
455 .map_err(DatabaseError::to_invalid_operation)
456 }
457
458 #[tracing::instrument(
459 name = "db.browser_session.authenticate_with_password",
460 skip_all,
461 fields(
462 db.query.text,
463 %user_session.id,
464 %user_password.id,
465 user_session_authentication.id,
466 ),
467 err,
468 )]
469 async fn authenticate_with_password(
470 &mut self,
471 rng: &mut (dyn RngCore + Send),
472 clock: &dyn Clock,
473 user_session: &BrowserSession,
474 user_password: &Password,
475 ) -> Result<Authentication, Self::Error> {
476 let created_at = clock.now();
477 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
478 tracing::Span::current().record(
479 "user_session_authentication.id",
480 tracing::field::display(id),
481 );
482
483 sqlx::query!(
484 r#"
485 INSERT INTO user_session_authentications
486 (user_session_authentication_id, user_session_id, created_at, user_password_id)
487 VALUES ($1, $2, $3, $4)
488 "#,
489 Uuid::from(id),
490 Uuid::from(user_session.id),
491 created_at,
492 Uuid::from(user_password.id),
493 )
494 .traced()
495 .execute(&mut *self.conn)
496 .await?;
497
498 Ok(Authentication {
499 id,
500 created_at,
501 authentication_method: AuthenticationMethod::Password {
502 user_password_id: user_password.id,
503 },
504 })
505 }
506
507 #[tracing::instrument(
508 name = "db.browser_session.authenticate_with_upstream",
509 skip_all,
510 fields(
511 db.query.text,
512 %user_session.id,
513 %upstream_oauth_session.id,
514 user_session_authentication.id,
515 ),
516 err,
517 )]
518 async fn authenticate_with_upstream(
519 &mut self,
520 rng: &mut (dyn RngCore + Send),
521 clock: &dyn Clock,
522 user_session: &BrowserSession,
523 upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
524 ) -> Result<Authentication, Self::Error> {
525 let created_at = clock.now();
526 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
527 tracing::Span::current().record(
528 "user_session_authentication.id",
529 tracing::field::display(id),
530 );
531
532 sqlx::query!(
533 r#"
534 INSERT INTO user_session_authentications
535 (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
536 VALUES ($1, $2, $3, $4)
537 "#,
538 Uuid::from(id),
539 Uuid::from(user_session.id),
540 created_at,
541 Uuid::from(upstream_oauth_session.id),
542 )
543 .traced()
544 .execute(&mut *self.conn)
545 .await?;
546
547 Ok(Authentication {
548 id,
549 created_at,
550 authentication_method: AuthenticationMethod::UpstreamOAuth2 {
551 upstream_oauth2_session_id: upstream_oauth_session.id,
552 },
553 })
554 }
555
556 #[tracing::instrument(
557 name = "db.browser_session.get_last_authentication",
558 skip_all,
559 fields(
560 db.query.text,
561 %user_session.id,
562 ),
563 err,
564 )]
565 async fn get_last_authentication(
566 &mut self,
567 user_session: &BrowserSession,
568 ) -> Result<Option<Authentication>, Self::Error> {
569 let authentication = sqlx::query_as!(
570 AuthenticationLookup,
571 r#"
572 SELECT user_session_authentication_id
573 , created_at
574 , user_password_id
575 , upstream_oauth_authorization_session_id
576 FROM user_session_authentications
577 WHERE user_session_id = $1
578 ORDER BY created_at DESC
579 LIMIT 1
580 "#,
581 Uuid::from(user_session.id),
582 )
583 .traced()
584 .fetch_optional(&mut *self.conn)
585 .await?;
586
587 let Some(authentication) = authentication else {
588 return Ok(None);
589 };
590
591 let authentication = Authentication::try_from(authentication)?;
592 Ok(Some(authentication))
593 }
594
595 #[tracing::instrument(
596 name = "db.browser_session.record_batch_activity",
597 skip_all,
598 fields(
599 db.query.text,
600 ),
601 err,
602 )]
603 async fn record_batch_activity(
604 &mut self,
605 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
606 ) -> Result<(), Self::Error> {
607 activities.sort_unstable();
610 let mut ids = Vec::with_capacity(activities.len());
611 let mut last_activities = Vec::with_capacity(activities.len());
612 let mut ips = Vec::with_capacity(activities.len());
613
614 for (id, last_activity, ip) in activities {
615 ids.push(Uuid::from(id));
616 last_activities.push(last_activity);
617 ips.push(ip);
618 }
619
620 let res = sqlx::query!(
621 r#"
622 UPDATE user_sessions
623 SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
624 , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
625 FROM (
626 SELECT *
627 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
628 AS t(user_session_id, last_active_at, last_active_ip)
629 ) AS t
630 WHERE user_sessions.user_session_id = t.user_session_id
631 "#,
632 &ids,
633 &last_activities,
634 &ips as &[Option<IpAddr>],
635 )
636 .traced()
637 .execute(&mut *self.conn)
638 .await?;
639
640 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
641
642 Ok(())
643 }
644}