deeksonparlma commited on
Commit
15e0607
·
1 Parent(s): 1643835

update on model

Browse files
.ipynb_checkpoints/model-checkpoint.ipynb CHANGED
@@ -2,17 +2,90 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "ace57031",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "Accuracy: 0.0\n",
14
- "Prediction: ['Symptoms of depression include sadness, lack of energy, and loss of interest in activities.']\n"
15
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  }
17
  ],
18
  "source": [
@@ -20,14 +93,116 @@
20
  "from sklearn.model_selection import train_test_split\n",
21
  "from sklearn.linear_model import LogisticRegression\n",
22
  "from sklearn.metrics import accuracy_score\n",
23
- "\n",
 
 
 
 
 
24
  "# Step 1: Collect and preprocess data\n",
25
- "questions = [\"What are some symptoms of depression?\",\n",
26
- " \"How can I manage my anxiety?\",\n",
27
- " \"What are the treatments for bipolar disorder?\"]\n",
28
- "responses = [\"Symptoms of depression include sadness, lack of energy, and loss of interest in activities.\",\n",
29
- " \"You can manage your anxiety through techniques such as deep breathing, meditation, and therapy.\",\n",
30
- " \"Treatments for bipolar disorder include medication, therapy, and lifestyle changes.\"]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  "\n",
32
  "vectorizer = TfidfVectorizer()\n",
33
  "X = vectorizer.fit_transform(questions)\n",
@@ -42,16 +217,56 @@
42
  "# Step 4: Train the model\n",
43
  "model.fit(X_train, y_train)\n",
44
  "\n",
 
 
 
 
 
45
  "# Step 5: Evaluate the model\n",
46
  "y_pred = model.predict(X_test)\n",
47
  "accuracy = accuracy_score(y_test, y_pred)\n",
48
  "print(\"Accuracy:\", accuracy)\n",
49
  "\n",
50
  "# Step 6: Use the model to make predictions\n",
51
- "new_question = \"What are the symptoms of anxiety?\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  "new_question_vector = vectorizer.transform([new_question])\n",
53
  "prediction = model.predict(new_question_vector)\n",
54
- "print(\"Prediction:\", prediction)\n"
55
  ]
56
  }
57
  ],
@@ -72,6 +287,11 @@
72
  "nbconvert_exporter": "python",
73
  "pygments_lexer": "ipython3",
74
  "version": "3.10.7"
 
 
 
 
 
75
  }
76
  },
77
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 8,
6
  "id": "ace57031",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
+ "data": {
11
+ "text/html": [
12
+ "<div>\n",
13
+ "<style scoped>\n",
14
+ " .dataframe tbody tr th:only-of-type {\n",
15
+ " vertical-align: middle;\n",
16
+ " }\n",
17
+ "\n",
18
+ " .dataframe tbody tr th {\n",
19
+ " vertical-align: top;\n",
20
+ " }\n",
21
+ "\n",
22
+ " .dataframe thead th {\n",
23
+ " text-align: right;\n",
24
+ " }\n",
25
+ "</style>\n",
26
+ "<table border=\"1\" class=\"dataframe\">\n",
27
+ " <thead>\n",
28
+ " <tr style=\"text-align: right;\">\n",
29
+ " <th></th>\n",
30
+ " <th>Question_ID</th>\n",
31
+ " <th>Questions</th>\n",
32
+ " <th>Answers</th>\n",
33
+ " </tr>\n",
34
+ " </thead>\n",
35
+ " <tbody>\n",
36
+ " <tr>\n",
37
+ " <th>0</th>\n",
38
+ " <td>1590140</td>\n",
39
+ " <td>What does it mean to have a mental illness?</td>\n",
40
+ " <td>Mental illnesses are health conditions that di...</td>\n",
41
+ " </tr>\n",
42
+ " <tr>\n",
43
+ " <th>1</th>\n",
44
+ " <td>2110618</td>\n",
45
+ " <td>Who does mental illness affect?</td>\n",
46
+ " <td>It is estimated that mental illness affects 1 ...</td>\n",
47
+ " </tr>\n",
48
+ " <tr>\n",
49
+ " <th>2</th>\n",
50
+ " <td>6361820</td>\n",
51
+ " <td>What causes mental illness?</td>\n",
52
+ " <td>It is estimated that mental illness affects 1 ...</td>\n",
53
+ " </tr>\n",
54
+ " <tr>\n",
55
+ " <th>3</th>\n",
56
+ " <td>9434130</td>\n",
57
+ " <td>What are some of the warning signs of mental i...</td>\n",
58
+ " <td>Symptoms of mental health disorders vary depen...</td>\n",
59
+ " </tr>\n",
60
+ " <tr>\n",
61
+ " <th>4</th>\n",
62
+ " <td>7657263</td>\n",
63
+ " <td>Can people with mental illness recover?</td>\n",
64
+ " <td>When healing from mental illness, early identi...</td>\n",
65
+ " </tr>\n",
66
+ " </tbody>\n",
67
+ "</table>\n",
68
+ "</div>"
69
+ ],
70
+ "text/plain": [
71
+ " Question_ID Questions \\\n",
72
+ "0 1590140 What does it mean to have a mental illness? \n",
73
+ "1 2110618 Who does mental illness affect? \n",
74
+ "2 6361820 What causes mental illness? \n",
75
+ "3 9434130 What are some of the warning signs of mental i... \n",
76
+ "4 7657263 Can people with mental illness recover? \n",
77
+ "\n",
78
+ " Answers \n",
79
+ "0 Mental illnesses are health conditions that di... \n",
80
+ "1 It is estimated that mental illness affects 1 ... \n",
81
+ "2 It is estimated that mental illness affects 1 ... \n",
82
+ "3 Symptoms of mental health disorders vary depen... \n",
83
+ "4 When healing from mental illness, early identi... "
84
+ ]
85
+ },
86
+ "execution_count": 8,
87
+ "metadata": {},
88
+ "output_type": "execute_result"
89
  }
