OAuth2EndpointUtils.kt

package io.github.lishangbu.avalon.oauth2.authorizationserver.util

import jakarta.servlet.http.HttpServletRequest
import org.springframework.security.oauth2.core.AuthorizationGrantType
import org.springframework.security.oauth2.core.OAuth2AuthenticationException
import org.springframework.security.oauth2.core.OAuth2Error
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap

object OAuth2EndpointUtils {
    /** 访问令牌请求错误 URI */
    const val ACCESS_TOKEN_REQUEST_ERROR_URI =
        "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"

    /** 获取参数 */
    @JvmStatic
    fun getParameters(request: HttpServletRequest): MultiValueMap<String, String> {
        val parameterMap = request.parameterMap
        val parameters = LinkedMultiValueMap<String, String>(parameterMap.size)
        parameterMap.forEach { (key, values) ->
            values.forEach { value -> parameters.add(key, value) }
        }
        return parameters
    }

    /** 获取表单参数 */
    @JvmStatic
    fun getFormParameters(request: HttpServletRequest): MultiValueMap<String, String> = getParameters(request)

    /** 判断是否匹配PKCE 令牌请求 */
    @JvmStatic
    fun matchesPkceTokenRequest(request: HttpServletRequest): Boolean =
        AuthorizationGrantType.AUTHORIZATION_CODE.value ==
            request.getParameter(OAuth2ParameterNames.GRANT_TYPE) &&
            request.getParameter(OAuth2ParameterNames.CODE) != null &&
            request.getParameter(PkceParameterNames.CODE_VERIFIER) != null

    /** 抛出错误 */
    @JvmStatic
    fun throwError(
        errorCode: String,
        parameterName: String,
        errorUri: String,
    ): Nothing {
        val error = OAuth2Error(errorCode, "OAuth 2.0 Parameter: $parameterName", errorUri)
        throw OAuth2AuthenticationException(error)
    }
}