damerajee commited on
Commit
c232079
·
verified ·
1 Parent(s): d57ca26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from mingru_lm import MinGRU_LM
4
+
5
+
6
+ model = MinGRU_LM(dim=512, num_tokens=256, num_layers=6)
7
+ pt_model = "model/best_model.pt"
8
+ checkpoint = torch.load(pt_model)
9
+ model.load_state_dict(checkpoint['model_state_dict'])
10
+
11
+ # Move model to GPU if available
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ model = model.to(device)
14
+
15
+ def decode_tokens(tokens):
16
+ return ''.join([chr(token) for token in tokens if token >= 32 and token < 256]) # ASCII-safe decoding
17
+
18
+ def tokenize_text(text):
19
+ return [ord(char) for char in text if ord(char) < 256] # ASCII-safe tokenization
20
+
21
+ def generate_text(start_text, max_length, temperature):
22
+ model.eval()
23
+
24
+ tokens = tokenize_text(start_text)
25
+ input_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) # Ensure long tensor
26
+
27
+ generated_tokens = tokens.copy()
28
+
29
+ with torch.no_grad():
30
+ for _ in range(max_length):
31
+ _, logits = model(input_tensor, labels=None)
32
+
33
+ last_token_logits = logits[0, -1, :] / temperature
34
+ probs = torch.softmax(last_token_logits, dim=-1)
35
+ next_token = torch.multinomial(probs, num_samples=1).item()
36
+
37
+ # Only append if it's within the 256-character ASCII range
38
+ if next_token < 256:
39
+ generated_tokens.append(next_token)
40
+ input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=device)], dim=1)
41
+ else:
42
+ continue # Skip tokens outside ASCII range
43
+
44
+ return decode_tokens(generated_tokens)
45
+
46
+ # Gradio interface
47
+ iface = gr.Interface(
48
+ fn=generate_text,
49
+ inputs=[
50
+ gr.Textbox(lines=3, label="Enter your prompt", value="Once upon a time"),
51
+ gr.Slider(minimum=10, maximum=500, value=200, step=1, label="Max Length"),
52
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
53
+ ],
54
+ outputs=gr.Textbox(lines=10, label="Generated Text"),
55
+ title="Text Generation with MinGRU_LM",
56
+ description="Enter a prompt and adjust parameters to generate text using the MinGRU_LM model."
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ iface.launch()