90
  ],
91
  "source": [
 
93
  "from sklearn.model_selection import train_test_split\n",
94
  "from sklearn.linear_model import LogisticRegression\n",
95
  "from sklearn.metrics import accuracy_score\n",
96
+ "import pandas as pd\n",
97
+ "import numpy as np\n",
98
+ "import torch\n",
99
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
100
+ "from huggingface_hub import notebook_login\n",
101
+ "# notebook_login()\n",
102
  "# Step 1: Collect and preprocess data\n",
103
+ "# Get all the questions from Questions column and responses from Questions column in the dataset data.csv\n",
104
+ "# questions = data[\"Questions\"].tolist()\n",
105
+ "# responses = data[\"Responses\"].tolist()\n",
106
+ "questions = []\n",
107
+ "responses = []\n",
108
+ "q_id = []\n",
109
+ "with open(\"mental_health_bot.csv\", \"r\") as f:\n",
110
+ " for line in f:\n",
111
+ " \n",
112
+ " array = line.split(\",\") \n",
113
+ " # questions.append(question)\n",
114
+ " # responses.append(response)\n",
115
+ " # q_id.append(question_id)\n",
116
+ " try:\n",
117
+ " question = array[1]\n",
118
+ " response = array[2]\n",
119
+ " question_id = array[0]\n",
120
+ " questions.append(question)\n",
121
+ " responses.append(response)\n",
122
+ " q_id.append(question_id)\n",
123
+ " except:\n",
124
+ " pass\n",
125
+ "\n",
126
+ "data = pd.read_csv(\"data.csv\")\n",
127
+ "data.head()\n",
128
+ " \n",
129
+ "\n"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": 9,
135
+ "id": "60e154b4",
136
+ "metadata": {},
137
+ "outputs": [
138
+ {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "missing values: Question_ID 0\n",
143
+ "Questions 0\n",
144
+ "Answers 0\n",
145
+ "dtype: int64\n"
146
+ ]
147
+ }
148
+ ],
149
+ "source": [
150
+ "print('missing values:', data.isnull().sum())"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 10,
156
+ "id": "41311468",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "name": "stdout",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "<class 'pandas.core.frame.DataFrame'>\n",
164
+ "RangeIndex: 149 entries, 0 to 148\n",
165
+ "Data columns (total 3 columns):\n",
166
+ " # Column Non-Null Count Dtype \n",
167
+ "--- ------ -------------- ----- \n",
168
+ " 0 Question_ID 149 non-null object\n",
169
+ " 1 Questions 149 non-null object\n",
170
+ " 2 Answers 149 non-null object\n",
171
+ "dtypes: object(3)\n",
172
+ "memory usage: 3.6+ KB\n",
173
+ "None\n"
174
+ ]
175
+ }
176
+ ],
177
+ "source": [
178
+ "print(data.info())"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 12,
184
+ "id": "f6719ffa",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Accuracy: 0.03333333333333333\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "# print(questions)\n",
197
+ "# print(responses)\n",
198
+ "\n",
199
+ "\n",
200
+ "# questions = [\"What are some symptoms of depression?\",\n",
201
+ "# \"How can I manage my anxiety?\",\n",
202
+ "# \"What are the treatments for bipolar disorder?\"]\n",
203
+ "# responses = [\"Symptoms of depression include sadness, lack of energy, and loss of interest in activities.\",\n",
204
+ "# \"You can manage your anxiety through techniques such as deep breathing, meditation, and therapy.\",\n",
205
+ "# \"Treatments for bipolar disorder include medication, therapy, and lifestyle changes.\"]\n",
206
  "\n",
207
  "vectorizer = TfidfVectorizer()\n",
208
  "X = vectorizer.fit_transform(questions)\n",
 
217
  "# Step 4: Train the model\n",
218
  "model.fit(X_train, y_train)\n",
219
  "\n",
220
+ "# model.push_to_hub(\"tabibu-ai/mental-health-chatbot\")\n",
221
+ "# pt_model = DistilBertForSequenceClassification.from_pretrained(\"model.ipynb\", from_tf=True)\n",
222
+ "# pt_model.save_pretrained(\"model.ipynb\")\n",
223
+ "# load model from hub\n",
224
+ "\n",
225
  "# Step 5: Evaluate the model\n",
226
  "y_pred = model.predict(X_test)\n",
227
  "accuracy = accuracy_score(y_test, y_pred)\n",
228
  "print(\"Accuracy:\", accuracy)\n",
229
  "\n",
230
  "# Step 6: Use the model to make predictions\n",
231
+ "\n"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 14,
237
+ "id": "d8d18524",
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "Ask me anythingWho are you\n"
245
+ ]
246
+ }
247
+ ],
248
+ "source": [
249
+ "new_question = input(\"Ask me anything : \")\n"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 15,
255
+ "id": "e51d4ca5",
256
+ "metadata": {},
257
+ "outputs": [
258
+ {
259
+ "name": "stdout",
260
+ "output_type": "stream",
261
+ "text": [
262
+ "Prediction: ['\"It is estimated that mental illness affects 1 in 5 adults in America']\n"
263
+ ]
264
+ }
265
+ ],
266
+ "source": [
267
  "new_question_vector = vectorizer.transform([new_question])\n",
268
  "prediction = model.predict(new_question_vector)\n",
269
+ "print(\"Prediction:\", prediction)"
270
  ]
271
  }
