1111compute the maximum XOR.
1212"""
1313
14+ from __future__ import annotations
15+
1416
1517class TrieNode :
1618 """Node of the Bitwise Trie."""
1719
1820 def __init__ (self ) -> None :
19- self .child = [ None , None ] # child[0] for bit 0, child[1] for bit 1
21+ self .child : list [ TrieNode | None ] = [ None , None ]
2022
2123
2224class BitwiseTrieMaxXOR :
@@ -30,11 +32,17 @@ def __init__(self) -> None:
3032 self .root = TrieNode ()
3133
3234 def insert (self , num : int ) -> None :
33- """Insert a number into the trie."""
35+ """
36+ Insert a number into the trie.
37+
38+ >>> trie = BitwiseTrieMaxXOR()
39+ >>> trie.insert(5)
40+ >>> trie.insert(10)
41+ """
3442 node = self .root
3543 for i in range (31 , - 1 , - 1 ): # 32-bit integers
3644 bit = (num >> i ) & 1
37- if not node .child [bit ]:
45+ if node .child [bit ] is None :
3846 node .child [bit ] = TrieNode ()
3947 node = node .child [bit ]
4048
@@ -47,21 +55,41 @@ def query_max_xor(self, num: int) -> int:
4755
4856 Returns:
4957 int: Maximum XOR value achievable with `num`.
58+
59+ >>> trie = BitwiseTrieMaxXOR()
60+ >>> trie.insert(5)
61+ >>> trie.insert(10)
62+ >>> trie.query_max_xor(5)
63+ 15
64+ >>> trie.query_max_xor(10)
65+ 15
5066 """
5167 node = self .root
5268 max_xor = 0
5369 for i in range (31 , - 1 , - 1 ):
5470 bit = (num >> i ) & 1
5571 toggle = 1 - bit
56- if node .child [toggle ]:
72+ if node .child [toggle ] is not None :
5773 max_xor |= 1 << i
5874 node = node .child [toggle ]
5975 else :
6076 node = node .child [bit ]
6177 return max_xor
6278
6379 def find_maximum_xor (self , nums : list [int ]) -> int :
64- """Compute maximum XOR of any two numbers in `nums`."""
80+ """
81+ Compute maximum XOR of any two numbers in `nums`.
82+
83+ >>> solver = BitwiseTrieMaxXOR()
84+ >>> solver.find_maximum_xor([3, 10, 5, 25, 2, 8])
85+ 28
86+ >>> solver.find_maximum_xor([42])
87+ 0
88+ >>> solver.find_maximum_xor([8, 1])
89+ 9
90+ >>> solver.find_maximum_xor([0, 0, 0])
91+ 0
92+ """
6593 if not nums :
6694 return 0
6795 for num in nums :
@@ -70,8 +98,13 @@ def find_maximum_xor(self, nums: list[int]) -> int:
7098
7199
72100if __name__ == "__main__" :
73- print ( "************ Testing Bitwise Trie Maximum XOR Algorithm ************ \n " )
101+ import doctest
74102
103+ doctest .testmod ()
104+ print ("All doctests passed!" )
105+
106+ # Optional: full test suite
107+ print ("************ Testing Bitwise Trie Maximum XOR Algorithm ************\n " )
75108 test_cases = [
76109 ([3 , 10 , 5 , 25 , 2 , 8 ], 28 ),
77110 ([42 ], 0 ),
@@ -86,7 +119,7 @@ def find_maximum_xor(self, nums: list[int]) -> int:
86119 ]
87120
88121 for idx , (nums , expected ) in enumerate (test_cases , 1 ):
89- solver = BitwiseTrieMaxXOR () # Reset trie for each test case
122+ solver = BitwiseTrieMaxXOR ()
90123 result = solver .find_maximum_xor (nums )
91124 print (f"Testcase { idx } : Expected={ expected } , Got={ result } " )
92125 assert result == expected , f"Testcase { idx } failed!"
0 commit comments