ubermenchh commited on
Commit
257b0ba
·
1 Parent(s): 5b56653

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
2
+ import os, pincone, time, transformers
3
+ from datasets import load_dataset
4
+ from torch import bfloat16
5
+ from langchain.llms import HuggingFacePipeline
6
+ from langchain.vectorstores import Pinecone
7
+ from langchain.chains import RetrievalQA
8
+
9
+ embed_model_id = 'sentence-transformers/all-MiniLM-L6-v2'
10
+ device = 'cpu'
11
+
12
+ embed_model = HuggingFaceEmbeddings(
13
+ model_name=embed_model_id,
14
+ model_kwargs={'device': device},
15
+ encode_kwargs={'device': device, 'batch_size': 32}
16
+ )
17
+
18
+ api_key = os.environ.get('PINCONE_API_KEY')
19
+ env_name = os.environ.get('PINECONE_ENV')
20
+
21
+ pinecone.init(
22
+ api_key=api_key,
23
+ environment=env_name
24
+ )
25
+
26
+ index_name = 'llama-2-rag-2.0'
27
+
28
+ if index_name not in pinecone.list_indexes():
29
+ pinecone.create_index(
30
+ index_name,
31
+ dimension=len(embeddings[0]),
32
+ metric='cosine'
33
+ )
34
+ while not pinecone.describe_index(index_name).status['ready']:
35
+ time.sleep(1)
36
+
37
+ index = pinecone.Index(index_name)
38
+
39
+ data = load_dataset('jamescalam/llama-2-arxiv-papers-chunked', split='train')
40
+ data = data.to_pandas()
41
+ batch_size = 32
42
+
43
+ for i in range(0, len(data), batch_size):
44
+ i_end = min(len(data), i+batch_size)
45
+ batch = data.iloc[i:i_end]
46
+ ids = [f"{x['doi']}-{x['chunk-id']}" for i, x in batch.iterrows()]
47
+ texts = [x['chunk'] for i, x in batch.iterrows()]
48
+ embeds = embed_model.embed_documents(texts)
49
+ metadata = [
50
+ {'text': x['chunk'],
51
+ 'source': x['source'],
52
+ 'title': x['title']} for i, x in batch.iterrows()
53
+ ]
54
+ index.upsert(vectors=zip(ids, embeds, metadata))
55
+
56
+ model_id = 'meta-llama/Llama-2-7b-chat-hf'
57
+
58
+ bnb_config = transformers.BitsAndBytesConfig(
59
+ load_in_4bit=True,
60
+ bnb_4bit_quant_type='nf4',
61
+ bnb_4bit_use_double_quant=True,
62
+ bnb_4bit_compute_dtype=bfloat16
63
+ )
64
+ model_config = transformers.AutoConfig.from_pretrained(model_id)
65
+ model = transformers.AutoModelForCausalLM.from_pretrained(
66
+ model_id,
67
+ trust_remote_code=True,
68
+ config=model_config,
69
+ quantization_config=bnb_config,
70
+ device_map='auto'
71
+ )
72
+ model.eval()
73
+
74
+ tokenizer = transformer.AutoTokenizer.from_pretrained(model_id)
75
+
76
+ generate_text = transformers.pipeline(
77
+ model=model,
78
+ tokenizer=tokenizer,
79
+ return_full_text=True,
80
+ task='text-generation',
81
+ temperature=0.3,
82
+ max_new_tokens=512,
83
+ repetition_penalty=1.1
84
+ )
85
+ llm = HuggingFacePipeline(pipeline=generate_text)
86
+ text_field = 'text'
87
+ vectorstore = Pinecone(index, embed_model.embed_query, text_field)
88
+ rag_pipeline = RetrievalQA.from_chain_type(
89
+ llm=llm,
90
+ chain_type='stuff',
91
+ retriever=vectorstore.as_retriever()
92
+ )
93
+
94
+ title = 'arxiv-retrieval'
95
+
96
+ def predict(input):
97
+ return rag_pipeline(input)
98
+
99
+ gr.Interface(
100
+ fn=predict,
101
+ inputs=['text', 'state'],
102
+ outputs=['chatbot', 'state']
103
+ ).launch()