Skip to content

Commit c2365ca

Browse files
committed
Fix bug with -O4
Now correctly computes inc/dec (hl) among other things, ensuring memory cells always have 8 bits values. New test added.
1 parent 4b1511c commit c2365ca

7 files changed

Lines changed: 365 additions & 86 deletions

File tree

src/arch/zx48k/optimizer/cpustate.py

Lines changed: 156 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,138 @@
11
# -*- coding: utf-8 -*-
22

3+
import re
4+
from typing import Dict, List, Tuple
35
from collections import defaultdict
6+
47
from . import asm
58

6-
from .helpers import new_tmp_val, new_tmp_val16, HI16, LO16, HL_SEP
7-
from .helpers import is_unknown, is_unknown16, valnum, is_number
9+
from .helpers import new_tmp_val16, HI16, LO16, HL_SEP
10+
from .helpers import is_unknown, is_unknown8, is_unknown16, valnum
11+
from .helpers import is_number, is_label, new_tmp_val, new_tmp_val16_from_label
812
from .helpers import is_register, is_8bit_oper_register, is_16bit_composed_register
9-
from .helpers import get_L_from_unknown_value, idx_args, LO16_val
13+
from .helpers import get_L_from_unknown_value, get_H_from_unknown_value, idx_args, LO16_val
14+
from .helpers import get_orig_label_from_unknown16
15+
1016

17+
RE_OFFSET = re.compile(r'(^[*._a-zA-Z0-9]+(?:[+-]\d+)*)([+-]\d+)$')
1118

12-
class Flags(object):
19+
20+
class Flags:
1321
def __init__(self):
1422
self.C = None
1523
self.Z = None
1624
self.P = None
1725
self.S = None
1826

1927

20-
class CPUState(object):
28+
class Memory:
29+
""" Implements a memory representation, dealing with unknown values
30+
"""
31+
def __init__(self):
32+
self.mem = defaultdict(new_tmp_val)
33+
34+
def _get_hl_addr(self, addr: str) -> Tuple[str, str]:
35+
if is_number(addr):
36+
return addr, str(int(addr) + 1)
37+
38+
ptr = RE_OFFSET.match(addr)
39+
if ptr is None:
40+
return addr, f"{addr}+1"
41+
42+
base, off, *_ = ptr.groups()
43+
off = int(off)
44+
if off == 0:
45+
return self._get_hl_addr(base)
46+
47+
if off == -1:
48+
return addr, base
49+
50+
return addr, "%s%+i" % (base, off + 1)
51+
52+
def read_16_bit_value(self, addr: str) -> str:
53+
addr_lo, addr_hi = self._get_hl_addr(addr)
54+
hi = self.mem[addr_hi]
55+
lo = self.mem[addr_lo]
56+
if is_number(hi) and is_number(lo):
57+
return str(int(lo) + 256 * int(hi))
58+
59+
result = f"{hi}|{lo}"
60+
if (label_ := get_orig_label_from_unknown16("")) is not None:
61+
return label_
62+
63+
return result
64+
65+
def write_16_bit_value(self, addr: str, value: str) -> None:
66+
if is_number(value):
67+
v_hi = str((int(value) >> 8) & 0xFF)
68+
v_lo = str(int(value) & 0xFF)
69+
else:
70+
if is_unknown16(value):
71+
v_ = value
72+
else:
73+
v_ = new_tmp_val16_from_label(value)
74+
v_hi = get_H_from_unknown_value(v_)
75+
v_lo = get_L_from_unknown_value(v_)
76+
77+
addr_lo, addr_hi = self._get_hl_addr(addr)
78+
self.mem[addr_lo] = v_lo
79+
self.mem[addr_hi] = v_hi
80+
81+
def read_8_bit_value(self, addr: str) -> str:
82+
addr_lo, _ = self._get_hl_addr(addr)
83+
lo = self.mem[addr_lo]
84+
if is_number(lo):
85+
return str(int(lo) & 0xFF)
86+
87+
return lo
88+
89+
def write_8_bit_value(self, addr: str, value: str) -> None:
90+
if is_number(value):
91+
value = str(int(value) & 0xFF)
92+
elif is_unknown16(value):
93+
value = get_L_from_unknown_value(value)
94+
elif is_label(value):
95+
value = get_L_from_unknown_value(new_tmp_val16_from_label(value))
96+
97+
addr_lo, _ = self._get_hl_addr(addr)
98+
self.mem[addr_lo] = value
99+
100+
def update(self, **kwargs):
101+
self.mem.update(kwargs)
102+
103+
def keys(self) -> List[str]:
104+
return list(self.mem.keys())
105+
106+
def values(self) -> List[str]:
107+
return list(self.mem.values())
108+
109+
def items(self) -> List[Tuple[str, str]]:
110+
return list(self.mem.items())
111+
112+
def __iter__(self):
113+
return (x for x in self.mem)
114+
115+
def __len__(self):
116+
return len(self.mem)
117+
118+
def __getitem__(self, item):
119+
return self.mem[item]
120+
121+
122+
class CPUState:
21123
""" A class storing registers value information (CPU State).
22124
"""
125+
mem: Memory
126+
stack: List[str]
127+
regs: Dict[str, str]
128+
_flags: Tuple[Flags, Flags]
129+
_16bit: Dict[str, str]
130+
23131
def __init__(self):
24-
self.regs = None
25-
self.stack = None
26-
self.mem = None
27-
self._flags = None
28-
self._16bit = None
132+
self._16bit = {'b': 'bc', 'c': 'bc', 'd': 'de', 'e': 'de', 'h': 'hl', 'l': 'hl',
133+
"b'": "bc'", "c'": "bc'", "d'": "de'", "e'": "de'", "h'": "hl'", "l'": "hl'",
134+
'ixy': 'ix', 'ixl': 'ix', 'iyh': 'iy', 'iyl': 'iy', 'a': 'af', "a'": "af'",
135+
'f': 'af', "f'": "af'"}
29136
self.reset()
30137

