sasan commited on
Commit
6b9d2e8
·
1 Parent(s): 74474b8
Files changed (2) hide show
  1. car_assistant_slim.ipynb +42 -11
  2. kitt.py +195 -0
car_assistant_slim.ipynb CHANGED
@@ -19,13 +19,13 @@
19
  "name": "stderr",
20
  "output_type": "stream",
21
  "text": [
22
- "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
23
  " from .autonotebook import tqdm as notebook_tqdm\n",
24
- "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
25
  " _torch_pytree._register_pytree_node(\n",
26
- "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
27
  " _torch_pytree._register_pytree_node(\n",
28
- "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
29
  " _torch_pytree._register_pytree_node(\n"
30
  ]
31
  }
@@ -116,21 +116,35 @@
116
  "name": "stderr",
117
  "output_type": "stream",
118
  "text": [
119
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
 
 
 
 
 
 
 
 
 
 
 
120
  ]
121
  },
122
  {
123
  "name": "stdout",
124
  "output_type": "stream",
125
  "text": [
126
- " > tts_models/multilingual/multi-dataset/xtts_v1.1 is already downloaded.\n"
127
  ]
128
  },
129
  {
130
  "name": "stderr",
131
  "output_type": "stream",
132
  "text": [
133
- "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
 
 
 
134
  " _torch_pytree._register_pytree_node(\n"
135
  ]
136
  },
@@ -138,6 +152,8 @@
138
  "name": "stdout",
139
  "output_type": "stream",
140
  "text": [
 
 
141
  " > Using model: xtts\n"
142
  ]
143
  }
@@ -160,7 +176,20 @@
160
  "collapsed": true,
161
  "id": "JNALTDb0LT90"
162
  },
163
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  "source": [
165
  "#load model language recognition\n",
166
  "model_ckpt = \"papluca/xlm-roberta-base-language-detection\"\n",
@@ -183,8 +212,10 @@
183
  "source": [
184
  "#load model llama2\n",
185
  "mn = 'stabilityai/StableBeluga-7B' #mn = \"TheBloke/Llama-2-7b-Chat-GPTQ\" --> other possibility \n",
186
- "model = AutoModelForCausalLM.from_pretrained(mn, device_map=0, load_in_4bit=True) #torch_dtype=torch.float16\n",
187
- "tokr = AutoTokenizer.from_pretrained(mn, load_in_4bit=True) #tokenizer"
 
 
188
  ]
189
  },
190
  {
@@ -881,7 +912,7 @@
881
  "name": "python",
882
  "nbconvert_exporter": "python",
883
  "pygments_lexer": "ipython3",
884
- "version": "3.10.13"
885
  }
886
  },
887
  "nbformat": 4,
 
19
  "name": "stderr",
20
  "output_type": "stream",
21
  "text": [
22
+ "/opt/homebrew/Caskroom/miniconda/base/envs/llm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
23
  " from .autonotebook import tqdm as notebook_tqdm\n",
24
+ "/opt/homebrew/Caskroom/miniconda/base/envs/llm/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
25
  " _torch_pytree._register_pytree_node(\n",
26
+ "/opt/homebrew/Caskroom/miniconda/base/envs/llm/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
27
  " _torch_pytree._register_pytree_node(\n",
28
+ "/opt/homebrew/Caskroom/miniconda/base/envs/llm/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
29
  " _torch_pytree._register_pytree_node(\n"
30
  ]
31
  }
 
116
  "name": "stderr",
117
  "output_type": "stream",
