I have a generic multiprocessing worker class that takes items from a queue to be processed. A user of the worker class would need to pass a function that processes each item. However, some processing functions need setup-code.
The current implementation uses a generator function that has to be correctly implemented by the user to correctly perform the setup code only once, process the items from the queue, and handle the StopIteration exception raised when the worker finishes normally.
Can a more straightforward and reliable method be used to separate the setup code from the processing code and handle the exceptions raised by the worker?
Here is what I have:
import multiprocessing as mp
import typing
P = typing.Callable[[], typing.Generator[None, None, None]]
Q: typing.TypeAlias = "mp.Queue"
class Worker(mp.Process):
def __init__(self, queue: Q, processor: P):
mp.Process.__init__(self)
self.queue = queue
self.processor = processor
def run(self):
processor = self.processor()
next(processor) # start the processor
while True:
item = self.queue.get()
processor.send(item)
if item is None:
break
class WorkerPool:
def __init__(self, n_workers: int, processor_generator: P, queue: Q):
self.workers = [Worker(queue, processor_generator) for _ in range(n_workers)]
self.queue = queue
def __enter__(self):
for worker in self.workers:
worker.start()
def signal_end(self):
for _ in self.workers:
self.queue.put(None)
def terminate(self):
for worker in self.workers:
worker.terminate()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.signal_end()
self.join()
return True
self.terminate()
return False
def join(self):
for worker in self.workers:
worker.join()
class GeneratorWorkerManager:
def __init__(
self, item_generator: typing.Generator, processor_generator: P, n_workers: int
) -> None:
queue: Q = mp.Queue()
with WorkerPool(n_workers, processor_generator, queue):
for item in item_generator:
queue.put(item)
A user of the GeneratorWorkerManager class could do the following:
def processor():
# All sorts of setup code possible, including a with-statement.
item = yield
while item is not None:
# process item
print(item)
item = yield
return
items = range(10)
GeneratorWorkerManager(items, processor, 1)
Where the worker would print 0 to 9. However, this relies on the user implementing the processor function correctly. The worker also raises a StopIteration exception when it finishes normally.
Is there a better way to use setup code and processing code in the same context?
The trick is to pass a function that returns a context manager. Here I've wrapped the open function in another function as an example.
import multiprocessing as mp
from typing import Callable, ContextManager, Iterable, TypeVar
I = TypeVar("I")
S = TypeVar("S")
class Worker(mp.Process):
def __init__(
self,
queue: "mp.Queue[I|None]",
processor: Callable[[S, I], None],
setup: Callable[[], ContextManager[S]],
):
mp.Process.__init__(self)
self.queue = queue
self.processor = processor
self.user_setup = setup
def run(self):
with self.user_setup() as setup:
while True:
item = self.queue.get()
if item is None:
break
self.processor(setup, item)
class WorkerPool:
def __init__(
self,
n_workers: int,
processor: Callable[[S, I], None],
queue: "mp.Queue[I|None]",
setup: Callable[[], ContextManager[S]],
):
self.workers = [Worker(queue, processor, setup) for _ in range(n_workers)]
self.queue = queue
def __enter__(self):
for worker in self.workers:
worker.start()
def signal_end(self):
for _ in self.workers:
self.queue.put(None)
def terminate(self):
for worker in self.workers:
worker.terminate()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.signal_end()
self.join()
return True
self.terminate()
return False
def join(self):
for worker in self.workers:
worker.join()
class GeneratorWorkerManager:
def __init__(
self,
item_generator: Iterable[I],
processor: Callable[[S, I], None],
n_workers: int,
setup: Callable[[], ContextManager[S]],
) -> None:
queue: "mp.Queue[I|None]" = mp.Queue()
with WorkerPool(n_workers, processor, queue, setup):
for item in item_generator:
queue.put(item)
def custom_setup():
return open("test.txt", "w", encoding="utf-8")
def custom_processor(setup, item) -> None:
f = setup
f.write(str(item))
items = range(10)
if __name__ == "__main__":
mp.freeze_support()
GeneratorWorkerManager(items, custom_processor, 1, custom_setup)
It is not possible to use a setup function that does not support the with statement, but It would be easy to extend this with two setup variables, one for context-setup, one for regular setup.
Related
Is there a Pool class for worker threads, similar to the multiprocessing module's Pool class?
I like for example the easy way to parallelize a map function
def long_running_func(p):
c_func_no_gil(p)
p = multiprocessing.Pool(4)
xs = p.map(long_running_func, range(100))
however I would like to do it without the overhead of creating new processes.
I know about the GIL. However, in my usecase, the function will be an IO-bound C function for which the python wrapper will release the GIL before the actual function call.
Do I have to write my own threading pool?
I just found out that there actually is a thread-based Pool interface in the multiprocessing module, however it is hidden somewhat and not properly documented.
It can be imported via
from multiprocessing.pool import ThreadPool
It is implemented using a dummy Process class wrapping a python thread. This thread-based Process class can be found in multiprocessing.dummy which is mentioned briefly in the docs. This dummy module supposedly provides the whole multiprocessing interface based on threads.
In Python 3 you can use concurrent.futures.ThreadPoolExecutor, i.e.:
executor = ThreadPoolExecutor(max_workers=10)
a = executor.submit(my_function)
See the docs for more info and examples.
Yes, and it seems to have (more or less) the same API.
import multiprocessing
def worker(lnk):
....
def start_process():
.....
....
if(PROCESS):
pool = multiprocessing.Pool(processes=POOL_SIZE, initializer=start_process)
else:
pool = multiprocessing.pool.ThreadPool(processes=POOL_SIZE,
initializer=start_process)
pool.map(worker, inputs)
....
For something very simple and lightweight (slightly modified from here):
from Queue import Queue
from threading import Thread
class Worker(Thread):
"""Thread executing tasks from a given tasks queue"""
def __init__(self, tasks):
Thread.__init__(self)
self.tasks = tasks
self.daemon = True
self.start()
def run(self):
while True:
func, args, kargs = self.tasks.get()
try:
func(*args, **kargs)
except Exception, e:
print e
finally:
self.tasks.task_done()
class ThreadPool:
"""Pool of threads consuming tasks from a queue"""
def __init__(self, num_threads):
self.tasks = Queue(num_threads)
for _ in range(num_threads):
Worker(self.tasks)
def add_task(self, func, *args, **kargs):
"""Add a task to the queue"""
self.tasks.put((func, args, kargs))
def wait_completion(self):
"""Wait for completion of all the tasks in the queue"""
self.tasks.join()
if __name__ == '__main__':
from random import randrange
from time import sleep
delays = [randrange(1, 10) for i in range(100)]
def wait_delay(d):
print 'sleeping for (%d)sec' % d
sleep(d)
pool = ThreadPool(20)
for i, d in enumerate(delays):
pool.add_task(wait_delay, d)
pool.wait_completion()
To support callbacks on task completion you can just add the callback to the task tuple.
Hi to use the thread pool in Python you can use this library :
from multiprocessing.dummy import Pool as ThreadPool
and then for use, this library do like that :
pool = ThreadPool(threads)
results = pool.map(service, tasks)
pool.close()
pool.join()
return results
The threads are the number of threads that you want and tasks are a list of task that most map to the service.
Yes, there is a threading pool similar to the multiprocessing Pool, however, it is hidden somewhat and not properly documented. You can import it by following way:-
from multiprocessing.pool import ThreadPool
Just I show you simple example
def test_multithread_stringio_read_csv(self):
# see gh-11786
max_row_range = 10000
num_files = 100
bytes_to_df = [
'\n'.join(
['%d,%d,%d' % (i, i, i) for i in range(max_row_range)]
).encode() for j in range(num_files)]
files = [BytesIO(b) for b in bytes_to_df]
# read all files in many threads
pool = ThreadPool(8)
results = pool.map(self.read_csv, files)
first_result = results[0]
for result in results:
tm.assert_frame_equal(first_result, result)
Here's the result I finally ended up using. It's a modified version of the classes by dgorissen above.
File: threadpool.py
from queue import Queue, Empty
import threading
from threading import Thread
class Worker(Thread):
_TIMEOUT = 2
""" Thread executing tasks from a given tasks queue. Thread is signalable,
to exit
"""
def __init__(self, tasks, th_num):
Thread.__init__(self)
self.tasks = tasks
self.daemon, self.th_num = True, th_num
self.done = threading.Event()
self.start()
def run(self):
while not self.done.is_set():
try:
func, args, kwargs = self.tasks.get(block=True,
timeout=self._TIMEOUT)
try:
func(*args, **kwargs)
except Exception as e:
print(e)
finally:
self.tasks.task_done()
except Empty as e:
pass
return
def signal_exit(self):
""" Signal to thread to exit """
self.done.set()
class ThreadPool:
"""Pool of threads consuming tasks from a queue"""
def __init__(self, num_threads, tasks=[]):
self.tasks = Queue(num_threads)
self.workers = []
self.done = False
self._init_workers(num_threads)
for task in tasks:
self.tasks.put(task)
def _init_workers(self, num_threads):
for i in range(num_threads):
self.workers.append(Worker(self.tasks, i))
def add_task(self, func, *args, **kwargs):
"""Add a task to the queue"""
self.tasks.put((func, args, kwargs))
def _close_all_threads(self):
""" Signal all threads to exit and lose the references to them """
for workr in self.workers:
workr.signal_exit()
self.workers = []
def wait_completion(self):
"""Wait for completion of all the tasks in the queue"""
self.tasks.join()
def __del__(self):
self._close_all_threads()
def create_task(func, *args, **kwargs):
return (func, args, kwargs)
To use the pool
from random import randrange
from time import sleep
delays = [randrange(1, 10) for i in range(30)]
def wait_delay(d):
print('sleeping for (%d)sec' % d)
sleep(d)
pool = ThreadPool(20)
for i, d in enumerate(delays):
pool.add_task(wait_delay, d)
pool.wait_completion()
another way can be adding the process to thethread queue pool
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor:
for i in range(10):
a = executor.submit(arg1, arg2,....)
The overhead of creating the new processes is minimal, especially when it's just 4 of them. I doubt this is a performance hot spot of your application. Keep it simple, optimize where you have to and where profiling results point to.
There is no built in thread based pool. However, it can be very quick to implement a producer/consumer queue with the Queue class.
From:
https://docs.python.org/2/library/queue.html
from threading import Thread
from Queue import Queue
def worker():
while True:
item = q.get()
do_work(item)
q.task_done()
q = Queue()
for i in range(num_worker_threads):
t = Thread(target=worker)
t.daemon = True
t.start()
for item in source():
q.put(item)
q.join() # block until all tasks are done
If you don't mind executing other's code, here's mine:
Note: There is lot of extra code you may want to remove [added for better clarificaiton and demonstration how it works]
Note: Python naming conventions were used for method names and variable names instead of camelCase.
Working procedure:
MultiThread class will initiate with no of instances of threads by sharing lock, work queue, exit flag and results.
SingleThread will be started by MultiThread once it creates all instances.
We can add works using MultiThread (It will take care of locking).
SingleThreads will process work queue using a lock in middle.
Once your work is done, you can destroy all threads with shared boolean value.
Here, work can be anything. It can automatically import (uncomment import line) and process module using given arguments.
Results will be added to results and we can get using get_results
Code:
import threading
import queue
class SingleThread(threading.Thread):
def __init__(self, name, work_queue, lock, exit_flag, results):
threading.Thread.__init__(self)
self.name = name
self.work_queue = work_queue
self.lock = lock
self.exit_flag = exit_flag
self.results = results
def run(self):
# print("Coming %s with parameters %s", self.name, self.exit_flag)
while not self.exit_flag:
# print(self.exit_flag)
self.lock.acquire()
if not self.work_queue.empty():
work = self.work_queue.get()
module, operation, args, kwargs = work.module, work.operation, work.args, work.kwargs
self.lock.release()
print("Processing : " + operation + " with parameters " + str(args) + " and " + str(kwargs) + " by " + self.name + "\n")
# module = __import__(module_name)
result = str(getattr(module, operation)(*args, **kwargs))
print("Result : " + result + " for operation " + operation + " and input " + str(args) + " " + str(kwargs))
self.results.append(result)
else:
self.lock.release()
# process_work_queue(self.work_queue)
class MultiThread:
def __init__(self, no_of_threads):
self.exit_flag = bool_instance()
self.queue_lock = threading.Lock()
self.threads = []
self.work_queue = queue.Queue()
self.results = []
for index in range(0, no_of_threads):
thread = SingleThread("Thread" + str(index+1), self.work_queue, self.queue_lock, self.exit_flag, self.results)
thread.start()
self.threads.append(thread)
def add_work(self, work):
self.queue_lock.acquire()
self.work_queue._put(work)
self.queue_lock.release()
def destroy(self):
self.exit_flag.value = True
for thread in self.threads:
thread.join()
def get_results(self):
return self.results
class Work:
def __init__(self, module, operation, args, kwargs={}):
self.module = module
self.operation = operation
self.args = args
self.kwargs = kwargs
class SimpleOperations:
def sum(self, *args):
return sum([int(arg) for arg in args])
#staticmethod
def mul(a, b, c=0):
return int(a) * int(b) + int(c)
class bool_instance:
def __init__(self, value=False):
self.value = value
def __setattr__(self, key, value):
if key != "value":
raise AttributeError("Only value can be set!")
if not isinstance(value, bool):
raise AttributeError("Only True/False can be set!")
self.__dict__[key] = value
# super.__setattr__(key, bool(value))
def __bool__(self):
return self.value
if __name__ == "__main__":
multi_thread = MultiThread(5)
multi_thread.add_work(Work(SimpleOperations(), "mul", [2, 3], {"c":4}))
while True:
data_input = input()
if data_input == "":
pass
elif data_input == "break":
break
else:
work = data_input.split()
multi_thread.add_work(Work(SimpleOperations(), work[0], work[1:], {}))
multi_thread.destroy()
print(multi_thread.get_results())
I'm setting up a library to spawn as many threads as I can and have them complete any task they are assigned too. When unit testing the creation of the object containing 'n' number of threads and trying to iterate over each one at the object level I am running into an infinite loop.
I fixed the pickling issue with race conditions, but now when trying to output how many threads were created in a for loop, when debugging I encounter an infinite loop.
from __future__ import print_function
try:
import sys
from threading import Thread, Lock, current_thread
from queue import Queue
except ImportError:
raise ImportError
class Worker(Thread):
""" Thread executing tasks from a given tasks queue """
def __init__(self, tasks):
Thread.__init__(self)
self.tasks = tasks
self.daemon = True
self.start()
def run(self):
while True:
func, args, kargs = self.tasks.get()
try:
# Acquire locking mechanism for threads to prevent race condition
Lock.acquire()
func(*args, **kargs)
# Release locking mechanism
Lock.release()
except Exception as e:
# An exception happened in this thread
raise e
finally:
if self.tasks is None:
# Mark process done once there are no more tasks to process
self.tasks.task_done()
class SpawnThreads:
"""Pool of threads consuming tasks from a queue."""
def __init__(self, num_threads: int):
self.tasks = Queue(num_threads)
self.num_threads = num_threads
for _ in range(self.num_threads):
Worker(self.tasks)
def __iter__(self):
return self
def __next__(self):
next_value = 0
while next_value < self.num_threads:
try:
next_value += 1
except Exception:
raise StopIteration
return next_value
def add_task(self, func, *args, **kargs):
"""Add a task to the queue."""
self.tasks.put((func, args, kargs))
def task_list(self, func, args_list):
"""Add a list of tasks to the queue."""
for args in args_list:
self.add_task(func, args)
def wait_completion(self):
""".Wait for completion of all the tasks in the queue."""
self.tasks.join()
def get_qsize(self):
"""Return the approximate size of the queue."""
return self.tasks.qsize()
def get_current_thread(self):
return current_thread()
This is the unittest to evaluate the creation of the thread spawning object and iterate and access each individual thread.
import pytest
import unittest
import SpawnThreads
#unittest
class TestThreadFactory(unittest.TestCase):
def test_spawn_threads(self):
workforce = SpawnThreads(5)
self.assertIsNotNone(workforce)
print(workforce)
for w in workforce:
print(w)
The expected output should be the address space/object, which is each thread (5 total).
More Specifically I would like to see this result 5 times in the console:
<ThreadFactory.thread_factory.SpawnThreads object at 0x0000024E851F5208>
I am getting the integer 5 returned infinitely instead of the 5 addresses of threads.
Your __next__ is definitionally always going to return 5 and never end. __next__ isn't a generator function, there is no state on entry aside from what's on self. So you always loop until next_value (a stateless local variable) is equal to self.num_threads (a value which never changes) and return it; your __next__ could simplify to just return self.num_threads, with no chance of StopIteration ever being raised (thus the infinite loop).
If you want it to return different values (specifically, each of your workers), you'll need state for that:
class SpawnThreads:
"""Pool of threads consuming tasks from a queue."""
def __init__(self, num_threads: int):
self.tasks = Queue(num_threads)
self.next_value = 0 # Initial next_value in instance state
self.num_threads = num_threads
# Store list of workers
self.workers = [Worker(self.tasks) for _ in range(self.num_threads)]
def __iter__(self):
return self
def __next__(self):
# Check if iteration has finished
if self.next_value >= self.num_threads:
raise StopIteration
retval = self.workers[self.next_value] # Save value to return to local
self.next_value += 1 # Increment state for next use
return retval # Return value
Those last three lines could be replaced with an alternative tricksy approach to avoid the local variable if you really care:
try:
return self.workers[self.next_value]
finally:
self.next_value += 1
Or even better, you could use Python built-ins to do the work for you:
class SpawnThreads:
"""Pool of threads consuming tasks from a queue."""
def __init__(self, num_threads: int):
self.tasks = Queue(num_threads)
self.num_threads = num_threads
self.workers = [Worker(self.tasks) for _ in range(self.num_threads)]
self.next_worker_iter = iter(self.workers) # Iterates workers
def __iter__(self):
return self
def __next__(self):
# Let the list iterator do the work of maintaining state,
# raising StopIteration, etc.
return next(self.next_worker_iter)
This approach is simpler, faster, and as a bonus, thread-safe, at least on CPython (if two threads iterate the same SpawnThreads instance, each of the workers will be produced exactly once, rather than values potentially being skipped or repeated).
If the goal is to make an iterable (can be iterated multiple times), not an iterator (can be iterated once from beginning to end and never again), the simplest solution is to make __iter__ return an iterator itself, removing the need for __next__ entirely:
class SpawnThreads:
"""Pool of threads consuming tasks from a queue."""
def __init__(self, num_threads: int):
self.tasks = Queue(num_threads)
self.num_threads = num_threads
self.workers = [Worker(self.tasks) for _ in range(self.num_threads)]
def __iter__(self):
# Makes this a generator function that produces each Worker once
yield from self.workers
# Alternatively:
return iter(self.workers)
# though that exposes more implementation details than anonymous generators
The problem is this method:
def __next__(self):
next_value = 0
while next_value < self.num_threads:
try:
next_value += 1
except Exception:
raise StopIteration
return next_value
It's basically the same as
def __next__(self):
return self.num_threads
As you see there isn't any iterator state, it will just return the same number forever. next_value += 1 will never throw an exception, next_value is just an integer.
To achieve what you want, just store the threads in a container and return an iterator to that container. Modify SpawnThreads:
def __init__(self, num_threads: int):
self.tasks = Queue(num_threads)
self.num_threads = num_threads
self.threads = []
for _ in range(self.num_threads):
self.threads.append(Worker(self.tasks));
def __iter__(self):
return iter(self.threads)
# remove the __next__() method
I like the default python multiprocessing.Pool, but it's still a pain that it isn't easy to show the current progress being made during the pool's execution. In leui of that, I attempted to create my own, custom multiprocess pool mapper, and it looks like this;
from multiprocessing import Process, Pool, cpu_count
from iterable_queue import IterableQueue
def _proc_action(f, in_queue, out_queue):
try:
for val in in_queue:
out_queue.put(f(val))
except (KeyboardInterrupt, EOFError):
pass
def progress_pool_map(f, ls, n_procs=cpu_count()):
in_queue = IterableQueue()
out_queue = IterableQueue()
err = None
try:
procs = [Process(target=_proc_action, args=(f, in_queue, out_queue)) for _ in range(n_procs)]
[p.start() for p in procs]
for elem in ls:
in_queue.put(elem)
in_queue.close()
bar = 0
for _ in ls:
elem = next(out_queue)
bar += 1
if bar % 1000 == 0:
print(bar)
yield elem
out_queue.close()
except (KeyboardInterrupt, EOFError) as e:
in_queue.close()
out_queue.close()
print("Joining processes")
[p.join() for p in procs]
print("Closing processes")
[p.close() for p in procs]
err = e
if err:
raise err
It works fairly well, and prints a value to the console for every 1000 items processed. The progress display itself is something I can worry about in future. Right now, however, my issue is that when cancelled, the operation does anything but fail gracefully. When I try to interrupt the map, it hangs on Joining Processes, and never makes it to Closing Processes. If I try hitting Ctrl+C again, it causes an infinite spew of BrokenPipeErrors to fill the console until I send an EOF and stop my program.
Here's iterable_queue.py, for reference;
from multiprocessing.queues import Queue
from multiprocessing import get_context, Value
import queue
class QueueClosed(Exception):
pass
class IterableQueue(Queue):
def __init__(self, maxsize=0, *, ctx=None):
super().__init__(
maxsize=maxsize,
ctx=ctx if ctx is not None else get_context()
)
self.closed = Value('b', False)
def close(self):
with self.closed.get_lock():
if not self.closed.value:
self.closed.value = True
super().put((None, False))
# throws BrokenPipeError in another thread without this sleep in between
# terrible hack, must fix at some point
import time; time.sleep(0.01)
super().close()
def __iter__(self):
return self
def __next__(self):
try:
return self.get()
except QueueClosed:
raise StopIteration
def get(self, *args, **kwargs):
try:
result, is_open = super().get(*args, **kwargs)
except OSError:
raise QueueClosed
if not is_open:
super().put((None, False))
raise QueueClosed
return result
def __bool__(self):
return bool(self.closed.value)
def put(self, val, *args, **kwargs):
with self.closed.get_lock():
if self.closed.value:
raise QueueClosed
super().put((val, True), *args, **kwargs)
def get_nowait(self):
return self.get(block=False)
def put_nowait(self):
return self.put(block=False)
def empty_remaining(self, block=False):
try:
while True:
yield self.get(block=block)
except (queue.Empty, QueueClosed):
pass
def clear(self):
for _ in self.empty_remaining():
pass
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
I was under the impression that python passed objects by reference, so logically (I thought), passing a dictionary to a thread's worker function would allow me to set new key-value pairs that would be preserved after the worker function had returned. Unfortunately, it seems as though I am wrong!
Without further ado, here's a test case:
from Queue import Queue
from threading import Thread
def work_fn(dictionary, dfield):
dictionary[dfield] = True
class Worker(Thread):
"""Thread executing tasks from a given tasks queue"""
def __init__(self, tasks):
Thread.__init__(self)
self.tasks = tasks
self.daemon = True
self.start()
def run(self):
while True:
func, args, kargs = self.tasks.get()
try:
func(*args, **kargs)
except Exception, e:
print e
self.tasks.task_done()
class ThreadPool(object):
"""Pool of threads consuming tasks from a queue"""
def __init__(self, num_threads):
self.tasks = Queue(num_threads)
for i in range(num_threads):
Worker(self.tasks)
def add_task(self, func, *args, **kargs):
"""Add a task to the queue"""
self.tasks.put((func, args, kargs))
def wait_completion(self):
"""Wait for completion of all the tasks in the queue"""
self.tasks.join()
if __name__ == '__main__':
pool = ThreadPool(4)
data = [{} for _ in range(10)]
for i, d in enumerate(data):
pool.add_task(work_fn, d, str(i))
pool.wait_completion()
for d in data:
print d.keys()
What is actually going on here?
What should I be doing differently?
I don't see where you are waiting for the tasks to finish before printing the results. It seems like you need to do the collection of results in a subsequent loop after spawning the jobs.
Python's futures package allows us to enjoy ThreadPoolExecutor and ProcessPoolExecutor for doing tasks in parallel.
However, for debugging it is sometimes useful to temporarily replace the true parallelism with a dummy one, which carries out the tasks in a serial way in the main thread, without spawning any threads or processes.
Is there anywhere an implementation of a DummyExecutor?
Something like this should do it:
from concurrent.futures import Future, Executor
from threading import Lock
class DummyExecutor(Executor):
def __init__(self):
self._shutdown = False
self._shutdownLock = Lock()
def submit(self, fn, *args, **kwargs):
with self._shutdownLock:
if self._shutdown:
raise RuntimeError('cannot schedule new futures after shutdown')
f = Future()
try:
result = fn(*args, **kwargs)
except BaseException as e:
f.set_exception(e)
else:
f.set_result(result)
return f
def shutdown(self, wait=True):
with self._shutdownLock:
self._shutdown = True
if __name__ == '__main__':
def fnc(err):
if err:
raise Exception("test")
else:
return "ok"
ex = DummyExecutor()
print(ex.submit(fnc, True))
print(ex.submit(fnc, False))
ex.shutdown()
ex.submit(fnc, True) # raises exception
locking is probably not needed in this case, but can't hurt to have it.
Use this to mock your ThreadPoolExecutor
class MockThreadPoolExecutor():
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
def submit(self, fn, *args, **kwargs):
# execute functions in series without creating threads
# for easier unit testing
result = fn(*args, **kwargs)
return result
def shutdown(self, wait=True):
pass
if __name__ == "__main__":
def sum(a, b):
return a + b
with MockThreadPoolExecutor(max_workers=3) as executor:
future_result = list()
for i in range(5):
future_result.append(executor.submit(sum, i + 1, i + 2))