Skip to content

Commit 16da829

Browse files
committed
Identity set fixes and refact
1 parent ca7d8a9 commit 16da829

1 file changed

Lines changed: 22 additions & 18 deletions

File tree

identityset.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,31 @@
22
# -*- coding: utf-8 -*-
33
# vim:ts=4:et:
44

5-
import collections
6-
75

86
class IdentitySet(object):
97
""" This set implementation only adds items
108
if they are not exactly the same (same reference)
119
preserving its order (OrderedDict). Allows deleting by ith-index.
1210
"""
13-
def __init__(self, l=None):
11+
def __init__(self, elems=None):
1412
self.elems = []
1513
self._elems = set()
16-
if l is not None:
17-
self.add(l)
14+
self.update(elems or [])
1815

19-
def add(self, l):
20-
if not isinstance(l, collections.Iterable):
21-
l = [l]
22-
self.elems.extend(x for x in l if x not in self._elems)
23-
self._elems.update(x for x in l)
16+
def add(self, elem):
17+
self.elems.append(elem)
18+
self._elems.add(elem)
2419

25-
def remove(self, l):
26-
if not isinstance(l, collections.Iterable):
27-
l = [l]
20+
def remove(self, elem):
21+
""" Removes an element if it exits. Otherwise does nothing.
22+
Returns if the element was removed.
23+
"""
24+
if elem in self._elems:
25+
self._elems.remove(elem)
26+
self.elems = [x for x in self.elems if x in self._elems]
27+
return True
2828

29-
self._elems.difference_update(l)
30-
self.elems = [x for x in self.elems if x not in self._elems]
29+
return False
3130

3231
def __len__(self):
3332
return len(self.elems)
@@ -45,11 +44,16 @@ def __delitem__(self, key):
4544
self.pop(self.elems.index(key))
4645

4746
def intersection(self, other):
48-
return IdentitySet([x for x in self.elems if x in self._elems.intersection(other)])
47+
return IdentitySet(self._elems.intersection(other))
4948

5049
def union(self, other):
5150
return IdentitySet(self.elems + [x for x in other])
5251

5352
def pop(self, i):
54-
tmp = self.elems.pop(i)
55-
self._elems.remove(tmp)
53+
result = self.elems.pop(i)
54+
self._elems.remove(result)
55+
return result
56+
57+
def update(self, elems):
58+
self.elems.extend(x for x in elems if x not in self._elems)
59+
self._elems.update(x for x in elems)

0 commit comments

Comments
 (0)