AlexWortega commited on
Commit
bff356a
1 Parent(s): aea90af

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -3
main.py CHANGED
@@ -11,6 +11,7 @@ from uuid import uuid4
11
  # Load the necessary models and tokenizers
12
  model_path = "Vikhrmodels/salt-116k"
13
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
14
  # Специальные токены
15
  start_audio_token = "<soa>"
16
  end_audio_token = "<eoa>"
@@ -98,9 +99,14 @@ def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
98
 
99
 
100
  # Inference functions
101
- def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
102
- print(text)
 
 
 
103
  print(type(tokenizer))
 
 
104
  text_tokenized = tokenizer(str(text), return_tensors="pt")
105
  text_input_tokens = text_tokenized["input_ids"].to(device)
106
 
@@ -117,7 +123,9 @@ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024,
117
 
118
  return audio_signal
119
 
120
- def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
 
 
121
  audio_data, sample_rate = torchaudio.load(audio_path)
122
 
123
  audio = audio_data.view(1, 1, -1).float().to(device)
 
11
  # Load the necessary models and tokenizers
12
  model_path = "Vikhrmodels/salt-116k"
13
  tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+ print(tokenizer)
15
  # Специальные токены
16
  start_audio_token = "<soa>"
17
  end_audio_token = "<eoa>"
 
99
 
100
 
101
  # Inference functions
102
+ def infer_text_to_audio(text):
103
+
104
+ max_seq_length=1024
105
+ top_k=20
106
+
107
  print(type(tokenizer))
108
+ print(text)
109
+
110
  text_tokenized = tokenizer(str(text), return_tensors="pt")
111
  text_input_tokens = text_tokenized["input_ids"].to(device)
112
 
 
123
 
124
  return audio_signal
125
 
126
+ def infer_audio_to_text(audio_path):
127
+ max_seq_length=1024
128
+ top_k=20
129
  audio_data, sample_rate = torchaudio.load(audio_path)
130
 
131
  audio = audio_data.view(1, 1, -1).float().to(device)