PHBJT commited on
Commit
80cff80
1 Parent(s): 44bb8b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -26
app.py CHANGED
@@ -11,10 +11,13 @@ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
11
  # Device setup
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
- # SmolLM Instruct setup
15
- checkpoint = "HuggingFaceTB/SmolLM-360M-Instruct"
16
- smol_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
17
- smol_model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
18
 
19
  # Original model setup
20
  repo_id = "ylacombe/p-m-e"
@@ -50,39 +53,28 @@ def format_description(raw_description, do_format=True):
50
  return raw_description
51
 
52
  messages = [{
53
- "role": "system",
54
- "content": "You are a helpful assistant that formats voice descriptions precisely according to the template provided."
55
- }, {
56
- "role": "user",
57
  "content": f"""Format this voice description exactly as:
58
  "a [gender] with a [pitch] voice speaks [speed] in a [environment], [delivery style]"
59
 
60
  Required format:
61
- - gender: man/woman
62
- - pitch: slightly low-pitched/moderate pitch/high-pitched
63
- - speed: slowly/moderately/quickly
64
- - environment: close-sounding and clear/distant-sounding and noisy
65
- - delivery style: with monotone delivery/with animated delivery
66
 
67
- Input description: {raw_description}
68
 
69
  Return only the formatted description, nothing else."""
70
  }]
71
 
72
- input_text = smol_tokenizer.apply_chat_template(messages, tokenize=False)
73
- inputs = smol_tokenizer.encode(input_text, return_tensors="pt").to(device)
74
- outputs = smol_model.generate(
75
- inputs,
76
- max_new_tokens=100,
77
- temperature=0.2,
78
- top_p=0.9,
79
- do_sample=True
80
- )
81
- formatted = smol_tokenizer.decode(outputs[0], skip_special_tokens=True)
82
 
83
- # Extract just the formatted description
84
  if "a woman" in formatted.lower() or "a man" in formatted.lower():
85
- return formatted.strip()
86
  return raw_description
87
 
88
  def preprocess(text):
 
11
  # Device setup
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
+ # Gemma setup
15
+ pipe = pipeline(
16
+ "text-generation",
17
+ model="google/gemma-2-2b-it",
18
+ model_kwargs={"torch_dtype": torch.bfloat16},
19
+ device=device
20
+ )
21
 
22
  # Original model setup
23
  repo_id = "ylacombe/p-m-e"
 
53
  return raw_description
54
 
55
  messages = [{
56
+ "role": "user",
 
 
 
57
  "content": f"""Format this voice description exactly as:
58
  "a [gender] with a [pitch] voice speaks [speed] in a [environment], [delivery style]"
59
 
60
  Required format:
61
+ - gender must be: man/woman
62
+ - pitch must be: slightly low-pitched/moderate pitch/high-pitched
63
+ - speed must be: slowly/moderately/quickly
64
+ - environment must be: close-sounding and clear/distant-sounding and noisy
65
+ - delivery style must be: with monotone delivery/with animated delivery
66
 
67
+ Input: {raw_description}
68
 
69
  Return only the formatted description, nothing else."""
70
  }]
71
 
72
+ outputs = pipe(messages, max_new_tokens=100)
73
+ formatted = outputs[0]["generated_text"][-1]["content"].strip()
 
 
 
 
 
 
 
 
74
 
75
+ # Validate and extract formatted description
76
  if "a woman" in formatted.lower() or "a man" in formatted.lower():
77
+ return formatted
78
  return raw_description
79
 
80
  def preprocess(text):