OpenSound commited on
Commit
3fd1350
·
1 Parent(s): 1b1262d

Update audiocraft/audiocraft/modules/conditioners.py

Browse files
audiocraft/audiocraft/modules/conditioners.py CHANGED
@@ -18,7 +18,7 @@ import warnings
18
 
19
  import einops
20
  from num2words import num2words
21
- import spacy
22
  from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
23
  import torch
24
  from torch import nn
@@ -185,72 +185,72 @@ class Tokenizer:
185
  raise NotImplementedError()
186
 
187
 
188
- class WhiteSpaceTokenizer(Tokenizer):
189
- """This tokenizer should be used for natural language descriptions.
190
- For example:
191
- ["he didn't, know he's going home.", 'shorter sentence'] =>
192
- [[78, 62, 31, 4, 78, 25, 19, 34],
193
- [59, 77, 0, 0, 0, 0, 0, 0]]
194
- """
195
- PUNCTUATION = "?:!.,;"
196
-
197
- def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
198
- lemma: bool = True, stopwords: bool = True) -> None:
199
- self.n_bins = n_bins
200
- self.pad_idx = pad_idx
201
- self.lemma = lemma
202
- self.stopwords = stopwords
203
- try:
204
- self.nlp = spacy.load(language)
205
- except IOError:
206
- spacy.cli.download(language) # type: ignore
207
- self.nlp = spacy.load(language)
208
-
209
- @tp.no_type_check
210
- def __call__(self, texts: tp.List[tp.Optional[str]],
211
- return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
212
- """Take a list of strings and convert them to a tensor of indices.
213
-
214
- Args:
215
- texts (list[str]): List of strings.
216
- return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
217
- Returns:
218
- tuple[torch.Tensor, torch.Tensor]:
219
- - Indices of words in the LUT.
220
- - And a mask indicating where the padding tokens are
221
- """
222
- output, lengths = [], []
223
- texts = deepcopy(texts)
224
- for i, text in enumerate(texts):
225
- # if current sample doesn't have a certain attribute, replace with pad token
226
- if text is None:
227
- output.append(torch.Tensor([self.pad_idx]))
228
- lengths.append(0)
229
- continue
230
-
231
- # convert numbers to words
232
- text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
233
- # normalize text
234
- text = self.nlp(text) # type: ignore
235
- # remove stopwords
236
- if self.stopwords:
237
- text = [w for w in text if not w.is_stop] # type: ignore
238
- # remove punctuation
239
- text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
240
- # lemmatize if needed
241
- text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
242
-
243
- texts[i] = " ".join(text)
244
- lengths.append(len(text))
245
- # convert to tensor
246
- tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
247
- output.append(tokens)
248
-
249
- mask = length_to_mask(torch.IntTensor(lengths)).int()
250
- padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
251
- if return_text:
252
- return padded_output, mask, texts # type: ignore
253
- return padded_output, mask
254
 
255
 
256
  class NoopTokenizer(Tokenizer):
 
18
 
19
  import einops
20
  from num2words import num2words
21
+ # import spacy
22
  from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
23
  import torch
24
  from torch import nn
 
185
  raise NotImplementedError()
186
 
187
 
188
+ # class WhiteSpaceTokenizer(Tokenizer):
189
+ # """This tokenizer should be used for natural language descriptions.
190
+ # For example:
191
+ # ["he didn't, know he's going home.", 'shorter sentence'] =>
192
+ # [[78, 62, 31, 4, 78, 25, 19, 34],
193
+ # [59, 77, 0, 0, 0, 0, 0, 0]]
194
+ # """
195
+ # PUNCTUATION = "?:!.,;"
196
+
197
+ # def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
198
+ # lemma: bool = True, stopwords: bool = True) -> None:
199
+ # self.n_bins = n_bins
200
+ # self.pad_idx = pad_idx
201
+ # self.lemma = lemma
202
+ # self.stopwords = stopwords
203
+ # try:
204
+ # self.nlp = spacy.load(language)
205
+ # except IOError:
206
+ # spacy.cli.download(language) # type: ignore
207
+ # self.nlp = spacy.load(language)
208
+
209
+ # @tp.no_type_check
210
+ # def __call__(self, texts: tp.List[tp.Optional[str]],
211
+ # return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
212
+ # """Take a list of strings and convert them to a tensor of indices.
213
+
214
+ # Args:
215
+ # texts (list[str]): List of strings.
216
+ # return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
217
+ # Returns:
218
+ # tuple[torch.Tensor, torch.Tensor]:
219
+ # - Indices of words in the LUT.
220
+ # - And a mask indicating where the padding tokens are
221
+ # """
222
+ # output, lengths = [], []
223
+ # texts = deepcopy(texts)
224
+ # for i, text in enumerate(texts):
225
+ # # if current sample doesn't have a certain attribute, replace with pad token
226
+ # if text is None:
227
+ # output.append(torch.Tensor([self.pad_idx]))
228
+ # lengths.append(0)
229
+ # continue
230
+
231
+ # # convert numbers to words
232
+ # text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
233
+ # # normalize text
234
+ # text = self.nlp(text) # type: ignore
235
+ # # remove stopwords
236
+ # if self.stopwords:
237
+ # text = [w for w in text if not w.is_stop] # type: ignore
238
+ # # remove punctuation
239
+ # text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
240
+ # # lemmatize if needed
241
+ # text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
242
+
243
+ # texts[i] = " ".join(text)
244
+ # lengths.append(len(text))
245
+ # # convert to tensor
246
+ # tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
247
+ # output.append(tokens)
248
+
249
+ # mask = length_to_mask(torch.IntTensor(lengths)).int()
250
+ # padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
251
+ # if return_text:
252
+ # return padded_output, mask, texts # type: ignore
253
+ # return padded_output, mask
254
 
255
 
256
  class NoopTokenizer(Tokenizer):