Maslov-Artem commited on
Commit
3a905e4
1 Parent(s): b90441b

Add styles

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ space_main_background.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ main_background.png filter=lfs diff=lfs merge=lfs -text
38
+ text_generation.png filter=lfs diff=lfs merge=lfs -text
39
+ space_background.jpeg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,5 +1,79 @@
 
 
1
  import streamlit as st
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  static_toxicity_path = "https://imagizer.imageshack.com/v2/480x360q70/r/924/L4Ditq.jpg"
4
  animated_toxicity_path = (
5
  "https://i.kym-cdn.com/photos/images/original/001/264/967/cdc.gif"
@@ -7,44 +81,58 @@ animated_toxicity_path = (
7
  animated_enlighten_path = "https://gifdb.com/images/high/zen-meditation-chakras-illustration-6lujnenasnfmn8dt.gif"
8
  static_enlighten_path = "https://imagizer.imageshack.com/v2/668x500q70/r/922/bpoy6G.jpg"
9
 
10
- # Calculate the column widths dynamically
11
-
12
-
13
  toxicity_html = f"""
14
- <div class="toxicity-image-container">
15
- <a href="review_predictor" target="_self" class="toxicity-link">
16
- <img src="{static_toxicity_path}" class="toxicity-image" />
17
- </a>
18
- </div>
19
  <style>
20
- /* Define the hover state for column 1 */
21
- .toxicity-image-container:hover .toxicity-image {{
22
  content: url("{animated_toxicity_path}");
23
  transform: scale(1.1); /* Enlarge the image by 10% */
24
  transition: transform 0.5s ease; /* Add smooth transition */
25
  }}
26
  </style>
 
27
  """
28
 
 
29
  enlighten_html = f"""
30
- <div class="enlighten-image-container">
31
- <a href="text_generator" target="_self" class="enlighten-link">
32
- <img src="{static_enlighten_path}" class="enlighten-image" />
33
- </a>
34
- </div>
35
  <style>
36
- /* Define the hover state for column 2 */
37
- .enlighten-image-container:hover .enlighten-image {{
38
  content: url("{animated_enlighten_path}");
39
  transform: scale(1.1); /* Enlarge the image by 10% */
40
  transition: transform 0.5s ease; /* Add smooth transition */
41
  }}
42
  </style>
 
43
  """
44
 
45
- # Display HTML code with Streamlit
46
- st.markdown(toxicity_html, unsafe_allow_html=True)
47
- st.markdown(enlighten_html, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
48
 
 
49
 
50
- # Display JavaScript code with Streamlit
 
 
 
 
 
 
1
+ import base64
2
+
3
  import streamlit as st
4
 
5
+
6
+ def get_base64(file_path):
7
+ with open(file_path, "rb") as file:
8
+ base64_bytes = base64.b64encode(file.read())
9
+ base64_string = base64_bytes.decode("utf-8")
10
+ return base64_string
11
+
12
+
13
+ def set_background(png_file):
14
+ bin_str = get_base64(png_file)
15
+ page_bg_img = (
16
+ """
17
+ <style>
18
+ .stApp {
19
+ background-image: url("data:image/png;base64,%s");
20
+ background-size: cover;
21
+ }
22
+ </style>
23
+ """
24
+ % bin_str
25
+ )
26
+ st.markdown(page_bg_img, unsafe_allow_html=True)
27
+
28
+
29
+ set_background("space_background.jpeg")
30
+
31
+ # About section
32
+ about = """
33
+ <div class="text-shadow">
34
+ <h1>About</h1>
35
+ <p class="bigger">This is a multipage application created using the Streamlit library and hosted on HuggingFace Spaces.
36
+ Our application focuses on solving various natural language processing (NLP) tasks using modern machine learning models.</p>
37
+ </div>
38
+ """
39
+
40
+ # Page 1 content
41
+ page_1 = """
42
+ <div class="text-shadow">
43
+ <h1>Classification of Reviews on Clinics</h1>
44
+
45
+ <p class="bigger">You can input your review about a clinic here, and our application will classify it using three different models:</p>
46
+
47
+ <ol>
48
+ <li>Logistic Regression trained on TF-IDF representation.</li>
49
+ <li>LSTM model with attention mechanism.</li>
50
+ <li>ruBERTtiny2.</li>
51
+ </ol>
52
+ </div>
53
+ """
54
+
55
+ # Page 2 content
56
+ page_2 = """
57
+ <div class="text-shadow">
58
+ <h1>Text Generation with GPT Model</h1>
59
+
60
+ <p class="bigger">Ask about the mysteries of the universe</p>
61
+ </div>
62
+ """
63
+
64
+ # Project collaborators section
65
+ project_colaborators = """
66
+ <div class="text-shadow">
67
+ <h1>Project Collaborators</h1>
68
+ <ul>
69
+ <li>Артем</li>
70
+ <li>Валера</li>
71
+ <li>Иван</li>
72
+ </ul>
73
+ </div>
74
+ """
75
+
76
+ st.markdown(about, unsafe_allow_html=True)
77
  static_toxicity_path = "https://imagizer.imageshack.com/v2/480x360q70/r/924/L4Ditq.jpg"
78
  animated_toxicity_path = (
79
  "https://i.kym-cdn.com/photos/images/original/001/264/967/cdc.gif"
 
81
  animated_enlighten_path = "https://gifdb.com/images/high/zen-meditation-chakras-illustration-6lujnenasnfmn8dt.gif"
82
  static_enlighten_path = "https://imagizer.imageshack.com/v2/668x500q70/r/922/bpoy6G.jpg"
83
 
84
+ # Toxicity image HTML
 
 
85
  toxicity_html = f"""
86
+ <div class="text-shadow">
87
+ <a href="review_predictor" target="_self">
88
+ <img src="{static_toxicity_path}" width="400" class="toxicity-image" />
89
+ </a>
 
90
  <style>
91
+ /* Define the hover state for the image */
92
+ .toxicity-image:hover {{
93
  content: url("{animated_toxicity_path}");
94
  transform: scale(1.1); /* Enlarge the image by 10% */
95
  transition: transform 0.5s ease; /* Add smooth transition */
96
  }}
97
  </style>
98
+ </div>
99
  """
100
 
101
+ # Enlightenment image HTML
102
  enlighten_html = f"""
103
+ <div class="text-shadow">
104
+ <a href="text_generator" target="_self">
105
+ <img src="{static_enlighten_path}" width="400" class="enlighten-image" />
106
+ </a>
 
107
  <style>
108
+ /* Define the hover state for the image */
109
+ .enlighten-image:hover {{
110
  content: url("{animated_enlighten_path}");
111
  transform: scale(1.1); /* Enlarge the image by 10% */
112
  transition: transform 0.5s ease; /* Add smooth transition */
113
  }}
114
  </style>
115
+ </div>
116
  """
117
 
118
+ # Add shadow to text content
119
+ text_shadow_style = """
120
+ <style>
121
+ .text-shadow {
122
+ color: white;
123
+ text-shadow: 4px 4px 8px #000000;
124
+ }
125
+ .bigger {
126
+ font-size: 20px;
127
+ }
128
+ </style>
129
+ """
130
 
131
+ st.markdown(text_shadow_style, unsafe_allow_html=True)
132
 
133
+ # Display the styled text with shadow
134
+ st.markdown(page_1, unsafe_allow_html=True)
135
+ st.markdown(toxicity_html, unsafe_allow_html=True)
136
+ st.markdown(page_2, unsafe_allow_html=True)
137
+ st.markdown(enlighten_html, unsafe_allow_html=True)
138
+ st.markdown(project_colaborators, unsafe_allow_html=True)
main_background.png ADDED

Git LFS Details

  • SHA256: 601b89f48519f65b3dd4c0da79a54c27812e6f051b48fda45145be424ae2a04b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
model/funcs.py CHANGED
@@ -12,11 +12,30 @@ from torch.utils.data import Dataset
12
  def execution_time(func):
13
  @wraps(func)
14
  def wrapper(*args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  start_time = time.time()
16
  result = func(*args, **kwargs)
17
  end_time = time.time()
18
  execution_seconds = end_time - start_time
19
- st.write(f"Model calculating time = {execution_seconds:.5f} seconds")
 
 
 
 
 
20
  return result
21
 
22
  return wrapper
 
12
  def execution_time(func):
13
  @wraps(func)
14
  def wrapper(*args, **kwargs):
15
+ # Define the styling for the execution time text
16
+ styled_text = """
17
+ <style>
18
+ .execution-time {
19
+ font-size: 20px;
20
+ color: #FFFFFF;
21
+ text-shadow: -2px -2px 4px #000000;
22
+ }
23
+ </style>
24
+ """
25
+
26
+ # Apply the styling directly before writing the execution time text
27
+ st.markdown(styled_text, unsafe_allow_html=True)
28
+
29
  start_time = time.time()
30
  result = func(*args, **kwargs)
31
  end_time = time.time()
32
  execution_seconds = end_time - start_time
33
+
34
+ # Write the styled text for the execution time
35
+ st.markdown(
36
+ f'<div class="execution-time">Model execution time = {execution_seconds:.5f} seconds</div>',
37
+ unsafe_allow_html=True,
38
+ )
39
  return result
40
 
41
  return wrapper
model/model.py CHANGED
@@ -1,18 +1,17 @@
1
-
2
  from typing import Tuple
 
3
  import torch
4
  import torch.nn as nn
5
 
6
- HIDDEN_SIZE = 32
7
- VOCAB_SIZE =196906
8
- EMBEDDING_DIM = 64 # embedding_dim
9
  SEQ_LEN = 100
10
- BATCH_SIZE = 64
11
 
12
 
13
  class BahdanauAttention(nn.Module):
14
  def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
15
-
16
  super().__init__()
17
  self.hidden_size = hidden_size
18
  self.W_q = nn.Linear(hidden_size, hidden_size)
@@ -26,66 +25,25 @@ class BahdanauAttention(nn.Module):
26
  lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
27
  final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE
28
  ) -> Tuple[torch.Tensor, torch.Tensor]:
29
-
30
- """Bahdanau Attention module
31
-
32
- Args:
33
- keys (torch.Tensor): lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
34
- query (torch.Tensor): lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE)
35
-
36
- Returns:
37
- Tuple[torch.Tensor]:
38
- context_matrix (BATCH_SIZE, HIDDEN_SIZE)
39
- attention scores (BATCH_SIZE, SEQ_LEN)
40
- """
41
- # input:
42
- # keys – lstm hidden states (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
43
- # query - lstm final hidden state (BATCH_SIZE, HIDDEN_SIZE)
44
-
45
  keys = self.W_k(lstm_outputs)
46
- # print(f'After linear keys: {keys.shape}')
47
-
48
  query = self.W_q(final_hidden)
49
- # print(f"After linear query: {query.shape}")
50
-
51
- # print(f"query.unsqueeze(1) {query.unsqueeze(1).shape}")
52
 
53
  sum = query.unsqueeze(1) + keys
54
- # print(f"After sum: {sum.shape}")
55
 
56
  tanhed = self.tanh(sum)
57
- # print(f"After tanhed: {tanhed.shape}")
58
 
59
  vector = self.W_v(tanhed).squeeze(-1)
60
- # print(f"After linear vector: {vector.shape}")
61
 
62
  att_weights = torch.softmax(vector, -1)
63
- # print(f"After softmax att_weights: {att_weights.shape}")
64
 
65
  context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()
66
- # print(f"After bmm context: {context.shape}")
67
 
68
  return context, att_weights
69
 
70
- # att_weights = self.linear(lstm_outputs)
71
- # # print(f'After linear: {att_weights.shape, final_hidden.unsqueeze(2).shape}')
72
-
73
- # att_weights = self.linear(lstm_outputs)
74
- # # print(f'After linear: {att_weights.shape, final_hidden.unsqueeze(2).shape}')
75
- # att_weights = torch.bmm(att_weights, final_hidden.unsqueeze(2))
76
- # # print(f'After bmm: {att_weights.shape}')
77
- # att_weights = F.softmax(att_weights.squeeze(2), dim=1)
78
- # # print(f'After softmax: {att_weights.shape}')
79
- # cntxt = torch.bmm(lstm_outputs.transpose(1, 2), att_weights.unsqueeze(2))
80
- # # print(f'Context: {cntxt.shape}')
81
- # concatted = torch.cat((cntxt, final_hidden.unsqueeze(2)), dim=1)
82
- # # print(f'Concatted: {concatted.shape}')
83
- # att_hidden = self.tanh(self.align(concatted.squeeze(-1)))
84
- # # print(f'Att Hidden: {att_hidden.shape}')
85
- # return att_hidden, att_weights
86
 
87
- # Test on random numbers
88
- BahdanauAttention()(torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE))[1].shape
 
89
 
90
 
91
  class LSTMConcatAttentionEmbed(nn.Module):
@@ -97,17 +55,18 @@ class LSTMConcatAttentionEmbed(nn.Module):
97
  self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
98
  self.attn = BahdanauAttention(HIDDEN_SIZE)
99
  self.clf = nn.Sequential(
100
- nn.Linear(HIDDEN_SIZE, 128),
101
- nn.Dropout(),
102
- nn.Tanh(),
103
- nn.Linear(128, 1)
 
 
 
104
  )
105
 
106
- def forward(self, x):
107
  embeddings = self.embedding(x)
108
  outputs, (h_n, _) = self.lstm(embeddings)
109
  att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
110
  out = self.clf(att_hidden)
111
  return out, att_weights
112
-
113
-
 
 
1
  from typing import Tuple
2
+
3
  import torch
4
  import torch.nn as nn
5
 
6
+ HIDDEN_SIZE = 64
7
+ VOCAB_SIZE = 196906
8
+ EMBEDDING_DIM = 64 # embedding_dim
9
  SEQ_LEN = 100
10
+ BATCH_SIZE = 16
11
 
12
 
13
  class BahdanauAttention(nn.Module):
14
  def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
 
15
  super().__init__()
16
  self.hidden_size = hidden_size
17
  self.W_q = nn.Linear(hidden_size, hidden_size)
 
25
  lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
26
  final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE
27
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  keys = self.W_k(lstm_outputs)
 
 
29
  query = self.W_q(final_hidden)
 
 
 
30
 
31
  sum = query.unsqueeze(1) + keys
 
32
 
33
  tanhed = self.tanh(sum)
 
34
 
35
  vector = self.W_v(tanhed).squeeze(-1)
 
36
 
37
  att_weights = torch.softmax(vector, -1)
 
38
 
39
  context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze()
 
40
 
41
  return context, att_weights
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ BahdanauAttention()(
45
+ torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE)
46
+ )[1].shape
47
 
48
 
49
  class LSTMConcatAttentionEmbed(nn.Module):
 
55
  self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
56
  self.attn = BahdanauAttention(HIDDEN_SIZE)
57
  self.clf = nn.Sequential(
58
+ nn.Linear(HIDDEN_SIZE, 128),
59
+ nn.Dropout(),
60
+ nn.Tanh(),
61
+ nn.Linear(128, 64),
62
+ nn.Dropout(),
63
+ nn.Tanh(),
64
+ nn.Linear(64, 1),
65
  )
66
 
67
+ def forward(self, x):
68
  embeddings = self.embedding(x)
69
  outputs, (h_n, _) = self.lstm(embeddings)
70
  att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
71
  out = self.clf(att_hidden)
72
  return out, att_weights
 
 
pages/review_predictor.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import pickle
3
 
@@ -14,6 +15,32 @@ from preprocessing.preprocessing import data_preprocessing
14
  from preprocessing.rnn_preprocessing import preprocess_single_string
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @st.cache_resource
18
  def load_logreg():
19
  with open("vectorizer.pkl", "rb") as f:
@@ -93,7 +120,6 @@ metrics = {
93
  }
94
 
95
 
96
- col1, col2 = st.columns([1, 3])
97
  df = pd.DataFrame(metrics)
98
  df.set_index("Models", inplace=True)
99
  df.index.name = "Model"
@@ -101,10 +127,40 @@ df.index.name = "Model"
101
 
102
  st.sidebar.title("Model Selection")
103
  model_type = st.sidebar.radio("Select Model Type", ["Classic ML", "LSTM", "BERT"])
104
- st.title("Review Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Streamlit app code
107
- st.title("Sentiment Analysis with Logistic Regression")
108
  text_input = st.text_input("Enter your review:")
109
  if st.button("Predict"):
110
  if model_type == "Classic ML":
@@ -116,11 +172,14 @@ if st.button("Predict"):
116
  elif model_type == "BERT":
117
  prediction = predict_sentiment(text_input, model, tokenizer, "cpu")
118
 
 
119
  if prediction == 1:
120
- st.write("prediction")
121
- st.write("Отзыв положительный")
 
122
  elif prediction == 0:
123
- st.write("prediction")
124
- st.write("Отзыв отрицательный")
 
125
 
126
  st.write(df)
 
1
+ import base64
2
  import json
3
  import pickle
4
 
 
15
  from preprocessing.rnn_preprocessing import preprocess_single_string
16
 
17
 
18
+ def get_base64(file_path):
19
+ with open(file_path, "rb") as file:
20
+ base64_bytes = base64.b64encode(file.read())
21
+ base64_string = base64_bytes.decode("utf-8")
22
+ return base64_string
23
+
24
+
25
+ def set_background(png_file):
26
+ bin_str = get_base64(png_file)
27
+ page_bg_img = (
28
+ """
29
+ <style>
30
+ .stApp {
31
+ background-image: url("data:image/png;base64,%s");
32
+ background-size: auto;
33
+ }
34
+ </style>
35
+ """
36
+ % bin_str
37
+ )
38
+ st.markdown(page_bg_img, unsafe_allow_html=True)
39
+
40
+
41
+ set_background("main_background.png")
42
+
43
+
44
  @st.cache_resource
45
  def load_logreg():
46
  with open("vectorizer.pkl", "rb") as f:
 
120
  }
121
 
122
 
 
123
  df = pd.DataFrame(metrics)
124
  df.set_index("Models", inplace=True)
125
  df.index.name = "Model"
 
127
 
128
  st.sidebar.title("Model Selection")
129
  model_type = st.sidebar.radio("Select Model Type", ["Classic ML", "LSTM", "BERT"])
130
+
131
+
132
+ styled_text = """
133
+ <style>
134
+ .styled-title {
135
+ color: #FF00FF;
136
+ font-size: 40px;
137
+ text-shadow: -2px -2px 4px #000000;
138
+ -webkit-text-stroke-width: 1px;
139
+ -webkit-text-stroke-color: #000000;
140
+ }
141
+ .positive {
142
+ color: #00FF00;
143
+ font-size: 30px;
144
+ text-shadow: -2px -2px 4px #000000;
145
+ -webkit-text-stroke-width: 1px;
146
+ -webkit-text-stroke-color: #000000;
147
+
148
+ }
149
+ .negative {
150
+ color: #FF0000;
151
+ font-size: 30px;
152
+ text-shadow: -2px -2px 4px #000000;
153
+ -webkit-text-stroke-width: 1px;
154
+ -webkit-text-stroke-color: #000000;
155
+
156
+ }
157
+ </style>
158
+ """
159
+
160
+ st.markdown(styled_text, unsafe_allow_html=True)
161
 
162
  # Streamlit app code
163
+ st.markdown('<div class="styled-title">Review Prediction</div>', unsafe_allow_html=True)
164
  text_input = st.text_input("Enter your review:")
165
  if st.button("Predict"):
166
  if model_type == "Classic ML":
 
172
  elif model_type == "BERT":
173
  prediction = predict_sentiment(text_input, model, tokenizer, "cpu")
174
 
175
+ # Apply different styles based on prediction result
176
  if prediction == 1:
177
+ st.markdown(
178
+ f'<div class="positive">Отзыв положительный</div>', unsafe_allow_html=True
179
+ )
180
  elif prediction == 0:
181
+ st.markdown(
182
+ f'<div class="negative">Отзыв отрицательный</div>', unsafe_allow_html=True
183
+ )
184
 
185
  st.write(df)
pages/text_generator.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -5,6 +7,32 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
  from model.funcs import execution_time
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @st.cache_data
9
  def load_model():
10
  model_path = "17/"
@@ -18,26 +46,65 @@ tokenizer, model = load_model()
18
 
19
 
20
  @execution_time
21
- def generate_text(promt):
22
- promt = tokenizer.encode(promt, return_tensors="pt")
 
 
23
  model.eval()
24
  with torch.no_grad():
25
  out = model.generate(
26
- promt,
27
  do_sample=True,
28
- num_beams=2,
29
- temperature=1.5,
30
- top_p=0.9,
31
- max_length=150,
 
32
  )
33
  out = list(map(tokenizer.decode, out))[0]
34
  return out
35
 
36
 
37
- promt = st.text_input("Ask a question")
38
- generate = st.button("Generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  if generate:
40
- if not promt:
41
  st.write("42")
42
  else:
43
- st.write(generate_text(promt))
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
  import streamlit as st
4
  import torch
5
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
7
  from model.funcs import execution_time
8
 
9
 
10
+ def get_base64(file_path):
11
+ with open(file_path, "rb") as file:
12
+ base64_bytes = base64.b64encode(file.read())
13
+ base64_string = base64_bytes.decode("utf-8")
14
+ return base64_string
15
+
16
+
17
+ def set_background(png_file):
18
+ bin_str = get_base64(png_file)
19
+ page_bg_img = (
20
+ """
21
+ <style>
22
+ .stApp {
23
+ background-image: url("data:image/png;base64,%s");
24
+ background-size: cover;
25
+ }
26
+ </style>
27
+ """
28
+ % bin_str
29
+ )
30
+ st.markdown(page_bg_img, unsafe_allow_html=True)
31
+
32
+
33
+ set_background("text_generation.png")
34
+
35
+
36
  @st.cache_data
37
  def load_model():
38
  model_path = "17/"
 
46
 
47
 
48
  @execution_time
49
+ def generate_text(
50
+ prompt, num_beams=2, temperature=1.5, top_p=0.9, top_k=3, max_length=150
51
+ ):
52
+ prompt = tokenizer.encode(prompt, return_tensors="pt")
53
  model.eval()
54
  with torch.no_grad():
55
  out = model.generate(
56
+ prompt,
57
  do_sample=True,
58
+ num_beams=num_beams,
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ max_length=max_length,
63
  )
64
  out = list(map(tokenizer.decode, out))[0]
65
  return out
66
 
67
 
68
+ with st.sidebar:
69
+ num_beams = st.slider("Number of Beams", min_value=1, max_value=5, value=2)
70
+ temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.5)
71
+ top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9)
72
+ top_k = st.slider("Top-k", min_value=1, max_value=10, value=3)
73
+ max_length = st.slider("Maximum Length", min_value=20, max_value=300, value=150)
74
+
75
+ styled_text = """
76
+ <style>
77
+ .styled-text {
78
+ font-size: 30px;
79
+ text-shadow: -2px -2px 4px #000000;
80
+ color: #FFFFFF;
81
+ -webkit-text-stroke-width: 1px;
82
+ -webkit-text-stroke-color: #000000;
83
+ }
84
+ </style>
85
+ """
86
+
87
+ st.markdown(styled_text, unsafe_allow_html=True)
88
+
89
+ prompt = st.text_input(
90
+ "Ask a question",
91
+ key="question_input",
92
+ placeholder="Type here...",
93
+ type="default",
94
+ value="",
95
+ )
96
+ generate = st.button("Generate", key="generate_button")
97
+
98
  if generate:
99
+ if not prompt:
100
  st.write("42")
101
  else:
102
+ generated_text = generate_text(
103
+ prompt, num_beams, temperature, top_p, top_k, max_length
104
+ )
105
+ paragraphs = generated_text.split("\n")
106
+ styled_paragraphs = [
107
+ f'<div class="styled-text">{paragraph}</div>' for paragraph in paragraphs
108
+ ]
109
+ styled_generated_text = " ".join(styled_paragraphs)
110
+ st.markdown(styled_generated_text, unsafe_allow_html=True)
space_background.jpeg ADDED

Git LFS Details

  • SHA256: b42be30730cb6ddff0686c7331153ff6b378e77f7152f2ae898e495e7d592ab9
  • Pointer size: 131 Bytes
  • Size of remote file: 400 kB
space_main_background.avif ADDED
space_main_background.jpeg ADDED

Git LFS Details

  • SHA256: 09f95f7481849c2787568396b97564b6b64e8b0653f5386c2d77a3a3f292b64e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.03 MB
text_generation.png ADDED

Git LFS Details

  • SHA256: dfcc883890dcd2f2ee20fe8d431284fe2a2bccf925f7fb3abfcd37f24e90f9c2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.62 MB