import pytest as pytest | |
from grouped_sampling import GroupedSamplingPipeLine | |
from available_models import AVAILABLE_MODELS | |
from hanlde_form_submit import create_pipeline, on_form_submit | |
def test_on_form_submit(): | |
model_name = "gpt2" | |
output_length = 10 | |
prompt = "Answer yes or no, is the sky blue?" | |
output = on_form_submit(model_name, output_length, prompt) | |
assert output is not None | |
assert len(output) > 0 | |
empty_prompt = "" | |
with pytest.raises(ValueError): | |
on_form_submit(model_name, output_length, empty_prompt) | |
def test_create_pipeline(model_name: str): | |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name) | |
assert pipeline is not None | |
assert pipeline.model_name == model_name | |
assert pipeline.wrapped_model.group_size == 5 | |
assert pipeline.wrapped_model.end_of_sentence_stop is False | |
del pipeline | |
if __name__ == "__main__": | |
pytest.main() | |