import gradio as gr import vllm import torch from collections import Counter # Initialize Model llm = vllm.LLM( "Qwen/Qwen2.5-32B-Instruct-AWQ", tensor_parallel_size=2, quantization="AWQ", gpu_memory_utilization=0.95, trust_remote_code=True, dtype="half", enforce_eager=True, max_model_len=10500, ) tokenizer = llm.get_tokenizer() # Helper Functions def extract_answer(text): idx = text.rfind("\\boxed") if idx < 0: return None i = idx num_open = 0 close_idx = None while i < len(text): if text[i] == "{": num_open += 1 elif text[i] == "}": num_open -= 1 if num_open == 0: close_idx = i break i += 1 if close_idx is None: return None boxed = text[idx:close_idx + 1] left = "\\boxed{" try: assert boxed[:len(left)] == left assert boxed[-1] == "}" return boxed[len(left):-1] except: return None def majority_vote(answers): answers = [a for a in answers if a is not None] if not answers: return None counts = Counter(answers) return counts.most_common(1)[0][0] class TIRAgent: def __init__(self, problem_id, id, problem, tokenizer, max_depth, log): self.problem_id = problem_id self.id = id self.depth = 1 self.max_depth = max_depth self.tokenizer = tokenizer self.problem = problem self.messages = [ { "role": "user", "content": f"""Here is a boolean expression to simplify: {self.problem} Show the step by step simplification using Boolean algebra laws. For each step: 1. Write the current expression 2. Name the rule applied 3. Explain the transformation clearly Put your final simplified answer in a LaTeX box \\boxed{{}}.""" } ] self.last_response = None self.answers = [] self.is_complete = False self.log = log self.next_prompt = None def complete(self): return self.is_complete def add_response(self, response): self.depth += 1 self.last_response = response self.messages.append({"role": "assistant", "content": response}) # Extract boxed answer if present answer = extract_answer(response) if answer is not None: self.answers.append(answer) # Mark complete after first response self.is_complete = True def next_message(self): assert not self.is_complete text = self.tokenizer.apply_chat_template( self.messages, tokenize=False, add_generation_prompt=True ) return text def final_answer(self): ans = None if len(self.answers) > 0: ans = self.answers[-1] if self.log: self.log.writerow([self.problem_id, self.id, ans]) return ans class SCTIRAgent: def __init__(self, problem_id, problem, tokenizer, samples, max_depth, log): self.problem_id = problem_id self.problem = problem self.tokenizer = tokenizer self.samples = samples self.max_depth = max_depth self.agents = [ TIRAgent(problem_id, i, problem, tokenizer, max_depth, log) for i in range(samples) ] self.log = log def complete(self): return all(agent.complete() for agent in self.agents) def get_ready_agents(self): return [agent for agent in self.agents if not agent.complete()] def final_answer(self): assert self.complete() answers = [agent.final_answer() for agent in self.agents] answer = majority_vote(answers) return answer if answer is not None else None # Sampling parameters sampling_params = vllm.SamplingParams( max_tokens=512, temperature=0.7, top_p=0.9 ) def simplify_boolean_expression(expression): agent = SCTIRAgent(0, expression, tokenizer, samples=1, max_depth=1, log=None) while not agent.complete(): ready_agents = agent.get_ready_agents() texts = [a.next_message() for a in ready_agents] responses = llm.generate(texts, sampling_params) for j, ready_agent in enumerate(ready_agents): response = responses[j].outputs[0].text ready_agent.add_response(response) answer = agent.final_answer() return answer # Gradio Interface def interface(boolean_expr): simplified_expr = simplify_boolean_expression(boolean_expr) return simplified_expr # Gradio app app = gr.Interface( fn=interface, inputs=gr.Textbox(label="Enter Boolean Expression", placeholder="e.g., (B.C' + A'.D).(A.B' + C.D')"), outputs=gr.Textbox(label="Final Simplified Expression"), title="Boolean Expression Simplifier", description="Input a Boolean expression, and the model will provide the final simplified result.", ) if __name__ == "__main__": app.launch()