31138
@property
@@ -100,7 +207,7 @@ def S(self, val):
100207
else:
101208
self.regs['f'] = new_tmp_val()
102209

103-
def reset(self, regs=None, mems=None):
210+
def reset(self, regs=None, mems: Memory = None):
104211
""" Initial state
105212
"""
106213
if regs is None:
@@ -111,8 +218,8 @@ def reset(self, regs=None, mems=None):
111218

112219
self.regs = {}
113220
self.stack = []
114-
self.mem = defaultdict(new_tmp_val16) # Dict of label -> value in memory
115-
self._flags = [Flags(), Flags()]
221+
self.mem = Memory() # Dict of label -> value in memory
222+
self._flags = Flags(), Flags()
116223

117224
# # Memory for IX / IY accesses
118225
self.ix_ptr = set()
@@ -136,11 +243,6 @@ def reset(self, regs=None, mems=None):
136243
self.regs['ix'] = '{}{}{}'.format(self.regs['ixh'], HL_SEP, self.regs['ixl'])
137244
self.regs['iy'] = '{}{}{}'.format(self.regs['iyh'], HL_SEP, self.regs['iyl'])
138245

139-
self._16bit = {'b': 'bc', 'c': 'bc', 'd': 'de', 'e': 'de', 'h': 'hl', 'l': 'hl',
140-
"b'": "bc'", "c'": "bc'", "d'": "de'", "e'": "de'", "h'": "hl'", "l'": "hl'",
141-
'ixy': 'ix', 'ixl': 'ix', 'iyh': 'iy', 'iyl': 'iy', 'a': 'af', "a'": "af'",
142-
'f': 'af', "f'": "af'"}
143-
144246
self.regs.update(**regs)
145247
self.mem.update(**mems)
146248

@@ -194,14 +296,21 @@ def shift_idx_regs_refs(self, r, offset):
194296
for i in range(-128, 128):
195297
idx = '%s%+i' % (r, i)
196298
old_idx = '%s%+i' % (r, offset + i)
197-
self.mem[idx] = new_tmp_val() if offset + i > 127 else self.mem[old_idx]
299+
self.mem.write_8_bit_value(
300+
idx,
301+
new_tmp_val() if offset + i > 127 else self.mem.read_8_bit_value(old_idx)
302+
)
198303
else:
199304
for i in range(127, -129, -1):
200305
idx = '%s%+i' % (r, i)
201306
old_idx = '%s%+i' % (r, offset + i)
202-
self.mem[idx] = new_tmp_val() if offset + i < -128 else self.mem[old_idx]
307+
self.mem.write_8_bit_value(
308+
idx,
309+
new_tmp_val() if offset + i < -128 else self.mem.read_8_bit_value(old_idx)
310+
)
203311

