pablo-rf commited on
Commit
fbfc907
1 Parent(s): 3e64466

Add Cerebras model

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -1,9 +1,11 @@
 
1
  import gradio as gr
2
  from gradio.components import Slider
3
  import torch
4
  from transformers import pipeline
5
 
6
  # Model, information and examples ----------------------------------------------
 
7
  model_id = "proxectonos/FLOR-1.3B-GL"
8
  title = "Modelo de xeraci贸n de texto FLOR-1.3B-GL"
9
  markdown_description = """
@@ -27,17 +29,26 @@ few_shot_prompts_examples = [
27
  fronted_theme = 'Soft'
28
 
29
  # Model charge ---------------------------------------------------------
30
- model_id = "proxectonos/FLOR-1.3B-GL"
31
- generator_model = pipeline("text-generation", model=model_id)
 
 
32
 
33
  # Generation functions ---------------------------------------------------------
 
 
 
 
 
 
34
  def remove_empty_lines(text):
35
  lines = text.strip().split("\n")
36
  non_empty_lines = [line for line in lines if line.strip()]
37
  return "\n".join(non_empty_lines)
38
 
39
- def predict(prompt, max_length, repetition_penalty, temperature):
40
  print("Dentro da xeraci贸n...")
 
41
  prompt_length = len(generator_model.tokenizer.encode(prompt))
42
  generated_text = generator_model(
43
  prompt,
@@ -91,7 +102,13 @@ def gradio_app():
91
  gr.HTML('<img src="https://huggingface.co/spaces/proxectonos/README/resolve/main/title-card.png" width="100%" style="border-radius: 0.75rem;">')
92
  with gr.Column():
93
  gr.Markdown(markdown_description)
94
-
 
 
 
 
 
 
95
  with gr.Row(equal_height=True):
96
  with gr.Column():
97
  text_gl = gr.Textbox(label="Input",
@@ -128,7 +145,7 @@ def gradio_app():
128
  pass_btn = gr.Button(value="Pass text to input")
129
  clean_btn = gr.Button(value="Clean")
130
 
131
- generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty, temperature], outputs=generated_gl, api_name="generate-flor-gl")
132
  clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty, temperature], queue=False, api_name=False)
133
  pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
134
 
 
1
+ import os
2
  import gradio as gr
3
  from gradio.components import Slider
4
  import torch
5
  from transformers import pipeline
6
 
7
  # Model, information and examples ----------------------------------------------
8
+ MODEL_NAMES = ["FLOR-1.3B-GL","Cerebras-1.3B-GL"]
9
  model_id = "proxectonos/FLOR-1.3B-GL"
10
  title = "Modelo de xeraci贸n de texto FLOR-1.3B-GL"
11
  markdown_description = """
 
29
  fronted_theme = 'Soft'
30
 
31
  # Model charge ---------------------------------------------------------
32
+ model_id_flor = "proxectonos/FLOR-1.3B-GL"
33
+ generator_model_flor = pipeline("text-generation", model=model_id_flor)
34
+ model_id_cerebras = "proxectonos/Cerebras-1.3B-GL"
35
+ generator_model_cerebras = pipeline("text-generation", model=model_id_cerebras, token=os.environ['TOKEN_HF'])
36
 
37
  # Generation functions ---------------------------------------------------------
38
+ def get_model(model_selection):
39
+ if model_selection == "FLOR-1.3B-GL":
40
+ return generator_model_flor
41
+ else:
42
+ return generator_model_cerebras
43
+
44
  def remove_empty_lines(text):
45
  lines = text.strip().split("\n")
46
  non_empty_lines = [line for line in lines if line.strip()]
47
  return "\n".join(non_empty_lines)
48
 
49
+ def predict(prompt, model_select, max_length, repetition_penalty, temperature):
50
  print("Dentro da xeraci贸n...")
51
+ generator_model = get_model(model_select)
52
  prompt_length = len(generator_model.tokenizer.encode(prompt))
53
  generated_text = generator_model(
54
  prompt,
 
102
  gr.HTML('<img src="https://huggingface.co/spaces/proxectonos/README/resolve/main/title-card.png" width="100%" style="border-radius: 0.75rem;">')
103
  with gr.Column():
104
  gr.Markdown(markdown_description)
105
+ with gr.Row():
106
+ model_select = gr.Dropdown(
107
+ label="Escolle un modelo:",
108
+ choices=MODEL_NAMES,
109
+ value=MODEL_NAMES[0],
110
+ interactive=True
111
+ )
112
  with gr.Row(equal_height=True):
113
  with gr.Column():
114
  text_gl = gr.Textbox(label="Input",
 
145
  pass_btn = gr.Button(value="Pass text to input")
146
  clean_btn = gr.Button(value="Clean")
147
 
148
+ generator_btn.click(predict, inputs=[text_gl, model_select, max_length, repetition_penalty, temperature], outputs=generated_gl, api_name="generate-flor-gl")
149
  clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty, temperature], queue=False, api_name=False)
150
  pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
151