Rgenerator / app.py
Prasanna Dhungana
updated quantization config
661fe9a verified
raw
history blame
2.09 kB
import gradio as gr
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
#Loading model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_path = "parsanna17/finetune_starcoder2_with_R_data"
checkpoint = "bigcode/starcoder2-3b"
config = PeftConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=quantization_config)
model = PeftModel.from_pretrained(model, model_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def remove_header_trailer(input):
text = input.split()
start=0
end=0
i=0
while i<len(text)-1 and text[i]!="#Solution:" :
i+=1
start =i+1
i+=1
while i<len(text)-1 and text[i]!="Solution:" and text[i]!="#Question:" and text[i]!=text[i+1] :
i+=1
end = i+1 if len(text)==i else i
text= text[start:end]
return " ".join(text)
def generate(inputs):
prompt = f"""Write a code as R programmer.
#Context: You are a R Programmer going for an interview you need to provide code snippet for the given question in R programming Language.
#Question: create a function to {inputs} in R language
#Solution: """
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(**inputs, pad_token_id=tokenizer.pad_token_id,max_new_tokens=100)
return remove_header_trailer(tokenizer.decode(outputs[0]))
demo = gr.Interface(fn = generate,
inputs = gr.Textbox(lines=5, placeholder = "write you program details to generate code in R", label="Code Prompt"),
outputs=gr.Textbox(lines=5,placeholder = "Code will be generated here", label="R Code"),
title="R Programming Language Code Generator",
description="Code is being generated using Starcoder2-3b llm fine tuned on Kaggle using R dataset",
article = "Created and Maintained By Prasanna Dhungana")
demo.launch()