华容道高性能计算引擎
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

116 lines
3.7 KiB

#!/usr/bin/env python3
import io
import numpy as np
from lxml import etree
from dataclasses import dataclass
from abc import ABC, abstractmethod
import matplotlib.colors as mcolors
@dataclass
class Config:
is_yed: bool
colors: list[str]
pretty_xml: bool = True
yed_xmlns: str = 'http://www.yworks.com/xml/graphml'
yed_node_type: str = 'ellipse'
yed_node_width: int = 50
yed_node_height: int = 50
yed_node_font_size: int = 10
class INode(ABC):
@staticmethod
def _add_yed_key(graphml: etree.Element) -> None:
etree.SubElement(graphml, 'key', attrib={
'id': 'info',
'for': 'node',
'yfiles.type': 'nodegraphics'
})
@staticmethod
def _yed_render(cfg: Config, color: str, text: str) -> etree.Element:
yed_ns = f'{{{cfg.yed_xmlns}}}'
info = etree.Element('data', attrib={'key': 'info'})
shape = etree.SubElement(info, f'{yed_ns}ShapeNode')
etree.SubElement(shape, f'{yed_ns}Fill', attrib={
'color': color
})
etree.SubElement(shape, f'{yed_ns}Shape', attrib={
'type': cfg.yed_node_type
})
etree.SubElement(shape, f'{yed_ns}Geometry', attrib={
'height': str(cfg.yed_node_width),
'width': str(cfg.yed_node_height),
})
label = etree.SubElement(shape, f'{yed_ns}NodeLabel', attrib={
'fontSize': str(cfg.yed_node_font_size),
'modelName': 'internal'
})
label.text = text
return info
@staticmethod
@abstractmethod
def add_keys(graphml: etree.Element, cfg: Config) -> None:
pass
@abstractmethod
def render(self, cfg: Config) -> etree.Element:
pass
class IEdge(ABC):
@staticmethod
@abstractmethod
def add_keys(graphml: etree.Element, cfg: Config) -> None:
pass
@abstractmethod
def render(self, cfg: Config) -> etree.Element:
pass
class GraphML:
def __init__(self, tag: str, nodes: list[INode], edges: list[IEdge]):
self.__tag = tag
self.__nodes = nodes
self.__edges = edges
assert len(nodes) > 0 and len(edges) > 0
@staticmethod
def __nsmap(cfg: Config) -> dict[str | None, str]:
return {
None: 'http://graphml.graphdrawing.org/xmlns',
'xsi': 'http://www.w3.org/2001/XMLSchema-instance',
**({'y': cfg.yed_xmlns} if cfg.is_yed else {})
}
@staticmethod
def build_colors(num: int, bwr: list[str]) -> list[str]:
cmap = mcolors.LinearSegmentedColormap.from_list('custom_bwr', bwr)
return [mcolors.to_hex(cmap(x)) for x in np.linspace(0, 1, num)]
def __build_graphml(self, cfg: Config) -> etree.Element:
graphml = etree.Element('graphml', nsmap=self.__nsmap(cfg))
graphml.set(
'{http://www.w3.org/2001/XMLSchema-instance}schemaLocation',
'http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd'
)
self.__nodes[0].add_keys(graphml, cfg)
self.__edges[0].add_keys(graphml, cfg)
graph = etree.SubElement(graphml, 'graph', id=self.__tag, edgedefault='undirected')
[graph.append(x.render(cfg)) for x in self.__nodes]
[graph.append(x.render(cfg)) for x in self.__edges]
return graphml
def save_graphml(self, file: str, cfg: Config) -> None:
xml_tree = etree.ElementTree(self.__build_graphml(cfg))
fake_output = io.BytesIO()
xml_tree.write(fake_output, pretty_print=cfg.pretty_xml, xml_declaration=True, encoding='utf-8')
content = fake_output.getvalue().decode('utf-8')
with open(file, 'w') as fp:
fp.write(content.replace('
', '
'))