timep12345 commited on
Commit
a743f3e
·
1 Parent(s): c5d1b72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -4,10 +4,12 @@ import json
4
 
5
  from langchain.document_loaders import DataFrameLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.llms import HuggingFaceHub
8
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
9
  from langchain.vectorstores import Chroma
10
  from langchain.chains import RetrievalQA
 
 
 
11
 
12
  from trafilatura import fetch_url, extract
13
  from trafilatura.spider import focused_crawler
@@ -50,8 +52,25 @@ def url_changes(url, pages_to_visit, urls_to_scrape, repo_id):
50
  persist_directory = './vector_db'
51
  db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
52
  retriever = db.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature":0.1, "max_new_tokens":250})
55
  global qa
56
  qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
57
  return "Ready"
 
4
 
5
  from langchain.document_loaders import DataFrameLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
7
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
8
  from langchain.vectorstores import Chroma
9
  from langchain.chains import RetrievalQA
10
+ from langchain import HuggingFacePipeline
11
+
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
 
14
  from trafilatura import fetch_url, extract
15
  from trafilatura.spider import focused_crawler
 
52
  persist_directory = './vector_db'
53
  db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
54
  retriever = db.as_retriever()
55
+
56
+ MODEL = 'beomi/KoAlpaca-Polyglot-5.8B'
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ MODEL,
59
+ torch_dtype=torch.float16,
60
+ low_cpu_mem_usage=True,
61
+ ).to(device=f"cuda", non_blocking=True)
62
+ model.eval()
63
+ pipe = pipeline(
64
+ 'text-generation',
65
+ model=model,
66
+ tokenizer=MODEL,
67
+ max_length=512,
68
+ temperature=0,
69
+ top_p=0.95,
70
+ repetition_penalty=1.15
71
+ )
72
+ llm = HuggingFacePipeline(pipeline=pipe)
73
 
 
74
  global qa
75
  qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
76
  return "Ready"