Skip to content

Commit c4f8e98

Browse files
committed
feat: add SVG text format and use formatter in Sphinx directive
- Register "svg" format in Formatter (DOT → SVG decoded as str) - Refactor Sphinx directive to use formatter.render() for both SVG and Mermaid instead of calling DotGraphMachine/MermaidGraphMachine directly - Update _prepare_svg and _resolve_target to work with str (not bytes)
1 parent 7089a17 commit c4f8e98

3 files changed

Lines changed: 70 additions & 37 deletions

File tree

statemachine/contrib/diagram/formatter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def _render_dot(machine_or_class: "MachineRef") -> str:
106106
return DotGraphMachine(machine_or_class).get_graph().to_string() # type: ignore[no-any-return]
107107

108108

109+
@formatter.register_format("svg")
110+
def _render_svg(machine_or_class: "MachineRef") -> str:
111+
from statemachine.contrib.diagram import DotGraphMachine
112+
113+
svg_bytes: bytes = DotGraphMachine(machine_or_class).get_graph().create_svg() # type: ignore[attr-defined]
114+
return svg_bytes.decode("utf-8")
115+
116+
109117
@formatter.register_format("mermaid")
110118
def _render_mermaid(machine_or_class: "MachineRef") -> str:
111119
from statemachine.contrib.diagram import MermaidGraphMachine

statemachine/contrib/diagram/sphinx_ext.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _parse_events(value: str) -> list[str]:
3939

4040

4141
# Match the outer <svg ...>...</svg> element, stripping XML prologue/DOCTYPE.
42-
_SVG_TAG_RE = re.compile(rb"(<svg\b.*</svg>)", re.DOTALL)
42+
_SVG_TAG_RE = re.compile(r"(<svg\b.*</svg>)", re.DOTALL)
4343

4444
# Match fixed width/height attributes (e.g. width="702pt" height="170pt").
4545
_SVG_WIDTH_RE = re.compile(r'\bwidth="([^"]*(?:pt|px))"')
@@ -79,7 +79,7 @@ def run(self) -> list[nodes.Node]:
7979
qualname = self.arguments[0]
8080

8181
try:
82-
from statemachine.contrib.diagram import DotGraphMachine
82+
from statemachine.contrib.diagram import formatter
8383
from statemachine.contrib.diagram import import_sm
8484

8585
sm_class = import_sm(qualname)
@@ -101,11 +101,10 @@ def run(self) -> list[nodes.Node]:
101101
output_format = self.options.get("format", "").strip().lower()
102102

103103
if output_format == "mermaid":
104-
return self._run_mermaid(machine, qualname)
104+
return self._run_mermaid(machine, formatter, qualname)
105105

106106
try:
107-
graph = DotGraphMachine(machine).get_graph()
108-
svg_bytes: bytes = graph.create_svg() # type: ignore[attr-defined]
107+
svg_text = formatter.render(machine, "svg")
109108
except Exception as exc:
110109
return [
111110
self.state_machine.reporter.warning(
@@ -114,12 +113,12 @@ def run(self) -> list[nodes.Node]:
114113
)
115114
]
116115

117-
svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_bytes)
116+
svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_text)
118117
svg_styles = self._build_svg_styles(intrinsic_width, intrinsic_height)
119118
svg_tag = svg_tag.replace("<svg ", f"<svg {svg_styles} ", 1)
120119

121120
alt_text = html_mod.escape(self.options.get("alt", qualname.rsplit(".", 1)[-1]))
122-
target = self._resolve_target(svg_bytes)
121+
target = self._resolve_target(svg_text)
123122

124123
img_html = f'<div role="img" aria-label="{alt_text}">{svg_tag}</div>'
125124
if target:
@@ -149,12 +148,10 @@ def run(self) -> list[nodes.Node]:
149148

150149
return [raw_node]
151150

