add application files
Browse files- .gitattributes +36 -35
- .gitignore +158 -0
- LICENSE +21 -0
- README.md +83 -13
- app.py +206 -0
- database.py +132 -0
- download.py +40 -0
- functions.py +270 -0
- generate.py +80 -0
- lib.txt +15 -0
- model.py +74 -0
- train.py +136 -0
.gitattributes
CHANGED
@@ -1,35 +1,36 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/** filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
lab/
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
database.db
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
113 |
+
.pdm.toml
|
114 |
+
.pdm-python
|
115 |
+
.pdm-build/
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Alireza
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,83 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GPT Tools
|
2 |
+
|
3 |
+
![GUI)](https://github.com/user-attachments/assets/6a845c99-6ecc-422f-b662-8069cb5c2324)
|
4 |
+
|
5 |
+
---
|
6 |
+
این پروژه یک رابط کاربری گرافیکی زیبا و کاربردی برای تولید متن، کد، داستانهای تعاملی و ارزیابی مدلهای مختلف مانند GPT-2 و CodeGen ارائه میدهد. با استفاده از این ابزار میتوانید به راحتی مدلهای زبان طبیعی را مدیریت و از خروجیهای آن بهرهبرداری کنید
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
## **🚨 Requirements**
|
11 |
+
|
12 |
+
این پروژه برای اجرا نیاز به **پایتون نسخه 3.8.6** دارد. لطفاً اطمینان حاصل کنید که نسخه صحیح پایتون روی سیستم شما نصب است.
|
13 |
+
برای بررسی نسخه پایتون، دستور زیر را در خط فرمان اجرا کنید:
|
14 |
+
```bash
|
15 |
+
python --version
|
16 |
+
```
|
17 |
+
|
18 |
+
---
|
19 |
+
|
20 |
+
## **💫 Main features**
|
21 |
+
|
22 |
+
- تولید متن: تولید متنهای خلاقانه با استفاده از مدلهای مختلف GPT-2.
|
23 |
+
- تولید کد: تولید کدهای برنامهنویسی با مدل CodeGen از طریق ورودیهای توصیفی.
|
24 |
+
- داستانهای تعاملی: ایجاد داستانهای سفارشی و خلاقانه با همکاری مدل.
|
25 |
+
- مدیریت مدلها: دانلود و ذخیره مدلها در مسیرهای سفارشی.
|
26 |
+
- آموزش مدلها: آموزش مجدد مدلها با دادههای دلخواه و ذخیرهسازی تغییرات.
|
27 |
+
---
|
28 |
+
## **📁 Project Structure**
|
29 |
+
```bash
|
30 |
+
.
|
31 |
+
├── app.py # رابط کاربری گرافیکی (Gradio)
|
32 |
+
├── model.py # مدیریت و بارگذاری مدلها
|
33 |
+
├── generate.py # منطق تولید متن و کد
|
34 |
+
├── train.py # آموزش مجدد مدلها
|
35 |
+
├── database.py # مدیریت پایگاه داده برای ذخیره ورودیها
|
36 |
+
├── models/ # مسیر پیشفرض برای ذخیره مدلها
|
37 |
+
└── lib.txt # لیست کتابخانههای موردنیاز
|
38 |
+
```
|
39 |
+
---
|
40 |
+
## **🚀 Installation and setup**
|
41 |
+
|
42 |
+
### **نصب پایتون 3.8.6**
|
43 |
+
اگر پایتون نسخه 3.8.6 روی سیستم شما نصب نیست، از صفحه دانلود آن را نصب کنید.
|
44 |
+
در سیستمهای لینوکسی میتوانید از دستورات زیر استفاده کنید
|
45 |
+
```bash
|
46 |
+
sudo apt update
|
47 |
+
sudo apt install python3.8
|
48 |
+
```
|
49 |
+
---
|
50 |
+
### **کلون کردن**
|
51 |
+
ابتدا مخزن پروژه رو کلون کنید
|
52 |
+
```bash
|
53 |
+
git clone https://github.com/ali0discord/gpt-text-generator.git
|
54 |
+
cd gpt-text-generator
|
55 |
+
```
|
56 |
+
---
|
57 |
+
### **نصب کتابخانه ها**
|
58 |
+
با این دستور کتابخانه ها رو نصب کنید
|
59 |
+
```bash
|
60 |
+
pip install -r lib.txt
|
61 |
+
```
|
62 |
+
---
|
63 |
+
### **دانلود مدل ها**
|
64 |
+
با اجرای این فایل مدل های مورد نیاز به صورت خودکار دانلود و در دایرکتوری مخصوص ذخیره میشوند
|
65 |
+
```bash
|
66 |
+
python download.py
|
67 |
+
```
|
68 |
+
---
|
69 |
+
### **اجرای کدها**
|
70 |
+
با دستور زیر کد ها رو اجرا کنید
|
71 |
+
```bash
|
72 |
+
python app.py
|
73 |
+
```
|
74 |
+
پس از اجرای کد ها به صورت کامل وارد این آدرش شوید
|
75 |
+
```bash
|
76 |
+
127.0.0.1:7860
|
77 |
+
```
|
78 |
+
---
|
79 |
+
## **گزارش باگ ها**
|
80 |
+
از طریق بخش Issues گیت هاب با ما در ارتباط باشید
|
81 |
+
|
82 |
+
---
|
83 |
+
### **جان گرفته از علیرضا**
|
app.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from database import create_db
|
3 |
+
from functions import *
|
4 |
+
from functions import _generate_code
|
5 |
+
|
6 |
+
# Supported models
|
7 |
+
models_options_general = ['GPT2', 'GPT2-medium', 'GPT2-large', 'GPT2-persian', 'GPT-Neo-125M']
|
8 |
+
models_options_codegen = ['codegen']
|
9 |
+
models_options_chatbot = ['dialoGPT', 'dialoGPT-medium', 'dialoGPT-large']
|
10 |
+
|
11 |
+
# Create database
|
12 |
+
create_db()
|
13 |
+
|
14 |
+
# Interface setup
|
15 |
+
with gr.Blocks() as interface:
|
16 |
+
gr.Markdown(
|
17 |
+
"# **GPT Tools**\n\n"
|
18 |
+
"Generate something using GPT models. Select the model and adjust the parameters for optimal results."
|
19 |
+
)
|
20 |
+
with gr.Tabs():
|
21 |
+
with gr.Tab("Text Generator"):
|
22 |
+
with gr.Row():
|
23 |
+
with gr.Column(scale=1, min_width=350):
|
24 |
+
input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=4, max_lines=6)
|
25 |
+
selected_model = gr.Radio(choices=models_options_general, value="GPT2", label="Select Model", type="value")
|
26 |
+
with gr.Row():
|
27 |
+
max_tokens = gr.Slider(10, 100, value=50, step=1, label="Max New Tokens", interactive=True)
|
28 |
+
with gr.Column(scale=1, min_width=350):
|
29 |
+
output_text = gr.Textbox(label="Generated Text", interactive=False, lines=8, max_lines=12)
|
30 |
+
generate_button = gr.Button("Generate Text", variant="primary")
|
31 |
+
|
32 |
+
generate_button.click(
|
33 |
+
generate,
|
34 |
+
inputs=[input_text, selected_model, max_tokens],
|
35 |
+
outputs=output_text,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
with gr.Tab("Multiverse Story Generator"):
|
40 |
+
with gr.Row():
|
41 |
+
with gr.Column(scale=1, min_width=350):
|
42 |
+
input_text = gr.Textbox(label="Enter your story idea", placeholder="e.g. A scientist discovers a parallel universe...", lines=4, max_lines=6)
|
43 |
+
selected_model = gr.Radio(choices=models_options_general, value="GPT2", label="Select Model for Story Generation", type="value")
|
44 |
+
max_length = gr.Slider(50, 300, value=150, step=1, label="Max Length", interactive=True)
|
45 |
+
|
46 |
+
with gr.Column(scale=1, min_width=350):
|
47 |
+
output_text = gr.Textbox(label="Generated Worlds", interactive=False, lines=12, max_lines=20)
|
48 |
+
generate_button = gr.Button("Generate Parallel Worlds", variant="primary")
|
49 |
+
|
50 |
+
generate_button.click(
|
51 |
+
generate_multiverse,
|
52 |
+
inputs=[input_text, selected_model, max_length],
|
53 |
+
outputs=output_text,
|
54 |
+
)
|
55 |
+
|
56 |
+
with gr.Tab("Interactive Story Writing"):
|
57 |
+
with gr.Row():
|
58 |
+
with gr.Column(scale=1, min_width=350):
|
59 |
+
story_input = gr.Textbox(label="Add to Story", placeholder="Enter your part of the story...", lines=4, max_lines=6)
|
60 |
+
story_model = gr.Radio(choices=models_options_general, value="GPT2", label="Select Model", type="value")
|
61 |
+
story_max_length = gr.Slider(50, 300, value=50, step=1, label="Max Length", interactive=True)
|
62 |
+
with gr.Column(scale=1, min_width=350):
|
63 |
+
story_text = gr.Textbox(label="Story So Far", interactive=False, lines=12, max_lines=20)
|
64 |
+
story_button = gr.Button("Generate Next Part", variant="primary")
|
65 |
+
reset_button = gr.Button("Reset Story", variant="secondary")
|
66 |
+
|
67 |
+
story_button.click(
|
68 |
+
interactive_story,
|
69 |
+
inputs=[story_input, story_model, story_max_length],
|
70 |
+
outputs=story_text,
|
71 |
+
)
|
72 |
+
reset_button.click(
|
73 |
+
reset_story,
|
74 |
+
inputs=[],
|
75 |
+
outputs=story_text,
|
76 |
+
)
|
77 |
+
|
78 |
+
with gr.Tab("Training"):
|
79 |
+
gr.Markdown("# **Train Model**\n\n")
|
80 |
+
with gr.Column(scale=1, min_width=250):
|
81 |
+
train_model_selector = gr.Radio(choices=models_options_general, value="GPT2", label="Select Model for Training", type="value")
|
82 |
+
train_method = gr.Radio(
|
83 |
+
choices=["Custom Text", "Database", "Dataset File", "Hugging Face Dataset"],
|
84 |
+
value="Custom Text",
|
85 |
+
label="Training Method",
|
86 |
+
type="value"
|
87 |
+
)
|
88 |
+
dataset_name = gr.Textbox(label="Hugging Face Dataset Name", placeholder="Enter dataset name (e.g., ag_news)")
|
89 |
+
split_name = gr.Textbox(label="Dataset Split", placeholder="e.g., train, test, validation")
|
90 |
+
epochs = gr.Slider(1, 100, value=10, step=1, label="Epochs", interactive=True)
|
91 |
+
batch_size = gr.Slider(1, 100, value=8, step=1, label="Batch Size", interactive=True)
|
92 |
+
password = gr.Textbox(label="Enter Training Password", placeholder="Enter password", type="password")
|
93 |
+
custom_text = gr.Textbox(label="Custom Text (optional)", placeholder="Enter custom text for training...")
|
94 |
+
dataset_file = gr.File(label="Upload Dataset", type="filepath", file_types=[".parquet", ".csv", ".json", ".txt"])
|
95 |
+
train_button = gr.Button("Train Model", variant="primary")
|
96 |
+
train_status = gr.Textbox(label="Training Status", interactive=False)
|
97 |
+
|
98 |
+
train_button.click(
|
99 |
+
verify_and_train_combined,
|
100 |
+
inputs=[train_model_selector, train_method, epochs, batch_size, password, custom_text, dataset_file, dataset_name, split_name],
|
101 |
+
outputs=train_status,
|
102 |
+
)
|
103 |
+
train_button.click(
|
104 |
+
verify_and_train_combined,
|
105 |
+
inputs=[train_model_selector, train_method, epochs, batch_size, password, custom_text, dataset_file, dataset_name, split_name],
|
106 |
+
outputs=train_status,
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.Tab("Code Generator"):
|
110 |
+
gr.Markdown("### Generate Code from Descriptions")
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column(scale=1, min_width=350):
|
113 |
+
code_prompt = gr.Textbox(label="Code Prompt", placeholder="Describe your coding task, e.g., 'Write a Python function to calculate Fibonacci numbers.'")
|
114 |
+
code_max_tokens = gr.Slider(10, 500, value=150, step=10, label="Max Tokens")
|
115 |
+
with gr.Column(scale=1, min_width=350):
|
116 |
+
generated_code = gr.Textbox(label="Generated Code", interactive=False, lines=10, max_lines=20)
|
117 |
+
generate_code_button = gr.Button("Generate Code")
|
118 |
+
|
119 |
+
generate_code_button.click(
|
120 |
+
_generate_code,
|
121 |
+
inputs=[code_prompt, code_max_tokens],
|
122 |
+
outputs=generated_code,
|
123 |
+
)
|
124 |
+
|
125 |
+
# Add AI-Powered Story World Builder Tab
|
126 |
+
with gr.Tab("Story World Builder"):
|
127 |
+
with gr.Row():
|
128 |
+
with gr.Column(scale=1, min_width=350):
|
129 |
+
world_name = gr.Textbox(label="World Name", placeholder="Enter your world name...")
|
130 |
+
locations = gr.Textbox(label="Locations", placeholder="Enter locations separated by commas...")
|
131 |
+
characters = gr.Textbox(label="Characters", placeholder="Enter characters separated by commas...")
|
132 |
+
create_button = gr.Button("Create World", variant='primary')
|
133 |
+
generate_story_button = gr.Button("Generate Story")
|
134 |
+
with gr.Column(scale=1, min_width=350):
|
135 |
+
world_status = gr.Textbox(label="World Status", interactive=False)
|
136 |
+
generated_story = gr.Textbox(label="Generated Story", interactive=False, lines=12, max_lines=20)
|
137 |
+
|
138 |
+
|
139 |
+
create_button.click(
|
140 |
+
define_world,
|
141 |
+
inputs=[world_name, locations, characters],
|
142 |
+
outputs=world_status,
|
143 |
+
)
|
144 |
+
|
145 |
+
gr.Markdown("### Generate a Story in Your World")
|
146 |
+
with gr.Row():
|
147 |
+
with gr.Column(scale=1, min_width=350):
|
148 |
+
story_world = gr.Textbox(label="Enter World Name", placeholder="World name...")
|
149 |
+
event = gr.Textbox(label="Event", placeholder="Describe an event in the world...")
|
150 |
+
selected_model = gr.Radio(choices=models_options_general, value="GPT2", label="Select Model", type="value")
|
151 |
+
max_length = gr.Slider(50, 300, value=150, step=1, label="Max Length")
|
152 |
+
|
153 |
+
with gr.Tab("Chatbot"):
|
154 |
+
gr.Markdown("### **Chat With AI Models**")
|
155 |
+
with gr.Row():
|
156 |
+
with gr.Column(scale=1, min_width=250):
|
157 |
+
username = gr.Textbox(label="Username", placeholder="Enter your username", lines=1)
|
158 |
+
chat_id = gr.Textbox(label="Chat ID (optional)", placeholder="Enter chat ID or leave blank for a new chat", lines=1)
|
159 |
+
selected_model = gr.Radio(models_options_chatbot, label="Select Model", value="dialoGPT")
|
160 |
+
send_button = gr.Button("Send", variant="primary")
|
161 |
+
reset_button = gr.Button("Reset Chat", variant="secondary")
|
162 |
+
with gr.Column(scale=1, min_width=250):
|
163 |
+
input_text = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2)
|
164 |
+
emotion_output = gr.Textbox(label="Detected Emotion", interactive=False)
|
165 |
+
chat_output = gr.Textbox(label="Chat History", lines=10, interactive=False)
|
166 |
+
|
167 |
+
send_button.click(
|
168 |
+
chatbot_response_with_emotion,
|
169 |
+
inputs=[username, input_text, selected_model, chat_id],
|
170 |
+
outputs=[chat_output, chat_id, emotion_output]
|
171 |
+
)
|
172 |
+
|
173 |
+
reset_button.click(
|
174 |
+
reset_chat,
|
175 |
+
inputs=[username],
|
176 |
+
outputs=[chat_output]
|
177 |
+
)
|
178 |
+
gr.Markdown("---")
|
179 |
+
gr.Markdown("### **Fetch Chat IDs**")
|
180 |
+
with gr.Row():
|
181 |
+
with gr.Column(scale=1, min_width=250):
|
182 |
+
username = gr.Textbox(label="Username", placeholder="Enter your username", lines=1)
|
183 |
+
fetch_btn = gr.Button("Fetch", variant="primary")
|
184 |
+
with gr.Column(scale=1, min_width=250):
|
185 |
+
fetch_output = gr.Textbox(label="Chat IDs", lines=3, interactive=False)
|
186 |
+
fetch_btn.click(
|
187 |
+
chat_ids,
|
188 |
+
inputs=[username],
|
189 |
+
outputs=[fetch_output],
|
190 |
+
)
|
191 |
+
|
192 |
+
generate_story_button.click(
|
193 |
+
generate_story,
|
194 |
+
inputs=[selected_model, story_world, max_length, event],
|
195 |
+
outputs=generated_story,
|
196 |
+
)
|
197 |
+
|
198 |
+
gr.Markdown("Made by **AliMc2021** with ❤️")
|
199 |
+
|
200 |
+
# Launch the interface
|
201 |
+
interface.queue().launch(
|
202 |
+
server_port=7860,
|
203 |
+
show_error=True,
|
204 |
+
inline=False,
|
205 |
+
#share=True,
|
206 |
+
)
|
database.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
|
3 |
+
# مسیر پایگاه داده
|
4 |
+
DATABASE_PATH = 'database.db'
|
5 |
+
|
6 |
+
# ایجاد یا بازنشانی جداول پایگاه داده
|
7 |
+
def create_db():
|
8 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
9 |
+
c = conn.cursor()
|
10 |
+
c.execute("""
|
11 |
+
CREATE TABLE IF NOT EXISTS inputs (
|
12 |
+
id INTEGER PRIMARY KEY,
|
13 |
+
input_text TEXT,
|
14 |
+
selected_model TEXT
|
15 |
+
)
|
16 |
+
""")
|
17 |
+
c.execute("""
|
18 |
+
CREATE TABLE IF NOT EXISTS chats (
|
19 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
20 |
+
chat_id TEXT NOT NULL,
|
21 |
+
username TEXT NOT NULL,
|
22 |
+
user_message TEXT NOT NULL,
|
23 |
+
ai_response TEXT NOT NULL
|
24 |
+
)
|
25 |
+
""")
|
26 |
+
conn.commit()
|
27 |
+
conn.close()
|
28 |
+
|
29 |
+
# درج چت در جدول chats
|
30 |
+
def insert_chat(chat_id, username, user_message, ai_response):
|
31 |
+
try:
|
32 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
33 |
+
cursor = conn.cursor()
|
34 |
+
cursor.execute("""
|
35 |
+
INSERT INTO chats (chat_id, username, user_message, ai_response)
|
36 |
+
VALUES (?, ?, ?, ?)
|
37 |
+
""", (str(chat_id), str(username), str(user_message), str(ai_response)))
|
38 |
+
conn.commit()
|
39 |
+
except sqlite3.Error as e:
|
40 |
+
print(f"Error inserting chat: {e}")
|
41 |
+
finally:
|
42 |
+
conn.close()
|
43 |
+
|
44 |
+
# درج داده در جدول inputs
|
45 |
+
def insert_into_db(input_text, selected_model):
|
46 |
+
try:
|
47 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
48 |
+
c = conn.cursor()
|
49 |
+
c.execute("""
|
50 |
+
INSERT INTO inputs (input_text, selected_model)
|
51 |
+
VALUES (?, ?)
|
52 |
+
""", (str(input_text), str(selected_model)))
|
53 |
+
conn.commit()
|
54 |
+
except sqlite3.Error as e:
|
55 |
+
print(f"Error inserting into inputs: {e}")
|
56 |
+
finally:
|
57 |
+
conn.close()
|
58 |
+
|
59 |
+
# پاک کردن دادههای جدول inputs
|
60 |
+
def clear_database():
|
61 |
+
try:
|
62 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
63 |
+
c = conn.cursor()
|
64 |
+
c.execute("DELETE FROM inputs")
|
65 |
+
conn.commit()
|
66 |
+
except sqlite3.Error as e:
|
67 |
+
print(f"Error clearing database: {e}")
|
68 |
+
finally:
|
69 |
+
conn.close()
|
70 |
+
|
71 |
+
# بازیابی تمام ورودیها از جدول inputs
|
72 |
+
def fetch_all_inputs():
|
73 |
+
try:
|
74 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
75 |
+
c = conn.cursor()
|
76 |
+
c.execute("SELECT input_text, selected_model FROM inputs")
|
77 |
+
results = c.fetchall()
|
78 |
+
return results
|
79 |
+
except sqlite3.Error as e:
|
80 |
+
print(f"Error fetching inputs from database: {e}")
|
81 |
+
return []
|
82 |
+
finally:
|
83 |
+
conn.close()
|
84 |
+
|
85 |
+
# بازیابی پیامها و پاسخهای مرتبط با یک chat_id
|
86 |
+
def fetch_chats_by_id(chat_id):
|
87 |
+
try:
|
88 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
89 |
+
cursor = conn.cursor()
|
90 |
+
cursor.execute("""
|
91 |
+
SELECT user_message, ai_response FROM chats
|
92 |
+
WHERE chat_id = ?
|
93 |
+
""", (str(chat_id),))
|
94 |
+
rows = cursor.fetchall()
|
95 |
+
return rows
|
96 |
+
except sqlite3.Error as e:
|
97 |
+
print(f"Error fetching chats by ID: {e}")
|
98 |
+
return []
|
99 |
+
finally:
|
100 |
+
conn.close()
|
101 |
+
|
102 |
+
# بازیابی chat_id ها برای یک کاربر خاص
|
103 |
+
def fetch_ids_by_user(username):
|
104 |
+
try:
|
105 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
106 |
+
cursor = conn.cursor()
|
107 |
+
cursor.execute("""
|
108 |
+
SELECT chat_id FROM chats
|
109 |
+
WHERE username = ?
|
110 |
+
""", (str(username),))
|
111 |
+
rows = cursor.fetchall()
|
112 |
+
return rows
|
113 |
+
except sqlite3.Error as e:
|
114 |
+
print(f"Error fetching chat IDs by username: {e}")
|
115 |
+
return []
|
116 |
+
finally:
|
117 |
+
conn.close()
|
118 |
+
|
119 |
+
# حذف چتهای مرتبط با یک کاربر خاص
|
120 |
+
def clear_chats_by_username(username):
|
121 |
+
try:
|
122 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
123 |
+
cursor = conn.cursor()
|
124 |
+
cursor.execute("""
|
125 |
+
DELETE FROM chats
|
126 |
+
WHERE username = ?
|
127 |
+
""", (str(username),))
|
128 |
+
conn.commit()
|
129 |
+
except sqlite3.Error as e:
|
130 |
+
print(f"Error clearing chats by username: {e}")
|
131 |
+
finally:
|
132 |
+
conn.close()
|
download.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
# لیست مدلها با مسیر ذخیره مشخصشده
|
5 |
+
MODEL_LIST = {
|
6 |
+
"gpt2": {"path": "openai-community/gpt2", "save_dir": "./models/gpt2"},
|
7 |
+
"gpt2-medium": {"path": "openai-community/gpt2-medium", "save_dir": "./models/gpt2-medium"},
|
8 |
+
"gpt2-persian": {"path": "flax-community/gpt2-medium-persian", "save_dir": "./models/gpt2-medium-persian"},
|
9 |
+
"gpt2-large": {"path": "openai-community/gpt2-large", "save_dir": "./models/gpt2-large"},
|
10 |
+
"codegen": {"path": "Salesforce/codegen-350M-mono", "save_dir": "./models/codegen"},
|
11 |
+
"dialogpt": {"path": "microsoft/DialoGPT-small", "save_dir": "./models/dialogpt"},
|
12 |
+
"dialogpt-medium": {"path": "microsoft/DialoGPT-medium", "save_dir": "./models/dialogpt-medium"},
|
13 |
+
"dialogpt-large": {"path": "microsoft/DialoGPT-large", "save_dir": "./models/dialogpt-large"}
|
14 |
+
}
|
15 |
+
|
16 |
+
def download_and_save_models():
|
17 |
+
"""
|
18 |
+
دانلود و ذخیره تمام مدلها در مسیرهای مشخصشده.
|
19 |
+
"""
|
20 |
+
for model_name, model_info in MODEL_LIST.items():
|
21 |
+
model_path = model_info["path"] # مسیر مدل در Hugging Face
|
22 |
+
save_dir = model_info["save_dir"] # مسیر ذخیره مدل
|
23 |
+
|
24 |
+
print(f"Downloading and saving model: {model_name} to folder: {save_dir}")
|
25 |
+
|
26 |
+
if not os.path.exists(save_dir): # بررسی اینکه آیا فولدر ذخیره وجود دارد یا نه
|
27 |
+
os.makedirs(save_dir, exist_ok=True)
|
28 |
+
|
29 |
+
# دانلود و ذخیره مدل
|
30 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
32 |
+
model.save_pretrained(save_dir)
|
33 |
+
tokenizer.save_pretrained(save_dir)
|
34 |
+
|
35 |
+
print(f"Model {model_name} saved to {save_dir}")
|
36 |
+
else:
|
37 |
+
print(f"Model {model_name} already exists in {save_dir}")
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
download_and_save_models()
|
functions.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from model import load_model_lazy, unload_model
|
3 |
+
from generate import generate_code, generate_text
|
4 |
+
from database import *
|
5 |
+
import train
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
train_pass = '6818'
|
9 |
+
|
10 |
+
# AI-Powered Story World Builder Functions
|
11 |
+
world_data = {}
|
12 |
+
|
13 |
+
def _generate_code(code_prompt, max_tokens, selected_model='codegen'):
|
14 |
+
"""
|
15 |
+
Generate code based on the code prompt and selected model.
|
16 |
+
"""
|
17 |
+
# Load the model lazily
|
18 |
+
model_data = load_model_lazy(selected_model)
|
19 |
+
|
20 |
+
# Generate code
|
21 |
+
generated_code = generate_code(model_data, code_prompt, max_tokens)
|
22 |
+
|
23 |
+
# Unload the model after use
|
24 |
+
unload_model(selected_model)
|
25 |
+
|
26 |
+
return generated_code
|
27 |
+
|
28 |
+
def generate(input_text, selected_model, max_new_token):
|
29 |
+
"""
|
30 |
+
Generate text based on the selected model and input text.
|
31 |
+
"""
|
32 |
+
# Load the model lazily
|
33 |
+
model_data = load_model_lazy(selected_model)
|
34 |
+
|
35 |
+
# Generate text
|
36 |
+
generated_text = generate_text(model_data, input_text, max_new_token)
|
37 |
+
insert_into_db(input_text, selected_model)
|
38 |
+
|
39 |
+
# Unload the model after use
|
40 |
+
unload_model(selected_model)
|
41 |
+
|
42 |
+
return generated_text
|
43 |
+
|
44 |
+
def define_world(world_name, locations, characters):
|
45 |
+
"""
|
46 |
+
Define a new story world with locations and characters.
|
47 |
+
"""
|
48 |
+
world_data["world_name"] = world_name
|
49 |
+
world_data["locations"] = locations.split(", ")
|
50 |
+
world_data["characters"] = characters.split(", ")
|
51 |
+
return f"World '{world_name}' created with locations: {locations} and characters: {characters}"
|
52 |
+
|
53 |
+
def generate_story(model, world_name, event, max_length):
|
54 |
+
"""
|
55 |
+
Generate a story based on the defined world and an event.
|
56 |
+
"""
|
57 |
+
if not world_name or not world_data.get("world_name"):
|
58 |
+
return "Error: Please define a world first."
|
59 |
+
|
60 |
+
if world_name != world_data["world_name"]:
|
61 |
+
return f"Error: World '{world_name}' not found. Define it first."
|
62 |
+
|
63 |
+
prompt = f"In the world of {world_name}, {event}. Locations: {', '.join(world_data['locations'])}. Characters: {', '.join(world_data['characters'])}."
|
64 |
+
|
65 |
+
generated_story = generate(prompt, model, max_length)
|
66 |
+
return generated_story
|
67 |
+
|
68 |
+
|
69 |
+
# Story Mode
|
70 |
+
story = []
|
71 |
+
|
72 |
+
# Main Function For Story Generating
|
73 |
+
def interactive_story(input_text, selected_model, max_length):
|
74 |
+
global story
|
75 |
+
if input_text.strip():
|
76 |
+
story.append(input_text) # Add user input to story
|
77 |
+
current_text = " ".join(story) # Build cumulative story
|
78 |
+
|
79 |
+
generated_text = generate(current_text, selected_model, max_length)
|
80 |
+
story.append(generated_text) # Add generated text to story
|
81 |
+
|
82 |
+
return current_text + "\n\n" + generated_text
|
83 |
+
|
84 |
+
|
85 |
+
def reset_story():
|
86 |
+
global story
|
87 |
+
story = [] # Reset story
|
88 |
+
return ""
|
89 |
+
|
90 |
+
def generate_multiverse(input_text, selected_model, max_new_tokens, num_worlds=3):
|
91 |
+
"""
|
92 |
+
Generate multiple parallel worlds from a single input text.
|
93 |
+
"""
|
94 |
+
worlds = []
|
95 |
+
|
96 |
+
for i in range(num_worlds):
|
97 |
+
world_intro = f"World {i + 1}: "
|
98 |
+
# Custom logic for different parallel worlds
|
99 |
+
if i == 0:
|
100 |
+
world_intro += f"{input_text} This world leads to a parallel universe!"
|
101 |
+
elif i == 1:
|
102 |
+
world_intro += f"{input_text} In this world, time splits into different periods!"
|
103 |
+
elif i == 2:
|
104 |
+
world_intro += f"{input_text} This world faces a strange physical anomaly that changes everything!"
|
105 |
+
|
106 |
+
# Generate the story for this world
|
107 |
+
generated_text = generate(world_intro, selected_model, max_new_tokens)
|
108 |
+
|
109 |
+
worlds.append(generated_text)
|
110 |
+
|
111 |
+
return "\n\n".join(worlds)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
# Function to verify password, train the model, and clear the database
|
116 |
+
def verify_and_train_combined(selected_model, train_method, epochs, batch_size, password, custom_text, dataset_file, dataset_name, split_name):
|
117 |
+
if password != train_pass:
|
118 |
+
return "Error: Incorrect password. Training not started."
|
119 |
+
|
120 |
+
if train_method == "Custom Text" and custom_text.strip():
|
121 |
+
train.train_model_with_text(selected_model, custom_text, epochs, batch_size)
|
122 |
+
return f"Training completed for model: {selected_model} using custom text."
|
123 |
+
|
124 |
+
elif train_method == "Database":
|
125 |
+
train.train_model_with_database(selected_model, epochs, batch_size)
|
126 |
+
clear_database()
|
127 |
+
return f"Training completed for model: {selected_model} using database. Database cleared."
|
128 |
+
|
129 |
+
elif train_method == "Dataset File" and dataset_file is not None:
|
130 |
+
try:
|
131 |
+
dataset_path = dataset_file.name
|
132 |
+
train.train_model_with_dataset(selected_model, epochs, batch_size, dataset_path)
|
133 |
+
return f"Training completed for model: {selected_model} using uploaded dataset."
|
134 |
+
except Exception as e:
|
135 |
+
return f"Error during training with dataset: {str(e)}"
|
136 |
+
|
137 |
+
elif train_method == "Hugging Face Dataset" and dataset_name.strip():
|
138 |
+
try:
|
139 |
+
train.train_model_with_hf_dataset(selected_model, epochs, batch_size, dataset_name, split=split_name.strip())
|
140 |
+
return f"Training completed for model: {selected_model} using Hugging Face dataset {dataset_name}."
|
141 |
+
except Exception as e:
|
142 |
+
return f"Error during training with Hugging Face dataset: {str(e)}"
|
143 |
+
|
144 |
+
else:
|
145 |
+
return "Error: Invalid input for training. Please check your selections."
|
146 |
+
|
147 |
+
def limit_chat_history(chat_history, max_turns=3):
|
148 |
+
"""
|
149 |
+
محدود کردن تعداد پیامهای تاریخچه به max_turns.
|
150 |
+
"""
|
151 |
+
turns = chat_history.split("\n")
|
152 |
+
if len(turns) > max_turns * 2: # هر سوال و پاسخ دو خط میشود
|
153 |
+
turns = turns[-max_turns * 2:] # فقط n پیام اخیر را نگه میدارد
|
154 |
+
return "\n".join(turns)
|
155 |
+
|
156 |
+
def chatbot_response(username, input_text, selected_model, chat_id=None):
|
157 |
+
if not username.strip():
|
158 |
+
return "Error: Please enter a username.", "", str(uuid.uuid4()) # تولید شناسه جدید
|
159 |
+
|
160 |
+
# اگر شناسه چت وارد نشده باشد، یک شناسه جدید تولید میشود
|
161 |
+
if not chat_id or chat_id.strip() == "":
|
162 |
+
chat_id = str(uuid.uuid4()) # تولید شناسه جدید
|
163 |
+
|
164 |
+
# Load model lazily
|
165 |
+
model_data = load_model_lazy(selected_model)
|
166 |
+
|
167 |
+
# Retrieve previous chats from database
|
168 |
+
previous_chats = fetch_chats_by_id(chat_id)
|
169 |
+
chat_history = "\n".join([f"User: {msg}\nAI: {resp}" for msg, resp in previous_chats])
|
170 |
+
|
171 |
+
# محدود کردن تاریخچه چت
|
172 |
+
if chat_history:
|
173 |
+
chat_history = limit_chat_history(chat_history, max_turns=3)
|
174 |
+
prompt = f"{chat_history}\nUser: {input_text}\nAI:"
|
175 |
+
else:
|
176 |
+
prompt = f"User: {input_text}\nAI:"
|
177 |
+
|
178 |
+
# Generate response
|
179 |
+
max_new_token = 150 # تعداد توکنهای جدید
|
180 |
+
full_response = generate_text(model_data, prompt, max_new_token) # حذف آرگومانهای اضافی
|
181 |
+
|
182 |
+
# Extract only the new AI response
|
183 |
+
ai_response = full_response.split("AI:")[-1].strip()
|
184 |
+
|
185 |
+
unload_model(selected_model)
|
186 |
+
|
187 |
+
# Save chat to database
|
188 |
+
insert_chat(chat_id, username, input_text, ai_response)
|
189 |
+
|
190 |
+
# Return updated chat history and chat_id
|
191 |
+
updated_history = chat_history + f"\nUser: {input_text}\nAI: {ai_response}"
|
192 |
+
return updated_history, chat_id
|
193 |
+
|
194 |
+
def chat_ids(username):
|
195 |
+
return fetch_ids_by_user(username)
|
196 |
+
|
197 |
+
def reset_chat(username):
|
198 |
+
clear_chats_by_username(username) # حذف چتهای مرتبط با کاربر
|
199 |
+
return f"Chat history cleared for user: {username}", ""
|
200 |
+
|
201 |
+
# توابع تحلیل احساسات
|
202 |
+
def analyze_emotion(user_input):
|
203 |
+
# بارگذاری مدل احساسات
|
204 |
+
model_data = load_model_lazy("bert-emotion")
|
205 |
+
|
206 |
+
# اگر مدل از pipeline پشتیبانی میکند
|
207 |
+
if "pipeline" in model_data:
|
208 |
+
emotion_pipeline = model_data["pipeline"]
|
209 |
+
result = emotion_pipeline(user_input)
|
210 |
+
emotion = result[0]['label']
|
211 |
+
confidence = result[0]['score']
|
212 |
+
else:
|
213 |
+
# روش قدیمی برای مدلهایی که از pipeline پشتیبانی نمیکنند
|
214 |
+
emotion_tokenizer = model_data['tokenizer']
|
215 |
+
emotion_model = model_data['model']
|
216 |
+
inputs = emotion_tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
|
217 |
+
outputs = emotion_model(**inputs)
|
218 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
219 |
+
emotion = probs.argmax().item()
|
220 |
+
confidence = probs.max().item()
|
221 |
+
|
222 |
+
unload_model("bert-emotion")
|
223 |
+
return emotion, confidence
|
224 |
+
|
225 |
+
def emotion_label(index):
|
226 |
+
emotions = ["anger", "joy", "sadness", "fear", "love", "surprise"]
|
227 |
+
return emotions[index]
|
228 |
+
|
229 |
+
def chatbot_response_with_emotion(username, input_text, selected_model, chat_id=None):
|
230 |
+
if not username.strip():
|
231 |
+
return "Error: Please enter a username.", "", str(uuid.uuid4())
|
232 |
+
|
233 |
+
if not chat_id or chat_id.strip() == "":
|
234 |
+
chat_id = str(uuid.uuid4())
|
235 |
+
|
236 |
+
# بارگذاری مدل چت و احساسات
|
237 |
+
model_data = load_model_lazy(selected_model)
|
238 |
+
|
239 |
+
# تحلیل احساسات پیام کاربر
|
240 |
+
emotion, confidence = analyze_emotion(input_text)
|
241 |
+
user_emotion = emotion # برچسب احساسات
|
242 |
+
|
243 |
+
# بازیابی چتهای قبلی از پایگاه داده
|
244 |
+
previous_chats = fetch_chats_by_id(chat_id)
|
245 |
+
chat_history = "\n".join([f"User: {msg}\nAI: {resp}" for msg, resp in previous_chats])
|
246 |
+
|
247 |
+
# محدود کردن تاریخچه چت
|
248 |
+
if chat_history:
|
249 |
+
chat_history = limit_chat_history(chat_history, max_turns=3)
|
250 |
+
prompt = f"[Emotion: {user_emotion}]\n{chat_history}\nUser: {input_text}\nAI:"
|
251 |
+
else:
|
252 |
+
prompt = f"[Emotion: {user_emotion}]\nUser: {input_text}\nAI:"
|
253 |
+
|
254 |
+
# تولید پاسخ
|
255 |
+
max_new_token = 150
|
256 |
+
full_response = generate_text(model_data, prompt, max_new_token)
|
257 |
+
|
258 |
+
# استخراج پاسخ AI
|
259 |
+
ai_response = full_response.split("AI:")[-1].strip()
|
260 |
+
|
261 |
+
# آزادسازی مدلها
|
262 |
+
unload_model(selected_model)
|
263 |
+
unload_model("bert-emotion")
|
264 |
+
|
265 |
+
# ذخیره چت در پایگاه داده
|
266 |
+
insert_chat(chat_id, username, input_text, ai_response)
|
267 |
+
|
268 |
+
# بازگرداندن تاریخچه بهروز شده و شناسه چت
|
269 |
+
updated_history = chat_history + f"\nUser: {input_text}\nAI: {ai_response}"
|
270 |
+
return updated_history, chat_id, user_emotion
|
generate.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
seed = 0
|
4 |
+
|
5 |
+
def generate_text(model_data, input_text, max_new_token):
|
6 |
+
"""
|
7 |
+
Generate text using the given model and tokenizer.
|
8 |
+
"""
|
9 |
+
if "pipeline" in model_data:
|
10 |
+
# اگر مدل از pipeline پشتیبانی میکند
|
11 |
+
model_pipeline = model_data["pipeline"]
|
12 |
+
generated_text = model_pipeline(
|
13 |
+
input_text,
|
14 |
+
max_length=max_new_token,
|
15 |
+
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
16 |
+
truncation=True # فعال کردن truncation
|
17 |
+
)[0]["generated_text"]
|
18 |
+
return generated_text
|
19 |
+
else:
|
20 |
+
# روش قدیمی برای مدلهایی که از pipeline پشتیبانی نمیکنند
|
21 |
+
model = model_data["model"]
|
22 |
+
tokenizer = model_data["tokenizer"]
|
23 |
+
|
24 |
+
if tokenizer.pad_token is None:
|
25 |
+
tokenizer.pad_token = tokenizer.eos_token
|
26 |
+
|
27 |
+
torch.manual_seed(seed)
|
28 |
+
torch.cuda.manual_seed_all(seed)
|
29 |
+
|
30 |
+
encodings = tokenizer(
|
31 |
+
input_text,
|
32 |
+
return_tensors="pt",
|
33 |
+
padding=True,
|
34 |
+
truncation=True, # فعال کردن truncation
|
35 |
+
max_length=512
|
36 |
+
)
|
37 |
+
input_ids = encodings.input_ids
|
38 |
+
attention_mask = encodings.attention_mask
|
39 |
+
|
40 |
+
outputs = model.generate(
|
41 |
+
input_ids=input_ids,
|
42 |
+
attention_mask=attention_mask,
|
43 |
+
max_new_tokens=max_new_token,
|
44 |
+
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
45 |
+
pad_token_id=tokenizer.eos_token_id,
|
46 |
+
repetition_penalty=1.2,
|
47 |
+
no_repeat_ngram_size=3,
|
48 |
+
)
|
49 |
+
|
50 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
51 |
+
|
52 |
+
def generate_code(model_data, prompt, max_new_tokens):
|
53 |
+
"""
|
54 |
+
Generate code based on the provided prompt using a code-specific model.
|
55 |
+
"""
|
56 |
+
model = model_data["model"]
|
57 |
+
tokenizer = model_data["tokenizer"]
|
58 |
+
|
59 |
+
# تنظیم seed برای خروجی ثابت
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
torch.cuda.manual_seed_all(seed)
|
62 |
+
|
63 |
+
# توکنایز کردن ورودی
|
64 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
65 |
+
|
66 |
+
# ایجاد attention mask
|
67 |
+
attention_mask = torch.ones(input_ids.shape, device=input_ids.device) # ایجاد یک ماسک توجه برای ورودیها
|
68 |
+
|
69 |
+
# تولید کد
|
70 |
+
outputs = model.generate(
|
71 |
+
input_ids=input_ids,
|
72 |
+
attention_mask=attention_mask, # ارسال attention mask
|
73 |
+
max_new_tokens=max_new_tokens,
|
74 |
+
do_sample=False,
|
75 |
+
pad_token_id=tokenizer.eos_token_id, # تنظیم شناسه توکن پایان به عنوان پرکننده
|
76 |
+
repetition_penalty=1.2, # جلوگیری از تکرار
|
77 |
+
no_repeat_ngram_size=3, # جلوگیری از تکرار n-gram
|
78 |
+
)
|
79 |
+
|
80 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
lib.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
gradio
|
4 |
+
datasets
|
5 |
+
numpy
|
6 |
+
Pillow
|
7 |
+
scikit-learn
|
8 |
+
wandb
|
9 |
+
pyarrow
|
10 |
+
pandas
|
11 |
+
chardet
|
12 |
+
accelerate
|
13 |
+
safetensors
|
14 |
+
diffusers
|
15 |
+
jax
|
model.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer, pipeline, AutoModelForSequenceClassification
|
4 |
+
|
5 |
+
# Dictionary of models and paths
|
6 |
+
model_dict = {
|
7 |
+
"GPT2": {"path": "./models/gpt2", "library": GPT2LMHeadModel, "tokenizer": GPT2Tokenizer, "use_pipeline": False},
|
8 |
+
"GPT2-medium": {"path": "./models/gpt2-medium", "library": GPT2LMHeadModel, "tokenizer": GPT2Tokenizer, "use_pipeline": False},
|
9 |
+
"GPT2-large": {"path": "./models/gpt2-large", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
10 |
+
"GPT2-persian": {"path": "./models/gpt2-medium-persian", "library": GPT2LMHeadModel, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
11 |
+
"codegen": {"path": "./models/codegen", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
12 |
+
"dialoGPT": {"path": "./models/dialogpt", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
13 |
+
"dialoGPT-medium": {"path": "./models/dialogpt-medium", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
14 |
+
"dialoGPT-large": {"path": "./models/dialogpt-large", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
15 |
+
"GPT-Neo-125M": {"path": "./models/GPT-neo-125M", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": True}, # اضافه کردن مدل جدید
|
16 |
+
"bert-emotion": {"path": "./models/bert-emotion", "library": AutoModelForSequenceClassification, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
17 |
+
}
|
18 |
+
|
19 |
+
loaded_models = {}
|
20 |
+
|
21 |
+
def load_model_lazy(model_name):
|
22 |
+
if not isinstance(model_name, str):
|
23 |
+
raise ValueError(f"Model name must be a string, not {type(model_name)}")
|
24 |
+
if model_name not in model_dict:
|
25 |
+
raise ValueError(f"Model {model_name} not found!")
|
26 |
+
|
27 |
+
model_info = model_dict[model_name]
|
28 |
+
print(f"Loading model: {model_name}")
|
29 |
+
|
30 |
+
# اگر مدل از pipeline پشتیبانی میکند
|
31 |
+
if model_info.get("use_pipeline", False):
|
32 |
+
print(f"Using pipeline for model: {model_name}")
|
33 |
+
if model_name == "bert-emotion":
|
34 |
+
# برای مدل bert-emotion از text-classification استفاده کنید
|
35 |
+
model_pipeline = pipeline(
|
36 |
+
"text-classification", # تغییر وظیفه به text-classification
|
37 |
+
model=model_info["path"],
|
38 |
+
truncation=True
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
# برای سایر مدلها از text-generation استفاده کنید
|
42 |
+
model_pipeline = pipeline(
|
43 |
+
"text-generation",
|
44 |
+
model=model_info["path"],
|
45 |
+
truncation=True,
|
46 |
+
pad_token_id=50256
|
47 |
+
)
|
48 |
+
loaded_models[model_name] = {"pipeline": model_pipeline}
|
49 |
+
return {"pipeline": model_pipeline}
|
50 |
+
|
51 |
+
# در غیر این صورت، مدل و توکنایزر را به روش قدیمی بارگذاری کنید
|
52 |
+
model = model_info["library"].from_pretrained(model_info["path"])
|
53 |
+
tokenizer = model_info["tokenizer"].from_pretrained(model_info["path"])
|
54 |
+
|
55 |
+
# تنظیمات پیشفرض
|
56 |
+
if tokenizer.pad_token is None:
|
57 |
+
tokenizer.pad_token = tokenizer.eos_token
|
58 |
+
|
59 |
+
loaded_models[model_name] = {"model": model, "tokenizer": tokenizer}
|
60 |
+
return {"model": model, "tokenizer": tokenizer}
|
61 |
+
|
62 |
+
def unload_model(model_name):
|
63 |
+
global loaded_models
|
64 |
+
if model_name in loaded_models:
|
65 |
+
if "pipeline" in loaded_models[model_name]:
|
66 |
+
del loaded_models[model_name]["pipeline"]
|
67 |
+
elif "model" in loaded_models[model_name]:
|
68 |
+
del loaded_models[model_name]["model"]
|
69 |
+
del loaded_models[model_name]["tokenizer"]
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
gc.collect()
|
72 |
+
print(f"Model {model_name} unloaded and memory cleared.")
|
73 |
+
else:
|
74 |
+
print(f"Model {model_name} was not loaded.")
|
train.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from transformers import AdamW
|
4 |
+
from model import load_model_lazy, unload_model
|
5 |
+
from database import fetch_all_inputs, clear_database # مدیریت دیتابیس
|
6 |
+
from datasets import load_dataset
|
7 |
+
|
8 |
+
class TextDataset(Dataset):
|
9 |
+
def __init__(self, texts, tokenizer, max_length=512):
|
10 |
+
self.texts = texts
|
11 |
+
self.tokenizer = tokenizer
|
12 |
+
self.max_length = max_length
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.texts)
|
16 |
+
|
17 |
+
def __getitem__(self, idx):
|
18 |
+
text = self.texts[idx]
|
19 |
+
encodings = self.tokenizer(
|
20 |
+
text,
|
21 |
+
truncation=True,
|
22 |
+
padding="max_length", # پُر کردن توکنها تا طول مشخص
|
23 |
+
max_length=self.max_length,
|
24 |
+
return_tensors="pt"
|
25 |
+
)
|
26 |
+
attention_mask = encodings.attention_mask.squeeze(0)
|
27 |
+
return encodings.input_ids.squeeze(0), attention_mask
|
28 |
+
|
29 |
+
def train_model_with_text(selected_model, custom_text, epochs, batch_size):
|
30 |
+
"""
|
31 |
+
آموزش مدل با متن سفارشی.
|
32 |
+
"""
|
33 |
+
model, tokenizer = load_model_lazy(selected_model)
|
34 |
+
dataset = TextDataset([custom_text], tokenizer)
|
35 |
+
dataloader = DataLoader(dataset, batch_size=min(batch_size, len(dataset)), shuffle=True)
|
36 |
+
|
37 |
+
_train_model(model, tokenizer, dataloader, epochs, selected_model, "custom_text")
|
38 |
+
unload_model(selected_model)
|
39 |
+
|
40 |
+
def train_model_with_database(selected_model, epochs, batch_size):
|
41 |
+
"""
|
42 |
+
آموزش مدل با دادههای موجود در دیتابیس.
|
43 |
+
"""
|
44 |
+
model, tokenizer = load_model_lazy(selected_model)
|
45 |
+
inputs_data = fetch_all_inputs()
|
46 |
+
texts = [input_text for input_text, model_name in inputs_data if model_name == selected_model]
|
47 |
+
|
48 |
+
if not texts:
|
49 |
+
print("Error: No data found in the database for the selected model.")
|
50 |
+
return
|
51 |
+
|
52 |
+
dataset = TextDataset(texts, tokenizer)
|
53 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
54 |
+
|
55 |
+
_train_model(model, tokenizer, dataloader, epochs, selected_model, "database")
|
56 |
+
clear_database()
|
57 |
+
unload_model(selected_model)
|
58 |
+
|
59 |
+
def train_model_with_dataset(selected_model, epochs, batch_size, dataset_path):
|
60 |
+
"""
|
61 |
+
آموزش مدل با فایل دیتاست آپلودشده.
|
62 |
+
"""
|
63 |
+
model, tokenizer = load_model_lazy(selected_model)
|
64 |
+
|
65 |
+
# خواندن دیتاست
|
66 |
+
with open(dataset_path, "r", encoding="utf-8") as f:
|
67 |
+
texts = f.readlines()
|
68 |
+
|
69 |
+
if not texts:
|
70 |
+
print("Error: Dataset is empty.")
|
71 |
+
return
|
72 |
+
|
73 |
+
dataset = TextDataset(texts, tokenizer)
|
74 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
75 |
+
|
76 |
+
_train_model(model, tokenizer, dataloader, epochs, selected_model, "dataset")
|
77 |
+
unload_model(selected_model)
|
78 |
+
|
79 |
+
def _train_model(model, tokenizer, dataloader, epochs, model_name, method):
|
80 |
+
"""
|
81 |
+
منطق مشترک آموزش مدل.
|
82 |
+
"""
|
83 |
+
optimizer = AdamW(model.parameters(), lr=5e-5)
|
84 |
+
|
85 |
+
# انتقال مدل به GPU در صورت وجود
|
86 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
87 |
+
model.to(device)
|
88 |
+
|
89 |
+
model.train()
|
90 |
+
for epoch in range(epochs):
|
91 |
+
total_loss = 0
|
92 |
+
for step, (input_ids, attention_mask) in enumerate(dataloader):
|
93 |
+
optimizer.zero_grad()
|
94 |
+
input_ids = input_ids.to(device)
|
95 |
+
attention_mask = attention_mask.to(device)
|
96 |
+
|
97 |
+
# محاسبه خروجی و خطا
|
98 |
+
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
|
99 |
+
loss = outputs.loss
|
100 |
+
loss.backward()
|
101 |
+
optimizer.step()
|
102 |
+
total_loss += loss.item()
|
103 |
+
|
104 |
+
print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")
|
105 |
+
|
106 |
+
# ذخیره مدل
|
107 |
+
save_path = f"trained_{model_name}_{method}"
|
108 |
+
model.save_pretrained(save_path)
|
109 |
+
tokenizer.save_pretrained(save_path)
|
110 |
+
print(f"Model {model_name} trained with {method} and saved to {save_path}.")
|
111 |
+
|
112 |
+
def train_model_with_hf_dataset(selected_model, epochs, batch_size, dataset_name, split="train"):
|
113 |
+
"""
|
114 |
+
آموزش مدل با استفاده از دیتاستهای Hugging Face.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
selected_model (str): نام مدل برای آموزش.
|
118 |
+
epochs (int): تعداد epochs.
|
119 |
+
batch_size (int): اندازه batch.
|
120 |
+
dataset_name (str): نام دیتاست در Hugging Face.
|
121 |
+
split (str): بخش دیتاست برای بارگذاری (train, test, validation).
|
122 |
+
"""
|
123 |
+
model, tokenizer = load_model_lazy(selected_model)
|
124 |
+
|
125 |
+
# بارگذاری دادهها از Hugging Face
|
126 |
+
texts = load_dataset(dataset_name, split)
|
127 |
+
|
128 |
+
if not texts:
|
129 |
+
print(f"Error: Dataset {dataset_name} ({split} split) is empty or invalid.")
|
130 |
+
return
|
131 |
+
|
132 |
+
dataset = TextDataset(texts, tokenizer)
|
133 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
134 |
+
|
135 |
+
_train_model(model, tokenizer, dataloader, epochs, selected_model, f"huggingface_{dataset_name}")
|
136 |
+
unload_model(selected_model)
|