File size: 16,801 Bytes
f6228f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
# Ultralytics YOLO ๐Ÿš€, AGPL-3.0 license

import shutil
import threading
import time
from http import HTTPStatus
from pathlib import Path
from urllib.parse import parse_qs, urlparse

import requests

from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis
from ultralytics.utils.errors import HUBModelError

AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local"


class HUBTrainingSession:
    """

    HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.



    Attributes:

        model_id (str): Identifier for the YOLO model being trained.

        model_url (str): URL for the model in Ultralytics HUB.

        rate_limits (dict): Rate limits for different API calls (in seconds).

        timers (dict): Timers for rate limiting.

        metrics_queue (dict): Queue for the model's metrics.

        model (dict): Model data fetched from Ultralytics HUB.

    """

    def __init__(self, identifier):
        """

        Initialize the HUBTrainingSession with the provided model identifier.



        Args:

            identifier (str): Model identifier used to initialize the HUB training session.

                It can be a URL string or a model key with specific format.



        Raises:

            ValueError: If the provided model identifier is invalid.

            ConnectionError: If connecting with global API key is not supported.

            ModuleNotFoundError: If hub-sdk package is not installed.

        """
        from hub_sdk import HUBClient

        self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300}  # rate limits (seconds)
        self.metrics_queue = {}  # holds metrics for each epoch until upload
        self.metrics_upload_failed_queue = {}  # holds metrics for each epoch if upload failed
        self.timers = {}  # holds timers in ultralytics/utils/callbacks/hub.py
        self.model = None
        self.model_url = None
        self.model_file = None
        self.train_args = None

        # Parse input
        api_key, model_id, self.filename = self._parse_identifier(identifier)

        # Get credentials
        active_key = api_key or SETTINGS.get("api_key")
        credentials = {"api_key": active_key} if active_key else None  # set credentials

        # Initialize client
        self.client = HUBClient(credentials)

        # Load models
        try:
            if model_id:
                self.load_model(model_id)  # load existing model
            else:
                self.model = self.client.model()  # load empty model
        except Exception:
            if identifier.startswith(f"{HUB_WEB_ROOT}/models/") and not self.client.authenticated:
                LOGGER.warning(
                    f"{PREFIX}WARNING โš ๏ธ Please log in using 'yolo login API_KEY'. "
                    "You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys."
                )

    @classmethod
    def create_session(cls, identifier, args=None):
        """Class method to create an authenticated HUBTrainingSession or return None."""
        try:
            session = cls(identifier)
            if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"):  # not a HUB model URL
                session.create_model(args)
                assert session.model.id, "HUB model not loaded correctly"
            return session
        # PermissionError and ModuleNotFoundError indicate hub-sdk not installed
        except (PermissionError, ModuleNotFoundError, AssertionError):
            return None

    def load_model(self, model_id):
        """Loads an existing model from Ultralytics HUB using the provided model identifier."""
        self.model = self.client.model(model_id)
        if not self.model.data:  # then model does not exist
            raise ValueError(emojis("โŒ The specified HUB model does not exist"))  # TODO: improve error handling

        self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
        if self.model.is_trained():
            print(emojis(f"Loading trained HUB model {self.model_url} ๐Ÿš€"))
            url = self.model.get_weights_url("best")  # download URL with auth
            self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
            return

        # Set training args and start heartbeats for HUB to monitor agent
        self._set_train_args()
        self.model.start_heartbeat(self.rate_limits["heartbeat"])
        LOGGER.info(f"{PREFIX}View model at {self.model_url} ๐Ÿš€")

    def create_model(self, model_args):
        """Initializes a HUB training session with the specified model identifier."""
        payload = {
            "config": {
                "batchSize": model_args.get("batch", -1),
                "epochs": model_args.get("epochs", 300),
                "imageSize": model_args.get("imgsz", 640),
                "patience": model_args.get("patience", 100),
                "device": str(model_args.get("device", "")),  # convert None to string
                "cache": str(model_args.get("cache", "ram")),  # convert True, False, None to string
            },
            "dataset": {"name": model_args.get("data")},
            "lineage": {
                "architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
                "parent": {},
            },
            "meta": {"name": self.filename},
        }

        if self.filename.endswith(".pt"):
            payload["lineage"]["parent"]["name"] = self.filename

        self.model.create_model(payload)

        # Model could not be created
        # TODO: improve error handling
        if not self.model.id:
            return None

        self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"

        # Start heartbeats for HUB to monitor agent
        self.model.start_heartbeat(self.rate_limits["heartbeat"])

        LOGGER.info(f"{PREFIX}View model at {self.model_url} ๐Ÿš€")

    @staticmethod
    def _parse_identifier(identifier):
        """

        Parses the given identifier to determine the type of identifier and extract relevant components.



        The method supports different identifier formats:

            - A HUB model URL https://hub.ultralytics.com/models/MODEL

            - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY

            - A local filename that ends with '.pt' or '.yaml'



        Args:

            identifier (str): The identifier string to be parsed.



        Returns:

            (tuple): A tuple containing the API key, model ID, and filename as applicable.



        Raises:

            HUBModelError: If the identifier format is not recognized.

        """
        api_key, model_id, filename = None, None, None
        if Path(identifier).suffix in {".pt", ".yaml"}:
            filename = identifier
        elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
            parsed_url = urlparse(identifier)
            model_id = Path(parsed_url.path).stem  # handle possible final backslash robustly
            query_params = parse_qs(parsed_url.query)  # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
            api_key = query_params.get("api_key", [None])[0]
        else:
            raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
        return api_key, model_id, filename

    def _set_train_args(self):
        """

        Initializes training arguments and creates a model entry on the Ultralytics HUB.



        This method sets up training arguments based on the model's state and updates them with any additional

        arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,

        or requires specific file setup.



        Raises:

            ValueError: If the model is already trained, if required dataset information is missing, or if there are

                issues with the provided training arguments.

        """
        if self.model.is_resumable():
            # Model has saved weights
            self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
            self.model_file = self.model.get_weights_url("last")
        else:
            # Model has no saved weights
            self.train_args = self.model.data.get("train_args")  # new response

            # Set the model file as either a *.pt or *.yaml file
            self.model_file = (
                self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
            )

        if "data" not in self.train_args:
            # RF bug - datasets are sometimes not exported
            raise ValueError("Dataset may still be processing. Please wait a minute and try again.")

        self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False)  # YOLOv5->YOLOv5u
        self.model_id = self.model.id

    def request_queue(

        self,

        request_func,

        retry=3,

        timeout=30,

        thread=True,

        verbose=True,

        progress_total=None,

        stream_response=None,

        *args,

        **kwargs,

    ):
        """Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress."""

        def retry_request():
            """Attempts to call `request_func` with retries, timeout, and optional threading."""
            t0 = time.time()  # Record the start time for the timeout
            response = None
            for i in range(retry + 1):
                if (time.time() - t0) > timeout:
                    LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
                    break  # Timeout reached, exit loop

                response = request_func(*args, **kwargs)
                if response is None:
                    LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
                    time.sleep(2**i)  # Exponential backoff before retrying
                    continue  # Skip further processing and retry

                if progress_total:
                    self._show_upload_progress(progress_total, response)
                elif stream_response:
                    self._iterate_content(response)

                if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
                    # if request related to metrics upload
                    if kwargs.get("metrics"):
                        self.metrics_upload_failed_queue = {}
                    return response  # Success, no need to retry

                if i == 0:
                    # Initial attempt, check status code and provide messages
                    message = self._get_failure_message(response, retry, timeout)

                    if verbose:
                        LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")

                if not self._should_retry(response.status_code):
                    LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
                    break  # Not an error that should be retried, exit loop

                time.sleep(2**i)  # Exponential backoff for retries

            # if request related to metrics upload and exceed retries
            if response is None and kwargs.get("metrics"):
                self.metrics_upload_failed_queue.update(kwargs.get("metrics"))

            return response

        if thread:
            # Start a new thread to run the retry_request function
            threading.Thread(target=retry_request, daemon=True).start()
        else:
            # If running in the main thread, call retry_request directly
            return retry_request()

    @staticmethod
    def _should_retry(status_code):
        """Determines if a request should be retried based on the HTTP status code."""
        retry_codes = {
            HTTPStatus.REQUEST_TIMEOUT,
            HTTPStatus.BAD_GATEWAY,
            HTTPStatus.GATEWAY_TIMEOUT,
        }
        return status_code in retry_codes

    def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
        """

        Generate a retry message based on the response status code.



        Args:

            response: The HTTP response object.

            retry: The number of retry attempts allowed.

            timeout: The maximum timeout duration.



        Returns:

            (str): The retry message.

        """
        if self._should_retry(response.status_code):
            return f"Retrying {retry}x for {timeout}s." if retry else ""
        elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS:  # rate limit
            headers = response.headers
            return (
                f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
                f"Please retry after {headers['Retry-After']}s."
            )
        else:
            try:
                return response.json().get("message", "No JSON message.")
            except AttributeError:
                return "Unable to read JSON."

    def upload_metrics(self):
        """Upload model metrics to Ultralytics HUB."""
        return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)

    def upload_model(

        self,

        epoch: int,

        weights: str,

        is_best: bool = False,

        map: float = 0.0,

        final: bool = False,

    ) -> None:
        """

        Upload a model checkpoint to Ultralytics HUB.



        Args:

            epoch (int): The current training epoch.

            weights (str): Path to the model weights file.

            is_best (bool): Indicates if the current model is the best one so far.

            map (float): Mean average precision of the model.

            final (bool): Indicates if the model is the final model after training.

        """
        weights = Path(weights)
        if not weights.is_file():
            last = weights.with_name(f"last{weights.suffix}")
            if final and last.is_file():
                LOGGER.warning(
                    f"{PREFIX} WARNING โš ๏ธ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "
                    "This often happens when resuming training in transient environments like Google Colab. "
                    "For more reliable training, consider using Ultralytics HUB Cloud. "
                    "Learn more at https://docs.ultralytics.com/hub/cloud-training."
                )
                shutil.copy(last, weights)  # copy last.pt to best.pt
            else:
                LOGGER.warning(f"{PREFIX} WARNING โš ๏ธ Model upload issue. Missing model {weights}.")
                return

        self.request_queue(
            self.model.upload_model,
            epoch=epoch,
            weights=str(weights),
            is_best=is_best,
            map=map,
            final=final,
            retry=10,
            timeout=3600,
            thread=not final,
            progress_total=weights.stat().st_size if final else None,  # only show progress if final
            stream_response=True,
        )

    @staticmethod
    def _show_upload_progress(content_length: int, response: requests.Response) -> None:
        """

        Display a progress bar to track the upload progress of a file download.



        Args:

            content_length (int): The total size of the content to be downloaded in bytes.

            response (requests.Response): The response object from the file download request.



        Returns:

            None

        """
        with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
            for data in response.iter_content(chunk_size=1024):
                pbar.update(len(data))

    @staticmethod
    def _iterate_content(response: requests.Response) -> None:
        """

        Process the streamed HTTP response data.



        Args:

            response (requests.Response): The response object from the file download request.



        Returns:

            None

        """
        for _ in response.iter_content(chunk_size=1024):
            pass  # Do nothing with data chunks