JdbcLoginFailureTracker.kt
package io.github.lishangbu.avalon.oauth2.authorizationserver.login
import io.github.lishangbu.avalon.oauth2.common.properties.Oauth2Properties
import org.springframework.dao.DuplicateKeyException
import org.springframework.jdbc.core.JdbcTemplate
import org.springframework.jdbc.core.RowMapper
import org.springframework.transaction.PlatformTransactionManager
import org.springframework.transaction.support.TransactionTemplate
import java.sql.ResultSet
import java.sql.Timestamp
import java.time.Clock
import java.time.Duration
import java.time.Instant
/**
* JDBC-backed login failure tracker.
*/
class JdbcLoginFailureTracker(
properties: Oauth2Properties?,
private val jdbcTemplate: JdbcTemplate,
transactionManager: PlatformTransactionManager,
clock: Clock = Clock.systemUTC(),
) : AbstractLoginFailureTracker(properties, clock) {
private val tableName = validateTableName(properties?.loginFailureTrackerJdbcTableName ?: DEFAULT_TABLE_NAME)
private val transactionTemplate = TransactionTemplate(transactionManager)
private val rowMapper =
RowMapper { resultSet: ResultSet, _: Int ->
LoginFailureState(
failures = resultSet.getInt("failure_count"),
lockUntil = resultSet.getTimestamp("lock_until")?.toInstant(),
)
}
override fun getRemainingLock(username: String?): Duration? {
if (!isEnabled()) {
return null
}
val key = normalize(username) ?: return null
val state = findState(key) ?: return null
val currentTime = now()
val remainingLock = getRemainingLock(state, currentTime)
if (remainingLock == null && state.lockUntil != null) {
jdbcTemplate.update(
"delete from $tableName where username = ? and lock_until <= ?",
key,
Timestamp.from(currentTime),
)
}
return remainingLock
}
override fun onFailure(username: String?) {
if (!isEnabled()) {
return
}
val key = normalize(username) ?: return
repeat(2) { attempt ->
try {
transactionTemplate.executeWithoutResult {
val currentTime = now()
val state = findStateForUpdate(key)
if (state == null) {
insertState(key, nextState(null, currentTime), currentTime)
return@executeWithoutResult
}
if (getRemainingLock(state, currentTime) != null) {
return@executeWithoutResult
}
updateState(key, nextState(state, currentTime), currentTime)
}
return
} catch (ex: DuplicateKeyException) {
if (attempt == 1) {
throw ex
}
}
}
}
override fun onSuccess(username: String?) {
val key = normalize(username) ?: return
jdbcTemplate.update("delete from $tableName where username = ?", key)
}
private fun findState(username: String): LoginFailureState? =
jdbcTemplate
.query(
"""
select failure_count, lock_until
from $tableName
where username = ?
""".trimIndent(),
rowMapper,
username,
).firstOrNull()
private fun findStateForUpdate(username: String): LoginFailureState? =
jdbcTemplate
.query(
"""
select failure_count, lock_until
from $tableName
where username = ?
for update
""".trimIndent(),
rowMapper,
username,
).firstOrNull()
private fun insertState(
username: String,
state: LoginFailureState,
currentTime: Instant,
) {
jdbcTemplate.update(
"""
insert into $tableName (
username,
failure_count,
lock_until,
created_at,
updated_at
) values (?, ?, ?, ?, ?)
""".trimIndent(),
username,
state.failures,
state.lockUntil?.let(Timestamp::from),
Timestamp.from(currentTime),
Timestamp.from(currentTime),
)
}
private fun updateState(
username: String,
state: LoginFailureState,
currentTime: Instant,
) {
jdbcTemplate.update(
"""
update $tableName
set failure_count = ?, lock_until = ?, updated_at = ?
where username = ?
""".trimIndent(),
state.failures,
state.lockUntil?.let(Timestamp::from),
Timestamp.from(currentTime),
username,
)
}
private fun validateTableName(tableName: String): String {
require(TABLE_NAME_PATTERN.matches(tableName)) {
"Invalid JDBC login failure tracker table name: $tableName"
}
return tableName
}
companion object {
const val DEFAULT_TABLE_NAME: String = "oauth2_login_failure"
val TABLE_NAME_PATTERN: Regex = Regex("[A-Za-z0-9_]+")
}
}