aolko commited on
Commit
bb14bef
1 Parent(s): 19af994

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -5
app.py CHANGED
@@ -7,16 +7,31 @@ from PIL import Image
7
  from io import BytesIO
8
 
9
  # Initialize models
10
- anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger")
11
  photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
12
  processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
13
 
14
  def get_booru_image(booru, image_id):
15
- # This is a placeholder function. You'd need to implement the actual API calls for each booru.
16
- url = f"https://api.{booru}.org/images/{image_id}"
 
 
 
 
 
 
 
17
  response = requests.get(url)
18
- img = Image.open(BytesIO(response.content))
19
- tags = ["tag1", "tag2", "tag3"] # Placeholder
 
 
 
 
 
 
 
 
20
  return img, tags
21
 
22
  def transcribe_image(image, image_type, transcriber, booru_tags=None):
 
7
  from io import BytesIO
8
 
9
  # Initialize models
10
+ anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-convnext-tagger-v3")
11
  photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
12
  processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
13
 
14
  def get_booru_image(booru, image_id):
15
+ if booru == "Gelbooru":
16
+ url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
17
+ elif booru == "Danbooru":
18
+ url = f"https://danbooru.donmai.us/posts/{image_id}.json"
19
+ elif booru == "rule34.xxx":
20
+ url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
21
+ else:
22
+ raise ValueError("Unsupported booru")
23
+
24
  response = requests.get(url)
25
+ data = response.json()
26
+
27
+ # The exact structure of the response will vary depending on the booru
28
+ # You'll need to adjust this part based on each booru's API
29
+ image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url']
30
+ tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split()
31
+
32
+ img_response = requests.get(image_url)
33
+ img = Image.open(BytesIO(img_response.content))
34
+
35
  return img, tags
36
 
37
  def transcribe_image(image, image_type, transcriber, booru_tags=None):