Bringing doctests upto snuff with Python26
[ejpi] / src / libraries / recipes / algorithms.py
index a2dc356..5da8b80 100644 (file)
@@ -87,51 +87,37 @@ def iterstep(iterator, n):
                        iterator.next()
 
 
-def itergroup(iterator, count, to_container=tuple):
+def itergroup(iterator, count, padValue = None):
        """
        Iterate in groups of 'count' values. If there
        aren't enough values, the last result is padded with
        None.
 
        >>> for val in itergroup([1, 2, 3, 4, 5, 6], 3):
-       ...     print val
+       ...     print tuple(val)
        (1, 2, 3)
        (4, 5, 6)
-       >>> for val in itergroup([1, 2, 3, 4, 5, 6], 3, list):
-       ...     print val
+       >>> for val in itergroup([1, 2, 3, 4, 5, 6], 3):
+       ...     print list(val)
        [1, 2, 3]
        [4, 5, 6]
        >>> for val in itergroup([1, 2, 3, 4, 5, 6, 7], 3):
-       ...     print val
+       ...     print tuple(val)
        (1, 2, 3)
        (4, 5, 6)
        (7, None, None)
        >>> for val in itergroup("123456", 3):
-       ...     print val
+       ...     print tuple(val)
        ('1', '2', '3')
        ('4', '5', '6')
-       >>> for val in itergroup("123456", 3, lambda i: "".join(s for s in i if s is not None)):
-       ...     print repr(val)
+       >>> for val in itergroup("123456", 3):
+       ...     print repr("".join(val))
        '123'
        '456'
        """
-
-       iterator = iter(iterator)
-       values_left = [True]
-
-       def values():
-               values_left[0] = False
-               for x in range(count):
-                       try:
-                               yield iterator.next()
-                               values_left[0] = True
-                       except StopIteration:
-                               yield None
-       while True:
-               value = to_container(values())
-               if not values_left[0]:
-                       raise StopIteration
-               yield value
+       paddedIterator = itertools.chain(iterator, itertools.repeat(padValue, count-1))
+       nIterators = (paddedIterator, ) * count
+       return itertools.izip(*nIterators)
 
 
 def xzip(*iterators):