Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| import os | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.prompts import PromptTemplate | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| import os | |
| from langchain_community.document_loaders import PyPDFLoader | |
| import os | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings.sentence_transformer import ( | |
| SentenceTransformerEmbeddings, | |
| ) | |
| from langchain_chroma import Chroma | |
| from sentence_transformers import SentenceTransformer | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| from fastapi import FastAPI, Request, UploadFile, File | |
| os.environ['HF_HOME'] = '/hug/cache/' | |
| os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/' | |
| app = FastAPI() | |
| def predict(message, db): | |
| llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) | |
| template = """You are a general purpose chatbot. Be friendly and kind. Help people answer their questions. Use the context below to answer the questions | |
| {context} | |
| Question: {question} | |
| Helpful Answer:""" | |
| QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| retriever = db.as_retriever(k=3) | |
| contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
| which might reference context in the chat history, formulate a standalone question \ | |
| which can be understood without the chat history. Do NOT answer the question, \ | |
| just reformulate it if needed and otherwise return it as is.""" | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{question}"), | |
| ] | |
| ) | |
| contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser() | |
| def contextualized_question(input: dict): | |
| if input.get("chat_history"): | |
| return contextualize_q_chain | |
| else: | |
| return input["question"] | |
| rag_chain = ( | |
| RunnablePassthrough.assign( | |
| context=contextualized_question | retriever | |
| ) | |
| | QA_CHAIN_PROMPT | |
| | llm | |
| ) | |
| history = [] | |
| ai_msg = rag_chain.invoke({"question": message, "chat_history": history}) | |
| print(ai_msg) | |
| bot_response = ai_msg.content.strip() | |
| # Ensure history is correctly formatted as a list of tuples (user_message, bot_response) | |
| history.append((HumanMessage(content=message), AIMessage(content=bot_response))) | |
| docs = db.similarity_search(message,k=3) | |
| extra = "\n" + "*"*100 + "\n" | |
| additional_info = [] | |
| for d in docs: | |
| citations = d.metadata["source"] + " pg." + str(d.metadata["page"]) | |
| additional_info = d.page_content | |
| extra += citations + "\n" + additional_info + "\n" + "*"*100 + "\n" | |
| # Return the bot's response and the updated history | |
| return bot_response + extra | |
| def upload_file(file_path): | |
| loaders = [] | |
| print(file_path) | |
| loaders.append(PyPDFLoader(file_path)) | |
| documents = [] | |
| for loader in loaders: | |
| documents.extend(loader.load()) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=16) | |
| docs = text_splitter.split_documents(documents) | |
| model = "thenlper/gte-large" | |
| embedding_function = SentenceTransformerEmbeddings(model_name=model) | |
| print(f"Model's maximum sequence length: {SentenceTransformer(model).max_seq_length}") | |
| collection_name = "Autism" | |
| persist_directory = "./chroma" | |
| print(len(docs)) | |
| db = Chroma.from_documents(docs, embedding_function) | |
| print("Done Processing, you can query") | |
| return db | |
| async def root(): | |
| return {"Entvin":"Version 1.0 'First Draft'"} | |
| def predicts(question: str, file: UploadFile = File(...)): | |
| contents = file.file.read() | |
| with open(file.filename, 'wb') as f: | |
| f.write(contents) | |
| db = upload_file(file.filename) | |
| result = predict(question, db) | |
| return {"answer":result} | |