@@ -223,7 +223,7 @@ def comb(n: int, r: int, mod: int | None = None) -> int:
223223
224224
225225# 多次元配列作成
226- from typing import Any , List
226+ from typing import List , Any
227227
228228
229229def create_array1 (n : int , default : Any = 0 ) -> List [Any ]:
@@ -717,6 +717,70 @@ def op(a,b):
717717
718718vは配列の長さまたは、初期化する内容
719719"""
720+ from collections import defaultdict
721+ import math
722+
723+
724+ class WeightedTreeLCA :
725+ def __init__ (self , n ):
726+ """初期化: ノード数nの木を構築(0-indexed)"""
727+ self .n = n
728+ self .log = math .ceil (math .log2 (n )) + 1
729+ self .adj = defaultdict (list ) # 隣接リスト: {ノード: [(隣接ノード, 重み), ...]}
730+ self .depth = [0 ] * n # 各ノードの深さ
731+ self .dist = [0 ] * n # 根からの重み合計
732+ self .ancestor = [[- 1 ] * self .log for _ in range (n )] # ダブリングテーブル
733+
734+ def add_edge (self , u , v , w ):
735+ """辺を追加: uとvを重みwで接続"""
736+ self .adj [u ].append ((v , w ))
737+ self .adj [v ].append ((u , w ))
738+
739+ def dfs (self , u , parent , d , w ):
740+ """DFSで深さ、距離、親を計算"""
741+ self .depth [u ] = d
742+ self .dist [u ] = w
743+ for v , weight in self .adj [u ]:
744+ if v != parent :
745+ self .ancestor [v ][0 ] = u
746+ self .dfs (v , u , d + 1 , w + weight )
747+
748+ def build (self , root = 0 ):
749+ """ダブリングテーブルの構築"""
750+ # DFSで初期情報収集
751+ self .dfs (root , - 1 , 0 , 0 )
752+ # ダブリングテーブルを埋める
753+ for k in range (1 , self .log ):
754+ for u in range (self .n ):
755+ if self .ancestor [u ][k - 1 ] != - 1 :
756+ self .ancestor [u ][k ] = self .ancestor [self .ancestor [u ][k - 1 ]][k - 1 ]
757+
758+ def lca (self , u , v ):
759+ """ノードuとvのLCAを求める"""
760+ # 深さを揃える
761+ if self .depth [u ] < self .depth [v ]:
762+ u , v = v , u
763+ for k in range (self .log - 1 , - 1 , - 1 ):
764+ if (
765+ self .ancestor [u ][k ] != - 1
766+ and self .depth [self .ancestor [u ][k ]] >= self .depth [v ]
767+ ):
768+ u = self .ancestor [u ][k ]
769+ if u == v :
770+ return u
771+ # 同時にジャンプ
772+ for k in range (self .log - 1 , - 1 , - 1 ):
773+ if self .ancestor [u ][k ] != self .ancestor [v ][k ]:
774+ u = self .ancestor [u ][k ]
775+ v = self .ancestor [v ][k ]
776+ return self .ancestor [u ][0 ]
777+
778+ def get_distance (self , u , v ):
779+ """ノードuとvの間の距離(重みの合計)を求める"""
780+ lca_node = self .lca (u , v )
781+ return self .dist [u ] + self .dist [v ] - 2 * self .dist [lca_node ]
782+
783+
720784# グラフ構造
721785# 無向グラフ
722786from collections import deque
0 commit comments