TreeUtils.kt

package io.github.lishangbu.avalon.web.util

import org.slf4j.LoggerFactory
import java.lang.reflect.Field
import java.lang.reflect.Modifier
import java.util.function.BiConsumer
import java.util.function.Function
import java.util.function.Predicate

/**
 * 树结构处理工具
 *
 * 提供树的构建、查找、路径获取、扁平化、过滤和遍历等常用操作
 *
 * @author lishangbu
 * @since 2025/08/25
 */
object TreeUtils {
    /** 将列表构建为树结构 */
    @JvmStatic
    fun <T : Any, I> buildTree(
        list: List<T>?,
        idGetter: Function<T, I>,
        parentIdGetter: Function<T, I>,
        childrenSetter: BiConsumer<T, List<T>>,
    ): List<T> {
        if (list.isNullOrEmpty()) {
            return emptyList()
        }

        val nodeMap = list.associateBy { idGetter.apply(it) }
        val roots = mutableListOf<T>()

        for (node in list) {
            val parentId = parentIdGetter.apply(node)
            if (parentId == null) {
                roots += node
                continue
            }

            val parentNode = nodeMap[parentId]
            if (parentNode != null) {
                var children = getChildren(parentNode)
                if (children == null) {
                    children = mutableListOf()
                    childrenSetter.accept(parentNode, children)
                }
                children.add(node)
            } else {
                roots += node
            }
        }

        return roots
    }

    /** 获取子节点列表 */
    @Suppress("UNCHECKED_CAST")
    private fun <T : Any> getChildren(node: T): MutableList<T>? =
        try {
            node::class.java.getMethod("getChildren").invoke(node) as? MutableList<T>
        } catch (_: Exception) {
            null
        }

    /** 查找首个匹配节点 */
    @JvmStatic
    fun <T : Any> findNode(
        tree: List<T>?,
        predicate: Predicate<T>,
        childrenGetter: Function<T, List<T>?>,
    ): T? {
        if (tree.isNullOrEmpty()) {
            return null
        }

        for (node in tree) {
            if (predicate.test(node)) {
                return node
            }
            val children = childrenGetter.apply(node)
            if (!children.isNullOrEmpty()) {
                val found = findNode(children, predicate, childrenGetter)
                if (found != null) {
                    return found
                }
            }
        }
        return null
    }

    /** 查找所有匹配节点 */
    @JvmStatic
    fun <T : Any> findNodes(
        tree: List<T>?,
        predicate: Predicate<T>,
        childrenGetter: Function<T, List<T>?>,
    ): List<T> {
        val result = mutableListOf<T>()
        findNodesInternal(tree, predicate, childrenGetter, result)
        return result
    }

    /** 递归收集匹配节点 */
    private fun <T : Any> findNodesInternal(
        nodes: List<T>?,
        predicate: Predicate<T>,
        childrenGetter: Function<T, List<T>?>,
        result: MutableList<T>,
    ) {
        if (nodes.isNullOrEmpty()) {
            return
        }

        for (node in nodes) {
            if (predicate.test(node)) {
                result += node
            }
            val children = childrenGetter.apply(node)
            if (!children.isNullOrEmpty()) {
                findNodesInternal(children, predicate, childrenGetter, result)
            }
        }
    }

    /** 获取目标节点路径 */
    @JvmStatic
    fun <T : Any> getNodePath(
        tree: List<T>?,
        targetPredicate: Predicate<T>,
        childrenGetter: Function<T, List<T>?>,
    ): List<T> {
        val path = mutableListOf<T>()
        findPath(tree, targetPredicate, childrenGetter, path)
        return path
    }

    /** 递归查找节点路径 */
    private fun <T : Any> findPath(
        nodes: List<T>?,
        targetPredicate: Predicate<T>,
        childrenGetter: Function<T, List<T>?>,
        path: MutableList<T>,
    ): Boolean {
        if (nodes.isNullOrEmpty()) {
            return false
        }

        for (node in nodes) {
            path += node
            if (targetPredicate.test(node)) {
                return true
            }
            val children = childrenGetter.apply(node)
            if (
                !children.isNullOrEmpty() &&
                findPath(children, targetPredicate, childrenGetter, path)
            ) {
                return true
            }
            path.removeAt(path.lastIndex)
        }
        return false
    }

    /** 将树结构扁平化为列表 */
    @JvmStatic
    fun <T : Any> flattenTree(
        tree: List<T>?,
        childrenGetter: Function<T, List<T>?>,
    ): List<T> {
        val result = mutableListOf<T>()
        flattenTreeInternal(tree, childrenGetter, result)
        return result
    }

