FuturesonyAi / app.py
Futuresony's picture
Update app.py
bf2110c verified
raw
history blame
3.01 kB
import gradio as gr
import os
import faiss
import torch
from huggingface_hub import InferenceClient, hf_hub_download
from sentence_transformers import SentenceTransformer
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
# Hugging Face Credentials
HF_REPO = "Futuresony/future_ai_12_10_2024.gguf" # Your model repo
HF_FAISS_REPO = "Futuresony/future_ai_12_10_2024.gguf" # Your FAISS repo
HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') # API token from env
# Load FAISS Index
faiss_index_path = hf_hub_download(
repo_id=HF_FAISS_REPO,
filename="asa_faiss.index",
repo_type="model",
token=HF_TOKEN
)
faiss_index = faiss.read_index(faiss_index_path)
# Load Sentence Transformer for embedding queries
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Hugging Face Model Client
client = InferenceClient(
model=HF_REPO,
token=HF_TOKEN
)
# Function to retrieve relevant context from FAISS
def retrieve_context(query, top_k=3):
"""Retrieve relevant past knowledge using FAISS"""
query_embedding = embed_model.encode([query], convert_to_tensor=True).cpu().numpy()
distances, indices = faiss_index.search(query_embedding, top_k)
# Convert indices to retrieved text (simulate as FAISS only returns IDs)
retrieved_context = "\n".join([f"Context {i+1}: Retrieved data for index {idx}" for i, idx in enumerate(indices[0])])
return retrieved_context
# Function to format input in Alpaca style
def format_alpaca_prompt(user_input, system_prompt, history):
"""Formats input in Alpaca/LLaMA style"""
retrieved_context = retrieve_context(user_input) # Retrieve past knowledge
history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
prompt = f"""{system_prompt}
{history_str}
### Instruction:
{user_input}
### Retrieved Context:
{retrieved_context}
### Response:
"""
return prompt
# Chatbot response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
formatted_prompt = format_alpaca_prompt(message, system_message, history)
response = client.text_generation(
formatted_prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
# Extract only the response
cleaned_response = response.split("### Response:")[-1].strip()
history.append((message, cleaned_response)) # Update chat history
yield cleaned_response # Output only the answer
# Gradio Chat Interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful AI.", label="System message"),
gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()