272
  ],
 
287
  "nbconvert_exporter": "python",
288
  "pygments_lexer": "ipython3",
289
  "version": "3.10.7"
290
+ },
291
+ "vscode": {
292
+ "interpreter": {
293
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
294
+ }
295
  }
296
  },
297
  "nbformat": 4,
.~lock.mental_health_bot.xlsx# ADDED
@@ -0,0 +1 @@
 
 
1
+ ,dickson,dickson,20.02.2023 22:13,file:///home/dickson/.config/libreoffice/4;
excel-data.xls ADDED
The diff for this file is too large to render. See raw diff
 
mental_health_bot.csv ADDED
The diff for this file is too large to render. See raw diff
 
mental_health_bot.ods ADDED
Binary file (85.9 kB). View file
 
mental_health_bot.xlsx ADDED
Binary file (55.6 kB). View file
 
model.ipynb CHANGED
@@ -2,17 +2,90 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 15,
6
  "id": "ace57031",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "Accuracy: 0.023255813953488372\n",
14
- "Prediction: [' is a member of the BC Partners for Mental Health and Addictions Information. The institute is dedicated to the study of substance use in support of community-wide efforts aimed at providing all people with access to healthier lives']\n"
15
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  }
17
  ],
18
  "source": [
@@ -20,7 +93,12 @@
20
  "from sklearn.model_selection import train_test_split\n",
21
  "from sklearn.linear_model import LogisticRegression\n",
22
  "from sklearn.metrics import accuracy_score\n",
23
- "\n",
 
 
 
 
 
24
  "# Step 1: Collect and preprocess data\n",
25
  "# Get all the questions from Questions column and responses from Questions column in the dataset data.csv\n",
26
  "# questions = data[\"Questions\"].tolist()\n",
@@ -28,7 +106,7 @@
28
  "questions = []\n",
29
  "responses = []\n",
30
  "q_id = []\n",
31
- "with open(\"data.csv\", \"r\") as f:\n",
32
  " for line in f:\n",
33
  " \n",
34
  " array = line.split(\",\") \n",
@@ -45,9 +123,74 @@
45
  " except:\n",
46
  " pass\n",
47
  "\n",
48
- "\n",
49
- " \n",
50
- "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  "# print(questions)\n",
52
  "# print(responses)\n",
53
  "\n",
@@ -72,8 +215,9 @@
72
  "# Step 4: Train the model\n",
73
  "model.fit(X_train, y_train)\n",
74
  "\n",
75
- "model.push_to_hub(\"tabibu-ai/mental-health-chatbot\")\n",
76
- "\n",
 
77
  "# load model from hub\n",
78
  "\n",
79
  "# Step 5: Evaluate the model\n",
@@ -82,16 +226,51 @@
82
  "print(\"Accuracy:\", accuracy)\n",
83
  "\n",
84
  "# Step 6: Use the model to make predictions\n",
85
- "new_question = \"I feel sad\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  "new_question_vector = vectorizer.transform([new_question])\n",
87
  "prediction = model.predict(new_question_vector)\n",
88
- "print(\"Prediction:\", prediction)\n"
89
  ]
