11from src .api .debug import __DEBUG__
22
33from .basicblock import BasicBlock , DummyBasicBlock
4- from .common import JUMP_LABELS , LABELS
54from .helpers import ALL_REGS
65from .labelinfo import LabelInfo
6+ from .labels_dict import LabelsDict
77
88__all__ = ("get_basic_blocks" ,)
99
1010
11- def split_block (block : BasicBlock , start_of_new_block : int ) -> tuple [BasicBlock , BasicBlock ]:
11+ def _split_block (block : BasicBlock , start_of_new_block : int , labels : LabelsDict ) -> tuple [BasicBlock , BasicBlock ]:
1212 assert 0 <= start_of_new_block < len (block ), f"Invalid split pos: { start_of_new_block } "
1313 new_block = BasicBlock ([])
1414 new_block .mem = block .mem [start_of_new_block :]
@@ -28,9 +28,9 @@ def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock,
2828 block .add_goes_to (new_block )
2929
3030 for i , mem in enumerate (new_block ):
31- if mem .is_label and mem .inst in LABELS :
32- LABELS [mem .inst ].basic_block = new_block
33- LABELS [mem .inst ].position = i
31+ if mem .is_label and mem .inst in labels :
32+ labels [mem .inst ].basic_block = new_block
33+ labels [mem .inst ].position = i
3434
3535 if block [- 1 ].is_ender :
3636 if not block [- 1 ].condition_flag : # If it's an unconditional jp, jr, call, ret
@@ -39,28 +39,32 @@ def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock,
3939 return block , new_block
4040
4141
42- def compute_calls (basic_blocks : list [BasicBlock ], jump_labels : set [str ]) -> None :
42+ def _compute_calls (
43+ basic_blocks : list [BasicBlock ],
44+ labels : LabelsDict ,
45+ jump_labels : set [str ],
46+ ) -> None :
4347 calling_blocks : dict [BasicBlock , BasicBlock ] = {}
4448
4549 # Compute which blocks use jump labels
4650 for bb in basic_blocks :
47- if bb [- 1 ].is_ender and (op := bb [- 1 ].branch_arg ) in LABELS :
48- LABELS [op ].used_by .add (bb )
51+ if bb [- 1 ].is_ender and (op := bb [- 1 ].branch_arg ) in labels :
52+ labels [op ].used_by .add (bb )
4953
5054 # For these blocks, add the referenced block in the goes_to
5155 for label in jump_labels :
52- for bb in LABELS [label ].used_by :
53- bb .add_goes_to (LABELS [label ].basic_block )
56+ for bb in labels [label ].used_by :
57+ bb .add_goes_to (labels [label ].basic_block )
5458
5559 # Annotate which blocks uses call (which should be the last instruction)
5660 for bb in basic_blocks :
5761 if bb [- 1 ].inst != "call" :
5862 continue
5963
6064 op = bb [- 1 ].branch_arg
61- if op in LABELS :
62- LABELS [op ].basic_block .called_by .add (bb )
63- calling_blocks [bb ] = LABELS [op ].basic_block
65+ if op in labels :
66+ labels [op ].basic_block .called_by .add (bb )
67+ calling_blocks [bb ] = labels [op ].basic_block
6468
6569 # For the annotated blocks, trace their goes_to, and their goes_to from
6670 # their goes_to and so on, until ret (unconditional or not) is found, and
@@ -91,7 +95,7 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
9195 pending .add ((caller , bb .next ))
9296
9397
94- def get_jump_labels (main_basic_block : BasicBlock ) -> set [str ]:
98+ def _get_jump_labels (main_basic_block : BasicBlock , labels : LabelsDict ) -> set [str ]:
9599 """Given the main basic block (which contain the entire program), populate
96100 the global JUMP_LABEL set with LABELS used by CALL, JR, JP (i.e JP LABEL0)
97101 Also updates the global LABELS index with the pertinent information.
@@ -103,8 +107,8 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
103107
104108 for i , mem in enumerate (main_basic_block ):
105109 if mem .is_label :
106- LABELS .pop (mem .inst )
107- LABELS [mem .inst ] = LabelInfo (
110+ labels .pop (mem .inst )
111+ labels [mem .inst ] = LabelInfo (
108112 label = mem .inst , addr = i , basic_block = main_basic_block , position = i # Unknown yet
109113 )
110114 continue
@@ -118,28 +122,32 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
118122
119123 jump_labels .add (lbl )
120124
121- if lbl not in LABELS :
125+ if lbl not in labels :
122126 __DEBUG__ (f"INFO: { lbl } is not defined. No optimization is done." , 2 )
123- LABELS [lbl ] = LabelInfo (lbl , 0 , DummyBasicBlock (ALL_REGS , ALL_REGS ))
127+ labels [lbl ] = LabelInfo (lbl , 0 , DummyBasicBlock (ALL_REGS , ALL_REGS ))
124128
125129 return jump_labels
126130
127131
128- def get_basic_blocks (block : BasicBlock ) -> list [BasicBlock ]:
132+ def get_basic_blocks (
133+ block : BasicBlock ,
134+ labels : LabelsDict ,
135+ jump_labels : set [str ],
136+ ) -> list [BasicBlock ]:
129137 """If a block is not partitionable, returns a list with the same block.
130138 Otherwise, returns a list with the resulting blocks.
131139 """
132140 result : list [BasicBlock ] = [block ]
133- JUMP_LABELS .clear ()
134- JUMP_LABELS .update (get_jump_labels (block ))
141+ jump_labels .clear ()
142+ jump_labels .update (_get_jump_labels (block , labels ))
135143
136144 # Split basic blocks per label or branch instruction
137- split_pos = block .get_first_partition_idx ()
145+ split_pos = block .get_first_partition_idx (jump_labels )
138146 while split_pos is not None :
139- _ , block = split_block (block , split_pos )
147+ _ , block = _split_block (block , split_pos , labels )
140148 result .append (block )
141- split_pos = block .get_first_partition_idx ()
149+ split_pos = block .get_first_partition_idx (jump_labels )
142150
143- compute_calls (result , JUMP_LABELS )
151+ _compute_calls (result , labels , jump_labels )
144152
145153 return result
0 commit comments