aperrot42wq commited on
Commit
fc9da2a
·
1 Parent(s): 7cce5b7

- basic diarization

Browse files

- basic transcript
- basic summary

.dockerignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch/
2
+ cache/
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # venv
2
+ .venv
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/Seamlessm4t_diarization_VAD.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Seamlessm4t_diarization_VAD.iml" filepath="$PROJECT_DIR$/.idea/Seamlessm4t_diarization_VAD.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # python build dependencies \
11
+ build-essential \
12
+ libssl-dev \
13
+ zlib1g-dev \
14
+ libbz2-dev \
15
+ libreadline-dev \
16
+ libsqlite3-dev \
17
+ libncursesw5-dev \
18
+ xz-utils \
19
+ tk-dev \
20
+ libxml2-dev \
21
+ libxmlsec1-dev \
22
+ libffi-dev \
23
+ liblzma-dev \
24
+ # gradio dependencies \
25
+ ffmpeg \
26
+ # fairseq2 dependencies \
27
+ libsndfile-dev && \
28
+ apt-get clean && \
29
+ rm -rf /var/lib/apt/lists/*
30
+
31
+ RUN useradd -m -u 1000 user
32
+ USER user
33
+ ENV HOME=/home/user \
34
+ PATH=/home/user/.local/bin:${PATH}
35
+ WORKDIR ${HOME}/app
36
+
37
+ RUN curl https://pyenv.run | bash
38
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
39
+ ARG PYTHON_VERSION=3.10.12
40
+ RUN pyenv install ${PYTHON_VERSION} && \
41
+ pyenv global ${PYTHON_VERSION} && \
42
+ pyenv rehash && \
43
+ pip install --no-cache-dir -U pip setuptools wheel
44
+
45
+ COPY --chown=1000 ./requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
47
+
48
+ COPY --chown=1000 . ${HOME}/app
49
+ ENV PYTHONPATH=${HOME}/app \
50
+ PYTHONUNBUFFERED=1 \
51
+ GRADIO_ALLOW_FLAGGING=never \
52
+ GRADIO_NUM_PORTS=1 \
53
+ GRADIO_SERVER_NAME=0.0.0.0 \
54
+ GRADIO_THEME=huggingface \
55
+ SYSTEM=spaces \
56
+ GRADIO_SERVER_PORT=7860
57
+ EXPOSE 7860
58
+
59
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Dirty one file implementation for expermiental (and fun) purpose only
2
+
3
+ import os
4
+ import gradio as gr
5
+
6
+ from dotenv import load_dotenv
7
+ from pydub import AudioSegment
8
+ from tqdm.auto import tqdm
9
+ print('starting')
10
+
11
+ load_dotenv()
12
+
13
+ from gradio_client import Client
14
+
15
+ HF_API = os.getenv("HF_API")
16
+ SEAMLESS_API_URL = os.getenv("SEAMLESS_API_URL") # path to Seamlessm4t API endpoint
17
+ GPU_AVAILABLE = os.getenv("GPU_AVAILABLE")
18
+ DEFAULT_TARGET_LANGUAGE = "French"
19
+ MISTRAL_SUMMARY_URL= "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
20
+ LLAMA_SUMMARY_URL="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
21
+
22
+ print('env setup ok')
23
+
24
+
25
+ DESCRIPTION = """
26
+ # Transcribe and create a summary of a conversation.
27
+ """
28
+
29
+ DUPLICATE = """
30
+ To duplicate this repo, you have to give permission from three reopsitories and accept all user conditions:
31
+ 1- https://huggingface.co/pyannote/voice-activity-detection
32
+ 2- https://hf.co/pyannote/segmentation
33
+ 3- https://hf.co/pyannote/speaker-diarization
34
+
35
+ """
36
+ from pyannote.audio import Pipeline
37
+ #initialize diarization pipeline
38
+ diarizer = Pipeline.from_pretrained(
39
+ "pyannote/speaker-diarization-3.1",
40
+ use_auth_token=HF_API)
41
+ # send pipeline to GPU (when available)
42
+ import torch
43
+ diarizer.to(torch.device(GPU_AVAILABLE))
44
+
45
+ print('diarizer setup ok')
46
+
47
+
48
+ # predict is a generator that incrementally yields recognized text with speaker label
49
+ def predict(target_language, input_audio):
50
+ print('->predict started')
51
+ print(target_language, type(input_audio), input_audio)
52
+
53
+ print('-->diarization')
54
+ diarized = diarizer(input_audio, min_speakers=2, max_speakers=5)
55
+
56
+ print('-->automatic speech recognition')
57
+ # split audio according to diarization
58
+ song = AudioSegment.from_wav(input_audio)
59
+ client = Client(SEAMLESS_API_URL, hf_token=HF_API)
60
+ output_text = ""
61
+ for turn, _, speaker in diarized.itertracks(yield_label=True):
62
+ print(speaker, turn)
63
+ try:
64
+ clipped = song[turn.start * 1000 : turn.end * 1000]
65
+ clipped.export(f"my.wav", format="wav", bitrate=16000)
66
+
67
+ result = client.predict(
68
+ f"my.wav",
69
+ target_language,
70
+ api_name="/asr"
71
+ )
72
+
73
+ current_text = f"speaker: {speaker} text: {result} "
74
+ print(current_text)
75
+
76
+ if current_text is not None:
77
+ output_text = output_text + "\n" + current_text
78
+ yield output_text
79
+
80
+ except Exception as e:
81
+ print(e)
82
+
83
+
84
+ import requests
85
+
86
+
87
+ def generate_summary_llama3(language, transcript):
88
+ queryTxt = f'''
89
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
90
+
91
+ You are a helpful and truthful patient-doctor encounter summary writer.
92
+ Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries.
93
+ The summary only contains informations from the transcript.
94
+ Your summary is written in {language}.
95
+ The summary only includes relevant sections.
96
+ <template>
97
+ # Chief Complaint
98
+ # History of Present Illness (HPI)
99
+ # Relevant Past Medical History
100
+ # Physical Examination
101
+ # Assessment and Plan
102
+ # Follow-up
103
+ # Additional Notes
104
+ </template> <|eot_id|>
105
+ <|begin_of_text|><|start_header_id|>user<|end_header_id|>
106
+
107
+ <transcript>
108
+ {transcript}
109
+ </transcript><|eot_id|>
110
+ <|start_header_id|>assistant<|end_header_id|>
111
+ '''
112
+
113
+ payload = {
114
+ "inputs": queryTxt,
115
+ "parameters": {
116
+ "return_full_text": False,
117
+ "wait_for_model": True,
118
+ "min_length": 1000
119
+ },
120
+ "options": {
121
+ "use_cache": False
122
+ }
123
+ }
124
+
125
+ response = requests.post(LLAMA_SUMMARY_URL, headers = {"Authorization": f"Bearer {HF_API}"}, json=payload)
126
+ print(response.json())
127
+ return response.json()[0]['generated_text'][len('<summary>'):]
128
+
129
+ def generate_summary_mistral(language, transcript):
130
+ sysPrompt = f'''<s>[INST]
131
+ You are a helpful and truthful patient-doctor encounter summary writer.
132
+ Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries.
133
+ The summary only contains informations from the transcript.
134
+ Your summary is written in {language}.
135
+ The summary only includes relevant sections.
136
+ <template>
137
+ # Chief Complaint
138
+ # History of Present Illness (HPI)
139
+ # Relevant Past Medical History
140
+ # Physical Examination
141
+ # Assessment and Plan
142
+ # Follow-up
143
+ # Additional Notes
144
+ </template>
145
+
146
+ '''
147
+ queryTxt=f'''
148
+ <transcript>
149
+ {transcript}
150
+ </transcript>
151
+ [/INST]
152
+ '''
153
+
154
+ payload = {
155
+ "inputs": sysPrompt + queryTxt,
156
+ "parameters": {
157
+ "return_full_text": False,
158
+ "wait_for_model": True,
159
+ "min_length": 1000
160
+ },
161
+ "options": {
162
+ "use_cache": False
163
+ }
164
+ }
165
+
166
+ response = requests.post(MISTRAL_SUMMARY_URL, headers = {"Authorization": f"Bearer {HF_API}"}, json=payload)
167
+ print(response.json())
168
+ return response.json()[0]['generated_text'][len('<summary>'):]
169
+
170
+ def generate_summary(model, language, transcript):
171
+ match model:
172
+ case "Mistral-7B":
173
+ print("-> summarize with mistral")
174
+ return generate_summary_mistral( language, transcript)
175
+ case "LLAMA3":
176
+ print("-> summarize with llama3")
177
+ return generate_summary_llama3(language, transcript)
178
+ case _:
179
+ return f"Unknown model {model}"
180
+
181
+ def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
182
+ mic = audio_source == "microphone"
183
+ return (
184
+ gr.update(visible=mic, value=None), # input_audio_mic
185
+ gr.update(visible=not mic, value=None), # input_audio_file
186
+ )
187
+
188
+
189
+ with gr.Blocks() as demo:
190
+ gr.Markdown(DESCRIPTION)
191
+ with gr.Group():
192
+ with gr.Row():
193
+ target_language = gr.Dropdown(
194
+ choices= ["French", "English"],
195
+ label="Output Language",
196
+ value="French",
197
+ interactive=True,
198
+ info="Select your target language",
199
+ )
200
+ with gr.Row() as audio_box:
201
+ input_audio = gr.Audio(
202
+ type="filepath"
203
+ )
204
+ submit = gr.Button("Transcribe")
205
+ transcribe_output = gr.Textbox(
206
+ label="Transcribed Text",
207
+ value="",
208
+ interactive=False,
209
+ lines=10,
210
+ scale=10,
211
+ max_lines=100,
212
+ )
213
+ submit.click(
214
+ fn=predict,
215
+ inputs=[
216
+ target_language,
217
+ input_audio
218
+ ],
219
+ outputs=[transcribe_output],
220
+ api_name="predict",
221
+ )
222
+ with gr.Row():
223
+ sumary_model = gr.Dropdown(
224
+ choices= ["Mistral-7B", "LLAMA3"],
225
+ label="Summary model",
226
+ value="Mistral-7B",
227
+ interactive=True,
228
+ info="Select your summary model",
229
+ )
230
+ summarize = gr.Button("Summarize")
231
+ summary_output = gr.Textbox(
232
+ label="Summarized Text",
233
+ value="",
234
+ interactive=False,
235
+ lines=10,
236
+ scale=10,
237
+ max_lines=100,
238
+ )
239
+ summarize.click(
240
+ fn=generate_summary,
241
+ inputs=[
242
+ sumary_model,
243
+ target_language,
244
+ transcribe_output
245
+ ],
246
+ outputs=[summary_output],
247
+ api_name="predict",
248
+ )
249
+ gr.Markdown(DUPLICATE)
250
+
251
+ demo.queue(max_size=50).launch()
252
+
docker-compose.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.7"
2
+
3
+ services:
4
+ seamless_diarization_VAD_service:
5
+ image: seamless_diarization_vad:v0.1
6
+ build:
7
+ context: .
8
+ container_name: seamless_diarization_VAD
9
+ user: root
10
+ environment:
11
+ - GIT_PYTHON_REFRESH=quiet
12
+ - HUGGINGFACE_HUB_CACHE=/home/user/temp/.cache/HUGGINGFACE_HUB_CACHE
13
+ - TRANSFORMERS_CACHE=/home/user/temp/.cache/TRANSFORMERS_CACHE
14
+ - HF_HOME=/home/user/temp/.cache/HF_HOME
15
+ - TORCH_HOME=/home/user/temp/torch
16
+ #- TRANSFORMERS_OFFLINE=1
17
+ tty: true
18
+ stdin_open: true
19
+ volumes:
20
+ - ./torch/:/home/user/temp/torch
21
+ - ./cache/:/home/user/.cache/
22
+ ports:
23
+ - 8005:7860
24
+ deploy:
25
+ resources:
26
+ reservations:
27
+ devices:
28
+ - driver: nvidia
29
+ device_ids: ["4"]
30
+ capabilities: [gpu]
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pyannote.audio
2
+ pydub
3
+ gradio_client==0.16.0
4
+ gradio==4.28.3
5
+ python-dotenv==1.0.0
6
+ torch
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }