1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{Clock, UpstreamOAuthLink, UpstreamOAuthProvider, User};
10use mas_storage::{
11 Page, Pagination,
12 pagination::Node,
13 upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
14};
15use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use tracing::Instrument;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError,
26 filter::{Filter, StatementExt},
27 iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgUpstreamOAuthLinkRepository<'c> {
35 conn: &'c mut PgConnection,
36}
37
38impl<'c> PgUpstreamOAuthLinkRepository<'c> {
39 pub fn new(conn: &'c mut PgConnection) -> Self {
42 Self { conn }
43 }
44}
45
46#[derive(sqlx::FromRow)]
47#[enum_def]
48struct LinkLookup {
49 upstream_oauth_link_id: Uuid,
50 upstream_oauth_provider_id: Uuid,
51 user_id: Option<Uuid>,
52 subject: String,
53 human_account_name: Option<String>,
54 created_at: DateTime<Utc>,
55}
56
57impl Node<Ulid> for LinkLookup {
58 fn cursor(&self) -> Ulid {
59 self.upstream_oauth_link_id.into()
60 }
61}
62
63impl From<LinkLookup> for UpstreamOAuthLink {
64 fn from(value: LinkLookup) -> Self {
65 UpstreamOAuthLink {
66 id: Ulid::from(value.upstream_oauth_link_id),
67 provider_id: Ulid::from(value.upstream_oauth_provider_id),
68 user_id: value.user_id.map(Ulid::from),
69 subject: value.subject,
70 human_account_name: value.human_account_name,
71 created_at: value.created_at,
72 }
73 }
74}
75
76impl Filter for UpstreamOAuthLinkFilter<'_> {
77 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
78 sea_query::Condition::all()
79 .add_option(self.user().map(|user| {
80 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
81 .eq(Uuid::from(user.id))
82 }))
83 .add_option(self.provider().map(|provider| {
84 Expr::col((
85 UpstreamOAuthLinks::Table,
86 UpstreamOAuthLinks::UpstreamOAuthProviderId,
87 ))
88 .eq(Uuid::from(provider.id))
89 }))
90 .add_option(self.provider_enabled().map(|enabled| {
91 Expr::col((
92 UpstreamOAuthLinks::Table,
93 UpstreamOAuthLinks::UpstreamOAuthProviderId,
94 ))
95 .eq(Expr::any(
96 Query::select()
97 .expr(Expr::col((
98 UpstreamOAuthProviders::Table,
99 UpstreamOAuthProviders::UpstreamOAuthProviderId,
100 )))
101 .from(UpstreamOAuthProviders::Table)
102 .and_where(
103 Expr::col((
104 UpstreamOAuthProviders::Table,
105 UpstreamOAuthProviders::DisabledAt,
106 ))
107 .is_null()
108 .eq(enabled),
109 )
110 .take(),
111 ))
112 }))
113 .add_option(self.subject().map(|subject| {
114 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
115 }))
116 }
117}
118
119#[async_trait]
120impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
121 type Error = DatabaseError;
122
123 #[tracing::instrument(
124 name = "db.upstream_oauth_link.lookup",
125 skip_all,
126 fields(
127 db.query.text,
128 upstream_oauth_link.id = %id,
129 ),
130 err,
131 )]
132 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
133 let res = sqlx::query_as!(
134 LinkLookup,
135 r#"
136 SELECT
137 upstream_oauth_link_id,
138 upstream_oauth_provider_id,
139 user_id,
140 subject,
141 human_account_name,
142 created_at
143 FROM upstream_oauth_links
144 WHERE upstream_oauth_link_id = $1
145 "#,
146 Uuid::from(id),
147 )
148 .traced()
149 .fetch_optional(&mut *self.conn)
150 .await?
151 .map(Into::into);
152
153 Ok(res)
154 }
155
156 #[tracing::instrument(
157 name = "db.upstream_oauth_link.find_by_subject",
158 skip_all,
159 fields(
160 db.query.text,
161 upstream_oauth_link.subject = subject,
162 %upstream_oauth_provider.id,
163 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
164 %upstream_oauth_provider.client_id,
165 ),
166 err,
167 )]
168 async fn find_by_subject(
169 &mut self,
170 upstream_oauth_provider: &UpstreamOAuthProvider,
171 subject: &str,
172 ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
173 let res = sqlx::query_as!(
174 LinkLookup,
175 r#"
176 SELECT
177 upstream_oauth_link_id,
178 upstream_oauth_provider_id,
179 user_id,
180 subject,
181 human_account_name,
182 created_at
183 FROM upstream_oauth_links
184 WHERE upstream_oauth_provider_id = $1
185 AND subject = $2
186 "#,
187 Uuid::from(upstream_oauth_provider.id),
188 subject,
189 )
190 .traced()
191 .fetch_optional(&mut *self.conn)
192 .await?
193 .map(Into::into);
194
195 Ok(res)
196 }
197
198 #[tracing::instrument(
199 name = "db.upstream_oauth_link.add",
200 skip_all,
201 fields(
202 db.query.text,
203 upstream_oauth_link.id,
204 upstream_oauth_link.subject = subject,
205 upstream_oauth_link.human_account_name = human_account_name,
206 %upstream_oauth_provider.id,
207 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
208 %upstream_oauth_provider.client_id,
209 ),
210 err,
211 )]
212 async fn add(
213 &mut self,
214 rng: &mut (dyn RngCore + Send),
215 clock: &dyn Clock,
216 upstream_oauth_provider: &UpstreamOAuthProvider,
217 subject: String,
218 human_account_name: Option<String>,
219 ) -> Result<UpstreamOAuthLink, Self::Error> {
220 let created_at = clock.now();
221 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
222 tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
223
224 sqlx::query!(
225 r#"
226 INSERT INTO upstream_oauth_links (
227 upstream_oauth_link_id,
228 upstream_oauth_provider_id,
229 user_id,
230 subject,
231 human_account_name,
232 created_at
233 ) VALUES ($1, $2, NULL, $3, $4, $5)
234 "#,
235 Uuid::from(id),
236 Uuid::from(upstream_oauth_provider.id),
237 &subject,
238 human_account_name.as_deref(),
239 created_at,
240 )
241 .traced()
242 .execute(&mut *self.conn)
243 .await?;
244
245 Ok(UpstreamOAuthLink {
246 id,
247 provider_id: upstream_oauth_provider.id,
248 user_id: None,
249 subject,
250 human_account_name,
251 created_at,
252 })
253 }
254
255 #[tracing::instrument(
256 name = "db.upstream_oauth_link.associate_to_user",
257 skip_all,
258 fields(
259 db.query.text,
260 %upstream_oauth_link.id,
261 %upstream_oauth_link.subject,
262 %user.id,
263 %user.username,
264 ),
265 err,
266 )]
267 async fn associate_to_user(
268 &mut self,
269 upstream_oauth_link: &UpstreamOAuthLink,
270 user: &User,
271 ) -> Result<(), Self::Error> {
272 sqlx::query!(
273 r#"
274 UPDATE upstream_oauth_links
275 SET user_id = $1
276 WHERE upstream_oauth_link_id = $2
277 "#,
278 Uuid::from(user.id),
279 Uuid::from(upstream_oauth_link.id),
280 )
281 .traced()
282 .execute(&mut *self.conn)
283 .await?;
284
285 Ok(())
286 }
287
288 #[tracing::instrument(
289 name = "db.upstream_oauth_link.list",
290 skip_all,
291 fields(
292 db.query.text,
293 ),
294 err,
295 )]
296 async fn list(
297 &mut self,
298 filter: UpstreamOAuthLinkFilter<'_>,
299 pagination: Pagination,
300 ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
301 let (sql, arguments) = Query::select()
302 .expr_as(
303 Expr::col((
304 UpstreamOAuthLinks::Table,
305 UpstreamOAuthLinks::UpstreamOAuthLinkId,
306 )),
307 LinkLookupIden::UpstreamOauthLinkId,
308 )
309 .expr_as(
310 Expr::col((
311 UpstreamOAuthLinks::Table,
312 UpstreamOAuthLinks::UpstreamOAuthProviderId,
313 )),
314 LinkLookupIden::UpstreamOauthProviderId,
315 )
316 .expr_as(
317 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
318 LinkLookupIden::UserId,
319 )
320 .expr_as(
321 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
322 LinkLookupIden::Subject,
323 )
324 .expr_as(
325 Expr::col((
326 UpstreamOAuthLinks::Table,
327 UpstreamOAuthLinks::HumanAccountName,
328 )),
329 LinkLookupIden::HumanAccountName,
330 )
331 .expr_as(
332 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
333 LinkLookupIden::CreatedAt,
334 )
335 .from(UpstreamOAuthLinks::Table)
336 .apply_filter(filter)
337 .generate_pagination(
338 (
339 UpstreamOAuthLinks::Table,
340 UpstreamOAuthLinks::UpstreamOAuthLinkId,
341 ),
342 pagination,
343 )
344 .build_sqlx(PostgresQueryBuilder);
345
346 let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
347 .traced()
348 .fetch_all(&mut *self.conn)
349 .await?;
350
351 let page = pagination.process(edges).map(UpstreamOAuthLink::from);
352
353 Ok(page)
354 }
355
356 #[tracing::instrument(
357 name = "db.upstream_oauth_link.count",
358 skip_all,
359 fields(
360 db.query.text,
361 ),
362 err,
363 )]
364 async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
365 let (sql, arguments) = Query::select()
366 .expr(
367 Expr::col((
368 UpstreamOAuthLinks::Table,
369 UpstreamOAuthLinks::UpstreamOAuthLinkId,
370 ))
371 .count(),
372 )
373 .from(UpstreamOAuthLinks::Table)
374 .apply_filter(filter)
375 .build_sqlx(PostgresQueryBuilder);
376
377 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
378 .traced()
379 .fetch_one(&mut *self.conn)
380 .await?;
381
382 count
383 .try_into()
384 .map_err(DatabaseError::to_invalid_operation)
385 }
386
387 #[tracing::instrument(
388 name = "db.upstream_oauth_link.remove",
389 skip_all,
390 fields(
391 db.query.text,
392 upstream_oauth_link.id,
393 upstream_oauth_link.provider_id,
394 %upstream_oauth_link.subject,
395 ),
396 err,
397 )]
398 async fn remove(
399 &mut self,
400 clock: &dyn Clock,
401 upstream_oauth_link: UpstreamOAuthLink,
402 ) -> Result<(), Self::Error> {
403 let span = tracing::info_span!(
406 "db.upstream_oauth_link.remove.unlink",
407 { DB_QUERY_TEXT } = tracing::field::Empty
408 );
409 sqlx::query!(
410 r#"
411 UPDATE upstream_oauth_authorization_sessions SET
412 upstream_oauth_link_id = NULL,
413 unlinked_at = $2
414 WHERE upstream_oauth_link_id = $1
415 "#,
416 Uuid::from(upstream_oauth_link.id),
417 clock.now()
418 )
419 .record(&span)
420 .execute(&mut *self.conn)
421 .instrument(span)
422 .await?;
423
424 let span = tracing::info_span!(
426 "db.upstream_oauth_link.remove.delete",
427 { DB_QUERY_TEXT } = tracing::field::Empty
428 );
429 let res = sqlx::query!(
430 r#"
431 DELETE FROM upstream_oauth_links
432 WHERE upstream_oauth_link_id = $1
433 "#,
434 Uuid::from(upstream_oauth_link.id),
435 )
436 .record(&span)
437 .execute(&mut *self.conn)
438 .instrument(span)
439 .await?;
440
441 DatabaseError::ensure_affected_rows(&res, 1)?;
442
443 Ok(())
444 }
445}