Skip to content

Commit fbebabf

Browse files
authored
Merge pull request #674 from boriel/refact/convert_optimizer_into_class
refact: convert optimizer module into class
2 parents 0e8017c + 473b17c commit fbebabf

13 files changed

Lines changed: 453 additions & 331 deletions

File tree

src/api/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_absolute_filename_path(fname: str) -> str:
8484
return os.path.realpath(os.path.expanduser(fname))
8585

8686

87-
def get_relative_filename_path(fname: str, current_dir: str = None) -> str:
87+
def get_relative_filename_path(fname: str, current_dir: str | None = None) -> str:
8888
"""Given an absolute path, returns it relative to the current directory,
8989
that is, if the file is in the same folder or any of it children, only
9090
the path from the current folder onwards is returned. Otherwise, the

src/arch/interface/optimizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class OptimizerInterface(ABC):
5+
"""Implements the Peephole Optimizer"""
6+
7+
@abstractmethod
8+
def init(self) -> None:
9+
pass
10+
11+
@abstractmethod
12+
def optimize(self, initial_memory: list[str]) -> str:
13+
"""This will remove useless instructions"""
14+
pass

src/arch/z80/optimizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .main import init, optimize
1+
from .main import Optimizer
22

3-
__all__ = "init", "optimize"
3+
__all__ = ("Optimizer",)

src/arch/z80/optimizer/basicblock.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from src.arch.z80.peephole import evaluator
1111

1212
from . import helpers
13-
from .common import JUMP_LABELS, LABELS
1413
from .cpustate import CPUState
1514
from .helpers import ALL_REGS
1615
from .labelinfo import LabelInfo
@@ -137,12 +136,12 @@ def labels(self) -> tuple[str, ...]:
137136
memory"""
138137
return tuple(cell.inst for cell in self.mem if cell.is_label)
139138

140-
def get_first_partition_idx(self) -> int | None:
139+
def get_first_partition_idx(self, jump_labels: set[str]) -> int | None:
141140
"""Returns the first position where this block can be
142141
partitioned or None if there's no such point
143142
"""
144143
for i, mem in enumerate(self):
145-
if i > 0 and mem.is_label and mem.inst in JUMP_LABELS:
144+
if i > 0 and mem.is_label and mem.inst in jump_labels:
146145
return i
147146

148147
if (mem.is_ender or mem.code in src.arch.z80.backend.common.ASMS) and i < len(self) - 1:
@@ -210,7 +209,7 @@ def add_goes_to(self, basic_block: BasicBlock | None) -> None:
210209
self.goes_to.add(basic_block)
211210
basic_block.comes_from.add(self)
212211

