Ashrafb commited on
Commit
54d4cab
·
verified ·
1 Parent(s): a1917c8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -1222
main.py CHANGED
@@ -1,1237 +1,44 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- from PIL import Image
5
- from einops import rearrange
6
- from torchvision.transforms.v2 import (
7
- Compose,
8
- Resize,
9
- InterpolationMode,
10
- ToImage,
11
- ToDtype,
12
- Normalize,
13
- )
14
-
15
- from transformers import CodeGenTokenizerFast as Tokenizer
16
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
17
- import re
18
-
19
- import math
20
- from typing import Optional
21
-
22
- from transformers import PretrainedConfig
23
-
24
-
25
- import math
26
- from dataclasses import dataclass, field
27
- from typing import Any, Dict, Optional, Tuple, Union
28
-
29
- import torch
30
- import torch.nn as nn
31
- from einops import rearrange, repeat
32
- from transformers import PretrainedConfig, PreTrainedModel
33
- from transformers.activations import ACT2FN
34
- from transformers.modeling_outputs import CausalLMOutputWithPast
35
-
36
- pad_input, unpad_input = None, None
37
- FlashRotaryEmbedding = None
38
- FlashSelfAttention, FlashCrossAttention = None, None
39
- FusedDense = None
40
-
41
- if torch.cuda.is_available():
42
- DEVICE = "cuda"
43
- DTYPE = torch.float16
44
- else:
45
- DEVICE = "cpu"
46
- DTYPE = torch.float32
47
-
48
-
49
- class PhiConfig(PretrainedConfig):
50
- """Phi configuration."""
51
-
52
- model_type = "phi-msft"
53
- attribute_map = {
54
- "max_position_embeddings": "n_positions",
55
- "hidden_size": "n_embd",
56
- "num_attention_heads": "n_head",
57
- "num_hidden_layers": "n_layer",
58
- }
59
-
60
- def __init__(
61
- self,
62
- vocab_size: int = 50304,
63
- n_positions: int = 2048,
64
- n_embd: int = 1024,
65
- n_layer: int = 20,
66
- n_inner: Optional[int] = None,
67
- n_head: int = 16,
68
- n_head_kv: Optional[int] = None,
69
- rotary_dim: Optional[int] = 32,
70
- activation_function: Optional[str] = "gelu_new",
71
- flash_attn: bool = False,
72
- flash_rotary: bool = False,
73
- fused_dense: bool = False,
74
- attn_pdrop: float = 0.0,
75
- embd_pdrop: float = 0.0,
76
- resid_pdrop: float = 0.0,
77
- layer_norm_epsilon: float = 1e-5,
78
- initializer_range: float = 0.02,
79
- tie_word_embeddings: bool = False,
80
- pad_vocab_size_multiple: int = 64,
81
- gradient_checkpointing: bool = False,
82
- **kwargs,
83
- ) -> None:
84
- self.vocab_size = int(
85
- math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
86
- )
87
- self.n_positions = n_positions
88
- self.n_embd = n_embd
89
- self.n_layer = n_layer
90
- self.n_inner = n_inner
91
- self.n_head = n_head
92
- self.n_head_kv = n_head_kv
93
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
94
- self.activation_function = activation_function
95
- self.flash_attn = flash_attn
96
- self.flash_rotary = flash_rotary
97
- self.fused_dense = fused_dense
98
- self.attn_pdrop = attn_pdrop
99
- self.embd_pdrop = embd_pdrop
100
- self.resid_pdrop = resid_pdrop
101
- self.layer_norm_epsilon = layer_norm_epsilon
102
- self.initializer_range = initializer_range
103
- self.gradient_checkpointing = gradient_checkpointing
104
-
105
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
106
-
107
-
108
- @dataclass
109
- class InferenceParams:
110
- """Inference parameters passed to model to efficiently calculate
111
- and store context during inference.
112
-
113
- Reference:
114
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
115
-
116
- Args:
117
- max_seqlen: Maximum sequence length.
118
- max_batch_size: Maximum batch size.
119
- seqlen_offset: Sequence length offset.
120
- batch_size_offset: Batch size offset.
121
- key_value_memory_dict: Key value memory dictionary.
122
- lengths_per_sample: Lengths per sample.
123
-
124
- """
125
-
126
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
127
-
128
- max_batch_size: int = field(metadata={"help": "Maximum batch size."})
129
-
130
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
131
-
132
- batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
133
-
134
- key_value_memory_dict: Dict[str, Any] = field(
135
- default_factory=dict, metadata={"help": "Key value memory dictionary."}
136
- )
137
-
138
- lengths_per_sample: torch.Tensor = field(
139
- default=None, metadata={"help": "Lengths per sample."}
140
- )
141
-
142
-
143
- class Embedding(nn.Module):
144
- """Token embedding with dropout."""
145
-
146
- def __init__(self, config: PretrainedConfig) -> None:
147
- super().__init__()
148
-
149
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
150
- self.drop = nn.Dropout(config.embd_pdrop)
151
-
152
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
153
- input_shape = input_ids.size()
154
- input_ids = input_ids.view(-1, input_shape[-1])
155
-
156
- hidden_states = self.wte(input_ids)
157
- hidden_states = self.drop(hidden_states)
158
-
159
- return hidden_states
160
-
161
-
162
- # @torch.compile
163
- def _apply_rotary_emb(
164
- x: torch.FloatTensor,
165
- cos: torch.FloatTensor,
166
- sin: torch.FloatTensor,
167
- ) -> torch.FloatTensor:
168
- _, seqlen, _, _ = x.shape
169
- _, rotary_dim = cos.shape
170
- rotary_dim *= 2
171
-
172
- x_rot = x[:, :, :, :rotary_dim]
173
- x_pass = x[:, :, :, rotary_dim:]
174
-
175
- x1, x2 = x_rot.chunk(2, dim=-1)
176
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
177
- sin[:seqlen], "s d -> s 1 d"
178
- )
179
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
180
-
181
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
182
-
183
- return torch.cat([x_rot, x_pass], axis=-1)
184
-
185
-
186
- # @torch.compile
187
- def _apply_rotary_emb_kv(
188
- kv: torch.FloatTensor,
189
- cos: torch.FloatTensor,
190
- sin: torch.FloatTensor,
191
- cos_k: Optional[torch.FloatTensor] = None,
192
- sin_k: Optional[torch.FloatTensor] = None,
193
- ) -> torch.FloatTensor:
194
- _, seqlen, _, _, _ = kv.shape
195
- _, rotary_dim = cos.shape
196
- rotary_dim *= 2
197
-
198
- k_rot = kv[:, :, 0, :, :rotary_dim]
199
- k_pass = kv[:, :, 0, :, rotary_dim:]
200
-
201
- k1, k2 = k_rot.chunk(2, dim=-1)
202
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
203
- sin[:seqlen], "s d -> s 1 d"
204
- )
205
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
206
-
207
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
208
-
209
- return torch.cat(
210
- [
211
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
212
- kv[:, :, 1:2, :, :],
213
- ],
214
- axis=2,
215
- )
216
-
217
-
218
- # @torch.compile
219
- def _apply_rotary_emb_qkv(
220
- qkv: torch.FloatTensor,
221
- cos: torch.FloatTensor,
222
- sin: torch.FloatTensor,
223
- cos_k: Optional[torch.FloatTensor] = None,
224
- sin_k: Optional[torch.FloatTensor] = None,
225
- ) -> torch.FloatTensor:
226
- _, seqlen, _, _, _ = qkv.shape
227
- _, rotary_dim = cos.shape
228
- rotary_dim *= 2
229
-
230
- q_rot = qkv[:, :, 0, :, :rotary_dim]
231
- q_pass = qkv[:, :, 0, :, rotary_dim:]
232
-
233
- k_rot = qkv[:, :, 1, :, :rotary_dim]
234
- k_pass = qkv[:, :, 1, :, rotary_dim:]
235
-
236
- q1, q2 = q_rot.chunk(2, dim=-1)
237
- k1, k2 = k_rot.chunk(2, dim=-1)
238
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
239
- sin[:seqlen], "s d -> s 1 d"
240
- )
241
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
242
-
243
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
244
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
245
-
246
- return torch.cat(
247
- [
248
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
249
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
250
- qkv[:, :, 2:3, :, :],
251
- ],
252
- axis=2,
253
- )
254
-
255
-
256
- class RotaryEmbedding(nn.Module):
257
- """Rotary positional embedding (RoPE).
258
-
259
- Reference:
260
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
261
- https://arxiv.org/pdf/2104.09864.pdf.
262
-
263
- """
264
-
265
- def __init__(
266
- self,
267
- dim: int,
268
- base: int = 10000,
269
- scale_base: Optional[float] = None,
270
- pos_idx_in_fp32: bool = True,
271
- max_position_embeddings: int = 2048,
272
- device: Optional[str] = None,
273
- **kwargs,
274
- ) -> None:
275
- super().__init__()
276
-
277
- if scale_base is not None:
278
- raise NotImplementedError
279
-
280
- self.dim = dim
281
- self.base = float(base)
282
- self.scale_base = scale_base
283
- self.pos_idx_in_fp32 = pos_idx_in_fp32
284
- self.max_position_embeddings = max_position_embeddings
285
- self.device = device
286
-
287
- # Generate and save the inverse frequency buffer (non-trainable)
288
- inv_freq = self._compute_inv_freq(device)
289
- self.register_buffer("inv_freq", inv_freq, persistent=False)
290
-
291
- # Generate and save the scale buffer (non-trainable)
292
- scale = (
293
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
294
- / (1.4 * dim)
295
- if scale_base is not None
296
- else None
297
- )
298
- self.register_buffer("scale", scale, persistent=False)
299
-
300
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
301
- self._update_cos_sin_cache(
302
- max_position_embeddings, device=device, dtype=torch.float32
303
- )
304
-
305
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
306
- return 1.0 / (
307
- self.base
308
- ** (
309
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
310
- / self.dim
311
- )
312
- )
313
-
314
- def _update_cos_sin_cache(
315
- self,
316
- seqlen: int,
317
- device: Optional[str] = None,
318
- dtype: Optional[torch.dtype] = None,
319
- ) -> None:
320
- self._seq_len_cached = seqlen
321
-
322
- # fp32 is preferred since the output of `torch.arange` can be quite large
323
- # and bf16 would lose a lot of precision
324
- if self.pos_idx_in_fp32:
325
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
326
- if self.inv_freq.dtype != torch.float32:
327
- inv_freq = self._compute_inv_freq(device=device)
328
- else:
329
- inv_freq = self.inv_freq
330
- else:
331
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
332
- inv_freq = self.inv_freq
333
-
334
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
335
- freqs = torch.outer(t, inv_freq)
336
- if self.scale is None:
337
- self._cos_cached = torch.cos(freqs).to(dtype)
338
- self._sin_cached = torch.sin(freqs).to(dtype)
339
- else:
340
- power = (
341
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
342
- - seqlen // 2
343
- ) / self.scale_base
344
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
345
-
346
- # Force the scale multiplication to happen in fp32
347
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
348
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
349
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
350
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
351
-
352
- def forward(
353
- self,
354
- qkv: torch.Tensor,
355
- kv: Optional[torch.Tensor] = None,
356
- seqlen_offset: int = 0,
357
- **kwargs,
358
- ) -> Tuple[torch.Tensor, torch.Tensor]:
359
- if (
360
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
361
- or self._cos_cached.device != qkv.device
362
- or self._cos_cached.dtype != qkv.dtype
363
- or (self.training and self._cos_cached.is_inference())
364
- ):
365
- self._update_cos_sin_cache(
366
- qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
367
- )
368
-
369
- if kv is None:
370
- return _apply_rotary_emb_qkv(
371
- qkv,
372
- self._cos_cached[seqlen_offset:],
373
- self._sin_cached[seqlen_offset:],
374
- )
375
- else:
376
- q = _apply_rotary_emb(
377
- qkv,
378
- self._cos_cached[seqlen_offset:],
379
- self._sin_cached[seqlen_offset:],
380
- )
381
- kv = _apply_rotary_emb_kv(
382
- kv,
383
- self._cos_cached[seqlen_offset:],
384
- self._sin_cached[seqlen_offset:],
385
- )
386
-
387
- return q, kv
388
-
389
-
390
- class MLP(nn.Module):
391
- """Multi-Layer Perceptron.
392
-
393
- Reference:
394
- Attention Is All You Need.
395
- https://arxiv.org/pdf/1706.03762.pdf.
396
-
397
- """
398
-
399
- def __init__(
400
- self,
401
- config: PretrainedConfig,
402
- n_inner: Optional[int] = None,
403
- act_fn: Optional[str] = None,
404
- ) -> None:
405
- super().__init__()
406
-
407
- act_fn = config.activation_function if act_fn is None else act_fn
408
-
409
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
410
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
411
-
412
- self.fc1 = nn.Linear(config.n_embd, n_inner)
413
- self.fc2 = nn.Linear(n_inner, config.n_embd)
414
- self.act = ACT2FN[act_fn]
415
-
416
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
417
- hidden_states = self.fc1(hidden_states)
418
- hidden_states = self.act(hidden_states)
419
- hidden_states = self.fc2(hidden_states)
420
-
421
- return hidden_states
422
-
423
-
424
- class SelfAttention(nn.Module):
425
- """Self-attention layer (compatible with PyTorch).
426
-
427
- Reference:
428
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
429
-
430
- """
431
-
432
- def __init__(
433
- self,
434
- causal: bool = True,
435
- softmax_scale: Optional[float] = None,
436
- attention_dropout: float = 0.0,
437
- ) -> None:
438
- super().__init__()
439
-
440
- self.causal = causal
441
- self.softmax_scale = softmax_scale
442
- self.drop = nn.Dropout(attention_dropout)
443
-
444
- @torch.autocast("cpu", enabled=False)
445
- @torch.autocast("cuda", enabled=False)
446
- def forward(
447
- self,
448
- qkv: torch.FloatTensor,
449
- causal: bool = None,
450
- key_padding_mask: Optional[torch.BoolTensor] = None,
451
- **kwargs,
452
- ) -> torch.FloatTensor:
453
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
454
- q, k, v = qkv.unbind(dim=2)
455
-
456
- q = q.to(torch.float32)
457
- k = k.to(torch.float32)
458
-
459
- causal = self.causal if causal is None else causal
460
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
461
-
462
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
463
- # using float16, which might lead to overflow
464
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
465
-
466
- if key_padding_mask is not None:
467
- padding_mask = torch.full(
468
- (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
469
- )
470
- padding_mask.masked_fill_(key_padding_mask, 0.0)
471
-
472
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
473
-
474
- if causal:
475
- causal_mask = torch.triu(
476
- torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
477
- )
478
- scores = scores + causal_mask.to(dtype=scores.dtype)
479
-
480
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
481
- attention = self.drop(attention)
482
-
483
- output = torch.einsum("bhts,bshd->bthd", attention, v)
484
-
485
- return output
486
-
487
-
488
- class CrossAttention(nn.Module):
489
- """Cross-attention layer (compatible with PyTorch).
490
-
491
- Reference:
492
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
493
-
494
- """
495
-
496
- def __init__(
497
- self,
498
- causal: bool = True,
499
- softmax_scale: Optional[float] = None,
500
- attention_dropout: float = 0.0,
501
- ) -> None:
502
- super().__init__()
503
-
504
- self.causal = causal
505
- self.softmax_scale = softmax_scale
506
- self.drop = nn.Dropout(attention_dropout)
507
-
508
- @torch.autocast("cpu", enabled=False)
509
- @torch.autocast("cuda", enabled=False)
510
- def forward(
511
- self,
512
- q: torch.FloatTensor,
513
- kv: torch.FloatTensor,
514
- causal: bool = None,
515
- key_padding_mask: Optional[torch.BoolTensor] = None,
516
- **kwargs,
517
- ) -> torch.FloatTensor:
518
- batch_size, seqlen_q = q.shape[0], q.shape[1]
519
- seqlen_k = kv.shape[1]
520
-
521
- if kv.shape[3] != q.shape[2]:
522
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
523
- k, v = kv.unbind(dim=2)
524
-
525
- q = q.to(torch.float32)
526
- k = k.to(torch.float32)
527
-
528
- causal = self.causal if causal is None else causal
529
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
530
-
531
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
532
- # using float16, which might lead to overflow
533
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
534
-
535
- if key_padding_mask is not None:
536
- padding_mask = torch.full(
537
- (batch_size, seqlen_k),
538
- -10000.0,
539
- dtype=scores.dtype,
540
- device=scores.device,
541
- )
542
- padding_mask.masked_fill_(key_padding_mask, 0.0)
543
-
544
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
545
-
546
- if causal:
547
- rows = rearrange(
548
- torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
549
- )
550
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
551
- causal_mask = cols > rows + seqlen_k - seqlen_q
552
-
553
- scores = scores.masked_fill(causal_mask, -10000.0)
554
-
555
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
556
- attention = self.drop(attention)
557
-
558
- output = torch.einsum("bhts,bshd->bthd", attention, v)
559
-
560
- return output
561
-
562
-
563
- def _find_mha_dims(
564
- config: PretrainedConfig,
565
- n_head: Optional[int] = None,
566
- n_head_kv: Optional[int] = None,
567
- head_dim: Optional[int] = None,
568
- ) -> Tuple[int, int]:
569
- if n_head is None and head_dim is None:
570
- head_dim = config.n_embd // config.n_head
571
- n_head = config.n_head
572
- elif n_head is None or head_dim is None:
573
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
574
-
575
- if n_head_kv is None:
576
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
577
-
578
- return n_head, n_head_kv, head_dim
579
-
580
-
581
- def _update_kv_cache(
582
- kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int
583
- ) -> torch.FloatTensor:
584
- num_heads, head_dim = kv.shape[-2:]
585
-
586
- if layer_idx not in inference_params.key_value_memory_dict:
587
- inference_params.key_value_memory_dict[layer_idx] = torch.empty(
588
- inference_params.max_batch_size,
589
- inference_params.max_seqlen,
590
- 2,
591
- num_heads,
592
- head_dim,
593
- dtype=kv.dtype,
594
- device=kv.device,
595
- )
596
-
597
- batch_start = inference_params.batch_size_offset
598
- batch_end = batch_start + kv.shape[0]
599
-
600
- sequence_start = inference_params.seqlen_offset
601
- sequence_end = sequence_start + kv.shape[1]
602
-
603
- # When the current sequence length is equal to or larger than the maximum sequence length,
604
- # we need to concatenate the current `kv` with the cached `kv` to expand its length
605
- if sequence_end >= inference_params.max_seqlen:
606
- inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
607
- (inference_params.key_value_memory_dict[layer_idx], kv), dim=1
608
- )
609
-
610
- inference_params.key_value_memory_dict[layer_idx][
611
- batch_start:batch_end, sequence_start:sequence_end, ...
612
- ] = kv
613
- kv = inference_params.key_value_memory_dict[layer_idx][
614
- batch_start:batch_end, :sequence_end, ...
615
- ]
616
-
617
- return kv
618
-
619
-
620
- class MHA(nn.Module):
621
- """Multi-head attention layer."""
622
-
623
- def __init__(
624
- self,
625
- config: PretrainedConfig,
626
- dtype: Optional[torch.dtype] = None,
627
- device: Optional[str] = None,
628
- rotary_dim: Optional[int] = None,
629
- rotary_base: float = 10000.0,
630
- rotary_scale_base: Optional[float] = None,
631
- n_head: Optional[int] = None,
632
- n_head_kv: Optional[int] = None,
633
- head_dim: Optional[int] = None,
634
- bias: bool = True,
635
- causal: bool = True,
636
- softmax_scale: Optional[float] = None,
637
- layer_idx: Optional[int] = None,
638
- return_residual: bool = False,
639
- checkpointing: bool = False,
640
- ) -> None:
641
- super().__init__()
642
-
643
- # Rotary embedding
644
- self.rotary_dim = (
645
- rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
646
- )
647
-
648
- if self.rotary_dim > 0:
649
- self.rotary_emb = RotaryEmbedding(
650
- self.rotary_dim,
651
- base=rotary_base,
652
- scale_base=rotary_scale_base,
653
- device=device,
654
- max_position_embeddings=config.n_positions,
655
- )
656
-
657
- # MLP
658
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
659
- config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
660
- )
661
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
662
- hidden_size = config.n_embd
663
-
664
- linear_cls = FusedDense if config.fused_dense else nn.Linear
665
- if linear_cls is None:
666
- linear_cls = nn.Linear
667
-
668
- self.Wqkv = linear_cls(
669
- hidden_size, op_size, bias=bias, device=device, dtype=dtype
670
- )
671
- self.out_proj = linear_cls(
672
- hidden_size, hidden_size, bias=bias, device=device, dtype=dtype
673
- )
674
-
675
- # Attention
676
- self.inner_attn = SelfAttention(
677
- causal=causal,
678
- softmax_scale=softmax_scale,
679
- attention_dropout=config.attn_pdrop,
680
- )
681
- self.inner_cross_attn = CrossAttention(
682
- causal=causal,
683
- softmax_scale=softmax_scale,
684
- attention_dropout=config.attn_pdrop,
685
- )
686
-
687
- self.layer_idx = layer_idx
688
- self.return_residual = return_residual
689
- self.checkpointing = checkpointing
690
-
691
- def _forward_self_attn(
692
- self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
693
- ) -> torch.FloatTensor:
694
- qkv = self.Wqkv(x)
695
- qkv = rearrange(
696
- qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
697
- )
698
-
699
- if self.rotary_dim > 0:
700
- qkv = self.rotary_emb(qkv)
701
-
702
- if self.checkpointing:
703
- return torch.utils.checkpoint.checkpoint(
704
- self.inner_attn, qkv, key_padding_mask=key_padding_mask
705
- )
706
-
707
- return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
708
-
709
- def _forward_cross_attn(
710
- self,
711
- x: torch.FloatTensor,
712
- past_key_values: Optional[InferenceParams],
713
- key_padding_mask: Optional[torch.BoolTensor],
714
- ) -> torch.FloatTensor:
715
- batch_size = x.shape[0]
716
-
717
- qkv = self.Wqkv(x)
718
-
719
- q = qkv[..., : self.n_head * self.head_dim]
720
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
721
-
722
- kv = qkv[..., self.n_head * self.head_dim :]
723
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
724
-
725
- seqlen_offset = (
726
- past_key_values.seqlen_offset if past_key_values is not None else 0
727
- )
728
- causal = None if seqlen_offset == 0 else False
729
- if self.rotary_dim > 0:
730
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
731
-
732
- if past_key_values is not None:
733
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
734
-
735
- if self.checkpointing:
736
- return torch.utils.checkpoint.checkpoint(
737
- self.inner_cross_attn,
738
- q,
739
- kv,
740
- key_padding_mask=key_padding_mask,
741
- causal=causal,
742
- )
743
-
744
- return self.inner_cross_attn(
745
- q, kv, key_padding_mask=key_padding_mask, causal=causal
746
- )
747
-
748
- def forward(
749
- self,
750
- x: torch.FloatTensor,
751
- past_key_values: Optional[InferenceParams] = None,
752
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
753
- **kwargs,
754
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
755
- if attention_mask is not None:
756
- attention_mask = attention_mask.bool()
757
- else:
758
- attention_mask = None
759
-
760
- # MHA
761
- if self.n_head == self.n_head_kv:
762
- if past_key_values is None:
763
- # If `past_key_values` are not supplied, we run self-attention
764
- attn_output = self._forward_self_attn(x, attention_mask)
765
- else:
766
- # If `past_key_values` are supplied, it means that we might have cached values and
767
- # could take advantage of cross-attention
768
- attn_output = self._forward_cross_attn(
769
- x, past_key_values, attention_mask
770
- )
771
- # MQA / GQA
772
- else:
773
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
774
- # because `q` and `kv` lengths might be different
775
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
776
-
777
- output = rearrange(attn_output, "... h d -> ... (h d)")
778
- output = self.out_proj(output)
779
-
780
- return output if not self.return_residual else (output, x)
781
-
782
-
783
- class ParallelBlock(nn.Module):
784
- """Parallel block.
785
-
786
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
787
-
788
- """
789
-
790
- def __init__(
791
- self,
792
- config: PretrainedConfig,
793
- block_idx: Optional[int] = None,
794
- ) -> None:
795
- super().__init__()
796
-
797
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
798
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
799
- self.block_idx = block_idx
800
-
801
- self.mixer = MHA(config, layer_idx=block_idx)
802
- self.mlp = MLP(config)
803
-
804
- def forward(
805
- self,
806
- hidden_states: torch.FloatTensor,
807
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
808
- attention_mask: Optional[torch.BoolTensor] = None,
809
- **kwargs,
810
- ) -> torch.FloatTensor:
811
- residual = hidden_states
812
- hidden_states = self.ln(hidden_states)
813
-
814
- attn_outputs = self.mixer(
815
- hidden_states,
816
- past_key_values=past_key_values,
817
- attention_mask=attention_mask,
818
- )
819
- if isinstance(attn_outputs, tuple):
820
- attn_outputs = attn_outputs[0]
821
-
822
- attn_outputs = self.resid_dropout(attn_outputs)
823
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
824
-
825
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
826
-
827
- return hidden_states
828
-
829
-
830
- class CausalLMHead(nn.Module):
831
- """Causal Language Modeling head.
832
-
833
- Reference:
834
- Improving Language Understanding by Generative Pre-Training.
835
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
836
-
837
- """
838
-
839
- def __init__(self, config: PretrainedConfig) -> None:
840
- super().__init__()
841
-
842
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
843
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
844
-
845
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
846
- hidden_states = self.ln(hidden_states)
847
- logits = self.linear(hidden_states).to(torch.float32)
848
-
849
- return logits
850
-
851
-
852
- class CausalLMLoss(nn.Module):
853
- """Causal Language Modeling loss.
854
-
855
- Reference:
856
- Improving Language Understanding by Generative Pre-Training.
857
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
858
-
859
- """
860
-
861
- def __init__(self, shift_labels: bool = True) -> None:
862
- super().__init__()
863
-
864
- self.shift_labels = shift_labels
865
- self.loss_fct = nn.CrossEntropyLoss()
866
-
867
- def forward(
868
- self, logits: torch.FloatTensor, labels: torch.LongTensor
869
- ) -> torch.FloatTensor:
870
- if self.shift_labels:
871
- logits = logits[..., :-1, :].contiguous()
872
- labels = labels[..., 1:].contiguous()
873
-
874
- loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
875
-
876
- return loss
877
-
878
-
879
- class PhiPreTrainedModel(PreTrainedModel):
880
- """Phi pre-trained model."""
881
-
882
- config_class = PhiConfig
883
- base_model_prefix = "transformer"
884
- supports_gradient_checkpointing = False
885
- _no_split_modules = ["ParallelBlock"]
886
-
887
- def __init__(self, *inputs, **kwargs) -> None:
888
- super().__init__(*inputs, **kwargs)
889
-
890
- def prepare_inputs_for_generation(
891
- self,
892
- input_ids: torch.LongTensor = None,
893
- inputs_embeds: torch.FloatTensor = None,
894
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
895
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
896
- **kwargs,
897
- ) -> Dict[str, Any]:
898
- if inputs_embeds is not None:
899
- max_batch_size = inputs_embeds.shape[0]
900
- seqlen_offset = inputs_embeds.shape[1] + input_ids.shape[1] - 2
901
- elif input_ids is not None:
902
- max_batch_size = input_ids.shape[0]
903
- seqlen_offset = input_ids.shape[1] - 1
904
- else:
905
- raise ValueError(
906
- "You have to specify either `input_ids` or `inputs_embeds`."
907
- )
908
-
909
- args = {}
910
-
911
- if past_key_values is None or not (
912
- isinstance(past_key_values, InferenceParams)
913
- ):
914
- past_key_values = InferenceParams(
915
- max_seqlen=self.config.n_positions,
916
- max_batch_size=max_batch_size,
917
- seqlen_offset=0,
918
- batch_size_offset=0,
919
- key_value_memory_dict={},
920
- lengths_per_sample=None,
921
- )
922
- if inputs_embeds is not None:
923
- args = {"inputs_embeds": inputs_embeds}
924
- elif input_ids is not None:
925
- args = {"input_ids": input_ids}
926
- else:
927
- raise ValueError(
928
- "You have to specify either `input_ids` or `inputs_embeds`."
929
- )
930
- else:
931
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
932
- past_key_values.seqlen_offset = seqlen_offset
933
- input_ids = input_ids[:, -1].unsqueeze(-1)
934
- args = {"input_ids": input_ids}
935
-
936
- return {
937
- **args,
938
- "past_key_values": past_key_values,
939
- "attention_mask": attention_mask,
940
- }
941
-
942
-
943
- class PhiModel(PhiPreTrainedModel):
944
- """Phi model."""
945
-
946
- _keys_to_ignore_on_load_missing = [""]
947
- _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
948
-
949
- def __init__(self, config: PhiConfig) -> None:
950
- super().__init__(config)
951
-
952
- self.embd = Embedding(config)
953
- self.h = nn.ModuleList(
954
- [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
955
- )
956
- self.gradient_checkpointing = config.gradient_checkpointing
957
- self.post_init()
958
-
959
- def get_input_embeddings(self) -> nn.Embedding:
960
- return self.embd.wte
961
-
962
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
963
- self.embd.wte = new_embeddings
964
-
965
- def forward(
966
- self,
967
- input_ids: torch.LongTensor = None,
968
- inputs_embeds: torch.FloatTensor = None,
969
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
970
- attention_mask: Optional[torch.BoolTensor] = None,
971
- ) -> torch.FloatTensor:
972
- if input_ids is not None and inputs_embeds is not None:
973
- raise ValueError(
974
- "You cannot specify both `input_ids` and `inputs_embeds` at the same time."
975
- )
976
- elif input_ids is None and inputs_embeds is None:
977
- raise ValueError(
978
- "You have to specify either `input_ids` or `inputs_embeds`."
979
- )
980
- elif input_ids is not None:
981
- hidden_states = self.embd(input_ids)
982
- else:
983
- hidden_states = inputs_embeds
984
-
985
- for layer in self.h:
986
- if self.gradient_checkpointing:
987
- hidden_states = torch.utils.checkpoint.checkpoint(
988
- layer.__call__,
989
- hidden_states,
990
- past_key_values,
991
- attention_mask,
992
- use_reentrant=True,
993
- )
994
- else:
995
- hidden_states = layer(
996
- hidden_states,
997
- past_key_values=past_key_values,
998
- attention_mask=attention_mask,
999
- )
1000
-
1001
- return hidden_states
1002
-
1003
-
1004
- class PhiForCausalLM(PhiPreTrainedModel):
1005
- """Phi for Causal Language Modeling."""
1006
-
1007
- _keys_to_ignore_on_load_missing = [""]
1008
- _keys_to_ignore_on_load_unexpected = [
1009
- r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
1010
- ]
1011
-
1012
- def __init__(self, config: PhiConfig) -> None:
1013
- super().__init__(config)
1014
-
1015
- self.transformer = PhiModel(config)
1016
- self.lm_head = CausalLMHead(config)
1017
- self.loss = CausalLMLoss()
1018
-
1019
- self.post_init()
1020
-
1021
- def get_output_embeddings(self) -> nn.Linear:
1022
- return self.lm_head.linear
1023
-
1024
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1025
- self.lm_head.linear = new_embeddings
1026
-
1027
- def forward(
1028
- self,
1029
- input_ids: torch.LongTensor = None,
1030
- inputs_embeds: torch.FloatTensor = None,
1031
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1032
- attention_mask: Optional[torch.BoolTensor] = None,
1033
- labels: Optional[torch.LongTensor] = None,
1034
- **kwargs,
1035
- ) -> CausalLMOutputWithPast:
1036
- hidden_states = self.transformer(
1037
- input_ids,
1038
- inputs_embeds,
1039
- past_key_values=past_key_values,
1040
- attention_mask=attention_mask,
1041
- )
1042
- lm_logits = self.lm_head(hidden_states)
1043
-
1044
- loss = None
1045
- if labels is not None:
1046
- loss = self.loss(lm_logits, labels)
1047
-
1048
- return CausalLMOutputWithPast(
1049
- loss=loss, logits=lm_logits, past_key_values=past_key_values
1050
- )
1051
-
1052
-
1053
- class VisionEncoder(nn.Module):
1054
- def __init__(self, model_path: str = "model") -> None:
1055
- super().__init__()
1056
- self.model = torch.jit.load(f"{model_path}/vision.pt").to(DEVICE, dtype=DTYPE)
1057
- self.preprocess = Compose(
1058
- [
1059
- Resize(size=(384, 384), interpolation=InterpolationMode.BICUBIC),
1060
- ToImage(),
1061
- ToDtype(torch.float32, scale=True),
1062
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
1063
- ]
1064
- )
1065
-
1066
- def __call__(self, image: Image) -> torch.Tensor:
1067
- with torch.no_grad():
1068
- image_vec = self.preprocess(image.convert("RGB")).unsqueeze(0)
1069
- image_vec = image_vec[:, :, :-6, :-6]
1070
- image_vec = rearrange(
1071
- image_vec, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14
1072
- )
1073
-
1074
- image_vec = image_vec.to(DEVICE, dtype=DTYPE)
1075
- return self.model(image_vec)
1076
-
1077
-
1078
- class TextModel(nn.Module):
1079
- def __init__(self, model_path: str = "model") -> None:
1080
- super().__init__()
1081
- self.tokenizer = Tokenizer.from_pretrained(f"{model_path}/tokenizer")
1082
- phi_config = PhiConfig.from_pretrained(f"{model_path}/text_model_cfg.json")
1083
-
1084
- with init_empty_weights():
1085
- self.model = PhiForCausalLM(phi_config)
1086
-
1087
- self.model = load_checkpoint_and_dispatch(
1088
- self.model,
1089
- f"{model_path}/text_model.pt",
1090
- device_map={"": DEVICE},
1091
- dtype=DTYPE,
1092
- )
1093
-
1094
- self.text_emb = self.model.get_input_embeddings()
1095
-
1096
- def input_embeds(self, prompt, image_embeds):
1097
- embeds = []
1098
-
1099
- def _add_toks(toks):
1100
- embeds.append(self.text_emb(toks))
1101
-
1102
- def _tokenize(txt):
1103
- return self.tokenizer(
1104
- txt, return_tensors="pt", add_special_tokens=False
1105
- ).input_ids.to(self.model.device)
1106
-
1107
- # Add BOS token
1108
- _add_toks(
1109
- torch.tensor([[self.tokenizer.bos_token_id]], device=self.model.device)
1110
- )
1111
-
1112
- if "<image>" not in prompt:
1113
- embeds.append(self.text_emb(_tokenize(prompt)))
1114
- else:
1115
- assert prompt.count("<image>") == 1
1116
- before, after = prompt.split("<image>")
1117
- embeds.append(self.text_emb(_tokenize(f"{before}<image>")))
1118
- embeds.append(image_embeds.to(self.model.device))
1119
- embeds.append(self.text_emb(_tokenize(f"</image>{after}")))
1120
-
1121
- return torch.cat(embeds, dim=1)
1122
-
1123
- def generate(
1124
- self, image_embeds, prompt, eos_text="Human:", max_new_tokens=128, **kwargs
1125
- ):
1126
- eos_tokens = self.tokenizer(eos_text, add_special_tokens=False)[0].ids
1127
-
1128
- generate_config = {
1129
- "eos_token_id": eos_tokens,
1130
- "bos_token_id": self.tokenizer.bos_token_id,
1131
- "pad_token_id": self.tokenizer.eos_token_id,
1132
- "max_new_tokens": max_new_tokens,
1133
- **kwargs,
1134
- }
1135
-
1136
- with torch.no_grad():
1137
- inputs_embeds = self.input_embeds(prompt, image_embeds)
1138
- output_ids = self.model.generate(
1139
- inputs_embeds=inputs_embeds, **generate_config
1140
- )
1141
-
1142
- return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
1143
-
1144
- def answer_question(self, image_embeds, question, **kwargs):
1145
- prompt = f"<image>\n\nQuestion: {question}\n\nAnswer:"
1146
- answer = self.generate(
1147
- image_embeds,
1148
- prompt,
1149
- eos_text="<END>",
1150
- max_new_tokens=128,
1151
- **kwargs,
1152
- )[0]
1153
-
1154
- return re.sub("<$", "", re.sub("END$", "", answer)).strip()
1155
-
1156
-
1157
- ##### GRADIO INTERFACE #####
1158
-
1159
- import gradio as gr
1160
- from huggingface_hub import snapshot_download
1161
- from threading import Thread
1162
- from transformers import TextIteratorStreamer
1163
- import hashlib
1164
- import os
1165
- from fastapi import FastAPI, File, UploadFile, Form
1166
- from PIL import Image
1167
- from io import BytesIO
1168
- from typing import List
1169
- from pydantic import BaseModel
1170
- from fastapi.responses import HTMLResponse, FileResponse
1171
  from fastapi.staticfiles import StaticFiles
1172
-
1173
- model_path = snapshot_download("vikhyatk/moondream1")
1174
- vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
1175
- text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)
 
 
1176
 
1177
 
1178
 
1179
- # Define a FastAPI app
1180
  app = FastAPI()
1181
- def cached_vision_encoder(image):
1182
- # Calculate checksum of the image
1183
- image_hash = hashlib.sha256(image.tobytes()).hexdigest()
1184
-
1185
- # Check if `image_encoder_cache/{image_hash}.pt` exists, if so load and return it.
1186
- # Otherwise, save the encoded image to `image_encoder_cache/{image_hash}.pt` and return it.
1187
- cache_path = f"image_encoder_cache/{image_hash}.pt"
1188
- if os.path.exists(cache_path):
1189
- return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
1190
- else:
1191
- image_vec = vision_encoder(image).to("cpu", dtype=torch.float16)
1192
- os.makedirs("image_encoder_cache", exist_ok=True)
1193
- torch.save(image_vec, cache_path)
1194
- return image_vec.to(DEVICE, dtype=DTYPE)
1195
- def answer_question(image, question):
1196
- yield ""
1197
 
1198
- streamer = TextIteratorStreamer(text_model.tokenizer, skip_special_tokens=True)
1199
- generation_kwargs = dict(
1200
- image_embeds=cached_vision_encoder(image), question=question, streamer=streamer
1201
- )
1202
- thread = Thread(target=text_model.answer_question, kwargs=generation_kwargs)
1203
- thread.start()
1204
 
1205
- # Create an empty list to store generated sentences
1206
- generated_sentences = []
1207
-
1208
- for new_text in streamer:
1209
- # Append each new sentence to the list
1210
- generated_sentences.append(new_text)
1211
-
1212
- # Concatenate the sentences into a single string
1213
- combined_result = " ".join(generated_sentences)
1214
-
1215
- # Return the combined result as a single sentence
1216
- yield combined_result
1217
 
 
1218
  @app.post("/upload/")
1219
- async def answer(image: UploadFile = File(...), Question: str = Form(...)):
1220
- image_pil = Image.open(image.file)
1221
-
1222
- # Generate the list of sentences
1223
- answer_generator = answer_question(image_pil, Question)
1224
- result_list = list(answer_generator)
 
 
 
 
1225
 
1226
- # Concatenate the sentences into a single string
1227
- combined_result = ", ".join(result_list)
1228
-
1229
- # Return the combined result as a single sentence
1230
- return {combined_result}
1231
 
1232
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
1233
 
1234
-
1235
  @app.get("/")
1236
  def index() -> FileResponse:
1237
- return FileResponse(path="/app/static/index.html", media_type="text/html")
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.responses import HTMLResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ from fastapi.responses import FileResponse
6
+ from gradio_client import Client
7
+ import shutil
8
+ import os
9
+ import tempfile
10
 
11
 
12
 
 
13
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ hf_token = os.environ.get('HF_TOKEN')
16
+ client = Client("Ashrafb/moondream_captioning", hf_token=hf_token)
 
 
 
 
17
 
18
+ # Function to make API prediction
19
+ def predict_image_description(file_path, question):
20
+ hf_token = os.environ.get('HF_TOKEN')
21
+ client = Client("Ashrafb/moondream_captioning", hf_token=hf_token)
22
+ result = client.predict(file_path, question, api_name="/get_caption")
23
+ return result
 
 
 
 
 
 
24
 
25
+ # Route to handle file upload
26
  @app.post("/upload/")
27
+ async def answer(image: UploadFile = File(...), question: str = Form(...)):
28
+ try:
29
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
30
+ shutil.copyfileobj(image.file, temp_file)
31
+ temp_file_path = temp_file.name
32
+ description = predict_image_description(temp_file_path, question)
33
+ os.unlink(temp_file_path)
34
+ return {"description}
35
+ except Exception as e:
36
+ raise HTTPException(status_code=500, detail=str(e))
37
 
 
 
 
 
 
38
 
39
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
40
 
 
41
  @app.get("/")
42
  def index() -> FileResponse:
43
+ return FileResponse(path="/app/static/index.html", media_type="text/html")
44
+