Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from noaa_incidents import NOAAIncidentDB, NOAAIncidentScraper | |
import json | |
from datetime import datetime | |
# Initialize Hugging Face client | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
class NOAAIncidentApp: | |
def __init__(self): | |
"""Initialize the NOAA Incident App with database and chatbot components.""" | |
self.db = NOAAIncidentDB(persist_directory="noaa_db") | |
self.last_update = None | |
self._load_last_update_time() | |
def _load_last_update_time(self): | |
"""Load the last update time from metadata file.""" | |
try: | |
if os.path.exists("metadata.json"): | |
with open("metadata.json", "r") as f: | |
metadata = json.load(f) | |
self.last_update = metadata.get("last_update") | |
except Exception as e: | |
print(f"Error loading metadata: {e}") | |
def _save_last_update_time(self): | |
"""Save the last update time to metadata file.""" | |
try: | |
with open("metadata.json", "w") as f: | |
json.dump({"last_update": self.last_update}, f) | |
except Exception as e: | |
print(f"Error saving metadata: {e}") | |
def search_incidents(self, query, min_date=None, max_date=None, location_filter=None, num_results=5): | |
"""Search incidents with optional filters and return results.""" | |
results = self.db.search(query, n_results=num_results) | |
filtered_results = [] | |
for result in results: | |
if min_date and result['date'] < min_date: | |
continue | |
if max_date and result['date'] > max_date: | |
continue | |
if location_filter and location_filter.lower() not in result['location'].lower(): | |
continue | |
filtered_results.append(result) | |
if not filtered_results: | |
return "No matching incidents found." | |
output = [] | |
for i, result in enumerate(filtered_results, 1): | |
output.append(f"## Result {i}: {result['title']}") | |
output.append(f"**Date:** {result['date']}") | |
output.append(f"**Location:** {result['location']}") | |
output.append(f"**Details:** {result['details']}") | |
output.append("---\n") | |
return "\n".join(output) | |
def respond(self, message, history, system_message, max_tokens, temperature, top_p): | |
"""Generate chatbot responses or query the NOAA database based on user input.""" | |
# Check if the message is a NOAA query | |
if "search noaa" in message.lower(): | |
# Extract filters (basic implementation, can be expanded) | |
query = message.replace("search noaa", "").strip() | |
response = self.search_incidents(query=query, num_results=5) | |
return response | |
# Generate chatbot response | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
def refresh_database(self, progress=gr.Progress()): | |
"""Refresh the database with new incidents.""" | |
try: | |
progress(0, desc="Initializing scraper...") | |
scraper = NOAAIncidentScraper(max_workers=5) | |
progress(0.2, desc="Scraping new incidents...") | |
csv_file, _ = scraper.run(validate_first=True) | |
if not csv_file: | |
return "Error: Failed to scrape new incidents." | |
progress(0.6, desc="Loading new data into database...") | |
num_loaded = self.db.load_incidents(csv_file) | |
self.last_update = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
self._save_last_update_time() | |
progress(1.0, desc="Complete!") | |
return f"Successfully refreshed database with {num_loaded} incidents." | |
except Exception as e: | |
return f"Error refreshing database: {str(e)}" | |
def create_interface(self): | |
"""Create the Gradio interface.""" | |
with gr.Blocks(title="NOAA Incident & Chatbot App") as interface: | |
gr.Markdown("# NOAA Incident & Chatbot Application") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("### Chatbot Interaction") | |
system_message = gr.Textbox( | |
label="System Message", value="You are a friendly assistant." | |
) | |
chat_history = gr.State([]) | |
message = gr.Textbox(label="Message") | |
max_tokens = gr.Slider(1, 2048, 512, step=1, label="Max Tokens") | |
temperature = gr.Slider(0.1, 4.0, 0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, 0.95, step=0.05, label="Top-p") | |
chat_btn = gr.Button("Send") | |
chat_md = gr.Markdown() | |
with gr.Column(scale=1): | |
refresh_btn = gr.Button("Refresh Database") | |
last_update_md = gr.Markdown( | |
f"*Last database update: {self.last_update or 'Never'}*" | |
) | |
chat_btn.click( | |
self.respond, | |
inputs=[ | |
message, | |
chat_history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
], | |
outputs=chat_md, | |
) | |
refresh_btn.click(self.refresh_database, inputs=[], outputs=last_update_md) | |
return interface | |
# Run the app | |
app = NOAAIncidentApp() | |
demo = app.create_interface() | |
if __name__ == "__main__": | |
demo.launch() | |