Spaces:
Runtime error
Runtime error
Initial commit
Browse files
app.py
CHANGED
@@ -1,63 +1,105 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
temperature=temperature,
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
""
|
45 |
-
demo = gr.ChatInterface(
|
46 |
-
respond,
|
47 |
-
additional_inputs=[
|
48 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
49 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
50 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
51 |
-
gr.Slider(
|
52 |
-
minimum=0.1,
|
53 |
-
maximum=1.0,
|
54 |
-
value=0.95,
|
55 |
-
step=0.05,
|
56 |
-
label="Top-p (nucleus sampling)",
|
57 |
-
),
|
58 |
-
],
|
59 |
-
)
|
60 |
|
|
|
|
|
|
|
61 |
|
62 |
if __name__ == "__main__":
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import spaces
|
4 |
+
|
5 |
import gradio as gr
|
6 |
+
|
7 |
+
import json
|
8 |
+
from threading import Thread
|
9 |
+
import torch
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
11 |
+
|
12 |
+
MAX_LENGTH = 4096
|
13 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
14 |
+
|
15 |
+
|
16 |
+
def parse_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--base_model", type=str) # model path
|
19 |
+
parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
@spaces.GPU()
|
23 |
+
def predict(message, history, system_prompt, temperature, max_tokens):
|
24 |
+
global model, tokenizer, device
|
25 |
+
messages = [{'role': 'system', 'content': system_prompt}]
|
26 |
+
for human, assistant in history:
|
27 |
+
messages.append({'role': 'user', 'content': human})
|
28 |
+
messages.append({'role': 'assistant', 'content': assistant})
|
29 |
+
messages.append({'role': 'user', 'content': message})
|
30 |
+
problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
|
31 |
+
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
|
32 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
|
33 |
+
enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
|
34 |
+
input_ids = enc.input_ids
|
35 |
+
attention_mask = enc.attention_mask
|
36 |
+
|
37 |
+
if input_ids.shape[1] > MAX_LENGTH:
|
38 |
+
input_ids = input_ids[:, -MAX_LENGTH:]
|
39 |
+
|
40 |
+
input_ids = input_ids.to(device)
|
41 |
+
attention_mask = attention_mask.to(device)
|
42 |
+
generate_kwargs = dict(
|
43 |
+
{"input_ids": input_ids, "attention_mask": attention_mask},
|
44 |
+
streamer=streamer,
|
45 |
+
do_sample=True,
|
46 |
+
top_p=0.95,
|
47 |
temperature=temperature,
|
48 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
49 |
+
use_cache=True,
|
50 |
+
eos_token_id=tokenizer.eos_token_id # <|im_end|>
|
51 |
+
)
|
52 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
53 |
+
t.start()
|
54 |
+
outputs = []
|
55 |
+
for text in streamer:
|
56 |
+
outputs.append(text)
|
57 |
+
yield "".join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
def submit_correction(original_answer, corrected_answer):
|
60 |
+
# No operation function for the submit button click event
|
61 |
+
return "Correction submitted!"
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
+
args = parse_args()
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained("lliu01/fortios_one_config")
|
66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
67 |
+
"lliu01/fortios_one_config",
|
68 |
+
torch_dtype=torch.bfloat16,
|
69 |
+
low_cpu_mem_usage=True
|
70 |
+
)
|
71 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
72 |
+
model = model.to(device)
|
73 |
+
|
74 |
+
correct_answer = gr.Textbox(label="Correct Answer", placeholder="Enter the correct answer if the provided one is wrong")
|
75 |
+
submit_btn = gr.Button("Submit Correction")
|
76 |
+
submit_btn.click(fn=submit_correction, inputs=[chatbot, correct_answer], outputs="text")
|
77 |
+
|
78 |
+
gr.ChatInterface(
|
79 |
+
predict,
|
80 |
+
title="FortiOS CLI Chat - Demo",
|
81 |
+
description="FortiOS CLI Chat",
|
82 |
+
theme="soft",
|
83 |
+
chatbot=gr.Chatbot(label="Chat History",),
|
84 |
+
textbox=gr.Textbox(placeholder="input", container=False, scale=7),
|
85 |
+
retry_btn=None,
|
86 |
+
undo_btn="Delete Previous",
|
87 |
+
clear_btn="Clear",
|
88 |
+
additional_inputs=[
|
89 |
+
gr.Textbox(sys_prompt, label="System Prompt"),
|
90 |
+
gr.Slider(0, 1, 0.5, label="Temperature"),
|
91 |
+
gr.Slider(100, 2048, 1024, label="Max Tokens"),
|
92 |
+
correct_answer,
|
93 |
+
submit_btn,
|
94 |
+
],
|
95 |
+
examples=[
|
96 |
+
["Allow all traffic from any source IP address and any source interface 'port10' to any destination IP address and any destination interface 'port9'. This policy will be applied at all times (always) and will allow all services. Additionally, this policy will enable UTM features, use proxy-based inspection mode, and use an SSL-SSH profile named 'deep-custom'. Finally, this policy will also enable source NAT."],
|
97 |
+
["Configure a firewall policy to allow users 'dina' and '15947' to access 'DR-Exchange-Servers' and 'HQ-Exchange-Servers' using RDP protocol from the 'SSL-VPN-IT-Pool' address range, incoming from the 'ssl.FG-Traffic' interface and outgoing to the 'FG-PA-Inside' interface. The policy should have Antivirus scanning enabled with profile 'ABE_AV' and log all traffic. The policy should be always active and currently disabled for testing or maintenance purposes."],
|
98 |
+
["Configure a firewall policy named 'ZoomAccess' that allows traffic from the 'IP_10.96.54.149' and 'HighCourt_Zoom' addresses coming in through the 'VLAN51' interface to access the 'Zoom_access' destination through the 'npu0_vlink1' interface, at any time, with all services allowed, using proxy-based inspection and SSL certificate inspection."],
|
99 |
+
["Create a dynamic firewall address object named 'EMS2_ZTNA_Condiciones-Clinic' that is based on a FortiClient EMS tag. This object will be used to represent a group of devices that have the 'Condiciones-Clinic' tag in the EMS system, which is related to zero-trust access control (ZTNA)."],
|
100 |
+
["The user wants to create a dynamic firewall address object named 'Pre-Prod DMN Servers' that retrieves IP addresses from a VMware vCenter SDN (Software-Defined Networking) environment. The object will dynamically include IP addresses that match the filter criteria 'Name=b4dmn*' from the vCenter inventory. Specifically, the object will include the following IP addresses: 172.21.121.44, 172.21.121.45, 172.21.121.46, 172.21.121.47, 172.21.121.48, and 172.21.121.49, each with associated object IDs and network IDs for further identification and grouping."],
|
101 |
+
["The user wants to create a traffic shaper named 'Videoconferencia' that limits the maximum bandwidth to 60 megabits per second, effectively enforcing an upper bandwidth limit for video conferencing traffic."],
|
102 |
+
["Configure an interface named 'Sec60' in the 'root' virtual domain with an IP address of 172.18.60.1/24. Allow management access to this interface for ping, fabric, and speed-test. Enable device identification and set the interface role to LAN. Set the SNMP index to 41 and enable auto-authentication for dedicated Fortinet extension devices. Additionally, enable switch controller features such as IGMP snooping, IGMP snooping proxy, and DHCP snooping. Set the color of the interface icon on the GUI to 7 and associate it with the 'FortiLink' interface and VLAN ID 60."],
|
103 |
+
],
|
104 |
+
additional_inputs_accordion_name="Parameters",
|
105 |
+
).queue().launch()
|