File size: 3,111 Bytes
2fd3831 7a75a15 d102e03 b1dd47e 4b6c061 7a75a15 c9089bd d102e03 7a75a15 30f253f 7a75a15 c9089bd dfa084c c9089bd dfa084c c9089bd d102e03 c9089bd d102e03 c9089bd 7a75a15 2fd3831 30f253f 7a75a15 30f253f 7a75a15 d73a8e9 7a75a15 b1dd47e d73a8e9 7a75a15 30f253f 7a75a15 c9089bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
from prompt_engeneering import rewrite_prompt
from supported_models import get_supported_model_names
SUPPORTED_MODEL_NAMES = get_supported_model_names()
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
print(f"Starts downloading model: {model_name} from the internet.")
pipeline = GroupedSamplingPipeLine(
model_name=model_name,
group_size=group_size,
end_of_sentence_stop=True,
temp=0.5,
top_p=0.6,
)
print(f"Finished downloading model: {model_name} from the internet.")
return pipeline
def generate_text(
pipeline: GroupedSamplingPipeLine,
prompt: str,
output_length: int,
) -> str:
"""
Generates text using the given pipeline.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param prompt: The prompt to use. str.
:param output_length: The size of the text to generate in tokens. int > 0.
:return: The generated text. str.
"""
better_prompt = rewrite_prompt(prompt)
return pipeline(
prompt_s=better_prompt,
max_new_tokens=output_length,
return_text=True,
return_full_text=False,
)["generated_text"]
@st.cache
def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
:raises ValueError: If the model name is not supported, the output length is <= 0,
the prompt is empty or longer than
16384 characters, or the output length is not an integer.
TypeError: If the output length is not an integer or the prompt is not a string.
RuntimeError: If the model is not found.
"""
if model_name not in SUPPORTED_MODEL_NAMES:
raise ValueError(f"The selected model {model_name} is not supported."
f"Supported models are all the models in:"
f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
if output_length <= 0:
raise ValueError(f"The output length {output_length} must be > 0.")
if len(prompt) == 0:
raise ValueError(f"The prompt must not be empty.")
if not isinstance(prompt, str):
raise ValueError(f"The prompt must be a string.")
if not isinstance(output_length, int):
raise ValueError(f"The output length must be an integer.")
pipeline = create_pipeline(
model_name=model_name,
group_size=output_length,
)
return generate_text(
pipeline=pipeline,
prompt=prompt,
output_length=output_length,
)
|