mas_storage_pg/upstream_oauth2/
link.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{Clock, UpstreamOAuthLink, UpstreamOAuthProvider, User};
10use mas_storage::{
11    Page, Pagination,
12    pagination::Node,
13    upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
14};
15use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use tracing::Instrument;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError,
26    filter::{Filter, StatementExt},
27    iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
33/// connection
34pub struct PgUpstreamOAuthLinkRepository<'c> {
35    conn: &'c mut PgConnection,
36}
37
38impl<'c> PgUpstreamOAuthLinkRepository<'c> {
39    /// Create a new [`PgUpstreamOAuthLinkRepository`] from an active PostgreSQL
40    /// connection
41    pub fn new(conn: &'c mut PgConnection) -> Self {
42        Self { conn }
43    }
44}
45
46#[derive(sqlx::FromRow)]
47#[enum_def]
48struct LinkLookup {
49    upstream_oauth_link_id: Uuid,
50    upstream_oauth_provider_id: Uuid,
51    user_id: Option<Uuid>,
52    subject: String,
53    human_account_name: Option<String>,
54    created_at: DateTime<Utc>,
55}
56
57impl Node<Ulid> for LinkLookup {
58    fn cursor(&self) -> Ulid {
59        self.upstream_oauth_link_id.into()
60    }
61}
62
63impl From<LinkLookup> for UpstreamOAuthLink {
64    fn from(value: LinkLookup) -> Self {
65        UpstreamOAuthLink {
66            id: Ulid::from(value.upstream_oauth_link_id),
67            provider_id: Ulid::from(value.upstream_oauth_provider_id),
68            user_id: value.user_id.map(Ulid::from),
69            subject: value.subject,
70            human_account_name: value.human_account_name,
71            created_at: value.created_at,
72        }
73    }
74}
75
76impl Filter for UpstreamOAuthLinkFilter<'_> {
77    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
78        sea_query::Condition::all()
79            .add_option(self.user().map(|user| {
80                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
81                    .eq(Uuid::from(user.id))
82            }))
83            .add_option(self.provider().map(|provider| {
84                Expr::col((
85                    UpstreamOAuthLinks::Table,
86                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
87                ))
88                .eq(Uuid::from(provider.id))
89            }))
90            .add_option(self.provider_enabled().map(|enabled| {
91                Expr::col((
92                    UpstreamOAuthLinks::Table,
93                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
94                ))
95                .eq(Expr::any(
96                    Query::select()
97                        .expr(Expr::col((
98                            UpstreamOAuthProviders::Table,
99                            UpstreamOAuthProviders::UpstreamOAuthProviderId,
100                        )))
101                        .from(UpstreamOAuthProviders::Table)
102                        .and_where(
103                            Expr::col((
104                                UpstreamOAuthProviders::Table,
105                                UpstreamOAuthProviders::DisabledAt,
106                            ))
107                            .is_null()
108                            .eq(enabled),
109                        )
110                        .take(),
111                ))
112            }))
113            .add_option(self.subject().map(|subject| {
114                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
115            }))
116    }
117}
118
119#[async_trait]
120impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
121    type Error = DatabaseError;
122
123    #[tracing::instrument(
124        name = "db.upstream_oauth_link.lookup",
125        skip_all,
126        fields(
127            db.query.text,
128            upstream_oauth_link.id = %id,
129        ),
130        err,
131    )]
132    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
133        let res = sqlx::query_as!(
134            LinkLookup,
135            r#"
136                SELECT
137                    upstream_oauth_link_id,
138                    upstream_oauth_provider_id,
139                    user_id,
140                    subject,
141                    human_account_name,
142                    created_at
143                FROM upstream_oauth_links
144                WHERE upstream_oauth_link_id = $1
145            "#,
146            Uuid::from(id),
147        )
148        .traced()
149        .fetch_optional(&mut *self.conn)
150        .await?
151        .map(Into::into);
152
153        Ok(res)
154    }
155
156    #[tracing::instrument(
157        name = "db.upstream_oauth_link.find_by_subject",
158        skip_all,
159        fields(
160            db.query.text,
161            upstream_oauth_link.subject = subject,
162            %upstream_oauth_provider.id,
163            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
164            %upstream_oauth_provider.client_id,
165        ),
166        err,
167    )]
168    async fn find_by_subject(
169        &mut self,
170        upstream_oauth_provider: &UpstreamOAuthProvider,
171        subject: &str,
172    ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
173        let res = sqlx::query_as!(
174            LinkLookup,
175            r#"
176                SELECT
177                    upstream_oauth_link_id,
178                    upstream_oauth_provider_id,
179                    user_id,
180                    subject,
181                    human_account_name,
182                    created_at
183                FROM upstream_oauth_links
184                WHERE upstream_oauth_provider_id = $1
185                  AND subject = $2
186            "#,
187            Uuid::from(upstream_oauth_provider.id),
188            subject,
189        )
190        .traced()
191        .fetch_optional(&mut *self.conn)
192        .await?
193        .map(Into::into);
194
195        Ok(res)
196    }
197
198    #[tracing::instrument(
199        name = "db.upstream_oauth_link.add",
200        skip_all,
201        fields(
202            db.query.text,
203            upstream_oauth_link.id,
204            upstream_oauth_link.subject = subject,
205            upstream_oauth_link.human_account_name = human_account_name,
206            %upstream_oauth_provider.id,
207            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
208            %upstream_oauth_provider.client_id,
209        ),
210        err,
211    )]
212    async fn add(
213        &mut self,
214        rng: &mut (dyn RngCore + Send),
215        clock: &dyn Clock,
216        upstream_oauth_provider: &UpstreamOAuthProvider,
217        subject: String,
218        human_account_name: Option<String>,
219    ) -> Result<UpstreamOAuthLink, Self::Error> {
220        let created_at = clock.now();
221        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
222        tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
223
224        sqlx::query!(
225            r#"
226                INSERT INTO upstream_oauth_links (
227                    upstream_oauth_link_id,
228                    upstream_oauth_provider_id,
229                    user_id,
230                    subject,
231                    human_account_name,
232                    created_at
233                ) VALUES ($1, $2, NULL, $3, $4, $5)
234            "#,
235            Uuid::from(id),
236            Uuid::from(upstream_oauth_provider.id),
237            &subject,
238            human_account_name.as_deref(),
239            created_at,
240        )
241        .traced()
242        .execute(&mut *self.conn)
243        .await?;
244
245        Ok(UpstreamOAuthLink {
246            id,
247            provider_id: upstream_oauth_provider.id,
248            user_id: None,
249            subject,
250            human_account_name,
251            created_at,
252        })
253    }
254
255    #[tracing::instrument(
256        name = "db.upstream_oauth_link.associate_to_user",
257        skip_all,
258        fields(
259            db.query.text,
260            %upstream_oauth_link.id,
261            %upstream_oauth_link.subject,
262            %user.id,
263            %user.username,
264        ),
265        err,
266    )]
267    async fn associate_to_user(
268        &mut self,
269        upstream_oauth_link: &UpstreamOAuthLink,
270        user: &User,
271    ) -> Result<(), Self::Error> {
272        sqlx::query!(
273            r#"
274                UPDATE upstream_oauth_links
275                SET user_id = $1
276                WHERE upstream_oauth_link_id = $2
277            "#,
278            Uuid::from(user.id),
279            Uuid::from(upstream_oauth_link.id),
280        )
281        .traced()
282        .execute(&mut *self.conn)
283        .await?;
284
285        Ok(())
286    }
287
288    #[tracing::instrument(
289        name = "db.upstream_oauth_link.list",
290        skip_all,
291        fields(
292            db.query.text,
293        ),
294        err,
295    )]
296    async fn list(
297        &mut self,
298        filter: UpstreamOAuthLinkFilter<'_>,
299        pagination: Pagination,
300    ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
301        let (sql, arguments) = Query::select()
302            .expr_as(
303                Expr::col((
304                    UpstreamOAuthLinks::Table,
305                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
306                )),
307                LinkLookupIden::UpstreamOauthLinkId,
308            )
309            .expr_as(
310                Expr::col((
311                    UpstreamOAuthLinks::Table,
312                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
313                )),
314                LinkLookupIden::UpstreamOauthProviderId,
315            )
316            .expr_as(
317                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
318                LinkLookupIden::UserId,
319            )
320            .expr_as(
321                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
322                LinkLookupIden::Subject,
323            )
324            .expr_as(
325                Expr::col((
326                    UpstreamOAuthLinks::Table,
327                    UpstreamOAuthLinks::HumanAccountName,
328                )),
329                LinkLookupIden::HumanAccountName,
330            )
331            .expr_as(
332                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
333                LinkLookupIden::CreatedAt,
334            )
335            .from(UpstreamOAuthLinks::Table)
336            .apply_filter(filter)
337            .generate_pagination(
338                (
339                    UpstreamOAuthLinks::Table,
340                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
341                ),
342                pagination,
343            )
344            .build_sqlx(PostgresQueryBuilder);
345
346        let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
347            .traced()
348            .fetch_all(&mut *self.conn)
349            .await?;
350
351        let page = pagination.process(edges).map(UpstreamOAuthLink::from);
352
353        Ok(page)
354    }
355
356    #[tracing::instrument(
357        name = "db.upstream_oauth_link.count",
358        skip_all,
359        fields(
360            db.query.text,
361        ),
362        err,
363    )]
364    async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
365        let (sql, arguments) = Query::select()
366            .expr(
367                Expr::col((
368                    UpstreamOAuthLinks::Table,
369                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
370                ))
371                .count(),
372            )
373            .from(UpstreamOAuthLinks::Table)
374            .apply_filter(filter)
375            .build_sqlx(PostgresQueryBuilder);
376
377        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
378            .traced()
379            .fetch_one(&mut *self.conn)
380            .await?;
381
382        count
383            .try_into()
384            .map_err(DatabaseError::to_invalid_operation)
385    }
386
387    #[tracing::instrument(
388        name = "db.upstream_oauth_link.remove",
389        skip_all,
390        fields(
391            db.query.text,
392            upstream_oauth_link.id,
393            upstream_oauth_link.provider_id,
394            %upstream_oauth_link.subject,
395        ),
396        err,
397    )]
398    async fn remove(
399        &mut self,
400        clock: &dyn Clock,
401        upstream_oauth_link: UpstreamOAuthLink,
402    ) -> Result<(), Self::Error> {
403        // Unlink the authorization sessions first, as they have a foreign key
404        // constraint on the links.
405        let span = tracing::info_span!(
406            "db.upstream_oauth_link.remove.unlink",
407            { DB_QUERY_TEXT } = tracing::field::Empty
408        );
409        sqlx::query!(
410            r#"
411                UPDATE upstream_oauth_authorization_sessions SET
412                    upstream_oauth_link_id = NULL,
413                    unlinked_at = $2
414                WHERE upstream_oauth_link_id = $1
415            "#,
416            Uuid::from(upstream_oauth_link.id),
417            clock.now()
418        )
419        .record(&span)
420        .execute(&mut *self.conn)
421        .instrument(span)
422        .await?;
423
424        // Then delete the link itself
425        let span = tracing::info_span!(
426            "db.upstream_oauth_link.remove.delete",
427            { DB_QUERY_TEXT } = tracing::field::Empty
428        );
429        let res = sqlx::query!(
430            r#"
431                DELETE FROM upstream_oauth_links
432                WHERE upstream_oauth_link_id = $1
433            "#,
434            Uuid::from(upstream_oauth_link.id),
435        )
436        .record(&span)
437        .execute(&mut *self.conn)
438        .instrument(span)
439        .await?;
440
441        DatabaseError::ensure_affected_rows(&res, 1)?;
442
443        Ok(())
444    }
445}