2025-07-24 14:40:53 +0000 UTC

Minimum Score After Removals on a Tree

Code

class Solution:
    def calc(self, part1: int, part2: int, part3: int) -> int:
        return max(part1, part2, part3) - min(part1, part2, part3)

    def minimumScore(self, nums: List[int], edges: List[List[int]]) -> int:
        length = len(nums)
        parent_to_child: list[int] = [[] for _ in range(length)]
        for node_1, node_2 in edges:
            parent_to_child[node_1].append(node_2)
            parent_to_child[node_2].append(node_1)

        total = 0
        for num in nums:
            total ^= num

        res = float("inf")

        def dfs2(node: int, parent: int, oth: int, anc: int) -> int:
            son = nums[node]
            for child in parent_to_child[node]:
                if child == parent:
                    continue
                son ^= dfs2(child, node, oth, anc)
            if parent == anc:
                return son
            nonlocal res
            res = min(res, self.calc(oth, son, total ^ oth ^ son))
            return son

        def dfs(node: int, parent: int) -> int:
            son = nums[node]
            for child in parent_to_child[node]:
                if child == parent:
                    continue
                son ^= dfs(child, node)
            for child in parent_to_child[node]:
                if child == parent:
                    dfs2(child, node, son, node)
            return son

        dfs(0, -1)
        return res