{ "cells": [ { "cell_type": "code", "execution_count": 8, "id": "ace57031", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Question_IDQuestionsAnswers
01590140What does it mean to have a mental illness?Mental illnesses are health conditions that di...
12110618Who does mental illness affect?It is estimated that mental illness affects 1 ...
26361820What causes mental illness?It is estimated that mental illness affects 1 ...
39434130What are some of the warning signs of mental i...Symptoms of mental health disorders vary depen...
47657263Can people with mental illness recover?When healing from mental illness, early identi...
\n", "
" ], "text/plain": [ " Question_ID Questions \\\n", "0 1590140 What does it mean to have a mental illness? \n", "1 2110618 Who does mental illness affect? \n", "2 6361820 What causes mental illness? \n", "3 9434130 What are some of the warning signs of mental i... \n", "4 7657263 Can people with mental illness recover? \n", "\n", " Answers \n", "0 Mental illnesses are health conditions that di... \n", "1 It is estimated that mental illness affects 1 ... \n", "2 It is estimated that mental illness affects 1 ... \n", "3 Symptoms of mental health disorders vary depen... \n", "4 When healing from mental illness, early identi... " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.metrics import accuracy_score\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "from huggingface_hub import notebook_login\n", "# notebook_login()\n", "# Step 1: Collect and preprocess data\n", "# Get all the questions from Questions column and responses from Questions column in the dataset data.csv\n", "# questions = data[\"Questions\"].tolist()\n", "# responses = data[\"Responses\"].tolist()\n", "questions = []\n", "responses = []\n", "q_id = []\n", "with open(\"mental_health_bot.csv\", \"r\") as f:\n", " for line in f:\n", " \n", " array = line.split(\",\") \n", " # questions.append(question)\n", " # responses.append(response)\n", " # q_id.append(question_id)\n", " try:\n", " question = array[1]\n", " response = array[2]\n", " question_id = array[0]\n", " questions.append(question)\n", " responses.append(response)\n", " q_id.append(question_id)\n", " except:\n", " pass\n", "\n", "data = pd.read_csv(\"data.csv\")\n", "data.head()\n", " \n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "60e154b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "missing values: Question_ID 0\n", "Questions 0\n", "Answers 0\n", "dtype: int64\n" ] } ], "source": [ "print('missing values:', data.isnull().sum())" ] }, { "cell_type": "code", "execution_count": 10, "id": "41311468", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 149 entries, 0 to 148\n", "Data columns (total 3 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Question_ID 149 non-null object\n", " 1 Questions 149 non-null object\n", " 2 Answers 149 non-null object\n", "dtypes: object(3)\n", "memory usage: 3.6+ KB\n", "None\n" ] } ], "source": [ "print(data.info())" ] }, { "cell_type": "code", "execution_count": 12, "id": "f6719ffa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.03333333333333333\n" ] } ], "source": [ "# print(questions)\n", "# print(responses)\n", "\n", "\n", "# questions = [\"What are some symptoms of depression?\",\n", "# \"How can I manage my anxiety?\",\n", "# \"What are the treatments for bipolar disorder?\"]\n", "# responses = [\"Symptoms of depression include sadness, lack of energy, and loss of interest in activities.\",\n", "# \"You can manage your anxiety through techniques such as deep breathing, meditation, and therapy.\",\n", "# \"Treatments for bipolar disorder include medication, therapy, and lifestyle changes.\"]\n", "\n", "vectorizer = TfidfVectorizer()\n", "X = vectorizer.fit_transform(questions)\n", "y = responses\n", "\n", "# Step 2: Split data into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n", "\n", "# Step 3: Choose a machine learning algorithm\n", "model = LogisticRegression()\n", "\n", "# Step 4: Train the model\n", "model.fit(X_train, y_train)\n", "\n", "# model.push_to_hub(\"tabibu-ai/mental-health-chatbot\")\n", "# pt_model = DistilBertForSequenceClassification.from_pretrained(\"model.ipynb\", from_tf=True)\n", "# pt_model.save_pretrained(\"model.ipynb\")\n", "# load model from hub\n", "\n", "# Step 5: Evaluate the model\n", "y_pred = model.predict(X_test)\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(\"Accuracy:\", accuracy)\n", "\n", "# Step 6: Use the model to make predictions\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "d8d18524", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ask me anythingWho are you\n" ] } ], "source": [ "new_question = input(\"Ask me anything : \")\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "e51d4ca5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction: ['\"It is estimated that mental illness affects 1 in 5 adults in America']\n" ] } ], "source": [ "new_question_vector = vectorizer.transform([new_question])\n", "prediction = model.predict(new_question_vector)\n", "print(\"Prediction:\", prediction)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.7" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 5 }