use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{CompatAccessToken, CompatSession};
use mas_storage::{compat::CompatAccessTokenRepository, Clock};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError};
pub struct PgCompatAccessTokenRepository<'c> {
    conn: &'c mut PgConnection,
}
impl<'c> PgCompatAccessTokenRepository<'c> {
    pub fn new(conn: &'c mut PgConnection) -> Self {
        Self { conn }
    }
}
struct CompatAccessTokenLookup {
    compat_access_token_id: Uuid,
    access_token: String,
    created_at: DateTime<Utc>,
    expires_at: Option<DateTime<Utc>>,
    compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
    fn from(value: CompatAccessTokenLookup) -> Self {
        Self {
            id: value.compat_access_token_id.into(),
            session_id: value.compat_session_id.into(),
            token: value.access_token,
            created_at: value.created_at,
            expires_at: value.expires_at,
        }
    }
}
#[async_trait]
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
    type Error = DatabaseError;
    #[tracing::instrument(
        name = "db.compat_access_token.lookup",
        skip_all,
        fields(
            db.query.text,
            compat_session.id = %id,
        ),
        err,
    )]
    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
        let res = sqlx::query_as!(
            CompatAccessTokenLookup,
            r#"
                SELECT compat_access_token_id
                     , access_token
                     , created_at
                     , expires_at
                     , compat_session_id
                FROM compat_access_tokens
                WHERE compat_access_token_id = $1
            "#,
            Uuid::from(id),
        )
        .traced()
        .fetch_optional(&mut *self.conn)
        .await?;
        let Some(res) = res else { return Ok(None) };
        Ok(Some(res.into()))
    }
    #[tracing::instrument(
        name = "db.compat_access_token.find_by_token",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn find_by_token(
        &mut self,
        access_token: &str,
    ) -> Result<Option<CompatAccessToken>, Self::Error> {
        let res = sqlx::query_as!(
            CompatAccessTokenLookup,
            r#"
                SELECT compat_access_token_id
                     , access_token
                     , created_at
                     , expires_at
                     , compat_session_id
                FROM compat_access_tokens
                WHERE access_token = $1
            "#,
            access_token,
        )
        .traced()
        .fetch_optional(&mut *self.conn)
        .await?;
        let Some(res) = res else { return Ok(None) };
        Ok(Some(res.into()))
    }
    #[tracing::instrument(
        name = "db.compat_access_token.add",
        skip_all,
        fields(
            db.query.text,
            compat_access_token.id,
            %compat_session.id,
            user.id = %compat_session.user_id,
        ),
        err,
    )]
    async fn add(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
        compat_session: &CompatSession,
        token: String,
        expires_after: Option<Duration>,
    ) -> Result<CompatAccessToken, Self::Error> {
        let created_at = clock.now();
        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
        tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
        let expires_at = expires_after.map(|expires_after| created_at + expires_after);
        sqlx::query!(
            r#"
                INSERT INTO compat_access_tokens
                    (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
                VALUES ($1, $2, $3, $4, $5)
            "#,
            Uuid::from(id),
            Uuid::from(compat_session.id),
            token,
            created_at,
            expires_at,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        Ok(CompatAccessToken {
            id,
            session_id: compat_session.id,
            token,
            created_at,
            expires_at,
        })
    }
    #[tracing::instrument(
        name = "db.compat_access_token.expire",
        skip_all,
        fields(
            db.query.text,
            %compat_access_token.id,
            compat_session.id = %compat_access_token.session_id,
        ),
        err,
    )]
    async fn expire(
        &mut self,
        clock: &dyn Clock,
        mut compat_access_token: CompatAccessToken,
    ) -> Result<CompatAccessToken, Self::Error> {
        let expires_at = clock.now();
        let res = sqlx::query!(
            r#"
                UPDATE compat_access_tokens
                SET expires_at = $2
                WHERE compat_access_token_id = $1
            "#,
            Uuid::from(compat_access_token.id),
            expires_at,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        DatabaseError::ensure_affected_rows(&res, 1)?;
        compat_access_token.expires_at = Some(expires_at);
        Ok(compat_access_token)
    }
}