wop commited on
Commit
7416d8a
·
verified ·
1 Parent(s): 50d2a7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -40
app.py CHANGED
@@ -1,64 +1,84 @@
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
- import json
 
 
4
 
5
  client = InferenceClient(
6
  "mistralai/Mistral-7B-Instruct-v0.1"
7
  )
8
 
9
- DATABASE_PATH = "database.json"
10
-
11
- def load_database():
12
- try:
13
- with open(DATABASE_PATH, "r") as file:
14
- return json.load(file)
15
- except FileNotFoundError:
16
- return {}
17
-
18
- def save_database(database):
19
- with open(DATABASE_PATH, "w") as file:
20
- json.dump(database, file)
21
-
22
  def format_prompt(message, history):
23
- prompt = "<s>"
24
  for user_prompt, bot_response in history:
25
  prompt += f"[INST] {user_prompt} [/INST]"
26
  prompt += f" {bot_response}</s> "
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
- def generate(
31
- prompt, history, temperature=0.9, max_new_tokens=2000, top_p=0.9, repetition_penalty=1.2,
32
- ):
33
- database = load_database() # Load the database
34
  temperature = float(temperature)
35
  if temperature < 1e-2:
36
  temperature = 1e-2
37
  top_p = float(top_p)
38
 
 
 
 
 
 
 
 
 
 
39
  formatted_prompt = format_prompt(prompt, history)
40
- if formatted_prompt in database:
41
- response = database[formatted_prompt]
42
- else:
43
- response = client.text_generation(formatted_prompt, details=True, return_full_text=False)
44
- response_text = response.generated_tokens[0].text
45
- database[formatted_prompt] = response_text
46
- save_database(database) # Save the updated database
47
 
48
- yield response_text
 
49
 
50
- css = """
51
- #mkd {
52
- height: 500px;
53
- overflow: auto;
54
- border: 1px solid #ccc;
55
- }
56
- """
57
 
58
- with gr.Blocks(css=css) as demo:
59
- gr.ChatInterface(
60
- generate,
61
- examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]]
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- demo.launch(debug=True)
 
1
+ import json
2
  from huggingface_hub import InferenceClient
3
  import gradio as gr
4
+ import random
5
+
6
+ API_URL = "https://api-inference.huggingface.co/models/"
7
 
8
  client = InferenceClient(
9
  "mistralai/Mistral-7B-Instruct-v0.1"
10
  )
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def format_prompt(message, history):
13
+ prompt = "You're a helpful assistant."
14
  for user_prompt, bot_response in history:
15
  prompt += f"[INST] {user_prompt} [/INST]"
16
  prompt += f" {bot_response}</s> "
17
  prompt += f"[INST] {message} [/INST]"
18
  return prompt
19
 
20
+ def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
 
 
 
21
  temperature = float(temperature)
22
  if temperature < 1e-2:
23
  temperature = 1e-2
24
  top_p = float(top_p)
25
 
26
+ generate_kwargs = dict(
27
+ temperature=temperature,
28
+ max_new_tokens=max_new_tokens,
29
+ top_p=top_p,
30
+ repetition_penalty=repetition_penalty,
31
+ do_sample=True,
32
+ seed=random.randint(0, 10**7),
33
+ )
34
+
35
  formatted_prompt = format_prompt(prompt, history)
 
 
 
 
 
 
 
36
 
37
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
38
+ output = ""
39
 
40
+ for response in stream:
41
+ output += response.token.text
42
+ yield output
43
+ return output
 
 
 
44
 
45
+ def load_database():
46
+ try:
47
+ # Attempt to load the database from JSON
48
+ with open("database.json", "r", encoding="utf-8") as f:
49
+ return json.load(f)
50
+ except (FileNotFoundError, json.JSONDecodeError):
51
+ # Handle potential errors gracefully
52
+ print("Error loading database: File not found or invalid format. Creating an empty database.")
53
+ return [] # Return an empty list if database loading fails
54
+
55
+ def save_database(data):
56
+ try:
57
+ # Save the updated database to JSON
58
+ with open("database.json", "w", encoding="utf-8") as f:
59
+ json.dump(data, f, indent=4)
60
+ except (IOError, json.JSONEncodeError):
61
+ # Handle potential errors gracefully
62
+ print("Error saving database: Encountered an issue while saving.")
63
+
64
+ def chat_interface(message):
65
+ database = load_database()
66
+
67
+ # Check if the question already exists in the database
68
+ if (message, None) not in database:
69
+ # If not, generate a response and add it to the database
70
+ response = generate(message, history=[])
71
+ database.append((message, response))
72
+ save_database(database)
73
+ else:
74
+ # If it does, retrieve the stored response
75
+ _, stored_response = next(item for item in database if item[0] == message)
76
+ response = stored_response
77
+
78
+ return response
79
+
80
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
81
+ demo.register("message", gr.Textbox(label="Your question"))
82
+ demo.register("response", gr.Textbox(label="Assistant's response"))
83
 
84
+ demo.launch(fn=chat_interface, inputs=["message"], outputs=["response"])