OAuth2ErrorApiResultAuthenticationFailureHandler.kt

package io.github.lishangbu.avalon.oauth2.authorizationserver.web.authentication

import io.github.lishangbu.avalon.oauth2.common.log.AuthenticationLogRecord
import io.github.lishangbu.avalon.oauth2.common.log.AuthenticationLogRecorder
import io.github.lishangbu.avalon.oauth2.common.properties.Oauth2Properties
import io.github.lishangbu.avalon.web.result.DefaultErrorResultCode
import io.github.lishangbu.avalon.web.util.JsonResponseWriter
import jakarta.servlet.ServletException
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.springframework.http.HttpStatus
import org.springframework.security.core.AuthenticationException
import org.springframework.security.oauth2.core.OAuth2AuthenticationException
import org.springframework.security.oauth2.core.OAuth2ErrorCodes
import org.springframework.security.web.authentication.AuthenticationFailureHandler
import tools.jackson.databind.json.JsonMapper
import java.io.IOException
import java.nio.charset.StandardCharsets
import java.time.Instant
import java.util.*

/**
 * OAuth2 认证失败处理器
 *
 * 将认证异常转换为统一的 JSON 错误响应并记录认证日志
 */
class OAuth2ErrorApiResultAuthenticationFailureHandler
    @JvmOverloads
    constructor(
        authenticationLogRecorder: AuthenticationLogRecorder? = AuthenticationLogRecorder.noop(),
        /** OAuth2 属性 */
        private val oauth2Properties: Oauth2Properties? = null,
        jsonMapper: JsonMapper,
    ) : AuthenticationFailureHandler {
        /** 日志记录器 */
        private val logger: Log = LogFactory.getLog(javaClass)

        /** 认证日志记录器 */
        private val authenticationLogRecorder: AuthenticationLogRecorder =
            authenticationLogRecorder ?: AuthenticationLogRecorder.noop()

        /** JSON 映射器 */
        private val jsonMapper = jsonMapper

        /** 处理认证失败 */
        @Throws(IOException::class, ServletException::class)
        override fun onAuthenticationFailure(
            request: HttpServletRequest,
            response: HttpServletResponse,
            authenticationException: AuthenticationException,
        ) {
            val resolvedError = resolveError(authenticationException)
            recordAuthenticationFailure(request, resolvedError)
            writeFailedResponse(
                response,
                normalize(resolvedError.code),
                normalize(resolvedError.description),
            )
            if (authenticationException !is OAuth2AuthenticationException && logger.isWarnEnabled) {
                logger.warn(
                    AuthenticationException::class.java.simpleName +
                        " must be of type " +
                        OAuth2AuthenticationException::class.java.name +
                        " but was " +
                        authenticationException.javaClass.name,
                )
            }
        }

        /** 解析错误 */
        private fun resolveError(authenticationException: AuthenticationException?): ResolvedError {
            if (authenticationException is OAuth2AuthenticationException) {
                val error = authenticationException.error
                val errorCode = error?.errorCode
                var errorDescription = error?.description
                if (normalize(errorDescription) == null) {
                    val message = normalize(authenticationException.message)
                    if (message != null && !message.equals(normalize(errorCode), ignoreCase = true)) {
                        errorDescription = message
                    }
                }
                return ResolvedError(errorCode, sanitizeDescription(errorDescription))
            }
            if (authenticationException == null) {
                return ResolvedError(null, null)
            }
            return ResolvedError(
                authenticationException.javaClass.simpleName,
                sanitizeDescription(authenticationException.message),
            )
        }

        /** 记录认证失败 */
        private fun recordAuthenticationFailure(
            request: HttpServletRequest,
            error: ResolvedError,
        ) {
            try {
                var errorMessage = sanitizeDescription(error.description)
                if (errorMessage == null) {
                    val code = normalize(error.code)
                    if (code != null && OAuth2ErrorCodes.INVALID_GRANT != code) {
                        errorMessage = code
                    }
                }
                val record =
                    AuthenticationLogRecord(
                        normalize(request.getParameter(resolveUsernameParameterName())),
                        resolveClientId(request),
                        normalize(request.getParameter("grant_type")),
                        resolveClientIp(request),
                        normalize(request.getHeader("User-Agent")),
                        false,
                        errorMessage,
                        Instant.now(),
                    )
                authenticationLogRecorder.record(record)
            } catch (ex: Exception) {
                if (logger.isWarnEnabled) {
                    logger.warn("Failed to record authentication log", ex)
                }
            }
        }

        /** 解析用户名参数名 */
        private fun resolveUsernameParameterName(): String {
            if (oauth2Properties == null) {
                return "username"
            }
            val configured = oauth2Properties.usernameParameterName
            return normalize(configured) ?: "username"
        }

        /** 解析客户端 ID */
        private fun resolveClientId(request: HttpServletRequest): String? {
            val clientId = normalize(request.getParameter("client_id"))
            if (clientId != null) {
                return clientId
            }
            val authorization = request.getHeader("Authorization")
            if (authorization == null || !authorization.startsWith("Basic ")) {
                return null
            }
            val base64Credentials = authorization.substring("Basic ".length).trim()
            if (base64Credentials.isEmpty()) {
                return null
            }
            try {
                val decoded = Base64.getDecoder().decode(base64Credentials)
                val credentials = String(decoded, StandardCharsets.UTF_8)
                val delimiterIndex = credentials.indexOf(':')
                if (delimiterIndex > 0) {
                    return normalize(credentials.substring(0, delimiterIndex))
                }
            } catch (ex: IllegalArgumentException) {
                if (logger.isTraceEnabled) {
                    logger.trace("Failed to decode client credentials", ex)
                }
            }
            return null
        }

        /** 解析客户端 IP */
        private fun resolveClientIp(request: HttpServletRequest): String? {
            val forwardedFor = request.getHeader("X-Forwarded-For")
            if (forwardedFor != null && forwardedFor.isNotBlank()) {
                return normalize(forwardedFor.substringBefore(','))
            }
            val realIp = request.getHeader("X-Real-IP")
            if (realIp != null && realIp.isNotBlank()) {
                return normalize(realIp)
            }
            return normalize(request.remoteAddr)
        }

        /** 清理字符串值 */
        private fun normalize(value: String?): String? = value?.trim()?.takeIf { it.isNotEmpty() }

        /** 写入失败响应 */
        private fun writeFailedResponse(
            response: HttpServletResponse,
            errorCode: String?,
            errorDescription: String?,
        ) {
            val code = normalize(errorCode)
            val description = sanitizeDescription(errorDescription)
            if (description != null) {
                JsonResponseWriter.writeFailedResponse(
                    response,
                    jsonMapper,
                    HttpStatus.BAD_REQUEST,
                    DefaultErrorResultCode.BAD_REQUEST,
                    description,
                )
                return
            }
            if (code == null || OAuth2ErrorCodes.INVALID_GRANT == code) {
                JsonResponseWriter.writeFailedResponse(
                    response,
                    jsonMapper,
                    HttpStatus.BAD_REQUEST,
                    DefaultErrorResultCode.BAD_REQUEST,
                )
                return
            }
            JsonResponseWriter.writeFailedResponse(
                response,
                jsonMapper,
                HttpStatus.BAD_REQUEST,
                DefaultErrorResultCode.BAD_REQUEST,
                code,
            )
        }

        /** 清理错误描述 */
        private fun sanitizeDescription(description: String?): String? {
            val normalized = normalize(description)
            if (normalized == null) {
                return null
            }
            val lower = normalized.lowercase(Locale.ROOT)
            if (OAuth2ErrorCodes.INVALID_GRANT == lower) {
                return null
            }
            val prefix = "[" + OAuth2ErrorCodes.INVALID_GRANT + "]"
            if (lower.startsWith(prefix)) {
                val trimmed = normalize(normalized.substring(prefix.length))
                return trimmed
            }
            val colonPrefix = OAuth2ErrorCodes.INVALID_GRANT + ":"
            if (lower.startsWith(colonPrefix)) {
                val trimmed = normalize(normalized.substring(colonPrefix.length))
                return trimmed
            }
            if (lower.startsWith(OAuth2ErrorCodes.INVALID_GRANT)) {
                val trimmed = normalize(normalized.substring(OAuth2ErrorCodes.INVALID_GRANT.length))
                return trimmed
            }
            return normalized
        }

        private data class ResolvedError(
            /** 状态码 */
            val code: String?,
            /** 描述 */
            val description: String?,
        )
    }