Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
07690ba
1
Parent(s):
63f1925
Improve parallel execution of functions
Browse files- src/utils.py +153 -65
src/utils.py
CHANGED
@@ -1,92 +1,180 @@
|
|
1 |
import re
|
2 |
-
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
|
|
|
|
5 |
|
6 |
-
|
|
|
7 |
def __init__(self, function, *args, **kwargs) -> None:
|
8 |
self.function = function
|
9 |
self.args = args
|
10 |
self.kwargs = kwargs
|
11 |
|
12 |
-
|
|
|
|
|
|
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
class
|
|
|
|
|
16 |
def __init__(self,
|
|
|
17 |
num_workers=None,
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
self.num_workers = os.cpu_count() if num_workers is None else num_workers
|
24 |
-
self.loop = asyncio.get_event_loop() if loop is None else loop
|
25 |
-
self.shutdown_message = shutdown_message
|
26 |
|
27 |
-
self.
|
|
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
self.on_job_complete = on_job_complete
|
32 |
-
self.raise_after_interrupt = raise_after_interrupt
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
37 |
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
return job
|
42 |
|
43 |
-
def
|
44 |
-
self.
|
|
|
45 |
|
46 |
-
def
|
47 |
-
|
48 |
-
tasks = [
|
49 |
-
# creating task starts coroutine
|
50 |
-
asyncio.ensure_future(self._sync_to_async(job))
|
51 |
-
for job in self.jobs
|
52 |
-
]
|
53 |
|
54 |
-
# https://stackoverflow.com/a/42097478
|
55 |
-
self.loop.run_until_complete(
|
56 |
-
asyncio.gather(*tasks, return_exceptions=True)
|
57 |
-
)
|
58 |
|
59 |
-
except KeyboardInterrupt:
|
60 |
-
# Optionally show a message if the shutdown may take a while
|
61 |
-
print(self.shutdown_message, flush=True)
|
62 |
-
|
63 |
-
# Do not show `asyncio.CancelledError` exceptions during shutdown
|
64 |
-
# (a lot of these may be generated, skip this if you prefer to see them)
|
65 |
-
def shutdown_exception_handler(loop, context):
|
66 |
-
if "exception" not in context \
|
67 |
-
or not isinstance(context["exception"], asyncio.CancelledError):
|
68 |
-
loop.default_exception_handler(context)
|
69 |
-
self.loop.set_exception_handler(shutdown_exception_handler)
|
70 |
-
|
71 |
-
# Handle shutdown gracefully by waiting for all tasks to be cancelled
|
72 |
-
cancelled_tasks = asyncio.gather(
|
73 |
-
*asyncio.all_tasks(loop=self.loop), loop=self.loop, return_exceptions=True)
|
74 |
-
cancelled_tasks.add_done_callback(lambda t: self.loop.stop())
|
75 |
-
cancelled_tasks.cancel()
|
76 |
-
|
77 |
-
# Keep the event loop running until it is either destroyed or all
|
78 |
-
# tasks have really terminated
|
79 |
-
while not cancelled_tasks.done() and not self.loop.is_closed():
|
80 |
-
self.loop.run_forever()
|
81 |
-
|
82 |
-
if self.raise_after_interrupt:
|
83 |
-
raise
|
84 |
-
finally:
|
85 |
-
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
86 |
-
self.loop.close()
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
+
|
3 |
import os
|
4 |
+
import signal
|
5 |
+
import logging
|
6 |
+
import sys
|
7 |
+
from time import sleep, time
|
8 |
+
from random import random, randint
|
9 |
+
from multiprocessing import JoinableQueue, Event, Process
|
10 |
+
from queue import Empty
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
|
16 |
+
def re_findall(pattern, string):
|
17 |
+
return [m.groupdict() for m in re.finditer(pattern, string)]
|
18 |
|
19 |
+
|
20 |
+
class Task:
|
21 |
def __init__(self, function, *args, **kwargs) -> None:
|
22 |
self.function = function
|
23 |
self.args = args
|
24 |
self.kwargs = kwargs
|
25 |
|
26 |
+
def run(self):
|
27 |
+
return self.function(*self.args, **self.kwargs)
|
28 |
+
|
29 |
+
|
30 |
|
31 |
+
class CallbackGenerator:
|
32 |
+
def __init__(self, generator, callback):
|
33 |
+
self.generator = generator
|
34 |
+
self.callback = callback
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
if self.callback is not None and callable(self.callback):
|
38 |
+
for t in self.generator:
|
39 |
+
self.callback(t)
|
40 |
+
yield t
|
41 |
+
else:
|
42 |
+
yield from self.generator
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def start_worker(q: JoinableQueue, stop_event: Event): # TODO make class?
|
47 |
+
logger.info('Starting worker...')
|
48 |
+
while True:
|
49 |
+
if stop_event.is_set():
|
50 |
+
logger.info('Worker exiting because of stop_event')
|
51 |
+
break
|
52 |
+
# We set a timeout so we loop past 'stop_event' even if the queue is empty
|
53 |
+
try:
|
54 |
+
task = q.get(timeout=.01)
|
55 |
+
except Empty:
|
56 |
+
# Run next iteration of loop
|
57 |
+
continue
|
58 |
+
|
59 |
+
# Exit if end of queue
|
60 |
+
if task is None:
|
61 |
+
logger.info('Worker exiting because of None on queue')
|
62 |
+
q.task_done()
|
63 |
+
break
|
64 |
+
|
65 |
+
try:
|
66 |
+
task.run() # Do the task
|
67 |
+
except: # Will also catch KeyboardInterrupt
|
68 |
+
logger.exception(f'Failed to process task {task}', )
|
69 |
+
# Can implement some kind of retry handling here
|
70 |
+
finally:
|
71 |
+
q.task_done()
|
72 |
|
73 |
+
class InterruptibleTaskPool:
|
74 |
+
|
75 |
+
# https://the-fonz.gitlab.io/posts/python-multiprocessing/
|
76 |
def __init__(self,
|
77 |
+
tasks=None,
|
78 |
num_workers=None,
|
79 |
+
|
80 |
+
callback=None, # Fired on start
|
81 |
+
max_queue_size=1,
|
82 |
+
grace_period=2,
|
83 |
+
kill_period=30,
|
84 |
+
):
|
85 |
+
|
86 |
+
self.tasks = CallbackGenerator(
|
87 |
+
[] if tasks is None else tasks, callback)
|
88 |
self.num_workers = os.cpu_count() if num_workers is None else num_workers
|
|
|
|
|
89 |
|
90 |
+
self.max_queue_size = max_queue_size
|
91 |
+
self.grace_period = grace_period
|
92 |
+
self.kill_period = kill_period
|
93 |
|
94 |
+
# The JoinableQueue has an internal counter that increments when an item is put on the queue and
|
95 |
+
# decrements when q.task_done() is called. This allows us to wait until it's empty using .join()
|
96 |
+
self.queue = JoinableQueue(maxsize=self.max_queue_size)
|
97 |
+
# This is a process-safe version of the 'panic' variable shown above
|
98 |
+
self.stop_event = Event()
|
99 |
|
|
|
|
|
100 |
|
101 |
+
# n_workers: Start this many processes
|
102 |
+
# max_queue_size: If queue exceeds this size, block when putting items on the queue
|
103 |
+
# grace_period: Send SIGINT to processes if they don't exit within this time after SIGINT/SIGTERM
|
104 |
+
# kill_period: Send SIGKILL to processes if they don't exit after this many seconds
|
105 |
|
106 |
+
# self.on_task_complete = on_task_complete
|
107 |
+
# self.raise_after_interrupt = raise_after_interrupt
|
108 |
|
|
|
109 |
|
110 |
+
def __enter__(self):
|
111 |
+
self.start()
|
112 |
+
return self
|
113 |
|
114 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
115 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
116 |
|
|
|
|
|
|
|
|
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
def start(self) -> None:
|
120 |
+
def handler(signalname):
|
121 |
+
"""
|
122 |
+
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
|
123 |
+
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
|
124 |
+
"""
|
125 |
+
def f(signal_received, frame):
|
126 |
+
raise KeyboardInterrupt(f'{signalname} received')
|
127 |
+
return f
|
128 |
|
129 |
+
# This will be inherited by the child process if it is forked (not spawned)
|
130 |
+
signal.signal(signal.SIGINT, handler('SIGINT'))
|
131 |
+
signal.signal(signal.SIGTERM, handler('SIGTERM'))
|
132 |
|
133 |
+
procs = []
|
134 |
+
|
135 |
+
for i in range(self.num_workers):
|
136 |
+
# Make it a daemon process so it is definitely terminated when this process exits,
|
137 |
+
# might be overkill but is a nice feature. See
|
138 |
+
# https://docs.python.org/3.8/library/multiprocessing.html#multiprocessing.Process.daemon
|
139 |
+
p = Process(name=f'Worker-{i:02d}', daemon=True,
|
140 |
+
target=start_worker, args=(self.queue, self.stop_event))
|
141 |
+
procs.append(p)
|
142 |
+
p.start()
|
143 |
+
|
144 |
+
try:
|
145 |
+
# Put tasks on queue
|
146 |
+
for task in self.tasks:
|
147 |
+
logger.info(f'Put task {task} on queue')
|
148 |
+
self.queue.put(task)
|
149 |
+
|
150 |
+
# Put exit tasks on queue
|
151 |
+
for i in range(self.num_workers):
|
152 |
+
self.queue.put(None)
|
153 |
+
|
154 |
+
# Wait until all tasks are processed
|
155 |
+
self.queue.join()
|
156 |
+
|
157 |
+
except KeyboardInterrupt:
|
158 |
+
logger.warning('Caught KeyboardInterrupt! Setting stop event...')
|
159 |
+
# raise # TODO add option
|
160 |
+
finally:
|
161 |
+
self.stop_event.set()
|
162 |
+
t = time()
|
163 |
+
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
|
164 |
+
# .is_alive() also implicitly joins the process (good practice in linux)
|
165 |
+
while alive_procs := [p for p in procs if p.is_alive()]:
|
166 |
+
if time() > t + self.grace_period:
|
167 |
+
for p in alive_procs:
|
168 |
+
os.kill(p.pid, signal.SIGINT)
|
169 |
+
logger.warning(f'Sending SIGINT to {p}')
|
170 |
+
elif time() > t + self.kill_period:
|
171 |
+
for p in alive_procs:
|
172 |
+
logger.warning(f'Sending SIGKILL to {p}')
|
173 |
+
# Queues and other inter-process communication primitives can break when
|
174 |
+
# process is killed, but we don't care here
|
175 |
+
p.kill()
|
176 |
+
sleep(.01)
|
177 |
+
|
178 |
+
sleep(.1)
|
179 |
+
for p in procs:
|
180 |
+
logger.info(f'Process status: {p}')
|