11# -*- coding: utf-8 -*-
22from __future__ import annotations
33
4- from typing import Final , Iterable , Iterator , List
4+ from typing import TYPE_CHECKING , Final , Iterable , Iterator , Sequence
55
6- import src .api . config
7- import src .arch . z80 . backend . common
6+ from src .api import errmsg
7+ from src .api . config import OPTIONS
88from src .api .debug import __DEBUG__
9- from src .api .utils import sfirst
9+ from src .api .utils import flatten_list , sfirst
10+ from src .arch .z80 .backend .common import ASMS
1011from src .arch .z80 .peephole import evaluator
1112
12- from . import helpers
1313from .cpustate import CPUState
14- from .helpers import ALL_REGS
14+ from .helpers import (
15+ ALL_REGS ,
16+ dict_intersection ,
17+ idx_args ,
18+ is_16bit_oper_register ,
19+ new_tmp_val ,
20+ simplify_asm_args ,
21+ single_registers ,
22+ )
1523from .labelinfo import LabelInfo
24+ from .labels_dict import LabelsDict
1625from .memcell import MemCell
1726from .patterns import RE_ID_OR_NUMBER
1827
28+ if TYPE_CHECKING :
29+ from .main import Optimizer
30+
1931__all__ = "BasicBlock" , "DummyBasicBlock"
2032
2133
22- class BasicBlock (Iterable [MemCell ]):
34+ class BasicBlock (Sequence [MemCell ]):
2335 """A Class describing a basic block"""
2436
2537 __UNIQUE_ID = 0
@@ -29,24 +41,24 @@ def __new__(cls, *args, **kwargs):
2941 cls .__UNIQUE_ID += 1
3042 return super ().__new__ (cls )
3143
32- def __init__ (self , memory : Iterable [str ]) :
44+ def __init__ (self , memory : Iterable [str ], optimizer : Optimizer ) -> None :
3345 """Initializes the internal array of instructions."""
34- self .mem : List [MemCell ] = []
46+ self .optimizer = optimizer
47+ self .mem : list [MemCell ] = []
3548 self .next : BasicBlock | None = None # Which (if any) basic block follows this one in memory
3649 self .prev : BasicBlock | None = None # Which (if any) basic block precedes to this one in the code
3750 self .lock = False # True if this block is being accessed by other subroutine
3851 self .comes_from : set [BasicBlock ] = set () # A list/tuple containing possible jumps to this block
3952 self .goes_to : set [BasicBlock ] = set () # A list/tuple of possible block to jump from here
4053 self .modified = False # True if something has been changed during optimization
4154 self .called_by : set [BasicBlock ] = set ()
42- self .label_goes = []
4355 self .ignored = False # True if this block can be ignored (it's useless)
4456 self .id : Final [int ] = BasicBlock .__UNIQUE_ID
4557 self ._bytes = None
4658 self ._sizeof = None
4759 self ._max_tstates = None
4860 self .optimized = False # True if this block was already optimized
49- self .code = memory
61+ self .code = list ( memory )
5062 self .cpu = CPUState ()
5163
5264 def __hash__ (self ) -> int :
@@ -61,7 +73,7 @@ def __str__(self) -> str:
6173 def __repr__ (self ) -> str :
6274 return "<{}: id: {}, len: {}>" .format (self .__class__ .__name__ , self .id , len (self ))
6375
64- def __getitem__ (self , key ) -> MemCell | list [ MemCell ] :
76+ def __getitem__ (self , key ):
6577 return self .mem [key ]
6678
6779 def __setitem__ (self , key , value : MemCell ):
@@ -74,31 +86,36 @@ def __iter__(self) -> Iterator[MemCell]:
7486 for mem in self .mem :
7587 yield mem
7688
89+ @property
90+ def jump_labels (self ) -> set [str ]:
91+ return self .optimizer .JUMP_LABELS
92+
93+ @property
94+ def opt_labels (self ) -> LabelsDict :
95+ return self .optimizer .LABELS
96+
7797 def pop (self , i : int ) -> MemCell :
7898 self ._bytes = None
7999 self ._sizeof = None
80100 self ._max_tstates = None
81101 return self .mem .pop (i )
82102
83- def insert (self , i : int , value : str ):
84- memcell = MemCell (value , i )
85- self .mem .insert (i , memcell )
86- self ._bytes = None
87- self ._sizeof = None
88- self ._max_tstates = None
89-
90103 @property
91- def code (self ) -> List [str ]:
104+ def code (self ) -> list [str ]:
92105 return [x .code for x in self .mem ]
93106
94107 @code .setter
95108 def code (self , value : Iterable [str ]):
109+ self ._set_code (value )
110+
111+ def _set_code (self , value : Iterable [str ]) -> None :
96112 assert isinstance (value , Iterable )
97- assert all (isinstance (x , str ) for x in value )
113+ mems = tuple (value )
114+ assert all (isinstance (x , str ) for x in mems )
98115 if self .clean_asm_args :
99- self .mem = [MemCell (helpers . simplify_asm_args (asm ), i ) for i , asm in enumerate (value )]
116+ self .mem = [MemCell (simplify_asm_args (asm ), i ) for i , asm in enumerate (mems )]
100117 else :
101- self .mem = [MemCell (asm , i ) for i , asm in enumerate (value )]
118+ self .mem = [MemCell (asm , i ) for i , asm in enumerate (mems )]
102119
103120 self ._bytes = None
104121 self ._sizeof = None
@@ -136,15 +153,15 @@ def labels(self) -> tuple[str, ...]:
136153 memory"""
137154 return tuple (cell .inst for cell in self .mem if cell .is_label )
138155
139- def get_first_partition_idx (self , jump_labels : set [ str ] ) -> int | None :
156+ def get_first_partition_idx (self ) -> int | None :
140157 """Returns the first position where this block can be
141158 partitioned or None if there's no such point
142159 """
143160 for i , mem in enumerate (self ):
144- if i > 0 and mem .is_label and mem .inst in jump_labels :
161+ if i > 0 and mem .is_label and mem .inst in self . jump_labels :
145162 return i
146163
147- if (mem .is_ender or mem .code in src . arch . z80 . backend . common . ASMS ) and i < len (self ) - 1 :
164+ if (mem .is_ender or mem .code in ASMS ) and i < len (self ) - 1 :
148165 return i + 1
149166
150167 return None
@@ -209,7 +226,7 @@ def add_goes_to(self, basic_block: BasicBlock | None) -> None:
209226 self .goes_to .add (basic_block )
210227 basic_block .comes_from .add (self )
211228
212- def update_next_block (self , labels : dict [ str , LabelInfo ] ) -> None :
229+ def update_next_block (self ) -> None :
213230 """If the last instruction of this block is a JP, JR or RET (with no
214231 conditions) then goes_to set contains just a
215232 single block
@@ -231,14 +248,16 @@ def update_next_block(self, labels: dict[str, LabelInfo]) -> None:
231248 if last .inst == "ret" :
232249 return
233250
234- if last .opers [0 ] not in labels . keys () :
251+ if last .opers [0 ] not in self . opt_labels :
235252 __DEBUG__ ("INFO: %s is not defined. No optimization is done." % last .opers [0 ], 2 )
236- labels [last .opers [0 ]] = LabelInfo (last .opers [0 ], 0 , DummyBasicBlock (ALL_REGS , ALL_REGS ))
253+ self .opt_labels [last .opers [0 ]] = LabelInfo (
254+ last .opers [0 ], 0 , DummyBasicBlock (ALL_REGS , ALL_REGS , self .optimizer )
255+ )
237256
238- n_block = labels [last .opers [0 ]].basic_block
257+ n_block = self . opt_labels [last .opers [0 ]].basic_block
239258 self .add_goes_to (n_block )
240259
241- def is_used (self , regs : list [str ], i : int , top : int | None = None ) -> bool :
260+ def is_used (self , regs : Sequence [str ], i : int , top : int | None = None ) -> bool :
242261 """Checks whether any of the given regs are required from the given point
243262 to the end or not.
244263 """
@@ -249,8 +268,8 @@ def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:
249268 top = len (self ) if top is None else top + 1
250269
251270 if regs and regs [0 ][0 ] == "(" and regs [0 ][- 1 ] == ")" : # A memory address
252- r16 = helpers . single_registers (regs [0 ][1 :- 1 ]) if helpers . is_16bit_oper_register (regs [0 ][1 :- 1 ]) else []
253- ix = helpers . single_registers (helpers . idx_args (regs [0 ][1 :- 1 ])[0 ]) if helpers . idx_args (regs [0 ][1 :- 1 ]) else []
271+ r16 = single_registers (regs [0 ][1 :- 1 ]) if is_16bit_oper_register (regs [0 ][1 :- 1 ]) else []
272+ ix = single_registers (idx_args (regs [0 ][1 :- 1 ])[0 ]) if idx_args (regs [0 ][1 :- 1 ]) else [] # type: ignore
254273
255274 rr = set (r16 + ix )
256275 mem_vars = set ([] if rr else RE_ID_OR_NUMBER .findall (regs [0 ]))
@@ -274,7 +293,7 @@ def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:
274293
275294 return True
276295
277- regs = src . api . utils . flatten_list ([helpers . single_registers (x ) for x in regs ]) # make a copy
296+ regs = flatten_list ([single_registers (x ) for x in regs ]) # make a copy
278297 for ii in range (i , top ):
279298 if any (r in regs for r in self .mem [ii ].requires ):
280299 return True
@@ -385,11 +404,11 @@ def guesses_initial_state_from_origin_blocks(self) -> tuple[dict[str, str], dict
385404 return {}, {}
386405
387406 regs = sfirst (self .comes_from ).cpu .regs
388- mems = sfirst (self .comes_from ).cpu .mem
407+ mems = dict ( sfirst (self .comes_from ).cpu .mem )
389408
390409 for blk in list (self .comes_from )[1 :]:
391- regs = helpers . dict_intersection (regs , blk .cpu .regs )
392- mems = helpers . dict_intersection (mems , blk .cpu .mem )
410+ regs = dict_intersection (regs , blk .cpu .regs )
411+ mems = dict_intersection (mems , blk .cpu .mem )
393412
394413 return regs , mems
395414
@@ -417,12 +436,12 @@ def optimize(self, patterns_list):
417436 # monkey-patches some functions in this optimizer level (> 2)
418437 evaluator .UNARY ["GVAL" ] = lambda x : self .cpu .get (x )
419438 evaluator .UNARY ["FLAGVAL" ] = lambda x : {
420- "c" : str (self .cpu .C ) if self .cpu .C is not None else helpers . new_tmp_val (),
421- "z" : str (self .cpu .Z ) if self .cpu .Z is not None else helpers . new_tmp_val (),
422- }.get (x .lower (), helpers . new_tmp_val ())
439+ "c" : str (self .cpu .C ) if self .cpu .C is not None else new_tmp_val (),
440+ "z" : str (self .cpu .Z ) if self .cpu .Z is not None else new_tmp_val (),
441+ }.get (x .lower (), new_tmp_val ())
423442 evaluator .UNARY ["IS_REQUIRED" ] = lambda x : self .is_used ([x ], i + len (p .patt ))
424443
425- if src . api . config . OPTIONS .optimization_level > 3 :
444+ if OPTIONS .optimization_level > 3 :
426445 regs , mems = self .guesses_initial_state_from_origin_blocks ()
427446 else :
428447 regs , mems = {}, {}
@@ -447,8 +466,8 @@ def optimize(self, patterns_list):
447466 new_code = list (code )
448467 matched = new_code [i : i + len (p .patt )]
449468 new_code [i : i + len (p .patt )] = p .template .filter (match )
450- src . api . errmsg .info ("pattern applied [{}:{}]" .format ("%03i" % p .flag , p .fname ))
451- src . api . debug . __DEBUG__ ("matched: \n {}" .format ("\n " .join (matched )), level = 1 )
469+ errmsg .info ("pattern applied [{}:{}]" .format ("%03i" % p .flag , p .fname ))
470+ __DEBUG__ ("matched: \n {}" .format ("\n " .join (matched )), level = 1 )
452471 changed = new_code != code
453472 if changed :
454473 code = new_code
@@ -470,8 +489,8 @@ class DummyBasicBlock(BasicBlock):
470489 about what registers uses an destroys
471490 """
472491
473- def __init__ (self , destroys : Iterable [str ], requires : Iterable [str ]) :
474- BasicBlock .__init__ (self , [])
492+ def __init__ (self , destroys : Iterable [str ], requires : Iterable [str ], optimizer : Optimizer ) -> None :
493+ BasicBlock .__init__ (self , [], optimizer )
475494 self .__destroys = tuple (destroys )
476495 self .__requires = set (requires )
477496 self .code = ["ret" ]
0 commit comments