phanerozoic commited on
Commit
785e1a7
Β·
verified Β·
1 Parent(s): b12d444

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -91
app.py CHANGED
@@ -1,15 +1,11 @@
1
- # app.py β€’ SchoolSpiritΒ AI chatbot Space
2
- # Granite‑3.3‑2B‑Instruct | Streaming + rate‑limit + hallucination guard
3
  import os, re, time, datetime, threading, traceback, torch, gradio as gr
4
- from transformers import (AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer)
5
  from transformers.utils import logging as hf_logging
6
 
7
- # ───────────────────────────────── Log helper ────────────────────────────────
8
  os.environ["HF_HOME"] = "/data/.huggingface"
9
  LOG_FILE = "/data/requests.log"
10
- def log(msg: str):
11
- ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
12
- line = f"[{ts}] {msg}"
13
  print(line, flush=True)
14
  try:
15
  with open(LOG_FILE, "a") as f:
@@ -17,13 +13,9 @@ def log(msg: str):
17
  except FileNotFoundError:
18
  pass
19
 
20
- # ─────────────────────────────── Configuration ───────────────────────────────
21
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
22
- CTX_TOKENS = 1800
23
- MAX_NEW_TOKENS = 120
24
- TEMP = 0.6
25
- MAX_INPUT_CH = 300
26
- RATE_N, RATE_SEC = 5, 60 # 5 msgs / 60Β s per IP
27
 
28
  SYSTEM_MSG = (
29
  "You are **SchoolSpiritΒ AI**, the friendly digital mascot of "
@@ -37,121 +29,90 @@ SYSTEM_MSG = (
37
  "β€’ If you can’t answer, politely direct the user to [email protected].\n"
38
  "β€’ Keep language age‑appropriate; avoid profanity, politics, mature themes."
39
  )
40
- WELCOME = "HiΒ there! I’m SchoolSpiritΒ AI. Ask me anything about our services!"
41
 
42
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
43
 
44
- # ─────────────────────── Load tokenizer & model ──────────────────────────────
45
  hf_logging.set_verbosity_error()
46
  try:
47
- log("Loading tokenizer …")
48
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
49
-
50
- if torch.cuda.is_available():
51
- log("GPU detected β†’ loading model in FP‑16")
52
- model = AutoModelForCausalLM.from_pretrained(
53
- MODEL_ID,
54
- device_map="auto",
55
- torch_dtype=torch.float16,
56
- )
57
- else:
58
- log("No GPU β†’ loading model on CPU (this is slower)")
59
- model = AutoModelForCausalLM.from_pretrained(
60
- MODEL_ID,
61
- device_map="cpu",
62
- torch_dtype="auto",
63
- low_cpu_mem_usage=True,
64
- )
65
-
66
  MODEL_ERR = None
67
- log("Model loaded βœ”")
68
- except Exception as exc:
69
- MODEL_ERR = f"Model load error: {exc}"
70
- log("❌ " + MODEL_ERR + "\n" + traceback.format_exc())
71
 
72
- # ────────────────────────── Per‑IP rate limiter ──────────────────────────────
73
- VISITS: dict[str, list[float]] = {}
74
- def allowed(ip: str) -> bool:
75
  now = time.time()
76
- VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_SEC]
77
  if len(VISITS[ip]) >= RATE_N:
78
  return False
79
  VISITS[ip].append(now)
80
  return True
81
 
82
- # ─────────────────────── Prompt builder (token budget) ───────────────────────
83
- def build_prompt(raw: list[dict]) -> str:
84
  def render(m):
85
  if m["role"] == "system":
86
  return m["content"]
87
- prefix = "User:" if m["role"] == "user" else "AI:"
88
- return f"{prefix} {m['content']}"
89
- system, convo = raw[0], raw[1:]
90
  while True:
91
- parts = [system["content"]] + [render(m) for m in convo] + ["AI:"]
92
- if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOKENS or len(convo) <= 2:
93
  return "\n".join(parts)
94
- convo = convo[2:] # drop oldest user+assistant pair
95
 
96
- # ───────────────────────── Streaming chat callback ───────────────────────────
97
- def chat_fn(user_msg, chat_hist, state, request: gr.Request):
98
  ip = request.client.host if request else "anon"
99
  if not allowed(ip):
100
- chat_hist.append((user_msg, "Rate limit exceeded β€” please wait a minute."))
101
- return chat_hist, state
102
-
103
  user_msg = strip(user_msg or "")
104
  if not user_msg:
105
- return chat_hist, state
106
- if len(user_msg) > MAX_INPUT_CH:
107
- chat_hist.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
108
- return chat_hist, state
109
  if MODEL_ERR:
110
- chat_hist.append((user_msg, MODEL_ERR))
111
- return chat_hist, state
112
 
113
- # append user turn & empty assistant slot
114
- chat_hist.append((user_msg, ""))
115
  state["raw"].append({"role": "user", "content": user_msg})
116
 
117
  prompt = build_prompt(state["raw"])
118
- input_ids = tok(prompt, return_tensors="pt").to(model.device).input_ids
119
-
120
  streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
121
  threading.Thread(
122
  target=model.generate,
123
- kwargs=dict(
124
- input_ids=input_ids,
125
- max_new_tokens=MAX_NEW_TOKENS,
126
- temperature=TEMP,
127
- streamer=streamer,
128
- ),
129
  ).start()
130
 
131
  partial = ""
132
- try:
133
- for token in streamer:
134
- partial += token
135
- # hallucination guard: stop if model starts new speaker tag
136
- if "User:" in partial or "\nAI:" in partial:
137
- partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip()
138
- break
139
- chat_hist[-1] = (user_msg, partial)
140
- yield chat_hist, state
141
- except Exception as exc:
142
- log("❌ Stream error:\n" + traceback.format_exc())
143
- partial = "Apologiesβ€”internal error. Please try again."
144
 
