11# -*- coding: utf-8 -*-
22
3+ import re
4+ from typing import Dict , List , Tuple
35from collections import defaultdict
6+
47from . import asm
58
6- from .helpers import new_tmp_val , new_tmp_val16 , HI16 , LO16 , HL_SEP
7- from .helpers import is_unknown , is_unknown16 , valnum , is_number
9+ from .helpers import new_tmp_val16 , HI16 , LO16 , HL_SEP
10+ from .helpers import is_unknown , is_unknown8 , is_unknown16 , valnum
11+ from .helpers import is_number , is_label , new_tmp_val , new_tmp_val16_from_label
812from .helpers import is_register , is_8bit_oper_register , is_16bit_composed_register
9- from .helpers import get_L_from_unknown_value , idx_args , LO16_val
13+ from .helpers import get_L_from_unknown_value , get_H_from_unknown_value , idx_args , LO16_val
14+ from .helpers import get_orig_label_from_unknown16
15+
1016
17+ RE_OFFSET = re .compile (r'(^[*._a-zA-Z0-9]+(?:[+-]\d+)*)([+-]\d+)$' )
1118
12- class Flags (object ):
19+
20+ class Flags :
1321 def __init__ (self ):
1422 self .C = None
1523 self .Z = None
1624 self .P = None
1725 self .S = None
1826
1927
20- class CPUState (object ):
28+ class Memory :
29+ """ Implements a memory representation, dealing with unknown values
30+ """
31+ def __init__ (self ):
32+ self .mem = defaultdict (new_tmp_val )
33+
34+ def _get_hl_addr (self , addr : str ) -> Tuple [str , str ]:
35+ if is_number (addr ):
36+ return addr , str (int (addr ) + 1 )
37+
38+ ptr = RE_OFFSET .match (addr )
39+ if ptr is None :
40+ return addr , f"{ addr } +1"
41+
42+ base , off , * _ = ptr .groups ()
43+ off = int (off )
44+ if off == 0 :
45+ return self ._get_hl_addr (base )
46+
47+ if off == - 1 :
48+ return addr , base
49+
50+ return addr , "%s%+i" % (base , off + 1 )
51+
52+ def read_16_bit_value (self , addr : str ) -> str :
53+ addr_lo , addr_hi = self ._get_hl_addr (addr )
54+ hi = self .mem [addr_hi ]
55+ lo = self .mem [addr_lo ]
56+ if is_number (hi ) and is_number (lo ):
57+ return str (int (lo ) + 256 * int (hi ))
58+
59+ result = f"{ hi } |{ lo } "
60+ if (label_ := get_orig_label_from_unknown16 ("" )) is not None :
61+ return label_
62+
63+ return result
64+
65+ def write_16_bit_value (self , addr : str , value : str ) -> None :
66+ if is_number (value ):
67+ v_hi = str ((int (value ) >> 8 ) & 0xFF )
68+ v_lo = str (int (value ) & 0xFF )
69+ else :
70+ if is_unknown16 (value ):
71+ v_ = value
72+ else :
73+ v_ = new_tmp_val16_from_label (value )
74+ v_hi = get_H_from_unknown_value (v_ )
75+ v_lo = get_L_from_unknown_value (v_ )
76+
77+ addr_lo , addr_hi = self ._get_hl_addr (addr )
78+ self .mem [addr_lo ] = v_lo
79+ self .mem [addr_hi ] = v_hi
80+
81+ def read_8_bit_value (self , addr : str ) -> str :
82+ addr_lo , _ = self ._get_hl_addr (addr )
83+ lo = self .mem [addr_lo ]
84+ if is_number (lo ):
85+ return str (int (lo ) & 0xFF )
86+
87+ return lo
88+
89+ def write_8_bit_value (self , addr : str , value : str ) -> None :
90+ if is_number (value ):
91+ value = str (int (value ) & 0xFF )
92+ elif is_unknown16 (value ):
93+ value = get_L_from_unknown_value (value )
94+ elif is_label (value ):
95+ value = get_L_from_unknown_value (new_tmp_val16_from_label (value ))
96+
97+ addr_lo , _ = self ._get_hl_addr (addr )
98+ self .mem [addr_lo ] = value
99+
100+ def update (self , ** kwargs ):
101+ self .mem .update (kwargs )
102+
103+ def keys (self ) -> List [str ]:
104+ return list (self .mem .keys ())
105+
106+ def values (self ) -> List [str ]:
107+ return list (self .mem .values ())
108+
109+ def items (self ) -> List [Tuple [str , str ]]:
110+ return list (self .mem .items ())
111+
112+ def __iter__ (self ):
113+ return (x for x in self .mem )
114+
115+ def __len__ (self ):
116+ return len (self .mem )
117+
118+ def __getitem__ (self , item ):
119+ return self .mem [item ]
120+
121+
122+ class CPUState :
21123 """ A class storing registers value information (CPU State).
22124 """
125+ mem : Memory
126+ stack : List [str ]
127+ regs : Dict [str , str ]
128+ _flags : Tuple [Flags , Flags ]
129+ _16bit : Dict [str , str ]
130+
23131 def __init__ (self ):
24- self .regs = None
25- self .stack = None
26- self .mem = None
27- self ._flags = None
28- self ._16bit = None
132+ self ._16bit = {'b' : 'bc' , 'c' : 'bc' , 'd' : 'de' , 'e' : 'de' , 'h' : 'hl' , 'l' : 'hl' ,
133+ "b'" : "bc'" , "c'" : "bc'" , "d'" : "de'" , "e'" : "de'" , "h'" : "hl'" , "l'" : "hl'" ,
134+ 'ixy' : 'ix' , 'ixl' : 'ix' , 'iyh' : 'iy' , 'iyl' : 'iy' , 'a' : 'af' , "a'" : "af'" ,
135+ 'f' : 'af' , "f'" : "af'" }
29136 self .reset ()
30137
31138 @property
@@ -100,7 +207,7 @@ def S(self, val):
100207 else :
101208 self .regs ['f' ] = new_tmp_val ()
102209
103- def reset (self , regs = None , mems = None ):
210+ def reset (self , regs = None , mems : Memory = None ):
104211 """ Initial state
105212 """
106213 if regs is None :
@@ -111,8 +218,8 @@ def reset(self, regs=None, mems=None):
111218
112219 self .regs = {}
113220 self .stack = []
114- self .mem = defaultdict ( new_tmp_val16 ) # Dict of label -> value in memory
115- self ._flags = [ Flags (), Flags ()]
221+ self .mem = Memory ( ) # Dict of label -> value in memory
222+ self ._flags = Flags (), Flags ()
116223
117224 # # Memory for IX / IY accesses
118225 self .ix_ptr = set ()
@@ -136,11 +243,6 @@ def reset(self, regs=None, mems=None):
136243 self .regs ['ix' ] = '{}{}{}' .format (self .regs ['ixh' ], HL_SEP , self .regs ['ixl' ])
137244 self .regs ['iy' ] = '{}{}{}' .format (self .regs ['iyh' ], HL_SEP , self .regs ['iyl' ])
138245
139- self ._16bit = {'b' : 'bc' , 'c' : 'bc' , 'd' : 'de' , 'e' : 'de' , 'h' : 'hl' , 'l' : 'hl' ,
140- "b'" : "bc'" , "c'" : "bc'" , "d'" : "de'" , "e'" : "de'" , "h'" : "hl'" , "l'" : "hl'" ,
141- 'ixy' : 'ix' , 'ixl' : 'ix' , 'iyh' : 'iy' , 'iyl' : 'iy' , 'a' : 'af' , "a'" : "af'" ,
142- 'f' : 'af' , "f'" : "af'" }
143-
144246 self .regs .update (** regs )
145247 self .mem .update (** mems )
146248
@@ -194,14 +296,21 @@ def shift_idx_regs_refs(self, r, offset):
194296 for i in range (- 128 , 128 ):
195297 idx = '%s%+i' % (r , i )
196298 old_idx = '%s%+i' % (r , offset + i )
197- self .mem [idx ] = new_tmp_val () if offset + i > 127 else self .mem [old_idx ]
299+ self .mem .write_8_bit_value (
300+ idx ,
301+ new_tmp_val () if offset + i > 127 else self .mem .read_8_bit_value (old_idx )
302+ )
198303 else :
199304 for i in range (127 , - 129 , - 1 ):
200305 idx = '%s%+i' % (r , i )
201306 old_idx = '%s%+i' % (r , offset + i )
202- self .mem [idx ] = new_tmp_val () if offset + i < - 128 else self .mem [old_idx ]
307+ self .mem .write_8_bit_value (
308+ idx ,
309+ new_tmp_val () if offset + i < - 128 else self .mem .read_8_bit_value (old_idx )
310+ )
203311
204312 def set (self , r , val ):
313+ orig_val = val
205314 val = self .get (val )
206315 is_num = is_number (val )
207316
@@ -225,10 +334,7 @@ def set(self, r, val):
225334
226335 if r in {'(hl)' , '(bc)' , '(de)' }: # ld (bc|de|hl), val
227336 r = self .regs [r [1 :- 1 ]]
228- if r in self .mem and val == self .mem [r ]:
229- return # Already set
230-
231- self .mem [r ] = val
337+ self .mem .write_8_bit_value (r , val )
232338 return
233339
234340 if r [0 ] == '(' : # (mem) <- r => store in memory address
@@ -238,10 +344,17 @@ def set(self, r, val):
238344 r = "{}{}{}" .format (* idx )
239345 self .ix_ptr .add (idx )
240346 val = LO16_val (val )
347+ self .mem .write_8_bit_value (r , val )
348+ return
349+
350+ if is_8bit_oper_register (orig_val ):
351+ self .mem .write_8_bit_value (r , val )
352+ return
353+
354+ if is_unknown8 (val ):
355+ val = f'{ new_tmp_val ()} { HL_SEP } { val } '
241356
242- if r in self .mem and val == self .mem [r ]:
243- return # the same value to the same pos does nothing... (strong assumption: NON-VOLATILE)
244- self .mem [r ] = val
357+ self .mem .write_16_bit_value (r , val )
245358 return
246359
247360 if is_8bit_oper_register (r ):
@@ -270,11 +383,17 @@ def set(self, r, val):
270383
271384 # a 16 bit reg
272385 assert r in self .regs
273- assert is_num or is_unknown16 (val ), "val '{}' is not a number nor an unknown16" .format (val )
386+
387+ if is_unknown8 (val ):
388+ val = f'{ new_tmp_val ()} { HL_SEP } { val } '
389+ assert is_num or is_unknown16 (val ) or is_label (val ), "val '{}' is neither a number, nor a label" \
390+ " nor an unknown16" .format (val )
274391
275392 self .regs [r ] = val
276393 if is_16bit_composed_register (r ): # sp register is not included. Special case
277394 if not is_num :
395+ if is_label (val ):
396+ val = new_tmp_val16_from_label (val )
278397 self .regs [HI16 (r )], self .regs [LO16 (r )] = val .split (HL_SEP )
279398 else :
280399 val = valnum (val )
@@ -296,32 +415,29 @@ def get(self, r):
296415
297416 if r .lower () in {'(hl)' , '(bc)' , '(de)' }:
298417 i = self .regs [r .lower ()[1 :- 1 ]]
299- return self .mem [ i ]
418+ return self .mem . read_8_bit_value ( i )
300419
301420 if r [:1 ] == '(' :
302421 v_ = r [1 :- 1 ].strip ()
303422 idx = idx_args (v_ )
304423 if idx is not None :
305424 v_ = "{}{}{}" .format (* idx )
306425 self .ix_ptr .add (idx )
426+ return self .mem .read_8_bit_value (v_ )
307427
308- val = self .mem [v_ ]
309- if idx is not None :
310- self .mem [v_ ] = val = LO16_val (val )
311-
312- return val
428+ return self .mem .read_16_bit_value (v_ )
313429
314430 if is_number (r ):
315431 return str (valnum (r ))
316432
317433 if is_unknown (r ):
318434 return r
319435
320- r = r .lower ()
321- if not is_register (r ):
322- return None
436+ r_ = r .lower ()
437+ if not is_register (r_ ): # If it's not a register, it *must* be a label
438+ return r
323439
324- return self .regs [r ]
440+ return self .regs [r_ ]
325441
326442 def getv (self , r ):
327443 """ Like the above, but returns the <int> value or None.
@@ -365,15 +481,17 @@ def inc(self, r):
365481 """
366482 if not is_register (r ):
367483 if r [0 ] == '(' : # a memory position, basically: inc(hl)
368- r_ = r [1 :- 1 ].strip ()
369- v_ = self .getv (self .mem .get (r_ , None ))
484+ v_ = self .getv (r )
370485 if v_ is not None :
371486 v_ = (v_ + 1 ) & 0xFF
372- self .mem [r_ ] = str (v_ )
373487 self .Z = int (v_ == 0 ) # HINT: This might be improved
374488 self .C = int (v_ == 0 )
375489 else :
376- self .mem [r_ ] = new_tmp_val ()
490+ v_ = new_tmp_val ()
491+ self .set_flag (None )
492+
493+ r_ = r [1 :- 1 ]
494+ self .mem .write_8_bit_value (self .get (r_ ), str (v_ ))
377495 return
378496
379497 if self .getv (r ) is not None :
@@ -399,16 +517,17 @@ def dec(self, r):
399517 """ Does dec on the register and precomputes flags
400518 """
401519 if not is_register (r ):
402- if r [0 ] == '(' : # a memory position, basically: inc(hl)
403- r_ = r [1 :- 1 ].strip ()
404- v_ = self .getv (self .mem .get (r_ , None ))
405- if v_ is not None :
406- v_ = (v_ - 1 ) & 0xFF
407- self .mem [r_ ] = str (v_ )
408- self .Z = int (v_ == 0 ) # HINT: This might be improved
409- self .C = int (v_ == 0xFF )
410- else :
411- self .mem [r_ ] = new_tmp_val ()
520+ v_ = self .getv (r )
521+ if v_ is not None :
522+ v_ = (v_ - 1 ) & 0xFF
523+ self .Z = int (v_ == 0 ) # HINT: This might be improved
524+ self .C = int (v_ == 0xFF )
525+ else :
526+ v_ = new_tmp_val ()
527+ self .set_flag (None )
528+
529+ r_ = r [1 :- 1 ]
530+ self .mem .write_8_bit_value (self .get (r_ ), str (v_ ))
412531 return
413532
414533 if self .getv (r ) is not None :
0 commit comments