OAuth2AccessTokenApiResultResponseAuthenticationSuccessHandler.kt

package io.github.lishangbu.avalon.oauth2.authorizationserver.web.authentication

import io.github.lishangbu.avalon.oauth2.common.log.AuthenticationLogRecord
import io.github.lishangbu.avalon.oauth2.common.log.AuthenticationLogRecorder
import io.github.lishangbu.avalon.oauth2.common.properties.Oauth2Properties
import io.github.lishangbu.avalon.web.util.JsonResponseWriter
import jakarta.servlet.ServletException
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.springframework.security.core.Authentication
import org.springframework.security.oauth2.core.*
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationContext
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken
import org.springframework.security.web.authentication.AuthenticationSuccessHandler
import tools.jackson.databind.json.JsonMapper
import java.io.IOException
import java.security.Principal
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.function.Consumer

/**
 * OAuth2 访问令牌成功响应处理器
 *
 * 在认证成功后写回统一的令牌响应,并记录认证日志
 */
class OAuth2AccessTokenApiResultResponseAuthenticationSuccessHandler
    @JvmOverloads
    constructor(
        authenticationLogRecorder: AuthenticationLogRecorder? = AuthenticationLogRecorder.noop(),
        /** 授权服务 */
        private val authorizationService: OAuth2AuthorizationService? = null,
        /** OAuth2 属性 */
        private val oauth2Properties: Oauth2Properties? = null,
        jsonMapper: JsonMapper,
    ) : AuthenticationSuccessHandler {
        /** 日志记录器 */
        private val logger: Log = LogFactory.getLog(javaClass)

        /** 访问令牌响应定制器 */
        private var accessTokenResponseCustomizer: Consumer<OAuth2AccessTokenAuthenticationContext>? =
            null

        /** 认证日志记录器 */
        private val authenticationLogRecorder: AuthenticationLogRecorder =
            authenticationLogRecorder ?: AuthenticationLogRecorder.noop()

        /** JSON 映射器 */
        private val jsonMapper: JsonMapper =
            requireNotNull(jsonMapper) { "jsonMapper cannot be null" }

        /** 将访问令牌响应写入统一 JSON,并记录认证成功日志 */
        @Throws(IOException::class, ServletException::class)
        override fun onAuthenticationSuccess(
            request: HttpServletRequest,
            response: HttpServletResponse,
            authentication: Authentication,
        ) {
            if (authentication !is OAuth2AccessTokenAuthenticationToken) {
                if (logger.isErrorEnabled) {
                    logger.error(
                        Authentication::class.java.simpleName +
                            " must be of type " +
                            OAuth2AccessTokenAuthenticationToken::class.java.name +
                            " but was " +
                            authentication.javaClass.name,
                    )
                }
                val error =
                    OAuth2Error(
                        OAuth2ErrorCodes.SERVER_ERROR,
                        "Unable to process the access token response.",
                        null,
                    )
                throw OAuth2AuthenticationException(error)
            }

            val accessToken = authentication.accessToken
            val refreshToken: OAuth2RefreshToken? = authentication.refreshToken
            val additionalParameters = authentication.additionalParameters

            val builder =
                OAuth2AccessTokenResponse
                    .withToken(accessToken.tokenValue)
                    .tokenType(accessToken.tokenType)
                    .scopes(accessToken.scopes)
            if (accessToken.issuedAt != null && accessToken.expiresAt != null) {
                builder.expiresIn(
                    ChronoUnit.SECONDS.between(accessToken.issuedAt, accessToken.expiresAt),
                )
            }
            if (refreshToken != null) {
                builder.refreshToken(refreshToken.tokenValue)
            }
            if (additionalParameters.isNotEmpty()) {
                builder.additionalParameters(additionalParameters)
            }

            if (accessTokenResponseCustomizer != null) {
                // @formatter:off
                val accessTokenAuthenticationContext =
                    OAuth2AccessTokenAuthenticationContext
                        .with(authentication)
                        .accessTokenResponse(builder)
                        .build()
                // @formatter:on
                accessTokenResponseCustomizer?.accept(accessTokenAuthenticationContext)
                if (logger.isTraceEnabled) {
                    logger.trace("Customized access token response")
                }
            }

            val accessTokenResponse = builder.build()
            recordAuthenticationSuccess(request, authentication)
            JsonResponseWriter.writeSuccessResponse(
                response,
                jsonMapper,
                buildTokenResponseBody(accessTokenResponse),
            )
        }

        /** 记录本次令牌签发对应的认证日志 */
        private fun recordAuthenticationSuccess(
            request: HttpServletRequest,
            accessTokenAuthentication: OAuth2AccessTokenAuthenticationToken,
        ) {
            try {
                val principal = accessTokenAuthentication.principal
                val grantType =
                    resolveGrantType(request, accessTokenAuthentication.additionalParameters)
                val username = resolveUsername(principal, request, accessTokenAuthentication, grantType)
                val clientId = resolveClientId(accessTokenAuthentication, principal)
                val record =
                    AuthenticationLogRecord(
                        normalize(username),
                        normalize(clientId),
                        normalize(grantType),
                        resolveClientIp(request),
                        normalize(request.getHeader("User-Agent")),
                        true,
                        null,
                        Instant.now(),
                    )
                authenticationLogRecorder.record(record)
            } catch (ex: Exception) {
                if (logger.isWarnEnabled) {
                    logger.warn("Failed to record authentication log", ex)
                }
            }
        }

        /** 按请求参数、授权记录和主体信息的优先级解析用户名 */
        private fun resolveUsername(
            principal: Any?,
            request: HttpServletRequest,
            accessTokenAuthentication: OAuth2AccessTokenAuthenticationToken,
            grantType: String?,
        ): String? {
            if (AuthorizationGrantType.CLIENT_CREDENTIALS.value == grantType) {
                return null
            }
            val requestUsername = normalize(request.getParameter(resolveUsernameParameterName()))
            if (requestUsername != null) {
                return requestUsername
            }
            val authorizationUsername = resolveAuthorizationUsername(accessTokenAuthentication)
            if (authorizationUsername != null) {
                return authorizationUsername
            }
            if (principal == null) {
                return null
            }
            if (principal is OAuth2ClientAuthenticationToken) {
                return null
            }
            if (principal is Authentication) {
                return normalize(principal.name)
            }
            if (principal is Principal) {
                return normalize(principal.name)
            }
            return normalize(principal.toString())
        }

        /** 从已保存授权记录中回查主体名称 */
        private fun resolveAuthorizationUsername(
            accessTokenAuthentication: OAuth2AccessTokenAuthenticationToken?,
        ): String? {
            if (authorizationService == null || accessTokenAuthentication == null) {
                return null
            }
            val accessToken = accessTokenAuthentication.accessToken ?: return null
            val authorization =
                authorizationService.findByToken(accessToken.tokenValue, OAuth2TokenType.ACCESS_TOKEN)
                    ?: return null
            return normalize(authorization.principalName)
        }

        /** 从认证结果或客户端主体中解析客户端 ID */
        private fun resolveClientId(
            accessTokenAuthentication: OAuth2AccessTokenAuthenticationToken,
            principal: Any?,
        ): String? {
            if (accessTokenAuthentication.registeredClient != null) {
                return normalize(accessTokenAuthentication.registeredClient.clientId)
            }
            if (principal is OAuth2ClientAuthenticationToken) {
                return normalize(principal.name)
            }
            return null
        }

        /** 优先从请求参数解析授权类型,缺失时回退到附加参数 */
        private fun resolveGrantType(
            request: HttpServletRequest,
            additionalParameters: Map<String, Any>?,
        ): String? {
            val grantType = normalize(request.getParameter("grant_type"))
            if (grantType != null) {
                return grantType
            }
            if (additionalParameters != null) {
                val value = additionalParameters["grant_type"]
                if (value is String && value.isNotBlank()) {
                    return value
                }
            }
            return null
        }

        /** 解析用户名请求参数名 */
        private fun resolveUsernameParameterName(): String {
            if (oauth2Properties == null) {
                return "username"
            }
            val configured = oauth2Properties.usernameParameterName
            return normalize(configured) ?: "username"
        }

        /** 按代理头优先级解析客户端 IP */
        private fun resolveClientIp(request: HttpServletRequest): String? {
            val forwardedFor = request.getHeader("X-Forwarded-For")
            if (forwardedFor != null && forwardedFor.isNotBlank()) {
                return normalize(forwardedFor.substringBefore(','))
            }
            val realIp = request.getHeader("X-Real-IP")
            if (realIp != null && realIp.isNotBlank()) {
                return normalize(realIp)
            }
            return normalize(request.remoteAddr)
        }

        /** 裁剪并过滤空白字符串 */
        private fun normalize(value: String?): String? = value?.trim()?.takeIf { it.isNotEmpty() }

        /** 构建统一返回的令牌响应体 */
        private fun buildTokenResponseBody(
            accessTokenResponse: OAuth2AccessTokenResponse,
        ): Map<String, Any> {
            val body: MutableMap<String, Any> = LinkedHashMap()
            val accessToken = accessTokenResponse.accessToken
            body["access_token"] = accessToken.tokenValue
            body["token_type"] = accessToken.tokenType.value
            if (accessToken.issuedAt != null && accessToken.expiresAt != null) {
                body["expires_in"] =
                    ChronoUnit.SECONDS.between(accessToken.issuedAt, accessToken.expiresAt)
            }
            accessTokenResponse.refreshToken?.let { refreshToken ->
                body["refresh_token"] = refreshToken.tokenValue
            }
            val scopes = accessToken.scopes
            if (scopes != null && scopes.isNotEmpty()) {
                body["scope"] = scopes.joinToString(" ")
            }
            val additionalParameters = accessTokenResponse.additionalParameters
            if (additionalParameters != null && additionalParameters.isNotEmpty()) {
                body.putAll(additionalParameters)
            }
            return body
        }
    }