Spaces:
Running
Running
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer | |
import gradio as gr | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts, sizes | |
from typing import Iterable | |
class SQLGEN(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.stone, | |
secondary_hue: colors.Color | str = colors.green, | |
neutral_hue: colors.Color | str = colors.gray, | |
spacing_size: sizes.Size | str = sizes.spacing_md, | |
radius_size: sizes.Size | str = sizes.radius_md, | |
text_size: sizes.Size | str = sizes.text_lg, | |
font: fonts.Font | |
| str | |
| Iterable[fonts.Font | str] = ( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-sans-serif", | |
"sans-serif", | |
), | |
font_mono: fonts.Font | |
| str | |
| Iterable[fonts.Font | str] = ( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-monospace", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
model_id = "alibidaran/Gemma2_SQLGEN" | |
#bnb_config for GPU usage | |
#bnb_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_quant_type="nf4", | |
# bnb_4bit_compute_dtype=torch.bfloat16 | |
#) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto') | |
tokenizer.padding_side = 'right' | |
def generate_sql(query,context): | |
prompt = query | |
context=context | |
text=f"<s>##Question: {prompt} \n ##Context: {context} \n ##Answer:" | |
inputs=tokenizer(text,return_tensors='pt').to('cuda') | |
with torch.no_grad(): | |
outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5) | |
output_text=outputs[:, inputs.input_ids.shape[1]:] | |
output_text=tokenizer.decode(output_text[0], skip_special_tokens=True) | |
return output_text | |
interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN()) | |
if __name__=='__main__': | |
interface.launch() |