Skip to content

Commit 25a27e9

Browse files
committed
Correctly manage calls in the GDA
This update: * Refactors code and simplyfies it * Correctly test the GDA generation works for calls * Removes block.original_next and uses always block.next * Correctly initializes the LABEL and JUMP_LABEL information * Tesst added Basically, the Identity set for the goes_to set must be removed from the previous block and added to the next one, one by one. The IdentitySet is converted to list (copy) because in place modification when in the loop produces unexpected behavior.
1 parent 9280471 commit 25a27e9

3 files changed

Lines changed: 169 additions & 102 deletions

File tree

arch/zx48k/optimizer/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
def init():
1616
global LABELS
17+
global JUMP_LABELS
18+
1719
LABELS.clear()
20+
JUMP_LABELS.clear()
1821

1922
LABELS['*START*'] = LabelInfo('*START*', 0, DummyBasicBlock(ALL_REGS, ALL_REGS)) # Special START BLOCK
2023
LABELS['*__END_PROGRAM*'] = LabelInfo('__END_PROGRAM', 0, DummyBasicBlock(ALL_REGS, list('bc')))
@@ -164,6 +167,7 @@ def initialize_memory(basic_block):
164167
"""
165168
global MEMORY
166169

170+
init()
167171
MEMORY = basic_block.mem
168172
get_labels(basic_block)
169173

@@ -174,8 +178,6 @@ def optimize(initial_memory):
174178
global BLOCKS
175179
global PROC_COUNTER
176180

177-
LABELS.clear()
178-
JUMP_LABELS.clear()
179181
del MEMORY[:]
180182
PROC_COUNTER = 0
181183

arch/zx48k/optimizer/basicblock.py

Lines changed: 85 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .labelinfo import LabelInfo
1010
from .helpers import ALL_REGS, END_PROGRAM_LABEL
1111
from .common import LABELS, JUMP_LABELS
12-
from .errors import OptimizerInvalidBasicBlockError
12+
from .errors import OptimizerInvalidBasicBlockError, OptimizerError
1313
from .cpustate import CPUState
1414
from ..peephole import engine
1515
from ..peephole import evaluator
@@ -27,15 +27,14 @@ def __init__(self, memory):
2727
""" Initializes the internal array of instructions.
2828
"""
2929
self.mem = None
30-
self.next = None # Which (if any) basic block follows this one in the code
30+
self.next = None # Which (if any) basic block follows this one in memory
3131
self.prev = None # Which (if any) basic block precedes to this one in the code
32-
self.original_next = None # Which block originally followed this one in the code, if any
3332
self.lock = False # True if this block is being accessed by other subroutine
3433
self.comes_from = IdentitySet() # A list/tuple containing possible jumps to this block
3534
self.goes_to = IdentitySet() # A list/tuple of possible block to jump from here
3635
self.modified = False # True if something has been changed during optimization
3736
self.calls = IdentitySet()
38-
self.label_goes = IdentitySet()
37+
self.label_goes = []
3938
self.ignored = False # True if this block can be ignored (it's useless)
4039
self.id = BasicBlock.__UNIQUE_ID
4140
self._bytes = None
@@ -156,7 +155,7 @@ def update_labels(self):
156155
for l in self.labels:
157156
LABELS[l].basic_block = self
158157

159-
def delete_from(self, basic_block):
158+
def delete_comes_from(self, basic_block):
160159
""" Removes the basic_block ptr from the list for "comes_from"
161160
if it exists. It also sets self.prev to None if it is basic_block.
162161
"""
@@ -168,19 +167,14 @@ def delete_from(self, basic_block):
168167

169168
self.lock = True
170169

171-
if self.prev is basic_block:
172-
if self.prev.next is self:
173-
self.prev.next = None
174-
self.prev = None
175-
176170
for i in range(len(self.comes_from)):
177171
if self.comes_from[i] is basic_block:
178172
self.comes_from.pop(i)
179173
break
180174

181175
self.lock = False
182176

183-
def delete_goes(self, basic_block):
177+
def delete_goes_to(self, basic_block):
184178
""" Removes the basic_block ptr from the list for "goes_to"
185179
if it exists. It also sets self.next to None if it is basic_block.
186180
"""
@@ -192,15 +186,10 @@ def delete_goes(self, basic_block):
192186

193187
self.lock = True
194188

195-
if self.next is basic_block:
196-
if self.next.prev is self:
197-
self.next.prev = None
198-
self.next = None
199-
200189
for i in range(len(self.goes_to)):
201190
if self.goes_to[i] is basic_block:
202191
self.goes_to.pop(i)
203-
basic_block.delete_from(self)
192+
basic_block.delete_comes_from(self)
204193
break
205194

206195
self.lock = False
@@ -244,35 +233,30 @@ def add_goes_to(self, basic_block):
244233

245234
def update_next_block(self):
246235
""" If the last instruction of this block is a JP, JR or RET (with no
247-
conditions) then the next and goes_to sets just contains a
236+
conditions) then goes_to set contains just a
248237
single block
249238
"""
250239
last = self.mem[-1]
251-
if last.inst not in ('ret', 'jp', 'jr') or last.condition_flag is not None:
240+
if last.inst not in {'djnz', 'jp', 'jr', 'call', 'ret', 'reti', 'retn', 'rst'}:
252241
return
253242

254-
if last.inst == 'ret':
243+
if last.inst in {'reti', 'retn'}:
255244
if self.next is not None:
256-
self.next.delete_from(self)
257-
self.delete_goes(self.next)
245+
self.next.delete_comes_from(self)
246+
return
247+
248+
if self.next is not None and last.condition_flag is None: # jp NNN, call NNN, rst, jr NNNN, ret
249+
self.next.delete_comes_from(self)
250+
251+
if last.inst == 'ret':
258252
return
259253

260254
if last.opers[0] not in LABELS.keys():
261255
__DEBUG__("INFO: %s is not defined. No optimization is done." % last.opers[0], 2)
262256
LABELS[last.opers[0]] = LabelInfo(last.opers[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
263257

264258
n_block = LABELS[last.opers[0]].basic_block
265-
if self.next is n_block:
266-
return
267-
268-
if self.next.prev == self:
269-
# The next basic block is not this one since it ends with a jump
270-
self.next.delete_from(self)
271-
self.delete_goes(self.next)
272-
273-
self.next = n_block
274-
self.next.add_comes_from(self)
275-
self.add_goes_to(self.next)
259+
self.add_goes_to(n_block)
276260

277261
def update_used_by_list(self):
278262
""" Every label has a set containing
@@ -288,81 +272,82 @@ def update_used_by_list(self):
288272
def clean_up_goes_to(self):
289273
for x in self.goes_to:
290274
if x is not self.next:
291-
self.delete_goes(x)
275+
self.delete_goes_to(x)
292276

