pro-grammer commited on
Commit
3e6a8ce
·
verified ·
1 Parent(s): 514a64a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -21
app.py CHANGED
@@ -2,36 +2,60 @@ import gradio as gr
2
  import torch
3
  import tiktoken
4
 
5
- from model import GPTLanguageModel
6
 
7
  # Load the model and tokenizer
8
  def load_model():
9
  """Load the trained GPT model"""
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
11
  GPT_CONFIG = {
12
- "vocab_size" : 50257,
13
- "n_heads" : 8,
14
- "n_layers" : 6,
15
- "head_size" : 64,
16
- "n_embd" : 512,
17
- "block_size" : 128,
18
- "dropout" : 0.1,
19
- "learning_rate" : 3e-4,
20
- "weight_decay" : 0.1,
21
  }
22
- model = GPTLanguageModel(GPT_CONFIG)
 
 
 
 
 
 
 
 
 
 
23
  model.load_state_dict(torch.load("model_weights.pth", map_location=device))
24
  model.to(device)
25
- model.eval()
 
 
26
  tokenizer = tiktoken.get_encoding("gpt2")
 
27
  return model, tokenizer, device
28
 
29
- # Load model globally
30
  model, tokenizer, device = load_model()
31
 
32
- # Define the respond function
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def respond(message, history: list[tuple[str, str]], system_message, max_tokens):
34
- # Build message history with system message
35
  messages = [{"role": "system", "content": system_message}]
36
 
37
  for val in history:
@@ -43,8 +67,11 @@ def respond(message, history: list[tuple[str, str]], system_message, max_tokens)
43
  # Add the user message to the conversation
44
  messages.append({"role": "user", "content": message})
45
 
 
 
 
46
  # Convert the latest user message to token IDs
47
- input_ids = text_to_token_ids(message, tokenizer).to(device)
48
 
49
  # Generate the response from the model
50
  token_ids = generate_text(
@@ -56,16 +83,17 @@ def respond(message, history: list[tuple[str, str]], system_message, max_tokens)
56
 
57
  # Convert the token IDs back to text and return
58
  response_text = token_ids_to_text(token_ids, tokenizer)
 
59
  return response_text
60
 
61
- # Gradio ChatInterface
62
  demo = gr.ChatInterface(
63
- respond,
64
  additional_inputs=[
65
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
66
- gr.Slider(minimum=1, maximum=256, value=50, step=1, label="Max new tokens")
67
  ]
68
  )
69
 
70
  if __name__ == "__main__":
71
- demo.launch(share=True)
 
2
  import torch
3
  import tiktoken
4
 
5
+ from model import GPTLanguageModel # Import the model from model.py
6
 
7
  # Load the model and tokenizer
8
  def load_model():
9
  """Load the trained GPT model"""
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
  GPT_CONFIG = {
13
+ "vocab_size": 50257,
14
+ "n_heads": 8,
15
+ "n_layers": 6,
16
+ "n_embd": 512,
17
+ "block_size": 128,
18
+ "dropout": 0.1,
 
 
 
19
  }
20
+
21
+ model = GPTLanguageModel(
22
+ GPT_CONFIG["vocab_size"],
23
+ GPT_CONFIG["n_embd"],
24
+ GPT_CONFIG["block_size"],
25
+ GPT_CONFIG["n_layers"],
26
+ GPT_CONFIG["n_heads"],
27
+ device
28
+ )
29
+
30
+ # Load the trained weights
31
  model.load_state_dict(torch.load("model_weights.pth", map_location=device))
32
  model.to(device)
33
+ model.eval() # Set the model to evaluation mode
34
+
35
+ # Use tiktoken for tokenization
36
  tokenizer = tiktoken.get_encoding("gpt2")
37
+
38
  return model, tokenizer, device
39
 
40
+ # Load the model globally
41
  model, tokenizer, device = load_model()
42
 
43
+ # Tokenization and detokenization functions
44
+ def text_to_token_ids(text, tokenizer):
45
+ return torch.tensor([tokenizer.encode(text)], dtype=torch.long)
46
+
47
+ def token_ids_to_text(token_ids, tokenizer):
48
+ return tokenizer.decode(token_ids[0].tolist())
49
+
50
+ # Generate text function using the model
51
+ def generate_text(model, idx, max_new_tokens, context_size=256):
52
+ # Call the model's generate function
53
+ token_ids = model.generate(idx, max_new_tokens)
54
+ return token_ids
55
+
56
+ # Define the response function
57
  def respond(message, history: list[tuple[str, str]], system_message, max_tokens):
58
+ # Build the message history with the system message
59
  messages = [{"role": "system", "content": system_message}]
60
 
61
  for val in history:
 
67
  # Add the user message to the conversation
68
  messages.append({"role": "user", "content": message})
69
 
70
+ # Concatenate the history into one context
71
+ conversation_history = " ".join([msg["content"] for msg in messages])
72
+
73
  # Convert the latest user message to token IDs
74
+ input_ids = text_to_token_ids(conversation_history, tokenizer).to(device)
75
 
76
  # Generate the response from the model
77
  token_ids = generate_text(
 
83
 
84
  # Convert the token IDs back to text and return
85
  response_text = token_ids_to_text(token_ids, tokenizer)
86
+
87
  return response_text
88
 
89
+ # Gradio Chat Interface
90
  demo = gr.ChatInterface(
91
+ fn=respond,
92
  additional_inputs=[
93
+ gr.Textbox(value="You are a friendly chatbot.", label="System message"), # System message input
94
+ gr.Slider(minimum=1, maximum=256, value=50, step=1, label="Max new tokens") # Max tokens slider
95
  ]
96
  )
97
 
98
  if __name__ == "__main__":
99
+ demo.launch(share=True)