Spaces:
Runtime error
Runtime error
File size: 6,285 Bytes
d6d9787 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextIteratorStreamer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from threading import Thread
from langchain_community.vectorstores.faiss import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import snapshot_download
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_NAME_OR_PATH = 'StevenChen16/llama3-8b-Lawyer'
DESCRIPTION = '''
<div style="display: flex; align-items: center; justify-content: center; text-align: center;">
<a href="https://wealthwizards.org/" target="_blank">
<img src="./images/logo.png" alt="Wealth Wizards Logo" style="width: 60px; height: auto; margin-right: 10px;">
</a>
<div style="display: inline-block; text-align: left;">
<h1 style="font-size: 36px; margin: 0;">AI Lawyer</h1>
<a href="https://wealthwizards.org/" target="_blank" style="text-decoration: none; color: inherit;">
<p style="font-size: 16px; margin: 0;">wealth wizards</p>
</a>
</div>
</div>
'''
LICENSE = """
<p/>
---
Built with model "StevenChen16/Llama3-8B-Lawyer", based on "meta-llama/Meta-Llama-3-8B"
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">AI Lawyer</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything about US and Canada law...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
# Load the model with disk offloading
print("Loading the model with disk offloading...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME_OR_PATH,
trust_remote_code=True,
low_cpu_mem_usage=True # Optimize memory usage during loading
)
# Specify an offload folder and map the model to disk and available GPUs
device_map = infer_auto_device_map(model, max_memory={"cpu": "50GB", "cuda:0": "16GB"})
dispatch_model(
model,
device_map=device_map,
offload_folder="./offload" # Folder for offloaded weights
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
# Embedding model and FAISS vector store
def create_embedding_model(model_name):
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'trust_remote_code': True}
)
embedding_model = create_embedding_model('intfloat/multilingual-e5-large-instruct')
try:
print("Downloading vector store from HuggingFace Hub...")
repo_path = snapshot_download(
repo_id="StevenChen16/laws.faiss",
repo_type="model"
)
print("Loading vector store...")
vector_store = FAISS.load_local(
folder_path=repo_path,
embeddings=embedding_model,
allow_dangerous_deserialization=True
)
print("Vector store loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
background_prompt = '''
As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law. Your purpose is to provide accurate, comprehensive, and professional legal information...
[Shortened for brevity]
'''
def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
"""
Query similar documents from vector store.
"""
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": relevance_threshold, "k": k}
)
similar_docs = retriever.invoke(query)
context = [doc.page_content for doc in similar_docs]
return " ".join(context) if context else ""
def chat_llama3_8b(message: str, history: list, temperature=0.6, max_new_tokens=4096) -> str:
"""
Generate a streaming response using the LLaMA model.
"""
citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
conversation = []
for user, assistant in history:
conversation.extend([
{"role": "user", "content": str(user)},
{"role": "assistant", "content": str(assistant)}
])
final_message = f"{background_prompt}\n{message}" if not citation else f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
conversation.append({"role": "user", "content": final_message})
input_ids = tokenizer.apply_chat_template(
conversation,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generation_config = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"eos_token_id": terminators
}
thread = Thread(target=model.generate, kwargs=generation_config)
thread.start()
accumulated_text = []
for text_chunk in streamer:
accumulated_text.append(text_chunk)
yield "".join(accumulated_text)
# Gradio interface
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
examples=[
['What are the key differences between a sole proprietorship and a partnership?'],
['What legal steps should I take if I want to start a business in the US?'],
['Can you explain the concept of "duty of care" in negligence law?'],
['What are the legal requirements for obtaining a patent in Canada?'],
['How can I protect my intellectual property when sharing my idea with potential investors?']
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()
|