|
import streamlit as st |
|
from grouped_sampling import GroupedSamplingPipeLine |
|
|
|
from prompt_engeneering import rewrite_prompt |
|
from supported_models import get_supported_model_names |
|
|
|
|
|
SUPPORTED_MODEL_NAMES = get_supported_model_names() |
|
|
|
|
|
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine: |
|
""" |
|
Creates a pipeline with the given model name and group size. |
|
:param model_name: The name of the model to use. |
|
:param group_size: The size of the groups to use. |
|
:return: A pipeline with the given model name and group size. |
|
""" |
|
print(f"Starts downloading model: {model_name} from the internet.") |
|
pipeline = GroupedSamplingPipeLine( |
|
model_name=model_name, |
|
group_size=group_size, |
|
end_of_sentence_stop=True, |
|
temp=0.5, |
|
top_p=0.6, |
|
) |
|
print(f"Finished downloading model: {model_name} from the internet.") |
|
return pipeline |
|
|
|
|
|
def generate_text( |
|
pipeline: GroupedSamplingPipeLine, |
|
prompt: str, |
|
output_length: int, |
|
) -> str: |
|
""" |
|
Generates text using the given pipeline. |
|
:param pipeline: The pipeline to use. GroupedSamplingPipeLine. |
|
:param prompt: The prompt to use. str. |
|
:param output_length: The size of the text to generate in tokens. int > 0. |
|
:return: The generated text. str. |
|
""" |
|
better_prompt = rewrite_prompt(prompt) |
|
return pipeline( |
|
prompt_s=better_prompt, |
|
max_new_tokens=output_length, |
|
return_text=True, |
|
return_full_text=False, |
|
)["generated_text"] |
|
|
|
|
|
@st.cache |
|
def on_form_submit(model_name: str, output_length: int, prompt: str) -> str: |
|
""" |
|
Called when the user submits the form. |
|
:param model_name: The name of the model to use. |
|
:param output_length: The size of the groups to use. |
|
:param prompt: The prompt to use. |
|
:return: The output of the model. |
|
:raises ValueError: If the model name is not supported, the output length is <= 0, |
|
the prompt is empty or longer than |
|
16384 characters, or the output length is not an integer. |
|
TypeError: If the output length is not an integer or the prompt is not a string. |
|
RuntimeError: If the model is not found. |
|
""" |
|
if model_name not in SUPPORTED_MODEL_NAMES: |
|
raise ValueError(f"The selected model {model_name} is not supported." |
|
f"Supported models are all the models in:" |
|
f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch") |
|
if output_length <= 0: |
|
raise ValueError(f"The output length {output_length} must be > 0.") |
|
if len(prompt) == 0: |
|
raise ValueError(f"The prompt must not be empty.") |
|
if not isinstance(prompt, str): |
|
raise ValueError(f"The prompt must be a string.") |
|
if not isinstance(output_length, int): |
|
raise ValueError(f"The output length must be an integer.") |
|
pipeline = create_pipeline( |
|
model_name=model_name, |
|
group_size=output_length, |
|
) |
|
return generate_text( |
|
pipeline=pipeline, |
|
prompt=prompt, |
|
output_length=output_length, |
|
) |
|
|