AImused commited on
Commit
6b28dbc
·
verified ·
1 Parent(s): 76494b7

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ CUDA_HOME=/usr/local/cuda \
7
+ PATH=/usr/local/cuda/bin:$PATH \
8
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
9
+ NVIDIA_VISIBLE_DEVICES=all \
10
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ python3 \
15
+ python3-pip \
16
+ python3-dev \
17
+ build-essential \
18
+ ffmpeg \
19
+ libsndfile1 \
20
+ curl \
21
+ git \
22
+ && rm -rf /var/lib/apt/lists/*
23
+
24
+ # Upgrade pip and install build tools
25
+ RUN python3 -m pip install --upgrade pip setuptools wheel
26
+
27
+ WORKDIR /app
28
+
29
+ COPY . .
30
+
31
+ # Install requirements
32
+ RUN pip3 install --no-cache-dir -r requirements.txt
33
+
34
+ EXPOSE 8000
35
+
36
+ CMD ["python3", "server.py"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
hotkey.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5HEWNpoj22h12CoM4Bue3TyQV9X5ayHcbdG8qwennaqyPw3p
inference.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import itertools
7
+ import sys
8
+ import time
9
+ from typing import Any, Dict, List
10
+
11
+ import torch
12
+ from torch import nn
13
+ from omegaconf import DictConfig
14
+ from PIL import Image
15
+
16
+ from torchtune import config, utils
17
+ from torchtune.utils._generation import sample
18
+ from torchtune.models import convert_weights
19
+ from torchtune.data import Message
20
+
21
+ from models.tokenizer import START_IMAGE, END_IMAGE, START_AUDIO, END_AUDIO, START_VIDEO, END_VIDEO
22
+ from imagebind.models.imagebind_model import ModalityType
23
+ from diffusers import DiffusionPipeline
24
+
25
+ from models import add_proj_convert_weights, _BASE_TRAINABLE
26
+ import os
27
+
28
+ log = utils.get_logger("DEBUG")
29
+ add_proj_convert_weights()
30
+
31
+
32
+ class InferenceRecipe:
33
+ """
34
+ Recipe for generating tokens from a dense Transformer-based LLM.
35
+
36
+ Currently this recipe supports single-GPU generation only. Speculative
37
+ decoding is not supported.
38
+
39
+ For more details on how to use this recipe for generation, please see our
40
+ tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation
41
+
42
+ For using this recipe with a quantized model, please the following section of
43
+ the above tutorial:
44
+ https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization
45
+ """
46
+
47
+ def __init__(self, cfg: DictConfig) -> None:
48
+ self._device = utils.get_device(device=cfg.device)
49
+ self._dtype = utils.get_dtype(dtype=cfg.dtype)
50
+ self._quantizer = config.instantiate(cfg.inference.quantizer)
51
+ self._quantization_mode = utils.get_quantizer_mode(self._quantizer)
52
+ self.prompt_template = cfg.inference.prompt_template
53
+ perception_tokens = cfg.model.perception_tokens
54
+ self._perception_tokens = ("0 " * perception_tokens)[:perception_tokens]
55
+ utils.set_seed(seed=cfg.seed)
56
+
57
+ def setup(self, cfg: DictConfig) -> None:
58
+ checkpointer = config.instantiate(cfg.checkpointer)
59
+ if self._quantization_mode is None:
60
+ ckpt_dict = checkpointer.load_checkpoint()
61
+ else:
62
+ # weights_only needs to be False when loading a quantized model
63
+ # currently loading a quantized model is only supported with the
64
+ # FullModelTorchTuneCheckpointer
65
+ ckpt_dict = checkpointer.load_checkpoint(weights_only=False)
66
+
67
+ self._model = self._setup_model(
68
+ model_cfg=cfg.model,
69
+ model_state_dict=ckpt_dict[utils.MODEL_KEY],
70
+ )
71
+ with self._device:
72
+ self._model.setup_caches(max_batch_size=cfg.batch_size, dtype=self._dtype)
73
+
74
+ self._tokenizer = config.instantiate(cfg.tokenizer)
75
+ self._mm_ids_start = self._tokenizer.encode(START_IMAGE + START_AUDIO + START_VIDEO, add_eos=False, add_bos=False)
76
+ self._mm_ids_end = self._tokenizer.encode(END_IMAGE + END_AUDIO + END_VIDEO, add_eos=False, add_bos=False)
77
+ self.use_clip = cfg.model.use_clip
78
+ if self.use_clip:
79
+ self._clip_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=self._dtype).to(self._device)
80
+
81
+ def _setup_model(
82
+ self,
83
+ model_cfg: DictConfig,
84
+ model_state_dict: Dict[str, Any],
85
+ ) -> nn.Module:
86
+ with utils.set_default_dtype(self._dtype), self._device:
87
+ model = config.instantiate(model_cfg)
88
+
89
+ if self._quantization_mode is not None:
90
+ model = self._quantizer.quantize(model)
91
+ model = model.to(device=self._device, dtype=self._dtype)
92
+
93
+ model.load_state_dict(model_state_dict)
94
+
95
+ # Validate model was loaded in with the expected dtype.
96
+ utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
97
+ log.debug(f"Model is initialized with precision {self._dtype}.")
98
+
99
+ return model
100
+
101
+ def mm_process_prompt(self, prompt):
102
+ return (
103
+ prompt
104
+ .replace("{image}", f"{START_IMAGE}{self._perception_tokens}{END_IMAGE}")
105
+ .replace("{audio}", f"{START_AUDIO}{self._perception_tokens}{END_AUDIO}")
106
+ .replace("{video}", f"{START_VIDEO}{self._perception_tokens}{END_VIDEO}")
107
+ )
108
+
109
+ def extract_mm_context(self, video_ib_embed, tokens):
110
+ context = {}
111
+ in_mm_embed = False
112
+ for idx, tok in enumerate(tokens):
113
+ in_mm_embed = in_mm_embed and not tok in self._mm_ids_end
114
+ if in_mm_embed:
115
+ #tokens[idx] # to support multiple embeds: get the value, match it up with the sample embed
116
+ context[idx] = {
117
+ "ib_embed": video_ib_embed.to(dtype=self._dtype, device=self._device),
118
+ }
119
+ in_mm_embed = in_mm_embed or tok in self._mm_ids_start
120
+ return context
121
+
122
+ @torch.no_grad()
123
+ def generate(self, cfg: DictConfig, video_ib_embed: List[float]):
124
+ messages = [
125
+ Message(
126
+ role="user",
127
+ content=self.mm_process_prompt(self.prompt_template),
128
+ ),
129
+ Message(
130
+ role="assistant",
131
+ content="",
132
+ )
133
+ ]
134
+ tokens, mask = self._tokenizer.tokenize_messages(messages)
135
+ tokens = tokens[:-2] # strip eot and eos
136
+ mm_context = [self.extract_mm_context(video_ib_embed, tokens)] # context should be a list, batch-id indexed
137
+ prompt = torch.tensor(tokens, dtype=torch.int, device=self._device)
138
+
139
+ self._model.tok_embeddings.set_context(mm_context)
140
+ self._model.output.set_context(mm_context)
141
+
142
+ bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0]
143
+ allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all")
144
+ disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id))
145
+ # self._model.output.weight.data[disallowed_tokens, :] = 0
146
+
147
+ def custom_generate_next_token(model, input_pos, x, temperature=1.0, top_k=None):
148
+ model.tok_embeddings.set_context([])
149
+ model.output.set_context([])
150
+ # x: [1, s]
151
+ # input_pos: [s]
152
+ logits = model(x, input_pos=input_pos)
153
+ # logits: [1, s, v] where v is vocab_size
154
+ # for sampling we extract the logits for the
155
+ # last token and convert to shape: [v]
156
+ logits = logits[0, -1]
157
+ # logits[disallowed_tokens] = float("-inf")
158
+ # sample the next token
159
+ token = sample(logits, temperature, top_k)
160
+ if token in disallowed_tokens:
161
+ return torch.tensor([self._tokenizer.eos_id]).to(x)
162
+ return token
163
+
164
+ # since quantized model uses torch.compile to get speedup, it needs a warm up / prefill run
165
+ # to get the accurate performance measurement
166
+ if self._quantization_mode is not None:
167
+ log.info("Starting compilation to improve generation performance ...")
168
+ custom_generate_next_token = torch.compile(
169
+ custom_generate_next_token, mode="max-autotune", fullgraph=True
170
+ )
171
+ t0 = time.perf_counter()
172
+ _ = utils.generate(
173
+ model=self._model,
174
+ prompt=prompt,
175
+ max_generated_tokens=2,
176
+ temperature=cfg.temperature,
177
+ top_k=cfg.top_k,
178
+ eos_id=self._tokenizer.eos_id,
179
+ custom_generate_next_token=custom_generate_next_token,
180
+ )
181
+ t = time.perf_counter() - t0
182
+ log.info(f"Warmup run for quantized model takes: {t:.02f} sec")
183
+
184
+ t0 = time.perf_counter()
185
+ generated_tokens = utils.generate(
186
+ model=self._model,
187
+ prompt=prompt,
188
+ max_generated_tokens=cfg.max_new_tokens,
189
+ temperature=cfg.temperature,
190
+ top_k=cfg.top_k,
191
+ eos_id=self._tokenizer.eos_id,
192
+ custom_generate_next_token=custom_generate_next_token,
193
+ )
194
+ t = time.perf_counter() - t0
195
+
196
+ cleaned_tokens = [t for t in generated_tokens[len(prompt):] if t not in disallowed_tokens + allowed_id]
197
+ caption = self._tokenizer.decode(cleaned_tokens)
198
+
199
+ # log.debug(f"Generated caption: {caption} in {t:.02f} sec")
200
+
201
+ return caption
202
+
203
+
204
+ @torch.no_grad()
205
+ def generate_batch(self, cfg: DictConfig, video_ib_embed: torch.Tensor):
206
+ log.info(f"inside generate_batch, video_ib_embed shape: {video_ib_embed.shape}")
207
+ batch_dim = video_ib_embed.size(0)
208
+ messages = [
209
+ Message(
210
+ role="user",
211
+ content=self.mm_process_prompt(self.prompt_template),
212
+ ),
213
+ Message(role="assistant", content="")
214
+ ]
215
+ tokens, mask = self._tokenizer.tokenize_messages(messages)
216
+ tokens = tokens[:-2] # strip eot and eos
217
+ mm_context = [self.extract_mm_context(e, tokens) for e in video_ib_embed] # context should be a list, batch-id indexed
218
+ prompt = torch.tensor(tokens, dtype=torch.int, device=self._device).expand(batch_dim, -1).clone()
219
+ prompt_length = prompt.size(1)
220
+
221
+ self._model.tok_embeddings.set_context(mm_context)
222
+ self._model.output.set_context(mm_context)
223
+
224
+ bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0]
225
+ allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all")
226
+ disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id))
227
+
228
+ def generate_next_token(model, input_pos, x, temperature=1.0, top_k=None):
229
+ # x: [B, s]
230
+ # input_pos: [s]
231
+ # logits: [B, s, v] where v is vocab_size
232
+ logits = model(x, input_pos=input_pos)[:, -1]
233
+ tokens = sample(logits, temperature, top_k)
234
+ return torch.tensor([
235
+ [self._tokenizer.eos_id if t in disallowed_tokens else t for t in toks]
236
+ for toks in tokens
237
+ ]).to(x.device)
238
+
239
+ generated_tokens = prompt.clone()
240
+ # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
241
+ stop_token_reached = torch.zeros(batch_dim, dtype=torch.bool, device=prompt.device)
242
+
243
+ # generate the first tokens conditioned on the prompt
244
+ tokens = generate_next_token(
245
+ self._model,
246
+ input_pos=torch.arange(0, prompt_length, device=prompt.device),
247
+ x=prompt,
248
+ temperature=cfg.temperature,
249
+ top_k=cfg.top_k,
250
+ )
251
+ eot_reached_b = tokens == self._tokenizer.eot_id
252
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
253
+
254
+ self._model.tok_embeddings.set_context([])
255
+ self._model.output.set_context([])
256
+
257
+ input_pos = torch.tensor([prompt_length], device=prompt.device)
258
+ for _ in range(cfg.max_new_tokens - 1):
259
+ tokens = generate_next_token(
260
+ self._model, input_pos=input_pos, x=tokens, temperature=cfg.temperature, top_k=cfg.top_k
261
+ )
262
+ eot_reached_b |= tokens == self._tokenizer.eot_id
263
+ tokens *= ~eot_reached_b
264
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
265
+ if eot_reached_b.all():
266
+ print('eot_reached_b.all()')
267
+ break
268
+ input_pos += 1
269
+
270
+ captions = []
271
+ for caption_tokens in generated_tokens.tolist():
272
+ captions.append(self._tokenizer.decode(caption_tokens[prompt.size(1):]))
273
+ return captions
274
+
275
+
276
+ @config.parse
277
+ def main(cfg: DictConfig) -> None:
278
+ config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
279
+ cfg.model = DictConfig({
280
+ "_component_": "models.mmllama3_8b",
281
+ "use_clip": False,
282
+ "perception_tokens": cfg.model.perception_tokens,
283
+ })
284
+ cfg.batch_size = 4
285
+ cfg.checkpointer.checkpoint_dir = os.path.dirname("/home/salman/tezuesh/omegalabs-anytoany-bittensor/sandboxing/cache/xzistance_omega-a2a-hotkey/meta_model_0.pth")
286
+
287
+ cfg.checkpointer.checkpoint_files = ["models/meta_model_0.pt"]
288
+ cfg.inference.max_new_tokens = 300
289
+ cfg.tokenizer.path = "./models/tokenizer.model"
290
+ inference_recipe = InferenceRecipe(cfg)
291
+ inference_recipe.setup(cfg=cfg)
292
+ captions = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=torch.randn(4,1024))
293
+ print(captions)
294
+
295
+
296
+ if __name__ == "__main__":
297
+ sys.exit(main())
models/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchtune.models import convert_weights
2
+
3
+ from models.tokenizer import a2a_tokenizer
4
+ from models.mmllama3 import lora_mmllama3_8b, mmllama3_8b, imagebind_huge
5
+
6
+ __all__ = [
7
+ "a2a_tokenizer",
8
+ "lora_mmllama3_8b",
9
+ "mmllama3_8b",
10
+ "imagebind_huge",
11
+
12
+ ]
13
+
14
+ _BASE_TRAINABLE = [
15
+ "tok_embeddings.proj_to_llama.0.weight",
16
+ "tok_embeddings.proj_to_llama.0.bias",
17
+ "tok_embeddings.proj_to_llama.2.weight",
18
+ "tok_embeddings.proj_to_llama.2.bias",
19
+ "tok_embeddings.proj_to_llama.3.weight",
20
+ "tok_embeddings.proj_to_llama.3.bias",
21
+ "output.proj_from_llama.0.weight",
22
+ "output.proj_from_llama.0.bias",
23
+ "output.proj_from_llama.2.weight",
24
+ "output.proj_from_llama.2.bias",
25
+ "output.proj_from_llama.3.weight",
26
+ "output.proj_from_llama.3.bias",
27
+ ]
28
+
29
+ def add_proj_convert_weights():
30
+ # extend _FROM_META torchtune -> meta mapping with new parameter names
31
+ # allow existing ckpt-save code to work without changes
32
+ convert_weights._FROM_META.update({a: a for a in _BASE_TRAINABLE})
33
+
34
+
models/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/imagebind_wrapper.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from typing import BinaryIO, List
5
+
6
+ from imagebind import imagebind_model
7
+ from imagebind.models.imagebind_model import ModalityType
8
+ from imagebind.models.multimodal_preprocessors import SimpleTokenizer, TextPreprocessor
9
+
10
+
11
+ V2_URL = "https://huggingface.co/jondurbin/videobind-v0.2/resolve/main/videobind.pth"
12
+ V2_PATH = "./.checkpoints/videobind-v0.2.pth"
13
+ BPE_PATH = "./models/bpe_simple_vocab_16e6.txt.gz"
14
+ TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH)
15
+ LENGTH_TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH, context_length=1024)
16
+ TOKEN_CHUNK_SIZE = 74
17
+
18
+ def get_imagebind_v2(path: str=V2_PATH):
19
+ if not os.path.isfile(path):
20
+ os.makedirs(os.path.dirname(path), exist_ok=True)
21
+ torch.hub.download_url_to_file(V2_URL, path, progress=True)
22
+ imagebind_model = torch.load(path)
23
+ return imagebind_model
24
+
25
+
26
+ def load_and_transform_text(text, device):
27
+ if text is None:
28
+ return None
29
+ tokens = [TOKENIZER(t).unsqueeze(0).to(device) for t in text]
30
+ tokens = torch.cat(tokens, dim=0)
31
+ return tokens
32
+
33
+ def split_text_by_token_limit(text, tokenizer, max_tokens=TOKEN_CHUNK_SIZE):
34
+ def fits_in_token_limit(text_segment):
35
+ tokens = tokenizer(text_segment)
36
+ tokens = tokens[tokens != 0][1:-1].tolist()
37
+ return len(tokens) <= max_tokens
38
+
39
+ def recursive_split(text, delimiters):
40
+ if fits_in_token_limit(text):
41
+ return [text]
42
+ if not delimiters:
43
+ return split_by_tokens(text)
44
+ delimiter = delimiters[0]
45
+ parts = text.split(delimiter)
46
+ result = []
47
+ current_segment = ""
48
+ for part in parts:
49
+ candidate_segment = current_segment + (delimiter if current_segment else '') + part
50
+ if fits_in_token_limit(candidate_segment):
51
+ current_segment = candidate_segment
52
+ else:
53
+ if current_segment:
54
+ result.append(current_segment)
55
+ current_segment = part
56
+ if current_segment:
57
+ result.append(current_segment)
58
+ final_result = []
59
+ for segment in result:
60
+ if fits_in_token_limit(segment):
61
+ final_result.append(segment)
62
+ else:
63
+ final_result.extend(recursive_split(segment, delimiters[1:]))
64
+ return final_result
65
+
66
+ def split_by_tokens(text):
67
+ tokens = tokenizer(text)
68
+ tokens = tokens[tokens != 0][1:-1].tolist()
69
+ chunks = np.array_split(tokens, int(len(tokens) / max_tokens) or 1)
70
+ return [
71
+ tokenizer.decode(segment_tokens)
72
+ for segment_tokens in chunks
73
+ ]
74
+
75
+ return recursive_split(text, ['\n', '.', '!', '?', ',', ' '])
76
+
77
+ def load_and_transform_text_chunks(text, device):
78
+ if not text:
79
+ return []
80
+ all_tokens = LENGTH_TOKENIZER(text)
81
+ all_tokens = all_tokens[all_tokens != 0][1:-1].tolist()
82
+
83
+ return [
84
+ load_and_transform_text([segment], device)
85
+ for segment in split_text_by_token_limit(text, LENGTH_TOKENIZER)
86
+ ]
87
+
88
+
89
+ class ImageBind:
90
+ def __init__(self, device="cuda:0", v2=False):
91
+ self.device = device
92
+ self.v2 = v2
93
+ if v2:
94
+ if not os.path.exists(V2_PATH):
95
+ os.makedirs(os.path.dirname(V2_PATH), exist_ok=True)
96
+ torch.hub.download_url_to_file(
97
+ V2_URL,
98
+ V2_PATH,
99
+ progress=True,
100
+ )
101
+ self.imagebind = torch.load(V2_PATH)
102
+ else:
103
+ self.imagebind = imagebind_model.imagebind_huge(pretrained=True)
104
+ self.imagebind.eval()
105
+ self.imagebind.to(self.device)
106
+
107
+ def generate_text_embeddings(self, text: str):
108
+ if not self.v2:
109
+ return self.imagebind({
110
+ ModalityType.TEXT: load_and_transform_text([text], self.device)
111
+ })[ModalityType.TEXT]
112
+ chunks = load_and_transform_text_chunks(text, self.device)
113
+ embeddings = [
114
+ self.imagebind({ModalityType.TEXT: chunk})[ModalityType.TEXT]
115
+ for chunk in chunks
116
+ ]
117
+ return torch.mean(torch.stack(embeddings), dim=0)
118
+
119
+ """ Deactivating full embeddings as they are not used in the current implementation
120
+ def get_inputs(self, video_file: BinaryIO) -> dict:
121
+ audio_file = video_utils.copy_audio(video_file.name)
122
+ try:
123
+ duration = video_utils.get_video_duration(video_file.name)
124
+ video_data = data.load_and_transform_video_data(
125
+ [video_file.name],
126
+ self.device,
127
+ )
128
+ audio_data = data.load_and_transform_audio_data(
129
+ [audio_file.name],
130
+ self.device,
131
+ )
132
+ inputs = {
133
+ ModalityType.VISION: video_data,
134
+ ModalityType.AUDIO: audio_data,
135
+ }
136
+ return inputs
137
+ finally:
138
+ audio_file.close()
139
+
140
+ @torch.no_grad()
141
+ def embed(self, descriptions: List[str], video_files: List[BinaryIO]) -> Embeddings:
142
+ return_value = None
143
+ for idx in range(len(descriptions)):
144
+ inputs = self.get_inputs(video_files[idx])
145
+ embeddings = self.imagebind(inputs)
146
+ text_embeddings = self.generate_text_embeddings(descriptions[idx])
147
+ if not return_value:
148
+ return_value = Embeddings(
149
+ video=embeddings[ModalityType.VISION],
150
+ audio=embeddings[ModalityType.AUDIO],
151
+ description=text_embeddings,
152
+ )
153
+ else:
154
+ return_value.video = torch.cat((return_value.video, embeddings[ModalityType.VISION]))
155
+ return_value.audio = torch.cat((return_value.audio, embeddings[ModalityType.AUDIO]))
156
+ return_value.description = torch.cat((return_value.description, text_embeddings))
157
+ return return_value
158
+
159
+ @torch.no_grad()
160
+ def embed_only_video(self, video_files: List[BinaryIO]) -> Embeddings:
161
+ video_filepaths = [video_file.name for video_file in video_files]
162
+ durations = [video_utils.get_video_duration(f.name) for f in video_files]
163
+ embeddings = self.imagebind({
164
+ ModalityType.VISION: [
165
+ data.load_and_transform_video_data(
166
+ [video_filepaths[idx]],
167
+ self.device,
168
+ )[0]
169
+ for idx in range(len(video_filepaths))
170
+ ]
171
+ })
172
+ return Embeddings(
173
+ video=embeddings[ModalityType.VISION],
174
+ )
175
+
176
+ @torch.no_grad()
177
+ def embed_video_and_text(self, video_files: List[BinaryIO], descriptions: List[str]) -> Embeddings:
178
+ video_filepaths = [video_file.name for video_file in video_files]
179
+ durations = [video_utils.get_video_duration(f.name) for f in video_files]
180
+ embeddings = self.imagebind({
181
+ ModalityType.VISION: [
182
+ data.load_and_transform_video_data(
183
+ [video_filepaths[idx]],
184
+ self.device,
185
+ )[0]
186
+ for idx in range(len(video_filepaths))
187
+ ],
188
+ })
189
+ description_embeddings = torch.stack([
190
+ self.generate_text_embeddings(description)
191
+ for description in descriptions
192
+ ])
193
+ return Embeddings(
194
+ video=embeddings[ModalityType.VISION],
195
+ description=description_embeddings,
196
+ )
197
+
198
+ @torch.no_grad()
199
+ def embed_text(self, texts: List[str]) -> torch.Tensor:
200
+ return_value = None
201
+ for text in texts:
202
+ emb = self.generate_text_embeddings(text)
203
+ if not return_value:
204
+ return_value = emb
205
+ else:
206
+ return_value = torch.cat((return_value, emb))
207
+ return return_value
208
+ """
209
+
210
+ @torch.no_grad()
211
+ def embed_text(self, texts: List[str]) -> torch.Tensor:
212
+ embeddings = []
213
+ for text in texts:
214
+ emb = self.generate_text_embeddings(text)
215
+ embeddings.append(emb)
216
+
217
+ if not embeddings:
218
+ return None
219
+
220
+ # Stack all embeddings along dimension 0
221
+ return torch.stack(embeddings, dim=0)
models/meta_model_5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a40e2a0f8c070103237a1c2c147fef1ba7cac02cd53b9d320370e3fbfee7ad84
3
+ size 16219158403
models/mmllama3.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torchvision import transforms
7
+
8
+ from torchtune.models.llama3 import lora_llama3_8b, llama3_8b
9
+ from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear
10
+ from torchtune.modules import TransformerDecoder
11
+
12
+ with warnings.catch_warnings():
13
+ warnings.simplefilter("ignore", UserWarning)
14
+ from imagebind.models import imagebind_model
15
+ from models.imagebind_wrapper import get_imagebind_v2, V2_PATH
16
+ from models.imagebind_wrapper import ImageBind
17
+
18
+ IMAGEBIND_DIM = 1024
19
+ CLIP_DIM = 768
20
+
21
+
22
+ class MMEmbedding(nn.Embedding):
23
+ def __init__(self, e, perception_tokens=1, use_clip=False):
24
+ super().__init__(
25
+ num_embeddings=e.num_embeddings,
26
+ embedding_dim=e.embedding_dim,
27
+ padding_idx=e.padding_idx,
28
+ max_norm=e.max_norm,
29
+ norm_type=e.norm_type,
30
+ scale_grad_by_freq=e.scale_grad_by_freq,
31
+ sparse=e.sparse,
32
+ )
33
+ self._perception_tokens = perception_tokens
34
+ self._context = []
35
+ self._use_clip = use_clip
36
+
37
+ dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0)
38
+ dim_out = e.embedding_dim * perception_tokens
39
+
40
+ self.proj_to_llama = nn.Sequential(
41
+ nn.Linear(dim_in, dim_out),
42
+ nn.GELU(),
43
+ nn.LayerNorm(dim_out),
44
+ nn.Linear(dim_out, dim_out),
45
+ )
46
+
47
+ def set_context(self, context):
48
+ self._context = context
49
+
50
+ def forward(self, input: Tensor) -> Tensor:
51
+ r = super().forward(input)
52
+ # self._context is first indexed by batch idx
53
+ for b, context_dict in enumerate(self._context):
54
+ # then by sequence idx
55
+ for s, embed in context_dict.items():
56
+ # and then must be transformed from imagebind dim -> llama3 dim
57
+ if self._use_clip:
58
+ llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]]))
59
+ else:
60
+ llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]]))
61
+ r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1)
62
+ return r
63
+
64
+
65
+ class MMLinear(nn.Linear):
66
+ def __init__(self, o):
67
+ super().__init__(
68
+ in_features=o.in_features,
69
+ out_features=o.out_features,
70
+ bias=(o.bias != None)
71
+ )
72
+ self._context = []
73
+
74
+ dim_out = CLIP_DIM
75
+ dim_in = o.in_features
76
+ self.proj_from_llama = nn.Sequential(
77
+ nn.Linear(dim_in, dim_out),
78
+ nn.GELU(),
79
+ nn.LayerNorm(dim_out),
80
+ nn.Linear(dim_out, dim_out),
81
+ )
82
+
83
+ def set_context(self, context):
84
+ self._context = context
85
+
86
+ def forward(self, input_bsd: Tensor) -> Tensor:
87
+ # self._context has the indexes of image llama tokens: process these with proj_from_llama
88
+ self._clip_projections = []
89
+ # # self._context is first indexed by batch idx
90
+ # for b, context_dict in enumerate(self._context):
91
+ # # then by sequence idx
92
+ # for s, embed in context_dict.items():
93
+ # # and then must be transformed from llama3 dim -> clip dim
94
+ # self._clip_projections.append((
95
+ # self.proj_from_llama(input_bsd[b, s]),
96
+ # (embed["clip_embed"] if "clip_embed" in embed else None) # terrible
97
+ # ))
98
+ r = super().forward(input_bsd)
99
+ return r
100
+
101
+
102
+
103
+ def lora_mmllama3_8b(
104
+ lora_attn_modules: List[LORA_ATTN_MODULES],
105
+ apply_lora_to_mlp: bool = False,
106
+ apply_lora_to_output: bool = False,
107
+ lora_rank: int = 8,
108
+ lora_alpha: float = 16,
109
+ quantize_base: bool = False,
110
+ perception_tokens: int = 2,
111
+ use_clip: bool = False
112
+ ) -> TransformerDecoder:
113
+ llama3 = lora_llama3_8b(
114
+ lora_attn_modules,
115
+ apply_lora_to_mlp,
116
+ apply_lora_to_output,
117
+ lora_rank,
118
+ lora_alpha,
119
+ quantize_base,
120
+ )
121
+ llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
122
+ llama3.output = MMLinear(llama3.output)
123
+ return llama3
124
+
125
+
126
+ def mmllama3_8b(
127
+ perception_tokens: int = 2,
128
+ use_clip: bool = False
129
+ ) -> TransformerDecoder:
130
+ llama3 = llama3_8b()
131
+ llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
132
+ llama3.output = MMLinear(llama3.output)
133
+ return llama3
134
+
135
+
136
+ def imagebind_huge(use_v2: bool=True):
137
+ if use_v2:
138
+ imagebind = ImageBind(v2=True)
139
+ else:
140
+ imagebind = imagebind_model.imagebind_huge(pretrained=True)
141
+ imagebind.transform_from_pil = transforms.Compose([
142
+ transforms.Resize(
143
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
144
+ ),
145
+ transforms.CenterCrop(224),
146
+ transforms.ToTensor(),
147
+ transforms.Normalize(
148
+ mean=(0.48145466, 0.4578275, 0.40821073),
149
+ std=(0.26862954, 0.26130258, 0.27577711),
150
+ ),
151
+ ])
152
+ return imagebind
153
+
models/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82e9d31979e92ab929cd544440f129d9ecd797b69e327f80f17e1c50d5551b55
3
+ size 2183982
models/tokenizer.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from torchtune.modules.tokenizers import TikTokenTokenizer
4
+ from torchtune.modules.tokenizers._utils import _split_long_repetitions
5
+ from torchtune.modules.tokenizers._tiktoken import (
6
+ MAX_ENCODE_CHARS,
7
+ MAX_NO_WHITESPACE_CHARS,
8
+ ALL_SPECIAL_TOKENS,
9
+ )
10
+
11
+
12
+ # use special tokens from TikTokenTokenizer, add some for MM delimiters
13
+ START_IMAGE = "<|start_image|>"
14
+ END_IMAGE = "<|end_image|>"
15
+ START_VIDEO = "<|start_video|>"
16
+ END_VIDEO = "<|end_video|>"
17
+ START_AUDIO = "<|start_audio|>"
18
+ END_AUDIO = "<|end_audio|>"
19
+
20
+ A2A_SPECIAL_TOKENS = ALL_SPECIAL_TOKENS[:-2] + [
21
+ START_IMAGE,
22
+ END_IMAGE,
23
+ START_VIDEO,
24
+ END_VIDEO,
25
+ START_AUDIO,
26
+ END_AUDIO,
27
+ ] + ALL_SPECIAL_TOKENS[-2:]
28
+
29
+ # override to allow START_IMAGE, END_IMAGE to be encoded
30
+ class A2ATokenizer(TikTokenTokenizer):
31
+ def encode(
32
+ self,
33
+ text: str,
34
+ add_bos: bool,
35
+ add_eos: bool,
36
+ ) -> List[int]:
37
+ """
38
+ Encode a string into a list of token ids. Assumes that the string
39
+ contains no special tokens.
40
+
41
+ Args:
42
+ text (str): The string to encode.
43
+ add_bos (bool): Whether to add the beginning of sequence token.
44
+ add_eos (bool): Whether to add the end of sequence token.
45
+
46
+ Returns:
47
+ List[int]: The list of token ids.
48
+ """
49
+ substrs: List[str] = []
50
+ tokens = []
51
+ for i in range(0, len(text), MAX_ENCODE_CHARS):
52
+ substr = text[i : i + MAX_ENCODE_CHARS]
53
+ # See https://github.com/openai/tiktoken/issues/195
54
+ sliced_substr = _split_long_repetitions(substr, MAX_NO_WHITESPACE_CHARS)
55
+ substrs.extend(sliced_substr)
56
+ for substr in substrs:
57
+ # allowed_special and disallowed_special are used by tiktoken to define
58
+ # how special tokens are encoded. Our setting here is to encode any
59
+ # special token as regular text and prevent tiktoken from raising errors.
60
+ # This means we should only call encode on strings not containing special tokens.
61
+ tokens.extend(
62
+ self.tt_model.encode(
63
+ substr,
64
+ allowed_special=set([
65
+ START_IMAGE,
66
+ END_IMAGE,
67
+ START_VIDEO,
68
+ END_VIDEO,
69
+ START_AUDIO,
70
+ END_AUDIO,
71
+ ]),
72
+ disallowed_special=(),
73
+ )
74
+ )
75
+ if add_bos:
76
+ tokens.insert(0, self.bos_id)
77
+ if add_eos:
78
+ tokens.append(self.eos_id)
79
+ return tokens
80
+
81
+
82
+ def a2a_tokenizer(path: str) -> TikTokenTokenizer:
83
+ tiktoken = A2ATokenizer(path, all_special_tokens=A2A_SPECIAL_TOKENS)
84
+ tiktoken.pad_id = 0
85
+ return tiktoken
models/training_config.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ _component_: models.lora_mmllama3_8b
3
+ lora_attn_modules:
4
+ - q_proj
5
+ - v_proj
6
+ apply_lora_to_mlp: false
7
+ apply_lora_to_output: false
8
+ lora_rank: 8
9
+ lora_alpha: 16
10
+ perception_tokens: 2
11
+ use_clip: false
12
+ tokenizer:
13
+ _component_: models.a2a_tokenizer
14
+ path: models/tokenizer.model
15
+ checkpointer:
16
+ _component_: torchtune.utils.FullModelMetaCheckpointer
17
+ checkpoint_dir: /workspace/omega_a2a/training
18
+ checkpoint_files:
19
+ - consolidated.00.pth
20
+ adapter_checkpoint: null
21
+ recipe_checkpoint: null
22
+ output_dir: /workspace/omega_a2a/checkpoints
23
+ model_type: LLAMA3
24
+ resume_from_checkpoint: false
25
+ interim_checkpoint_steps: 5000
26
+ interim_gen_steps: null
27
+ max_new_tokens: 170
28
+ temperature: 0.8
29
+ top_k: 200
30
+ dataset:
31
+ _component_: ds.EvenBatcher
32
+ buffer_size: 36
33
+ dataset:
34
+ _component_: ds.RoundRobinDataset
35
+ datasets:
36
+ - _component_: ds.OmegaVideoCaptionDataset
37
+ length: 500000
38
+ - _component_: ds.LlavaInstructDataset
39
+ dataset_path: ds/coco_llava_instruct/output.parquet
40
+ train_on_input: false
41
+ - _component_: ds.LlavaInstructDataset
42
+ dataset_path: ds/vision_flan/output.parquet
43
+ train_on_input: false
44
+ - _component_: ds.CaptionInstructDataset
45
+ dataset_path: ds/sam_llava/output.parquet
46
+ train_on_input: false
47
+ seed: null
48
+ shuffle: true
49
+ batch_size: 4
50
+ optimizer:
51
+ _component_: torch.optim.AdamW
52
+ weight_decay: 0.0001
53
+ lr: 3.0e-05
54
+ lr_scheduler:
55
+ _component_: torchtune.modules.get_cosine_schedule_with_warmup
56
+ num_warmup_steps: 100
57
+ loss:
58
+ _component_: torch.nn.CrossEntropyLoss
59
+ epochs: 6
60
+ max_steps_per_epoch: null
61
+ gradient_accumulation_steps: 64
62
+ compile: false
63
+ output_dir: /tmp/lora_finetune_output
64
+ metric_logger:
65
+ _component_: torchtune.utils.metric_logging.DiskLogger
66
+ log_dir: ${output_dir}
67
+ log_every_n_steps: null
68
+ device: cuda
69
+ dtype: bf16
70
+ enable_activation_checkpointing: false
71
+ profiler:
72
+ _component_: torchtune.utils.profiler
73
+ enabled: false
74
+ inference:
75
+ prompt_template: 'Video:
76
+
77
+ {video}
78
+
79
+ Caption the previous video.'
80
+ max_new_tokens: 300
81
+ temperature: 0.6
82
+ top_k: 5
83
+ quantizer: null
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ sentencepiece==0.2.0
3
+ tiktoken==0.4.0
4
+ torchtune @ git+https://github.com/pytorch/torchtune.git@8f59c2fecd722691271eecca630a526719a32f76#egg=torchtune
5
+ lm_eval==0.4
6
+ torchvision==0.21.0
7
+ diffusers==0.27.2
8
+ imagebind @ git+https://github.com/omegalabsinc/ImageBind.git@c3c3b2e1ce6fd850ff42ce0375823fe22880a7cc#egg=imagebind
9
+ llama3 @ git+https://github.com/meta-llama/llama3.git@af6eedf7042fb51d00b2b26d8ef1ceaab73e1670
10
+ pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d
11
+ wandb==0.17.1
12
+ numpy==1.26.4
13
+ huggingface-hub==0.24.0
14
+ omegaconf==2.3.0
15
+ uvicorn==0.25.0
16
+ fastapi==0.104.1
17
+ pydantic==2.5.2
18
+ torchaudio==2.6.0
server.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import numpy as np
3
+ import torch
4
+ from pydantic import BaseModel
5
+ from typing import List
6
+ import base64
7
+ import io
8
+ import os
9
+ import logging
10
+ from pathlib import Path
11
+ from inference import InferenceRecipe
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ from omegaconf import OmegaConf, DictConfig
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI()
20
+
21
+ # Add CORS middleware
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ class EmbeddingRequest(BaseModel):
31
+ embedding: List[float]
32
+
33
+ class TextResponse(BaseModel):
34
+ texts: List[str] = []
35
+
36
+ # Model initialization status
37
+ INITIALIZATION_STATUS = {
38
+ "model_loaded": False,
39
+ "error": None
40
+ }
41
+
42
+ # Global model instance
43
+ inference_recipe = None
44
+ cfg = None
45
+
46
+
47
+ def initialize_model():
48
+ """Initialize the model with correct path resolution"""
49
+ global inference_recipe, INITIALIZATION_STATUS, cfg
50
+ try:
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ logger.info(f"Initializing model on device: {device}")
53
+
54
+ # Critical: Use absolute path for model loading
55
+ model_path = os.path.abspath(os.path.join('/app', 'models'))
56
+ logger.info(f"Loading models from: {model_path}")
57
+
58
+ if not os.path.exists(model_path):
59
+ raise RuntimeError(f"Model path {model_path} does not exist")
60
+
61
+ # Log available model files for debugging
62
+ model_files = os.listdir(model_path)
63
+ logger.info(f"Available model files: {model_files}")
64
+
65
+ cfg = OmegaConf.load(os.path.join('/app', 'training_config.yml'))
66
+ cfg.model = DictConfig({
67
+ "_component_": "models.mmllama3_8b",
68
+ "use_clip": False,
69
+ "perception_tokens": cfg.model.perception_tokens,
70
+ })
71
+ cfg.checkpointer.checkpoint_dir = model_path
72
+ cfg.checkpointer.checkpoint_files = ["meta_model_5.pt"]
73
+ cfg.inference.max_new_tokens = 300
74
+ cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model")
75
+ inference_recipe = InferenceRecipe(cfg)
76
+ inference_recipe.setup(cfg=cfg)
77
+ INITIALIZATION_STATUS["model_loaded"] = True
78
+ logger.info("Model initialized successfully")
79
+ return True
80
+ except Exception as e:
81
+ INITIALIZATION_STATUS["error"] = str(e)
82
+ logger.error(f"Failed to initialize model: {e}")
83
+ return False
84
+
85
+ @app.on_event("startup")
86
+ async def startup_event():
87
+ """Initialize model on startup"""
88
+ initialize_model()
89
+
90
+ @app.get("/api/v1/health")
91
+ def health_check():
92
+ """Health check endpoint"""
93
+ status = {
94
+ "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
95
+ "initialization_status": INITIALIZATION_STATUS
96
+ }
97
+
98
+ if inference_recipe is not None:
99
+ status.update({
100
+ "device": str(inference_recipe._device),
101
+ "dtype": str(inference_recipe._dtype)
102
+ })
103
+
104
+ return status
105
+
106
+ @app.post("/api/v1/inference")
107
+ async def inference(request: EmbeddingRequest) -> TextResponse:
108
+ """Run inference with enhanced error handling and logging"""
109
+ if not INITIALIZATION_STATUS["model_loaded"]:
110
+ raise HTTPException(
111
+ status_code=503,
112
+ detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
113
+ )
114
+
115
+ try:
116
+ # Log input validation
117
+ logger.info("Received inference request")
118
+
119
+ # Convert embedding to tensor
120
+ embedding = request.embedding # generate() expects List[float]
121
+ embedding = torch.tensor(embedding)
122
+ embedding = embedding.unsqueeze(0) # Add batch dimension
123
+ embedding = embedding.reshape(-1, 1024)
124
+ logger.info(f"Converted embedding to tensor with shape: {embedding.shape}")
125
+
126
+ # Run inference
127
+ results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding)
128
+ logger.info("Generation complete")
129
+
130
+ # Convert results to list if it's not already
131
+ if isinstance(results, str):
132
+ results = [results]
133
+
134
+ return TextResponse(texts=results)
135
+
136
+ except Exception as e:
137
+ logger.error(f"Inference failed: {str(e)}", exc_info=True)
138
+ raise HTTPException(
139
+ status_code=500,
140
+ detail=str(e)
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ import uvicorn
145
+ uvicorn.run(app, host="0.0.0.0", port=8000)
146
+
setup.py ADDED
File without changes
test.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Configure bash error handling
4
+ set -euo pipefail
5
+
6
+ # Configuration
7
+ API_HOST="localhost"
8
+ API_PORT="8000"
9
+ API_VERSION="v1"
10
+ BASE_URL="http://${API_HOST}:${API_PORT}/api/${API_VERSION}"
11
+
12
+ # Function to generate test embedding data
13
+ generate_test_embedding() {
14
+ python3 - <<EOF
15
+ import numpy as np
16
+ import json
17
+
18
+ # Generate a 4096-dimensional embedding vector (correct dimension for model)
19
+ embedding = np.random.randn(4096).astype(np.float32)
20
+ # Normalize the embedding
21
+ embedding = embedding / np.linalg.norm(embedding)
22
+ print(json.dumps(embedding.tolist()), end="")
23
+ EOF
24
+ }
25
+
26
+ # Function to test health endpoint
27
+ test_health() {
28
+ echo "Testing health endpoint..."
29
+ curl -s "${BASE_URL}/health" || {
30
+ echo "Health check failed"
31
+ exit 1
32
+ }
33
+ }
34
+
35
+ # Function to test inference endpoint
36
+ test_inference() {
37
+ echo
38
+ start_time=$(date +%s)
39
+ echo "Testing inference endpoint..."
40
+ local embedding_data=$(generate_test_embedding)
41
+
42
+ curl -X POST "${BASE_URL}/inference" \
43
+ -H "Content-Type: application/json" \
44
+ -d "{
45
+ \"embedding\": ${embedding_data}
46
+ }" || {
47
+ echo "Inference request failed"
48
+ exit 1
49
+ }
50
+ end_time=$(date +%s)
51
+ duration=$((end_time - start_time))
52
+ echo "Inference request completed in ${duration} seconds"
53
+ }
54
+
55
+ main() {
56
+ test_health
57
+ test_inference
58
+ }
59
+
60
+ main "$@"
training_config.yml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ identity_token: 0 1 2
2
+ model:
3
+ _component_: models.lora_mmllama3_8b
4
+ lora_attn_modules:
5
+ - q_proj
6
+ - v_proj
7
+ apply_lora_to_mlp: false
8
+ apply_lora_to_output: false
9
+ lora_rank: 8
10
+ lora_alpha: 16
11
+ perception_tokens: 2
12
+ use_clip: false
13
+ tokenizer:
14
+ _component_: models.a2a_tokenizer
15
+ path: checkpoints/Meta-Llama-3-8B-Instruct/original/tokenizer.model
16
+ checkpointer:
17
+ _component_: torchtune.utils.FullModelMetaCheckpointer
18
+ checkpoint_dir: checkpoints/Meta-Llama-3-8B-Instruct/original/
19
+ checkpoint_files:
20
+ - consolidated.00.pth
21
+ adapter_checkpoint: null
22
+ recipe_checkpoint: null
23
+ output_dir: output_checkpoints/experiment_4
24
+ model_type: LLAMA3
25
+ resume_from_checkpoint: false
26
+ interim_checkpoint_steps: 1500000
27
+ interim_gen_steps: null
28
+ max_new_tokens: 100
29
+ temperature: 0.6
30
+ top_k: 300
31
+ dataset:
32
+ _component_: ds.EvenBatcher
33
+ dataset:
34
+ _component_: ds.RoundRobinDataset
35
+ datasets:
36
+ - _component_: ds.IdentityDataset
37
+ identity: ${identity_token}
38
+ length: 250000
39
+ train_on_input: true
40
+ seed: null
41
+ shuffle: true
42
+ batch_size: 4
43
+ optimizer:
44
+ _component_: torch.optim.AdamW
45
+ weight_decay: 0.01
46
+ lr: 0.0003
47
+ lr_scheduler:
48
+ _component_: torchtune.modules.get_cosine_schedule_with_warmup
49
+ num_warmup_steps: 100
50
+ loss:
51
+ _component_: torch.nn.CrossEntropyLoss
52
+ epochs: 1
53
+ max_steps_per_epoch: null
54
+ gradient_accumulation_steps: 64
55
+ compile: false
56
+ output_dir: /tmp/lora_finetune_output
57
+ metric_logger:
58
+ _component_: torchtune.utils.metric_logging.DiskLogger
59
+ log_dir: ${output_dir}
60
+ log_every_n_steps: null
61
+ device: cuda
62
+ dtype: bf16
63
+ enable_activation_checkpointing: false
64
+ profiler:
65
+ _component_: torchtune.utils.profiler
66
+ enabled: false
67
+ inference:
68
+ prompt_template: 'Video:
69
+
70
+ {video}
71
+
72
+ Caption the previous video.'
73
+ max_new_tokens: 300
74
+ temperature: 0.6
75
+ top_k: 300
76
+ quantizer: null