Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Initialize model and tokenizer | |
MODEL_ID = "erikbeltran/pydiff" | |
GGUF_FILE = "unsloth.Q4_K_M.gguf" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, gguf_file=GGUF_FILE) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, gguf_file=GGUF_FILE) | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
def format_diff_response(response): | |
"""Format the response to look like a diff output""" | |
lines = response.split('\n') | |
formatted = [] | |
for line in lines: | |
if line.startswith('+'): | |
formatted.append(f'<span style="color: green">{line}</span>') | |
elif line.startswith('-'): | |
formatted.append(f'<span style="color: red">{line}</span>') | |
else: | |
formatted.append(line) | |
return '<br>'.join(formatted) | |
def create_prompt(request, file_content, system_message): | |
return f"""<system>{system_message}</system> | |
<request>{request}</request> | |
<file> | |
{file_content} | |
</file>""" | |
def respond(request, file_content, system_message, max_tokens, temperature, top_p): | |
prompt = create_prompt(request, file_content, system_message) | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate response with streaming | |
response = "" | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
generation_kwargs = dict( | |
inputs=inputs["input_ids"], | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
streamer=streamer, | |
) | |
# Start generation in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield formatted responses as they're generated | |
for new_text in streamer: | |
response += new_text | |
yield format_diff_response(response) | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Code Review Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
request_input = gr.Textbox( | |
label="Request", | |
placeholder="Enter your request (e.g., 'fix the function', 'add error handling')", | |
lines=3 | |
) | |
file_input = gr.Code( | |
label="File Content", | |
language="python", | |
lines=10 | |
) | |
with gr.Column(): | |
output = gr.HTML(label="Diff Output") | |
with gr.Accordion("Advanced Settings", open=False): | |
system_msg = gr.Textbox( | |
value="You are a code review assistant. Analyze the code and provide suggestions in diff format. Use '+' for additions and '-' for deletions.", | |
label="System Message" | |
) | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=512, | |
step=1, | |
label="Max Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p" | |
) | |
submit_btn = gr.Button("Submit") | |
submit_btn.click( | |
fn=respond, | |
inputs=[ | |
request_input, | |
file_input, | |
system_msg, | |
max_tokens, | |
temperature, | |
top_p | |
], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.launch() |