145
  reply = strip(partial)
146
- chat_hist[-1] = (user_msg, reply)
147
  state["raw"].append({"role": "assistant", "content": reply})
148
- yield chat_hist, state # final
149
 
150
- # ─────────────────────────── Gradio Blocks UI ────────────────────────────────
151
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
152
  gr.Markdown("### SchoolSpiritΒ AI Chat")
153
- bot = gr.Chatbot(value=[("", WELCOME)], height=480, label="SchoolSpiritΒ AI")
154
- st = gr.State({
155
  "raw": [
156
  {"role": "system", "content": SYSTEM_MSG},
157
  {"role": "assistant", "content": WELCOME},
@@ -159,8 +120,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
159
  })
160
  with gr.Row():
161
  txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
162
- btn = gr.Button("Send", variant="primary")
163
- btn.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st])
164
- txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st])
165
 
166
  demo.launch()
 
 
 
1
  import os, re, time, datetime, threading, traceback, torch, gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from transformers.utils import logging as hf_logging
4
 
 
5
  os.environ["HF_HOME"] = "/data/.huggingface"
6
  LOG_FILE = "/data/requests.log"
7
+ def log(m):
8
+ line = f"[{datetime.datetime.utcnow().strftime('%H:%M:%S.%f')[:-3]}] {m}"
 
9
  print(line, flush=True)
10
  try:
11
  with open(LOG_FILE, "a") as f:
 
13
  except FileNotFoundError:
14
  pass
15
 
16
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
17
+ CTX_TOK, MAX_NEW, TEMP = 1800, 64, 0.6
18
+ MAX_IN, RATE_N, RATE_T = 300, 5, 60
 
 
 
 
19
 
20
  SYSTEM_MSG = (
21
  "You are **SchoolSpiritΒ AI**, the friendly digital mascot of "
 
29
  "β€’ If you can’t answer, politely direct the user to [email protected].\n"
30
  "β€’ Keep language age‑appropriate; avoid profanity, politics, mature themes."
31
  )
32
+ WELCOME = "HiΒ there! I’m SchoolSpiritΒ AI. How can I help?"
33
 
34
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
35
 
 
36
  hf_logging.set_verbosity_error()
37
  try:
 
38
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ MODEL_ID,
41
+ device_map="auto" if torch.cuda.is_available() else "cpu",
42
+ torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
43
+ low_cpu_mem_usage=True,
44
+ )
 
 
 
 
 
 
 
 
 
 
 
45
  MODEL_ERR = None
46
+ log("Model loaded")
47
+ except Exception as e:
48
+ MODEL_ERR = f"Model load error: {e}"
49
+ log(MODEL_ERR + "\n" + traceback.format_exc())
50
 
51
+ VISITS = {}
52
+ def allowed(ip):
 
53
  now = time.time()
54
+ VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_T]
55
  if len(VISITS[ip]) >= RATE_N:
56
  return False
57
  VISITS[ip].append(now)
58
  return True
59
 
60
+ def build_prompt(raw):
 
61
  def render(m):
62
  if m["role"] == "system":
63
  return m["content"]
64
+ return f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}"
65
+ sys, convo = raw[0], raw[1:]
 
66
  while True:
67
+ parts = [sys["content"]] + [render(m) for m in convo] + ["AI:"]
68
+ if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOK or len(convo) <= 2:
69
  return "\n".join(parts)
70
+ convo = convo[2:]
71
 
72
+ def chat_fn(user_msg, hist, state, request: gr.Request):
 
73
  ip = request.client.host if request else "anon"
74
  if not allowed(ip):
75
+ hist.append((user_msg, "Rate limit exceeded β€” please wait a minute."))
76
+ return hist, state, ""
 
77
  user_msg = strip(user_msg or "")
78
  if not user_msg:
79
+ return hist, state, ""
80
+ if len(user_msg) > MAX_IN:
81
+ hist.append((user_msg, f"Input >{MAX_IN} chars."))
82
+ return hist, state, ""
83
  if MODEL_ERR:
84
+ hist.append((user_msg, MODEL_ERR))
85
+ return hist, state, ""
86
 
87
+ hist.append((user_msg, ""))
 
88
  state["raw"].append({"role": "user", "content": user_msg})
89
 
90
  prompt = build_prompt(state["raw"])
91
+ ids = tok(prompt, return_tensors="pt").to(model.device).input_ids
 
92
  streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
93
  threading.Thread(
94
  target=model.generate,
95
+ kwargs=dict(input_ids=ids, max_new_tokens=MAX_NEW, temperature=TEMP, streamer=streamer),
 
 
 
 
 
96
  ).start()
97
 
98
  partial = ""
99
+ for piece in streamer:
100
+ partial += piece
101
+ if "User:" in partial or "\nAI:" in partial:
102
+ partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip()
103
+ break
104
+ hist[-1] = (user_msg, partial)
105
+ yield hist, state, ""
 
 
 
 
 
106
 
107
  reply = strip(partial)
108
+ hist[-1] = (user_msg, reply)
109
  state["raw"].append({"role": "assistant", "content": reply})
110
+ yield hist, state, ""
111
 
 
112
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
113
  gr.Markdown("### SchoolSpiritΒ AI Chat")
114
+ bot = gr.Chatbot(value=[("", WELCOME)], height=480)
115
+ st = gr.State({
116
  "raw": [
117
  {"role": "system", "content": SYSTEM_MSG},
118
  {"role": "assistant", "content": WELCOME},
 
120
  })
121
  with gr.Row():
122
  txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
123
+ send = gr.Button("Send", variant="primary")
124
+ send.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
125
+ txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
126
 
127
  demo.launch()