Spaces:
Running
Running
peter szemraj
commited on
Commit
·
88b1e11
1
Parent(s):
91d1162
format
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ logging.basicConfig(
|
|
11 |
|
12 |
use_gpu = torch.cuda.is_available()
|
13 |
|
|
|
14 |
def generate_text(
|
15 |
prompt: str,
|
16 |
gen_length=64,
|
@@ -40,7 +41,7 @@ def generate_text(
|
|
40 |
st = time.perf_counter()
|
41 |
|
42 |
input_tokens = generator.tokenizer(prompt)
|
43 |
-
input_len = len(input_tokens[
|
44 |
if input_len > abs_max_length:
|
45 |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
|
46 |
result = generator(
|
@@ -55,9 +56,8 @@ def generate_text(
|
|
55 |
early_stopping=True,
|
56 |
# tokenizer
|
57 |
truncation=True,
|
58 |
-
|
59 |
-
|
60 |
-
response = result[0]['generated_text']
|
61 |
rt = time.perf_counter() - st
|
62 |
if verbose:
|
63 |
logging.info(f"Generated text: {response}")
|
@@ -74,12 +74,12 @@ def get_parser():
|
|
74 |
)
|
75 |
|
76 |
parser.add_argument(
|
77 |
-
|
78 |
-
|
79 |
required=False,
|
80 |
type=str,
|
81 |
default="postbot/distilgpt2-emailgen",
|
82 |
-
help=
|
83 |
)
|
84 |
|
85 |
parser.add_argument(
|
@@ -91,6 +91,7 @@ def get_parser():
|
|
91 |
)
|
92 |
return parser
|
93 |
|
|
|
94 |
default_prompt = """
|
95 |
Hello,
|
96 |
|
@@ -109,7 +110,6 @@ if __name__ == "__main__":
|
|
109 |
device=0 if use_gpu else -1,
|
110 |
)
|
111 |
|
112 |
-
|
113 |
demo = gr.Blocks()
|
114 |
|
115 |
logging.info("launching interface...")
|
@@ -119,7 +119,9 @@ if __name__ == "__main__":
|
|
119 |
gr.Markdown(
|
120 |
"Enter part of an email, and the model will autocomplete it for you!"
|
121 |
)
|
122 |
-
gr.Markdown(
|
|
|
|
|
123 |
gr.Markdown("---")
|
124 |
|
125 |
with gr.Column():
|
@@ -151,10 +153,11 @@ if __name__ == "__main__":
|
|
151 |
value=2,
|
152 |
)
|
153 |
length_penalty = gr.Slider(
|
154 |
-
|
155 |
)
|
156 |
generated_email = gr.Textbox(
|
157 |
-
label="Generated Result",
|
|
|
158 |
)
|
159 |
|
160 |
generate_button = gr.Button(
|
@@ -168,16 +171,24 @@ if __name__ == "__main__":
|
|
168 |
gr.Markdown(
|
169 |
"This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
|
170 |
)
|
171 |
-
gr.Markdown(
|
|
|
|
|
172 |
gr.Markdown("---")
|
173 |
|
174 |
generate_button.click(
|
175 |
fn=generate_text,
|
176 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
outputs=[generated_email],
|
178 |
)
|
179 |
|
180 |
demo.launch(
|
181 |
enable_queue=True,
|
182 |
-
share=True,
|
183 |
)
|
|
|
11 |
|
12 |
use_gpu = torch.cuda.is_available()
|
13 |
|
14 |
+
|
15 |
def generate_text(
|
16 |
prompt: str,
|
17 |
gen_length=64,
|
|
|
41 |
st = time.perf_counter()
|
42 |
|
43 |
input_tokens = generator.tokenizer(prompt)
|
44 |
+
input_len = len(input_tokens["input_ids"])
|
45 |
if input_len > abs_max_length:
|
46 |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
|
47 |
result = generator(
|
|
|
56 |
early_stopping=True,
|
57 |
# tokenizer
|
58 |
truncation=True,
|
59 |
+
) # generate
|
60 |
+
response = result[0]["generated_text"]
|
|
|
61 |
rt = time.perf_counter() - st
|
62 |
if verbose:
|
63 |
logging.info(f"Generated text: {response}")
|
|
|
74 |
)
|
75 |
|
76 |
parser.add_argument(
|
77 |
+
"-m",
|
78 |
+
"--model",
|
79 |
required=False,
|
80 |
type=str,
|
81 |
default="postbot/distilgpt2-emailgen",
|
82 |
+
help="Pass an different huggingface model tag to use a custom model",
|
83 |
)
|
84 |
|
85 |
parser.add_argument(
|
|
|
91 |
)
|
92 |
return parser
|
93 |
|
94 |
+
|
95 |
default_prompt = """
|
96 |
Hello,
|
97 |
|
|
|
110 |
device=0 if use_gpu else -1,
|
111 |
)
|
112 |
|
|
|
113 |
demo = gr.Blocks()
|
114 |
|
115 |
logging.info("launching interface...")
|
|
|
119 |
gr.Markdown(
|
120 |
"Enter part of an email, and the model will autocomplete it for you!"
|
121 |
)
|
122 |
+
gr.Markdown(
|
123 |
+
"The model used is [postbot/distilgpt2-emailgen](https://huggingface.co/postbot/distilgpt2-emailgen)"
|
124 |
+
)
|
125 |
gr.Markdown("---")
|
126 |
|
127 |
with gr.Column():
|
|
|
153 |
value=2,
|
154 |
)
|
155 |
length_penalty = gr.Slider(
|
156 |
+
minimum=0.5, maximum=1.0, label="length penalty", default=0.8, step=0.05
|
157 |
)
|
158 |
generated_email = gr.Textbox(
|
159 |
+
label="Generated Result",
|
160 |
+
placeholder="The completed email will appear here",
|
161 |
)
|
162 |
|
163 |
generate_button = gr.Button(
|
|
|
171 |
gr.Markdown(
|
172 |
"This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
|
173 |
)
|
174 |
+
gr.Markdown(
|
175 |
+
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements before accepting/sending something."
|
176 |
+
)
|
177 |
gr.Markdown("---")
|
178 |
|
179 |
generate_button.click(
|
180 |
fn=generate_text,
|
181 |
+
inputs=[
|
182 |
+
prompt_text,
|
183 |
+
num_gen_tokens,
|
184 |
+
num_beams,
|
185 |
+
no_repeat_ngram_size,
|
186 |
+
length_penalty,
|
187 |
+
],
|
188 |
outputs=[generated_email],
|
189 |
)
|
190 |
|
191 |
demo.launch(
|
192 |
enable_queue=True,
|
193 |
+
share=True, # for local testing
|
194 |
)
|