Skip to content

Commit c927e37

Browse files
committed
refact: clean up arch.z80.optimizer code
1 parent ef7df2d commit c927e37

8 files changed

Lines changed: 414 additions & 402 deletions

File tree

src/arch/z80/optimizer/__init__.py

Lines changed: 1 addition & 243 deletions
Original file line numberDiff line numberDiff line change
@@ -1,243 +1 @@
1-
# -*- coding: utf-8 -*-
2-
3-
from src.api.config import OPTIONS
4-
from src.api.debug import __DEBUG__
5-
from src.api.utils import flatten_list
6-
7-
from ..peephole import engine
8-
from . import basicblock
9-
from .basicblock import DummyBasicBlock
10-
from .common import JUMP_LABELS, LABELS, MEMORY
11-
from .helpers import ALL_REGS, END_PROGRAM_LABEL
12-
from .labelinfo import LabelInfo
13-
from .patterns import RE_LABEL, RE_PRAGMA
14-
15-
16-
def init():
17-
LABELS.clear()
18-
JUMP_LABELS.clear()
19-
20-
LABELS["*START*"] = LabelInfo("*START*", 0, DummyBasicBlock(ALL_REGS, ALL_REGS)) # Special START BLOCK
21-
LABELS["*__END_PROGRAM*"] = LabelInfo("__END_PROGRAM", 0, DummyBasicBlock(ALL_REGS, list("bc")))
22-
23-
# SOME Global modules initialization
24-
LABELS["__ADDF"] = LabelInfo("__ADDF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
25-
LABELS["__SUBF"] = LabelInfo("__SUBF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
26-
LABELS["__DIVF"] = LabelInfo("__DIVF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
27-
LABELS["__MULF"] = LabelInfo("__MULF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
28-
LABELS["__GEF"] = LabelInfo("__GEF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
29-
LABELS["__GTF"] = LabelInfo("__GTF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
30-
LABELS["__EQF"] = LabelInfo("__EQF", 0, DummyBasicBlock(ALL_REGS, list("aedbc")))
31-
LABELS["__STOREF"] = LabelInfo("__STOREF", 0, DummyBasicBlock(ALL_REGS, list("hlaedbc")))
32-
LABELS["PRINT_AT"] = LabelInfo("PRINT_AT", 0, DummyBasicBlock(ALL_REGS, list("a")))
33-
LABELS["INK"] = LabelInfo("INK", 0, DummyBasicBlock(ALL_REGS, list("a")))
34-
LABELS["INK_TMP"] = LabelInfo("INK_TMP", 0, DummyBasicBlock(ALL_REGS, list("a")))
35-
LABELS["PAPER"] = LabelInfo("PAPER", 0, DummyBasicBlock(ALL_REGS, list("a")))
36-
LABELS["PAPER_TMP"] = LabelInfo("PAPER_TMP", 0, DummyBasicBlock(ALL_REGS, list("a")))
37-
LABELS["RND"] = LabelInfo("RND", 0, DummyBasicBlock(ALL_REGS, []))
38-
LABELS["INKEY"] = LabelInfo("INKEY", 0, DummyBasicBlock(ALL_REGS, []))
39-
LABELS["PLOT"] = LabelInfo("PLOT", 0, DummyBasicBlock(ALL_REGS, ["a"]))
40-
LABELS["DRAW"] = LabelInfo("DRAW", 0, DummyBasicBlock(ALL_REGS, ["h", "l"]))
41-
LABELS["DRAW3"] = LabelInfo("DRAW3", 0, DummyBasicBlock(ALL_REGS, list("abcde")))
42-
LABELS["__ARRAY"] = LabelInfo("__ARRAY", 0, DummyBasicBlock(ALL_REGS, ["h", "l"]))
43-
LABELS["__MEMCPY"] = LabelInfo("__MEMCPY", 0, DummyBasicBlock(list("bcdefhl"), list("bcdehl")))
44-
LABELS["__PLOADF"] = LabelInfo("__PLOADF", 0, DummyBasicBlock(ALL_REGS, ALL_REGS)) # Special START BLOCK
45-
LABELS["__PSTOREF"] = LabelInfo("__PSTOREF", 0, DummyBasicBlock(ALL_REGS, ALL_REGS)) # Special START BLOCK
46-
47-
48-
def cleanupmem(initial_memory):
49-
"""Cleans up initial memory. Each label must be
50-
ALONE. Each instruction must have an space, etc...
51-
"""
52-
i = 0
53-
while i < len(initial_memory):
54-
tmp = initial_memory[i]
55-
match = RE_LABEL.match(tmp)
56-
if not match:
57-
i += 1
58-
continue
59-
60-
if tmp.rstrip() == match.group():
61-
i += 1
62-
continue
63-
64-
initial_memory[i] = tmp[match.end() :]
65-
initial_memory.insert(i, match.group())
66-
i += 1
67-
68-
69-
def cleanup_local_labels(block):
70-
"""Traverses memory, to make any local label a unique
71-
global one. At this point there's only a single code
72-
block
73-
"""
74-
global PROC_COUNTER
75-
76-
stack = [[]]
77-
hashes = [{}]
78-
stackprc = [PROC_COUNTER]
79-
used = [{}] # List of hashes of unresolved labels per scope
80-
81-
MEMORY = block.mem
82-
83-
for cell in MEMORY:
84-
if cell.inst.upper() == "PROC":
85-
stack += [[]]
86-
hashes += [{}]
87-
stackprc += [PROC_COUNTER]
88-
used += [{}]
89-
PROC_COUNTER += 1
90-
continue
91-
92-
if cell.inst.upper() == "ENDP":
93-
if len(stack) > 1: # There might be unbalanced stack due to syntax errors
94-
for label in used[-1].keys():
95-
if label in stack[-1]:
96-
newlabel = hashes[-1][label]
97-
for cell in used[-1][label]:
98-
cell.replace_label(label, newlabel)
99-
100-
stack.pop()
101-
hashes.pop()
102-
stackprc.pop()
103-
used.pop()
104-
continue
105-
106-
tmp = cell.asm.asm
107-
if tmp.upper()[:5] == "LOCAL":
108-
tmp = tmp[5:].split(",")
109-
for lbl in tmp:
110-
lbl = lbl.strip()
111-
if lbl in stack[-1]:
112-
continue
113-
stack[-1] += [lbl]
114-
hashes[-1][lbl] = "PROC%i." % stackprc[-1] + lbl
115-
if used[-1].get(lbl, None) is None:
116-
used[-1][lbl] = []
117-
118-
cell.asm = ";" + cell.asm # Remove it
119-
continue
120-
121-
if cell.is_label:
122-
label = cell.inst
123-
for i in range(len(stack) - 1, -1, -1):
124-
if label in stack[i]:
125-
label = hashes[i][label]
126-
cell.asm = label + ":"
127-
break
128-
continue
129-
130-
for label in cell.used_labels:
131-
labelUsed = False
132-
for i in range(len(stack) - 1, -1, -1):
133-
if label in stack[i]:
134-
newlabel = hashes[i][label]
135-
cell.replace_label(label, newlabel)
136-
labelUsed = True
137-
break
138-
139-
if not labelUsed:
140-
if used[-1].get(label, None) is None:
141-
used[-1][label] = []
142-
143-
used[-1][label] += [cell]
144-
145-
for i in range(len(MEMORY) - 1, -1, -1):
146-
if MEMORY[i].asm.asm[0] == ";":
147-
MEMORY.pop(i)
148-
149-
block.mem = MEMORY
150-
block.asm = [x.asm for x in MEMORY if len(x.asm)]
151-
152-
153-
def get_labels(basic_block):
154-
"""Traverses memory, to annotate all the labels in the global
155-
LABELS table
156-
"""
157-
for i, cell in enumerate(basic_block):
158-
if cell.is_label:
159-
label = cell.inst
160-
LABELS[label] = LabelInfo(label, cell.addr, basic_block, i) # Stores it globally
161-
162-
163-
def initialize_memory(basic_block):
164-
"""Initializes global memory array with the one in the main (initial) basic_block"""
165-
global MEMORY
166-
167-
init()
168-
MEMORY = basic_block.mem
169-
get_labels(basic_block)
170-
171-
172-
def optimize(initial_memory: list[str]) -> str:
173-
"""This will remove useless instructions"""
174-
global BLOCKS
175-
global PROC_COUNTER
176-
177-
del MEMORY[:]
178-
PROC_COUNTER = 0
179-
180-
cleanupmem(initial_memory)
181-
if OPTIONS.optimization_level <= 2: # if -O2 or lower, do nothing and return
182-
return "\n".join(x for x in initial_memory if not RE_PRAGMA.match(x))
183-
184-
basicblock.BasicBlock.clean_asm_args = OPTIONS.optimization_level > 3
185-
bb = basicblock.BasicBlock(initial_memory)
186-
cleanup_local_labels(bb)
187-
initialize_memory(bb)
188-
189-
BLOCKS = basic_blocks = basicblock.get_basic_blocks(bb) # 1st partition the Basic Blocks
190-
191-
for b in basic_blocks:
192-
__DEBUG__("--- BASIC BLOCK: {} ---".format(b.id), 1)
193-
__DEBUG__("Code:\n" + "\n".join(" {}".format(x) for x in b.code), 1)
194-
__DEBUG__("Requires: {}".format(b.requires()), 1)
195-
__DEBUG__("Destroys: {}".format(b.destroys()), 1)
196-
__DEBUG__("Label goes: {}".format(b.label_goes), 1)
197-
__DEBUG__("Comes from: {}".format([x.id for x in b.comes_from]), 1)
198-
__DEBUG__("Goes to: {}".format([x.id for x in b.goes_to]), 1)
199-
__DEBUG__("Next: {}".format(b.next.id if b.next is not None else None), 1)
200-
__DEBUG__("Size: {} Time: {}".format(b.sizeof, b.max_tstates), 1)
201-
__DEBUG__("--- END ---", 1)
202-
203-
LABELS["*START*"].basic_block.add_goes_to(basic_blocks[0])
204-
LABELS["*START*"].basic_block.next = basic_blocks[0]
205-
206-
basic_blocks[0].prev = LABELS["*START*"].basic_block
207-
if END_PROGRAM_LABEL in LABELS:
208-
LABELS[END_PROGRAM_LABEL].basic_block.add_goes_to(LABELS["*__END_PROGRAM*"].basic_block)
209-
210-
# In O3 we simplify the graph by reducing jumps over jumps
211-
for label in JUMP_LABELS:
212-
block = LABELS[label].basic_block
213-
if isinstance(block, DummyBasicBlock):
214-
continue
215-
216-
# The instruction that starts this block must be one of jr / jp
217-
first = block.get_next_exec_instruction()
218-
if first is None or first.inst not in ("jp", "jr"):
219-
continue
220-
221-
for blk in list(LABELS[label].used_by):
222-
if not first.condition_flag or blk[-1].condition_flag == first.condition_flag:
223-
new_label = first.opers[0]
224-
blk[-1].asm = blk[-1].code.replace(label, new_label)
225-
block.delete_comes_from(blk)
226-
LABELS[label].used_by.remove(blk)
227-
LABELS[new_label].used_by.add(blk)
228-
blk.add_goes_to(LABELS[new_label].basic_block)
229-
230-
for x in basic_blocks:
231-
x.compute_cpu_state()
232-
233-
filtered_patterns_list = [p for p in engine.PATTERNS if OPTIONS.optimization_level >= p.level >= 3]
234-
for x in basic_blocks:
235-
x.optimize(filtered_patterns_list)
236-
237-
for x in basic_blocks:
238-
if x.comes_from == [] and len([y for y in JUMP_LABELS if x is LABELS[y].basic_block]):
239-
x.ignored = True
240-
241-
return "\n".join(
242-
[y for y in flatten_list([x.code for x in basic_blocks if not x.ignored]) if not RE_PRAGMA.match(y)]
243-
)
1+
from .main import init, optimize

src/arch/z80/optimizer/basicblock.py

Lines changed: 2 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from src.arch.z80.optimizer.patterns import RE_ID_OR_NUMBER
1717
from src.arch.z80.peephole import evaluator
1818

19+
__all__ = "BasicBlock", "DummyBasicBlock"
20+
1921

2022
class BasicBlock(Iterable[MemCell]):
2123
"""A Class describing a basic block"""
@@ -480,140 +482,3 @@ def requires(self, i: int = 0, end_=None) -> set[str]:
480482

481483
def is_used(self, regs: Iterable[str], i: int, top: int | None = None) -> bool:
482484
return len([x for x in regs if x in self.__requires]) > 0
483-
484-
485-
def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock, BasicBlock]:
486-
assert 0 <= start_of_new_block < len(block), f"Invalid split pos: {start_of_new_block}"
487-
new_block = BasicBlock([])
488-
new_block.mem = block.mem[start_of_new_block:]
489-
block.mem = block.mem[:start_of_new_block]
490-
491-
new_block.next = block.next
492-
block.next = new_block
493-
new_block.prev = block
494-
495-
if new_block.next is not None:
496-
new_block.next.prev = new_block
497-
498-
for blk in list(block.goes_to):
499-
block.delete_goes_to(blk)
500-
new_block.add_goes_to(blk)
501-
502-
block.add_goes_to(new_block)
503-
504-
for i, mem in enumerate(new_block):
505-
if mem.is_label and mem.inst in LABELS:
506-
LABELS[mem.inst].basic_block = new_block
507-
LABELS[mem.inst].position = i
508-
509-
if block[-1].is_ender:
510-
if not block[-1].condition_flag: # If it's an unconditional jp, jr, call, ret
511-
block.delete_goes_to(block.next)
512-
513-
return block, new_block
514-
515-
516-
def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None:
517-
calling_blocks: dict[BasicBlock, BasicBlock] = {}
518-
519-
# Compute which blocks use jump labels
520-
for bb in basic_blocks:
521-
if bb[-1].is_ender and (op := bb[-1].branch_arg) in LABELS:
522-
LABELS[op].used_by.add(bb)
523-
524-
# For these blocks, add the referenced block in the goes_to
525-
for label in jump_labels:
526-
for bb in LABELS[label].used_by:
527-
bb.add_goes_to(LABELS[label].basic_block)
528-
529-
# Annotate which blocks uses call (which should be the last instruction)
530-
for bb in basic_blocks:
531-
if bb[-1].inst != "call":
532-
continue
533-
534-
op = bb[-1].branch_arg
535-
if op in LABELS:
536-
LABELS[op].basic_block.called_by.add(bb)
537-
calling_blocks[bb] = LABELS[op].basic_block
538-
539-
# For the annotated blocks, trace their goes_to, and their goes_to from
540-
# their goes_to and so on, until ret (unconditional or not) is found, and
541-
# save that block in a set for later
542-
visited: set[tuple[BasicBlock, BasicBlock]] = set()
543-
pending: set[tuple[BasicBlock, BasicBlock]] = set(calling_blocks.items())
544-
545-
while pending:
546-
caller, bb = pending.pop()
547-
if (caller, bb) in visited:
548-
continue
549-
550-
visited.add((caller, bb))
551-
552-
if not bb[-1].is_ender: # if it does not branch, search in the next block
553-
pending.add((caller, bb.next))
554-
continue
555-
556-
if bb[-1].inst in {"ret", "reti", "retn"}:
557-
if bb[-1].condition_flag:
558-
pending.add((caller, bb.next))
559-
560-
bb.add_goes_to(caller.next)
561-
continue
562-
563-
if bb[-1].inst in {"call", "rst"}: # A call from this block
564-
if bb[-1].condition_flag: # if it has conditions, it can return from the next block
565-
pending.add((caller, bb.next))
566-
567-
568-
def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
569-
"""Given the main basic block (which contain the entire program), populate
570-
the global JUMP_LABEL set with LABELS used by CALL, JR, JP (i.e JP LABEL0)
571-
Also updates the global LABELS index with the pertinent information.
572-
573-
Any BasicBlock containing a JUMP_LABEL in any position which is not the initial
574-
one (0 position) must be split at that point into two basic blocks.
575-
"""
576-
jump_labels: set[str] = set()
577-
578-
for i, mem in enumerate(main_basic_block):
579-
if mem.is_label:
580-
LABELS.pop(mem.inst)
581-
LABELS[mem.inst] = LabelInfo(
582-
label=mem.inst, addr=i, basic_block=main_basic_block, position=i # Unknown yet
583-
)
584-
continue
585-
586-
if not mem.is_ender:
587-
continue
588-
589-
lbl = mem.branch_arg
590-
if lbl is None:
591-
continue
592-
593-
jump_labels.add(lbl)
594-
595-
if lbl not in LABELS:
596-
__DEBUG__(f"INFO: {lbl} is not defined. No optimization is done.", 2)
597-
LABELS[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
598-
599-
return jump_labels
600-
601-
602-
def get_basic_blocks(block: BasicBlock) -> list[BasicBlock]:
603-
"""If a block is not partitionable, returns a list with the same block.
604-
Otherwise, returns a list with the resulting blocks.
605-
"""
606-
result: list[BasicBlock] = [block]
607-
JUMP_LABELS.clear()
608-
JUMP_LABELS.update(get_jump_labels(block))
609-
610-
# Split basic blocks per label or branch instruction
611-
split_pos = block.get_first_partition_idx()
612-
while split_pos is not None:
613-
_, block = split_block(block, split_pos)
614-
result.append(block)
615-
split_pos = block.get_first_partition_idx()
616-
617-
compute_calls(result, JUMP_LABELS)
618-
619-
return result

0 commit comments

Comments
 (0)