Spaces:
Paused
Paused
Commit
·
e106a6d
1
Parent(s):
00662e9
add quantization config
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import csv
|
|
12 |
import json
|
13 |
import torch
|
14 |
from tqdm.auto import tqdm
|
|
|
15 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
|
17 |
|
@@ -33,6 +34,8 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
|
|
33 |
|
34 |
|
35 |
|
|
|
|
|
36 |
prompt_template = """
|
37 |
|
38 |
You are the chatbot and the face of Asian Institute of Technology (AIT). Your job is to give answers to prospective and current students about the school.
|
@@ -59,7 +62,10 @@ st.set_page_config(
|
|
59 |
page_title = 'aitGPT',
|
60 |
page_icon = '✅')
|
61 |
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
@st.cache_data
|
@@ -91,19 +97,21 @@ def load_faiss_index():
|
|
91 |
|
92 |
@st.cache_resource
|
93 |
def load_llm_model():
|
94 |
-
#
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
|
100 |
|
101 |
-
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
return llm
|
108 |
|
109 |
|
|
|
12 |
import json
|
13 |
import torch
|
14 |
from tqdm.auto import tqdm
|
15 |
+
from transformers import BitsAndBytesConfig
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
|
18 |
|
|
|
34 |
|
35 |
|
36 |
|
37 |
+
|
38 |
+
|
39 |
prompt_template = """
|
40 |
|
41 |
You are the chatbot and the face of Asian Institute of Technology (AIT). Your job is to give answers to prospective and current students about the school.
|
|
|
62 |
page_title = 'aitGPT',
|
63 |
page_icon = '✅')
|
64 |
|
65 |
+
bitsandbyte_config = BitsAndBytesConfig(
|
66 |
+
load_in_4bit=True,
|
67 |
+
bnb_4bit_quant_type="nf4",
|
68 |
+
bnb_4bit_compute_dtype=torch.float16)
|
69 |
|
70 |
|
71 |
@st.cache_data
|
|
|
97 |
|
98 |
@st.cache_resource
|
99 |
def load_llm_model():
|
100 |
+
#this one is for running with GPT
|
101 |
+
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
102 |
+
task= 'text2text-generation',
|
103 |
+
model_kwargs={ "device_map": "auto",
|
104 |
+
"max_length": 256, "temperature": 0,
|
105 |
+
"repetition_penalty": 1.5,
|
106 |
+
"quantization_config": bitsandbyte_config}) #add this quantization config
|
107 |
|
108 |
|
109 |
+
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
110 |
+
# task= 'text2text-generation',
|
111 |
|
112 |
+
# model_kwargs={ "max_length": 256, "temperature": 0,
|
113 |
+
# "torch_dtype":torch.float32,
|
114 |
+
# "repetition_penalty": 1.3})
|
115 |
return llm
|
116 |
|
117 |
|