File size: 4,664 Bytes
8824f88 4354c7d fcdfc0f fbe5614 f6f433d 8824f88 7cb6017 8824f88 fcdfc0f 5946569 8824f88 a5c0568 a174343 8824f88 31bf44d 0737a9d 34353a1 0737a9d 8824f88 abf5e5e 8824f88 fe36abc 8824f88 fe36abc d0af199 2698250 d0af199 8824f88 fe36abc abf5e5e fe36abc 8824f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#!/usr/bin/env python
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
HF_TOKEN = os.environ['HF_TOKEN']
DESCRIPTION = """# π GEITje-7B-chat π
## Een groot open Nederlands taalmodel
[_Coming soon_](https://github.com/Rijgersberg/GEITje)"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU π₯Ά This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
if torch.cuda.is_available():
model_id = "Rijgersberg/GEITje-7B-chat"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.06,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(height=400),
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.,
maximum=1.2,
step=0.05,
value=0.2,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["""Welk woord hoort er niet in dit rijtje thuis: "auto, vliegtuig, geit, bus"?"""],
["Schrijf een nieuwsbericht voor De Speld over de inzet van een kudde geiten door het Nederlands Forensisch Instituut"],
["Wat zijn leuke dingen om te doen als ik een weekendje naar Friesland ga?"],
["Kan je naar de maan fietsen?"],
["Wat is het belang van open source taalmodellen?"],
],
title="π GEITje 7B Chat",
description="""Een eerste chatbot op basis van GEITje 7B: een groot open Nederlands taalmodel.
Dit is een chatbot gebaseerd op GEITje 7B, gemaakt voor demonstratiedoeleinden. Generatieve taalmodellen maken fouten, controleer daarom feiten voordat je ze overneemt. GEITJje Chat is niet uitgebreid getraind om _gealigned_ te zijn met menselijke waarden. Het is daarom mogelijk dat het problematische output genereert, zeker als het daartoe ge_prompt_ wordt.
Voor meer info over GEITJje: zie de <a href="https://github.com/Rijgersberg/GEITje">π README op GitHub</a>.""",
submit_btn="Genereer",
stop_btn="Stop",
retry_btn="π Opnieuw",
undo_btn="β©οΈ Ongedaan maken",
clear_btn="ποΈ Wissen",
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|