chat-with-SFF / app.py
ccm's picture
Update app.py
7623dc6 verified
import gradio # Interface handling
import spaces # GPU
import langchain_community.vectorstores # Vectorstore for publications
import langchain_huggingface # Embeddings
import transformers # LLM
# The number of publications to retrieve for the prompt
PUBLICATIONS_TO_RETRIEVE = 5
# The template for the RAG prompt
RAG_TEMPLATE = """You are an AI assistant who enjoys helping users learn about research.
Answer the USER_QUERY on additive manufacturing research using the RESEARCH_EXCERPTS.
Provide a concise ANSWER based on these excerpts. Avoid listing references.
===== RESEARCH_EXCERPTS =====
{research_excerpts}
===== USER_QUERY =====
{query}
===== ANSWER =====
"""
# Example Queries for Interface
EXAMPLE_QUERIES = [
{"text": "What is multi-material 3D printing?"},
{"text": "How is additive manufacturing being applied in aerospace?"},
{"text": "Tell me about innovations in metal 3D printing techniques."},
{"text": "What are some sustainable materials for 3D printing?"},
{"text": "What are the challenges with support structures in 3D printing?"},
{"text": "How is 3D printing impacting the medical field?"},
{"text": "What are common applications of additive manufacturing in industry?"},
{"text": "What are the benefits and limitations of using polymers in 3D printing?"},
{"text": "Tell me about the environmental impacts of additive manufacturing."},
{"text": "What are the primary limitations of current 3D printing technologies?"},
{"text": "How are researchers improving the speed of 3D printing processes?"},
{"text": "What are best practices for post-processing in additive manufacturing?"},
]
# Load vectorstore of SFF publications
publication_vectorstore = langchain_community.vectorstores.FAISS.load_local(
folder_path="publication_vectorstore",
embeddings=langchain_huggingface.HuggingFaceEmbeddings(
model_name="all-MiniLM-L12-v2",
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": False},
),
allow_dangerous_deserialization=True,
)
# Create the callable LLM
model = transformers.AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct-AWQ"
)
model.to("cuda") # Move the model to GPU
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-AWQ")
llm = transformers.pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
device="cuda",
)
def preprocess(query: str) -> str:
"""
Generates a prompt based on the top k documents matching the query.
Args:
query (str): The user's query.
Returns:
str: The formatted prompt containing research excerpts and the user's query.
"""
# Search for the top k documents matching the query
documents = publication_vectorstore.search(
query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity"
)
# Extract the page content from the documents
research_excerpts = [f'"... {doc.page_content}..."' for doc in documents]
# Format the prompt with the research excerpts and the user's query
prompt = RAG_TEMPLATE.format(
research_excerpts="\n\n".join(research_excerpts), query=query
)
return prompt
@spaces.GPU(duration=30)
def reply(message: str, history: list[str]) -> str:
"""
Generates a response to the user’s message.
Args:
message (str): The user's message or query.
history (list[str]): The conversation history.
Returns:
str: The generated response from the language model.
"""
# Preprocess the user's message
rag_prompt = preprocess(message)
# Generate a response from the language model
response = llm(rag_prompt, max_new_tokens=512, return_full_text=False)
# Return the generated response
return response[0]["generated_text"].strip("= ")
# Run the Gradio Interface
gradio.ChatInterface(
reply,
examples=EXAMPLE_QUERIES,
cache_examples=False,
chatbot=gradio.Chatbot(
show_label=False,
show_share_button=False,
show_copy_button=False,
bubble_full_width=False,
),
).launch(debug=True)