RedisIdempotentStore.kt
package io.github.lishangbu.avalon.idempotent.store
import org.springframework.data.redis.core.StringRedisTemplate
import org.springframework.data.redis.core.script.DefaultRedisScript
import tools.jackson.databind.json.JsonMapper
import java.time.Duration
/**
* Redis implementation of [IdempotentStore].
*/
class RedisIdempotentStore(
private val stringRedisTemplate: StringRedisTemplate,
private val jsonMapper: JsonMapper,
) : IdempotentStore {
private val completeScript =
DefaultRedisScript<Long>().apply {
setScriptText(COMPLETE_SCRIPT)
setResultType(Long::class.java)
}
private val releaseScript =
DefaultRedisScript<Long>().apply {
setScriptText(RELEASE_SCRIPT)
setResultType(Long::class.java)
}
private val renewScript =
DefaultRedisScript<Long>().apply {
setScriptText(RENEW_SCRIPT)
setResultType(Long::class.java)
}
override fun acquire(
key: String,
token: String,
processingTtl: Duration,
): IdempotentStore.AcquireResult {
val pendingValue = serialize(StoredValue.processing(token))
if (stringRedisTemplate.opsForValue().setIfAbsent(key, pendingValue, processingTtl) == true) {
return IdempotentStore.AcquireResult.Acquired
}
repeat(2) {
val existingValue = stringRedisTemplate.opsForValue().get(key)
if (existingValue != null) {
return when (deserialize(existingValue).status) {
StoredStatus.PROCESSING -> {
IdempotentStore.AcquireResult.Processing
}
StoredStatus.SUCCEEDED -> {
IdempotentStore.AcquireResult.Completed(
cachedValue = deserialize(existingValue).cachedValue,
)
}
}
}
if (stringRedisTemplate.opsForValue().setIfAbsent(key, pendingValue, processingTtl) == true) {
return IdempotentStore.AcquireResult.Acquired
}
}
return IdempotentStore.AcquireResult.Processing
}
override fun complete(
key: String,
token: String,
cachedValue: String?,
ttl: Duration,
): Boolean =
stringRedisTemplate.execute(
completeScript,
listOf(key),
token,
serialize(StoredValue.succeeded(token, cachedValue)),
ttl.toMillis().toString(),
) == 1L
override fun release(
key: String,
token: String,
): Boolean = stringRedisTemplate.execute(releaseScript, listOf(key), token) == 1L
override fun renew(
key: String,
token: String,
processingTtl: Duration,
): Boolean =
stringRedisTemplate.execute(
renewScript,
listOf(key),
token,
processingTtl.toMillis().toString(),
) == 1L
private fun serialize(value: StoredValue): String = jsonMapper.writeValueAsString(value)
private fun deserialize(value: String): StoredValue = jsonMapper.readValue(value, StoredValue::class.java)
private data class StoredValue(
val status: StoredStatus,
val token: String,
val cachedValue: String? = null,
) {
companion object {
fun processing(token: String): StoredValue = StoredValue(status = StoredStatus.PROCESSING, token = token)
fun succeeded(
token: String,
cachedValue: String?,
): StoredValue = StoredValue(status = StoredStatus.SUCCEEDED, token = token, cachedValue = cachedValue)
}
}
private enum class StoredStatus {
PROCESSING,
SUCCEEDED,
}
companion object {
private const val COMPLETE_SCRIPT: String =
"""
local current = redis.call('GET', KEYS[1])
if not current then
return 0
end
local decoded = cjson.decode(current)
if decoded['status'] ~= 'PROCESSING' or decoded['token'] ~= ARGV[1] then
return 0
end
redis.call('SET', KEYS[1], ARGV[2], 'PX', ARGV[3])
return 1
"""
private const val RELEASE_SCRIPT: String =
"""
local current = redis.call('GET', KEYS[1])
if not current then
return 0
end
local decoded = cjson.decode(current)
if decoded['token'] ~= ARGV[1] then
return 0
end
redis.call('DEL', KEYS[1])
return 1
"""
private const val RENEW_SCRIPT: String =
"""
local current = redis.call('GET', KEYS[1])
if not current then
return 0
end
local decoded = cjson.decode(current)
if decoded['status'] ~= 'PROCESSING' or decoded['token'] ~= ARGV[1] then
return 0
end
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return 1
"""
}
}