Threatthriver commited on
Commit
fb5b02d
1 Parent(s): 8ebe706

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -63
app.py CHANGED
@@ -4,24 +4,29 @@ import random
4
 
5
  API_URL = "https://api-inference.huggingface.co/models/"
6
 
7
- client = InferenceClient(
8
- "mistralai/Mistral-7B-Instruct-v0.1"
9
- )
10
 
11
  def format_prompt(message, history):
12
- prompt = "<s>"
13
- for user_prompt, bot_response in history:
14
- prompt += f"[INST] {user_prompt} [/INST]"
15
- prompt += f" {bot_response}</s> "
16
- prompt += f"[INST] {message} [/INST]"
17
- return prompt
 
18
 
19
  def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
20
- temperature = float(temperature)
21
- if temperature < 1e-2:
22
- temperature = 1e-2
23
  top_p = float(top_p)
24
 
 
 
 
 
 
25
  generate_kwargs = dict(
26
  temperature=temperature,
27
  max_new_tokens=max_new_tokens,
@@ -31,67 +36,72 @@ def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, r
31
  seed=random.randint(0, 10**7),
32
  )
33
 
 
34
  formatted_prompt = format_prompt(prompt, history)
35
 
 
36
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
  output = ""
38
-
39
  for response in stream:
40
  output += response.token.text
41
  yield output
42
  return output
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- additional_inputs=[
46
- gr.Slider(
47
- label="Temperature",
48
- value=0.9,
49
- minimum=0.0,
50
- maximum=1.0,
51
- step=0.05,
52
- interactive=True,
53
- info="Higher values produce more diverse outputs",
54
- ),
55
- gr.Slider(
56
- label="Max new tokens",
57
- value=512,
58
- minimum=64,
59
- maximum=1024,
60
- step=64,
61
- interactive=True,
62
- info="The maximum numbers of new tokens",
63
- ),
64
- gr.Slider(
65
- label="Top-p (nucleus sampling)",
66
- value=0.90,
67
- minimum=0.0,
68
- maximum=1,
69
- step=0.05,
70
- interactive=True,
71
- info="Higher values sample more low-probability tokens",
72
- ),
73
- gr.Slider(
74
- label="Repetition penalty",
75
- value=1.2,
76
- minimum=1.0,
77
- maximum=2.0,
78
- step=0.05,
79
- interactive=True,
80
- info="Penalize repeated tokens",
81
- )
82
- ]
83
 
84
- customCSS = """
85
- #component-7 { # this is the default element ID of the chat component
86
- height: 800px; # adjust the height as needed
87
- flex-grow: 1;
88
- }
89
- """
90
 
91
- with gr.Blocks(css=customCSS) as demo:
92
- gr.ChatInterface(
93
- generate,
94
- additional_inputs=additional_inputs,
95
- )
96
 
97
- demo.queue().launch(debug=True)
 
 
4
 
5
  API_URL = "https://api-inference.huggingface.co/models/"
6
 
7
+ # Initialize the InferenceClient
8
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
 
9
 
10
  def format_prompt(message, history):
11
+ """Format the prompt for the text generation model."""
12
+ prompt = "<s>"
13
+ for user_prompt, bot_response in history:
14
+ prompt += f"[INST] {user_prompt} [/INST]"
15
+ prompt += f" {bot_response}</s> "
16
+ prompt += f"[INST] {message} [/INST]"
17
+ return prompt
18
 
19
  def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
20
+ """Generate a response using the text generation model."""
21
+ # Ensure temperature is not too low
22
+ temperature = max(float(temperature), 1e-2)
23
  top_p = float(top_p)
24
 
25
+ # Check if the prompt is asking who created the bot
26
+ if "who created you" in prompt.lower():
27
+ return "I was created by Aniket Kumar and many more."
28
+
29
+ # Set up parameters for text generation
30
  generate_kwargs = dict(
31
  temperature=temperature,
32
  max_new_tokens=max_new_tokens,
 
36
  seed=random.randint(0, 10**7),
37
  )
38
 
39
+ # Format the prompt
40
  formatted_prompt = format_prompt(prompt, history)
41
 
42
+ # Generate the response
43
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
44
  output = ""
 
45
  for response in stream:
46
  output += response.token.text
47
  yield output
48
  return output
49
 
50
+ def create_interface():
51
+ """Create the Gradio interface."""
52
+ additional_inputs=[
53
+ gr.Slider(
54
+ label="Temperature",
55
+ value=0.9,
56
+ minimum=0.0,
57
+ maximum=1.0,
58
+ step=0.05,
59
+ interactive=True,
60
+ info="Higher values produce more diverse outputs",
61
+ ),
62
+ gr.Slider(
63
+ label="Max new tokens",
64
+ value=512,
65
+ minimum=64,
66
+ maximum=1024,
67
+ step=64,
68
+ interactive=True,
69
+ info="The maximum numbers of new tokens",
70
+ ),
71
+ gr.Slider(
72
+ label="Top-p (nucleus sampling)",
73
+ value=0.90,
74
+ minimum=0.0,
75
+ maximum=1,
76
+ step=0.05,
77
+ interactive=True,
78
+ info="Higher values sample more low-probability tokens",
79
+ ),
80
+ gr.Slider(
81
+ label="Repetition penalty",
82
+ value=1.2,
83
+ minimum=1.0,
84
+ maximum=2.0,
85
+ step=0.05,
86
+ interactive=True,
87
+ info="Penalize repeated tokens",
88
+ )
89
+ ]
90
 
91
+ customCSS = """
92
+ #component-7 { # this is the default element ID of the chat component
93
+ height: 800px; # adjust the height as needed
94
+ flex-grow: 1;
95
+ }
96
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ with gr.Blocks(css=customCSS) as demo:
99
+ gr.ChatInterface(
100
+ generate,
101
+ additional_inputs=additional_inputs,
102
+ )
 
103
 
104
+ demo.queue().launch(debug=True)
 
 
 
 
105
 
106
+ # Run the application
107
+ create_interface()