Pulling in skeleton code
[ejpi] / src / util / concurrent.py
index 503a1b4..a6499fe 100644 (file)
@@ -7,6 +7,76 @@ import errno
 import time
 import functools
 import contextlib
+import logging
+
+import misc
+
+
+_moduleLogger = logging.getLogger(__name__)
+
+
+class AsyncLinearExecution(object):
+
+       def __init__(self, pool, func):
+               self._pool = pool
+               self._func = func
+               self._run = None
+
+       def start(self, *args, **kwds):
+               assert self._run is None, "Task already started"
+               self._run = self._func(*args, **kwds)
+               trampoline, args, kwds = self._run.send(None) # priming the function
+               self._pool.add_task(
+                       trampoline,
+                       args,
+                       kwds,
+                       self.on_success,
+                       self.on_error,
+               )
+
+       @misc.log_exception(_moduleLogger)
+       def on_success(self, result):
+               _moduleLogger.debug("Processing success for: %r", self._func)
+               try:
+                       trampoline, args, kwds = self._run.send(result)
+               except StopIteration, e:
+                       pass
+               else:
+                       self._pool.add_task(
+                               trampoline,
+                               args,
+                               kwds,
+                               self.on_success,
+                               self.on_error,
+                       )
+
+       @misc.log_exception(_moduleLogger)
+       def on_error(self, error):
+               _moduleLogger.debug("Processing error for: %r", self._func)
+               try:
+                       trampoline, args, kwds = self._run.throw(error)
+               except StopIteration, e:
+                       pass
+               else:
+                       self._pool.add_task(
+                               trampoline,
+                               args,
+                               kwds,
+                               self.on_success,
+                               self.on_error,
+                       )
+
+       def __repr__(self):
+               return "<async %s at 0x%x>" % (self._func.__name__, id(self))
+
+       def __hash__(self):
+               return hash(self._func)
+
+       def __eq__(self, other):
+               return self._func == other._func
+
+       def __ne__(self, other):
+               return self._func != other._func
 
 
 def synchronized(lock):