yonikremer
commited on
Commit
·
30f253f
1
Parent(s):
c7b7b1d
changed pipeline's parameters
Browse files- hanlde_form_submit.py +10 -6
hanlde_form_submit.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from grouped_sampling import GroupedSamplingPipeLine
|
3 |
|
4 |
-
|
5 |
from supported_models import get_supported_model_names
|
6 |
|
7 |
|
@@ -19,15 +18,17 @@ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
|
|
19 |
model_name=model_name,
|
20 |
group_size=group_size,
|
21 |
end_of_sentence_stop=True,
|
|
|
|
|
22 |
)
|
23 |
|
24 |
|
25 |
@st.cache
|
26 |
-
def on_form_submit(model_name: str,
|
27 |
"""
|
28 |
Called when the user submits the form.
|
29 |
:param model_name: The name of the model to use.
|
30 |
-
:param
|
31 |
:param prompt: The prompt to use.
|
32 |
:return: The output of the model.
|
33 |
"""
|
@@ -36,7 +37,10 @@ def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
|
|
36 |
f"Supported models are all the models in:"
|
37 |
f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
|
38 |
pipeline = create_pipeline(
|
39 |
-
model_name,
|
40 |
-
group_size,
|
41 |
)
|
42 |
-
return pipeline(
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from grouped_sampling import GroupedSamplingPipeLine
|
3 |
|
|
|
4 |
from supported_models import get_supported_model_names
|
5 |
|
6 |
|
|
|
18 |
model_name=model_name,
|
19 |
group_size=group_size,
|
20 |
end_of_sentence_stop=True,
|
21 |
+
temp=0.5,
|
22 |
+
top_p=0.6,
|
23 |
)
|
24 |
|
25 |
|
26 |
@st.cache
|
27 |
+
def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
|
28 |
"""
|
29 |
Called when the user submits the form.
|
30 |
:param model_name: The name of the model to use.
|
31 |
+
:param output_length: The size of the groups to use.
|
32 |
:param prompt: The prompt to use.
|
33 |
:return: The output of the model.
|
34 |
"""
|
|
|
37 |
f"Supported models are all the models in:"
|
38 |
f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
|
39 |
pipeline = create_pipeline(
|
40 |
+
model_name=model_name,
|
41 |
+
group_size=output_length,
|
42 |
)
|
43 |
+
return pipeline(
|
44 |
+
prompt_s=prompt,
|
45 |
+
max_new_tokens=output_length,
|
46 |
+
)["generated_text"]
|