1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Clock, Session, SessionState, User};
12use mas_storage::{
13 Page, Pagination,
14 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15 pagination::Node,
16};
17use oauth2_types::scope::{Scope, ScopeToken};
18use rand::RngCore;
19use sea_query::{
20 Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
21 extension::postgres::PgExpr,
22};
23use sea_query_binder::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29 DatabaseError, DatabaseInconsistencyError,
30 filter::{Filter, StatementExt},
31 iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
32 pagination::QueryBuilderExt,
33 tracing::ExecuteExt,
34};
35
36pub struct PgOAuth2SessionRepository<'c> {
38 conn: &'c mut PgConnection,
39}
40
41impl<'c> PgOAuth2SessionRepository<'c> {
42 pub fn new(conn: &'c mut PgConnection) -> Self {
45 Self { conn }
46 }
47}
48
49#[derive(sqlx::FromRow)]
50#[enum_def]
51struct OAuthSessionLookup {
52 oauth2_session_id: Uuid,
53 user_id: Option<Uuid>,
54 user_session_id: Option<Uuid>,
55 oauth2_client_id: Uuid,
56 scope_list: Vec<String>,
57 created_at: DateTime<Utc>,
58 finished_at: Option<DateTime<Utc>>,
59 user_agent: Option<String>,
60 last_active_at: Option<DateTime<Utc>>,
61 last_active_ip: Option<IpAddr>,
62 human_name: Option<String>,
63}
64
65impl Node<Ulid> for OAuthSessionLookup {
66 fn cursor(&self) -> Ulid {
67 self.oauth2_session_id.into()
68 }
69}
70
71impl TryFrom<OAuthSessionLookup> for Session {
72 type Error = DatabaseInconsistencyError;
73
74 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
75 let id = Ulid::from(value.oauth2_session_id);
76 let scope: Result<Scope, _> = value
77 .scope_list
78 .iter()
79 .map(|s| s.parse::<ScopeToken>())
80 .collect();
81 let scope = scope.map_err(|e| {
82 DatabaseInconsistencyError::on("oauth2_sessions")
83 .column("scope")
84 .row(id)
85 .source(e)
86 })?;
87
88 let state = match value.finished_at {
89 None => SessionState::Valid,
90 Some(finished_at) => SessionState::Finished { finished_at },
91 };
92
93 Ok(Session {
94 id,
95 state,
96 created_at: value.created_at,
97 client_id: value.oauth2_client_id.into(),
98 user_id: value.user_id.map(Ulid::from),
99 user_session_id: value.user_session_id.map(Ulid::from),
100 scope,
101 user_agent: value.user_agent,
102 last_active_at: value.last_active_at,
103 last_active_ip: value.last_active_ip,
104 human_name: value.human_name,
105 })
106 }
107}
108
109impl Filter for OAuth2SessionFilter<'_> {
110 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
111 sea_query::Condition::all()
112 .add_option(self.user().map(|user| {
113 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
114 }))
115 .add_option(self.client().map(|client| {
116 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
117 .eq(Uuid::from(client.id))
118 }))
119 .add_option(self.client_kind().map(|client_kind| {
120 let static_clients = Query::select()
124 .expr(Expr::col((
125 OAuth2Clients::Table,
126 OAuth2Clients::OAuth2ClientId,
127 )))
128 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
129 .from(OAuth2Clients::Table)
130 .take();
131 if client_kind.is_static() {
132 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
133 .eq(Expr::any(static_clients))
134 } else {
135 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
136 .ne(Expr::all(static_clients))
137 }
138 }))
139 .add_option(self.device().map(|device| -> SimpleExpr {
140 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
141 Condition::any()
142 .add(
143 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
144 OAuth2Sessions::Table,
145 OAuth2Sessions::ScopeList,
146 )))),
147 )
148 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
149 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
150 )))
151 .into()
152 } else {
153 Expr::val(false).into()
155 }
156 }))
157 .add_option(self.browser_session().map(|browser_session| {
158 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
159 .eq(Uuid::from(browser_session.id))
160 }))
161 .add_option(self.browser_session_filter().map(|browser_session_filter| {
162 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
163 Query::select()
164 .expr(Expr::col((
165 UserSessions::Table,
166 UserSessions::UserSessionId,
167 )))
168 .apply_filter(browser_session_filter)
169 .from(UserSessions::Table)
170 .take(),
171 )
172 }))
173 .add_option(self.state().map(|state| {
174 if state.is_active() {
175 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
176 } else {
177 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
178 }
179 }))
180 .add_option(self.scope().map(|scope| {
181 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
182 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
183 }))
184 .add_option(self.any_user().map(|any_user| {
185 if any_user {
186 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
187 } else {
188 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
189 }
190 }))
191 .add_option(self.last_active_after().map(|last_active_after| {
192 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
193 .gt(last_active_after)
194 }))
195 .add_option(self.last_active_before().map(|last_active_before| {
196 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
197 .lt(last_active_before)
198 }))
199 }
200}
201
202#[async_trait]
203impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
204 type Error = DatabaseError;
205
206 #[tracing::instrument(
207 name = "db.oauth2_session.lookup",
208 skip_all,
209 fields(
210 db.query.text,
211 session.id = %id,
212 ),
213 err,
214 )]
215 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
216 let res = sqlx::query_as!(
217 OAuthSessionLookup,
218 r#"
219 SELECT oauth2_session_id
220 , user_id
221 , user_session_id
222 , oauth2_client_id
223 , scope_list
224 , created_at
225 , finished_at
226 , user_agent
227 , last_active_at
228 , last_active_ip as "last_active_ip: IpAddr"
229 , human_name
230 FROM oauth2_sessions
231
232 WHERE oauth2_session_id = $1
233 "#,
234 Uuid::from(id),
235 )
236 .traced()
237 .fetch_optional(&mut *self.conn)
238 .await?;
239
240 let Some(session) = res else { return Ok(None) };
241
242 Ok(Some(session.try_into()?))
243 }
244
245 #[tracing::instrument(
246 name = "db.oauth2_session.add",
247 skip_all,
248 fields(
249 db.query.text,
250 %client.id,
251 session.id,
252 session.scope = %scope,
253 ),
254 err,
255 )]
256 async fn add(
257 &mut self,
258 rng: &mut (dyn RngCore + Send),
259 clock: &dyn Clock,
260 client: &Client,
261 user: Option<&User>,
262 user_session: Option<&BrowserSession>,
263 scope: Scope,
264 ) -> Result<Session, Self::Error> {
265 let created_at = clock.now();
266 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
267 tracing::Span::current().record("session.id", tracing::field::display(id));
268
269 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
270
271 sqlx::query!(
272 r#"
273 INSERT INTO oauth2_sessions
274 ( oauth2_session_id
275 , user_id
276 , user_session_id
277 , oauth2_client_id
278 , scope_list
279 , created_at
280 )
281 VALUES ($1, $2, $3, $4, $5, $6)
282 "#,
283 Uuid::from(id),
284 user.map(|u| Uuid::from(u.id)),
285 user_session.map(|s| Uuid::from(s.id)),
286 Uuid::from(client.id),
287 &scope_list,
288 created_at,
289 )
290 .traced()
291 .execute(&mut *self.conn)
292 .await?;
293
294 Ok(Session {
295 id,
296 state: SessionState::Valid,
297 created_at,
298 user_id: user.map(|u| u.id),
299 user_session_id: user_session.map(|s| s.id),
300 client_id: client.id,
301 scope,
302 user_agent: None,
303 last_active_at: None,
304 last_active_ip: None,
305 human_name: None,
306 })
307 }
308
309 #[tracing::instrument(
310 name = "db.oauth2_session.finish_bulk",
311 skip_all,
312 fields(
313 db.query.text,
314 ),
315 err,
316 )]
317 async fn finish_bulk(
318 &mut self,
319 clock: &dyn Clock,
320 filter: OAuth2SessionFilter<'_>,
321 ) -> Result<usize, Self::Error> {
322 let finished_at = clock.now();
323 let (sql, arguments) = Query::update()
324 .table(OAuth2Sessions::Table)
325 .value(OAuth2Sessions::FinishedAt, finished_at)
326 .apply_filter(filter)
327 .build_sqlx(PostgresQueryBuilder);
328
329 let res = sqlx::query_with(&sql, arguments)
330 .traced()
331 .execute(&mut *self.conn)
332 .await?;
333
334 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
335 }
336
337 #[tracing::instrument(
338 name = "db.oauth2_session.finish",
339 skip_all,
340 fields(
341 db.query.text,
342 %session.id,
343 %session.scope,
344 client.id = %session.client_id,
345 ),
346 err,
347 )]
348 async fn finish(
349 &mut self,
350 clock: &dyn Clock,
351 session: Session,
352 ) -> Result<Session, Self::Error> {
353 let finished_at = clock.now();
354 let res = sqlx::query!(
355 r#"
356 UPDATE oauth2_sessions
357 SET finished_at = $2
358 WHERE oauth2_session_id = $1
359 "#,
360 Uuid::from(session.id),
361 finished_at,
362 )
363 .traced()
364 .execute(&mut *self.conn)
365 .await?;
366
367 DatabaseError::ensure_affected_rows(&res, 1)?;
368
369 session
370 .finish(finished_at)
371 .map_err(DatabaseError::to_invalid_operation)
372 }
373
374 #[tracing::instrument(
375 name = "db.oauth2_session.list",
376 skip_all,
377 fields(
378 db.query.text,
379 ),
380 err,
381 )]
382 async fn list(
383 &mut self,
384 filter: OAuth2SessionFilter<'_>,
385 pagination: Pagination,
386 ) -> Result<Page<Session>, Self::Error> {
387 let (sql, arguments) = Query::select()
388 .expr_as(
389 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
390 OAuthSessionLookupIden::Oauth2SessionId,
391 )
392 .expr_as(
393 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
394 OAuthSessionLookupIden::UserId,
395 )
396 .expr_as(
397 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
398 OAuthSessionLookupIden::UserSessionId,
399 )
400 .expr_as(
401 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
402 OAuthSessionLookupIden::Oauth2ClientId,
403 )
404 .expr_as(
405 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
406 OAuthSessionLookupIden::ScopeList,
407 )
408 .expr_as(
409 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
410 OAuthSessionLookupIden::CreatedAt,
411 )
412 .expr_as(
413 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
414 OAuthSessionLookupIden::FinishedAt,
415 )
416 .expr_as(
417 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
418 OAuthSessionLookupIden::UserAgent,
419 )
420 .expr_as(
421 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
422 OAuthSessionLookupIden::LastActiveAt,
423 )
424 .expr_as(
425 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
426 OAuthSessionLookupIden::LastActiveIp,
427 )
428 .expr_as(
429 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
430 OAuthSessionLookupIden::HumanName,
431 )
432 .from(OAuth2Sessions::Table)
433 .apply_filter(filter)
434 .generate_pagination(
435 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
436 pagination,
437 )
438 .build_sqlx(PostgresQueryBuilder);
439
440 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
441 .traced()
442 .fetch_all(&mut *self.conn)
443 .await?;
444
445 let page = pagination.process(edges).try_map(Session::try_from)?;
446
447 Ok(page)
448 }
449
450 #[tracing::instrument(
451 name = "db.oauth2_session.count",
452 skip_all,
453 fields(
454 db.query.text,
455 ),
456 err,
457 )]
458 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
459 let (sql, arguments) = Query::select()
460 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
461 .from(OAuth2Sessions::Table)
462 .apply_filter(filter)
463 .build_sqlx(PostgresQueryBuilder);
464
465 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
466 .traced()
467 .fetch_one(&mut *self.conn)
468 .await?;
469
470 count
471 .try_into()
472 .map_err(DatabaseError::to_invalid_operation)
473 }
474
475 #[tracing::instrument(
476 name = "db.oauth2_session.record_batch_activity",
477 skip_all,
478 fields(
479 db.query.text,
480 ),
481 err,
482 )]
483 async fn record_batch_activity(
484 &mut self,
485 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
486 ) -> Result<(), Self::Error> {
487 activities.sort_unstable();
490 let mut ids = Vec::with_capacity(activities.len());
491 let mut last_activities = Vec::with_capacity(activities.len());
492 let mut ips = Vec::with_capacity(activities.len());
493
494 for (id, last_activity, ip) in activities {
495 ids.push(Uuid::from(id));
496 last_activities.push(last_activity);
497 ips.push(ip);
498 }
499
500 let res = sqlx::query!(
501 r#"
502 UPDATE oauth2_sessions
503 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
504 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
505 FROM (
506 SELECT *
507 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
508 AS t(oauth2_session_id, last_active_at, last_active_ip)
509 ) AS t
510 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
511 "#,
512 &ids,
513 &last_activities,
514 &ips as &[Option<IpAddr>],
515 )
516 .traced()
517 .execute(&mut *self.conn)
518 .await?;
519
520 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
521
522 Ok(())
523 }
524
525 #[tracing::instrument(
526 name = "db.oauth2_session.record_user_agent",
527 skip_all,
528 fields(
529 db.query.text,
530 %session.id,
531 %session.scope,
532 client.id = %session.client_id,
533 session.user_agent = user_agent,
534 ),
535 err,
536 )]
537 async fn record_user_agent(
538 &mut self,
539 mut session: Session,
540 user_agent: String,
541 ) -> Result<Session, Self::Error> {
542 let res = sqlx::query!(
543 r#"
544 UPDATE oauth2_sessions
545 SET user_agent = $2
546 WHERE oauth2_session_id = $1
547 "#,
548 Uuid::from(session.id),
549 &*user_agent,
550 )
551 .traced()
552 .execute(&mut *self.conn)
553 .await?;
554
555 session.user_agent = Some(user_agent);
556
557 DatabaseError::ensure_affected_rows(&res, 1)?;
558
559 Ok(session)
560 }
561
562 #[tracing::instrument(
563 name = "repository.oauth2_session.set_human_name",
564 skip(self),
565 fields(
566 client.id = %session.client_id,
567 session.human_name = ?human_name,
568 ),
569 err,
570 )]
571 async fn set_human_name(
572 &mut self,
573 mut session: Session,
574 human_name: Option<String>,
575 ) -> Result<Session, Self::Error> {
576 let res = sqlx::query!(
577 r#"
578 UPDATE oauth2_sessions
579 SET human_name = $2
580 WHERE oauth2_session_id = $1
581 "#,
582 Uuid::from(session.id),
583 human_name.as_deref(),
584 )
585 .traced()
586 .execute(&mut *self.conn)
587 .await?;
588
589 session.human_name = human_name;
590
591 DatabaseError::ensure_affected_rows(&res, 1)?;
592
593 Ok(session)
594 }
595}