Maslov-Artem
commited on
Commit
•
3a905e4
1
Parent(s):
b90441b
Add styles
Browse files- .gitattributes +4 -0
- app.py +109 -21
- main_background.png +3 -0
- model/funcs.py +20 -1
- model/model.py +16 -57
- pages/review_predictor.py +66 -7
- pages/text_generator.py +78 -11
- space_background.jpeg +3 -0
- space_main_background.avif +0 -0
- space_main_background.jpeg +3 -0
- text_generation.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
space_main_background.jpeg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
main_background.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
text_generation.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
space_background.jpeg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,5 +1,79 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
static_toxicity_path = "https://imagizer.imageshack.com/v2/480x360q70/r/924/L4Ditq.jpg"
|
4 |
animated_toxicity_path = (
|
5 |
"https://i.kym-cdn.com/photos/images/original/001/264/967/cdc.gif"
|
@@ -7,44 +81,58 @@ animated_toxicity_path = (
|
|
7 |
animated_enlighten_path = "https://gifdb.com/images/high/zen-meditation-chakras-illustration-6lujnenasnfmn8dt.gif"
|
8 |
static_enlighten_path = "https://imagizer.imageshack.com/v2/668x500q70/r/922/bpoy6G.jpg"
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
toxicity_html = f"""
|
14 |
-
<div class="
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
</div>
|
19 |
<style>
|
20 |
-
/* Define the hover state for
|
21 |
-
.toxicity-image
|
22 |
content: url("{animated_toxicity_path}");
|
23 |
transform: scale(1.1); /* Enlarge the image by 10% */
|
24 |
transition: transform 0.5s ease; /* Add smooth transition */
|
25 |
}}
|
26 |
</style>
|
|
|
27 |
"""
|
28 |
|
|
|
29 |
enlighten_html = f"""
|
30 |
-
<div class="
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
</div>
|
35 |
<style>
|
36 |
-
/* Define the hover state for
|
37 |
-
.enlighten-image
|
38 |
content: url("{animated_enlighten_path}");
|
39 |
transform: scale(1.1); /* Enlarge the image by 10% */
|
40 |
transition: transform 0.5s ease; /* Add smooth transition */
|
41 |
}}
|
42 |
</style>
|
|
|
43 |
"""
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
49 |
|
50 |
-
# Display
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
import streamlit as st
|
4 |
|
5 |
+
|
6 |
+
def get_base64(file_path):
|
7 |
+
with open(file_path, "rb") as file:
|
8 |
+
base64_bytes = base64.b64encode(file.read())
|
9 |
+
base64_string = base64_bytes.decode("utf-8")
|
10 |
+
return base64_string
|
11 |
+
|
12 |
+
|
13 |
+
def set_background(png_file):
|
14 |
+
bin_str = get_base64(png_file)
|
15 |
+
page_bg_img = (
|
16 |
+
"""
|
17 |
+
<style>
|
18 |
+
.stApp {
|
19 |
+
background-image: url("data:image/png;base64,%s");
|
20 |
+
background-size: cover;
|
21 |
+
}
|
22 |
+
</style>
|
23 |
+
"""
|
24 |
+
% bin_str
|
25 |
+
)
|
26 |
+
st.markdown(page_bg_img, unsafe_allow_html=True)
|
27 |
+
|
28 |
+
|
29 |
+
set_background("space_background.jpeg")
|
30 |
+
|
31 |
+
# About section
|
32 |
+
about = """
|
33 |
+
<div class="text-shadow">
|
34 |
+
<h1>About</h1>
|
35 |
+
<p class="bigger">This is a multipage application created using the Streamlit library and hosted on HuggingFace Spaces.
|
36 |
+
Our application focuses on solving various natural language processing (NLP) tasks using modern machine learning models.</p>
|
37 |
+
</div>
|
38 |
+
"""
|
39 |
+
|
40 |
+
# Page 1 content
|
41 |
+
page_1 = """
|
42 |
+
<div class="text-shadow">
|
43 |
+
<h1>Classification of Reviews on Clinics</h1>
|
44 |
+
|
45 |
+
<p class="bigger">You can input your review about a clinic here, and our application will classify it using three different models:</p>
|
46 |
+
|
47 |
+
<ol>
|
48 |
+
<li>Logistic Regression trained on TF-IDF representation.</li>
|
49 |
+
<li>LSTM model with attention mechanism.</li>
|
50 |
+
<li>ruBERTtiny2.</li>
|
51 |
+
</ol>
|
52 |
+
</div>
|
53 |
+
"""
|
54 |
+
|
55 |
+
# Page 2 content
|
56 |
+
page_2 = """
|
57 |
+
<div class="text-shadow">
|
58 |
+
<h1>Text Generation with GPT Model</h1>
|
59 |
+
|
60 |
+
<p class="bigger">Ask about the mysteries of the universe</p>
|
61 |
+
</div>
|
62 |
+
"""
|
63 |
+
|
64 |
+
# Project collaborators section
|
65 |
+
project_colaborators = """
|
66 |
+
<div class="text-shadow">
|
67 |
+
<h1>Project Collaborators</h1>
|
68 |
+
<ul>
|
69 |
+
<li>Артем</li>
|
70 |
+
<li>Валера</li>
|
71 |
+
<li>Иван</li>
|
72 |
+
</ul>
|
73 |
+
</div>
|
74 |
+
"""
|
75 |
+
|
76 |
+
st.markdown(about, unsafe_allow_html=True)
|
77 |
static_toxicity_path = "https://imagizer.imageshack.com/v2/480x360q70/r/924/L4Ditq.jpg"
|
78 |
animated_toxicity_path = (
|
79 |
"https://i.kym-cdn.com/photos/images/original/001/264/967/cdc.gif"
|
|
|
81 |
animated_enlighten_path = "https://gifdb.com/images/high/zen-meditation-chakras-illustration-6lujnenasnfmn8dt.gif"
|
82 |
static_enlighten_path = "https://imagizer.imageshack.com/v2/668x500q70/r/922/bpoy6G.jpg"
|
83 |
|
84 |
+
# Toxicity image HTML
|
|
|
|
|
85 |
toxicity_html = f"""
|
86 |
+
<div class="text-shadow">
|
87 |
+
<a href="review_predictor" target="_self">
|
88 |
+
<img src="{static_toxicity_path}" width="400" class="toxicity-image" />
|
89 |
+
</a>
|
|
|
90 |
<style>
|
91 |
+
/* Define the hover state for the image */
|
92 |
+
.toxicity-image:hover {{
|
93 |
content: url("{animated_toxicity_path}");
|
94 |
transform: scale(1.1); /* Enlarge the image by 10% */
|
95 |
transition: transform 0.5s ease; /* Add smooth transition */
|
96 |
}}
|
97 |
</style>
|
98 |
+
</div>
|
99 |
"""
|
100 |
|
101 |
+
# Enlightenment image HTML
|
102 |
enlighten_html = f"""
|
103 |
+
<div class="text-shadow">
|
104 |
+
<a href="text_generator" target="_self">
|
105 |
+
<img src="{static_enlighten_path}" width="400" class="enlighten-image" />
|
106 |
+
</a>
|
|
|
107 |
<style>
|
108 |
+
/* Define the hover state for the image */
|
109 |
+
.enlighten-image:hover {{
|
110 |
content: url("{animated_enlighten_path}");
|
111 |
transform: scale(1.1); /* Enlarge the image by 10% */
|
112 |
transition: transform 0.5s ease; /* Add smooth transition */
|
113 |
}}
|
114 |
</style>
|
115 |
+
</div>
|
116 |
"""
|
117 |
|
118 |
+
# Add shadow to text content
|
119 |
+
text_shadow_style = """
|
120 |
+
<style>
|
121 |
+
.text-shadow {
|
122 |
+
color: white;
|
123 |
+
text-shadow: 4px 4px 8px #000000;
|
124 |
+
}
|
125 |
+
.bigger {
|
126 |
+
font-size: 20px;
|
127 |
+
}
|
128 |
+
</style>
|
129 |
+
"""
|
130 |
|
131 |
+
st.markdown(text_shadow_style, unsafe_allow_html=True)
|
132 |
|
133 |
+
# Display the styled text with shadow
|
134 |
+
st.markdown(page_1, unsafe_allow_html=True)
|
135 |
+
st.markdown(toxicity_html, unsafe_allow_html=True)
|
136 |
+
st.markdown(page_2, unsafe_allow_html=True)
|
137 |
+
st.markdown(enlighten_html, unsafe_allow_html=True)
|
138 |
+
st.markdown(project_colaborators, unsafe_allow_html=True)
|
main_background.png
ADDED
Git LFS Details
|
model/funcs.py
CHANGED
@@ -12,11 +12,30 @@ from torch.utils.data import Dataset
|
|
12 |
def execution_time(func):
|
13 |
@wraps(func)
|
14 |
def wrapper(*args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
start_time = time.time()
|
16 |
result = func(*args, **kwargs)
|
17 |
end_time = time.time()
|
18 |
execution_seconds = end_time - start_time
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
return result
|
21 |
|
22 |
return wrapper
|
|
|
12 |
def execution_time(func):
|
13 |
@wraps(func)
|
14 |
def wrapper(*args, **kwargs):
|
15 |
+
# Define the styling for the execution time text
|
16 |
+
styled_text = """
|
17 |
+
<style>
|
18 |
+
.execution-time {
|
19 |
+
font-size: 20px;
|
20 |
+
color: #FFFFFF;
|
21 |
+
text-shadow: -2px -2px 4px #000000;
|
22 |
+
}
|
23 |
+
</style>
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Apply the styling directly before writing the execution time text
|
27 |
+
st.markdown(styled_text, unsafe_allow_html=True)
|
28 |
+
|
29 |
start_time = time.time()
|
30 |
result = func(*args, **kwargs)
|
31 |
end_time = time.time()
|
32 |
execution_seconds = end_time - start_time
|
33 |
+
|
34 |
+
# Write the styled text for the execution time
|
35 |
+
st.markdown(
|
36 |
+
f'<div class="execution-time">Model execution time = {execution_seconds:.5f} seconds</div>',
|
37 |
+
unsafe_allow_html=True,
|
38 |
+
)
|
39 |
return result
|
40 |
|
41 |
return wrapper
|
model/model.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
-
|
2 |
from typing import Tuple
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
-
HIDDEN_SIZE =
|
7 |
-
VOCAB_SIZE =196906
|
8 |
-
EMBEDDING_DIM = 64
|
9 |
SEQ_LEN = 100
|
10 |
-
BATCH_SIZE =
|
11 |
|
12 |
|
13 |
class BahdanauAttention(nn.Module):
|
14 |
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
|
15 |
-
|
16 |
super().__init__()
|
17 |
self.hidden_size = hidden_size
|
18 |
self.W_q = nn.Linear(hidden_size, hidden_size)
|
@@ -26,66 +25,25 @@ class BahdanauAttention(nn.Module):
|
|
26 |
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
|
27 |
final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE
|
28 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
29 |
-
|
30 |
-
"""Bahdanau Attention module
|
31 |
-
|
32 |
-
Args:
|
33 |
-
keys (torch.Tensor): lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
|
34 |
-
query (torch.Tensor): lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE)
|
35 |
-
|
36 |
-
Returns:
|
37 |
-
Tuple[torch.Tensor]:
|
38 |
-
context_matrix (BATCH_SIZE, HIDDEN_SIZE)
|
39 |
-
attention scores (BATCH_SIZE, SEQ_LEN)
|
40 |
-
"""
|
41 |
-
# input:
|
42 |
-
# keys – lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
|
43 |
-
# query - lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE)
|
44 |
-
|
45 |
keys = self.W_k(lstm_outputs)
|
46 |
-
# print(f'After linear keys: {keys.shape}')
|
47 |
-
|
48 |
query = self.W_q(final_hidden)
|
49 |
-
# print(f"After linear query: {query.shape}")
|
50 |
-
|
51 |
-
# print(f"query.unsqueeze(1) {query.unsqueeze(1).shape}")
|
52 |
|
53 |
sum = query.unsqueeze(1) + keys
|
54 |
-
# print(f"After sum: {sum.shape}")
|
55 |
|
56 |
tanhed = self.tanh(sum)
|
57 |
-
# print(f"After tanhed: {tanhed.shape}")
|
58 |
|
59 |
vector = self.W_v(tanhed).squeeze(-1)
|
60 |
-
# print(f"After linear vector: {vector.shape}")
|
61 |
|
62 |
att_weights = torch.softmax(vector, -1)
|
63 |
-
# print(f"After softmax att_weights: {att_weights.shape}")
|
64 |
|
65 |
context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()
|
66 |
-
# print(f"After bmm context: {context.shape}")
|
67 |
|
68 |
return context, att_weights
|
69 |
|
70 |
-
# att_weights = self.linear(lstm_outputs)
|
71 |
-
# # print(f'After linear: {att_weights.shape, final_hidden.unsqueeze(2).shape}')
|
72 |
-
|
73 |
-
# att_weights = self.linear(lstm_outputs)
|
74 |
-
# # print(f'After linear: {att_weights.shape, final_hidden.unsqueeze(2).shape}')
|
75 |
-
# att_weights = torch.bmm(att_weights, final_hidden.unsqueeze(2))
|
76 |
-
# # print(f'After bmm: {att_weights.shape}')
|
77 |
-
# att_weights = F.softmax(att_weights.squeeze(2), dim=1)
|
78 |
-
# # print(f'After softmax: {att_weights.shape}')
|
79 |
-
# cntxt = torch.bmm(lstm_outputs.transpose(1, 2), att_weights.unsqueeze(2))
|
80 |
-
# # print(f'Context: {cntxt.shape}')
|
81 |
-
# concatted = torch.cat((cntxt, final_hidden.unsqueeze(2)), dim=1)
|
82 |
-
# # print(f'Concatted: {concatted.shape}')
|
83 |
-
# att_hidden = self.tanh(self.align(concatted.squeeze(-1)))
|
84 |
-
# # print(f'Att Hidden: {att_hidden.shape}')
|
85 |
-
# return att_hidden, att_weights
|
86 |
|
87 |
-
|
88 |
-
|
|
|
89 |
|
90 |
|
91 |
class LSTMConcatAttentionEmbed(nn.Module):
|
@@ -97,17 +55,18 @@ class LSTMConcatAttentionEmbed(nn.Module):
|
|
97 |
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
|
98 |
self.attn = BahdanauAttention(HIDDEN_SIZE)
|
99 |
self.clf = nn.Sequential(
|
100 |
-
nn.Linear(HIDDEN_SIZE, 128),
|
101 |
-
nn.Dropout(),
|
102 |
-
nn.Tanh(),
|
103 |
-
nn.Linear(128,
|
|
|
|
|
|
|
104 |
)
|
105 |
|
106 |
-
def forward(self, x):
|
107 |
embeddings = self.embedding(x)
|
108 |
outputs, (h_n, _) = self.lstm(embeddings)
|
109 |
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
|
110 |
out = self.clf(att_hidden)
|
111 |
return out, att_weights
|
112 |
-
|
113 |
-
|
|
|
|
|
1 |
from typing import Tuple
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
+
HIDDEN_SIZE = 64
|
7 |
+
VOCAB_SIZE = 196906
|
8 |
+
EMBEDDING_DIM = 64 # embedding_dim
|
9 |
SEQ_LEN = 100
|
10 |
+
BATCH_SIZE = 16
|
11 |
|
12 |
|
13 |
class BahdanauAttention(nn.Module):
|
14 |
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
|
|
|
15 |
super().__init__()
|
16 |
self.hidden_size = hidden_size
|
17 |
self.W_q = nn.Linear(hidden_size, hidden_size)
|
|
|
25 |
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
|
26 |
final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE
|
27 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
keys = self.W_k(lstm_outputs)
|
|
|
|
|
29 |
query = self.W_q(final_hidden)
|
|
|
|
|
|
|
30 |
|
31 |
sum = query.unsqueeze(1) + keys
|
|
|
32 |
|
33 |
tanhed = self.tanh(sum)
|
|
|
34 |
|
35 |
vector = self.W_v(tanhed).squeeze(-1)
|
|
|
36 |
|
37 |
att_weights = torch.softmax(vector, -1)
|
|
|
38 |
|
39 |
context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()
|
|
|
40 |
|
41 |
return context, att_weights
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
BahdanauAttention()(
|
45 |
+
torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE)
|
46 |
+
)[1].shape
|
47 |
|
48 |
|
49 |
class LSTMConcatAttentionEmbed(nn.Module):
|
|
|
55 |
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
|
56 |
self.attn = BahdanauAttention(HIDDEN_SIZE)
|
57 |
self.clf = nn.Sequential(
|
58 |
+
nn.Linear(HIDDEN_SIZE, 128),
|
59 |
+
nn.Dropout(),
|
60 |
+
nn.Tanh(),
|
61 |
+
nn.Linear(128, 64),
|
62 |
+
nn.Dropout(),
|
63 |
+
nn.Tanh(),
|
64 |
+
nn.Linear(64, 1),
|
65 |
)
|
66 |
|
67 |
+
def forward(self, x):
|
68 |
embeddings = self.embedding(x)
|
69 |
outputs, (h_n, _) = self.lstm(embeddings)
|
70 |
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
|
71 |
out = self.clf(att_hidden)
|
72 |
return out, att_weights
|
|
|
|
pages/review_predictor.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import json
|
2 |
import pickle
|
3 |
|
@@ -14,6 +15,32 @@ from preprocessing.preprocessing import data_preprocessing
|
|
14 |
from preprocessing.rnn_preprocessing import preprocess_single_string
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@st.cache_resource
|
18 |
def load_logreg():
|
19 |
with open("vectorizer.pkl", "rb") as f:
|
@@ -93,7 +120,6 @@ metrics = {
|
|
93 |
}
|
94 |
|
95 |
|
96 |
-
col1, col2 = st.columns([1, 3])
|
97 |
df = pd.DataFrame(metrics)
|
98 |
df.set_index("Models", inplace=True)
|
99 |
df.index.name = "Model"
|
@@ -101,10 +127,40 @@ df.index.name = "Model"
|
|
101 |
|
102 |
st.sidebar.title("Model Selection")
|
103 |
model_type = st.sidebar.radio("Select Model Type", ["Classic ML", "LSTM", "BERT"])
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Streamlit app code
|
107 |
-
st.
|
108 |
text_input = st.text_input("Enter your review:")
|
109 |
if st.button("Predict"):
|
110 |
if model_type == "Classic ML":
|
@@ -116,11 +172,14 @@ if st.button("Predict"):
|
|
116 |
elif model_type == "BERT":
|
117 |
prediction = predict_sentiment(text_input, model, tokenizer, "cpu")
|
118 |
|
|
|
119 |
if prediction == 1:
|
120 |
-
st.
|
121 |
-
|
|
|
122 |
elif prediction == 0:
|
123 |
-
st.
|
124 |
-
|
|
|
125 |
|
126 |
st.write(df)
|
|
|
1 |
+
import base64
|
2 |
import json
|
3 |
import pickle
|
4 |
|
|
|
15 |
from preprocessing.rnn_preprocessing import preprocess_single_string
|
16 |
|
17 |
|
18 |
+
def get_base64(file_path):
|
19 |
+
with open(file_path, "rb") as file:
|
20 |
+
base64_bytes = base64.b64encode(file.read())
|
21 |
+
base64_string = base64_bytes.decode("utf-8")
|
22 |
+
return base64_string
|
23 |
+
|
24 |
+
|
25 |
+
def set_background(png_file):
|
26 |
+
bin_str = get_base64(png_file)
|
27 |
+
page_bg_img = (
|
28 |
+
"""
|
29 |
+
<style>
|
30 |
+
.stApp {
|
31 |
+
background-image: url("data:image/png;base64,%s");
|
32 |
+
background-size: auto;
|
33 |
+
}
|
34 |
+
</style>
|
35 |
+
"""
|
36 |
+
% bin_str
|
37 |
+
)
|
38 |
+
st.markdown(page_bg_img, unsafe_allow_html=True)
|
39 |
+
|
40 |
+
|
41 |
+
set_background("main_background.png")
|
42 |
+
|
43 |
+
|
44 |
@st.cache_resource
|
45 |
def load_logreg():
|
46 |
with open("vectorizer.pkl", "rb") as f:
|
|
|
120 |
}
|
121 |
|
122 |
|
|
|
123 |
df = pd.DataFrame(metrics)
|
124 |
df.set_index("Models", inplace=True)
|
125 |
df.index.name = "Model"
|
|
|
127 |
|
128 |
st.sidebar.title("Model Selection")
|
129 |
model_type = st.sidebar.radio("Select Model Type", ["Classic ML", "LSTM", "BERT"])
|
130 |
+
|
131 |
+
|
132 |
+
styled_text = """
|
133 |
+
<style>
|
134 |
+
.styled-title {
|
135 |
+
color: #FF00FF;
|
136 |
+
font-size: 40px;
|
137 |
+
text-shadow: -2px -2px 4px #000000;
|
138 |
+
-webkit-text-stroke-width: 1px;
|
139 |
+
-webkit-text-stroke-color: #000000;
|
140 |
+
}
|
141 |
+
.positive {
|
142 |
+
color: #00FF00;
|
143 |
+
font-size: 30px;
|
144 |
+
text-shadow: -2px -2px 4px #000000;
|
145 |
+
-webkit-text-stroke-width: 1px;
|
146 |
+
-webkit-text-stroke-color: #000000;
|
147 |
+
|
148 |
+
}
|
149 |
+
.negative {
|
150 |
+
color: #FF0000;
|
151 |
+
font-size: 30px;
|
152 |
+
text-shadow: -2px -2px 4px #000000;
|
153 |
+
-webkit-text-stroke-width: 1px;
|
154 |
+
-webkit-text-stroke-color: #000000;
|
155 |
+
|
156 |
+
}
|
157 |
+
</style>
|
158 |
+
"""
|
159 |
+
|
160 |
+
st.markdown(styled_text, unsafe_allow_html=True)
|
161 |
|
162 |
# Streamlit app code
|
163 |
+
st.markdown('<div class="styled-title">Review Prediction</div>', unsafe_allow_html=True)
|
164 |
text_input = st.text_input("Enter your review:")
|
165 |
if st.button("Predict"):
|
166 |
if model_type == "Classic ML":
|
|
|
172 |
elif model_type == "BERT":
|
173 |
prediction = predict_sentiment(text_input, model, tokenizer, "cpu")
|
174 |
|
175 |
+
# Apply different styles based on prediction result
|
176 |
if prediction == 1:
|
177 |
+
st.markdown(
|
178 |
+
f'<div class="positive">Отзыв положительный</div>', unsafe_allow_html=True
|
179 |
+
)
|
180 |
elif prediction == 0:
|
181 |
+
st.markdown(
|
182 |
+
f'<div class="negative">Отзыв отрицательный</div>', unsafe_allow_html=True
|
183 |
+
)
|
184 |
|
185 |
st.write(df)
|
pages/text_generator.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
@@ -5,6 +7,32 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
5 |
from model.funcs import execution_time
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
@st.cache_data
|
9 |
def load_model():
|
10 |
model_path = "17/"
|
@@ -18,26 +46,65 @@ tokenizer, model = load_model()
|
|
18 |
|
19 |
|
20 |
@execution_time
|
21 |
-
def generate_text(
|
22 |
-
|
|
|
|
|
23 |
model.eval()
|
24 |
with torch.no_grad():
|
25 |
out = model.generate(
|
26 |
-
|
27 |
do_sample=True,
|
28 |
-
num_beams=
|
29 |
-
temperature=
|
30 |
-
top_p=
|
31 |
-
|
|
|
32 |
)
|
33 |
out = list(map(tokenizer.decode, out))[0]
|
34 |
return out
|
35 |
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
if generate:
|
40 |
-
if not
|
41 |
st.write("42")
|
42 |
else:
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
import streamlit as st
|
4 |
import torch
|
5 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
7 |
from model.funcs import execution_time
|
8 |
|
9 |
|
10 |
+
def get_base64(file_path):
|
11 |
+
with open(file_path, "rb") as file:
|
12 |
+
base64_bytes = base64.b64encode(file.read())
|
13 |
+
base64_string = base64_bytes.decode("utf-8")
|
14 |
+
return base64_string
|
15 |
+
|
16 |
+
|
17 |
+
def set_background(png_file):
|
18 |
+
bin_str = get_base64(png_file)
|
19 |
+
page_bg_img = (
|
20 |
+
"""
|
21 |
+
<style>
|
22 |
+
.stApp {
|
23 |
+
background-image: url("data:image/png;base64,%s");
|
24 |
+
background-size: cover;
|
25 |
+
}
|
26 |
+
</style>
|
27 |
+
"""
|
28 |
+
% bin_str
|
29 |
+
)
|
30 |
+
st.markdown(page_bg_img, unsafe_allow_html=True)
|
31 |
+
|
32 |
+
|
33 |
+
set_background("text_generation.png")
|
34 |
+
|
35 |
+
|
36 |
@st.cache_data
|
37 |
def load_model():
|
38 |
model_path = "17/"
|
|
|
46 |
|
47 |
|
48 |
@execution_time
|
49 |
+
def generate_text(
|
50 |
+
prompt, num_beams=2, temperature=1.5, top_p=0.9, top_k=3, max_length=150
|
51 |
+
):
|
52 |
+
prompt = tokenizer.encode(prompt, return_tensors="pt")
|
53 |
model.eval()
|
54 |
with torch.no_grad():
|
55 |
out = model.generate(
|
56 |
+
prompt,
|
57 |
do_sample=True,
|
58 |
+
num_beams=num_beams,
|
59 |
+
temperature=temperature,
|
60 |
+
top_p=top_p,
|
61 |
+
top_k=top_k,
|
62 |
+
max_length=max_length,
|
63 |
)
|
64 |
out = list(map(tokenizer.decode, out))[0]
|
65 |
return out
|
66 |
|
67 |
|
68 |
+
with st.sidebar:
|
69 |
+
num_beams = st.slider("Number of Beams", min_value=1, max_value=5, value=2)
|
70 |
+
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.5)
|
71 |
+
top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9)
|
72 |
+
top_k = st.slider("Top-k", min_value=1, max_value=10, value=3)
|
73 |
+
max_length = st.slider("Maximum Length", min_value=20, max_value=300, value=150)
|
74 |
+
|
75 |
+
styled_text = """
|
76 |
+
<style>
|
77 |
+
.styled-text {
|
78 |
+
font-size: 30px;
|
79 |
+
text-shadow: -2px -2px 4px #000000;
|
80 |
+
color: #FFFFFF;
|
81 |
+
-webkit-text-stroke-width: 1px;
|
82 |
+
-webkit-text-stroke-color: #000000;
|
83 |
+
}
|
84 |
+
</style>
|
85 |
+
"""
|
86 |
+
|
87 |
+
st.markdown(styled_text, unsafe_allow_html=True)
|
88 |
+
|
89 |
+
prompt = st.text_input(
|
90 |
+
"Ask a question",
|
91 |
+
key="question_input",
|
92 |
+
placeholder="Type here...",
|
93 |
+
type="default",
|
94 |
+
value="",
|
95 |
+
)
|
96 |
+
generate = st.button("Generate", key="generate_button")
|
97 |
+
|
98 |
if generate:
|
99 |
+
if not prompt:
|
100 |
st.write("42")
|
101 |
else:
|
102 |
+
generated_text = generate_text(
|
103 |
+
prompt, num_beams, temperature, top_p, top_k, max_length
|
104 |
+
)
|
105 |
+
paragraphs = generated_text.split("\n")
|
106 |
+
styled_paragraphs = [
|
107 |
+
f'<div class="styled-text">{paragraph}</div>' for paragraph in paragraphs
|
108 |
+
]
|
109 |
+
styled_generated_text = " ".join(styled_paragraphs)
|
110 |
+
st.markdown(styled_generated_text, unsafe_allow_html=True)
|
space_background.jpeg
ADDED
Git LFS Details
|
space_main_background.avif
ADDED
space_main_background.jpeg
ADDED
Git LFS Details
|
text_generation.png
ADDED
Git LFS Details
|