DefaultOAuth2AuthorizationService.kt

package io.github.lishangbu.avalon.auth.service.impl

import io.github.lishangbu.avalon.auth.entity.OauthAuthorization
import io.github.lishangbu.avalon.auth.repository.Oauth2AuthorizationRepository
import org.springframework.dao.DataRetrievalFailureException
import org.springframework.security.jackson.SecurityJacksonModules
import org.springframework.security.oauth2.core.*
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
import org.springframework.security.oauth2.core.oidc.OidcIdToken
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository
import org.springframework.stereotype.Service
import org.springframework.transaction.annotation.Transactional
import tools.jackson.core.type.TypeReference
import tools.jackson.databind.json.JsonMapper
import java.time.Instant

/**
 * OAuth2 授权服务实现
 *
 * 负责 OAuth2Authorization 的持久化、查询与删除
 *
 * @author lishangbu
 * @since 2025/11/30
 */
@Service
class DefaultOAuth2AuthorizationService(
    /** OAuth2 授权仓储 */
    private val oauth2AuthorizationRepository: Oauth2AuthorizationRepository,
    /** 注册客户端仓储 */
    private val registeredClientRepository: RegisteredClientRepository,
) : OAuth2AuthorizationService {
    /** 保存默认 OAuth2 授权 */
    @Transactional(rollbackFor = [Exception::class])
    override fun save(authorization: OAuth2Authorization) {
        val entity = toEntity(authorization)
        oauth2AuthorizationRepository.save(entity)
    }

    /** 删除默认 OAuth2 授权 */
    @Transactional(rollbackFor = [Exception::class])
    override fun remove(authorization: OAuth2Authorization) {
        oauth2AuthorizationRepository.deleteById(authorization.id)
    }

    /** 按 ID 查询默认 OAuth2 授权 */
    override fun findById(id: String): OAuth2Authorization? {
        require(id.isNotBlank()) { "id cannot be empty" }
        return oauth2AuthorizationRepository.findNullable(id)?.let(::toObject)
    }

    /** 根据令牌查找默认 OAuth2 授权 */
    override fun findByToken(
        token: String,
        tokenType: OAuth2TokenType?,
    ): OAuth2Authorization? {
        require(token.isNotBlank()) { "token cannot be empty" }

        val result =
            when (tokenType?.value) {
                null -> {
                    oauth2AuthorizationRepository
                        .loadByTokenValue(
                            token,
                        )
                }

                OAuth2ParameterNames.STATE -> {
                    oauth2AuthorizationRepository.findByState(token)
                }

                OAuth2ParameterNames.CODE -> {
                    oauth2AuthorizationRepository.findByAuthorizationCodeValue(token)
                }

                OAuth2ParameterNames.ACCESS_TOKEN -> {
                    oauth2AuthorizationRepository.findByAccessTokenValue(token)
                }

                OAuth2ParameterNames.REFRESH_TOKEN -> {
                    oauth2AuthorizationRepository.findByRefreshTokenValue(token)
                }

                OidcParameterNames.ID_TOKEN -> {
                    oauth2AuthorizationRepository.findByOidcIdTokenValue(token)
                }

                OAuth2ParameterNames.USER_CODE -> {
                    oauth2AuthorizationRepository.findByUserCodeValue(token)
                }

                OAuth2ParameterNames.DEVICE_CODE -> {
                    oauth2AuthorizationRepository.findByDeviceCodeValue(token)
                }

                else -> {
                    null
                }
            }

        return result?.let(::toObject)
    }

    /** 返回转换为对象 */
    private fun toObject(entity: OauthAuthorization): OAuth2Authorization {
        val registeredClientId =
            requireNotNull(entity.registeredClientId) { "registeredClientId cannot be null" }
        val authorizationId = requireNotNull(entity.id) { "id cannot be null" }
        val principalName = requireNotNull(entity.principalName) { "principalName cannot be null" }
        val authorizationGrantType =
            requireNotNull(entity.authorizationGrantType) {
                "authorizationGrantType cannot be null"
            }

        val registeredClient: RegisteredClient? =
            registeredClientRepository.findById(registeredClientId)
        if (registeredClient == null) {
            throw DataRetrievalFailureException(
                "The RegisteredClient with id '" +
                    registeredClientId +
                    "' was not found in the Oauth2RegisteredClientRepository.",
            )
        }

        val builder =
            OAuth2Authorization
                .withRegisteredClient(registeredClient)
                .id(authorizationId)
                .principalName(principalName)
                .authorizationGrantType(resolveAuthorizationGrantType(authorizationGrantType))
                .authorizedScopes(readCommaDelimitedSet(entity.authorizedScopes))

        val attributes = readAttributes(entity.attributes)
        if (attributes != null) {
            builder.attributes { attributesHolder -> attributesHolder.putAll(attributes) }
        }
        if (entity.state != null) {
            builder.attribute(OAuth2ParameterNames.STATE, entity.state)
        }

        if (entity.authorizationCodeValue != null) {
            val authorizationCode =
                OAuth2AuthorizationCode(
                    entity.authorizationCodeValue,
                    entity.authorizationCodeIssuedAt,
                    entity.authorizationCodeExpiresAt,
                )
            val authorizationCodeMetadata = readAttributes(entity.authorizationCodeMetadata)
            builder.token(authorizationCode) { metadata ->
                if (authorizationCodeMetadata != null) {
                    metadata.putAll(authorizationCodeMetadata)
                }
            }
        }

        if (entity.accessTokenValue != null) {
            val accessToken =
                OAuth2AccessToken(
                    OAuth2AccessToken.TokenType.BEARER,
                    entity.accessTokenValue,
                    entity.accessTokenIssuedAt,
                    entity.accessTokenExpiresAt,
                    readCommaDelimitedSet(entity.accessTokenScopes),
                )
            val accessTokenMetadata = readAttributes(entity.accessTokenMetadata)
            builder.token(accessToken) { metadata ->
                if (accessTokenMetadata != null) {
                    metadata.putAll(accessTokenMetadata)
                }
            }
        }

        if (entity.refreshTokenValue != null) {
            val refreshToken =
                OAuth2RefreshToken(
                    entity.refreshTokenValue,
                    entity.refreshTokenIssuedAt,
                    entity.refreshTokenExpiresAt,
                )
            val refreshTokenMetadata = readAttributes(entity.refreshTokenMetadata)
            builder.token(refreshToken) { metadata ->
                if (refreshTokenMetadata != null) {
                    metadata.putAll(refreshTokenMetadata)
                }
            }
        }

        if (entity.oidcIdTokenValue != null) {
            val oidcIdTokenMetadata = readAttributes(entity.oidcIdTokenMetadata)
            val idToken =
                OidcIdToken(
                    entity.oidcIdTokenValue,
                    entity.oidcIdTokenIssuedAt,
                    entity.oidcIdTokenExpiresAt,
                    extractOidcIdTokenClaims(oidcIdTokenMetadata),
                )
            builder.token(idToken) { metadata ->
                if (oidcIdTokenMetadata != null) {
                    metadata.putAll(oidcIdTokenMetadata)
                }
            }
        }

        if (entity.userCodeValue != null) {
            val userCode =
                OAuth2UserCode(
                    entity.userCodeValue,
                    entity.userCodeIssuedAt,
                    entity.userCodeExpiresAt,
                )
            val userCodeMetadata = readAttributes(entity.userCodeMetadata)
            builder.token(userCode) { metadata ->
                if (userCodeMetadata != null) {
                    metadata.putAll(userCodeMetadata)
                }
            }
        }

        if (entity.deviceCodeValue != null) {
            val deviceCode =
                OAuth2DeviceCode(
                    entity.deviceCodeValue,
                    entity.deviceCodeIssuedAt,
                    entity.deviceCodeExpiresAt,
                )
            val deviceCodeMetadata = readAttributes(entity.deviceCodeMetadata)
            builder.token(deviceCode) { metadata ->
                if (deviceCodeMetadata != null) {
                    metadata.putAll(deviceCodeMetadata)
                }
            }
        }

        return builder.build()
    }

    /** 返回转换为实体 */
    private fun toEntity(authorization: OAuth2Authorization): OauthAuthorization {
        val authorizationCodeSnapshot =
            toTokenSnapshot(authorization.getToken(OAuth2AuthorizationCode::class.java))
        val accessToken =
            authorization.getToken(OAuth2AccessToken::class.java)
        val accessTokenSnapshot = toTokenSnapshot(accessToken)
        val refreshTokenSnapshot =
            toTokenSnapshot(authorization.getToken(OAuth2RefreshToken::class.java))
        val oidcIdTokenSnapshot =
            toTokenSnapshot(authorization.getToken(OidcIdToken::class.java))
        val userCodeSnapshot =
            toTokenSnapshot(authorization.getToken(OAuth2UserCode::class.java))
        val deviceCodeSnapshot =
            toTokenSnapshot(authorization.getToken(OAuth2DeviceCode::class.java))

        return OauthAuthorization {
            id = authorization.id
            registeredClientId = authorization.registeredClientId
            principalName = authorization.principalName
            authorizationGrantType = authorization.authorizationGrantType.value
            authorizedScopes = authorization.authorizedScopes.joinToString(",")
            attributes = writeAttributes(authorization.attributes)
            state = authorization.getAttribute(OAuth2ParameterNames.STATE)

            authorizationCodeValue = authorizationCodeSnapshot.value
            authorizationCodeIssuedAt = authorizationCodeSnapshot.issuedAt
            authorizationCodeExpiresAt = authorizationCodeSnapshot.expiresAt
            authorizationCodeMetadata = writeAttributes(authorizationCodeSnapshot.metadata)

            accessTokenValue = accessTokenSnapshot.value
            accessTokenIssuedAt = accessTokenSnapshot.issuedAt
            accessTokenExpiresAt = accessTokenSnapshot.expiresAt
            accessTokenMetadata = writeAttributes(accessTokenSnapshot.metadata)
            accessTokenScopes = accessToken?.token?.scopes?.joinToString(",")
            accessTokenType = accessToken?.token?.tokenType?.value

            refreshTokenValue = refreshTokenSnapshot.value
            refreshTokenIssuedAt = refreshTokenSnapshot.issuedAt
            refreshTokenExpiresAt = refreshTokenSnapshot.expiresAt
            refreshTokenMetadata = writeAttributes(refreshTokenSnapshot.metadata)

            oidcIdTokenValue = oidcIdTokenSnapshot.value
            oidcIdTokenIssuedAt = oidcIdTokenSnapshot.issuedAt
            oidcIdTokenExpiresAt = oidcIdTokenSnapshot.expiresAt
            oidcIdTokenMetadata = writeAttributes(oidcIdTokenSnapshot.metadata)

            userCodeValue = userCodeSnapshot.value
            userCodeIssuedAt = userCodeSnapshot.issuedAt
            userCodeExpiresAt = userCodeSnapshot.expiresAt
            userCodeMetadata = writeAttributes(userCodeSnapshot.metadata)

            deviceCodeValue = deviceCodeSnapshot.value
            deviceCodeIssuedAt = deviceCodeSnapshot.issuedAt
            deviceCodeExpiresAt = deviceCodeSnapshot.expiresAt
            deviceCodeMetadata = writeAttributes(deviceCodeSnapshot.metadata)
        }
    }

    /** 返回转为令牌快照 */
    private fun toTokenSnapshot(token: OAuth2Authorization.Token<*>?): TokenSnapshot {
        if (token == null) {
            return TokenSnapshot()
        }
        val oAuth2Token: OAuth2Token = token.token
        return TokenSnapshot(
            value = oAuth2Token.tokenValue,
            issuedAt = oAuth2Token.issuedAt,
            expiresAt = oAuth2Token.expiresAt,
            metadata = token.metadata,
        )
    }

    /** 读取属性映射 */
    private fun readAttributes(json: String?): Map<String, Any>? {
        if (json.isNullOrBlank()) {
            return null
        }
        return mapper.readValue(json, object : TypeReference<Map<String, Any>>() {})
    }

    /** 写入属性映射 */
    private fun writeAttributes(attributes: Map<String, Any>?): String? {
        if (attributes.isNullOrEmpty()) {
            return null
        }
        return mapper.writeValueAsString(attributes)
    }

    private fun readCommaDelimitedSet(value: String?): LinkedHashSet<String> = value?.splitToSequence(',')?.toCollection(linkedSetOf()) ?: linkedSetOf()

    /** 提取 OIDC ID Token claims */
    private fun extractOidcIdTokenClaims(metadata: Map<String, Any>?): Map<String, Any> {
        if (metadata.isNullOrEmpty()) {
            return emptyMap()
        }
        val claims = metadata[OAuth2Authorization.Token.CLAIMS_METADATA_NAME]
        if (claims is Map<*, *>) {
            return claims.entries
                .filter { it.key is String }
                .associate { it.key as String to it.value as Any }
        }
        return metadata
    }

    private data class TokenSnapshot(
        /** 值 */
        val value: String? = null,
        /** 签发时间 */
        val issuedAt: Instant? = null,
        /** 过期时间 */
        val expiresAt: Instant? = null,
        /** 元数据 */
        val metadata: Map<String, Any>? = null,
    )

    companion object {
        /** 映射器 */
        private val mapper: JsonMapper =
            JsonMapper
                .builder()
                .addModules(SecurityJacksonModules.getModules(DefaultOAuth2AuthorizationService::class.java.classLoader))
                .build()

        /** 解析授权类型 */
        private fun resolveAuthorizationGrantType(
            authorizationGrantType: String,
        ): AuthorizationGrantType =
            when (authorizationGrantType) {
                AuthorizationGrantType.AUTHORIZATION_CODE.value -> {
                    AuthorizationGrantType.AUTHORIZATION_CODE
                }

                AuthorizationGrantType.CLIENT_CREDENTIALS.value -> {
                    AuthorizationGrantType.CLIENT_CREDENTIALS
                }

                AuthorizationGrantType.REFRESH_TOKEN.value -> {
                    AuthorizationGrantType.REFRESH_TOKEN
                }

                AuthorizationGrantType.DEVICE_CODE.value -> {
                    AuthorizationGrantType.DEVICE_CODE
                }

                else -> {
                    AuthorizationGrantType(authorizationGrantType)
                }
            }
    }
}