cbensimon HF staff commited on
Commit
f0582f1
·
1 Parent(s): 8d11afc
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
spaces/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import sys
5
+
6
+
7
+ if sys.version_info.minor < 8: # pragma: no cover
8
+ raise RuntimeError("Importing PySpaces requires Python 3.8+")
9
+
10
+
11
+ # Prevent gradio from importing spaces
12
+ if (gr := sys.modules.get('gradio')) is not None: # pragma: no cover
13
+ try:
14
+ gr.Blocks
15
+ except AttributeError:
16
+ raise ImportError
17
+
18
+
19
+ from .zero.decorator import GPU
20
+ from .gradio import gradio_auto_wrap
21
+ from .gradio import disable_gradio_auto_wrap
22
+ from .gradio import enable_gradio_auto_wrap
23
+
24
+
25
+ __all__ = [
26
+ 'GPU',
27
+ 'gradio_auto_wrap',
28
+ 'disable_gradio_auto_wrap',
29
+ 'enable_gradio_auto_wrap',
30
+ ]
spaces/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ from .utils import boolean
9
+
10
+
11
+ ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
12
+
13
+
14
+ class Settings:
15
+ def __init__(self):
16
+ self.zero_gpu = boolean(
17
+ os.getenv('SPACES_ZERO_GPU'))
18
+ self.zero_device_api_url = (
19
+ os.getenv('SPACES_ZERO_DEVICE_API_URL'))
20
+ self.gradio_auto_wrap = boolean(
21
+ os.getenv('SPACES_GRADIO_AUTO_WRAP'))
22
+ self.zero_patch_torch_device = boolean(
23
+ os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
24
+ self.zero_gpu_v2 = boolean(
25
+ os.getenv('ZEROGPU_V2'))
26
+ self.zerogpu_offload_dir = (
27
+ os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT))
28
+
29
+
30
+ Config = Settings()
31
+
32
+
33
+ if Config.zero_gpu:
34
+ assert Config.zero_device_api_url is not None, (
35
+ 'SPACES_ZERO_DEVICE_API_URL env must be set '
36
+ 'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
37
+ )
spaces/gradio.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable
6
+ from typing import Generator
7
+ from typing import TypeVar
8
+ from typing import overload
9
+ from typing_extensions import ParamSpec
10
+
11
+ from .config import Config
12
+ from .zero.decorator import GPU
13
+
14
+
15
+ Param = ParamSpec('Param')
16
+ Res = TypeVar('Res')
17
+
18
+
19
+ gradio_auto_wrap_enabled = Config.gradio_auto_wrap
20
+
21
+
22
+ def disable_gradio_auto_wrap():
23
+ global gradio_auto_wrap_enabled
24
+ gradio_auto_wrap_enabled = False
25
+
26
+ def enable_gradio_auto_wrap():
27
+ global gradio_auto_wrap_enabled
28
+ gradio_auto_wrap_enabled = True
29
+
30
+
31
+ @overload
32
+ def gradio_auto_wrap(
33
+ task:
34
+ Callable[Param, Res],
35
+ ) -> Callable[Param, Res]:
36
+ ...
37
+ @overload
38
+ def gradio_auto_wrap(
39
+ task:
40
+ None,
41
+ ) -> None:
42
+ ...
43
+ def gradio_auto_wrap(
44
+ task:
45
+ Callable[Param, Res]
46
+ | None,
47
+ ) -> (Callable[Param, Res]
48
+ | None):
49
+ """
50
+ """
51
+ if not gradio_auto_wrap_enabled:
52
+ return task
53
+ if not callable(task):
54
+ return task
55
+ return GPU(task) # type: ignore
spaces/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import ctypes
6
+ import sys
7
+ from functools import lru_cache as cache
8
+ from functools import partial
9
+
10
+ import multiprocessing
11
+ from multiprocessing.queues import SimpleQueue as _SimpleQueue
12
+ from pathlib import Path
13
+ from pickle import PicklingError
14
+ from typing import Callable
15
+ from typing import TypeVar
16
+
17
+
18
+ GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
19
+
20
+
21
+ T = TypeVar('T')
22
+
23
+
24
+ @cache
25
+ def self_cgroup_device_path() -> str:
26
+ cgroup_content = Path('/proc/self/cgroup').read_text()
27
+ for line in cgroup_content.strip().split('\n'):
28
+ contents = line.split(':devices:')
29
+ if len(contents) != 2:
30
+ continue # pragma: no cover
31
+ return contents[1]
32
+ raise Exception # pragma: no cover
33
+
34
+
35
+ if sys.version_info.minor < 9: # pragma: no cover
36
+ _SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
37
+
38
+ class SimpleQueue(_SimpleQueue[T]):
39
+ def __init__(self, *args):
40
+ super().__init__(*args, ctx=multiprocessing.get_context('fork'))
41
+ def put(self, obj: T):
42
+ try:
43
+ super().put(obj)
44
+ except PicklingError:
45
+ raise # pragma: no cover
46
+ # https://bugs.python.org/issue29187
47
+ except Exception as e:
48
+ message = str(e)
49
+ if not "pickle" in message:
50
+ raise # pragma: no cover
51
+ raise PicklingError(message)
52
+ def close(self): # Python 3.8 static typing trick
53
+ super().close() # type: ignore
54
+ def wlock_release(self):
55
+ if (lock := getattr(self, '_wlock', None)) is None:
56
+ return # pragma: no cover
57
+ try:
58
+ lock.release()
59
+ except ValueError:
60
+ pass
61
+
62
+
63
+ def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
64
+ def drop(*args):
65
+ return fn()
66
+ return drop
67
+
68
+
69
+ def boolean(value: str | None) -> bool:
70
+ return value is not None and value.lower() in ("1", "t", "true")
71
+
72
+
73
+ def gradio_request_var():
74
+ try:
75
+ from gradio.context import LocalContext
76
+ except ImportError: # pragma: no cover
77
+ raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
78
+ return LocalContext.request
79
+
80
+
81
+ def malloc_trim():
82
+ ctypes.CDLL("libc.so.6").malloc_trim(0)
83
+
84
+
85
+ debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
spaces/zero/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from pathlib import Path
5
+
6
+ from ..config import Config
7
+
8
+
9
+ if Config.zero_gpu:
10
+
11
+ from . import gradio
12
+ from . import torch
13
+
14
+ if torch.is_in_bad_fork():
15
+ raise RuntimeError(
16
+ "CUDA has been initialized before importing the `spaces` package"
17
+ )
18
+
19
+ torch.patch()
20
+ gradio.one_launch(torch.pack)
21
+ Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
spaces/zero/api.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synced with huggingface/pyspaces:spaces/zero/api.py
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from datetime import timedelta
7
+ from typing import Any
8
+ from typing import Generator
9
+ from typing import Literal
10
+ from typing import NamedTuple
11
+ from typing import Optional
12
+ from typing import overload
13
+
14
+ import httpx
15
+ from pydantic import BaseModel
16
+ from typing_extensions import assert_never
17
+
18
+
19
+ AllowToken = str
20
+ NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
21
+ NvidiaUUID = str
22
+ CGroupPath = str
23
+ VisitorId = str
24
+ Score = float
25
+
26
+ AuthLevel = Literal['regular', 'pro']
27
+
28
+
29
+ AUTHENTICATED_HEADER = 'X-Authenticated'
30
+
31
+
32
+ class ScheduleResponse(BaseModel):
33
+ idle: bool
34
+ nvidiaIndex: int
35
+ nvidiaUUID: str
36
+ allowToken: str
37
+
38
+
39
+ class QuotaInfos(BaseModel):
40
+ left: int
41
+ wait: timedelta
42
+
43
+
44
+ class ReportUsageMonitoringParams(NamedTuple):
45
+ nvidia_index: int
46
+ visitor_id: str
47
+ duration: timedelta
48
+
49
+
50
+ class QueueEvent(BaseModel):
51
+ event: Literal['ping', 'failed', 'succeeded']
52
+ data: Optional[ScheduleResponse] = None
53
+
54
+
55
+ def sse_parse(text: str):
56
+ event, *data = text.strip().splitlines()
57
+ assert event.startswith('event:')
58
+ event = event[6:].strip()
59
+ if event in ('ping', 'failed'):
60
+ return QueueEvent(event=event)
61
+ assert event == 'succeeded'
62
+ (data,) = data
63
+ assert data.startswith('data:')
64
+ data = data[5:].strip()
65
+ return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
66
+
67
+
68
+ def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
69
+ for text in res.iter_text():
70
+ if len(text) == 0:
71
+ break # pragma: no cover
72
+ try:
73
+ yield sse_parse(text)
74
+ except GeneratorExit:
75
+ res.close()
76
+ break
77
+
78
+
79
+ class APIClient:
80
+
81
+ def __init__(self, client: httpx.Client):
82
+ self.client = client
83
+
84
+ def startup_report(self) -> httpx.codes:
85
+ res = self.client.post('/startup-report')
86
+ return httpx.codes(res.status_code)
87
+
88
+ def schedule(
89
+ self,
90
+ cgroup_path: str,
91
+ task_id: int = 0,
92
+ token: str | None = None,
93
+ duration_seconds: int | None = None,
94
+ enable_queue: bool = True,
95
+ ):
96
+ params: dict[str, str | int | bool] = {
97
+ 'cgroupPath': cgroup_path,
98
+ 'taskId': task_id,
99
+ 'enableQueue': enable_queue,
100
+ }
101
+ if duration_seconds is not None:
102
+ params['durationSeconds'] = duration_seconds
103
+ if token is not None:
104
+ params['token'] = token
105
+ res = self.client.send(
106
+ request=self.client.build_request(
107
+ method='POST',
108
+ url='/schedule',
109
+ params=params,
110
+ ),
111
+ stream=True,
112
+ )
113
+ status = httpx.codes(res.status_code)
114
+ auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
115
+ if (status is not httpx.codes.OK and
116
+ status is not httpx.codes.TOO_MANY_REQUESTS
117
+ ):
118
+ res.close()
119
+ return status, auth
120
+ if "text/event-stream" in res.headers['content-type']:
121
+ return sse_stream(res), auth
122
+ res.read()
123
+ if status is httpx.codes.TOO_MANY_REQUESTS:
124
+ return QuotaInfos(**res.json()), auth # pragma: no cover
125
+ if status is httpx.codes.OK:
126
+ return ScheduleResponse(**res.json()), auth
127
+ assert_never(status)
128
+
129
+ def allow(
130
+ self,
131
+ allow_token: str,
132
+ pid: int,
133
+ ):
134
+ res = self.client.post('/allow', params={
135
+ 'allowToken': allow_token,
136
+ 'pid': pid,
137
+ })
138
+ return httpx.codes(res.status_code)
139
+
140
+ def release(
141
+ self,
142
+ allow_token: str,
143
+ fail: bool = False,
144
+ ) -> httpx.codes:
145
+ res = self.client.post('/release', params={
146
+ 'allowToken': allow_token,
147
+ 'fail': fail,
148
+ })
149
+ return httpx.codes(res.status_code)
150
+
151
+ def get_queue_size(self) -> int:
152
+ res = self.client.get('/queue-size')
153
+ assert res.status_code == 200, res.status_code
154
+ size = res.json()
155
+ assert isinstance(size, int)
156
+ return size
spaces/zero/client.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import time
7
+ import warnings
8
+ from datetime import timedelta
9
+
10
+ import gradio as gr
11
+ import httpx
12
+ from packaging import version
13
+ from typing_extensions import assert_never
14
+
15
+ from .. import utils
16
+ from ..config import Config
17
+ from .api import APIClient
18
+ from .api import AuthLevel
19
+ from .api import QuotaInfos
20
+ from .api import ScheduleResponse
21
+ from .gradio import HTMLError
22
+ from .gradio import get_event
23
+ from .gradio import supports_auth
24
+
25
+
26
+ TOKEN_HEADER = 'X-IP-Token'
27
+ DEFAULT_SCHEDULE_DURATION = 60
28
+
29
+ QUOTA_MESSAGE = "You have exceeded your GPU quota"
30
+ UNUSED_MESSAGE = "GPU device not used"
31
+ NO_GPU_MESSAGE_REGULAR = "No GPU was available"
32
+ NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60s"
33
+
34
+ SIGNUP_ON_HF_TXT = "Create a free account"
35
+ SIGNUP_ON_HF_URL = "https://huggingface.co/join"
36
+ SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
37
+ SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
38
+
39
+
40
+ def api_client():
41
+ assert Config.zero_device_api_url is not None
42
+ httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
43
+ return APIClient(httpx_client)
44
+
45
+
46
+ def startup_report():
47
+ retries, max_retries = 0, 2
48
+ client = api_client()
49
+ while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
50
+ time.sleep(1)
51
+ if (retries := retries + 1) > max_retries:
52
+ raise RuntimeError("Error while initializing ZeroGPU: NotFound")
53
+ if status is not httpx.codes.OK: # pragma: no cover
54
+ raise RuntimeError("Error while initializing ZeroGPU: Unknown")
55
+
56
+
57
+ def html_string(html_contents: str, text_contents: str): # pragma: no cover
58
+ class HTMLString(str):
59
+ def __str__(self):
60
+ return text_contents
61
+ return HTMLString(html_contents)
62
+
63
+
64
+ def _toast_action(
65
+ auth: AuthLevel | None,
66
+ supports_html: bool,
67
+ pro_message: str,
68
+ unlogged_desc: str,
69
+ logged_desc: str,
70
+ ending: str,
71
+ ) -> tuple[str, str]: # pragma: no cover
72
+ if not supports_auth() or auth == 'pro':
73
+ return pro_message, pro_message
74
+ html = ""
75
+ link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
76
+ text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
77
+ desc = unlogged_desc if auth is None else logged_desc
78
+ desc += f" {ending}."
79
+ style = ";".join([
80
+ "white-space: nowrap",
81
+ "text-underline-offset: 2px",
82
+ "color: var(--body-text-color)",
83
+ ])
84
+ if supports_html:
85
+ html += f'<a style="{style}" href="{link}">'
86
+ html += text
87
+ if supports_html:
88
+ html += '</a>'
89
+ html += f" {desc}"
90
+ markdown = f'[{text}]({link}) {desc}'
91
+ return html, markdown
92
+
93
+
94
+ def schedule(
95
+ task_id: int,
96
+ request: gr.Request | None = None,
97
+ duration: timedelta | None = None,
98
+ _first_attempt: bool = True,
99
+ ) -> ScheduleResponse:
100
+
101
+ if not (gradio_version := version.parse(gr.__version__)).major >= 4: # pragma: no cover
102
+ raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
103
+
104
+ GRADIO_HTML_TOASTS = gradio_version >= version.Version('4.39')
105
+
106
+ res, auth = api_client().schedule(
107
+ cgroup_path=utils.self_cgroup_device_path(),
108
+ task_id=task_id,
109
+ token=_get_token(request),
110
+ duration_seconds=duration.seconds if duration is not None else None,
111
+ )
112
+
113
+ if isinstance(res, ScheduleResponse):
114
+ return res
115
+
116
+ if isinstance(res, QuotaInfos): # pragma: no cover
117
+ requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
118
+ if res.wait < timedelta(0):
119
+ raise gr.Error(
120
+ f"The requested GPU duration ({requested}s) "
121
+ f"is larger than the maximum allowed"
122
+ )
123
+ else:
124
+ gpu = "Pro GPU" if auth == 'pro' else ("free GPU" if auth == 'regular' else "GPU")
125
+ message = (
126
+ f"You have exceeded your {gpu} quota "
127
+ f"({requested}s requested vs. {res.left}s left)."
128
+ )
129
+ details_html, details_markdown = _toast_action(
130
+ auth=auth,
131
+ supports_html=GRADIO_HTML_TOASTS,
132
+ pro_message=f"Try again in {res.wait}",
133
+ unlogged_desc="to get more",
134
+ logged_desc="to get 5x more",
135
+ ending="usage quota",
136
+ )
137
+ message_html = f"{message} {details_html}"
138
+ message_text = f"{message} {details_markdown}"
139
+ raise HTMLError(html_string(message_html, message_text))
140
+
141
+ if not isinstance(res, httpx.codes): # pragma: no cover
142
+ gr.Info("Waiting for a GPU to become available")
143
+ # TODO: Sign-up message if not authenticated (after some time ?)
144
+ connection_event = get_event()
145
+ if connection_event is None and request is not None:
146
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
147
+ while True:
148
+ try:
149
+ event = next(res)
150
+ except StopIteration:
151
+ raise RuntimeError("Unexpected end of stream")
152
+ except httpx.RemoteProtocolError:
153
+ if not _first_attempt:
154
+ raise RuntimeError("Error while re-trying after queue disconnect")
155
+ return schedule(task_id, request, duration, _first_attempt=False)
156
+ if event.event == 'ping':
157
+ if connection_event is not None and not connection_event.alive:
158
+ res.close()
159
+ raise RuntimeError("Connection closed by visitor while queueing")
160
+ continue
161
+ if event.event == 'failed':
162
+ details_html, details_markdown = _toast_action(
163
+ auth=auth,
164
+ supports_html=GRADIO_HTML_TOASTS,
165
+ pro_message="Retry later",
166
+ unlogged_desc="to get a higher",
167
+ logged_desc="to get the highest",
168
+ ending="priority in ZeroGPU queues",
169
+ )
170
+ message_html = f"{NO_GPU_MESSAGE_INQUEUE}. {details_html}"
171
+ message_text = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
172
+ raise HTMLError(html_string(message_html, message_text))
173
+ if event.event == 'succeeded':
174
+ assert event.data is not None
175
+ if connection_event is not None and not connection_event.alive:
176
+ release(event.data.allowToken)
177
+ raise RuntimeError("Connection closed by visitor on queue success")
178
+ gr.Info("Successfully acquired a GPU")
179
+ return event.data
180
+
181
+ if res is httpx.codes.SERVICE_UNAVAILABLE:
182
+ raise gr.Error(NO_GPU_MESSAGE_REGULAR)
183
+
184
+ # TODO: Find a way to log 'detail' response field
185
+ raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
186
+
187
+
188
+ def allow(allow_token: str) -> None:
189
+ pid = os.getpid()
190
+ assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
191
+ assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
192
+
193
+
194
+ def release(
195
+ allow_token: str, *,
196
+ fail: bool = False,
197
+ allow_404: bool = False,
198
+ ) -> None:
199
+
200
+ res = api_client().release(
201
+ allow_token=allow_token,
202
+ fail=fail,
203
+ )
204
+
205
+ if res is httpx.codes.NO_CONTENT: # pragma: no cover
206
+ try:
207
+ gr.Warning(UNUSED_MESSAGE)
208
+ except AttributeError:
209
+ pass
210
+ warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
211
+ return None
212
+
213
+ if res is httpx.codes.NOT_FOUND:
214
+ if not allow_404:
215
+ warnings.warn("ZeroGPU API /release warning: 404 Not Found")
216
+ return None
217
+
218
+ if httpx.codes.is_success(res):
219
+ return None
220
+
221
+ # TODO: Find a way to log 'detail' response field
222
+ # TODO: Only raise in dev environment. Simply warn in production ?
223
+ raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
224
+
225
+
226
+ def _get_token(request: gr.Request | None) -> str | None:
227
+
228
+ if request is None:
229
+ return None
230
+
231
+ headers = getattr(request, 'headers', None)
232
+ if headers is None or not hasattr(headers, '__dict__'):
233
+ raise gr.Error("Internal Gradio error")
234
+
235
+ # Compatibility trick
236
+ if not hasattr(headers, 'get'):
237
+ headers = headers.__dict__ # pragma: no cover
238
+
239
+ return headers.get(TOKEN_HEADER.lower())
spaces/zero/decorator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ import sys
7
+ import warnings
8
+ from datetime import timedelta
9
+ from functools import partial
10
+ from typing import Callable
11
+ from typing import TypeVar
12
+ from typing import overload
13
+ from typing_extensions import ParamSpec
14
+ from typing_extensions import Unpack
15
+
16
+ from ..config import Config
17
+ from .types import DynamicDuration
18
+ from .types import EmptyKwargs
19
+
20
+
21
+ P = ParamSpec('P')
22
+ R = TypeVar('R')
23
+
24
+
25
+ decorated_cache: dict[Callable, Callable] = {}
26
+
27
+
28
+ @overload
29
+ def GPU(
30
+ task: None = None, *,
31
+ duration: DynamicDuration[P] = None,
32
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
33
+ ...
34
+ @overload
35
+ def GPU(
36
+ task: Callable[P, R], *,
37
+ duration: DynamicDuration[P] = None,
38
+ ) -> Callable[P, R]:
39
+ ...
40
+ def GPU(
41
+ task: Callable[P, R] | None = None, *,
42
+ duration: DynamicDuration[P] = None,
43
+ **kwargs: Unpack[EmptyKwargs],
44
+ ) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
45
+ """
46
+ ZeroGPU decorator
47
+
48
+ Basic usage:
49
+ ```
50
+ @spaces.GPU
51
+ def fn(...):
52
+ # CUDA is available here
53
+ pass
54
+ ```
55
+
56
+ With custom duration:
57
+ ```
58
+ @spaces.GPU(duration=45) # Expressed in seconds
59
+ def fn(...):
60
+ # CUDA is available here
61
+ pass
62
+ ```
63
+
64
+ Args:
65
+ task (`Callable | None`): Python function that requires CUDA
66
+ duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
67
+
68
+ Returns:
69
+ `Callable`: GPU-ready function
70
+ """
71
+ if "enable_queue" in kwargs:
72
+ warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
73
+ if task is None:
74
+ return partial(_GPU, duration=duration)
75
+ return _GPU(task, duration)
76
+
77
+
78
+ def _GPU(
79
+ task: Callable[P, R],
80
+ duration: DynamicDuration[P],
81
+ ) -> Callable[P, R]:
82
+
83
+ if not Config.zero_gpu:
84
+ return task
85
+
86
+ from . import client
87
+ from .wrappers import regular_function_wrapper
88
+ from .wrappers import generator_function_wrapper
89
+
90
+ if sys.version_info.minor < 9: # pragma: no cover
91
+ raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
92
+
93
+ if task in decorated_cache:
94
+ # TODO: Assert same duration ?
95
+ return decorated_cache[task] # type: ignore
96
+
97
+ if inspect.iscoroutinefunction(task):
98
+ raise NotImplementedError
99
+
100
+ if inspect.isgeneratorfunction(task):
101
+ decorated = generator_function_wrapper(task, duration)
102
+ else:
103
+ decorated = regular_function_wrapper(task, duration)
104
+
105
+ setattr(decorated, 'zerogpu', None)
106
+
107
+ client.startup_report()
108
+ decorated_cache.update({
109
+ task: decorated,
110
+ decorated: decorated,
111
+ })
112
+
113
+ return decorated # type: ignore
spaces/zero/gradio.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from functools import wraps
6
+ from packaging import version
7
+ from typing import Callable
8
+ from typing import NamedTuple
9
+ from typing import TYPE_CHECKING
10
+ import warnings
11
+
12
+ import gradio as gr
13
+ from gradio.context import Context
14
+ from gradio.context import LocalContext
15
+ from gradio.helpers import Progress
16
+ from gradio.helpers import TrackedIterable
17
+ from gradio.queueing import Queue
18
+ from typing_extensions import ParamSpec
19
+
20
+ from ..utils import SimpleQueue
21
+ from .types import GeneratorResQueueResult
22
+ from .types import GradioQueueEvent
23
+ from .types import RegularResQueueResult
24
+
25
+
26
+ QUEUE_RPC_METHODS = [
27
+ "set_progress",
28
+ "log_message",
29
+ ]
30
+
31
+
32
+ class GradioPartialContext(NamedTuple):
33
+ event_id: str | None
34
+ in_event_listener: bool
35
+ progress: Progress | None
36
+
37
+ @staticmethod
38
+ def get():
39
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
40
+ return GradioPartialContext(
41
+ event_id=LocalContext.event_id.get(),
42
+ in_event_listener=LocalContext.in_event_listener.get(),
43
+ progress=LocalContext.progress.get(),
44
+ )
45
+
46
+ @staticmethod
47
+ def apply(context: 'GradioPartialContext'):
48
+ LocalContext.event_id.set(context.event_id)
49
+ LocalContext.in_event_listener.set(context.in_event_listener)
50
+ LocalContext.progress.set(context.progress)
51
+
52
+
53
+ def get_queue_instance():
54
+ blocks = LocalContext.blocks.get()
55
+ if blocks is None: # pragma: no cover
56
+ return None
57
+ return blocks._queue
58
+
59
+
60
+ def get_event():
61
+ queue = get_queue_instance()
62
+ event_id = LocalContext.event_id.get()
63
+ if queue is None:
64
+ return None
65
+ if event_id is None: # pragma: no cover
66
+ return None
67
+ for job in queue.active_jobs:
68
+ if job is None: # pragma: no cover
69
+ continue
70
+ for event in job:
71
+ if event._id == event_id:
72
+ return event
73
+
74
+
75
+ def get_server_port() -> int | None:
76
+ from_request_context = True
77
+ if (blocks := LocalContext.blocks.get()) is None: # Request
78
+ from_request_context = False
79
+ if (blocks := Context.root_block) is None: # Caching
80
+ return None
81
+ if (server := getattr(blocks, 'server', None)) is None: # pragma: no cover (Gradio 4)
82
+ if from_request_context:
83
+ warnings.warn("Gradio: No blocks.server inside a request") # pragma: no cover
84
+ return -1
85
+ if TYPE_CHECKING:
86
+ assert (server := blocks.server)
87
+ return server.config.port
88
+
89
+
90
+ def try_process_queue_event(method_name: str, *args, **kwargs):
91
+ queue = get_queue_instance()
92
+ if queue is None: # pragma: no cover
93
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
94
+ return
95
+ method = getattr(queue, method_name, None)
96
+ assert callable(method)
97
+ method(*args, **kwargs)
98
+
99
+
100
+ def patch_gradio_queue(
101
+ res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
102
+ ):
103
+
104
+ def rpc_method(method_name: str):
105
+ def method(*args, **kwargs):
106
+ if args and isinstance(args[0], Queue):
107
+ args = args[1:] # drop `self`
108
+ res_queue.put(GradioQueueEvent(method_name, args, kwargs))
109
+ return method
110
+
111
+ for method_name in QUEUE_RPC_METHODS:
112
+ if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
113
+ warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
114
+ continue
115
+ if not callable(method): # pragma: no cover
116
+ warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
117
+ continue
118
+ setattr(Queue, method_name, rpc_method(method_name))
119
+
120
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
121
+
122
+
123
+ def tracked_iterable__reduce__(self):
124
+ res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
125
+ cls, base, state, *_ = res
126
+ return cls, base,{**state, **{
127
+ 'iterable': None,
128
+ '_tqdm': None,
129
+ }}
130
+
131
+
132
+ def supports_auth():
133
+ return version.parse(gr.__version__) >= version.Version('4.27.0')
134
+
135
+
136
+ Param = ParamSpec('Param')
137
+
138
+ def one_launch(task: Callable[Param, None], *task_args: Param.args, **task_kwargs: Param.kwargs):
139
+ _launch = gr.Blocks.launch
140
+ @wraps(gr.Blocks.launch)
141
+ def launch(*args, **kwargs):
142
+ task(*task_args, **task_kwargs)
143
+ gr.Blocks.launch = _launch
144
+ return gr.Blocks.launch(*args, **kwargs)
145
+ gr.Blocks.launch = launch
146
+
147
+
148
+ class HTMLError(gr.Error):
149
+ def __str__(self): # pragma: no cover
150
+ return self.message
spaces/zero/torch/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from ...config import Config
5
+
6
+
7
+ try:
8
+
9
+ import torch
10
+
11
+ except ImportError:
12
+
13
+ _patch = lambda *args, **kwargs: None
14
+ _unpatch = lambda *args, **kwargs: None
15
+ _pack = lambda *args, **kwargs: None
16
+ _init = lambda *args, **kwargs: None
17
+ _size = lambda *args, **kwargs: 0
18
+ _move = lambda *args, **kwargs: None
19
+ _is_in_bad_fork = lambda *args, **kwargs: False
20
+
21
+ else:
22
+
23
+ if Config.zero_gpu_v2:
24
+ from . import patching as _patching
25
+ else: # pragma: no cover
26
+ from . import patching_legacy as _patching
27
+
28
+ _patch = _patching.patch
29
+ _unpatch = _patching.unpatch
30
+ _pack = _patching.pack
31
+ _init = _patching.init
32
+ _size = _patching.size
33
+ _move = _patching.move
34
+ _is_in_bad_fork = _patching.is_in_bad_fork
35
+
36
+ patch = _patch
37
+ unpatch = _unpatch
38
+ pack = _pack
39
+ init = _init
40
+ size = _size
41
+ move = _move
42
+ is_in_bad_fork = _is_in_bad_fork
spaces/zero/torch/bitsandbytes.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ from contextlib import contextmanager
9
+ from importlib import metadata
10
+ from types import ModuleType
11
+ from typing import TYPE_CHECKING
12
+ from typing import Tuple
13
+
14
+ import torch
15
+ from packaging import version
16
+
17
+ if TYPE_CHECKING:
18
+ import torch as Torch
19
+
20
+
21
+ @contextmanager
22
+ def cuda_unavailable(torch: ModuleType):
23
+ _is_available = torch.cuda.is_available
24
+ torch.cuda.is_available = lambda: False
25
+ yield
26
+ torch.cuda.is_available = _is_available
27
+
28
+
29
+ def maybe_import_bitsandbytes():
30
+ try:
31
+ import torch
32
+ except ImportError: # pragma: no cover
33
+ return None
34
+ with cuda_unavailable(torch):
35
+ try:
36
+ import bitsandbytes
37
+ except ImportError:
38
+ bitsandbytes = None
39
+ else:
40
+ if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
41
+ raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
42
+ print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
43
+ return bitsandbytes
44
+
45
+
46
+ if (bnb := maybe_import_bitsandbytes()):
47
+
48
+ from torch.utils.weak import WeakTensorKeyDictionary
49
+
50
+ with cuda_unavailable(torch):
51
+ from bitsandbytes import cextension
52
+ from bitsandbytes import functional
53
+ try: # bitsandbytes < 0.44
54
+ from bitsandbytes.cuda_setup.main import CUDASetup
55
+ except ModuleNotFoundError: # pragma: no cover
56
+ CUDASetup = None
57
+ from bitsandbytes.nn import Int8Params
58
+ from bitsandbytes.nn import Params4bit
59
+
60
+ _param_to_8bit = Int8Params.to # type: ignore
61
+ _param_cuda_8bit = Int8Params.cuda
62
+ _param_to_4bit = Params4bit.to # type: ignore
63
+ _param_cuda_4bit = Params4bit.cuda
64
+
65
+ TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
66
+
67
+ to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
68
+ to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
69
+
70
+ def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
71
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
72
+ device, *_ = parsed
73
+ if not isinstance(device, torch.device): # pragma: no cover
74
+ return _param_to_8bit(self, *args, **kwargs)
75
+ if device.type != 'cuda':
76
+ return _param_to_8bit(self, *args, **kwargs)
77
+ to_ops_8bit[self] = parsed
78
+ return self
79
+
80
+ def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
81
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
82
+ device, *_ = parsed
83
+ if not isinstance(device, torch.device): # pragma: no cover
84
+ return _param_to_4bit(self, *args, **kwargs)
85
+ if device.type != 'cuda':
86
+ return _param_to_4bit(self, *args, **kwargs)
87
+ to_ops_4bit[self] = parsed
88
+ return self
89
+
90
+ def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
91
+ if device is None: # pragma: no cover
92
+ return True
93
+ if isinstance(device, int):
94
+ return True
95
+ if isinstance(device, str): # pragma: no cover
96
+ device = torch.device(device)
97
+ return device.type == 'cuda' # pragma: no cover
98
+
99
+ def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
100
+ if not _cuda_op_arg_check(device): # pragma: no cover
101
+ # Let PyTorch handle the fail
102
+ return _param_cuda_8bit(self, device, **kwargs)
103
+ to_ops_8bit[self] = None
104
+ return self
105
+
106
+ def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
107
+ if not _cuda_op_arg_check(device): # pragma: no cover
108
+ # Let PyTorch handle the fail
109
+ return _param_cuda_4bit(self, device, **kwargs)
110
+ to_ops_4bit[self] = None
111
+ return self
112
+
113
+ def _patch():
114
+ Int8Params.to = _to_op_register_8bit # type: ignore
115
+ Int8Params.cuda = _cuda_op_register_8bit # type: ignore
116
+ Params4bit.to = _to_op_register_4bit # type: ignore
117
+ Params4bit.cuda = _cuda_op_register_4bit # type: ignore
118
+
119
+ def _unpatch():
120
+ Int8Params.to = _param_to_8bit # type: ignore
121
+ Int8Params.cuda = _param_cuda_8bit
122
+ Params4bit.to = _param_to_4bit # type: ignore
123
+ Params4bit.cuda = _param_cuda_4bit
124
+
125
+ def _move():
126
+ if CUDASetup is not None:
127
+ CUDASetup._instance = None
128
+ importlib.reload(cextension)
129
+ functional.lib = cextension.lib
130
+ for op in to_ops_8bit.items():
131
+ tensor, parsed_args = op
132
+ if parsed_args:
133
+ _, dtype, _, memory_format = parsed_args
134
+ else:
135
+ dtype, memory_format = None, None
136
+ tensor.data = _param_to_8bit(tensor,
137
+ device='cuda',
138
+ dtype=dtype,
139
+ memory_format=memory_format,
140
+ ) # type: ignore
141
+ for op in to_ops_4bit.items():
142
+ tensor, parsed_args = op
143
+ if parsed_args:
144
+ _, dtype, _, memory_format = parsed_args
145
+ else:
146
+ dtype, memory_format = None, None
147
+ tensor.data = _param_to_4bit(tensor,
148
+ device='cuda',
149
+ dtype=dtype,
150
+ memory_format=memory_format,
151
+ ) # type: ignore
152
+
153
+ else:
154
+
155
+ _patch = lambda: None
156
+ _unpatch = lambda: None
157
+ _move = lambda: None
158
+
159
+
160
+ patch = _patch
161
+ unpatch = _unpatch
162
+ move = _move
spaces/zero/torch/packing.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import time
6
+
7
+ import ctypes
8
+ import os
9
+ from concurrent.futures import as_completed
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from contextvars import copy_context
12
+ from dataclasses import dataclass
13
+ from queue import Queue
14
+ from typing import Callable
15
+
16
+ from ...utils import debug
17
+
18
+ import torch
19
+ from typing_extensions import TypeAlias
20
+
21
+
22
+ PAGE_SIZE = 4096
23
+ TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
24
+ VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
25
+
26
+ BUFFER_SIZE = 64 * 2**20
27
+ BUFFER_COUNT = 2
28
+
29
+
30
+ TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
31
+
32
+ @dataclass
33
+ class ZeroGPUTensorPack:
34
+ base_dir: str
35
+ batches: list[list[TensorWithSizes]]
36
+ big_tensors: list[TensorWithSizes]
37
+ fakes: dict[torch.Tensor, list[torch.Tensor]]
38
+ total_size: int
39
+ def path(self):
40
+ return f'{self.base_dir}/{id(self)}'
41
+ def __del__(self):
42
+ try:
43
+ os.remove(self.path())
44
+ except FileNotFoundError: # pragma: no cover
45
+ pass
46
+
47
+
48
+ def write(fd: int, tensor: torch.Tensor):
49
+ clone = torch.empty_like(tensor)
50
+ size = clone.untyped_storage().size() # pyright: ignore [reportAttributeAccessIssue]
51
+ buffer = torch.UntypedStorage(VM_MAX_SIZE)
52
+ buffer_ptr = buffer.data_ptr()
53
+ offset = -buffer_ptr % PAGE_SIZE
54
+ padding = -size % PAGE_SIZE
55
+ clone.set_(buffer[offset:offset+size], 0, clone.shape, clone.stride()) # pyright: ignore [reportArgumentType]
56
+ clone.copy_(tensor)
57
+ mv = memoryview((ctypes.c_char * (size+padding)).from_address(buffer_ptr+offset))
58
+ written_bytes = 0
59
+ while written_bytes < size:
60
+ written_bytes += os.write(fd, mv[written_bytes:])
61
+
62
+
63
+ def pack_tensors(
64
+ tensors: set[torch.Tensor],
65
+ fakes: dict[torch.Tensor, list[torch.Tensor]],
66
+ offload_dir: str,
67
+ callback: Callable[[int]] | None = None,
68
+ ):
69
+
70
+ callback = (lambda bytes: None) if callback is None else callback
71
+
72
+ batches: list[list[TensorWithSizes]] = []
73
+ big_tensors: list[TensorWithSizes] = []
74
+
75
+ tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
76
+ for tensor in tensors:
77
+ size = tensor.numel() * tensor.element_size()
78
+ aligned_size = size + (-size % PAGE_SIZE)
79
+ tensors_with_sizes += [(tensor, size, aligned_size)]
80
+
81
+ current_batch, current_size = [], 0
82
+ for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
83
+ if aligned_size > BUFFER_SIZE:
84
+ big_tensors += [(tensor, size, aligned_size)]
85
+ continue
86
+ current_size += aligned_size
87
+ if current_size > BUFFER_SIZE:
88
+ batches += [current_batch]
89
+ current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
90
+ else:
91
+ current_batch += [(tensor, size, aligned_size)]
92
+
93
+ if current_batch:
94
+ batches += [current_batch]
95
+
96
+ get_meta = {tensor: torch.empty_like(tensor) for tensor in tensors}
97
+ batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
98
+ big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
99
+ fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
100
+
101
+ pack = ZeroGPUTensorPack(
102
+ base_dir=offload_dir,
103
+ batches=batches_meta,
104
+ big_tensors=big_tensors_meta,
105
+ fakes=fakes_meta,
106
+ total_size=sum([size for _, size, _ in tensors_with_sizes]),
107
+ )
108
+
109
+ fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
110
+ try:
111
+ total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
112
+ total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
113
+ if total_asize > 0:
114
+ os.posix_fallocate(fd, 0, total_asize)
115
+ for batch in batches:
116
+ for tensor, size, _ in batch:
117
+ write(fd, tensor)
118
+ callback(size)
119
+ for tensor, size, _ in big_tensors:
120
+ write(fd, tensor)
121
+ callback(size)
122
+ return pack
123
+ finally:
124
+ os.close(fd)
125
+
126
+
127
+ def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int]] | None = None):
128
+
129
+ callback = (lambda bytes: None) if callback is None else callback
130
+
131
+ free_buffers: Queue[torch.Tensor] = Queue()
132
+ read_buffers: Queue[torch.Tensor] = Queue()
133
+
134
+ for _ in range(BUFFER_COUNT):
135
+ free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
136
+
137
+ def read(fd: int, buffer: torch.Tensor, size: int):
138
+ mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
139
+ read_bytes = 0
140
+ while read_bytes < size:
141
+ read_bytes += os.readv(fd, [mv[read_bytes:]])
142
+
143
+ def disk_to_pin(fd: int):
144
+ for batch in pack.batches:
145
+ buffer = free_buffers.get()
146
+ batch_size = sum([aligned_size for *_, aligned_size in batch])
147
+ read(fd, buffer, batch_size)
148
+ read_buffers.put(buffer)
149
+ for *_, aligned_size in pack.big_tensors:
150
+ read_bytes = 0
151
+ while read_bytes < aligned_size:
152
+ buffer = free_buffers.get()
153
+ read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
154
+ read(fd, buffer, read_size)
155
+ read_buffers.put(buffer)
156
+ read_bytes += read_size
157
+
158
+ def pin_to_cuda():
159
+ total_duration_in_callback = 0
160
+ for batch in pack.batches:
161
+ buffer = read_buffers.get()
162
+ offset = 0
163
+ cuda_storages = []
164
+ for tensor, size, aligned_size in batch:
165
+ cuda_storages += [buffer[offset:offset+size].cuda(non_blocking=True)]
166
+ offset += aligned_size
167
+ torch.cuda.synchronize()
168
+ free_buffers.put(buffer)
169
+ batch_total_size = 0
170
+ for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
171
+ cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
172
+ cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
173
+ for fake in pack.fakes[tensor]:
174
+ fake.data = cuda_tensor
175
+ batch_total_size += size
176
+ t0 = time.perf_counter()
177
+ callback(batch_total_size)
178
+ total_duration_in_callback += time.perf_counter() - t0
179
+ for tensor, size, _ in pack.big_tensors:
180
+ cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
181
+ offset = 0
182
+ while offset < size:
183
+ buffer = read_buffers.get()
184
+ read_size = min(BUFFER_SIZE, size - offset)
185
+ cuda_storage[offset:offset+read_size] = buffer[:read_size]
186
+ offset += read_size
187
+ torch.cuda.synchronize() # Probably not needed
188
+ free_buffers.put(buffer)
189
+ t0 = time.perf_counter()
190
+ callback(read_size)
191
+ total_duration_in_callback += time.perf_counter() - t0
192
+ cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
193
+ cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
194
+ for fake in pack.fakes[tensor]:
195
+ fake.data = cuda_tensor
196
+
197
+ debug(f"{total_duration_in_callback=}")
198
+
199
+ with ThreadPoolExecutor(2) as e:
200
+ fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
201
+ try:
202
+ futures = [
203
+ e.submit(copy_context().run, disk_to_pin, fd),
204
+ e.submit(copy_context().run, pin_to_cuda),
205
+ ]
206
+ for future in as_completed(futures):
207
+ future.result()
208
+ finally:
209
+ os.close(fd)
spaces/zero/torch/patching.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import gc
8
+ import multiprocessing
9
+ import os
10
+ from collections import defaultdict
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from contextlib import nullcontext
14
+ from contextvars import copy_context
15
+ from types import SimpleNamespace
16
+ from typing import Any
17
+ from typing import Callable
18
+
19
+ import torch
20
+ from torch.overrides import TorchFunctionMode
21
+ from torch.overrides import resolve_name
22
+ from torch.utils._python_dispatch import TorchDispatchMode
23
+ from torch.utils._pytree import tree_map_only
24
+ from torch.utils.weak import WeakTensorKeyDictionary
25
+
26
+ from ...config import Config
27
+ from ...utils import malloc_trim
28
+ from ..tqdm import tqdm
29
+ from . import bitsandbytes
30
+ from .packing import ZeroGPUTensorPack
31
+ from .packing import pack_tensors
32
+ from .packing import pack_to_cuda
33
+ from .types import AliasId
34
+
35
+
36
+ # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
37
+ CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
38
+ CUDA_TOTAL_MEMORY = 42144366592
39
+ CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
40
+ CUDA_DEVICE_CAPABILITY = (8, 0)
41
+ CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
42
+
43
+ OPS_INPUTS_CHECK_NO_RETURN = (
44
+ torch.Tensor.equal,
45
+ )
46
+
47
+ OPS_INPUT_CHECK_SELF_RETURN = (
48
+ torch.Tensor.set_, # probably never dispatched
49
+ torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
50
+ )
51
+
52
+ OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
53
+
54
+ _tensor_make_subclass = torch.Tensor._make_subclass
55
+ _asarray = torch.asarray
56
+ _cuda_init = torch._C._cuda_init
57
+ _cuda_exchange_device = torch.cuda._exchange_device
58
+ _cuda_available = torch.cuda.is_available
59
+ _cuda_device_count = torch.cuda.device_count
60
+ _cuda_current_device = torch.cuda.current_device
61
+ _cuda_mem_get_info = torch.cuda.mem_get_info
62
+ _cuda_get_device_capability = torch.cuda.get_device_capability
63
+ _cuda_get_device_properties = torch.cuda.get_device_properties
64
+ _cuda_get_device_name = torch.cuda.get_device_name
65
+
66
+ # PyTorch 2.3
67
+ _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
68
+
69
+
70
+ cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
71
+
72
+ tensor_packs: list[ZeroGPUTensorPack] = []
73
+
74
+ class ZeroGPUTensor(torch.Tensor):
75
+ pass
76
+
77
+ def empty_fake(tensor: torch.Tensor):
78
+ fake = torch.empty_like(tensor, requires_grad=tensor.requires_grad)
79
+ if fake.__class__ != tensor.__class__:
80
+ fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
81
+ return fake
82
+
83
+ class ZeroGPUFunctionMode(TorchFunctionMode):
84
+
85
+ def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
86
+
87
+ kwargs = {} if kwargs is None else kwargs
88
+
89
+ if func == torch._C._nn._parse_to:
90
+ return func(*args, **kwargs)
91
+
92
+ # Redispatch: tensor.cuda() -> tensor.to(device='cuda')
93
+ if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
94
+ memory_format = kwargs.get('memory_format')
95
+ return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
96
+ 'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
97
+ **({'memory_format': memory_format} if memory_format is not None else {}),
98
+ })
99
+
100
+ # Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
101
+ if func == torch.Tensor.to and len(args) > 1:
102
+ device, dtype, _, memory_format = torch._C._nn._parse_to(*args[1:], **kwargs)
103
+ return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
104
+ 'device': device,
105
+ 'dtype': dtype,
106
+ 'memory_format': memory_format,
107
+ })
108
+
109
+ if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
110
+ self, target = args
111
+ if target in cuda_aliases:
112
+ if (target_original := cuda_aliases[target]) is None:
113
+ raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
114
+ original = empty_fake(self)
115
+ original.data = target_original
116
+ cuda_aliases[self] = original
117
+ elif self in cuda_aliases:
118
+ del cuda_aliases[self]
119
+ self.data = target
120
+ return
121
+
122
+ if func == torch.Tensor.device.__get__:
123
+ tensor, = args
124
+ if tensor in cuda_aliases:
125
+ return torch.device('cuda', index=0)
126
+
127
+ elif func == torch.Tensor.__repr__:
128
+ tensor, = args
129
+ if tensor in cuda_aliases:
130
+ if (original := cuda_aliases[tensor]) is None:
131
+ original = tensor.to('meta')
132
+ original_class = original.__class__
133
+ original.__class__ = ZeroGPUTensor
134
+ try:
135
+ return func(original, **kwargs)
136
+ finally:
137
+ original.__class__ = original_class
138
+
139
+ elif func == torch.Tensor.untyped_storage:
140
+ tensor, = args
141
+ if tensor in cuda_aliases:
142
+ if (original := cuda_aliases[tensor]) is None:
143
+ raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
144
+ res = func(original, **kwargs)
145
+ res._zerogpu = True
146
+ return res
147
+
148
+ cuda: bool | None = None
149
+
150
+ # Handle device kwarg
151
+ if (device := kwargs.get('device')) is not None:
152
+ device = torch.device(device)
153
+ if device.type == 'cuda':
154
+ kwargs['device'] = torch.device('cpu')
155
+ cuda = True
156
+ else:
157
+ cuda = False
158
+
159
+ # Swap fake inputs with original data
160
+ swapped = {}
161
+ inputs_are_cuda = set()
162
+ def swap(tensor: torch.Tensor):
163
+ nonlocal inputs_are_cuda
164
+ if tensor not in cuda_aliases:
165
+ inputs_are_cuda |= {False}
166
+ return tensor
167
+ if (original := cuda_aliases[tensor]) is None:
168
+ raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
169
+ swapped[original] = tensor
170
+ inputs_are_cuda |= {True}
171
+ return original
172
+ args_ = tree_map_only(torch.Tensor, swap, args)
173
+ kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
174
+ if inputs_are_cuda == {True}:
175
+ if cuda is not False:
176
+ cuda = True
177
+
178
+ res = func(*args_, **kwargs_)
179
+
180
+ # Re-generate swapped fakes in case of mutation
181
+ for original, fake in swapped.items():
182
+ fake.data = empty_fake(original)
183
+
184
+ # Special case for Tensor indexing where only 'self' matters
185
+ if func in {
186
+ torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
187
+ torch.Tensor.__getitem__, # PyTorch 2.4+
188
+ }:
189
+ self = args[0]
190
+ cuda = self in cuda_aliases
191
+ inputs_are_cuda = {cuda}
192
+
193
+ # Emulate device check
194
+ if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
195
+ self = None
196
+ if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
197
+ self = args_[0]
198
+ # Only raise if func does not return its first input (Tensor.copy_)
199
+ if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
200
+ if inputs_are_cuda == {True, False}:
201
+ raise RuntimeError(
202
+ "Expected all tensors to be on the same device, "
203
+ "but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
204
+ )
205
+
206
+ # Register output
207
+ def register(tensor: torch.Tensor):
208
+ if tensor in swapped and cuda is not False:
209
+ return swapped[tensor]
210
+ if cuda is not True:
211
+ return tensor
212
+ fake = empty_fake(tensor)
213
+ cuda_aliases[fake] = tensor
214
+ return fake
215
+
216
+ return tree_map_only(torch.Tensor, register, res)
217
+
218
+ # When enabling DispatchMode, some aten ops are dispatched to FunctionMode
219
+ # We are using it for aten.alias.default and aten.set_.source_Tensor
220
+ class DefaultDispatchMode(TorchDispatchMode):
221
+ def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
222
+ return func(*args, **(kwargs or {}))
223
+
224
+
225
+ function_mode = ZeroGPUFunctionMode()
226
+ dispatch_mode = DefaultDispatchMode()
227
+
228
+
229
+ def _untyped_storage_new_register(*args, **kwargs):
230
+ cuda = False
231
+ if (device := kwargs.get('device')) is not None and device.type == 'cuda':
232
+ cuda = True
233
+ del kwargs['device']
234
+ storage = torch._C.StorageBase.__new__(*args, **kwargs)
235
+ if cuda:
236
+ storage._zerogpu = True
237
+ return storage
238
+
239
+ @property
240
+ def _untyped_storage_device(self):
241
+ if hasattr(self, '_zerogpu'):
242
+ return torch.device('cuda', index=0)
243
+ return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
244
+
245
+ # Force dispatch
246
+ def _tensor_make_subclass_function_mode(*args, **kwargs):
247
+ with torch._C.DisableTorchFunction():
248
+ return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
249
+ def _asarray_function_mode(*args, **kwargs):
250
+ with torch._C.DisableTorchFunction():
251
+ return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
252
+
253
+ def _cuda_init_raise():
254
+ raise RuntimeError(
255
+ "CUDA must not be initialized in the main process "
256
+ "on Spaces with Stateless GPU environment.\n"
257
+ "You can look at this Stacktrace to find out "
258
+ "which part of your code triggered a CUDA init"
259
+ )
260
+
261
+ def _cuda_dummy_exchange_device(device):
262
+ assert device in {-1, 0}
263
+ return device
264
+
265
+ def patch():
266
+ function_mode.__enter__()
267
+ dispatch_mode.__enter__()
268
+ # TODO: only patch bellow methods on current Thread to be consistent with TorchModes
269
+ # (or hijack threading.Thread.__init__ to force Modes on all threads)
270
+ torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
271
+ torch.UntypedStorage.__new__ = _untyped_storage_new_register
272
+ torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
273
+ torch.asarray = _asarray_function_mode
274
+ torch._C._cuda_init = _cuda_init_raise
275
+ torch.cuda._exchange_device = _cuda_dummy_exchange_device
276
+ torch.cuda.is_available = lambda: True
277
+ torch.cuda.device_count = lambda: 1
278
+ torch.cuda.current_device = lambda: 0
279
+ torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
280
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
281
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
282
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
283
+ # PyTorch 2.3
284
+ if _cuda_maybe_exchange_device is not None: # pragma: no cover
285
+ setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
286
+ bitsandbytes.patch()
287
+
288
+ def unpatch():
289
+ try:
290
+ dispatch_mode.__exit__(None, None, None)
291
+ function_mode.__exit__(None, None, None)
292
+ except RuntimeError:
293
+ pass # patch() and unpatch() called from != threads
294
+ torch.Tensor._make_subclass = _tensor_make_subclass
295
+ torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
296
+ torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
297
+ torch.asarray = _asarray
298
+ torch._C._cuda_init = _cuda_init
299
+ torch.cuda._exchange_device = _cuda_exchange_device
300
+ torch.cuda.is_available = _cuda_available
301
+ torch.cuda.device_count = _cuda_device_count
302
+ torch.cuda.current_device = _cuda_current_device
303
+ torch.cuda.mem_get_info = _cuda_mem_get_info
304
+ torch.cuda.get_device_capability = _cuda_get_device_capability
305
+ torch.cuda.get_device_properties = _cuda_get_device_properties
306
+ torch.cuda.get_device_name = _cuda_get_device_name
307
+ # PyTorch 2.3
308
+ if _cuda_maybe_exchange_device is not None: # pragma: no cover
309
+ setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
310
+ bitsandbytes.unpatch()
311
+
312
+
313
+ def _total_unpacked_size():
314
+ tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
315
+ deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
316
+ return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
317
+
318
+
319
+ def _pack(offload_dir: str):
320
+ # Pack to disk
321
+ originals: set[torch.Tensor] = set()
322
+ originals_dedup: dict[AliasId, torch.Tensor] = {}
323
+ fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
324
+ for fake, original in cuda_aliases.items():
325
+ # TODO filter-out sparse Tensors
326
+ if original is not None:
327
+ original_id = AliasId.from_tensor(original)
328
+ if original_id not in originals_dedup:
329
+ originals_dedup[original_id] = original
330
+ originals |= {original}
331
+ fakes[originals_dedup[original_id]] += [fake]
332
+ progress = tqdm(
333
+ total=_total_unpacked_size(),
334
+ unit='B',
335
+ unit_scale=True,
336
+ desc="ZeroGPU tensors packing",
337
+ ) if tqdm is not None else nullcontext()
338
+ with progress as progress:
339
+ update = progress.update if progress is not None else lambda _: None
340
+ pack = pack_tensors(originals, fakes, offload_dir, callback=update)
341
+ tensor_packs.append(pack)
342
+ # Free memory
343
+ for fake_list in fakes.values():
344
+ for fake in fake_list:
345
+ cuda_aliases[fake] = None
346
+
347
+ def pack():
348
+ _pack(Config.zerogpu_offload_dir)
349
+ gc.collect()
350
+ malloc_trim()
351
+
352
+ def init(nvidia_uuid: str):
353
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
354
+ torch.Tensor([0]).cuda()
355
+
356
+ def size():
357
+ return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
358
+
359
+ def _move(callback: Callable[[int]] | None = None):
360
+ callback = callback if callback is not None else lambda _: None
361
+ # CPU -> CUDA
362
+ moved: dict[AliasId, torch.Tensor] = {}
363
+ for fake, original in cuda_aliases.items():
364
+ if original is not None:
365
+ original_id = AliasId.from_tensor(original)
366
+ if original_id not in moved:
367
+ moved[original_id] = original.cuda()
368
+ callback(fake.numel() * fake.element_size())
369
+ for fake, original in cuda_aliases.items():
370
+ if original is not None:
371
+ fake.data = moved[AliasId.from_tensor(original)]
372
+ # Disk -> CUDA
373
+ for tensor_pack in tensor_packs:
374
+ pack_to_cuda(tensor_pack, callback=callback)
375
+ bitsandbytes.move()
376
+
377
+ def move(callback: Callable[[int]] | None = None):
378
+ callback = callback if callback is not None else lambda _: None
379
+ with ThreadPoolExecutor(1) as e:
380
+ e.submit(copy_context().run, _move, callback=callback).result()
381
+ torch.cuda.synchronize()
382
+
383
+ def is_in_bad_fork():
384
+ with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
385
+ f = e.submit(torch.cuda._is_in_bad_fork)
386
+ return f.result()
spaces/zero/torch/patching_legacy.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import multiprocessing
8
+ import os
9
+ from concurrent.futures import ProcessPoolExecutor
10
+ from contextlib import suppress
11
+ from functools import partial
12
+ from types import SimpleNamespace
13
+ from typing import Any
14
+ from typing import Callable
15
+ from typing import Optional
16
+ from typing import Tuple
17
+
18
+ import torch
19
+ from torch.utils.weak import WeakTensorKeyDictionary
20
+
21
+ from ...config import Config
22
+ from . import bitsandbytes
23
+
24
+
25
+ # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
26
+ CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
27
+ CUDA_TOTAL_MEMORY = 42144366592
28
+ CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
29
+ CUDA_DEVICE_CAPABILITY = (8, 0)
30
+ CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
31
+
32
+ GENERIC_METHOD_NAMES = [
33
+ 'arange',
34
+ 'as_tensor',
35
+ 'asarray',
36
+ 'bartlett_window',
37
+ 'blackman_window',
38
+ 'empty',
39
+ 'empty_like',
40
+ 'empty_strided',
41
+ 'eye',
42
+ 'full',
43
+ 'full_like',
44
+ 'hamming_window',
45
+ 'hann_window',
46
+ 'kaiser_window',
47
+ 'linspace',
48
+ 'logspace',
49
+ 'ones',
50
+ 'ones_like',
51
+ 'rand',
52
+ 'rand_like',
53
+ 'randint',
54
+ 'randint_like',
55
+ 'randn',
56
+ 'randn_like',
57
+ 'randperm',
58
+ 'range',
59
+ 'sparse_bsc_tensor',
60
+ 'sparse_bsr_tensor',
61
+ 'sparse_compressed_tensor',
62
+ 'sparse_coo_tensor',
63
+ 'sparse_csc_tensor',
64
+ 'sparse_csr_tensor',
65
+ 'tensor',
66
+ 'tril_indices',
67
+ 'triu_indices',
68
+ 'zeros',
69
+ 'zeros_like',
70
+ ]
71
+
72
+
73
+ TO_CUDA = (torch.device('cuda'), None, False, None)
74
+
75
+ _tensor__deepcopy__ = torch.Tensor.__deepcopy__
76
+ _tensor_to = torch.Tensor.to
77
+ _tensor_cuda = torch.Tensor.cuda
78
+ _tensor_cpu = torch.Tensor.cpu
79
+ _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
80
+ _cuda_init = torch._C._cuda_init
81
+ _cuda_available = torch.cuda.is_available
82
+ _cuda_device_count = torch.cuda.device_count
83
+ _cuda_current_device = torch.cuda.current_device
84
+ _cuda_mem_get_info = torch.cuda.mem_get_info
85
+ _cuda_get_device_capability = torch.cuda.get_device_capability
86
+ _cuda_get_device_properties = torch.cuda.get_device_properties
87
+ _cuda_get_device_name = torch.cuda.get_device_name
88
+
89
+ TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
90
+
91
+ to_ops: dict[torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
92
+
93
+ def _tensor_new_register(*args, **kwargs):
94
+ new_tensor: torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
95
+ if (base_tensor := new_tensor._base) is not None:
96
+ if base_tensor in to_ops:
97
+ to_ops[new_tensor] = to_ops[base_tensor]
98
+ return new_tensor
99
+
100
+ def _tensor_deepcopy_register(self: torch.Tensor, memo):
101
+ new_tensor = _tensor__deepcopy__(self, memo)
102
+ if isinstance(new_tensor, torch.Tensor):
103
+ if self in to_ops:
104
+ to_ops[new_tensor] = to_ops[self]
105
+ return new_tensor
106
+
107
+ @property
108
+ def _tensor_device_property(self: torch.Tensor):
109
+ if self in to_ops:
110
+ return torch.device(type='cuda', index=0)
111
+ del torch.Tensor.device
112
+ try:
113
+ return self.device
114
+ finally:
115
+ torch.Tensor.device = _tensor_device_property # type: ignore
116
+
117
+ @property
118
+ def _tensor_dtype_property(self: torch.Tensor):
119
+ if self in to_ops:
120
+ if (to_dtype := to_ops[self][1]) is not None:
121
+ return to_dtype
122
+ del torch.Tensor.dtype
123
+ try:
124
+ return self.dtype
125
+ finally:
126
+ torch.Tensor.dtype = _tensor_dtype_property # type: ignore
127
+
128
+ def _to_op_register(self: torch.Tensor, *args, **kwargs):
129
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
130
+ device, dtype, *_ = parsed
131
+ try:
132
+ to_args = to_ops.pop(self)
133
+ except KeyError:
134
+ to_args = None
135
+ if device is None: # pyright: ignore [reportUnnecessaryComparison]
136
+ if to_args is not None:
137
+ to_ops[self] = (to_args[0], dtype, *to_args[2:])
138
+ return self
139
+ return _tensor_to(self, *args, **kwargs)
140
+ if device.type != 'cuda':
141
+ if to_args is not None:
142
+ if (to_dtype := to_args[1]) is not None:
143
+ kwargs = {'dtype': to_dtype, **kwargs}
144
+ return _tensor_to(self, *args, **kwargs)
145
+ to_ops[self] = parsed
146
+ return self
147
+
148
+ def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
149
+ if device is None:
150
+ return True
151
+ if isinstance(device, int):
152
+ return True
153
+ if isinstance(device, str):
154
+ device = torch.device(device)
155
+ return device.type == 'cuda'
156
+
157
+ def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
158
+ if not _cuda_op_arg_check(device):
159
+ # Let PyTorch handle the fail
160
+ return _tensor_cuda(self, device, **kwargs)
161
+ to_ops[self] = TO_CUDA
162
+ return self
163
+
164
+ def _cpu_op_remove(self: torch.Tensor, **kwargs):
165
+ try:
166
+ to_args = to_ops.pop(self)
167
+ except KeyError:
168
+ to_args = None
169
+ if to_args is not None:
170
+ if (to_dtype := to_args[1]) is not None:
171
+ return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
172
+ return _tensor_cpu(self, **kwargs)
173
+
174
+ def _cuda_init_raise():
175
+ raise RuntimeError(
176
+ "CUDA must not be initialized in the main process "
177
+ "on Spaces with Stateless GPU environment.\n"
178
+ "You can look at this Stacktrace to find out "
179
+ "which part of your code triggered a CUDA init"
180
+ )
181
+
182
+ def _generic_method_register(name: str, *args: Any, **kwargs: Any):
183
+ try:
184
+ device = torch.device(kwargs.get('device', "cpu"))
185
+ except Exception:
186
+ return _torch_generics[name](*args, **kwargs)
187
+ if device.type != 'cuda':
188
+ return _torch_generics[name](*args, **kwargs)
189
+ tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
190
+ to_ops[tensor] = TO_CUDA
191
+ return tensor
192
+
193
+ def patch():
194
+ torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
195
+ torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
196
+ torch.Tensor.to = _to_op_register # type: ignore
197
+ torch.Tensor.cuda = _cuda_op_register # type: ignore
198
+ torch.Tensor.cpu = _cpu_op_remove # type: ignore
199
+ if Config.zero_patch_torch_device:
200
+ torch.Tensor.device = _tensor_device_property # type: ignore
201
+ torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
202
+ for name in GENERIC_METHOD_NAMES:
203
+ setattr(torch, name, partial(_generic_method_register, name))
204
+ torch._C._cuda_init = _cuda_init_raise
205
+ torch.cuda.is_available = lambda: True
206
+ torch.cuda.device_count = lambda: 1
207
+ torch.cuda.current_device = lambda: 0
208
+ torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
209
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
210
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
211
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
212
+ bitsandbytes.patch()
213
+
214
+ def unpatch():
215
+ torch.Tensor.__deepcopy__ = _tensor__deepcopy__
216
+ with suppress(AttributeError):
217
+ del torch.Tensor.__new__
218
+ torch.Tensor.to = _tensor_to
219
+ torch.Tensor.cuda = _tensor_cuda
220
+ torch.Tensor.cpu = _tensor_cpu
221
+ with suppress(AttributeError):
222
+ del torch.Tensor.device
223
+ with suppress(AttributeError):
224
+ del torch.Tensor.dtype
225
+ for name in GENERIC_METHOD_NAMES:
226
+ setattr(torch, name, _torch_generics[name])
227
+ torch._C._cuda_init = _cuda_init
228
+ torch.cuda.is_available = _cuda_available
229
+ torch.cuda.device_count = _cuda_device_count
230
+ torch.cuda.current_device = _cuda_current_device
231
+ torch.cuda.mem_get_info = _cuda_mem_get_info
232
+ torch.cuda.get_device_capability = _cuda_get_device_capability
233
+ torch.cuda.get_device_properties = _cuda_get_device_properties
234
+ torch.cuda.get_device_name = _cuda_get_device_name
235
+ bitsandbytes.unpatch()
236
+
237
+ def pack():
238
+ pass
239
+
240
+ def init(nvidia_uuid: str):
241
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
242
+ torch.Tensor([0]).cuda() # CUDA init
243
+
244
+ def size():
245
+ return 0
246
+
247
+ def move(callback: Callable[[int]] | None = None):
248
+ for op in to_ops.items():
249
+ tensor, parsed_args = op
250
+ _, dtype, _, memory_format = parsed_args
251
+ tensor.data = _tensor_to(tensor,
252
+ device='cuda',
253
+ dtype=dtype,
254
+ memory_format=memory_format,
255
+ ) # type: ignore
256
+ bitsandbytes.move()
257
+ torch.cuda.synchronize()
258
+
259
+ def is_in_bad_fork():
260
+ with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
261
+ f = e.submit(torch.cuda._is_in_bad_fork)
262
+ return f.result()
263
+
264
+ def disable_cuda_intercept():
265
+ torch.Tensor.to = _tensor_to
266
+ torch.Tensor.cuda = _tensor_cuda
spaces/zero/torch/types.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from typing import NamedTuple
6
+
7
+ import torch
8
+
9
+
10
+ class AliasId(NamedTuple):
11
+ data_ptr: int
12
+ dtype: torch.dtype
13
+ shape: tuple[int, ...]
14
+ stride: tuple[int, ...]
15
+
16
+ @classmethod
17
+ def from_tensor(cls, tensor: torch.Tensor):
18
+ return cls(
19
+ tensor.data_ptr(),
20
+ tensor.dtype,
21
+ tensor.shape,
22
+ tensor.stride(),
23
+ )
spaces/zero/tqdm.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from multiprocessing.synchronize import RLock as MultiprocessingRLock
5
+
6
+
7
+ try:
8
+ from tqdm import tqdm as _tqdm
9
+ except ImportError: # pragma: no cover
10
+ _tqdm = None
11
+
12
+
13
+ def remove_tqdm_multiprocessing_lock():
14
+ if _tqdm is None: # pragma: no cover
15
+ return
16
+ tqdm_lock = _tqdm.get_lock()
17
+ assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
18
+ tqdm_lock.locks = [
19
+ lock for lock in tqdm_lock.locks
20
+ if not isinstance(lock, MultiprocessingRLock)
21
+ ]
22
+
23
+
24
+ tqdm = _tqdm
spaces/zero/types.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+
6
+ from dataclasses import dataclass
7
+ from datetime import timedelta
8
+ from typing import Any
9
+ from typing import Dict
10
+ from typing import Tuple
11
+ from typing import TypedDict
12
+ from typing_extensions import Callable
13
+ from typing_extensions import Generic
14
+ from typing_extensions import ParamSpec
15
+ from typing_extensions import TypeAlias
16
+ from typing_extensions import TypeVar
17
+
18
+
19
+ Params = Tuple[Tuple[object, ...], Dict[str, Any]]
20
+ Res = TypeVar('Res')
21
+ Param = ParamSpec('Param')
22
+
23
+ class EmptyKwargs(TypedDict):
24
+ pass
25
+
26
+ @dataclass
27
+ class OkResult(Generic[Res]):
28
+ value: Res
29
+ @dataclass
30
+ class ExceptionResult:
31
+ value: Exception
32
+ @dataclass
33
+ class AbortedResult:
34
+ pass
35
+ @dataclass
36
+ class EndResult:
37
+ pass
38
+ @dataclass
39
+ class GradioQueueEvent:
40
+ method_name: str
41
+ args: tuple[Any, ...]
42
+ kwargs: dict[str, Any]
43
+
44
+ RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
45
+ GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
46
+ YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
47
+
48
+ Duration: TypeAlias = "int | timedelta"
49
+ DynamicDuration: TypeAlias = "Duration | Callable[Param, Duration] | None"
spaces/zero/wrappers.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import multiprocessing
6
+ import os
7
+ import signal
8
+ import time
9
+ import traceback
10
+ import warnings
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from contextlib import nullcontext
13
+ from contextvars import copy_context
14
+ from datetime import timedelta
15
+ from functools import partial
16
+ from functools import wraps
17
+ from multiprocessing.context import ForkProcess
18
+ from pickle import PicklingError
19
+ from queue import Empty
20
+ from queue import Queue as ThreadQueue
21
+ from threading import Thread
22
+ from typing import TYPE_CHECKING
23
+ from typing import Callable
24
+ from typing import Generator
25
+ from typing import Generic
26
+ from typing_extensions import assert_never
27
+
28
+ import psutil
29
+
30
+ from ..config import Config
31
+ from ..utils import debug
32
+ from ..utils import drop_params
33
+ from ..utils import gradio_request_var
34
+ from ..utils import SimpleQueue as Queue
35
+ from . import client
36
+ from . import torch
37
+ from .api import AllowToken
38
+ from .api import NvidiaIndex
39
+ from .api import NvidiaUUID
40
+ from .gradio import GradioPartialContext
41
+ from .gradio import get_server_port
42
+ from .gradio import patch_gradio_queue
43
+ from .gradio import try_process_queue_event
44
+ from .tqdm import remove_tqdm_multiprocessing_lock
45
+ from .tqdm import tqdm
46
+ from .types import * # TODO: Please don't do that
47
+
48
+
49
+ GENERATOR_GLOBAL_TIMEOUT = 20 * 60
50
+
51
+ SPAWN_PROGRESS_CLEANUP = 0.1
52
+ SPAWN_PROGRESS_INIT = 0.1
53
+
54
+
55
+ Process = multiprocessing.get_context('fork').Process
56
+ forked = False
57
+
58
+
59
+ class Worker(Generic[Res]):
60
+ process: ForkProcess
61
+ arg_queue: Queue[tuple[Params, GradioPartialContext]]
62
+ res_queue: Queue[Res | None]
63
+ _sentinel: Thread
64
+
65
+ def __init__(
66
+ self,
67
+ target: Callable[[
68
+ Queue[tuple[Params, GradioPartialContext]],
69
+ Queue[Res | None],
70
+ AllowToken,
71
+ NvidiaUUID,
72
+ list[int],
73
+ ], None],
74
+ allow_token: str,
75
+ nvidia_uuid: str,
76
+ ):
77
+ self._sentinel = Thread(target=self._close_on_exit, daemon=True)
78
+ self.arg_queue = Queue()
79
+ self.res_queue = Queue()
80
+ debug(f"{self.arg_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
81
+ debug(f"{self.res_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
82
+ if (server_port := get_server_port()) is not None:
83
+ fds = [c.fd for c in psutil.Process().connections() if c.laddr.port == server_port]
84
+ debug(f"{fds=}")
85
+ else:
86
+ warnings.warn("Using a ZeroGPU function outside of Gradio caching or request might block the app")
87
+ fds = []
88
+ args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
89
+ if TYPE_CHECKING:
90
+ target(*args)
91
+ self.process = Process(
92
+ target=target,
93
+ args=args,
94
+ daemon=True,
95
+ )
96
+ self.process.start()
97
+ self._sentinel.start()
98
+
99
+ def _close_on_exit(self):
100
+ self.process.join()
101
+ self.arg_queue.close()
102
+ self.res_queue.wlock_release()
103
+ self.res_queue.put(None)
104
+
105
+
106
+ def worker_init(
107
+ res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
108
+ allow_token: str,
109
+ nvidia_uuid: str,
110
+ fds: list[int],
111
+ ) -> None | ExceptionResult:
112
+ # Immediately close file descriptors
113
+ for fd in fds:
114
+ try:
115
+ os.close(fd)
116
+ except Exception as e: # pragma: no cover
117
+ if isinstance(e, OSError) and e.errno == 9:
118
+ continue
119
+ traceback.print_exc()
120
+ return ExceptionResult(e)
121
+ progress = nullcontext()
122
+ if tqdm is not None and Config.zero_gpu_v2:
123
+ progress = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w'))
124
+ try: # Unrecoverable init part
125
+ patch_gradio_queue(res_queue)
126
+ with progress as progress:
127
+ current_progress = 0 # Gradio does not support float progress updates
128
+ def update(n: float):
129
+ nonlocal current_progress
130
+ current_progress += n
131
+ if progress is not None:
132
+ progress.update(round(current_progress * 100) - progress.n)
133
+ t0 = time.perf_counter()
134
+ client.allow(allow_token)
135
+ print("client.allow", (dt := time.perf_counter() - t0)); t0 = dt
136
+ update(SPAWN_PROGRESS_CLEANUP)
137
+ torch.unpatch()
138
+ print("torch.unpatch", (dt := time.perf_counter() - t0)); t0 = dt
139
+ torch.init(nvidia_uuid)
140
+ print("torch.init", (dt := time.perf_counter() - t0)); t0 = dt
141
+ update(SPAWN_PROGRESS_INIT)
142
+ callback = None
143
+ if (transfer_size := torch.size()) > 0:
144
+ remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
145
+ callback = lambda n: update(n * remaining / transfer_size)
146
+ torch.move(callback=callback)
147
+ print("torch.move", (dt := time.perf_counter() - t0)); t0 = dt
148
+ except Exception as e: # pragma: no cover
149
+ traceback.print_exc()
150
+ return ExceptionResult(e)
151
+ try:
152
+ remove_tqdm_multiprocessing_lock()
153
+ except Exception: # pragma: no cover
154
+ print("Error while trying to remove tqdm mp_lock:")
155
+ traceback.print_exc()
156
+
157
+
158
+ def process_duration(duration: Duration | None):
159
+ if duration is None or isinstance(duration, timedelta):
160
+ return duration
161
+ return timedelta(seconds=duration)
162
+
163
+
164
+ def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs):
165
+ if not callable(duration):
166
+ return duration
167
+ return duration(*args, **kwargs)
168
+
169
+
170
+ def regular_function_wrapper(
171
+ task: Callable[Param, Res],
172
+ duration: DynamicDuration[Param],
173
+ ) -> Callable[Param, Res]:
174
+
175
+ import gradio as gr
176
+
177
+ request_var = gradio_request_var()
178
+ workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
179
+ task_id = id(task)
180
+
181
+ @wraps(task)
182
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
183
+
184
+ if forked:
185
+ return task(*args, **kwargs)
186
+
187
+ request = request_var.get()
188
+ duration_ = static_duration(duration, *args, **kwargs)
189
+ duration_ = process_duration(duration_)
190
+ t0 = time.perf_counter()
191
+ schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
192
+ print("client.schedule", time.perf_counter() - t0)
193
+ allow_token = schedule_response.allowToken
194
+ nvidia_index = schedule_response.nvidiaIndex
195
+ nvidia_uuid = schedule_response.nvidiaUUID
196
+ release = partial(client.release, allow_token)
197
+
198
+ try:
199
+ worker = workers.pop(nvidia_index)
200
+ except KeyError:
201
+ worker = None
202
+
203
+ if worker is not None and worker.process.is_alive() and schedule_response.idle:
204
+ assert worker.arg_queue.empty()
205
+ assert worker.res_queue.empty()
206
+ else:
207
+ worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
208
+
209
+ try:
210
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
211
+ except PicklingError: # TODO: detailed serialization diagnostic
212
+ release(fail=True)
213
+ raise
214
+
215
+ while True:
216
+ res = worker.res_queue.get()
217
+ if res is None:
218
+ release(fail=True, allow_404=True)
219
+ raise gr.Error("GPU task aborted")
220
+ if isinstance(res, ExceptionResult):
221
+ release(fail=True)
222
+ raise res.value
223
+ if isinstance(res, OkResult):
224
+ t0 = time.perf_counter()
225
+ release()
226
+ print("client.release", time.perf_counter() - t0)
227
+ workers[nvidia_index] = worker
228
+ return res.value
229
+ if isinstance(res, GradioQueueEvent):
230
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
231
+ continue
232
+ assert_never(res)
233
+
234
+
235
+ def thread_wrapper(
236
+ arg_queue: Queue[tuple[Params, GradioPartialContext]],
237
+ res_queue: Queue[RegularResQueueResult[Res] | None],
238
+ allow_token: str,
239
+ nvidia_uuid: str,
240
+ fds: list[int],
241
+ ):
242
+ global forked
243
+ forked = True
244
+ signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
245
+ initialized = False
246
+ while True:
247
+ try:
248
+ (args, kwargs), gradio_context = arg_queue.get()
249
+ except OSError:
250
+ break
251
+ if not initialized:
252
+ t0 = time.perf_counter()
253
+ if (res := worker_init(
254
+ res_queue=res_queue,
255
+ allow_token=allow_token,
256
+ nvidia_uuid=nvidia_uuid,
257
+ fds=fds,
258
+ )) is not None:
259
+ res_queue.put(res)
260
+ return
261
+ print("worker_init", time.perf_counter() - t0)
262
+ initialized = True
263
+ GradioPartialContext.apply(gradio_context)
264
+ context = copy_context()
265
+ with ThreadPoolExecutor() as executor:
266
+ future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
267
+ try:
268
+ res = future.result()
269
+ except Exception as e:
270
+ traceback.print_exc()
271
+ res = ExceptionResult(e)
272
+ else:
273
+ res = OkResult(res)
274
+ try:
275
+ res_queue.put(res)
276
+ except PicklingError as e:
277
+ res_queue.put(ExceptionResult(e))
278
+
279
+ # https://github.com/python/cpython/issues/91002
280
+ if not hasattr(task, '__annotations__'):
281
+ gradio_handler.__annotations__ = {}
282
+
283
+ return gradio_handler
284
+
285
+
286
+ def generator_function_wrapper(
287
+ task: Callable[Param, Generator[Res, None, None]],
288
+ duration: DynamicDuration[Param],
289
+ ) -> Callable[Param, Generator[Res, None, None]]:
290
+
291
+ import gradio as gr
292
+
293
+ request_var = gradio_request_var()
294
+ workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
295
+ task_id = id(task)
296
+
297
+ @wraps(task)
298
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
299
+
300
+ if forked:
301
+ yield from task(*args, **kwargs)
302
+ return
303
+
304
+ request = request_var.get()
305
+ duration_ = static_duration(duration, *args, **kwargs)
306
+ duration_ = process_duration(duration_)
307
+ schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
308
+ allow_token = schedule_response.allowToken
309
+ nvidia_index = schedule_response.nvidiaIndex
310
+ nvidia_uuid = schedule_response.nvidiaUUID
311
+ release = partial(client.release, allow_token)
312
+
313
+ try:
314
+ worker = workers.pop(nvidia_index)
315
+ except KeyError:
316
+ worker = None
317
+
318
+ if worker is not None and worker.process.is_alive() and schedule_response.idle:
319
+ assert worker.arg_queue.empty()
320
+ assert worker.res_queue.empty()
321
+ else:
322
+ worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
323
+
324
+ try:
325
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
326
+ except PicklingError: # TODO: detailed serialization diagnostic
327
+ release(fail=True)
328
+ raise
329
+
330
+ yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
331
+ def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
332
+ while True:
333
+ res = worker.res_queue.get()
334
+ if res is None:
335
+ release(fail=True, allow_404=True)
336
+ yield_queue.put(AbortedResult())
337
+ return
338
+ if isinstance(res, ExceptionResult):
339
+ release(fail=True)
340
+ yield_queue.put(ExceptionResult(res.value))
341
+ return
342
+ if isinstance(res, EndResult):
343
+ release()
344
+ workers[nvidia_index] = worker
345
+ yield_queue.put(EndResult())
346
+ return
347
+ if isinstance(res, OkResult):
348
+ yield_queue.put(OkResult(res.value))
349
+ continue
350
+ if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
351
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
352
+ continue
353
+ debug(f"fill_yield_queue: assert_never({res=})")
354
+ assert_never(res)
355
+ from typing_extensions import assert_never
356
+ with ThreadPoolExecutor() as e:
357
+ f = e.submit(copy_context().run, fill_yield_queue, worker)
358
+ f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
359
+ while True:
360
+ try:
361
+ res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
362
+ except Empty: # pragma: no cover
363
+ debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
364
+ raise
365
+ if isinstance(res, AbortedResult):
366
+ raise gr.Error("GPU task aborted")
367
+ if isinstance(res, ExceptionResult):
368
+ raise res.value
369
+ if isinstance(res, EndResult):
370
+ break
371
+ if isinstance(res, OkResult):
372
+ yield res.value
373
+ continue
374
+ debug(f"gradio_handler: assert_never({res=})")
375
+ assert_never(res)
376
+
377
+
378
+ def thread_wrapper(
379
+ arg_queue: Queue[tuple[Params, GradioPartialContext]],
380
+ res_queue: Queue[GeneratorResQueueResult[Res] | None],
381
+ allow_token: str,
382
+ nvidia_uuid: str,
383
+ fds: list[int],
384
+ ):
385
+ global forked
386
+ forked = True
387
+ signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
388
+ initialized = False
389
+ while True:
390
+ try:
391
+ (args, kwargs), gradio_context = arg_queue.get()
392
+ except OSError:
393
+ break
394
+ if not initialized:
395
+ if (res := worker_init(
396
+ res_queue=res_queue,
397
+ allow_token=allow_token,
398
+ nvidia_uuid=nvidia_uuid,
399
+ fds=fds,
400
+ )) is not None:
401
+ res_queue.put(res)
402
+ return
403
+ initialized = True
404
+ def iterate():
405
+ gen = task(*args, **kwargs) # type: ignore
406
+ while True:
407
+ try:
408
+ res = next(gen)
409
+ except StopIteration:
410
+ break
411
+ except Exception as e:
412
+ res_queue.put(ExceptionResult(e))
413
+ break
414
+ try:
415
+ res_queue.put(OkResult(res))
416
+ except PicklingError as e:
417
+ res_queue.put(ExceptionResult(e))
418
+ break
419
+ else:
420
+ continue
421
+ GradioPartialContext.apply(gradio_context)
422
+ with ThreadPoolExecutor() as executor:
423
+ executor.submit(copy_context().run, iterate)
424
+ res_queue.put(EndResult())
425
+
426
+ # https://github.com/python/cpython/issues/91002
427
+ if not hasattr(task, '__annotations__'):
428
+ gradio_handler.__annotations__ = {}
429
+
430
+ return gradio_handler