    /** 按条件过滤树结构 */
    @JvmStatic
    fun <T : Any> filterTree(
        tree: List<T>?,
        predicate: Predicate<T>,
        childrenGetter: Function<T, List<T>?>,
        childrenSetter: BiConsumer<T, List<T>?>,
    ): List<T> {
        if (tree.isNullOrEmpty()) {
            return emptyList()
        }

        val result = mutableListOf<T>()
        for (node in tree) {
            val copyNode = tryCreateCopy(node)
            val children = childrenGetter.apply(node)
            if (!children.isNullOrEmpty()) {
                val filteredChildren =
                    filterTree(children, predicate, childrenGetter, childrenSetter)
                childrenSetter.accept(copyNode, filteredChildren)
            } else {
                childrenSetter.accept(
                    copyNode,
                    if (children == null) null else emptyList(),
                )
            }

            val copyNodeChildren = childrenGetter.apply(copyNode)
            if (predicate.test(node) || !copyNodeChildren.isNullOrEmpty()) {
                result += copyNode
            }
        }
        return result
    }

    /** 尝试创建节点副本 */
    @Suppress("UNCHECKED_CAST")
    private fun <T : Any> tryCreateCopy(node: T): T =
        try {
            val copyNode = node::class.java.getDeclaredConstructor().newInstance() as T
            for (field in getAllFields(node::class.java)) {
                if (Modifier.isStatic(field.modifiers)) {
                    continue
                }
                field.isAccessible = true
                if (field.name != "children") {
                    field.set(copyNode, field.get(node))
                }
            }
            copyNode
        } catch (ex: Exception) {
            log.error("Failed to create or copy node instance: {}", ex.message)
            node
        }

    /** 遍历树结构 */
    @JvmStatic
    fun <T : Any> traverseTree(
        tree: List<T>?,
        action: BiConsumer<T, Int>,
        childrenGetter: Function<T, List<T>?>,
    ) {
        traverseTreeInternal(tree, action, childrenGetter, 0)
    }

    /** 获取树的最大深度 */
    @JvmStatic
    fun <T : Any> getMaxDepth(
        tree: List<T>?,
        childrenGetter: Function<T, List<T>?>,
    ): Int {
        if (tree.isNullOrEmpty()) {
            return 0
        }

        var maxDepth = 0
        for (node in tree) {
            val children = childrenGetter.apply(node)
            if (!children.isNullOrEmpty()) {
                maxDepth = maxOf(maxDepth, getMaxDepth(children, childrenGetter))
            }
        }
        return maxDepth + 1
    }

    /** 根据 ID 查找节点 */
    @JvmStatic
    fun <T : Any, I> findNodeById(
        tree: List<T>?,
        id: I,
        idGetter: Function<T, I>,
        childrenGetter: Function<T, List<T>?>,
    ): T? =
        findNode(
            tree,
            Predicate { node -> idGetter.apply(node) == id },
            childrenGetter,
        )

    /** 获取类型的全部字段 */
    private fun getAllFields(clazz: Class<*>): List<Field> {
        val fields = mutableListOf<Field>()
        fields += clazz.declaredFields
        val superClass = clazz.superclass
        if (superClass != null && superClass != Any::class.java) {
            fields += getAllFields(superClass)
        }
        return fields
    }

    /** 递归遍历树结构 */
    private fun <T : Any> traverseTreeInternal(
        nodes: List<T>?,
        action: BiConsumer<T, Int>,
        childrenGetter: Function<T, List<T>?>,
        level: Int,
    ) {
        if (nodes.isNullOrEmpty()) {
            return
        }
        for (node in nodes) {
            action.accept(node, level)
            val children = childrenGetter.apply(node)
            if (!children.isNullOrEmpty()) {
                traverseTreeInternal(children, action, childrenGetter, level + 1)
            }
        }
    }

    /** 递归扁平化树结构 */
    private fun <T : Any> flattenTreeInternal(
        nodes: List<T>?,
        childrenGetter: Function<T, List<T>?>,
        result: MutableList<T>,
    ) {
        if (nodes.isNullOrEmpty()) {
            return
        }
        for (node in nodes) {
            result += node
            val children = childrenGetter.apply(node)
            if (!children.isNullOrEmpty()) {
                flattenTreeInternal(children, childrenGetter, result)
            }
        }
    }

    /** 日志记录器 */
    private val log = LoggerFactory.getLogger(TreeUtils::class.java)
}