213-
def update_next_block(self):
212+
def update_next_block(self, labels: dict[str, LabelInfo]) -> None:
214213
"""If the last instruction of this block is a JP, JR or RET (with no
215214
conditions) then goes_to set contains just a
216215
single block
@@ -232,11 +231,11 @@ def update_next_block(self):
232231
if last.inst == "ret":
233232
return
234233

235-
if last.opers[0] not in LABELS.keys():
234+
if last.opers[0] not in labels.keys():
236235
__DEBUG__("INFO: %s is not defined. No optimization is done." % last.opers[0], 2)
237-
LABELS[last.opers[0]] = LabelInfo(last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
236+
labels[last.opers[0]] = LabelInfo(last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
238237

239-
n_block = LABELS[last.opers[0]].basic_block
238+
n_block = labels[last.opers[0]].basic_block
240239
self.add_goes_to(n_block)
241240

242241
def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:

src/arch/z80/optimizer/common.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

src/arch/z80/optimizer/flow_graph.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from src.api.debug import __DEBUG__
22

33
from .basicblock import BasicBlock, DummyBasicBlock
4-
from .common import JUMP_LABELS, LABELS
54
from .helpers import ALL_REGS
65
from .labelinfo import LabelInfo
6+
from .labels_dict import LabelsDict
77

88
__all__ = ("get_basic_blocks",)
99

1010

11-
def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock, BasicBlock]:
11+
def _split_block(block: BasicBlock, start_of_new_block: int, labels: LabelsDict) -> tuple[BasicBlock, BasicBlock]:
1212
assert 0 <= start_of_new_block < len(block), f"Invalid split pos: {start_of_new_block}"
1313
new_block = BasicBlock([])
1414
new_block.mem = block.mem[start_of_new_block:]
@@ -28,9 +28,9 @@ def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock,
2828
block.add_goes_to(new_block)
2929

3030
for i, mem in enumerate(new_block):
31-
if mem.is_label and mem.inst in LABELS:
32-
LABELS[mem.inst].basic_block = new_block
33-
LABELS[mem.inst].position = i
31+
if mem.is_label and mem.inst in labels:
32+
labels[mem.inst].basic_block = new_block
33+
labels[mem.inst].position = i
3434

3535
if block[-1].is_ender:
3636
if not block[-1].condition_flag: # If it's an unconditional jp, jr, call, ret
@@ -39,28 +39,32 @@ def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock,
3939
return block, new_block
4040

4141

42-
def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None:
42+
def _compute_calls(
43+
basic_blocks: list[BasicBlock],
44+
labels: LabelsDict,
45+
jump_labels: set[str],
46+
) -> None:
4347
calling_blocks: dict[BasicBlock, BasicBlock] = {}
4448

4549
# Compute which blocks use jump labels
4650
for bb in basic_blocks:
47-
if bb[-1].is_ender and (op := bb[-1].branch_arg) in LABELS:
48-
LABELS[op].used_by.add(bb)
51+
if bb[-1].is_ender and (op := bb[-1].branch_arg) in labels:
52+
labels[op].used_by.add(bb)
4953

5054
# For these blocks, add the referenced block in the goes_to
5155
for label in jump_labels:
52-
for bb in LABELS[label].used_by:
53-
bb.add_goes_to(LABELS[label].basic_block)
56+
for bb in labels[label].used_by:
57+
bb.add_goes_to(labels[label].basic_block)
5458

5559
# Annotate which blocks uses call (which should be the last instruction)
5660
for bb in basic_blocks:
5761
if bb[-1].inst != "call":
5862
continue
5963

6064
op = bb[-1].branch_arg
61-
if op in LABELS:
62-
LABELS[op].basic_block.called_by.add(bb)
63-
calling_blocks[bb] = LABELS[op].basic_block
65+
if op in labels:
66+
labels[op].basic_block.called_by.add(bb)
67+
calling_blocks[bb] = labels[op].basic_block
6468

6569
# For the annotated blocks, trace their goes_to, and their goes_to from
6670
# their goes_to and so on, until ret (unconditional or not) is found, and
@@ -91,7 +95,7 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
9195
pending.add((caller, bb.next))
9296

9397

94-
def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
98+
def _get_jump_labels(main_basic_block: BasicBlock, labels: LabelsDict) -> set[str]:
9599
"""Given the main basic block (which contain the entire program), populate
96100
the global JUMP_LABEL set with LABELS used by CALL, JR, JP (i.e JP LABEL0)
97101
Also updates the global LABELS index with the pertinent information.
@@ -103,8 +107,8 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
103107

