mas_storage_pg/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11    Page, Pagination,
12    pagination::Node,
13    upstream_oauth2::{
14        UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
15    },
16};
17use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
18use rand::RngCore;
19use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
20use sea_query_binder::SqlxBinder;
21use sqlx::{PgConnection, types::Json};
22use tracing::{Instrument, info_span};
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27    DatabaseError, DatabaseInconsistencyError,
28    filter::{Filter, StatementExt},
29    iden::UpstreamOAuthProviders,
30    pagination::QueryBuilderExt,
31    tracing::ExecuteExt,
32};
33
34/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
35/// connection
36pub struct PgUpstreamOAuthProviderRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgUpstreamOAuthProviderRepository<'c> {
41    /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active
42    /// PostgreSQL connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48#[derive(sqlx::FromRow)]
49#[enum_def]
50struct ProviderLookup {
51    upstream_oauth_provider_id: Uuid,
52    issuer: Option<String>,
53    human_name: Option<String>,
54    brand_name: Option<String>,
55    scope: String,
56    client_id: String,
57    encrypted_client_secret: Option<String>,
58    token_endpoint_signing_alg: Option<String>,
59    token_endpoint_auth_method: String,
60    id_token_signed_response_alg: String,
61    fetch_userinfo: bool,
62    userinfo_signed_response_alg: Option<String>,
63    created_at: DateTime<Utc>,
64    disabled_at: Option<DateTime<Utc>>,
65    claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
66    jwks_uri_override: Option<String>,
67    authorization_endpoint_override: Option<String>,
68    token_endpoint_override: Option<String>,
69    userinfo_endpoint_override: Option<String>,
70    discovery_mode: String,
71    pkce_mode: String,
72    response_mode: Option<String>,
73    additional_parameters: Option<Json<Vec<(String, String)>>>,
74    forward_login_hint: bool,
75    on_backchannel_logout: String,
76}
77
78impl Node<Ulid> for ProviderLookup {
79    fn cursor(&self) -> Ulid {
80        self.upstream_oauth_provider_id.into()
81    }
82}
83
84impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
85    type Error = DatabaseInconsistencyError;
86
87    fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
88        let id = value.upstream_oauth_provider_id.into();
89        let scope = value.scope.parse().map_err(|e| {
90            DatabaseInconsistencyError::on("upstream_oauth_providers")
91                .column("scope")
92                .row(id)
93                .source(e)
94        })?;
95        let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
96            DatabaseInconsistencyError::on("upstream_oauth_providers")
97                .column("token_endpoint_auth_method")
98                .row(id)
99                .source(e)
100        })?;
101        let token_endpoint_signing_alg = value
102            .token_endpoint_signing_alg
103            .map(|x| x.parse())
104            .transpose()
105            .map_err(|e| {
106                DatabaseInconsistencyError::on("upstream_oauth_providers")
107                    .column("token_endpoint_signing_alg")
108                    .row(id)
109                    .source(e)
110            })?;
111        let id_token_signed_response_alg =
112            value.id_token_signed_response_alg.parse().map_err(|e| {
113                DatabaseInconsistencyError::on("upstream_oauth_providers")
114                    .column("id_token_signed_response_alg")
115                    .row(id)
116                    .source(e)
117            })?;
118
119        let userinfo_signed_response_alg = value
120            .userinfo_signed_response_alg
121            .map(|x| x.parse())
122            .transpose()
123            .map_err(|e| {
124                DatabaseInconsistencyError::on("upstream_oauth_providers")
125                    .column("userinfo_signed_response_alg")
126                    .row(id)
127                    .source(e)
128            })?;
129
130        let authorization_endpoint_override = value
131            .authorization_endpoint_override
132            .map(|x| x.parse())
133            .transpose()
134            .map_err(|e| {
135                DatabaseInconsistencyError::on("upstream_oauth_providers")
136                    .column("authorization_endpoint_override")
137                    .row(id)
138                    .source(e)
139            })?;
140
141        let token_endpoint_override = value
142            .token_endpoint_override
143            .map(|x| x.parse())
144            .transpose()
145            .map_err(|e| {
146                DatabaseInconsistencyError::on("upstream_oauth_providers")
147                    .column("token_endpoint_override")
148                    .row(id)
149                    .source(e)
150            })?;
151
152        let userinfo_endpoint_override = value
153            .userinfo_endpoint_override
154            .map(|x| x.parse())
155            .transpose()
156            .map_err(|e| {
157                DatabaseInconsistencyError::on("upstream_oauth_providers")
158                    .column("userinfo_endpoint_override")
159                    .row(id)
160                    .source(e)
161            })?;
162
163        let jwks_uri_override = value
164            .jwks_uri_override
165            .map(|x| x.parse())
166            .transpose()
167            .map_err(|e| {
168                DatabaseInconsistencyError::on("upstream_oauth_providers")
169                    .column("jwks_uri_override")
170                    .row(id)
171                    .source(e)
172            })?;
173
174        let discovery_mode = value.discovery_mode.parse().map_err(|e| {
175            DatabaseInconsistencyError::on("upstream_oauth_providers")
176                .column("discovery_mode")
177                .row(id)
178                .source(e)
179        })?;
180
181        let pkce_mode = value.pkce_mode.parse().map_err(|e| {
182            DatabaseInconsistencyError::on("upstream_oauth_providers")
183                .column("pkce_mode")
184                .row(id)
185                .source(e)
186        })?;
187
188        let response_mode = value
189            .response_mode
190            .map(|x| x.parse())
191            .transpose()
192            .map_err(|e| {
193                DatabaseInconsistencyError::on("upstream_oauth_providers")
194                    .column("response_mode")
195                    .row(id)
196                    .source(e)
197            })?;
198
199        let additional_authorization_parameters = value
200            .additional_parameters
201            .map(|Json(x)| x)
202            .unwrap_or_default();
203
204        let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
205            DatabaseInconsistencyError::on("upstream_oauth_providers")
206                .column("on_backchannel_logout")
207                .row(id)
208                .source(e)
209        })?;
210
211        Ok(UpstreamOAuthProvider {
212            id,
213            issuer: value.issuer,
214            human_name: value.human_name,
215            brand_name: value.brand_name,
216            scope,
217            client_id: value.client_id,
218            encrypted_client_secret: value.encrypted_client_secret,
219            token_endpoint_auth_method,
220            token_endpoint_signing_alg,
221            id_token_signed_response_alg,
222            fetch_userinfo: value.fetch_userinfo,
223            userinfo_signed_response_alg,
224            created_at: value.created_at,
225            disabled_at: value.disabled_at,
226            claims_imports: value.claims_imports.0,
227            authorization_endpoint_override,
228            token_endpoint_override,
229            userinfo_endpoint_override,
230            jwks_uri_override,
231            discovery_mode,
232            pkce_mode,
233            response_mode,
234            additional_authorization_parameters,
235            forward_login_hint: value.forward_login_hint,
236            on_backchannel_logout,
237        })
238    }
239}
240
241impl Filter for UpstreamOAuthProviderFilter<'_> {
242    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
243        sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
244            Expr::col((
245                UpstreamOAuthProviders::Table,
246                UpstreamOAuthProviders::DisabledAt,
247            ))
248            .is_null()
249            .eq(enabled)
250        }))
251    }
252}
253
254#[async_trait]
255impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
256    type Error = DatabaseError;
257
258    #[tracing::instrument(
259        name = "db.upstream_oauth_provider.lookup",
260        skip_all,
261        fields(
262            db.query.text,
263            upstream_oauth_provider.id = %id,
264        ),
265        err,
266    )]
267    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
268        let res = sqlx::query_as!(
269            ProviderLookup,
270            r#"
271                SELECT
272                    upstream_oauth_provider_id,
273                    issuer,
274                    human_name,
275                    brand_name,
276                    scope,
277                    client_id,
278                    encrypted_client_secret,
279                    token_endpoint_signing_alg,
280                    token_endpoint_auth_method,
281                    id_token_signed_response_alg,
282                    fetch_userinfo,
283                    userinfo_signed_response_alg,
284                    created_at,
285                    disabled_at,
286                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
287                    jwks_uri_override,
288                    authorization_endpoint_override,
289                    token_endpoint_override,
290                    userinfo_endpoint_override,
291                    discovery_mode,
292                    pkce_mode,
293                    response_mode,
294                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
295                    forward_login_hint,
296                    on_backchannel_logout
297                FROM upstream_oauth_providers
298                WHERE upstream_oauth_provider_id = $1
299            "#,
300            Uuid::from(id),
301        )
302        .traced()
303        .fetch_optional(&mut *self.conn)
304        .await?;
305
306        let res = res
307            .map(UpstreamOAuthProvider::try_from)
308            .transpose()
309            .map_err(DatabaseError::from)?;
310
311        Ok(res)
312    }
313
314    #[tracing::instrument(
315        name = "db.upstream_oauth_provider.add",
316        skip_all,
317        fields(
318            db.query.text,
319            upstream_oauth_provider.id,
320            upstream_oauth_provider.issuer = params.issuer,
321            upstream_oauth_provider.client_id = %params.client_id,
322        ),
323        err,
324    )]
325    async fn add(
326        &mut self,
327        rng: &mut (dyn RngCore + Send),
328        clock: &dyn Clock,
329        params: UpstreamOAuthProviderParams,
330    ) -> Result<UpstreamOAuthProvider, Self::Error> {
331        let created_at = clock.now();
332        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
333        tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
334
335        sqlx::query!(
336            r#"
337            INSERT INTO upstream_oauth_providers (
338                upstream_oauth_provider_id,
339                issuer,
340                human_name,
341                brand_name,
342                scope,
343                token_endpoint_auth_method,
344                token_endpoint_signing_alg,
345                id_token_signed_response_alg,
346                fetch_userinfo,
347                userinfo_signed_response_alg,
348                client_id,
349                encrypted_client_secret,
350                claims_imports,
351                authorization_endpoint_override,
352                token_endpoint_override,
353                userinfo_endpoint_override,
354                jwks_uri_override,
355                discovery_mode,
356                pkce_mode,
357                response_mode,
358                forward_login_hint,
359                on_backchannel_logout,
360                created_at
361            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
362                      $12, $13, $14, $15, $16, $17, $18, $19, $20,
363                      $21, $22, $23)
364        "#,
365            Uuid::from(id),
366            params.issuer.as_deref(),
367            params.human_name.as_deref(),
368            params.brand_name.as_deref(),
369            params.scope.to_string(),
370            params.token_endpoint_auth_method.to_string(),
371            params
372                .token_endpoint_signing_alg
373                .as_ref()
374                .map(ToString::to_string),
375            params.id_token_signed_response_alg.to_string(),
376            params.fetch_userinfo,
377            params
378                .userinfo_signed_response_alg
379                .as_ref()
380                .map(ToString::to_string),
381            &params.client_id,
382            params.encrypted_client_secret.as_deref(),
383            Json(&params.claims_imports) as _,
384            params
385                .authorization_endpoint_override
386                .as_ref()
387                .map(ToString::to_string),
388            params
389                .token_endpoint_override
390                .as_ref()
391                .map(ToString::to_string),
392            params
393                .userinfo_endpoint_override
394                .as_ref()
395                .map(ToString::to_string),
396            params.jwks_uri_override.as_ref().map(ToString::to_string),
397            params.discovery_mode.as_str(),
398            params.pkce_mode.as_str(),
399            params.response_mode.as_ref().map(ToString::to_string),
400            params.forward_login_hint,
401            params.on_backchannel_logout.as_str(),
402            created_at,
403        )
404        .traced()
405        .execute(&mut *self.conn)
406        .await?;
407
408        Ok(UpstreamOAuthProvider {
409            id,
410            issuer: params.issuer,
411            human_name: params.human_name,
412            brand_name: params.brand_name,
413            scope: params.scope,
414            client_id: params.client_id,
415            encrypted_client_secret: params.encrypted_client_secret,
416            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
417            token_endpoint_auth_method: params.token_endpoint_auth_method,
418            id_token_signed_response_alg: params.id_token_signed_response_alg,
419            fetch_userinfo: params.fetch_userinfo,
420            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
421            created_at,
422            disabled_at: None,
423            claims_imports: params.claims_imports,
424            authorization_endpoint_override: params.authorization_endpoint_override,
425            token_endpoint_override: params.token_endpoint_override,
426            userinfo_endpoint_override: params.userinfo_endpoint_override,
427            jwks_uri_override: params.jwks_uri_override,
428            discovery_mode: params.discovery_mode,
429            pkce_mode: params.pkce_mode,
430            response_mode: params.response_mode,
431            additional_authorization_parameters: params.additional_authorization_parameters,
432            on_backchannel_logout: params.on_backchannel_logout,
433            forward_login_hint: params.forward_login_hint,
434        })
435    }
436
437    #[tracing::instrument(
438        name = "db.upstream_oauth_provider.delete_by_id",
439        skip_all,
440        fields(
441            db.query.text,
442            upstream_oauth_provider.id = %id,
443        ),
444        err,
445    )]
446    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
447        // Delete the authorization sessions first, as they have a foreign key
448        // constraint on the links and the providers.
449        {
450            let span = info_span!(
451                "db.oauth2_client.delete_by_id.authorization_sessions",
452                upstream_oauth_provider.id = %id,
453                { DB_QUERY_TEXT } = tracing::field::Empty,
454            );
455            sqlx::query!(
456                r#"
457                    DELETE FROM upstream_oauth_authorization_sessions
458                    WHERE upstream_oauth_provider_id = $1
459                "#,
460                Uuid::from(id),
461            )
462            .record(&span)
463            .execute(&mut *self.conn)
464            .instrument(span)
465            .await?;
466        }
467
468        // Delete the links next, as they have a foreign key constraint on the
469        // providers.
470        {
471            let span = info_span!(
472                "db.oauth2_client.delete_by_id.links",
473                upstream_oauth_provider.id = %id,
474                { DB_QUERY_TEXT } = tracing::field::Empty,
475            );
476            sqlx::query!(
477                r#"
478                    DELETE FROM upstream_oauth_links
479                    WHERE upstream_oauth_provider_id = $1
480                "#,
481                Uuid::from(id),
482            )
483            .record(&span)
484            .execute(&mut *self.conn)
485            .instrument(span)
486            .await?;
487        }
488
489        let res = sqlx::query!(
490            r#"
491                DELETE FROM upstream_oauth_providers
492                WHERE upstream_oauth_provider_id = $1
493            "#,
494            Uuid::from(id),
495        )
496        .traced()
497        .execute(&mut *self.conn)
498        .await?;
499
500        DatabaseError::ensure_affected_rows(&res, 1)
501    }
502
503    #[tracing::instrument(
504        name = "db.upstream_oauth_provider.add",
505        skip_all,
506        fields(
507            db.query.text,
508            upstream_oauth_provider.id = %id,
509            upstream_oauth_provider.issuer = params.issuer,
510            upstream_oauth_provider.client_id = %params.client_id,
511        ),
512        err,
513    )]
514    async fn upsert(
515        &mut self,
516        clock: &dyn Clock,
517        id: Ulid,
518        params: UpstreamOAuthProviderParams,
519    ) -> Result<UpstreamOAuthProvider, Self::Error> {
520        let created_at = clock.now();
521
522        let created_at = sqlx::query_scalar!(
523            r#"
524                INSERT INTO upstream_oauth_providers (
525                    upstream_oauth_provider_id,
526                    issuer,
527                    human_name,
528                    brand_name,
529                    scope,
530                    token_endpoint_auth_method,
531                    token_endpoint_signing_alg,
532                    id_token_signed_response_alg,
533                    fetch_userinfo,
534                    userinfo_signed_response_alg,
535                    client_id,
536                    encrypted_client_secret,
537                    claims_imports,
538                    authorization_endpoint_override,
539                    token_endpoint_override,
540                    userinfo_endpoint_override,
541                    jwks_uri_override,
542                    discovery_mode,
543                    pkce_mode,
544                    response_mode,
545                    additional_parameters,
546                    forward_login_hint,
547                    ui_order,
548                    on_backchannel_logout,
549                    created_at
550                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
551                          $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
552                          $21, $22, $23, $24, $25)
553                ON CONFLICT (upstream_oauth_provider_id)
554                    DO UPDATE
555                    SET
556                        issuer = EXCLUDED.issuer,
557                        human_name = EXCLUDED.human_name,
558                        brand_name = EXCLUDED.brand_name,
559                        scope = EXCLUDED.scope,
560                        token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
561                        token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
562                        id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
563                        fetch_userinfo = EXCLUDED.fetch_userinfo,
564                        userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
565                        disabled_at = NULL,
566                        client_id = EXCLUDED.client_id,
567                        encrypted_client_secret = EXCLUDED.encrypted_client_secret,
568                        claims_imports = EXCLUDED.claims_imports,
569                        authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
570                        token_endpoint_override = EXCLUDED.token_endpoint_override,
571                        userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
572                        jwks_uri_override = EXCLUDED.jwks_uri_override,
573                        discovery_mode = EXCLUDED.discovery_mode,
574                        pkce_mode = EXCLUDED.pkce_mode,
575                        response_mode = EXCLUDED.response_mode,
576                        additional_parameters = EXCLUDED.additional_parameters,
577                        forward_login_hint = EXCLUDED.forward_login_hint,
578                        ui_order = EXCLUDED.ui_order,
579                        on_backchannel_logout = EXCLUDED.on_backchannel_logout
580                RETURNING created_at
581            "#,
582            Uuid::from(id),
583            params.issuer.as_deref(),
584            params.human_name.as_deref(),
585            params.brand_name.as_deref(),
586            params.scope.to_string(),
587            params.token_endpoint_auth_method.to_string(),
588            params
589                .token_endpoint_signing_alg
590                .as_ref()
591                .map(ToString::to_string),
592            params.id_token_signed_response_alg.to_string(),
593            params.fetch_userinfo,
594            params
595                .userinfo_signed_response_alg
596                .as_ref()
597                .map(ToString::to_string),
598            &params.client_id,
599            params.encrypted_client_secret.as_deref(),
600            Json(&params.claims_imports) as _,
601            params
602                .authorization_endpoint_override
603                .as_ref()
604                .map(ToString::to_string),
605            params
606                .token_endpoint_override
607                .as_ref()
608                .map(ToString::to_string),
609            params
610                .userinfo_endpoint_override
611                .as_ref()
612                .map(ToString::to_string),
613            params.jwks_uri_override.as_ref().map(ToString::to_string),
614            params.discovery_mode.as_str(),
615            params.pkce_mode.as_str(),
616            params.response_mode.as_ref().map(ToString::to_string),
617            Json(&params.additional_authorization_parameters) as _,
618            params.forward_login_hint,
619            params.ui_order,
620            params.on_backchannel_logout.as_str(),
621            created_at,
622        )
623        .traced()
624        .fetch_one(&mut *self.conn)
625        .await?;
626
627        Ok(UpstreamOAuthProvider {
628            id,
629            issuer: params.issuer,
630            human_name: params.human_name,
631            brand_name: params.brand_name,
632            scope: params.scope,
633            client_id: params.client_id,
634            encrypted_client_secret: params.encrypted_client_secret,
635            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
636            token_endpoint_auth_method: params.token_endpoint_auth_method,
637            id_token_signed_response_alg: params.id_token_signed_response_alg,
638            fetch_userinfo: params.fetch_userinfo,
639            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
640            created_at,
641            disabled_at: None,
642            claims_imports: params.claims_imports,
643            authorization_endpoint_override: params.authorization_endpoint_override,
644            token_endpoint_override: params.token_endpoint_override,
645            userinfo_endpoint_override: params.userinfo_endpoint_override,
646            jwks_uri_override: params.jwks_uri_override,
647            discovery_mode: params.discovery_mode,
648            pkce_mode: params.pkce_mode,
649            response_mode: params.response_mode,
650            additional_authorization_parameters: params.additional_authorization_parameters,
651            forward_login_hint: params.forward_login_hint,
652            on_backchannel_logout: params.on_backchannel_logout,
653        })
654    }
655
656    #[tracing::instrument(
657        name = "db.upstream_oauth_provider.disable",
658        skip_all,
659        fields(
660            db.query.text,
661            %upstream_oauth_provider.id,
662        ),
663        err,
664    )]
665    async fn disable(
666        &mut self,
667        clock: &dyn Clock,
668        mut upstream_oauth_provider: UpstreamOAuthProvider,
669    ) -> Result<UpstreamOAuthProvider, Self::Error> {
670        let disabled_at = clock.now();
671        let res = sqlx::query!(
672            r#"
673                UPDATE upstream_oauth_providers
674                SET disabled_at = $2
675                WHERE upstream_oauth_provider_id = $1
676            "#,
677            Uuid::from(upstream_oauth_provider.id),
678            disabled_at,
679        )
680        .traced()
681        .execute(&mut *self.conn)
682        .await?;
683
684        DatabaseError::ensure_affected_rows(&res, 1)?;
685
686        upstream_oauth_provider.disabled_at = Some(disabled_at);
687
688        Ok(upstream_oauth_provider)
689    }
690
691    #[tracing::instrument(
692        name = "db.upstream_oauth_provider.list",
693        skip_all,
694        fields(
695            db.query.text,
696        ),
697        err,
698    )]
699    async fn list(
700        &mut self,
701        filter: UpstreamOAuthProviderFilter<'_>,
702        pagination: Pagination,
703    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
704        let (sql, arguments) = Query::select()
705            .expr_as(
706                Expr::col((
707                    UpstreamOAuthProviders::Table,
708                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
709                )),
710                ProviderLookupIden::UpstreamOauthProviderId,
711            )
712            .expr_as(
713                Expr::col((
714                    UpstreamOAuthProviders::Table,
715                    UpstreamOAuthProviders::Issuer,
716                )),
717                ProviderLookupIden::Issuer,
718            )
719            .expr_as(
720                Expr::col((
721                    UpstreamOAuthProviders::Table,
722                    UpstreamOAuthProviders::HumanName,
723                )),
724                ProviderLookupIden::HumanName,
725            )
726            .expr_as(
727                Expr::col((
728                    UpstreamOAuthProviders::Table,
729                    UpstreamOAuthProviders::BrandName,
730                )),
731                ProviderLookupIden::BrandName,
732            )
733            .expr_as(
734                Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
735                ProviderLookupIden::Scope,
736            )
737            .expr_as(
738                Expr::col((
739                    UpstreamOAuthProviders::Table,
740                    UpstreamOAuthProviders::ClientId,
741                )),
742                ProviderLookupIden::ClientId,
743            )
744            .expr_as(
745                Expr::col((
746                    UpstreamOAuthProviders::Table,
747                    UpstreamOAuthProviders::EncryptedClientSecret,
748                )),
749                ProviderLookupIden::EncryptedClientSecret,
750            )
751            .expr_as(
752                Expr::col((
753                    UpstreamOAuthProviders::Table,
754                    UpstreamOAuthProviders::TokenEndpointSigningAlg,
755                )),
756                ProviderLookupIden::TokenEndpointSigningAlg,
757            )
758            .expr_as(
759                Expr::col((
760                    UpstreamOAuthProviders::Table,
761                    UpstreamOAuthProviders::TokenEndpointAuthMethod,
762                )),
763                ProviderLookupIden::TokenEndpointAuthMethod,
764            )
765            .expr_as(
766                Expr::col((
767                    UpstreamOAuthProviders::Table,
768                    UpstreamOAuthProviders::IdTokenSignedResponseAlg,
769                )),
770                ProviderLookupIden::IdTokenSignedResponseAlg,
771            )
772            .expr_as(
773                Expr::col((
774                    UpstreamOAuthProviders::Table,
775                    UpstreamOAuthProviders::FetchUserinfo,
776                )),
777                ProviderLookupIden::FetchUserinfo,
778            )
779            .expr_as(
780                Expr::col((
781                    UpstreamOAuthProviders::Table,
782                    UpstreamOAuthProviders::UserinfoSignedResponseAlg,
783                )),
784                ProviderLookupIden::UserinfoSignedResponseAlg,
785            )
786            .expr_as(
787                Expr::col((
788                    UpstreamOAuthProviders::Table,
789                    UpstreamOAuthProviders::CreatedAt,
790                )),
791                ProviderLookupIden::CreatedAt,
792            )
793            .expr_as(
794                Expr::col((
795                    UpstreamOAuthProviders::Table,
796                    UpstreamOAuthProviders::DisabledAt,
797                )),
798                ProviderLookupIden::DisabledAt,
799            )
800            .expr_as(
801                Expr::col((
802                    UpstreamOAuthProviders::Table,
803                    UpstreamOAuthProviders::ClaimsImports,
804                )),
805                ProviderLookupIden::ClaimsImports,
806            )
807            .expr_as(
808                Expr::col((
809                    UpstreamOAuthProviders::Table,
810                    UpstreamOAuthProviders::JwksUriOverride,
811                )),
812                ProviderLookupIden::JwksUriOverride,
813            )
814            .expr_as(
815                Expr::col((
816                    UpstreamOAuthProviders::Table,
817                    UpstreamOAuthProviders::TokenEndpointOverride,
818                )),
819                ProviderLookupIden::TokenEndpointOverride,
820            )
821            .expr_as(
822                Expr::col((
823                    UpstreamOAuthProviders::Table,
824                    UpstreamOAuthProviders::AuthorizationEndpointOverride,
825                )),
826                ProviderLookupIden::AuthorizationEndpointOverride,
827            )
828            .expr_as(
829                Expr::col((
830                    UpstreamOAuthProviders::Table,
831                    UpstreamOAuthProviders::UserinfoEndpointOverride,
832                )),
833                ProviderLookupIden::UserinfoEndpointOverride,
834            )
835            .expr_as(
836                Expr::col((
837                    UpstreamOAuthProviders::Table,
838                    UpstreamOAuthProviders::DiscoveryMode,
839                )),
840                ProviderLookupIden::DiscoveryMode,
841            )
842            .expr_as(
843                Expr::col((
844                    UpstreamOAuthProviders::Table,
845                    UpstreamOAuthProviders::PkceMode,
846                )),
847                ProviderLookupIden::PkceMode,
848            )
849            .expr_as(
850                Expr::col((
851                    UpstreamOAuthProviders::Table,
852                    UpstreamOAuthProviders::ResponseMode,
853                )),
854                ProviderLookupIden::ResponseMode,
855            )
856            .expr_as(
857                Expr::col((
858                    UpstreamOAuthProviders::Table,
859                    UpstreamOAuthProviders::AdditionalParameters,
860                )),
861                ProviderLookupIden::AdditionalParameters,
862            )
863            .expr_as(
864                Expr::col((
865                    UpstreamOAuthProviders::Table,
866                    UpstreamOAuthProviders::ForwardLoginHint,
867                )),
868                ProviderLookupIden::ForwardLoginHint,
869            )
870            .expr_as(
871                Expr::col((
872                    UpstreamOAuthProviders::Table,
873                    UpstreamOAuthProviders::OnBackchannelLogout,
874                )),
875                ProviderLookupIden::OnBackchannelLogout,
876            )
877            .from(UpstreamOAuthProviders::Table)
878            .apply_filter(filter)
879            .generate_pagination(
880                (
881                    UpstreamOAuthProviders::Table,
882                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
883                ),
884                pagination,
885            )
886            .build_sqlx(PostgresQueryBuilder);
887
888        let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
889            .traced()
890            .fetch_all(&mut *self.conn)
891            .await?;
892
893        let page = pagination
894            .process(edges)
895            .try_map(UpstreamOAuthProvider::try_from)?;
896
897        return Ok(page);
898    }
899
900    #[tracing::instrument(
901        name = "db.upstream_oauth_provider.count",
902        skip_all,
903        fields(
904            db.query.text,
905        ),
906        err,
907    )]
908    async fn count(
909        &mut self,
910        filter: UpstreamOAuthProviderFilter<'_>,
911    ) -> Result<usize, Self::Error> {
912        let (sql, arguments) = Query::select()
913            .expr(
914                Expr::col((
915                    UpstreamOAuthProviders::Table,
916                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
917                ))
918                .count(),
919            )
920            .from(UpstreamOAuthProviders::Table)
921            .apply_filter(filter)
922            .build_sqlx(PostgresQueryBuilder);
923
924        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
925            .traced()
926            .fetch_one(&mut *self.conn)
927            .await?;
928
929        count
930            .try_into()
931            .map_err(DatabaseError::to_invalid_operation)
932    }
933
934    #[tracing::instrument(
935        name = "db.upstream_oauth_provider.all_enabled",
936        skip_all,
937        fields(
938            db.query.text,
939        ),
940        err,
941    )]
942    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
943        let res = sqlx::query_as!(
944            ProviderLookup,
945            r#"
946                SELECT
947                    upstream_oauth_provider_id,
948                    issuer,
949                    human_name,
950                    brand_name,
951                    scope,
952                    client_id,
953                    encrypted_client_secret,
954                    token_endpoint_signing_alg,
955                    token_endpoint_auth_method,
956                    id_token_signed_response_alg,
957                    fetch_userinfo,
958                    userinfo_signed_response_alg,
959                    created_at,
960                    disabled_at,
961                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
962                    jwks_uri_override,
963                    authorization_endpoint_override,
964                    token_endpoint_override,
965                    userinfo_endpoint_override,
966                    discovery_mode,
967                    pkce_mode,
968                    response_mode,
969                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
970                    forward_login_hint,
971                    on_backchannel_logout
972                FROM upstream_oauth_providers
973                WHERE disabled_at IS NULL
974                ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
975            "#,
976        )
977        .traced()
978        .fetch_all(&mut *self.conn)
979        .await?;
980
981        let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
982        Ok(res?)
983    }
984}