118
  "text": [
119
+ "preprocessor_config.json: 100%|██████████| 185k/185k [00:00<00:00, 94.3MB/s]\n",
120
+ "tokenizer_config.json: 100%|██████████| 283k/283k [00:00<00:00, 1.05MB/s]\n",
121
+ "vocab.json: 100%|██████████| 836k/836k [00:00<00:00, 3.03MB/s]\n",
122
+ "tokenizer.json: 100%|██████████| 2.48M/2.48M [00:00<00:00, 50.6MB/s]\n",
123
+ "merges.txt: 100%|██████████| 494k/494k [00:00<00:00, 28.8MB/s]\n",
124
+ "normalizer.json: 100%|██████████| 52.7k/52.7k [00:00<00:00, 67.8MB/s]\n",
125
+ "added_tokens.json: 100%|██████████| 34.6k/34.6k [00:00<00:00, 38.7MB/s]\n",
126
+ "special_tokens_map.json: 100%|██████████| 2.19k/2.19k [00:00<00:00, 8.88MB/s]\n",
127
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
128
+ "config.json: 100%|██████████| 1.97k/1.97k [00:00<00:00, 4.46MB/s]\n",
129
+ "model.safetensors: 100%|██████████| 967M/967M [00:12<00:00, 74.9MB/s] \n",
130
+ "generation_config.json: 100%|██████████| 3.87k/3.87k [00:00<00:00, 39.0MB/s]\n"
131
  ]
132
  },
133
  {
134
  "name": "stdout",
135
  "output_type": "stream",
136
  "text": [
137
+ " > Downloading model to /Users/sasan.jafarnejad/Library/Application Support/tts/tts_models--multilingual--multi-dataset--xtts_v1.1\n"
138
  ]
139
  },
140
  {
141
  "name": "stderr",
142
  "output_type": "stream",
143
  "text": [
144
+ "100%|██████████| 1.87G/1.87G [00:24<00:00, 75.6MiB/s]\n",
145
+ "100%|██████████| 4.70k/4.70k [00:00<00:00, 17.9kiB/s]\n",
146
+ "100%|██████████| 294k/294k [00:00<00:00, 1.23MiB/s]\n",
147
+ "/opt/homebrew/Caskroom/miniconda/base/envs/llm/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
148
  " _torch_pytree._register_pytree_node(\n"
149
  ]
150
  },
 
152
  "name": "stdout",
153
  "output_type": "stream",
154
  "text": [
155
+ " > Model's license - CPML\n",
156
+ " > Check https://coqui.ai/cpml.txt for more info.\n",
157
  " > Using model: xtts\n"
158
  ]
159
  }
 
176
  "collapsed": true,
177
  "id": "JNALTDb0LT90"
178
  },
179
+ "outputs": [
180
+ {
181
+ "name": "stderr",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "config.json: 100%|██████████| 1.42k/1.42k [00:00<00:00, 3.24MB/s]\n",
185
+ "model.safetensors: 100%|██████████| 1.11G/1.11G [00:13<00:00, 79.5MB/s]\n",
186
+ "tokenizer_config.json: 100%|██████████| 502/502 [00:00<00:00, 5.01MB/s]\n",
187
+ "sentencepiece.bpe.model: 100%|██████████| 5.07M/5.07M [00:00<00:00, 78.4MB/s]\n",
188
+ "tokenizer.json: 100%|██████████| 9.08M/9.08M [00:00<00:00, 61.5MB/s]\n",
189
+ "special_tokens_map.json: 100%|██████████| 239/239 [00:00<00:00, 372kB/s]\n"
190
+ ]
191
+ }
192
+ ],
193
  "source": [
194
  "#load model language recognition\n",
195
  "model_ckpt = \"papluca/xlm-roberta-base-language-detection\"\n",
 
212
  "source": [
213
  "#load model llama2\n",
214
  "mn = 'stabilityai/StableBeluga-7B' #mn = \"TheBloke/Llama-2-7b-Chat-GPTQ\" --> other possibility \n",
215
+ "# model = AutoModelForCausalLM.from_pretrained(mn, device_map=0, load_in_4bit=True) #torch_dtype=torch.float16\n",
216
+ "model = AutoModelForCausalLM.from_pretrained(mn, device_map=0) #torch_dtype=torch.float16\n",
217
+ "# tokr = AutoTokenizer.from_pretrained(mn, load_in_4bit=True) #tokenizer\n",
218
+ "tokr = AutoTokenizer.from_pretrained(mn) #tokenizer"
219
  ]
220
  },
