abrakjamson commited on
Commit
fb1c1ed
·
1 Parent(s): 2c5c709

New model training interface

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import re
7
  import tempfile
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- from repeng import ControlVector, ControlModel
10
  import gradio as gr
11
 
12
  # Initialize model and tokenizer
@@ -35,6 +35,7 @@ if cuda:
35
 
36
  model = ControlModel(model, list(range(-5, -18, -1)))
37
 
 
38
  # Generation settings
39
  default_generation_settings = {
40
  "pad_token_id": tokenizer.eos_token_id,
@@ -382,16 +383,16 @@ def train_model_persona(positive_text, negative_text):
382
  positive_list,
383
  negative_list,
384
  output_suffixes)
385
- # model.reset()
386
- # output_model = ControlVector.train(model, tokenizer, dataset
387
  # Write file to temporary directory returning the path to Gradio for download
388
- filename = re.sub(r'[<>:"/\\|?*]', '', positive_text) + '_'
389
  temp_file = tempfile.NamedTemporaryFile(
390
  prefix=filename,
391
  suffix=".gguf",
392
  delete= False
393
  )
394
- # ControlVector.export_gguf(output_model, temp_file.name)
395
  temp_file.close()
396
  return temp_file.name
397
 
@@ -408,7 +409,7 @@ def train_model_facts(positive_text, negative_text):
408
  )
409
 
410
  output_model = ControlVector.train(model, tokenizer, dataset)
411
- filename = re.sub(r'[<>:"/\\|?*]', '', positive_text) + '_'
412
  temp_file = tempfile.NamedTemporaryFile(
413
  prefix=filename,
414
  suffix=".gguf",
@@ -787,10 +788,10 @@ with gr.Blocks(
787
  gr.Markdown("Fill in the blank with a persona and its opposite within, \"Pretend to be a (persona) making statements about the world.\"")
788
  facts_input_positive = gr.Text(
789
  label="Positive",
790
- placeholder="time traveller from the future")
791
  facts_input_negative = gr.Text(
792
  label="Negative",
793
- placeholder="time travaller from the past")
794
  button_facts = gr.Button(
795
  value="Generate fact control model"
796
  )
 
6
  import re
7
  import tempfile
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from repeng import ControlVector, ControlModel, DatasetEntry
10
  import gradio as gr
11
 
12
  # Initialize model and tokenizer
 
35
 
36
  model = ControlModel(model, list(range(-5, -18, -1)))
37
 
38
+ # Generation settings
39
  # Generation settings
40
  default_generation_settings = {
41
  "pad_token_id": tokenizer.eos_token_id,
 
383
  positive_list,
384
  negative_list,
385
  output_suffixes)
386
+ model.reset()
387
+ output_model = ControlVector.train(model, tokenizer, dataset)
388
  # Write file to temporary directory returning the path to Gradio for download
389
+ filename = re.sub(r'[ <>:"/\\|?*]', '', positive_text) + '_'
390
  temp_file = tempfile.NamedTemporaryFile(
391
  prefix=filename,
392
  suffix=".gguf",
393
  delete= False
394
  )
395
+ ControlVector.export_gguf(output_model, temp_file.name)
396
  temp_file.close()
397
  return temp_file.name
398
 
 
409
  )
410
 
411
  output_model = ControlVector.train(model, tokenizer, dataset)
412
+ filename = re.sub(r'[ <>:"/\\|?*]', '', positive_text) + '_'
413
  temp_file = tempfile.NamedTemporaryFile(
414
  prefix=filename,
415
  suffix=".gguf",
 
788
  gr.Markdown("Fill in the blank with a persona and its opposite within, \"Pretend to be a (persona) making statements about the world.\"")
789
  facts_input_positive = gr.Text(
790
  label="Positive",
791
+ placeholder="time traveler from the future")
792
  facts_input_negative = gr.Text(
793
  label="Negative",
794
+ placeholder="time travaler from the past")
795
  button_facts = gr.Button(
796
  value="Generate fact control model"
797
  )