RedisLoginFailureTracker.kt

package io.github.lishangbu.avalon.oauth2.authorizationserver.login

import io.github.lishangbu.avalon.oauth2.common.properties.Oauth2Properties
import org.springframework.data.redis.core.StringRedisTemplate
import org.springframework.data.redis.core.script.DefaultRedisScript
import java.time.Clock
import java.time.Duration

/**
 * Redis-backed login failure tracker.
 */
class RedisLoginFailureTracker(
    properties: Oauth2Properties?,
    private val stringRedisTemplate: StringRedisTemplate,
    clock: Clock = Clock.systemUTC(),
) : AbstractLoginFailureTracker(properties, clock) {
    private val keyPrefix =
        properties
            ?.loginFailureTrackerKeyPrefix
            ?.takeIf { it.isNotBlank() }
            ?: DEFAULT_KEY_PREFIX

    private val failureScript =
        DefaultRedisScript<Long>().apply {
            setScriptText(FAILURE_SCRIPT)
            setResultType(Long::class.java)
        }

    private val remainingLockScript =
        DefaultRedisScript<Long>().apply {
            setScriptText(REMAINING_LOCK_SCRIPT)
            setResultType(Long::class.java)
        }

    override fun getRemainingLock(username: String?): Duration? {
        if (!isEnabled()) {
            return null
        }
        val key = normalize(username)?.let(::buildStorageKey) ?: return null
        val remaining =
            stringRedisTemplate.execute(
                remainingLockScript,
                listOf(key),
                now().toEpochMilli().toString(),
            ) ?: 0L

        return remaining.takeIf { it > 0 }?.let(Duration::ofMillis)
    }

    override fun onFailure(username: String?) {
        if (!isEnabled()) {
            return
        }
        val key = normalize(username)?.let(::buildStorageKey) ?: return
        stringRedisTemplate.execute(
            failureScript,
            listOf(key),
            now().toEpochMilli().toString(),
            maxFailures.toString(),
            checkNotNull(lockDuration).toMillis().toString(),
        )
    }

    override fun onSuccess(username: String?) {
        val key = normalize(username)?.let(::buildStorageKey) ?: return
        stringRedisTemplate.delete(key)
    }

    private fun buildStorageKey(username: String): String = "$keyPrefix:$username"

    companion object {
        const val DEFAULT_KEY_PREFIX: String = "oauth2:login-failure"

        private const val FAILURE_SCRIPT: String =
            """
            local currentLockUntil = redis.call('HGET', KEYS[1], 'lockUntil')
            local nowMillis = tonumber(ARGV[1])
            local maxFailures = tonumber(ARGV[2])
            local lockDurationMillis = tonumber(ARGV[3])

            if currentLockUntil then
              local lockUntilMillis = tonumber(currentLockUntil)
              if lockUntilMillis and lockUntilMillis > nowMillis then
                return lockUntilMillis - nowMillis
              end
              redis.call('DEL', KEYS[1])
            end

            local failures = redis.call('HINCRBY', KEYS[1], 'failures', 1)
            if failures >= maxFailures then
              local lockUntilMillis = nowMillis + lockDurationMillis
              redis.call('HSET', KEYS[1], 'failures', 0, 'lockUntil', lockUntilMillis)
              redis.call('PEXPIRE', KEYS[1], lockDurationMillis)
              return lockDurationMillis
            end

            redis.call('HDEL', KEYS[1], 'lockUntil')
            redis.call('PERSIST', KEYS[1])
            return 0
            """

        private const val REMAINING_LOCK_SCRIPT: String =
            """
            local currentLockUntil = redis.call('HGET', KEYS[1], 'lockUntil')
            if not currentLockUntil then
              return 0
            end

            local remaining = tonumber(currentLockUntil) - tonumber(ARGV[1])
            if remaining <= 0 then
              redis.call('DEL', KEYS[1])
              return 0
            end
            return remaining
            """
    }
}