221
  {
 
912
  "name": "python",
913
  "nbconvert_exporter": "python",
914
  "pygments_lexer": "ipython3",
915
+ "version": "3.11.8"
916
  }
917
  },
918
  "nbformat": 4,
kitt.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import plotly.express as px
3
+ import requests
4
+
5
+ # INTERFACE WITH AUDIO TO AUDIO
6
+
7
+
8
+ def calculate_route():
9
+ api_key = "api_key"
10
+ origin = "49.631997,6.171029"
11
+ destination = "49.586745,6.140002"
12
+
13
+ url = f"https://api.tomtom.com/routing/1/calculateRoute/{origin}:{destination}/json?key={api_key}"
14
+ response = requests.get(url)
15
+ data = response.json()
16
+
17
+ lats = []
18
+ lons = []
19
+
20
+ for point in data['routes'][0]['legs'][0]['points']:
21
+ lats.append(point['latitude'])
22
+ lons.append(point['longitude'])
23
+ # fig = px.line_geo(lat=lats, lon=lons)
24
+ # fig.update_geos(fitbounds="locations")
25
+
26
+ fig = px.line_mapbox(lat=lats, lon=lons, zoom=12, height=600)
27
+
28
+ fig.update_layout(mapbox_style="open-street-map", mapbox_zoom=12, mapbox_center_lat=lats[0], mapbox_center_lon=lons[0])
29
+ fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
30
+
31
+ return fig
32
+
33
+
34
+ def transcript(
35
+ general_context, link_to_audio, voice, emotion, place, time, delete_history, state
36
+ ):
37
+ """this function manages speech-to-text to input Fnanswer function and text-to-speech with the Fnanswer output"""
38
+ # load audio from a specific path
39
+ audio_path = link_to_audio
40
+ audio_array, sampling_rate = librosa.load(
41
+ link_to_audio, sr=16000
42
+ ) # "sr=16000" ensures that the sampling rate is as required
43
+
44
+ # process the audio array
45
+ input_features = processor(
46
+ audio_array, sampling_rate, return_tensors="pt"
47
+ ).input_features
48
+ predicted_ids = modelw.generate(input_features)
49
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
50
+ quest_processing = FnAnswer(
51
+ general_context, transcription, place, time, delete_history, state
52
+ )
53
+ state = quest_processing[2]
54
+ print("langue " + quest_processing[3])
55
+
56
+ tts.tts_to_file(
57
+ text=str(quest_processing[0]),
58
+ file_path="output.wav",
59
+ speaker_wav=f"Audio_Files/{voice}.wav",
60
+ language=quest_processing[3],
61
+ emotion="angry",
62
+ )
63
+
64
+ audio_path = "output.wav"
65
+ return audio_path, state["context"], state
66
+
67
+
68
+ # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
69
+ # in "Insecure origins treated as secure", enable it and relaunch chrome
70
+
71
+ # example question:
72
+ # what's the weather like outside?
73
+ # What's the closest restaurant from here?
74
+
75
+
76
+ import gradio as gr
77
+
78
+ shortcut_js = """
79
+ <script>
80
+ function shortcuts(e) {
81
+ var event = document.all ? window.event : e;
82
+ switch (e.target.tagName.toLowerCase()) {
83
+ case "input":
84
+ case "textarea":
85
+ break;
86
+ default:
87
+ if (e.key.toLowerCase() == "r" && e.ctrlKey) {
88
+ console.log("recording")
89
+ document.getElementById("recorder").start_recording();
90
+ }
91
+ if (e.key.toLowerCase() == "s" && e.ctrlKey) {
92
+ console.log("stopping")
93
+ document.getElementById("recorder").stop_recording();
94
+ }
95
+ }
96
+ }
97
+ document.addEventListener('keypress', shortcuts, false);
98
+ </script>
99
+ """
100
+
101
+ # with gr.Blocks(head=shortcut_js) as demo:
102
+ # action_button = gr.Button(value="Name", elem_id="recorder")
103
+ # textbox = gr.Textbox()
104
+ # action_button.click(lambda : "button pressed", None, textbox)
105
+
106
+ # demo.launch()
107
+
108
+
109
+ # Generate options for hours (00-23)
110
+ hour_options = [f"{i:02d}:00:00" for i in range(24)]
111
+
112
+ model_answer = ""
113
+ general_context = ""
114
+ # Define the initial state with some initial context.
115
+ print(general_context)
116
+ initial_state = {"context": general_context}
117
+ initial_context = initial_state["context"]
118
+ # Create the Gradio interface.
119
+
120
+
121
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
122
+
123
+ with gr.Row():
124
+ with gr.Column(scale=1, min_width=300):
125
+ time_picker = gr.Dropdown(
126
+ choices=hour_options, label="What time is it?", value="08:00:00"
127
+ )
128
+ history = gr.Radio(
129
+ ["Yes", "No"], label="Maintain the conversation history?", value="No"
130
+ )
131
+ voice_character = gr.Radio(
132
+ choices=[
133
+ "Rick Sanches",
134
+ "Eddie Murphy",
135
+ "David Attenborough",
136
+ "Morgan Freeman",
137
+ ],
138
+ label="Choose a voice",
139
+ value="Rick Sancher",
140
+ show_label=True,
141
+ )
142
+ emotion = gr.Radio(
143
+ choices=["Cheerful", "Grumpy"],
144
+ label="Choose an emotion",
145
+ value="Cheerful",
146
+ show_label=True,
147
+ )
148
+ # place = gr.Radio(
149
+ # choices=[
150
+ # "Luxembourg Gare, Luxembourg",
151
+ # "Kirchberg Campus, Kirchberg",
152
+ # "Belval Campus, Belval",
153
+ # "Eiffel Tower, Paris",
154
+ # "Thionville, France",
155
+ # ],
156
+ # label="Choose a location for your car",
157
+ # value="Kirchberg Campus, Kirchberg",
158
+ # show_label=True,
159
+ # )
160
+ origin = gr.Textbox(value="Luxembourg Gare, Luxembourg", label="Origin", interactive=True)
161
+ destination = gr.Textbox(
162
+ value="Kirchberg Campus, Kirchberg", label="Destination", interactive=True)
163
+ recorder = gr.Audio(type="filepath", label="input audio", elem_id="recorder")
164
+ with gr.Column(scale=2, min_width=600):
165
+ map_plot = gr.Plot()
166
+ origin.submit(fn=calculate_route, outputs=map_plot)
167
+ destination.submit(fn=calculate_route, outputs=map_plot)
168
+ output_audio = gr.Audio(label="output audio")
169
+ # map_if = gr.Interface(fn=plot_map, inputs=year_input, outputs=map_plot)
170
+
171
+ # iface = gr.Interface(
172
+ # fn=transcript,
173
+ # inputs=[
174
+ # gr.Textbox(value=initial_context, visible=False),
175
+ # gr.Audio(type="filepath", label="input audio", elem_id="recorder"),
176
+ # voice_character,
177
+ # emotion,
178
+ # place,
179
+ # time_picker,
180
+ # history,
181
+ # gr.State(), # This will keep track of the context state across interactions.
182
+ # ],
183
+ # outputs=[gr.Audio(label="output audio"), gr.Textbox(visible=False), gr.State()],
184
+ # head=shortcut_js,
185
+ # )
186
+
187
+ # close all interfaces open to make the port available
188
+ gr.close_all()
189
+ # Launch the interface.
190
+
191
+ demo.queue().launch(
192
+ debug=True, server_name="0.0.0.0", server_port=7860, ssl_verify=False
193
+ )
194
+
195
+ # iface.launch(debug=True, share=False, server_name="0.0.0.0", server_port=7860, ssl_verify=False)