From 77c3922897f18dd88fa46c219d8b7b62e354fdfe Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sat, 28 Jun 2025 19:16:34 +0800 Subject: [PATCH] perf: simplify graph combine algorithm --- misc/all-graph/04-combine_layout.py | 183 +++++++++++----------------- misc/all-graph/compare.py | 18 +-- 2 files changed, 84 insertions(+), 117 deletions(-) diff --git a/misc/all-graph/04-combine_layout.py b/misc/all-graph/04-combine_layout.py index af6a262..6274ea8 100755 --- a/misc/all-graph/04-combine_layout.py +++ b/misc/all-graph/04-combine_layout.py @@ -5,132 +5,97 @@ import igraph as ig import multiprocessing -def split_layer(graph: ig.Graph, step_a: int, step_b: int) -> tuple[list[set[int]], list[set[int]]]: - - def extend_from(node: ig.Vertex) -> tuple[set[ig.Vertex], set[ig.Vertex]]: - assert node['step'] == step_a - - scan_a2b = True - union_a, union_b = set(), set() - curr_set, next_set = set([node]), set() - - while curr_set: - for layout in curr_set: - for neigh in layout.neighbors(): - if scan_a2b and neigh['step'] == step_b and neigh not in union_b: - next_set.add(neigh) - elif not scan_a2b and neigh['step'] == step_a and neigh not in union_a: - next_set.add(neigh) - - union_a.update(curr_set if scan_a2b else next_set) - union_b.update(next_set if scan_a2b else curr_set) - scan_a2b = not scan_a2b - curr_set = next_set - next_set = set() - - return union_a, union_b - - assert step_a + 1 == step_b - - layer_a = set(x for x in graph.vs if x['step'] == step_a) - layer_b = set(x for x in graph.vs if x['step'] == step_b) - layer_num_a, layer_num_b = len(layer_a), len(layer_b) - assert layer_num_a > 0 and layer_num_b > 0 - - data_a: list[set[int]] = [] - data_b: list[set[int]] = [] - special_set = set() - while layer_a: - union_a, union_b = extend_from(layer_a.pop()) - if len(union_b) == 0: - assert len(union_a) == 1 - special_set.update(union_a) - continue - layer_a -= union_a - layer_b -= union_b - data_a.append(set(x.index for x in union_a)) - data_b.append(set(x.index for x in union_b)) - data_a.append(set(x.index for x in special_set)) - - assert len(layer_a) == 0 and len(layer_b) == 0 - assert sum(len(x) for x in data_a) == layer_num_a - assert sum(len(x) for x in data_b) == layer_num_b - return data_a, [x for x in data_b if x] - - -def build_multi_set(unions_a: list[set[int]], unions_b: list[set[int]]) -> list[set[int]]: - assert set(y for x in unions_a for y in x) == set(y for x in unions_b for y in x) - - release = [] - for curr in unions_a: - for other in unions_b: - mid = curr.intersection(other) - if mid: - release.append(mid) - curr -= mid - other -= mid - assert len(curr) == 0 +def split_adjacent_layers(graph: ig.Graph, step: int) -> tuple[list[set[int]], list[set[int]]]: + layouts = graph.vs.select(step_in=[step, step + 1]) + code_map = {x['code']: x.index for x in layouts} + to_index = lambda iter: {code_map[x['code']] for x in iter} + + layer_curr, layer_next = [], [] + g_focus = graph.subgraph(layouts) + isolated = g_focus.vs.select(_degree=0) + if isolated: + assert {x['step'] for x in isolated} == {step} + layer_curr = [to_index(isolated)] + g_focus.delete_vertices(isolated) + + for component in g_focus.connected_components(): + component = [g_focus.vs[x] for x in component] + layer_curr.append(to_index(x for x in component if x['step'] == step)) + layer_next.append(to_index(x for x in component if x['step'] == step + 1)) + return layer_curr, layer_next + + +def apply_layer_unions(unions_a: list[set[int]], unions_b: list[set[int]]) -> list[set[int]]: + layer_data = {x for u in unions_a for x in u} + assert layer_data == {x for u in unions_b for x in u} + + unions = [] + for curr_union in unions_a: + for other_union in unions_b: + if union := curr_union.intersection(other_union): + unions.append(union) + curr_union -= union + other_union -= union + assert len(curr_union) == 0 assert set(len(x) for x in unions_a) == {0} assert set(len(x) for x in unions_b) == {0} - return release + assert layer_data == {x for u in unions for x in u} + return unions -def do_split(g: ig.Graph) -> ig.Graph: - max_step = max(x['step'] for x in g.vs) - - layer_data = [[] for _ in range(max_step + 1)] - layer_data[0].append([set(x.index for x in g.vs if x['step'] == 0)]) +def build_all_unions(graph: ig.Graph) -> list[set[int]]: + max_step = max(graph.vs['step']) + layer_unions = [[{x.index for x in graph.vs if x['step'] == 0}]] for step in range(0, max_step): - data_a, data_b = split_layer(g, step, step + 1) - layer_data[step].append(data_a) - layer_data[step + 1].append(data_b) - layer_data[max_step].append([set(x.index for x in g.vs if x['step'] == max_step)]) + layer_unions.extend(list(split_adjacent_layers(graph, step))) + layer_unions.append([{x.index for x in graph.vs if x['step'] == max_step}]) + assert len(layer_unions) == (max_step + 1) * 2 - assert len(layer_data) == max_step + 1 - assert set(len(x) for x in layer_data) == {2} + all_unions = [] + for idx in range(0, len(layer_unions), 2): + all_unions.extend(apply_layer_unions(*layer_unions[idx:idx + 2])) + for unions in all_unions: + assert len(unions) > 0 + assert len(set(graph.vs[x]['step'] for x in unions)) == 1 + return sorted(all_unions, key=lambda u: min(graph.vs[x]['code'] for x in u)) - unions = {} - for step in range(0, max_step + 1): - layer_unions = build_multi_set(layer_data[step][0], layer_data[step][1]) - for union in layer_unions: - assert len(set(g.vs[x]['step'] for x in union)) == 1 - codes = [g.vs[x]['code'] for x in union] - unions[min(codes)] = union - assert sorted(y for x in unions.values() for y in x) == list(range(g.vcount())) +def combine_graph(graph: ig.Graph) -> ig.Graph: + unions = build_all_unions(graph) + union_idx = sorted((x, idx) for idx, u in enumerate(unions) for x in u) - combine_info = [-1 for _ in range(g.vcount())] - for index, key in enumerate(sorted(unions)): - for x in unions[key]: - combine_info[x] = index + combine_idx = [x for _, x in union_idx] + assert len(combine_idx) == graph.vcount() + assert set(combine_idx) == set(range(len(unions))) - assert len(combine_info) == g.vcount() - assert set(combine_info) == set(range(len(unions))) + id_len = len(str(len(unions) - 1)) + graph.vs['id'] = [f'U{x:0{id_len}}' for x in combine_idx] - g.contract_vertices(combine_info, combine_attrs={'step': 'first', 'code': list}) - assert set(x.is_loop() for x in g.es) == {False} - g.simplify(multiple=True) - return g + graph.contract_vertices(combine_idx, combine_attrs={'id': 'first', 'step': 'first', 'code': list}) + assert [int(x.removeprefix('U')) for x in graph.vs['id']] == list(range(len(unions))) + assert not any(x.is_loop() for x in graph.es) + graph.simplify(multiple=True) + return graph def do_combine(input: str, output: str) -> None: print(f'Start combining: {input}') - g = do_split(ig.Graph.Read_Pickle(input)) - g.write_pickle(output) - g_mod = g.copy() - for x in g_mod.vs: - x['code'] = '+'.join(x['code']) - g_mod = do_split(g_mod) + g_raw = (graph := combine_graph(ig.Graph.Read_Pickle(input))).copy() + graph.vs['codes'] = graph.vs['code'] + del graph.vs['code'] + graph.write_pickle(output) # save combined graph + + g_raw.vs['code'] = g_raw.vs['id'] # modify as origin format + g_mod = combine_graph(g_raw.copy()) - assert g.vcount() == g_mod.vcount() - assert g.ecount() == g_mod.ecount() - for index in range(g.vcount()): - assert len(g_mod.vs[index]['code']) == 1 - assert g.vs[index]['step'] == g_mod.vs[index]['step'] - assert '+'.join(g.vs[index]['code']) == g_mod.vs[index]['code'][0] - assert g.isomorphic(g_mod) + assert g_raw.vcount() == g_mod.vcount() + assert g_raw.ecount() == g_mod.ecount() + assert all(x['code'] == [x['id']] for x in g_mod.vs) + assert g_raw.vs['step'] == g_mod.vs['step'] + assert g_raw.vs['code'] == g_mod.vs['id'] + assert g_raw.isomorphic(g_mod) def combine_all(ig_dir: str, output_dir: str) -> None: @@ -141,6 +106,6 @@ def combine_all(ig_dir: str, output_dir: str) -> None: pool.join() -if __name__ == "__main__": +if __name__ == '__main__': os.makedirs('output-combine', exist_ok=True) combine_all('output-ig', 'output-combine') diff --git a/misc/all-graph/compare.py b/misc/all-graph/compare.py index 33f16be..d065e59 100644 --- a/misc/all-graph/compare.py +++ b/misc/all-graph/compare.py @@ -7,14 +7,18 @@ import igraph as ig def load_legacy(file: str) -> ig.Graph: g = ig.Graph.Read_Pickle(file) for node in g.vs: - node['code'] = sorted(node['code']) + assert sorted(node['code']) == node['code'] + node['codes'] = node['code'] + del g.vs['code'] return g def load_modern(file: str) -> ig.Graph: g = ig.Graph.Read_Pickle(file) + assert [int(x.removeprefix('U')) for x in g.vs['id']] == list(range(g.vcount())) for node in g.vs: - assert sorted(node['code']) == sorted(node['code']) + assert sorted(node['codes']) == node['codes'] + del g.vs['id'] return g @@ -26,15 +30,13 @@ def compare(g1: ig.Graph, g2: ig.Graph) -> None: assert {len(x.attributes()) for x in g1.es} == {0} assert {len(x.attributes()) for x in g2.es} == {0} - data_a = {min(x['code']): x.attributes() for x in g1.vs} - data_b = {min(x['code']): x.attributes() for x in g2.vs} + data_a = {min(x['codes']): x.attributes() for x in g1.vs} + data_b = {min(x['codes']): x.attributes() for x in g2.vs} assert data_a == data_b if __name__ == '__main__': - for name in sorted(os.listdir('output-combine')): - if '_' not in name: - continue - g1 = load_legacy(f'combined/{name.split('_')[1]}') + for name in sorted(os.listdir('output-combine-raw')): + g1 = load_legacy(f'output-combine-raw/{name}') g2 = load_modern(f'output-combine/{name}') compare(g1, g2)