90
  }
91
  ],
92
  "metadata": {
93
  "kernelspec": {
94
- "display_name": "Python 3",
95
  "language": "python",
96
  "name": "python3"
97
  },
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 8,
6
  "id": "ace57031",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
+ "data": {
11
+ "text/html": [
12
+ "<div>\n",
13
+ "<style scoped>\n",
14
+ " .dataframe tbody tr th:only-of-type {\n",
15
+ " vertical-align: middle;\n",
16
+ " }\n",
17
+ "\n",
18
+ " .dataframe tbody tr th {\n",
19
+ " vertical-align: top;\n",
20
+ " }\n",
21
+ "\n",
22
+ " .dataframe thead th {\n",
23
+ " text-align: right;\n",
24
+ " }\n",
25
+ "</style>\n",
26
+ "<table border=\"1\" class=\"dataframe\">\n",
27
+ " <thead>\n",
28
+ " <tr style=\"text-align: right;\">\n",
29
+ " <th></th>\n",
30
+ " <th>Question_ID</th>\n",
31
+ " <th>Questions</th>\n",
32
+ " <th>Answers</th>\n",
33
+ " </tr>\n",
34
+ " </thead>\n",
35
+ " <tbody>\n",
36
+ " <tr>\n",
37
+ " <th>0</th>\n",
38
+ " <td>1590140</td>\n",
39
+ " <td>What does it mean to have a mental illness?</td>\n",
40
+ " <td>Mental illnesses are health conditions that di...</td>\n",
41
+ " </tr>\n",
42
+ " <tr>\n",
43
+ " <th>1</th>\n",
44
+ " <td>2110618</td>\n",
45
+ " <td>Who does mental illness affect?</td>\n",
46
+ " <td>It is estimated that mental illness affects 1 ...</td>\n",
47
+ " </tr>\n",
48
+ " <tr>\n",
49
+ " <th>2</th>\n",
50
+ " <td>6361820</td>\n",
51
+ " <td>What causes mental illness?</td>\n",
52
+ " <td>It is estimated that mental illness affects 1 ...</td>\n",
53
+ " </tr>\n",
54
+ " <tr>\n",
55
+ " <th>3</th>\n",
56
+ " <td>9434130</td>\n",
57
+ " <td>What are some of the warning signs of mental i...</td>\n",
58
+ " <td>Symptoms of mental health disorders vary depen...</td>\n",
59
+ " </tr>\n",
60
+ " <tr>\n",
61
+ " <th>4</th>\n",
62
+ " <td>7657263</td>\n",
63
+ " <td>Can people with mental illness recover?</td>\n",
64
+ " <td>When healing from mental illness, early identi...</td>\n",
65
+ " </tr>\n",
66
+ " </tbody>\n",
67
+ "</table>\n",
68
+ "</div>"
69
+ ],
70
+ "text/plain": [
71
+ " Question_ID Questions \\\n",
72
+ "0 1590140 What does it mean to have a mental illness? \n",
73
+ "1 2110618 Who does mental illness affect? \n",
74
+ "2 6361820 What causes mental illness? \n",
75
+ "3 9434130 What are some of the warning signs of mental i... \n",
76
+ "4 7657263 Can people with mental illness recover? \n",
77
+ "\n",
78
+ " Answers \n",
79
+ "0 Mental illnesses are health conditions that di... \n",
80
+ "1 It is estimated that mental illness affects 1 ... \n",
81
+ "2 It is estimated that mental illness affects 1 ... \n",
82
+ "3 Symptoms of mental health disorders vary depen... \n",
83
+ "4 When healing from mental illness, early identi... "
84
+ ]
85
+ },
86
+ "execution_count": 8,
87
+ "metadata": {},
88
+ "output_type": "execute_result"
89
  }
