yonikremer
commited on
Commit
•
0499581
1
Parent(s):
6bcf2e3
added a checkbox that can disable the web search
Browse files- app.py +12 -1
- hanlde_form_submit.py +14 -2
app.py
CHANGED
@@ -40,6 +40,12 @@ with st.form("request_form"):
|
|
40 |
max_chars=2048,
|
41 |
)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
submitted: bool = st.form_submit_button(
|
44 |
label="Generate",
|
45 |
help="Generate the output text.",
|
@@ -48,7 +54,12 @@ with st.form("request_form"):
|
|
48 |
|
49 |
if submitted:
|
50 |
try:
|
51 |
-
output = on_form_submit(
|
|
|
|
|
|
|
|
|
|
|
52 |
except CudaError as e:
|
53 |
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
|
54 |
except (ValueError, TypeError, RuntimeError) as e:
|
|
|
40 |
max_chars=2048,
|
41 |
)
|
42 |
|
43 |
+
web_search: bool = st.checkbox(
|
44 |
+
label="Web search",
|
45 |
+
value=True,
|
46 |
+
help="If checked, the model will get your prompt as well as some web search results."
|
47 |
+
)
|
48 |
+
|
49 |
submitted: bool = st.form_submit_button(
|
50 |
label="Generate",
|
51 |
help="Generate the output text.",
|
|
|
54 |
|
55 |
if submitted:
|
56 |
try:
|
57 |
+
output = on_form_submit(
|
58 |
+
selected_model_name,
|
59 |
+
output_length,
|
60 |
+
submitted_prompt,
|
61 |
+
web_search,
|
62 |
+
)
|
63 |
except CudaError as e:
|
64 |
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
|
65 |
except (ValueError, TypeError, RuntimeError) as e:
|
hanlde_form_submit.py
CHANGED
@@ -52,15 +52,20 @@ def generate_text(
|
|
52 |
pipeline: GroupedSamplingPipeLine,
|
53 |
prompt: str,
|
54 |
output_length: int,
|
|
|
55 |
) -> str:
|
56 |
"""
|
57 |
Generates text using the given pipeline.
|
58 |
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
|
59 |
:param prompt: The prompt to use. str.
|
60 |
:param output_length: The size of the text to generate in tokens. int > 0.
|
|
|
61 |
:return: The generated text. str.
|
62 |
"""
|
63 |
-
|
|
|
|
|
|
|
64 |
return pipeline(
|
65 |
prompt_s=better_prompt,
|
66 |
max_new_tokens=output_length,
|
@@ -69,12 +74,18 @@ def generate_text(
|
|
69 |
)["generated_text"]
|
70 |
|
71 |
|
72 |
-
def on_form_submit(
|
|
|
|
|
|
|
|
|
|
|
73 |
"""
|
74 |
Called when the user submits the form.
|
75 |
:param model_name: The name of the model to use.
|
76 |
:param output_length: The size of the groups to use.
|
77 |
:param prompt: The prompt to use.
|
|
|
78 |
:return: The output of the model.
|
79 |
:raises ValueError: If the model name is not supported, the output length is <= 0,
|
80 |
the prompt is empty or longer than
|
@@ -99,6 +110,7 @@ def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
|
|
99 |
pipeline=pipeline,
|
100 |
prompt=prompt,
|
101 |
output_length=output_length,
|
|
|
102 |
)
|
103 |
generation_end_time = time()
|
104 |
generation_time = generation_end_time - generation_start_time
|
|
|
52 |
pipeline: GroupedSamplingPipeLine,
|
53 |
prompt: str,
|
54 |
output_length: int,
|
55 |
+
web_search: bool,
|
56 |
) -> str:
|
57 |
"""
|
58 |
Generates text using the given pipeline.
|
59 |
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
|
60 |
:param prompt: The prompt to use. str.
|
61 |
:param output_length: The size of the text to generate in tokens. int > 0.
|
62 |
+
:param web_search: Whether to use web search or not. bool.
|
63 |
:return: The generated text. str.
|
64 |
"""
|
65 |
+
if web_search:
|
66 |
+
better_prompt = rewrite_prompt(prompt)
|
67 |
+
else:
|
68 |
+
better_prompt = prompt
|
69 |
return pipeline(
|
70 |
prompt_s=better_prompt,
|
71 |
max_new_tokens=output_length,
|
|
|
74 |
)["generated_text"]
|
75 |
|
76 |
|
77 |
+
def on_form_submit(
|
78 |
+
model_name: str,
|
79 |
+
output_length: int,
|
80 |
+
prompt: str,
|
81 |
+
web_search: bool
|
82 |
+
) -> str:
|
83 |
"""
|
84 |
Called when the user submits the form.
|
85 |
:param model_name: The name of the model to use.
|
86 |
:param output_length: The size of the groups to use.
|
87 |
:param prompt: The prompt to use.
|
88 |
+
:param web_search: Whether to use web search or not.
|
89 |
:return: The output of the model.
|
90 |
:raises ValueError: If the model name is not supported, the output length is <= 0,
|
91 |
the prompt is empty or longer than
|
|
|
110 |
pipeline=pipeline,
|
111 |
prompt=prompt,
|
112 |
output_length=output_length,
|
113 |
+
web_search=web_search,
|
114 |
)
|
115 |
generation_end_time = time()
|
116 |
generation_time = generation_end_time - generation_start_time
|