Spaces:
Running
Running
File size: 2,352 Bytes
fa5f463 e53e706 fa5f463 e53e706 fa5f463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
from typing import Iterable
class SQLGEN(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.stone,
secondary_hue: colors.Color | str = colors.green,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-sans-serif",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
model_id = "alibidaran/Gemma2_SQLGEN"
#bnb_config for GPU usage
#bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
#)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
tokenizer.padding_side = 'right'
def generate_sql(query,context):
prompt = query
context=context
text=f"<s>##Question: {prompt} \n ##Context: {context} \n ##Answer:"
inputs=tokenizer(text,return_tensors='pt').to('cuda')
with torch.no_grad():
outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5)
output_text=outputs[:, inputs.input_ids.shape[1]:]
output_text=tokenizer.decode(output_text[0], skip_special_tokens=True)
return output_text
interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN())
if __name__=='__main__':
interface.launch() |