104108
for i, mem in enumerate(main_basic_block):
105109
if mem.is_label:
106-
LABELS.pop(mem.inst)
107-
LABELS[mem.inst] = LabelInfo(
110+
labels.pop(mem.inst)
111+
labels[mem.inst] = LabelInfo(
108112
label=mem.inst, addr=i, basic_block=main_basic_block, position=i # Unknown yet
109113
)
110114
continue
@@ -118,28 +122,32 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
118122

119123
jump_labels.add(lbl)
120124

121-
if lbl not in LABELS:
125+
if lbl not in labels:
122126
__DEBUG__(f"INFO: {lbl} is not defined. No optimization is done.", 2)
123-
LABELS[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
127+
labels[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
124128

125129
return jump_labels
126130

127131

128-
def get_basic_blocks(block: BasicBlock) -> list[BasicBlock]:
132+
def get_basic_blocks(
133+
block: BasicBlock,
134+
labels: LabelsDict,
135+
jump_labels: set[str],
136+
) -> list[BasicBlock]:
129137
"""If a block is not partitionable, returns a list with the same block.
130138
Otherwise, returns a list with the resulting blocks.
131139
"""
132140
result: list[BasicBlock] = [block]
133-
JUMP_LABELS.clear()
134-
JUMP_LABELS.update(get_jump_labels(block))
141+
jump_labels.clear()
142+
jump_labels.update(_get_jump_labels(block, labels))
135143

136144
# Split basic blocks per label or branch instruction
137-
split_pos = block.get_first_partition_idx()
145+
split_pos = block.get_first_partition_idx(jump_labels)
138146
while split_pos is not None:
139-
_, block = split_block(block, split_pos)
147+
_, block = _split_block(block, split_pos, labels)
140148
result.append(block)
141-
split_pos = block.get_first_partition_idx()
149+
split_pos = block.get_first_partition_idx(jump_labels)
142150

143-
compute_calls(result, JUMP_LABELS)
151+
_compute_calls(result, labels, jump_labels)
144152

145153
return result

src/arch/z80/optimizer/helpers.py

Lines changed: 97 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,117 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Any, Iterable, TypeVar
3+
from typing import Any, Final, Iterable, TypeVar, cast
4+
5+
from . import patterns
6+
7+
__all__ = (
8+
"ALL_REGS",
9+
"END_PROGRAM_LABEL",
10+
"init",
11+
"new_tmp_val",
12+
"new_tmp_val16",
13+
"new_tmp_val16_from_label",
14+
"is_unknown",
15+
"is_unknown8",
16+
"is_unknown16",
17+
"get_orig_label_from_unknown16",
18+
"get_L_from_unknown_value",
19+
"get_H_from_unknown_value",
20+
"is_mem_access",
21+
"is_number",
22+
"is_label",
23+
"valnum",
24+
"simplify_arg",
25+
"simplify_asm_args",
26+
"is_register",
27+
"is_8bit_normal_register",
28+
"is_8bit_idx_register",
29+
"is_8bit_oper_register",
30+
"is_16bit_normal_register",
31+
"is_16bit_idx_register",
32+
"is_16bit_composed_register",
33+
"is_16bit_oper_register",
34+
"LO16",
35+
"HI16",
36+
"single_registers",
37+
"idx_args",
38+
"LO16_val",
39+
"HI16_val",
40+
"dict_intersection",
41+
)
442

5-
from . import common, patterns
643

744
T = TypeVar("T")
845
K = TypeVar("K")
946

1047

1148
# All 'single' registers (even f FLAG one). SP is not decomposable so it's 'single' already
12-
ALL_REGS = {"a", "b", "c", "d", "e", "f", "h", "l", "ixh", "ixl", "iyh", "iyl", "r", "i", "sp"}
49+
ALL_REGS: Final[frozenset[str]] = frozenset(
50+
[
51+
"a",
52+
"b",
53+
"c",
54+
"d",
55+
"e",
56+
"f",
57+
"h",
58+
"l",
59+
"ixh",
60+
"ixl",
61+
"iyh",
62+
"iyl",
63+
"r",
64+
"i",
65+
"sp",
66+
]
67+
)
1368

1469
# The set of all registers as they can appear in any instruction as operands
15-
REGS_OPER_SET = {
16-
"a",
17-
"b",
18-
"c",
19-
"d",
20-
"e",
21-
"h",
22-
"l",
23-
"bc",
24-
"de",
25-
"hl",
26-
"sp",
27-
"ix",
28-
"iy",
29-
"ixh",
30-
"ixl",
31-
"iyh",
32-
"iyl",
33-
"af",
34-
"af'",
35-
"i",
36-
"r",
37-
}
70+
REGS_OPER_SET: Final[frozenset[str]] = frozenset(
71+
[
72+
"a",
73+
"b",
74+
"c",
75+
"d",
76+
"e",
77+
"h",
78+
"l",
79+
"bc",
80+
"de",
81+
"hl",
82+
"sp",
83+
"ix",
84+
"iy",
85+
"ixh",
86+
"ixl",
87+
"iyh",
88+
"iyl",
89+
"af",
90+
"af'",
91+
"i",
92+
"r",
93+
]
94+
)
3895

3996
# Instructions that marks the end of a basic block (any branching instruction)
40-
BLOCK_ENDERS = {"jr", "jp", "call", "ret", "reti", "retn", "djnz", "rst"}
97+
BLOCK_ENDERS: Final[frozenset[str]] = frozenset(["jr", "jp", "call", "ret", "reti", "retn", "djnz", "rst"])
98+
UNKNOWN_PREFIX: Final[str] = "*UNKNOWN_"
99+
END_PROGRAM_LABEL: Final[str] = "__END_PROGRAM" # Label for end program
100+
HL_SEP: Final[str] = "|" # Hi/Low separator
101+
_RAND_COUNT: int = 0 # Counter for unknown values
41102

42-
UNKNOWN_PREFIX = "*UNKNOWN_"
43-
END_PROGRAM_LABEL = "__END_PROGRAM" # Label for end program
44-
HL_SEP = "|" # Hi/Low separator
103+
104+
def init() -> None:
105+
global _RAND_COUNT
106+
_RAND_COUNT = 0
45107

46108

47109
def new_tmp_val() -> str:
48110
"""Generates an 8-bit unknown value"""
49-
common.RAND_COUNT += 1
50-
return f"{UNKNOWN_PREFIX}{common.RAND_COUNT}"
111+
global _RAND_COUNT
112+
113+
_RAND_COUNT += 1
114+
return f"{UNKNOWN_PREFIX}{_RAND_COUNT}"
51115

52116

53117
def new_tmp_val16() -> str:
@@ -390,7 +454,7 @@ def LO16_val(x: int | str | None) -> str:
390454
if not is_unknown(x):
391455
return new_tmp_val()
392456

393-
return x.split(HL_SEP)[-1]
457+
return cast(str, x).split(HL_SEP)[-1]
394458

395459

396460
def HI16_val(x: int | str | None) -> str:

0 commit comments

Comments
 (0)