db / app.py
latterworks's picture
Update app.py
d46ba59 verified
import os
import gradio as gr
from huggingface_hub import InferenceClient
from noaa_incidents import NOAAIncidentDB, NOAAIncidentScraper
import json
from datetime import datetime
# Initialize Hugging Face client
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
class NOAAIncidentApp:
def __init__(self):
"""Initialize the NOAA Incident App with database and chatbot components."""
self.db = NOAAIncidentDB(persist_directory="noaa_db")
self.last_update = None
self._load_last_update_time()
def _load_last_update_time(self):
"""Load the last update time from metadata file."""
try:
if os.path.exists("metadata.json"):
with open("metadata.json", "r") as f:
metadata = json.load(f)
self.last_update = metadata.get("last_update")
except Exception as e:
print(f"Error loading metadata: {e}")
def _save_last_update_time(self):
"""Save the last update time to metadata file."""
try:
with open("metadata.json", "w") as f:
json.dump({"last_update": self.last_update}, f)
except Exception as e:
print(f"Error saving metadata: {e}")
def search_incidents(self, query, min_date=None, max_date=None, location_filter=None, num_results=5):
"""Search incidents with optional filters and return results."""
results = self.db.search(query, n_results=num_results)
filtered_results = []
for result in results:
if min_date and result['date'] < min_date:
continue
if max_date and result['date'] > max_date:
continue
if location_filter and location_filter.lower() not in result['location'].lower():
continue
filtered_results.append(result)
if not filtered_results:
return "No matching incidents found."
output = []
for i, result in enumerate(filtered_results, 1):
output.append(f"## Result {i}: {result['title']}")
output.append(f"**Date:** {result['date']}")
output.append(f"**Location:** {result['location']}")
output.append(f"**Details:** {result['details']}")
output.append("---\n")
return "\n".join(output)
def respond(self, message, history, system_message, max_tokens, temperature, top_p):
"""Generate chatbot responses or query the NOAA database based on user input."""
# Check if the message is a NOAA query
if "search noaa" in message.lower():
# Extract filters (basic implementation, can be expanded)
query = message.replace("search noaa", "").strip()
response = self.search_incidents(query=query, num_results=5)
return response
# Generate chatbot response
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
def refresh_database(self, progress=gr.Progress()):
"""Refresh the database with new incidents."""
try:
progress(0, desc="Initializing scraper...")
scraper = NOAAIncidentScraper(max_workers=5)
progress(0.2, desc="Scraping new incidents...")
csv_file, _ = scraper.run(validate_first=True)
if not csv_file:
return "Error: Failed to scrape new incidents."
progress(0.6, desc="Loading new data into database...")
num_loaded = self.db.load_incidents(csv_file)
self.last_update = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._save_last_update_time()
progress(1.0, desc="Complete!")
return f"Successfully refreshed database with {num_loaded} incidents."
except Exception as e:
return f"Error refreshing database: {str(e)}"
def create_interface(self):
"""Create the Gradio interface."""
with gr.Blocks(title="NOAA Incident & Chatbot App") as interface:
gr.Markdown("# NOAA Incident & Chatbot Application")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Chatbot Interaction")
system_message = gr.Textbox(
label="System Message", value="You are a friendly assistant."
)
chat_history = gr.State([])
message = gr.Textbox(label="Message")
max_tokens = gr.Slider(1, 2048, 512, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 4.0, 0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, 0.95, step=0.05, label="Top-p")
chat_btn = gr.Button("Send")
chat_md = gr.Markdown()
with gr.Column(scale=1):
refresh_btn = gr.Button("Refresh Database")
last_update_md = gr.Markdown(
f"*Last database update: {self.last_update or 'Never'}*"
)
chat_btn.click(
self.respond,
inputs=[
message,
chat_history,
system_message,
max_tokens,
temperature,
top_p,
],
outputs=chat_md,
)
refresh_btn.click(self.refresh_database, inputs=[], outputs=last_update_md)
return interface
# Run the app
app = NOAAIncidentApp()
demo = app.create_interface()
if __name__ == "__main__":
demo.launch()