Shriharsh's picture
Update app.py
5ef8553 verified
import gradio as gr
import pinecone
import requests
import PyPDF2
from transformers import AutoTokenizer, AutoModel
import torch
import re
import google.generativeai as genai
import os
import time
from datetime import datetime, timedelta
from google.api_core import exceptions
# Constants
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") # Set in HF Spaces Secrets
PINECONE_INDEX_NAME = "diabetes-bot"
PINECONE_NAMESPACE = "general"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Set in HF Spaces Secrets
MODEL_NAME = "dmis-lab/biobert-base-cased-v1.1"
# Free tier limits
FREE_TIER_RPD_LIMIT = 1000 # Requests per day
FREE_TIER_RPM_LIMIT = 15 # Requests per minute
FREE_TIER_TPM_LIMIT = 1000000 # Tokens per minute
WARNING_THRESHOLD = 0.9 # Stop at 90% of the limit to be safe
# Usage tracking
usage_file = "usage.txt"
def load_usage():
if not os.path.exists(usage_file):
return {"requests": [], "tokens": []}
with open(usage_file, "r") as f:
data = f.read().strip()
if not data:
return {"requests": [], "tokens": []}
requests, tokens = data.split("|")
return {
"requests": [float(t) for t in requests.split(",") if t],
"tokens": [(float(t), float(n)) for t, n in [pair.split(":") for pair in tokens.split(",") if pair]]
}
def save_usage(requests, tokens):
with open(usage_file, "w") as f:
f.write(",".join(map(str, requests)) + "|" + ",".join(f"{t}:{n}" for t, n in tokens))
def check_usage():
usage = load_usage()
now = time.time()
# Clean up old requests (older than 24 hours)
day_ago = now - 24 * 60 * 60
usage["requests"] = [t for t in usage["requests"] if t > day_ago]
# Clean up old token counts (older than 1 minute)
minute_ago = now - 60
usage["tokens"] = [(t, n) for t, n in usage["tokens"] if t > minute_ago]
# Count requests per day
rpd = len(usage["requests"])
rpd_limit = int(FREE_TIER_RPD_LIMIT * WARNING_THRESHOLD)
if rpd >= rpd_limit:
return False, f"Approaching daily request limit ({rpd}/{FREE_TIER_RPD_LIMIT}). Stopping to stay in free tier. Try again tomorrow."
# Count requests per minute
minute_ago = now - 60
rpm = len([t for t in usage["requests"] if t > minute_ago])
rpm_limit = int(FREE_TIER_RPM_LIMIT * WARNING_THRESHOLD)
if rpm >= rpm_limit:
return False, f"Approaching minute request limit ({rpm}/{FREE_TIER_RPM_LIMIT}). Wait a minute and try again."
# Count tokens per minute
tpm = sum(n for t, n in usage["tokens"])
tpm_limit = int(FREE_TIER_TPM_LIMIT * WARNING_THRESHOLD)
if tpm >= tpm_limit:
return False, f"Approaching token limit ({tpm}/{FREE_TIER_TPM_LIMIT} per minute). Wait a minute and try again."
return True, (rpd, rpm, tpm)
# Initialize Pinecone
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index(PINECONE_INDEX_NAME)
# Initialize BioBERT for embedding queries
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
if torch.cuda.is_available():
model.cuda()
# Initialize Gemini and check available models
genai.configure(api_key=GEMINI_API_KEY)
# List available models to confirm free tier access
available_models = [model.name for model in genai.list_models()]
print("Available Gemini models:", available_models)
preferred_model = "gemini-pro" # Use the generally available model
if preferred_model in available_models:
gemini_model = genai.GenerativeModel(preferred_model)
print(f"Using model: {preferred_model}")
else:
# Try other available models (if needed)
for model_name in ["gemini-2.0-flash", "gemini-1.5-pro"]:
if f"models/{model_name}" in available_models:
gemini_model = genai.GenerativeModel(f"models/{model_name}")
print(f"Using model: models/{model_name}")
break # Use the first available match
else:
raise ValueError("No suitable Gemini model available. Available models: " + str(available_models))
# Clean text
def clean_text(text):
text = re.sub(r'<[^>]+>', '', text) # Remove HTML tags
text = re.sub(r'[^\x00-\x7F]+', ' ', text) # Remove non-ASCII
text = re.sub(r'\s+', ' ', text) # Normalize spaces
return text.strip()
# Embed text using BioBERT
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()[0]
return embedding.tolist()
# Extract text from PDF (up to 10 pages)
def extract_pdf_text(pdf_file):
reader = PyPDF2.PdfReader(pdf_file)
num_pages = min(len(reader.pages), 10) # Limit to 10 pages
text = ""
for page in range(num_pages):
text += reader.pages[page].extract_text() + "\n"
return clean_text(text)
# Retrieve relevant chunks from Pinecone
def retrieve_from_pinecone(query, top_k=5):
query_embedding = embed_text(query)
results = index.query(
namespace=PINECONE_NAMESPACE,
vector=query_embedding,
top_k=top_k,
include_metadata=True
)
retrieved_chunks = [match["metadata"]["chunk"] for match in results["matches"]]
return "\n".join(retrieved_chunks)
# Count tokens using Gemini API
def count_tokens(text):
try:
response = gemini_model.count_tokens(text)
return response.total_tokens
except exceptions.QuotaExceeded as e:
return 0 # If quota is exceeded, return 0 to avoid counting issues
# Generate answer using Gemini
def generate_answer(query, context):
prompt = f"""
You are a diabetes research assistant. Answer the following question based on the provided context. If the context is insufficient, use your knowledge to provide a helpful answer, but note if the information might be limited.
**Question**: {query}
**Context**:
{context}
**Answer**:
"""
try:
response = gemini_model.generate_content(prompt)
return response.text
except exceptions.QuotaExceeded as e:
return f"Error: Gemini API quota exceeded ({str(e)}). Try again later."
except Exception as e:
return f"Error generating answer: {str(e)}"
# Main function to handle user input
def diabetes_bot(query, pdf_file=None):
# Check usage limits
can_proceed, usage_info = check_usage()
if not can_proceed:
return usage_info
# Step 1: Get context from PDF if uploaded
pdf_context = ""
if pdf_file is not None:
pdf_context = extract_pdf_text(pdf_file)
if pdf_context:
pdf_context = f"Uploaded PDF content:\n{pdf_context}\n\n"
# Step 2: Retrieve relevant chunks from Pinecone
pinecone_context = retrieve_from_pinecone(query)
if pinecone_context:
pinecone_context = f"Pinecone retrieved content (latest research, 2010 onward):\n{pinecone_context}\n\n"
# Step 3: Combine contexts
full_context = pdf_context + pinecone_context
if not full_context.strip():
full_context = "No relevant context found in Pinecone or uploaded PDF."
# Step 4: Count tokens for the prompt
prompt = f"""
You are a diabetes research assistant. Answer the following question based on the provided context. If the context is insufficient, use your knowledge to provide a helpful answer, but note if the information might be limited.
**Question**: {query}
**Context**:
{full_context}
**Answer**:
"""
input_tokens = count_tokens(prompt)
if input_tokens == 0: # Quota exceeded during token counting
return "Error: Gemini API quota exceeded while counting tokens. Try again later."
# Update usage
usage = load_usage()
now = time.time()
usage["requests"].append(now)
usage["tokens"].append((now, input_tokens))
save_usage(usage["requests"], usage["tokens"])
# Step 5: Generate answer using Gemini
answer = generate_answer(query, full_context)
# Step 6: Count output tokens and update usage
output_tokens = count_tokens(answer)
if output_tokens == 0: # Quota exceeded during output token counting
return answer + "\n\nError: Gemini API quota exceeded while counting output tokens. Usage stats may be incomplete."
usage = load_usage()
usage["tokens"].append((now, output_tokens))
save_usage(usage["requests"], usage["tokens"])
# Step 7: Show usage stats
rpd, rpm, tpm = check_usage()[1]
usage_message = f"\n\nUsage: {rpd}/{FREE_TIER_RPD_LIMIT} requests today, {rpm}/{FREE_TIER_RPM_LIMIT} requests this minute, {tpm}/{FREE_TIER_TPM_LIMIT} tokens this minute."
return answer + usage_message
# Gradio interface
def chat_wrapper(query, pdf, history):
# Initialize history if empty
if history is None:
history = []
# If no query is provided, return the current history without changes
if query.strip() == "":
return history, "", None, history
# Call your existing diabetes_bot function to generate an answer
answer = diabetes_bot(query, pdf)
# Append the new interaction as a message-style tuple (role, content)
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": answer})
# Return the updated conversation and clear the query and pdf inputs
return history, "", None, history
def clear_all():
# Clear conversation history and inputs
return [], "", None, []
with gr.Blocks() as app:
gr.HTML(
"""
<h1 style="text-align:center;">Diabetes Research ChatBot Powered By Gemini 2.0 Flash And Pinecone 🩺</h1>
<p style="text-align:center;"><strong>Powered by the latest diabetes research, Running on Gemini 2.0 Flash API</strong></p>
<p style="text-align:center;">Ask questions about diabetes directly or upload a research paper (up to 10 pages) for specific Q&A.</p>
<br>
<div style="border: 1px solid #ccc; border-radius: 5px; padding: 10px; background-color: #f9f9f9; margin:auto; width:80%;">
<strong>Disclaimer:</strong>
The information provided by this chatbot is for research and informational purposes only and is not intended to substitute professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition.
</div>
<br>
"""
)
# Create a Chatbot component with type set to "messages" and a specified height
chatbot = gr.Chatbot(label="Conversation", type="messages", height=370)
# Input row for query and PDF file (with PDF box sized smaller)
with gr.Row():
query_input = gr.Textbox(label="Ask a Question", placeholder="Type your diabetes-related query here...", lines=2)
with gr.Column(scale=0.2):
pdf_input = gr.File(label="Upload a PDF (optional, max 10 pages)", file_types=[".pdf"])
# Row for Submit and Clear buttons
with gr.Row():
submit_button = gr.Button("Ask", variant="primary")
clear_button = gr.Button("Clear")
# State to maintain conversation history
state = gr.State([])
# On submit, update the conversation and clear inputs; outputs: chatbot, query_input, pdf_input, state
submit_button.click(
fn=chat_wrapper,
inputs=[query_input, pdf_input, state],
outputs=[chatbot, query_input, pdf_input, state]
)
# Clear all components including conversation history
clear_button.click(
fn=clear_all,
inputs=[],
outputs=[chatbot, query_input, pdf_input, state]
)
app.launch()