AuthorizationEndpointResponseHandler.kt

package io.github.lishangbu.avalon.oauth2.authorizationserver.web.authentication

import jakarta.servlet.ServletException
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.springframework.http.converter.HttpMessageConverter
import org.springframework.http.server.ServletServerHttpResponse
import org.springframework.security.core.Authentication
import org.springframework.security.oauth2.core.*
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationContext
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken
import org.springframework.security.web.authentication.AuthenticationSuccessHandler
import java.io.IOException
import java.time.temporal.ChronoUnit
import java.util.function.Consumer

/**
 * 授权端点成功响应处理器
 *
 * 将访问令牌认证结果写回标准 OAuth2 访问令牌响应
 */
class AuthorizationEndpointResponseHandler : AuthenticationSuccessHandler {
    /** 日志记录器 */
    private val logger: Log = LogFactory.getLog(javaClass)

    /** 访问令牌响应转换器 */
    private val accessTokenResponseConverter: HttpMessageConverter<OAuth2AccessTokenResponse> =
        OAuth2AccessTokenResponseHttpMessageConverter()

    /** 访问令牌响应定制器 */
    private var accessTokenResponseCustomizer: Consumer<OAuth2AccessTokenAuthenticationContext>? =
        null

    /** 处理认证成功 */
    @Throws(IOException::class, ServletException::class)
    override fun onAuthenticationSuccess(
        request: HttpServletRequest,
        response: HttpServletResponse,
        authentication: Authentication,
    ) {
        if (authentication !is OAuth2AccessTokenAuthenticationToken) {
            if (logger.isErrorEnabled) {
                logger.error(
                    Authentication::class.java.simpleName +
                        " must be of type " +
                        OAuth2AccessTokenAuthenticationToken::class.java.name +
                        " but was " +
                        authentication.javaClass.name,
                )
            }
            val error =
                OAuth2Error(
                    OAuth2ErrorCodes.SERVER_ERROR,
                    "Unable to process the access token response.",
                    null,
                )
            throw OAuth2AuthenticationException(error)
        }

        val accessToken: OAuth2AccessToken = authentication.accessToken
        val refreshToken: OAuth2RefreshToken? = authentication.refreshToken
        val additionalParameters = authentication.additionalParameters

        val builder =
            OAuth2AccessTokenResponse
                .withToken(accessToken.tokenValue)
                .tokenType(accessToken.tokenType)
                .scopes(accessToken.scopes)
        if (accessToken.issuedAt != null && accessToken.expiresAt != null) {
            builder.expiresIn(
                ChronoUnit.SECONDS.between(accessToken.issuedAt, accessToken.expiresAt),
            )
        }
        if (refreshToken != null) {
            builder.refreshToken(refreshToken.tokenValue)
        }
        if (!additionalParameters.isNullOrEmpty()) {
            builder.additionalParameters(additionalParameters)
        }

        accessTokenResponseCustomizer?.let { customizer ->
            // @formatter:off
            val accessTokenAuthenticationContext =
                OAuth2AccessTokenAuthenticationContext
                    .with(authentication)
                    .accessTokenResponse(builder)
                    .build()
            // @formatter:on
            customizer.accept(accessTokenAuthenticationContext)
            if (logger.isTraceEnabled) {
                logger.trace("Customized access token response")
            }
        }

        val accessTokenResponse = builder.build()
        val httpResponse = ServletServerHttpResponse(response)
        accessTokenResponseConverter.write(accessTokenResponse, null, httpResponse)
    }

    /** 设置访问令牌响应定制器 */
    fun setAccessTokenResponseCustomizer(
        accessTokenResponseCustomizer: Consumer<OAuth2AccessTokenAuthenticationContext>,
    ) {
        this.accessTokenResponseCustomizer = accessTokenResponseCustomizer
    }
}