152-
def _run_mermaid(self, machine: object, qualname: str) -> list[nodes.Node]:
151+
def _run_mermaid(self, machine: object, formatter: Any, qualname: str) -> list[nodes.Node]:
153152
"""Render a Mermaid diagram using sphinxcontrib-mermaid's node type."""
154153
try:
155-
from statemachine.contrib.diagram import MermaidGraphMachine
156-
157-
mermaid_src = MermaidGraphMachine(machine).get_mermaid()
154+
mermaid_src = formatter.render(machine, "mermaid")
158155
except Exception as exc:
159156
return [
160157
self.state_machine.reporter.warning(
@@ -190,10 +187,10 @@ def _run_mermaid(self, machine: object, qualname: str) -> list[nodes.Node]:
190187
self.add_name(node)
191188
return [node]
192189

193-
def _prepare_svg(self, svg_bytes: bytes) -> tuple[str, str, str]:
190+
def _prepare_svg(self, svg_text: str) -> tuple[str, str, str]:
194191
"""Extract the ``<svg>`` element and its intrinsic dimensions."""
195-
match = _SVG_TAG_RE.search(svg_bytes)
196-
svg_tag = match.group(1).decode("utf-8") if match else svg_bytes.decode("utf-8")
192+
match = _SVG_TAG_RE.search(svg_text)
193+
svg_tag = match.group(1) if match else svg_text
197194

198195
width_match = _SVG_WIDTH_RE.search(svg_tag)
199196
height_match = _SVG_HEIGHT_RE.search(svg_tag)
@@ -235,7 +232,7 @@ def _build_svg_styles(self, intrinsic_width: str, intrinsic_height: str) -> str:
235232

236233
return f'style="{"; ".join(parts)}"'
237234

238-
def _resolve_target(self, svg_bytes: bytes) -> str:
235+
def _resolve_target(self, svg_text: str) -> str:
239236
"""Return the href for the wrapper ``<a>`` tag, if any.
240237
241238
When ``:target:`` is given without a value (or as empty string), the
@@ -258,8 +255,8 @@ def _resolve_target(self, svg_bytes: bytes) -> str:
258255
outdir = os.path.join(self.env.app.outdir, "_images")
259256
os.makedirs(outdir, exist_ok=True)
260257
outpath = os.path.join(outdir, filename)
261-
with open(outpath, "wb") as f:
262-
f.write(svg_bytes)
258+
with open(outpath, "w", encoding="utf-8") as f:
259+
f.write(svg_text)
263260

264261
return f"/_images/{filename}"
265262

tests/test_contrib_diagram.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -889,47 +889,47 @@ def _make_directive(self, options=None):
889889
return directive
890890

891891
def test_strips_xml_prologue(self):
892-
svg_bytes = (
893-
b'<?xml version="1.0"?>\n<!DOCTYPE svg>\n'
894-
b'<svg width="100pt" height="50pt" viewBox="0 0 100 50">'
895-
b"<circle/></svg>"
892+
svg_text = (
893+
'<?xml version="1.0"?>\n<!DOCTYPE svg>\n'
894+
'<svg width="100pt" height="50pt" viewBox="0 0 100 50">'
895+
"<circle/></svg>"
896896
)
897897
directive = self._make_directive()
898-
svg_tag, _, _ = directive._prepare_svg(svg_bytes)
898+
svg_tag, _, _ = directive._prepare_svg(svg_text)
899899

900900
assert not svg_tag.startswith("<?xml")
901901
assert svg_tag.startswith("<svg")
902902
assert "</svg>" in svg_tag
903903

904904
def test_extracts_intrinsic_dimensions(self):
905-
svg_bytes = b'<svg width="702pt" height="170pt"><rect/></svg>'
905+
svg_text = '<svg width="702pt" height="170pt"><rect/></svg>'
906906
directive = self._make_directive()
907-
_, w, h = directive._prepare_svg(svg_bytes)
907+
_, w, h = directive._prepare_svg(svg_text)
908908

909909
assert w == "702pt"
910910
assert h == "170pt"
911911

912912
def test_removes_fixed_dimensions(self):
913-
svg_bytes = b'<svg width="702pt" height="170pt" viewBox="0 0 702 170"><rect/></svg>'
913+
svg_text = '<svg width="702pt" height="170pt" viewBox="0 0 702 170"><rect/></svg>'
914914
directive = self._make_directive()
915-
svg_tag, _, _ = directive._prepare_svg(svg_bytes)
915+
svg_tag, _, _ = directive._prepare_svg(svg_text)
916916

917917
assert 'width="702pt"' not in svg_tag
918918
assert 'height="170pt"' not in svg_tag
919919
assert "viewBox" in svg_tag
920920

921921
def test_handles_no_dimensions(self):
922-
svg_bytes = b'<svg viewBox="0 0 100 50"><rect/></svg>'
922+
svg_text = '<svg viewBox="0 0 100 50"><rect/></svg>'
923923
directive = self._make_directive()
924-
_, w, h = directive._prepare_svg(svg_bytes)
924+
_, w, h = directive._prepare_svg(svg_text)
925925

926926
assert w == ""
927927
assert h == ""
928928

929929
def test_handles_px_dimensions(self):
930-
svg_bytes = b'<svg width="200px" height="100px"><rect/></svg>'
930+
svg_text = '<svg width="200px" height="100px"><rect/></svg>'
931931
directive = self._make_directive()
932-
_, w, h = directive._prepare_svg(svg_bytes)
932+
_, w, h = directive._prepare_svg(svg_text)
933933

934934
assert w == "200px"
935935
assert h == "100px"
@@ -1030,15 +1030,15 @@ def _make_directive(self, options=None, tmp_path=None):
10301030

10311031
def test_no_target_option(self):
10321032
directive = self._make_directive()
1033-
assert directive._resolve_target(b"<svg/>") == ""
1033+
assert directive._resolve_target("<svg/>") == ""
10341034

10351035
def test_explicit_target_url(self):
10361036
directive = self._make_directive({"target": "https://example.com/diagram.svg"})
1037-
assert directive._resolve_target(b"<svg/>") == "https://example.com/diagram.svg"
1037+
assert directive._resolve_target("<svg/>") == "https://example.com/diagram.svg"
10381038

10391039
def test_empty_target_generates_file(self, tmp_path):
10401040
directive = self._make_directive({"target": ""}, tmp_path=tmp_path)
1041-
svg_data = b"<svg><rect/></svg>"
1041+
svg_data = "<svg><rect/></svg>"
10421042
result = directive._resolve_target(svg_data)
10431043

10441044
assert result.startswith("/_images/statemachine-")
@@ -1048,21 +1048,21 @@ def test_empty_target_generates_file(self, tmp_path):
10481048
images_dir = tmp_path / "_images"
10491049
svg_files = list(images_dir.glob("statemachine-*.svg"))
10501050
assert len(svg_files) == 1
1051-
assert svg_files[0].read_bytes() == svg_data
1051+
assert svg_files[0].read_text(encoding="utf-8") == svg_data
10521052

10531053
def test_empty_target_deterministic_filename(self, tmp_path):
10541054
"""Same qualname + events produces the same filename."""
10551055
directive1 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path)
10561056
directive2 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path)
1057-
result1 = directive1._resolve_target(b"<svg>1</svg>")
1058-
result2 = directive2._resolve_target(b"<svg>2</svg>")
1057+
result1 = directive1._resolve_target("<svg>1</svg>")
1058+
result2 = directive2._resolve_target("<svg>2</svg>")
10591059
assert result1 == result2
10601060

10611061
def test_different_events_different_filename(self, tmp_path):
10621062
"""Different events produce different filenames."""
10631063
d1 = self._make_directive({"target": "", "events": "a"}, tmp_path=tmp_path)
10641064
d2 = self._make_directive({"target": "", "events": "b"}, tmp_path=tmp_path)
1065-
assert d1._resolve_target(b"<svg/>") != d2._resolve_target(b"<svg/>")
1065+
assert d1._resolve_target("<svg/>") != d2._resolve_target("<svg/>")
10661066

10671067

10681068
class TestDirectiveRun:
@@ -1347,6 +1347,33 @@ def test_render_dot(self):
13471347
result = formatter.render(TrafficLightMachine, "dot")
13481348
assert result.startswith("digraph TrafficLightMachine {")
13491349

1350+
def test_render_svg(self):
1351+
from statemachine.contrib.diagram import formatter
1352+
1353+
from tests.examples.traffic_light_machine import TrafficLightMachine
1354+
1355+
result = formatter.render(TrafficLightMachine, "svg")
1356+
assert isinstance(result, str)
1357+
assert "<svg" in result
1358+
assert "green" in result
1359+
1360+
def test_render_svg_instance(self):
1361+
from statemachine.contrib.diagram import formatter
1362+
1363+
from tests.examples.traffic_light_machine import TrafficLightMachine
1364+
1365+
sm = TrafficLightMachine()
1366+
result = formatter.render(sm, "svg")
1367+
assert "<svg" in result
1368+
# Active state should be highlighted
1369+
assert "turquoise" in result
1370+
1371+
def test_format_svg(self):
1372+
from tests.examples.traffic_light_machine import TrafficLightMachine
1373+
1374+
result = f"{TrafficLightMachine:svg}"
1375+
assert "<svg" in result
1376+
13501377
def test_render_md(self):
13511378
from statemachine.contrib.diagram import formatter
13521379

@@ -1398,6 +1425,7 @@ def test_supported_formats(self):
13981425

13991426
fmts = formatter.supported_formats()
14001427
assert "dot" in fmts
1428+
assert "svg" in fmts
14011429
assert "mermaid" in fmts
14021430
assert "md" in fmts
14031431
assert "markdown" in fmts

0 commit comments

Comments
 (0)