Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import spaces | |
title = "ืืืืื ืฉืืจืื" | |
DESCRIPTION = """\ | |
# ืฆืจื ืฉืืจืื ืืืืคืฉืื | |
ืืืืื ืืื [ืคืืื ืืืื ืฉื ืืืื ืืืืื 2 - 2ืืณ](https://huggingface.co/Norod78/hebrew_lyrics-gemma2_2b-unsloth-gguf) | |
ืืชืื ืคืจืืืคื ืืกืื ืื ืดืืชืื ืื ืืืงืฉื ืฉืืจ ืขื / ืืืชืืจ / ืฉืืืืจ ืขื ____ืด | |
""" | |
article = """\ | |
ืืืืื ืึผืึผืึทึผืื ืขืดื [ืืืจืื ืืืืจ](https://linktr.ee/Norod78) | |
""" | |
#model_id = "./hebrew_lyrics-gemma2_2b" | |
model_id = "Norod78/hebrew_lyrics-gemma2_2b" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16) | |
# model_id = "Norod78/hebrew_lyrics-gemma2_2b-unsloth-gguf" | |
# gguf_file_name = "hebrew_lyrics-gemma2_2b-unsloth.BF16.gguf" | |
# tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=gguf_file_name) | |
# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, gguf_file=gguf_file_name).to("cpu") | |
torch.manual_seed(1234) | |
def generate_song(prompt_text = ''): | |
with torch.no_grad(): | |
result = "" | |
input_template = tokenizer.apply_chat_template([{"role": "user", "content": prompt_text}], tokenize=False, add_generation_prompt=True) | |
input_ids = tokenizer(input_template, return_tensors="pt").to(model.device) | |
#sample_outputs = model.generate(**input_ids, max_new_tokens=512 , repetition_penalty=1.1, temperature=0.4, top_p=0.95, top_k=40, do_sample = True) | |
sample_outputs = model.generate(**input_ids, max_new_tokens=384 , repetition_penalty=1.1, temperature=0.6, top_p=0.4, top_k=40, do_sample = True) | |
#sample_outputs = model.generate(**input_ids, max_new_tokens=512 , repetition_penalty=1.1, temperature=0.5, do_sample = True) | |
decoded_output = tokenizer.batch_decode(sample_outputs, skip_special_tokens=True)[0] | |
result = decoded_output.replace("user\n", "ืืฉืชืืฉ:\n").replace("model\n", "\nืืืื:\n") | |
return result | |
demo = gr.Interface( | |
generate_song, | |
inputs=gr.Textbox(lines=1, label="ืืงืฉื ืฉืืจ", rtl=True), | |
outputs=gr.Textbox(label="ืืคืื ืฉื ืืืืื", rtl=True), | |
title=title, | |
description=DESCRIPTION, | |
article=article, | |
examples=["ืชื ื ืืฉืืฉ ืืขืืืช, ืืืืงืจ ืืืืืจ", "ืืชืื ืื ืืืงืฉื ืฉืืจ ืขื ืชืคืื ืืืื ืขื ืืจืื ืืืจืชืืช", "ืฉืืจ ืืืชืืจ ืืช ืืืืื ืฉื ืืืืจืื ืืฆืืข ืกืืื ืขื ืื ืคืืื ืืืืจืคืชืงืืืช ืืืขืืคืคืืช ืฉืืื", "ืฉืืจ ืขื ืคืชื ืชืงืืื", "ืฉืืจ ืขื ืืื ืงืจื ืืืจืืืื ๐ฆ"], | |
allow_flagging="never", | |
) | |
demo.queue() | |
demo.launch() |