Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import TextIteratorStreamer | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
from threading import Thread | |
from langchain_community.vectorstores.faiss import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from huggingface_hub import snapshot_download | |
# Set an environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL_NAME_OR_PATH = 'StevenChen16/llama3-8b-Lawyer' | |
DESCRIPTION = ''' | |
<div style="display: flex; align-items: center; justify-content: center; text-align: center;"> | |
<a href="https://wealthwizards.org/" target="_blank"> | |
<img src="./images/logo.png" alt="Wealth Wizards Logo" style="width: 60px; height: auto; margin-right: 10px;"> | |
</a> | |
<div style="display: inline-block; text-align: left;"> | |
<h1 style="font-size: 36px; margin: 0;">AI Lawyer</h1> | |
<a href="https://wealthwizards.org/" target="_blank" style="text-decoration: none; color: inherit;"> | |
<p style="font-size: 16px; margin: 0;">wealth wizards</p> | |
</a> | |
</div> | |
</div> | |
''' | |
LICENSE = """ | |
<p/> | |
--- | |
Built with model "StevenChen16/Llama3-8B-Lawyer", based on "meta-llama/Meta-Llama-3-8B" | |
""" | |
PLACEHOLDER = """ | |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">AI Lawyer</h1> | |
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything about US and Canada law...</p> | |
</div> | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: white; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
""" | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH) | |
# Load the model with disk offloading | |
print("Loading the model with disk offloading...") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME_OR_PATH, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True # Optimize memory usage during loading | |
) | |
# Specify an offload folder and map the model to disk and available GPUs | |
device_map = infer_auto_device_map(model, max_memory={"cpu": "50GB", "cuda:0": "16GB"}) | |
dispatch_model( | |
model, | |
device_map=device_map, | |
offload_folder="./offload" # Folder for offloaded weights | |
) | |
terminators = [ | |
tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
# Embedding model and FAISS vector store | |
def create_embedding_model(model_name): | |
return HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs={'trust_remote_code': True} | |
) | |
embedding_model = create_embedding_model('intfloat/multilingual-e5-large-instruct') | |
try: | |
print("Downloading vector store from HuggingFace Hub...") | |
repo_path = snapshot_download( | |
repo_id="StevenChen16/laws.faiss", | |
repo_type="model" | |
) | |
print("Loading vector store...") | |
vector_store = FAISS.load_local( | |
folder_path=repo_path, | |
embeddings=embedding_model, | |
allow_dangerous_deserialization=True | |
) | |
print("Vector store loaded successfully") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}") | |
background_prompt = ''' | |
As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law. Your purpose is to provide accurate, comprehensive, and professional legal information... | |
[Shortened for brevity] | |
''' | |
def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8): | |
""" | |
Query similar documents from vector store. | |
""" | |
retriever = vector_store.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"score_threshold": relevance_threshold, "k": k} | |
) | |
similar_docs = retriever.invoke(query) | |
context = [doc.page_content for doc in similar_docs] | |
return " ".join(context) if context else "" | |
def chat_llama3_8b(message: str, history: list, temperature=0.6, max_new_tokens=4096) -> str: | |
""" | |
Generate a streaming response using the LLaMA model. | |
""" | |
citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7) | |
conversation = [] | |
for user, assistant in history: | |
conversation.extend([ | |
{"role": "user", "content": str(user)}, | |
{"role": "assistant", "content": str(assistant)} | |
]) | |
final_message = f"{background_prompt}\n{message}" if not citation else f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}" | |
conversation.append({"role": "user", "content": final_message}) | |
input_ids = tokenizer.apply_chat_template( | |
conversation, | |
return_tensors="pt" | |
).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generation_config = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": temperature > 0, | |
"temperature": temperature, | |
"eos_token_id": terminators | |
} | |
thread = Thread(target=model.generate, kwargs=generation_config) | |
thread.start() | |
accumulated_text = [] | |
for text_chunk in streamer: | |
accumulated_text.append(text_chunk) | |
yield "".join(accumulated_text) | |
# Gradio interface | |
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, label='Gradio ChatInterface') | |
with gr.Blocks(fill_height=True, css=css) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.ChatInterface( | |
fn=chat_llama3_8b, | |
chatbot=chatbot, | |
fill_height=True, | |
examples=[ | |
['What are the key differences between a sole proprietorship and a partnership?'], | |
['What legal steps should I take if I want to start a business in the US?'], | |
['Can you explain the concept of "duty of care" in negligence law?'], | |
['What are the legal requirements for obtaining a patent in Canada?'], | |
['How can I protect my intellectual property when sharing my idea with potential investors?'] | |
], | |
cache_examples=False, | |
) | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
demo.launch() | |