Spaces:
Sleeping
Sleeping
import gradio # Interface handling | |
import spaces # GPU | |
import langchain_community.vectorstores # Vectorstore for publications | |
import langchain_huggingface # Embeddings | |
import transformers # LLM | |
# The number of publications to retrieve for the prompt | |
PUBLICATIONS_TO_RETRIEVE = 5 | |
# The template for the RAG prompt | |
RAG_TEMPLATE = """You are an AI assistant who enjoys helping users learn about research. | |
Answer the USER_QUERY on additive manufacturing research using the RESEARCH_EXCERPTS. | |
Provide a concise ANSWER based on these excerpts. Avoid listing references. | |
===== RESEARCH_EXCERPTS ===== | |
{research_excerpts} | |
===== USER_QUERY ===== | |
{query} | |
===== ANSWER ===== | |
""" | |
# Example Queries for Interface | |
EXAMPLE_QUERIES = [ | |
{"text": "What is multi-material 3D printing?"}, | |
{"text": "How is additive manufacturing being applied in aerospace?"}, | |
{"text": "Tell me about innovations in metal 3D printing techniques."}, | |
{"text": "What are some sustainable materials for 3D printing?"}, | |
{"text": "What are the challenges with support structures in 3D printing?"}, | |
{"text": "How is 3D printing impacting the medical field?"}, | |
{"text": "What are common applications of additive manufacturing in industry?"}, | |
{"text": "What are the benefits and limitations of using polymers in 3D printing?"}, | |
{"text": "Tell me about the environmental impacts of additive manufacturing."}, | |
{"text": "What are the primary limitations of current 3D printing technologies?"}, | |
{"text": "How are researchers improving the speed of 3D printing processes?"}, | |
{"text": "What are best practices for post-processing in additive manufacturing?"}, | |
] | |
# Load vectorstore of SFF publications | |
publication_vectorstore = langchain_community.vectorstores.FAISS.load_local( | |
folder_path="publication_vectorstore", | |
embeddings=langchain_huggingface.HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L12-v2", | |
model_kwargs={"device": "cuda"}, | |
encode_kwargs={"normalize_embeddings": False}, | |
), | |
allow_dangerous_deserialization=True, | |
) | |
# Create the callable LLM | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-7B-Instruct-AWQ" | |
) | |
model.to("cuda") # Move the model to GPU | |
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-AWQ") | |
llm = transformers.pipeline( | |
task="text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device="cuda", | |
) | |
def preprocess(query: str) -> str: | |
""" | |
Generates a prompt based on the top k documents matching the query. | |
Args: | |
query (str): The user's query. | |
Returns: | |
str: The formatted prompt containing research excerpts and the user's query. | |
""" | |
# Search for the top k documents matching the query | |
documents = publication_vectorstore.search( | |
query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity" | |
) | |
# Extract the page content from the documents | |
research_excerpts = [f'"... {doc.page_content}..."' for doc in documents] | |
# Format the prompt with the research excerpts and the user's query | |
prompt = RAG_TEMPLATE.format( | |
research_excerpts="\n\n".join(research_excerpts), query=query | |
) | |
return prompt | |
def reply(message: str, history: list[str]) -> str: | |
""" | |
Generates a response to the user’s message. | |
Args: | |
message (str): The user's message or query. | |
history (list[str]): The conversation history. | |
Returns: | |
str: The generated response from the language model. | |
""" | |
# Preprocess the user's message | |
rag_prompt = preprocess(message) | |
# Generate a response from the language model | |
response = llm(rag_prompt, max_new_tokens=512, return_full_text=False) | |
# Return the generated response | |
return response[0]["generated_text"].strip("= ") | |
# Run the Gradio Interface | |
gradio.ChatInterface( | |
reply, | |
examples=EXAMPLE_QUERIES, | |
cache_examples=False, | |
chatbot=gradio.Chatbot( | |
show_label=False, | |
show_share_button=False, | |
show_copy_button=False, | |
bubble_full_width=False, | |
), | |
).launch(debug=True) | |