90
  ],
91
  "source": [
 
93
  "from sklearn.model_selection import train_test_split\n",
94
  "from sklearn.linear_model import LogisticRegression\n",
95
  "from sklearn.metrics import accuracy_score\n",
96
+ "import pandas as pd\n",
97
+ "import numpy as np\n",
98
+ "import torch\n",
99
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
100
+ "from huggingface_hub import notebook_login\n",
101
+ "# notebook_login()\n",
102
  "# Step 1: Collect and preprocess data\n",
103
  "# Get all the questions from Questions column and responses from Questions column in the dataset data.csv\n",
104
  "# questions = data[\"Questions\"].tolist()\n",
 
106
  "questions = []\n",
107
  "responses = []\n",
108
  "q_id = []\n",
109
+ "with open(\"mental_health_bot.csv\", \"r\") as f:\n",
110
  " for line in f:\n",
111
  " \n",
112
  " array = line.split(\",\") \n",
 
123
  " except:\n",
124
  " pass\n",
125
  "\n",
126
+ "data = pd.read_csv(\"data.csv\")\n",
127
+ "data.head()"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 9,
133
+ "id": "8f51e39d",
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "name": "stdout",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "missing values: Question_ID 0\n",
141
+ "Questions 0\n",
142
+ "Answers 0\n",
143
+ "dtype: int64\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "print('missing values:', data.isnull().sum())"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 10,
154
+ "id": "1d697a39",
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "name": "stdout",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "<class 'pandas.core.frame.DataFrame'>\n",
162
+ "RangeIndex: 149 entries, 0 to 148\n",
163
+ "Data columns (total 3 columns):\n",
164
+ " # Column Non-Null Count Dtype \n",
165
+ "--- ------ -------------- ----- \n",
166
+ " 0 Question_ID 149 non-null object\n",
167
+ " 1 Questions 149 non-null object\n",
168
+ " 2 Answers 149 non-null object\n",
169
+ "dtypes: object(3)\n",
170
+ "memory usage: 3.6+ KB\n",
171
+ "None\n"
172
+ ]
173
+ }
174
+ ],
175
+ "source": [
176
+ "print(data.info())"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 12,
182
+ "id": "c5dde0e4",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "Accuracy: 0.03333333333333333\n"
190
+ ]
191
+ }
192
+ ],
193
+ "source": [
194
  "# print(questions)\n",
195
  "# print(responses)\n",
196
  "\n",
 
215
  "# Step 4: Train the model\n",
216
  "model.fit(X_train, y_train)\n",
217
  "\n",
218
+ "# model.push_to_hub(\"tabibu-ai/mental-health-chatbot\")\n",
219
+ "pt_model = DistilBertForSequenceClassification.from_pretrained(\"model.ipynb\", from_tf=True)\n",
220
+ "pt_model.save_pretrained(\"model.ipynb\")\n",
221
  "# load model from hub\n",
222
  "\n",
223
  "# Step 5: Evaluate the model\n",
 
226
  "print(\"Accuracy:\", accuracy)\n",
227
  "\n",
228
  "# Step 6: Use the model to make predictions\n",
229
+ "\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 18,
235
+ "id": "14406312",
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "name": "stdout",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "Ask me anything : I feel sad\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "new_question = input(\"Ask me anything : \")\n"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 17,
253
+ "id": "6b9198db",
254
+ "metadata": {},
255
+ "outputs": [
256
+ {
257
+ "name": "stdout",
258
+ "output_type": "stream",
259
+ "text": [
260
+ "Prediction: ['\"It is estimated that mental illness affects 1 in 5 adults in America']\n"
261
+ ]
262
+ }
263
+ ],
264
+ "source": [
265
  "new_question_vector = vectorizer.transform([new_question])\n",
266
  "prediction = model.predict(new_question_vector)\n",
267
+ "print(\"Prediction:\", prediction)"
268
  ]
269
  }
270
  ],
271
  "metadata": {
272
  "kernelspec": {
273
+ "display_name": "Python 3 (ipykernel)",
274
  "language": "python",
275
  "name": "python3"
276
  },
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  torch
2
- transformers
 
 
 
1
  torch
2
+ transformers
3
+ huggingface_hub
4
+ tensorflow