wop commited on
Commit
7165422
·
verified ·
1 Parent(s): b5fc4fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -60
app.py CHANGED
@@ -1,73 +1,64 @@
1
- import json
2
- import gradio as gr
3
- import random
4
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
5
 
6
- API_URL = "https://api-inference.huggingface.co/models/"
 
 
 
 
 
7
 
8
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
 
 
9
 
10
  def format_prompt(message, history):
11
- prompt = "You're a helpful assistant."
12
  for user_prompt, bot_response in history:
13
- prompt += f" [INST] {user_prompt} [/INST] {bot_response}</s> "
14
- prompt += f" [INST] {message} [/INST]"
 
15
  return prompt
16
 
17
- def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
18
- temperature = float(temperature) if temperature > 0 else 0.01
 
 
 
 
 
19
  top_p = float(top_p)
20
 
21
- generate_kwargs = dict(
22
- temperature=temperature,
23
- max_new_tokens=max_new_tokens,
24
- top_p=top_p,
25
- repetition_penalty=repetition_penalty,
26
- do_sample=True,
27
- seed=random.randint(0, 10**7),
28
- )
29
-
30
  formatted_prompt = format_prompt(prompt, history)
31
-
32
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
33
- output = ""
34
-
35
- for response in stream:
36
- output += response.token.text
37
- yield output
38
-
39
- def load_database():
40
- try:
41
- with open("database.json", "r", encoding="utf-8") as f:
42
- data = json.load(f)
43
- if not isinstance(data, list):
44
- raise ValueError("Invalid data format")
45
- return data
46
- except (FileNotFoundError, json.JSONDecodeError, ValueError):
47
- print("Error loading database: File not found, invalid format, or empty. Creating an empty database.")
48
- return []
49
-
50
-
51
- def save_database(data):
52
- try:
53
- with open("database.json", "w", encoding="utf-8") as f:
54
- json.dump(data, f, indent=4)
55
- except (IOError, json.JSONEncodeError):
56
- print("Error saving database: Encountered an issue while saving.")
57
-
58
- def chat_interface(message):
59
- database = load_database()
60
-
61
- if (message, None) not in database:
62
- response = next(generate(message, history=[]))
63
- database.append((message, response))
64
- save_database(database)
65
  else:
66
- _, stored_response = next(item for item in database if item[0] == message)
67
- response = stored_response
68
-
69
- return response
70
-
71
- with gr.Interface(fn=chat_interface, inputs="textbox", outputs="textbox", title="Chat Interface") as iface:
72
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
1
  from huggingface_hub import InferenceClient
2
+ import gradio as gr
3
+ import json
4
+
5
+ client = InferenceClient(
6
+ "mistralai/Mistral-7B-Instruct-v0.1"
7
+ )
8
+
9
+ DATABASE_PATH = "database.json"
10
 
11
+ def load_database():
12
+ try:
13
+ with open(DATABASE_PATH, "r") as file:
14
+ return json.load(file)
15
+ except FileNotFoundError:
16
+ return {}
17
 
18
+ def save_database(database):
19
+ with open(DATABASE_PATH, "w") as file:
20
+ json.dump(database, file)
21
 
22
  def format_prompt(message, history):
23
+ prompt = "<s>"
24
  for user_prompt, bot_response in history:
25
+ prompt += f"[INST] {user_prompt} [/INST]"
26
+ prompt += f" {bot_response}</s> "
27
+ prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
+ def generate(
31
+ prompt, history, temperature=0.9, max_new_tokens=2000, top_p=0.9, repetition_penalty=1.2,
32
+ ):
33
+ database = load_database() # Load the database
34
+ temperature = float(temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
  top_p = float(top_p)
38
 
 
 
 
 
 
 
 
 
 
39
  formatted_prompt = format_prompt(prompt, history)
40
+ if formatted_prompt in database:
41
+ response = database[formatted_prompt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  else:
43
+ response = client.text_generation(formatted_prompt, details=True, return_full_text=False)
44
+ response_text = response.generated_tokens[0].text
45
+ database[formatted_prompt] = response_text
46
+ save_database(database) # Save the updated database
47
+
48
+ yield response_text
49
+
50
+ css = """
51
+ #mkd {
52
+ height: 500px;
53
+ overflow: auto;
54
+ border: 1px solid #ccc;
55
+ }
56
+ """
57
+
58
+ with gr.Blocks(css=css) as demo:
59
+ gr.ChatInterface(
60
+ generate,
61
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]]
62
+ )
63
 
64
+ demo.launch(debug=True)