293277
def clean_up_comes_from(self):
294278
for x in self.comes_from:
295279
if x is not self.prev:
296-
self.delete_from(x)
280+
self.delete_comes_from(x)
297281

298282
def update_goes_and_comes(self):
299283
""" Once the block is a Basic one, check the last instruction and updates
300284
goes_to and comes_from set of the receivers.
301285
Note: jp, jr and ret are already done in update_next_block()
302286
"""
303-
# Remove any block from the comes_from and goes_to list except the PREVIOUS and NEXT
304287
if not len(self):
305288
return
306289

307-
if self.mem[-1].inst == 'ret':
308-
return # subroutine returns are updated from CALLer blocks
309-
310-
self.update_used_by_list()
311-
312-
if not self.mem[-1].is_ender:
313-
return
314-
315290
last = self.mem[-1]
316291
inst = last.inst
317292
oper = last.opers
318293
cond = last.condition_flag
319294

320-
if oper and oper[0] not in LABELS.keys():
321-
__DEBUG__("INFO: %s is not defined. No optimization is done." % oper[0], 1)
322-
LABELS[oper[0]] = LabelInfo(oper[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
295+
if not last.is_ender:
296+
return
323297

324-
if inst == 'djnz' or inst in ('jp', 'jr') and cond is not None:
325-
if oper[0] in LABELS:
326-
self.add_goes_to(LABELS[oper[0]].basic_block)
298+
if cond is None:
299+
self.delete_goes_to(self.next)
327300

328-
elif inst in ('jp', 'jr') and cond is None:
329-
if oper[0] in LABELS:
330-
self.delete_goes(self.next)
331-
self.next = LABELS[oper[0]].basic_block
332-
self.add_goes_to(self.next)
301+
if last.inst in {'ret', 'reti', 'retn'} and cond is None:
302+
return # subroutine returns are updated from CALLer blocks
333303

334-
elif inst == 'call':
335-
if cond is None:
336-
self.delete_goes(self.next)
337-
self.next = LABELS[oper[0]].basic_block
304+
if oper and oper[0]:
305+
if oper[0] not in LABELS.keys():
306+
__DEBUG__("INFO: %s is not defined. No optimization is done." % oper[0], 1)
307+
LABELS[oper[0]] = LabelInfo(oper[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
338308

339-
LABELS[oper[0]].basic_block.add_comes_from(self)
340-
stack = [LABELS[oper[0]].basic_block]
341-
bbset = IdentitySet()
309+
self.add_goes_to(LABELS[oper[0]].basic_block)
342310

343-
while stack:
344-
bb = stack.pop(0)
311+
if inst in {'djnz', 'jp', 'jr'}:
312+
return
345313

346-
while bb is not None:
347-
if bb in bbset:
348-
break
314+
assert inst in ('call', 'rst')
315+
316+
if self.next is None:
317+
raise OptimizerError("Unexpected NULL next block")
318+
319+
final_blk = self.next # The block all the final returns should go to
320+
stack = [LABELS[oper[0]].basic_block]
321+
bbset = IdentitySet()
349322

350-
bbset.add(bb)
351-
if len(bb):
352-
bb1 = bb[-1]
353-
if bb1.inst == 'ret':
354-
bb.add_goes_to(self.next)
355-
if bb1.condition_flag is None: # 'ret'
356-
break
323+
while stack:
324+
bb = stack.pop(0)
325+
while True:
326+
if bb is None:
327+
DummyBasicBlock(ALL_REGS, ALL_REGS)
357328

358-
if bb1.inst in ('jp', 'jr') and bb1.condition_flag is not None: # jp/jr nc/nz/.. LABEL
359-
if bb1.opers[0] in LABELS: # some labels does not exist (e.g. immediate numeric addresses)
360-
stack += [LABELS[bb1.opers[0]].basic_block]
329+
if bb in bbset:
330+
break
331+
332+
bbset.add(bb)
333+
334+
if isinstance(bb, DummyBasicBlock):
335+
bb.add_goes_to(final_blk)
336+
break
361337

362-
bb = bb.next # next contiguous block
338+
if bb:
339+
bb1 = bb[-1]
340+
if bb1.inst in {'ret', 'reti', 'retn'}:
341+
bb.add_goes_to(final_blk)
342+
if bb1.condition_flag is None: # 'ret'
343+
break
344+
elif bb1.inst in ('jp', 'jr') and bb1.condition_flag is not None: # jp/jr nc/nz/.. LABEL
345+
if bb1.opers[0] in LABELS: # some labels does not exist (e.g. immediate numeric addresses)
346+
stack.append(LABELS[bb1.opers[0]].basic_block)
347+
else:
348+
raise OptimizerError("Unknown block label '{}'".format(bb1.opers[0]))
363349

364-
if cond is None:
365-
self.calls.add(LABELS[oper[0]].basic_block)
350+
bb = bb.next # next contiguous block
366351

367352
def is_used(self, regs, i, top=None):
368353
""" Checks whether any of the given regs are required from the given point
@@ -516,7 +501,7 @@ def guesses_initial_state_from_origin_blocks(self):
516501
regs = self.comes_from[0].cpu.regs
517502
mems = self.comes_from[0].cpu.mem
518503

519-
for blk in self.comes_from:
504+
for blk in self.comes_from[1:]:
520505
regs = helpers.dict_intersection(regs, blk.cpu.regs)
521506
mems = helpers.dict_intersection(mems, blk.cpu.mem)
522507

@@ -619,24 +604,25 @@ def block_partition(block, i):
619604
lbl_info.basic_block = new_block
620605
lbl_info.position -= len(block)
621606

622-
new_block.goes_to = block.goes_to
623-
block.goes_to = IdentitySet()
607+
for b_ in list(block.goes_to):
608+
block.delete_goes_to(b_)
609+
new_block.add_goes_to(b_)
624610

625611
new_block.label_goes = block.label_goes
626612
block.label_goes = []
627613

628-
new_block.next = new_block.original_next = block.original_next
614+
new_block.next = block.next
629615
new_block.prev = block
616+
block.next = new_block
630617
new_block.add_comes_from(block)
631618

632619
if new_block.next is not None:
633620
new_block.next.prev = new_block
634-
new_block.next.add_comes_from(new_block)
635-
new_block.next.delete_from(block)
621+
if block in new_block.next.comes_from:
622+
new_block.next.delete_comes_from(block)
623+
new_block.next.add_comes_from(new_block)
636624

637-
block.next = block.original_next = new_block
638625
block.update_next_block()
639-
block.add_goes_to(new_block)
640626

641627
return block, new_block
642628

@@ -653,23 +639,27 @@ def get_basic_blocks(block):
653639
block = new_block
654640
new_block = None
655641

656-
for i in range(len(block) - 1):
657-
if i and block.mem[i].code == EDP: # END_PROGRAM label always starts a basic block
642+
for i, mem in enumerate(block[:-1]):
643+
if i and mem.code == EDP: # END_PROGRAM label always starts a basic block
658644
block, new_block = block_partition(block, i - 1)
659645
LABELS[END_PROGRAM_LABEL].basic_block = new_block
660646
break
661647

662-
if block.mem[i].is_ender:
648+
if mem.is_ender:
663649
block, new_block = block_partition(block, i)
664-
op = block.mem[i].opers
650+
if not mem.condition_flag:
651+
block.delete_goes_to(new_block)
665652

666-
for l in op:
653+
for l in mem.opers:
667654
if l in LABELS.keys():
668655
JUMP_LABELS.add(l)
669656
block.label_goes.append(l)
670657
break
671658

672-
if block.mem[i].code in arch.zx48k.backend.ASMS:
659+
if mem.is_label and mem.code[:-1] not in LABELS:
660+
raise OptimizerError("Missing label '{}' in labels list".format(mem.code[:-1]))
661+
662+
if mem.code in arch.zx48k.backend.ASMS: # An inline ASM block
673663
block, new_block = block_partition(block, max(0, i - 1))
674664
break
675665

@@ -683,8 +673,7 @@ def get_basic_blocks(block):
683673
must_partition = False
684674
# This label must point to the beginning of blk, just before the code
685675
# Otherwise we must partition it (must_partition = True)
686-
for i in range(len(blk)):
687-
cell = blk.mem[i]
676+
for i, cell in enumerate(blk):
688677
if cell.inst == label:
689678
break # already starts with this label
690679

@@ -707,10 +696,6 @@ def get_basic_blocks(block):
707696
result.insert(j, block_)
708697
result.insert(j + 1, new_block_)
709698

710-
for b in result:
711-
b.clean_up_comes_from()
712-
b.clean_up_goes_to()
713-
714699
for b in result:
715700
b.update_goes_and_comes()
716701

0 commit comments

Comments
 (0)