Spaces:
Running
on
T4
Running
on
T4
Julien Simon
commited on
Commit
·
70ec0d7
1
Parent(s):
dad2516
Initial version
Browse files- .gitattributes +1 -0
- README.md +3 -3
- app.py +110 -0
- df10k_SP500_2020.csv.zip +3 -0
- df10k_embeddings_msmarco-distilbert-base-v4.npz +3 -0
- dummy.wav +0 -0
- energy_16k_fr.wav +0 -0
- energy_24k_es.wav +0 -0
- requirements.txt +7 -0
- sales_16k_fr.wav +0 -0
- tax_24k_de.wav +0 -0
.gitattributes
CHANGED
@@ -23,5 +23,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
27 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Voice Queries
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
|
|
1 |
---
|
2 |
title: Voice Queries
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
from transformers import pipeline
|
8 |
+
from librosa import load, resample
|
9 |
+
|
10 |
+
# Constants
|
11 |
+
filename = 'df10k_SP500_2020.csv.zip'
|
12 |
+
|
13 |
+
model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
|
14 |
+
max_sequence_length = 512
|
15 |
+
embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
|
16 |
+
asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'
|
17 |
+
|
18 |
+
# Load corpus
|
19 |
+
df = pd.read_csv(filename)
|
20 |
+
df.drop_duplicates(inplace=True)
|
21 |
+
print(f'Number of documents: {len(df)}')
|
22 |
+
|
23 |
+
corpus = []
|
24 |
+
sentence_count = []
|
25 |
+
for _, row in df.iterrows():
|
26 |
+
# We're interested in the 'mdna' column: 'Management discussion and analysis'
|
27 |
+
sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
|
28 |
+
sentence_count.append(len(sentences))
|
29 |
+
for _,s in enumerate(sentences):
|
30 |
+
corpus.append(s)
|
31 |
+
print(f'Number of sentences: {len(corpus)}')
|
32 |
+
|
33 |
+
# Load pre-embedded corpus
|
34 |
+
corpus_embeddings = np.load(embeddings_filename)['arr_0']
|
35 |
+
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
|
36 |
+
|
37 |
+
# Load embedding model
|
38 |
+
model = SentenceTransformer(model_name)
|
39 |
+
model.max_seq_length = max_sequence_length
|
40 |
+
|
41 |
+
# Load speech to text model
|
42 |
+
asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)
|
43 |
+
|
44 |
+
def find_sentences(query, hits):
|
45 |
+
query_embedding = model.encode(query)
|
46 |
+
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
|
47 |
+
hits = hits[0]
|
48 |
+
|
49 |
+
output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
|
50 |
+
for hit in hits:
|
51 |
+
corpus_id = hit['corpus_id']
|
52 |
+
# Find source document based on sentence index
|
53 |
+
count = 0
|
54 |
+
for idx, c in enumerate(sentence_count):
|
55 |
+
count+=c
|
56 |
+
if (corpus_id > count-1):
|
57 |
+
continue
|
58 |
+
else:
|
59 |
+
doc = df.iloc[idx]
|
60 |
+
new_row = {
|
61 |
+
'Ticker' : doc['ticker'],
|
62 |
+
'Form type' : doc['form_type'],
|
63 |
+
'Filing date': doc['filing_date'],
|
64 |
+
'Text' : corpus[corpus_id],
|
65 |
+
'Score' : '{:.2f}'.format(hit['score'])
|
66 |
+
}
|
67 |
+
output = output.append(new_row, ignore_index=True)
|
68 |
+
break
|
69 |
+
return output
|
70 |
+
|
71 |
+
|
72 |
+
def process(input_selection, query, filepath, hits):
|
73 |
+
if input_selection=='speech':
|
74 |
+
speech, sampling_rate = load(filepath)
|
75 |
+
if sampling_rate != 16000:
|
76 |
+
speech = resample(speech, sampling_rate, 16000)
|
77 |
+
text = asr(speech)['text']
|
78 |
+
else:
|
79 |
+
text = query
|
80 |
+
return text, find_sentences(text, hits)
|
81 |
+
|
82 |
+
# Gradio inputs
|
83 |
+
buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
|
84 |
+
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
|
85 |
+
mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
|
86 |
+
slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
|
87 |
+
|
88 |
+
# Gradio outputs
|
89 |
+
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
|
90 |
+
results = gr.outputs.Dataframe(
|
91 |
+
headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'],
|
92 |
+
label='Query results')
|
93 |
+
|
94 |
+
iface = gr.Interface(
|
95 |
+
theme='huggingface',
|
96 |
+
description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages.',
|
97 |
+
fn=process,
|
98 |
+
inputs=[buttons,text_query,mic,slider],
|
99 |
+
outputs=[speech_query, results],
|
100 |
+
examples=[
|
101 |
+
['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
|
102 |
+
['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
|
103 |
+
['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
|
104 |
+
['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
|
105 |
+
['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
|
106 |
+
['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
|
107 |
+
],
|
108 |
+
allow_flagging=False
|
109 |
+
)
|
110 |
+
iface.launch()
|
df10k_SP500_2020.csv.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:984d7a036dba3c32e176e6609392909b005bc5ac030de24427f0982c88aaaf0d
|
3 |
+
size 134796242
|
df10k_embeddings_msmarco-distilbert-base-v4.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa084a851fd82187d2e0aa1ab72fb4e1ac127a07cbcc3597a551719850b8b25d
|
3 |
+
size 526747035
|
dummy.wav
ADDED
File without changes
|
energy_16k_fr.wav
ADDED
Binary file (156 kB). View file
|
|
energy_24k_es.wav
ADDED
Binary file (182 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
nltk
|
4 |
+
pandas
|
5 |
+
numpy
|
6 |
+
sentence-transformers
|
7 |
+
librosa
|
sales_16k_fr.wav
ADDED
Binary file (136 kB). View file
|
|
tax_24k_de.wav
ADDED
Binary file (152 kB). View file
|
|