John6666 commited on
Commit
8e6bf35
·
verified ·
1 Parent(s): 0c1dfae

Upload joycaption.py

Browse files
Files changed (1) hide show
  1. joycaption.py +28 -4
joycaption.py CHANGED
@@ -1,4 +1,13 @@
1
- import spaces
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
  from torch import nn
@@ -7,11 +16,13 @@ from pathlib import Path
7
  import torch
8
  import torch.amp.autocast_mode
9
  from PIL import Image
10
- import os
11
  import torchvision.transforms.functional as TVF
12
  import gc
13
  from peft import PeftConfig
14
 
 
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
  use_inference_client = False
@@ -119,6 +130,8 @@ class ImageAdapter(nn.Module):
119
  # https://huggingface.co/blog/4bit-transformers-bitsandbytes
120
  # https://huggingface.co/docs/transformers/main/en/peft
121
  # https://huggingface.co/docs/transformers/main/en/peft#enable-and-disable-adapters
 
 
122
  tokenizer = None
123
  text_model_client = None
124
  text_model = None
@@ -171,14 +184,12 @@ load_text_model.zerogpu = True
171
  print("Loading CLIP")
172
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
173
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
174
-
175
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
176
  print("Loading VLM's custom vision model")
177
  checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
178
  checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
179
  clip_model.load_state_dict(checkpoint)
180
  del checkpoint
181
-
182
  clip_model.eval().requires_grad_(False).to(device)
183
 
184
  # Tokenizer
@@ -376,6 +387,19 @@ def is_repo_exists(repo_id):
376
  return True # for safe
377
 
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def get_text_model():
380
  return list(llm_models.keys())
381
 
 
1
+ import os
2
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
3
+ import spaces
4
+ else:
5
+ class spaces:
6
+ @staticmethod
7
+ def GPU(func):
8
+ def wrapper(*args, **kwargs):
9
+ return func(*args, **kwargs)
10
+ return wrapper
11
  import gradio as gr
12
  from huggingface_hub import InferenceClient
13
  from torch import nn
 
16
  import torch
17
  import torch.amp.autocast_mode
18
  from PIL import Image
 
19
  import torchvision.transforms.functional as TVF
20
  import gc
21
  from peft import PeftConfig
22
 
23
+ import subprocess
24
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
25
+
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
28
  use_inference_client = False
 
130
  # https://huggingface.co/blog/4bit-transformers-bitsandbytes
131
  # https://huggingface.co/docs/transformers/main/en/peft
132
  # https://huggingface.co/docs/transformers/main/en/peft#enable-and-disable-adapters
133
+ # https://huggingface.co/docs/transformers/main/quantization/bitsandbytes?bnb=4-bit
134
+ # https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4
135
  tokenizer = None
136
  text_model_client = None
137
  text_model = None
 
184
  print("Loading CLIP")
185
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
186
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
 
187
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
188
  print("Loading VLM's custom vision model")
189
  checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
190
  checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
191
  clip_model.load_state_dict(checkpoint)
192
  del checkpoint
 
193
  clip_model.eval().requires_grad_(False).to(device)
194
 
195
  # Tokenizer
 
387
  return True # for safe
388
 
389
 
390
+ def is_valid_repo(repo_id):
391
+ from huggingface_hub import HfApi
392
+ import re
393
+ try:
394
+ if not re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', repo_id): return False
395
+ api = HfApi()
396
+ if api.repo_exists(repo_id=repo_id): return True
397
+ else: return False
398
+ except Exception as e:
399
+ print(f"Failed to connect {repo_id}. {e}")
400
+ return False
401
+
402
+
403
  def get_text_model():
404
  return list(llm_models.keys())
405