Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
import numpy as np | |
from groq import Groq | |
import spaces | |
from transformers import AutoModel, AutoTokenizer | |
from diffusers import StableDiffusion3Pipeline | |
from parler_tts import ParlerTTSForConditionalGeneration | |
import soundfile as sf | |
from langchain_groq import ChatGroq | |
from PIL import Image | |
from tavily import TavilyClient | |
from langchain.schema import AIMessage | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from torchvision import transforms | |
import json | |
import pandas | |
# Initialize models and clients | |
MODEL = 'llama-3.1-70b-versatile' | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, | |
device_map="auto", torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) | |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") | |
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") | |
# Updated Image generation model | |
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
# Tavily Client for web search | |
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API")) | |
# Function to play voice output | |
def play_voice_output(response): | |
print("Executing play_voice_output function") | |
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." | |
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') | |
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') | |
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) | |
audio_arr = generation.cpu().numpy().squeeze() | |
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) | |
return "output.wav" | |
# Function to classify user input using LLM | |
def classify_function(user_prompt): | |
prompt = f""" | |
You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions: | |
- `image_generation`: If the user wants to generate an image. | |
- `image_vqa`: If the user wants to ask questions about an image. | |
- `document_qa`: If the user wants to ask questions about a document. | |
- `text_to_text`: If the user wants a text-based response. | |
Respond with a JSON object containing only the chosen function. For example: | |
```json | |
{{"function": "image_generation"}} | |
``` | |
User input: {user_prompt} | |
""" | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
], | |
model="llama3-8b-8192", | |
) | |
try: | |
response = json.loads(chat_completion.choices[0].message.content) | |
function = response.get("function") | |
return function | |
except json.JSONDecodeError: | |
print(f"Error decoding JSON: {chat_completion.choices[0].message.content}") | |
return "text_to_text" # Default to text-to-text if JSON parsing fails | |
# Document Question Answering Tool | |
class DocumentQuestionAnswering: | |
def __init__(self, document): | |
self.document = document | |
self.qa_chain = self._setup_qa_chain() | |
def _setup_qa_chain(self): | |
print("Setting up DocumentQuestionAnswering tool") | |
loader = TextLoader(self.document) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_documents(documents) | |
embeddings = HuggingFaceEmbeddings() | |
db = FAISS.from_documents(texts, embeddings) | |
retriever = db.as_retriever() | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")), | |
chain_type="stuff", | |
retriever=retriever, | |
) | |
return qa_chain | |
def run(self, query: str) -> str: | |
print("Executing DocumentQuestionAnswering tool") | |
response = self.qa_chain.run(query) | |
return str(response) | |
# Function to handle different input types and choose the right pipeline | |
def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None): | |
print(f"Handling input: {user_prompt}") | |
# Initialize the LLM | |
llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")) | |
# Handle voice-only mode | |
if audio: | |
print("Processing audio input") | |
transcription = client.audio.transcriptions.create( | |
file=(audio.name, audio.read()), | |
model="whisper-large-v3" | |
) | |
user_prompt = transcription.text | |
response = llm.invoke(query=user_prompt) | |
audio_output = play_voice_output(response) | |
return "Response generated.", audio_output | |
# Handle websearch mode | |
if websearch: | |
print("Executing Web Search") | |
answer = tavily_client.qna_search(query=user_prompt) | |
return answer, None | |
# Handle cases with only image or document input | |
if user_prompt is None or user_prompt.strip() == "": | |
if image: | |
user_prompt = "Describe this image" | |
elif document: | |
user_prompt = "Summarize this document" | |
# Classify user input using LLM | |
function = classify_function(user_prompt) | |
# Handle different functions | |
if function == "image_generation": | |
print("Executing Image Generation") | |
image = pipe( | |
user_prompt, | |
negative_prompt="", | |
num_inference_steps=15, | |
guidance_scale=7.0, | |
).images[0] | |
image.save("output.jpg") | |
return "output.jpg", None | |
elif function == "image_vqa": | |
print("Executing Image Description") | |
if image: | |
print("1") | |
image = Image.open(image).convert('RGB') | |
print("2") | |
# Add preprocessing steps here (see examples above) | |
preprocess = transforms.Compose([ | |
transforms.Resize((512, 512)), # Example size, replace with the correct one | |
transforms.ToTensor(), | |
]) | |
image = preprocess(image) | |
image = image.unsqueeze(0) # Add batch dimension | |
image = image.to(torch.float32) # Ensure correct data type | |
print("3") | |
messages = [{"role": "user", "content": user_prompt}] | |
print("4") | |
response,ctxt = vqa_model.chat(image=image, msgs=messages, tokenizer=tokenizer, context=None, temperature=0.5) | |
print("5") | |
return response, None | |
else: | |
return "Please upload an imagee.", None | |
elif function == "document_qa": | |
print("Executing Document Summarization") | |
if document: | |
document_qa = DocumentQuestionAnswering(document) | |
response = document_qa.run(user_prompt) | |
return response, None | |
else: | |
return "Please upload a documentt.", None | |
else: # function == "text_to_text" | |
print("Executing Text-to-Text") | |
response = llm.invoke(query=user_prompt) | |
return response, None | |
# Main interface function | |
def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None): | |
print("Starting main_interface function") | |
vqa_model.to(device='cuda', dtype=torch.bfloat16) | |
tts_model.to("cuda") | |
pipe.to("cuda") | |
print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}") | |
try: | |
response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document) | |
print("handle_input function executed successfully") | |
except Exception as e: | |
print(f"Error in handle_input: {e}") | |
response = "Error occurred during processing." | |
return response | |
def create_ui(): | |
with gr.Blocks(css=""" | |
/* Overall Styling */ | |
body { | |
font-family: 'Poppins', sans-serif; | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
margin: 0; | |
padding: 0; | |
color: #333; | |
} | |
/* Title Styling */ | |
.gradio-container h1 { | |
text-align: center; | |
padding: 20px 0; | |
background: linear-gradient(45deg, #007bff, #00c6ff); | |
color: white; | |
font-size: 2.5em; | |
font-weight: bold; | |
letter-spacing: 1px; | |
text-transform: uppercase; | |
margin: 0; | |
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); | |
} | |
/* Input Area Styling */ | |
.gradio-container .gr-row { | |
display: flex; | |
justify-content: space-around; | |
align-items: center; | |
padding: 20px; | |
background-color: white; | |
border-radius: 10px; | |
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1); | |
margin-bottom: 20px; | |
} | |
.gradio-container .gr-column { | |
flex: 1; | |
margin: 0 10px; | |
} | |
/* Textbox Styling */ | |
.gradio-container textarea { | |
width: calc(100% - 20px); | |
padding: 15px; | |
border: 2px solid #007bff; | |
border-radius: 8px; | |
font-size: 1.1em; | |
transition: border-color 0.3s, box-shadow 0.3s; | |
} | |
.gradio-container textarea:focus { | |
border-color: #00c6ff; | |
box-shadow: 0px 0px 8px rgba(0, 198, 255, 0.5); | |
outline: none; | |
} | |
/* Button Styling */ | |
.gradio-container button { | |
background: linear-gradient(45deg, #007bff, #00c6ff); | |
color: white; | |
padding: 15px 25px; | |
border: none; | |
border-radius: 8px; | |
cursor: pointer; | |
font-size: 1.2em; | |
font-weight: bold; | |
transition: background 0.3s, transform 0.3s; | |
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.gradio-container button:hover { | |
background: linear-gradient(45deg, #0056b3, #009bff); | |
transform: translateY(-3px); | |
} | |
.gradio-container button:active { | |
transform: translateY(0); | |
} | |
/* Output Area Styling */ | |
.gradio-container .output-area { | |
padding: 20px; | |
text-align: center; | |
background-color: #f7f9fc; | |
border-radius: 10px; | |
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1); | |
margin-top: 20px; | |
} | |
/* Image Styling */ | |
.gradio-container img { | |
max-width: 100%; | |
height: auto; | |
border-radius: 10px; | |
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1); | |
transition: transform 0.3s, box-shadow 0.3s; | |
} | |
.gradio-container img:hover { | |
transform: scale(1.05); | |
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.2); | |
} | |
/* Checkbox Styling */ | |
.gradio-container input[type="checkbox"] { | |
width: 20px; | |
height: 20px; | |
cursor: pointer; | |
accent-color: #007bff; | |
transition: transform 0.3s; | |
} | |
.gradio-container input[type="checkbox"]:checked { | |
transform: scale(1.2); | |
} | |
/* Audio and Document Upload Styling */ | |
.gradio-container .gr-file-upload input[type="file"] { | |
width: 100%; | |
padding: 10px; | |
border: 2px solid #007bff; | |
border-radius: 8px; | |
cursor: pointer; | |
background-color: white; | |
transition: border-color 0.3s, background-color 0.3s; | |
} | |
.gradio-container .gr-file-upload input[type="file"]:hover { | |
border-color: #00c6ff; | |
background-color: #f0f8ff; | |
} | |
/* Advanced Tooltip Styling */ | |
.gradio-container .gr-tooltip { | |
position: relative; | |
display: inline-block; | |
cursor: pointer; | |
} | |
.gradio-container .gr-tooltip .tooltiptext { | |
visibility: hidden; | |
width: 200px; | |
background-color: black; | |
color: #fff; | |
text-align: center; | |
border-radius: 6px; | |
padding: 5px; | |
position: absolute; | |
z-index: 1; | |
bottom: 125%; | |
left: 50%; | |
margin-left: -100px; | |
opacity: 0; | |
transition: opacity 0.3s; | |
} | |
.gradio-container .gr-tooltip:hover .tooltiptext { | |
visibility: visible; | |
opacity: 1; | |
} | |
/* Footer Styling */ | |
.gradio-container footer { | |
text-align: center; | |
padding: 10px; | |
background: #007bff; | |
color: white; | |
font-size: 0.9em; | |
border-radius: 0 0 10px 10px; | |
box-shadow: 0px -2px 8px rgba(0, 0, 0, 0.1); | |
} | |
""") as demo: | |
gr.Markdown("# AI Assistant") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") | |
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") | |
document_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon") | |
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") | |
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") | |
with gr.Column(scale=1): | |
submit = gr.Button("Submit") | |
output_label = gr.Label(label="Output") | |
audio_output = gr.Audio(label="Audio Output", visible=False) | |
submit.click( | |
fn=main_interface, | |
inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode, document_input], | |
outputs=[output_label, audio_output] | |
) | |
voice_only_mode.change( | |
lambda x: gr.update(visible=not x), | |
inputs=voice_only_mode, | |
outputs=[user_prompt, image_input, websearch_mode, document_input, submit] | |
) | |
voice_only_mode.change( | |
lambda x: gr.update(visible=x), | |
inputs=voice_only_mode, | |
outputs=[audio_input] | |
) | |
return demo | |
# Launch the UI | |
demo = create_ui() | |
demo.launch() |