multimodalart HF staff commited on
Commit
fd01f63
1 Parent(s): 99fe006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -1
app.py CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_d
12
  import copy
13
  import random
14
  import time
 
15
 
16
  # Load LoRAs from JSON file
17
  with open('loras.json', 'r') as f:
@@ -56,6 +57,26 @@ class calculateDuration:
56
  else:
57
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def update_selection(evt: gr.SelectData, selected_indices, width, height):
60
  selected_index = evt.index
61
  selected_indices = selected_indices or []
@@ -337,10 +358,12 @@ def add_custom_lora(custom_lora, selected_indices):
337
  print(f"Loaded custom LoRA: {repo}")
338
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
339
  if existing_item_index is None:
 
 
340
  new_item = {
341
  "image": image if image is not None else "",
342
  "title": title,
343
- "repo": repo,
344
  "weights": path,
345
  "trigger_word": trigger_word
346
  }
 
12
  import copy
13
  import random
14
  import time
15
+ import requests
16
 
17
  # Load LoRAs from JSON file
18
  with open('loras.json', 'r') as f:
 
57
  else:
58
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
59
 
60
+ def download_file(url, directory=None):
61
+ if directory is None:
62
+ directory = os.getcwd() # Use current working directory if not specified
63
+
64
+ # Get the filename from the URL
65
+ filename = url.split('/')[-1]
66
+
67
+ # Full path for the downloaded file
68
+ filepath = os.path.join(directory, filename)
69
+
70
+ # Download the file
71
+ response = requests.get(url)
72
+ response.raise_for_status() # Raise an exception for bad status codes
73
+
74
+ # Write the content to the file
75
+ with open(filepath, 'wb') as file:
76
+ file.write(response.content)
77
+
78
+ return filepath
79
+
80
  def update_selection(evt: gr.SelectData, selected_indices, width, height):
81
  selected_index = evt.index
82
  selected_indices = selected_indices or []
 
358
  print(f"Loaded custom LoRA: {repo}")
359
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
360
  if existing_item_index is None:
361
+ if(repo.endswith(".safetensors")):
362
+ downloaded_path = download_file(url)
363
  new_item = {
364
  "image": image if image is not None else "",
365
  "title": title,
366
+ "repo": downloaded_path,
367
  "weights": path,
368
  "trigger_word": trigger_word
369
  }