204312
def set(self, r, val):
313+
orig_val = val
205314
val = self.get(val)
206315
is_num = is_number(val)
207316

@@ -225,10 +334,7 @@ def set(self, r, val):
225334

226335
if r in {'(hl)', '(bc)', '(de)'}: # ld (bc|de|hl), val
227336
r = self.regs[r[1:-1]]
228-
if r in self.mem and val == self.mem[r]:
229-
return # Already set
230-
231-
self.mem[r] = val
337+
self.mem.write_8_bit_value(r, val)
232338
return
233339

234340
if r[0] == '(': # (mem) <- r => store in memory address
@@ -238,10 +344,17 @@ def set(self, r, val):
238344
r = "{}{}{}".format(*idx)
239345
self.ix_ptr.add(idx)
240346
val = LO16_val(val)
347+
self.mem.write_8_bit_value(r, val)
348+
return
241349

242-
if r in self.mem and val == self.mem[r]:
243-
return # the same value to the same pos does nothing... (strong assumption: NON-VOLATILE)
244-
self.mem[r] = val
350+
if is_8bit_oper_register(orig_val):
351+
self.mem.write_8_bit_value(r, val)
352+
return
353+
354+
if is_unknown8(val):
355+
val = f'{new_tmp_val()}{HL_SEP}{val}'
356+
357+
self.mem.write_16_bit_value(r, val)
245358
return
246359

247360
if is_8bit_oper_register(r):
@@ -270,11 +383,17 @@ def set(self, r, val):
270383

271384
# a 16 bit reg
272385
assert r in self.regs
273-
assert is_num or is_unknown16(val), "val '{}' is not a number nor an unknown16".format(val)
386+
387+
if is_unknown8(val):
388+
val = f'{new_tmp_val()}{HL_SEP}{val}'
389+
assert is_num or is_unknown16(val) or is_label(val), "val '{}' is neither a number, nor a label" \
390+
" nor an unknown16".format(val)
274391

275392
self.regs[r] = val
276393
if is_16bit_composed_register(r): # sp register is not included. Special case
277394
if not is_num:
395+
if is_label(val):
396+
val = new_tmp_val16_from_label(val)
278397
self.regs[HI16(r)], self.regs[LO16(r)] = val.split(HL_SEP)
279398
else:
280399
val = valnum(val)
@@ -296,32 +415,29 @@ def get(self, r):
296415

297416
if r.lower() in {'(hl)', '(bc)', '(de)'}:
298417
i = self.regs[r.lower()[1:-1]]
299-
return self.mem[i]
418+
return self.mem.read_8_bit_value(i)
300419

301420
if r[:1] == '(':
302421
v_ = r[1:-1].strip()
303422
idx = idx_args(v_)
304423
if idx is not None:
305424
v_ = "{}{}{}".format(*idx)
306425
self.ix_ptr.add(idx)
426+
return self.mem.read_8_bit_value(v_)
307427

308-
val = self.mem[v_]
309-
if idx is not None:
310-
self.mem[v_] = val = LO16_val(val)
311-
312-
return val
428+
return self.mem.read_16_bit_value(v_)
313429

314430
if is_number(r):
315431
return str(valnum(r))
316432

317433
if is_unknown(r):
318434
return r
319435

320-
r = r.lower()
321-
if not is_register(r):
322-
return None
436+
r_ = r.lower()
437+
if not is_register(r_): # If it's not a register, it *must* be a label
438+
return r
323439

324-
return self.regs[r]
440+
return self.regs[r_]
325441

326442
def getv(self, r):
327443
""" Like the above, but returns the <int> value or None.
@@ -375,7 +491,7 @@ def inc(self, r):
375491
self.set_flag(None)
376492

377493
r_ = r[1:-1]
378-
self.mem[self.get(r_)] = str(v_)
494+
self.mem.write_8_bit_value(self.get(r_), str(v_))
379495
return
380496

381497
if self.getv(r) is not None:
@@ -411,7 +527,7 @@ def dec(self, r):
411527
self.set_flag(None)
412528

413529
r_ = r[1:-1]
414-
self.mem[self.get(r_)] = str(v_)
530+
self.mem.write_8_bit_value(self.get(r_), str(v_))
415531
return
416532

417533
if self.getv(r) is not None:

0 commit comments

Comments
 (0)