ReferenceOAuth2AccessTokenGenerator.kt

package io.github.lishangbu.avalon.oauth2.authorizationserver.token

import io.github.lishangbu.avalon.oauth2.authorizationserver.keygen.UuidKeyGenerator
import org.springframework.security.core.Authentication
import org.springframework.security.crypto.keygen.StringKeyGenerator
import org.springframework.security.oauth2.core.ClaimAccessor
import org.springframework.security.oauth2.core.OAuth2AccessToken
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient
import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenClaimsSet
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator
import java.time.Instant
import java.util.*

/**
 * 引用型访问令牌生成器
 *
 * 生成携带声明集的 reference access token
 */
class ReferenceOAuth2AccessTokenGenerator : OAuth2TokenGenerator<OAuth2AccessToken> {
    /** 访问令牌生成器 */
    private val accessTokenGenerator: StringKeyGenerator = UuidKeyGenerator()

    /** 生成引用型访问令牌 */
    override fun generate(context: OAuth2TokenContext): OAuth2AccessToken? {
        if (
            OAuth2TokenType.ACCESS_TOKEN != context.tokenType ||
            OAuth2TokenFormat.REFERENCE !=
            context.registeredClient.tokenSettings.accessTokenFormat
        ) {
            return null
        }

        val issuer = context.authorizationServerContext?.issuer
        val registeredClient: RegisteredClient = context.registeredClient
        val authorizedScopes = context.authorizedScopes ?: emptySet()

        val issuedAt = Instant.now()
        val expiresAt = issuedAt.plus(registeredClient.tokenSettings.accessTokenTimeToLive)

        val claimsBuilder = OAuth2TokenClaimsSet.builder()
        issuer?.takeIf(String::isNotBlank)?.let(claimsBuilder::issuer)
        claimsBuilder
            .subject(context.getPrincipal<Authentication>().name)
            .audience(listOf(registeredClient.clientId))
            .issuedAt(issuedAt)
            .expiresAt(expiresAt)
            .notBefore(issuedAt)
            .id(UUID.randomUUID().toString())
        if (authorizedScopes.isNotEmpty()) {
            claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes)
        }
        val accessTokenClaimsSet = claimsBuilder.build()

        return OAuth2AccessTokenClaims(
            OAuth2AccessToken.TokenType.BEARER,
            accessTokenGenerator.generateKey(),
            accessTokenClaimsSet.issuedAt,
            accessTokenClaimsSet.expiresAt,
            LinkedHashSet(authorizedScopes),
            LinkedHashMap(accessTokenClaimsSet.claims),
        )
    }

    private class OAuth2AccessTokenClaims(
        tokenType: OAuth2AccessToken.TokenType,
        tokenValue: String,
        issuedAt: Instant?,
        expiresAt: Instant?,
        scopes: MutableSet<String>,
        /** 声明集 */
        private val claims: MutableMap<String, Any>,
    ) : OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, scopes),
        ClaimAccessor {
        /** 获取声明 */
        override fun getClaims(): MutableMap<String, Any> = claims
    }
}