amiraaaa123 commited on
Commit
42b72fb
·
1 Parent(s): a74068a

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +154 -0
main.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain.llms import HuggingFaceHub
4
+ from langchain.chains import LLMChain
5
+ from langchain.prompts import PromptTemplate
6
+
7
+ class UserInterface():
8
+
9
+ def __init__(self, ):
10
+ st.warning("Warning: Some models may not work and some models may require GPU to run")
11
+ st.text("An Open Source Chat Application")
12
+ st.header("Open LLMs")
13
+
14
+ # self.API_KEY = st.sidebar.text_input(
15
+ # 'API Key',
16
+ # type='password',
17
+ # help="Type in your HuggingFace API key to use this app"
18
+ # )
19
+
20
+ models_name = (
21
+ "HuggingFaceH4/zephyr-7b-beta",
22
+ "Sharathhebbar24/chat_gpt2",
23
+ "Sharathhebbar24/convo_bot_gpt_v1",
24
+ "Open-Orca/Mistral-7B-OpenOrca",
25
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26
+ "Sharathhebbar24/llama_7b_chat",
27
+ "CultriX/MistralTrix-v1",
28
+ "ahxt/LiteLlama-460M-1T",
29
+ )
30
+ self.models = st.sidebar.selectbox(
31
+ label="Choose your models",
32
+ options=models_name,
33
+ help="Choose your model",
34
+ )
35
+
36
+ self.temperature = st.sidebar.slider(
37
+ label='Temperature',
38
+ min_value=0.1,
39
+ max_value=1.0,
40
+ step=0.1,
41
+ value=0.5,
42
+ help="Set the temperature to get accurate or random result"
43
+ )
44
+
45
+ self.max_token_length = st.sidebar.slider(
46
+ label="Token Length",
47
+ min_value=32,
48
+ max_value=2048,
49
+ step=16,
50
+ value=64,
51
+ help="Set max tokens to generate maximum amount of text output"
52
+ )
53
+
54
+
55
+ self.model_kwargs = {
56
+ "temperature": self.temperature,
57
+ "max_new_tokens": self.max_token_length
58
+ }
59
+
60
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.getenv("HF_KEY")
61
+
62
+
63
+ def form_data(self):
64
+
65
+ try:
66
+ # if not self.API_KEY.startswith('hf_'):
67
+ # st.warning('Please enter your API key!', icon='⚠')
68
+ # text_input_visibility = True
69
+ # else:
70
+ # text_input_visibility = False
71
+ text_input_visibility = False
72
+
73
+
74
+ if "messages" not in st.session_state:
75
+ st.session_state.messages = []
76
+
77
+ st.write(f"You are using {self.models} model")
78
+
79
+ for message in st.session_state.messages:
80
+ with st.chat_message(message.get('role')):
81
+ st.write(message.get("content"))
82
+
83
+ context = st.sidebar.text_input(
84
+ label="Context",
85
+ help="Context lets you know on what the answer should be generated"
86
+ )
87
+
88
+
89
+ question = st.chat_input(
90
+ key="question",
91
+ disabled=text_input_visibility
92
+ )
93
+
94
+ template = f"<|system|>\nYou are a intelligent chatbot and expertise in {context}.</s>\n<|user|>\n{question}.\n<|assistant|>"
95
+
96
+ # template = """
97
+ # Answer the question based on the context, if you don't know then output "Out of Context"
98
+ # Context: {context}
99
+ # Question: {question}
100
+
101
+ # Answer:
102
+ # """
103
+ prompt = PromptTemplate(
104
+ template=template,
105
+ input_variables=[
106
+ 'question',
107
+ 'context'
108
+ ]
109
+ )
110
+ llm = HuggingFaceHub(
111
+ repo_id = self.models,
112
+ model_kwargs = self.model_kwargs
113
+ )
114
+
115
+ if question:
116
+ llm_chain = LLMChain(
117
+ prompt=prompt,
118
+ llm=llm,
119
+ )
120
+
121
+ result = llm_chain.run({
122
+ "question": question,
123
+ "context": context
124
+ })
125
+
126
+ if "Out of Context" in result:
127
+ result = "Out of Context"
128
+ st.session_state.messages.append(
129
+ {
130
+ "role":"user",
131
+ "content": f"Context: {context}\n\nQuestion: {question}"
132
+ }
133
+ )
134
+ with st.chat_message("user"):
135
+ st.write(f"Context: {context}\n\nQuestion: {question}")
136
+
137
+ if question.lower() == "clear":
138
+ del st.session_state.messages
139
+ return
140
+
141
+ st.session_state.messages.append(
142
+ {
143
+ "role": "assistant",
144
+ "content": result
145
+ }
146
+ )
147
+ with st.chat_message('assistant'):
148
+ st.markdown(result)
149
+
150
+ except Exception as e:
151
+ st.error(e, icon="🚨")
152
+
153
+ model = UserInterface()
154
+ model.form_data()