ashutoshzade commited on
Commit
20fce2f
·
verified ·
1 Parent(s): 5be6d03

Create app.py

Browse files

Initial version

Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import HuggingFacePipeline
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.chains import RetrievalQA
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain_community.document_loaders import TextLoader
7
+ from langchain.text_splitter import CharacterTextSplitter
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+
10
+ # Load Gemma model and tokenizer
11
+ model_name = "google/gemma-2b"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(model_name)
14
+
15
+ # Create a text generation pipeline
16
+ text_generation_pipeline = pipeline(
17
+ "text-generation",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ max_new_tokens=512,
21
+ temperature=0.7
22
+ )
23
+
24
+ # Create a LangChain LLM from the pipeline
25
+ llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
26
+
27
+ # Load and process documents
28
+ loader = TextLoader("https://en.wikipedia.org/wiki/Cheetah")
29
+ documents = loader.load()
30
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
31
+ texts = text_splitter.split_documents(documents)
32
+
33
+ # Create embeddings and vector store
34
+ embeddings = HuggingFaceEmbeddings()
35
+ db = Chroma.from_documents(texts, embeddings)
36
+
37
+ # Create a retriever
38
+ retriever = db.as_retriever()
39
+
40
+ # Create a prompt template
41
+ template = """Use the following pieces of context to answer the question at the end.
42
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
43
+
44
+ {context}
45
+
46
+ Question: {question}
47
+ Answer:"""
48
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
49
+
50
+ # Create the RetrievalQA chain
51
+ qa_chain = RetrievalQA.from_chain_type(
52
+ llm=llm,
53
+ chain_type="stuff",
54
+ retriever=retriever,
55
+ return_source_documents=True,
56
+ chain_type_kwargs={"prompt": prompt}
57
+ )
58
+
59
+ # Example query
60
+ query = "How fast cheetah can run?"
61
+ result = qa_chain({"query": query})
62
+
63
+ print("Question:", query)
64
+ print("Answer:", result["result"])