Testys commited on
Commit
e67cd9e
·
1 Parent(s): 8c50a25

Upload 3 files

Browse files
utils/caption_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration
3
+ from utils.image_utils import load_image
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+
8
+ class ImageCaptioning:
9
+
10
+ def __int__(self):
11
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
13
+
14
+ def get_caption(self, image_path):
15
+ image = load_image(image_path)
16
+
17
+ # Preprocessing the Image
18
+ img = self.processor(image, return_tensors="pt").to(device)
19
+
20
+ # Generating captions
21
+ output = self.model.generate(**img)
22
+
23
+ # decode the output
24
+ caption = self.processor.batch_decode(output, skip_special_tokens=True)[0]
25
+
26
+ return caption
utils/image_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import urllib.parse as parse
4
+ import os
5
+
6
+
7
+ # Verify url
8
+ def check_url(string):
9
+ try:
10
+ result = parse.urlparse(string)
11
+ return all([result.scheme, result.netloc, result.path])
12
+ except:
13
+ return False
14
+
15
+
16
+ # Load an image
17
+ def load_image(image_path):
18
+ if check_url(image_path):
19
+ return Image.open(requests.get(image_path, stream=True).raw)
20
+ elif os.path.exists(image_path):
21
+ return Image.open(image_path)
utils/topic_generation.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+ device = "cuda" if torch.cuda.is_available() else "cpu"
4
+
5
+
6
+ class TopicGenerator:
7
+
8
+ def __init__(self):
9
+ # Initialize tokenizer and model upon class instantiation
10
+ self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
11
+ self.model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to(device) # assuming you have a GPU available
12
+
13
+ def generate_topics(self, user_input, num_topics=3):
14
+ """
15
+ Generate topic sentences based on the user input.
16
+
17
+ Args:
18
+ - user_input (str): The input text provided by the user.
19
+ - num_topics (int, optional): Number of topics to generate. Defaults to 3.
20
+
21
+ Returns:
22
+ - list: A list of generated topic sentences.
23
+ """
24
+ prompt_text = f"Generate a topic sentence based on the following input: {user_input}"
25
+ input_ids = self.tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
26
+
27
+ # Generate topics
28
+ outputs = self.model.generate(input_ids, do_sample=True, top_k=50, temperature=0.7, max_length=50, num_return_sequences=num_topics)
29
+
30
+ # Decode the outputs and return as a list of topic sentences
31
+ return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]