SnowflakeIdGenerator.kt

package io.github.lishangbu.avalon.jimmer.id

import org.babyfish.jimmer.sql.meta.UserIdGenerator
import java.lang.management.ManagementFactory
import java.net.InetAddress
import java.net.NetworkInterface
import java.util.concurrent.ThreadLocalRandom

/**
 * Jimmer 主键生成器
 *
 * 雪花算法实现与 MyBatis-Plus Sequence 保持一致
 */
class SnowflakeIdGenerator : UserIdGenerator<Long> {
    /** 生成雪花 ID */
    override fun generate(entityType: Class<*>): Long = sequence.nextId()

    companion object {
        /** 序列 */
        private val sequence = Sequence()
    }

    private class Sequence(
        inetAddress: InetAddress? = null,
    ) {
        // 与 MyBatis-Plus Sequence 一致的起始时间戳
        private val twepoch = 1288834974657L

        // 机器与数据中心位宽:5 + 5
        private val workerIdBits = 5L

        /** 数据中心 ID 位数 */
        private val datacenterIdBits = 5L

        /** 最大工作节点 ID */
        private val maxWorkerId = -1L xor (-1L shl workerIdBits.toInt())

        /** 最大数据中心 ID */
        private val maxDatacenterId = -1L xor (-1L shl datacenterIdBits.toInt())

        // 序列位宽:12
        private val sequenceBits = 12L

        /** 工作节点 ID 偏移 */
        private val workerIdShift = sequenceBits

        /** 数据中心 ID 偏移 */
        private val datacenterIdShift = sequenceBits + workerIdBits

        /** 时间戳左偏移 */
        private val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits

        /** 序列掩码 */
        private val sequenceMask = -1L xor (-1L shl sequenceBits.toInt())

        /** 数据中心 ID */
        private val datacenterId: Long

        /** 工作节点 ID */
        private val workerId: Long

        /** 序列 */
        private var sequence = 0L

        /** 最后时间戳 */
        private var lastTimestamp = -1L

        init {
            datacenterId = getDatacenterId(maxDatacenterId, inetAddress)
            workerId = getMaxWorkerId(datacenterId, maxWorkerId)
        }

        /** 生成下一个 ID */
        @Synchronized
        fun nextId(): Long {
            var timestamp = timeGen()
            // 时钟回拨时,按 MP 逻辑:小回拨等待,大回拨直接失败
            if (timestamp < lastTimestamp) {
                val offset = lastTimestamp - timestamp
                if (offset <= 5) {
                    Thread.sleep(offset shl 1)
                    timestamp = timeGen()
                    check(timestamp >= lastTimestamp) {
                        "Clock moved backwards. Refusing to generate id for $offset milliseconds"
                    }
                } else {
                    error("Clock moved backwards. Refusing to generate id for $offset milliseconds")
                }
            }

            if (lastTimestamp == timestamp) {
                // 同毫秒内自增序列,溢出则等待到下一毫秒
                sequence = (sequence + 1) and sequenceMask
                if (sequence == 0L) {
                    timestamp = tilNextMillis(lastTimestamp)
                }
            } else {
                // 跨毫秒时将序列随机置为 1..2,降低固定尾号聚集
                sequence = ThreadLocalRandom.current().nextLong(1, 3)
            }

            lastTimestamp = timestamp
            return ((timestamp - twepoch) shl timestampLeftShift.toInt()) or
                (datacenterId shl datacenterIdShift.toInt()) or
                (workerId shl workerIdShift.toInt()) or
                sequence
        }

        /** 获取数据中心 ID */
        private fun getDatacenterId(
            maxDatacenterId: Long,
            inetAddress: InetAddress?,
        ): Long {
            return try {
                // 使用网卡 MAC 推导数据中心 ID
                val address = inetAddress ?: InetAddress.getLocalHost()
                val network = NetworkInterface.getByInetAddress(address) ?: return 1L
                val mac = network.hardwareAddress ?: return 0L
                val id =
                    (
                        (0x000000FFL and mac[mac.size - 2].toLong()) or
                            (0x0000FF00L and (mac[mac.size - 1].toLong() shl 8))
                    ) shr 6
                id % (maxDatacenterId + 1)
            } catch (_: Exception) {
                0L
            }
        }

        /** 获取最大工作节点 ID */
        private fun getMaxWorkerId(
            datacenterId: Long,
            maxWorkerId: Long,
        ): Long {
            // 使用 datacenterId + pid 哈希推导 workerId
            val runtimeName = ManagementFactory.getRuntimeMXBean().name
            val pid =
                runtimeName
                    .substringBefore('@')
                    .toIntOrNull()
                    ?.let { if (it < 10) ThreadLocalRandom.current().nextInt(10, 4194304) else it }
            val mpid =
                buildString {
                    append(datacenterId)
                    if (pid != null) {
                        append(pid)
                    }
                }
            return (mpid.hashCode().toLong() and 0xffff) % (maxWorkerId + 1)
        }

        /** 等待到下一毫秒 */
        private fun tilNextMillis(lastTimestamp: Long): Long {
            var timestamp = timeGen()
            while (timestamp <= lastTimestamp) {
                timestamp = timeGen()
            }
            return timestamp
        }

        /** 获取当前时间戳 */
        private fun timeGen(): Long = System.currentTimeMillis()
    }
}