Skip to content

Commit 2f9b1a3

Browse files
committed
add: lca weighted library created by Grok
1 parent 3b486f0 commit 2f9b1a3

File tree

3 files changed

+128
-2
lines changed

3 files changed

+128
-2
lines changed

code/main.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def comb(n: int, r: int, mod: int | None = None) -> int:
223223

224224

225225
# 多次元配列作成
226-
from typing import Any, List
226+
from typing import List, Any
227227

228228

229229
def create_array1(n: int, default: Any = 0) -> List[Any]:
@@ -717,6 +717,70 @@ def op(a,b):
717717
718718
vは配列の長さまたは、初期化する内容
719719
"""
720+
from collections import defaultdict
721+
import math
722+
723+
724+
class WeightedTreeLCA:
725+
def __init__(self, n):
726+
"""初期化: ノード数nの木を構築(0-indexed)"""
727+
self.n = n
728+
self.log = math.ceil(math.log2(n)) + 1
729+
self.adj = defaultdict(list) # 隣接リスト: {ノード: [(隣接ノード, 重み), ...]}
730+
self.depth = [0] * n # 各ノードの深さ
731+
self.dist = [0] * n # 根からの重み合計
732+
self.ancestor = [[-1] * self.log for _ in range(n)] # ダブリングテーブル
733+
734+
def add_edge(self, u, v, w):
735+
"""辺を追加: uとvを重みwで接続"""
736+
self.adj[u].append((v, w))
737+
self.adj[v].append((u, w))
738+
739+
def dfs(self, u, parent, d, w):
740+
"""DFSで深さ、距離、親を計算"""
741+
self.depth[u] = d
742+
self.dist[u] = w
743+
for v, weight in self.adj[u]:
744+
if v != parent:
745+
self.ancestor[v][0] = u
746+
self.dfs(v, u, d + 1, w + weight)
747+
748+
def build(self, root=0):
749+
"""ダブリングテーブルの構築"""
750+
# DFSで初期情報収集
751+
self.dfs(root, -1, 0, 0)
752+
# ダブリングテーブルを埋める
753+
for k in range(1, self.log):
754+
for u in range(self.n):
755+
if self.ancestor[u][k - 1] != -1:
756+
self.ancestor[u][k] = self.ancestor[self.ancestor[u][k - 1]][k - 1]
757+
758+
def lca(self, u, v):
759+
"""ノードuとvのLCAを求める"""
760+
# 深さを揃える
761+
if self.depth[u] < self.depth[v]:
762+
u, v = v, u
763+
for k in range(self.log - 1, -1, -1):
764+
if (
765+
self.ancestor[u][k] != -1
766+
and self.depth[self.ancestor[u][k]] >= self.depth[v]
767+
):
768+
u = self.ancestor[u][k]
769+
if u == v:
770+
return u
771+
# 同時にジャンプ
772+
for k in range(self.log - 1, -1, -1):
773+
if self.ancestor[u][k] != self.ancestor[v][k]:
774+
u = self.ancestor[u][k]
775+
v = self.ancestor[v][k]
776+
return self.ancestor[u][0]
777+
778+
def get_distance(self, u, v):
779+
"""ノードuとvの間の距離(重みの合計)を求める"""
780+
lca_node = self.lca(u, v)
781+
return self.dist[u] + self.dist[v] - 2 * self.dist[lca_node]
782+
783+
720784
# グラフ構造
721785
# 無向グラフ
722786
from collections import deque

libs/lca_weight.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from collections import defaultdict
2+
import math
3+
4+
5+
class WeightedTreeLCA:
6+
def __init__(self, n):
7+
"""初期化: ノード数nの木を構築(0-indexed)"""
8+
self.n = n
9+
self.log = math.ceil(math.log2(n)) + 1
10+
self.adj = defaultdict(list) # 隣接リスト: {ノード: [(隣接ノード, 重み), ...]}
11+
self.depth = [0] * n # 各ノードの深さ
12+
self.dist = [0] * n # 根からの重み合計
13+
self.ancestor = [[-1] * self.log for _ in range(n)] # ダブリングテーブル
14+
15+
def add_edge(self, u, v, w):
16+
"""辺を追加: uとvを重みwで接続"""
17+
self.adj[u].append((v, w))
18+
self.adj[v].append((u, w))
19+
20+
def dfs(self, u, parent, d, w):
21+
"""DFSで深さ、距離、親を計算"""
22+
self.depth[u] = d
23+
self.dist[u] = w
24+
for v, weight in self.adj[u]:
25+
if v != parent:
26+
self.ancestor[v][0] = u
27+
self.dfs(v, u, d + 1, w + weight)
28+
29+
def build(self, root=0):
30+
"""ダブリングテーブルの構築"""
31+
# DFSで初期情報収集
32+
self.dfs(root, -1, 0, 0)
33+
# ダブリングテーブルを埋める
34+
for k in range(1, self.log):
35+
for u in range(self.n):
36+
if self.ancestor[u][k - 1] != -1:
37+
self.ancestor[u][k] = self.ancestor[self.ancestor[u][k - 1]][k - 1]
38+
39+
def lca(self, u, v):
40+
"""ノードuとvのLCAを求める"""
41+
# 深さを揃える
42+
if self.depth[u] < self.depth[v]:
43+
u, v = v, u
44+
for k in range(self.log - 1, -1, -1):
45+
if (
46+
self.ancestor[u][k] != -1
47+
and self.depth[self.ancestor[u][k]] >= self.depth[v]
48+
):
49+
u = self.ancestor[u][k]
50+
if u == v:
51+
return u
52+
# 同時にジャンプ
53+
for k in range(self.log - 1, -1, -1):
54+
if self.ancestor[u][k] != self.ancestor[v][k]:
55+
u = self.ancestor[u][k]
56+
v = self.ancestor[v][k]
57+
return self.ancestor[u][0]
58+
59+
def get_distance(self, u, v):
60+
"""ノードuとvの間の距離(重みの合計)を求める"""
61+
lca_node = self.lca(u, v)
62+
return self.dist[u] + self.dist[v] - 2 * self.dist[lca_node]

merge_file.bash

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ echo "新しいmain.py作成完了"
88
# /bin/cat python/<filename>.py >> code/main.py
99

1010
lib_path
11-
for file_name in "import.py" "math_func.py" "array_create.py" "binary_search.py" "modint.py" "standard_input.py" "yn_func.py" "grid.py" "coordinates_to_id.py" "dijkstra.py" "get_path.py" "dp.py" "coordinate_compression.py" "memo.py" "graph.py" "unionfind.py" "potential_unionfind.py" "trie.py" "bit.py" "dis_lib.py" "alias.py" "utils.py"; do
11+
for file_name in "import.py" "math_func.py" "array_create.py" "binary_search.py" "modint.py" "standard_input.py" "yn_func.py" "grid.py" "coordinates_to_id.py" "dijkstra.py" "get_path.py" "dp.py" "coordinate_compression.py" "memo.py" "lca_weight.py" "graph.py" "unionfind.py" "potential_unionfind.py" "trie.py" "bit.py" "dis_lib.py" "alias.py" "utils.py"; do
1212
lib_path="libs/${file_name}"
1313
cat $lib_path >>code/main.py
1414
done

0 commit comments

Comments
 (0)