DefaultRegisteredClientRepository.kt

package io.github.lishangbu.avalon.auth.service.impl

import io.github.lishangbu.avalon.auth.entity.OauthRegisteredClient
import io.github.lishangbu.avalon.auth.repository.Oauth2RegisteredClientRepository
import org.springframework.boot.convert.DurationStyle
import org.springframework.security.oauth2.core.AuthorizationGrantType
import org.springframework.security.oauth2.core.ClientAuthenticationMethod
import org.springframework.security.oauth2.jose.jws.MacAlgorithm
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository
import org.springframework.security.oauth2.server.authorization.settings.ClientSettings
import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat
import org.springframework.security.oauth2.server.authorization.settings.TokenSettings
import org.springframework.stereotype.Component
import java.time.temporal.ChronoUnit

/**
 * RegisteredClient 仓储适配器
 *
 * 负责在 Spring Authorization Server 的 [RegisteredClient] 与持久化实体之间转换
 *
 * @author lishangbu
 * @since 2025/8/17
 */
@Component
class DefaultRegisteredClientRepository(
    /** OAuth2 注册客户端仓储 */
    private val oauth2RegisteredClientRepository: Oauth2RegisteredClientRepository,
) : RegisteredClientRepository {
    /**
     * 保存注册客户端
     *
     * @param registeredClient 注册客户端
     */
    override fun save(registeredClient: RegisteredClient) {
        oauth2RegisteredClientRepository.save(toEntity(registeredClient))
    }

    /**
     * 按 ID 查询注册客户端
     *
     * @param id 客户端 ID
     * @return 注册客户端,未找到时返回 null
     */
    override fun findById(id: String): RegisteredClient? {
        require(id.isNotBlank()) { "id cannot be empty" }
        return oauth2RegisteredClientRepository.findNullable(id)?.let(::toObject)
    }

    /**
     * 按客户端 ID 查询注册客户端
     *
     * @param clientId 客户端 ID
     * @return 注册客户端,未找到时返回 null
     */
    override fun findByClientId(clientId: String): RegisteredClient? {
        require(clientId.isNotBlank()) { "clientId cannot be empty" }
        return oauth2RegisteredClientRepository
            .findByClientId(clientId)
            ?.let(::toObject)
    }

    /**
     * 将实体转换为对象
     *
     * @param client 实体
     * @return 对象
     */
    private fun toObject(client: OauthRegisteredClient): RegisteredClient {
        val clientAuthenticationMethods = readCommaDelimitedSet(client.clientAuthenticationMethods)
        val authorizationGrantTypes = readCommaDelimitedSet(client.authorizationGrantTypes)
        val redirectUris = readCommaDelimitedSet(client.redirectUris)
        val postLogoutRedirectUris = readCommaDelimitedSet(client.postLogoutRedirectUris)
        val clientScopes = readCommaDelimitedSet(client.scopes)

        val builder =
            RegisteredClient
                .withId(client.id)
                .clientId(client.clientId)
                .clientIdIssuedAt(client.clientIdIssuedAt)
                .clientSecret(client.clientSecret)
                .clientSecretExpiresAt(client.clientSecretExpiresAt)
                .clientName(client.clientName)
                .clientAuthenticationMethods { authenticationMethods ->
                    clientAuthenticationMethods.forEach { authenticationMethod ->
                        authenticationMethods.add(
                            resolveClientAuthenticationMethod(authenticationMethod),
                        )
                    }
                }.authorizationGrantTypes { grantTypes ->
                    authorizationGrantTypes.forEach { grantType ->
                        grantTypes.add(resolveAuthorizationGrantType(grantType))
                    }
                }.redirectUris { uris -> uris.addAll(redirectUris) }
                .postLogoutRedirectUris { uris -> uris.addAll(postLogoutRedirectUris) }
                .scopes { scopes -> scopes.addAll(clientScopes) }

        val clientSettingsBuilder = ClientSettings.builder()
        client.requireProofKey?.let { clientSettingsBuilder.requireProofKey(it) }
        client.requireAuthorizationConsent?.let {
            clientSettingsBuilder.requireAuthorizationConsent(it)
        }
        client.jwkSetUrl?.let { clientSettingsBuilder.jwkSetUrl(it) }
        if (client.tokenEndpointAuthenticationSigningAlgorithm != null) {
            val signatureAlgorithm =
                SignatureAlgorithm.from(client.tokenEndpointAuthenticationSigningAlgorithm)
            if (signatureAlgorithm != null) {
                clientSettingsBuilder.tokenEndpointAuthenticationSigningAlgorithm(
                    signatureAlgorithm,
                )
            } else {
                val macAlgorithm =
                    MacAlgorithm.from(client.tokenEndpointAuthenticationSigningAlgorithm)
                if (macAlgorithm != null) {
                    clientSettingsBuilder.tokenEndpointAuthenticationSigningAlgorithm(macAlgorithm)
                }
            }
        }
        client.x509CertificateSubjectDn?.let { clientSettingsBuilder.x509CertificateSubjectDN(it) }
        builder.clientSettings(clientSettingsBuilder.build())

        val tokenSettingsBuilder = TokenSettings.builder()
        client.reuseRefreshTokens?.let { tokenSettingsBuilder.reuseRefreshTokens(it) }
        client.x509CertificateBoundAccessTokens?.let {
            tokenSettingsBuilder.x509CertificateBoundAccessTokens(it)
        }
        client.authorizationCodeTimeToLive?.let {
            tokenSettingsBuilder.authorizationCodeTimeToLive(
                DurationStyle.detectAndParse(it, ChronoUnit.SECONDS),
            )
        }
        client.accessTokenTimeToLive?.let {
            tokenSettingsBuilder.accessTokenTimeToLive(
                DurationStyle.detectAndParse(it, ChronoUnit.SECONDS),
            )
        }
        client.accessTokenFormat?.let {
            tokenSettingsBuilder.accessTokenFormat(OAuth2TokenFormat(it))
        }
        client.deviceCodeTimeToLive?.let {
            tokenSettingsBuilder.deviceCodeTimeToLive(
                DurationStyle.detectAndParse(it, ChronoUnit.SECONDS),
            )
        }
        client.refreshTokenTimeToLive?.let {
            tokenSettingsBuilder.refreshTokenTimeToLive(
                DurationStyle.detectAndParse(it, ChronoUnit.SECONDS),
            )
        }
        client.idTokenSignatureAlgorithm?.let {
            SignatureAlgorithm.from(it)?.let { alg ->
                tokenSettingsBuilder.idTokenSignatureAlgorithm(alg)
            }
        }
        builder.tokenSettings(tokenSettingsBuilder.build())

        return builder.build()
    }

    /**
     * 将对象转换为实体
     *
     * @param registeredClient 对象
     * @return 实体
     */
    private fun toEntity(registeredClient: RegisteredClient): OauthRegisteredClient {
        val clientAuthenticationMethodValues =
            registeredClient.clientAuthenticationMethods.map { it.value }

        val authorizationGrantTypeValues = registeredClient.authorizationGrantTypes.map { it.value }
        val registeredClientSettings = registeredClient.clientSettings
        val registeredClientTokenSettings = registeredClient.tokenSettings

        return OauthRegisteredClient {
            id = registeredClient.id
            clientId = registeredClient.clientId
            clientIdIssuedAt = registeredClient.clientIdIssuedAt
            clientSecret = registeredClient.clientSecret
            clientSecretExpiresAt = registeredClient.clientSecretExpiresAt
            clientName = registeredClient.clientName
            this.clientAuthenticationMethods = clientAuthenticationMethodValues.joinToString(",")
            this.authorizationGrantTypes = authorizationGrantTypeValues.joinToString(",")
            redirectUris = registeredClient.redirectUris.joinToString(",")
            postLogoutRedirectUris = registeredClient.postLogoutRedirectUris.joinToString(",")
            scopes = registeredClient.scopes.joinToString(",")

            requireProofKey = registeredClientSettings?.isRequireProofKey
            requireAuthorizationConsent =
                registeredClientSettings?.isRequireAuthorizationConsent
            tokenEndpointAuthenticationSigningAlgorithm =
                registeredClientSettings?.tokenEndpointAuthenticationSigningAlgorithm?.name
            jwkSetUrl = registeredClientSettings?.jwkSetUrl
            x509CertificateSubjectDn = registeredClientSettings?.x509CertificateSubjectDN

            reuseRefreshTokens = registeredClientTokenSettings?.isReuseRefreshTokens
            x509CertificateBoundAccessTokens =
                registeredClientTokenSettings?.isX509CertificateBoundAccessTokens
            authorizationCodeTimeToLive =
                registeredClientTokenSettings?.authorizationCodeTimeToLive?.toString()
            accessTokenTimeToLive =
                registeredClientTokenSettings?.accessTokenTimeToLive?.toString()
            accessTokenFormat = registeredClientTokenSettings?.accessTokenFormat?.value
            deviceCodeTimeToLive = registeredClientTokenSettings?.deviceCodeTimeToLive?.toString()
            refreshTokenTimeToLive =
                registeredClientTokenSettings?.refreshTokenTimeToLive?.toString()
            idTokenSignatureAlgorithm =
                registeredClientTokenSettings?.idTokenSignatureAlgorithm?.name
        }
    }

    private fun readCommaDelimitedSet(value: String?): LinkedHashSet<String> = value?.splitToSequence(',')?.toCollection(linkedSetOf()) ?: linkedSetOf()

    companion object {
        /**
         * 解析授权类型
         *
         * @param authorizationGrantType 授权类型字符串
         * @return 授权类型对象
         */
        private fun resolveAuthorizationGrantType(
            authorizationGrantType: String,
        ): AuthorizationGrantType =
            when (authorizationGrantType) {
                AuthorizationGrantType.AUTHORIZATION_CODE.value -> {
                    AuthorizationGrantType.AUTHORIZATION_CODE
                }

                AuthorizationGrantType.CLIENT_CREDENTIALS.value -> {
                    AuthorizationGrantType.CLIENT_CREDENTIALS
                }

                AuthorizationGrantType.REFRESH_TOKEN.value -> {
                    AuthorizationGrantType.REFRESH_TOKEN
                }

                else -> {
                    AuthorizationGrantType(authorizationGrantType)
                }
            }

        /**
         * 解析客户端认证方法
         *
         * @param clientAuthenticationMethod 认证方法字符串
         * @return 认证方法对象
         */
        private fun resolveClientAuthenticationMethod(
            clientAuthenticationMethod: String,
        ): ClientAuthenticationMethod =
            when (clientAuthenticationMethod) {
                ClientAuthenticationMethod.CLIENT_SECRET_BASIC.value -> {
                    ClientAuthenticationMethod.CLIENT_SECRET_BASIC
                }

                ClientAuthenticationMethod.CLIENT_SECRET_POST.value -> {
                    ClientAuthenticationMethod.CLIENT_SECRET_POST
                }

                ClientAuthenticationMethod.NONE.value -> {
                    ClientAuthenticationMethod.NONE
                }

                else -> {
                    ClientAuthenticationMethod(clientAuthenticationMethod)
                }
            }
    }
}