Spaces:
Sleeping
Sleeping
Upload 20 files
Browse files- .gitattributes +6 -0
- app.py +15 -0
- data/cleaned_df.csv +3 -0
- data/final_with_emb.csv +3 -0
- data/to_translate_df.csv +3 -0
- for_models/embeddings.npy +3 -0
- for_models/faiss_index_hnsw.bin +3 -0
- for_models/faiss_index_ip.bin +3 -0
- for_models/faiss_index_l2.bin +3 -0
- images/logo_0.jpeg +3 -0
- images/logo_1.jpeg +3 -0
- images/screenshot.png +3 -0
- models/dashas/distilbert_index.pkl +3 -0
- models/dashas/labse_index.pkl +3 -0
- models/dashas/tiny2_index.pkl +3 -0
- models/description_vectors_MiniLM-L12-v2.pt +3 -0
- pages/page_02.py +53 -0
- pages/page_03.py +228 -0
- pages/page_04.py +64 -0
- pages/random_10.py +17 -0
- requirements.txt +7 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/cleaned_df.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/final_with_emb.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/to_translate_df.csv filter=lfs diff=lfs merge=lfs -text
|
39 |
+
images/logo_0.jpeg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
images/logo_1.jpeg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
images/screenshot.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
st.title("Умный поиск сериалов")
|
5 |
+
|
6 |
+
logo = Image.open('images/logo_0.jpeg')
|
7 |
+
st.image(logo, width=800)
|
8 |
+
|
9 |
+
st.write("### Оглавление")
|
10 |
+
st.write("Для улучшения поиска на стриминговом сервисе мы создали систему семантического поиска, которая учитывает описание сериалов.")
|
11 |
+
|
12 |
+
st.write("### Команда проекта:")
|
13 |
+
st.write("[Илья](https://github.com/lefuuu)")
|
14 |
+
st.write("[Алина](https://github.com/RenaTheDv)")
|
15 |
+
st.write("[Даша](https://github.com/DashonokOk)")
|
data/cleaned_df.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:937cfb75767b97a1c0d8769f5f7d19acff29f50494f49554c9a77458bdf8ba26
|
3 |
+
size 18382297
|
data/final_with_emb.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac79c9b1169792207187145289bb12712ece61ce0577450bc94a42371d08d825
|
3 |
+
size 358228433
|
data/to_translate_df.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:107478c345f28d11daa257792baa4a193cfa27bf6691dbe29d9b17cfb0559619
|
3 |
+
size 16177462
|
for_models/embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:504ed433a6d9507a551425974a69a30be1eb597c34cfe81153e42279d96990fe
|
3 |
+
size 121442432
|
for_models/faiss_index_hnsw.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89702507a9e9e459774f877208596641562151b70c63cb6edafc17835fd94e7b
|
3 |
+
size 66100074
|
for_models/faiss_index_ip.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a41346a31ef936b4739da33bd1c64d464cf158669f6216c477cc7dc9b2eea74f
|
3 |
+
size 60721197
|
for_models/faiss_index_l2.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b408cae8ce3dc64de21db17d0eea2e0e13fb7cd8f837a168ecc65abb048d925
|
3 |
+
size 60721197
|
images/logo_0.jpeg
ADDED
![]() |
Git LFS Details
|
images/logo_1.jpeg
ADDED
![]() |
Git LFS Details
|
images/screenshot.png
ADDED
![]() |
Git LFS Details
|
models/dashas/distilbert_index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:55f517adbb228363e26b42f582362817dcddb5d901ebc951f7f2c029890d276f
|
3 |
+
size 60721279
|
models/dashas/labse_index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a088ce654496357664a98d61caff34c16d7a04dc2c725b2840795dd9d053214
|
3 |
+
size 60721279
|
models/dashas/tiny2_index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0aefe24fb67cb8bc404a03beaf0789596d880951e762523dd485bbe76f6cde89
|
3 |
+
size 24668095
|
models/description_vectors_MiniLM-L12-v2.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3dc29187c076888ec1688092ddea8be357892ae67250f7d7f2545a48c9819338
|
3 |
+
size 25912158
|
pages/page_02.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
|
7 |
+
@st.cache_resource
|
8 |
+
def load_data():
|
9 |
+
return pd.read_csv('data/to_translate_df.csv', index_col='Unnamed: 0')
|
10 |
+
|
11 |
+
|
12 |
+
@st.cache_resource
|
13 |
+
def load_model():
|
14 |
+
return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
15 |
+
|
16 |
+
@st.cache_data
|
17 |
+
def load_vectors():
|
18 |
+
return torch.load("models/description_vectors_MiniLM-L12-v2.pt")
|
19 |
+
|
20 |
+
|
21 |
+
def search_series(query, top_k=5):
|
22 |
+
query_vector = model.encode([query], convert_to_tensor=True)
|
23 |
+
similarities = cosine_similarity(query_vector.cpu(), description_vectors.cpu())[0]
|
24 |
+
top_indices = similarities.argsort()[-top_k:][::-1]
|
25 |
+
|
26 |
+
results = df.iloc[top_indices][["title", "description", "image_url", "genres"]].copy()
|
27 |
+
results["similarity (%)"] = (similarities[top_indices] * 100).round(2)
|
28 |
+
return results
|
29 |
+
|
30 |
+
# UI
|
31 |
+
|
32 |
+
st.title('Умный (не самый) поиск сериалов')
|
33 |
+
st.write('Использовалась модель MiniLM-L12-v2')
|
34 |
+
st.divider()
|
35 |
+
user_input = st.text_input("Чего желаешь посмотреть")
|
36 |
+
col2, col1, col3 = st.columns([2, 5,1])
|
37 |
+
with col3:
|
38 |
+
click = st.button("Найти")
|
39 |
+
with col2:
|
40 |
+
n = st.number_input('Топ', 0, 50, value=10 , step=5)
|
41 |
+
if click:
|
42 |
+
df = load_data()
|
43 |
+
model = load_model()
|
44 |
+
description_vectors = load_vectors()
|
45 |
+
|
46 |
+
results = search_series(user_input,n)
|
47 |
+
for i in range(n):
|
48 |
+
st.subheader(results.iloc[i]['title'])
|
49 |
+
st.write(f"**Жанры:** {results.iloc[i]['genres']}")
|
50 |
+
st.write(f"**Процент схожести:** {round(results.iloc[i]['similarity (%)'])}%")
|
51 |
+
st.image(results.iloc[i]['image_url'])
|
52 |
+
st.write(results.iloc[i]['description'])
|
53 |
+
st.divider()
|
pages/page_03.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import faiss
|
6 |
+
# import openai
|
7 |
+
import spacy
|
8 |
+
from googletrans import Translator
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
@st.cache_resource
|
14 |
+
def load_model():
|
15 |
+
return SentenceTransformer("sentence-transformers/paraphrase-xlm-r-multilingual-v1")
|
16 |
+
|
17 |
+
@st.cache_data
|
18 |
+
def load_data():
|
19 |
+
df = pd.read_csv('data/final_with_emb.csv')
|
20 |
+
return df
|
21 |
+
|
22 |
+
@st.cache_data
|
23 |
+
def load_embeddings():
|
24 |
+
return np.load('for_models/embeddings.npy')
|
25 |
+
|
26 |
+
@st.cache_resource
|
27 |
+
def load_faiss_index():
|
28 |
+
index_l2 = faiss.read_index('for_models/faiss_index_l2.bin')
|
29 |
+
index_ip = faiss.read_index('for_models/faiss_index_ip.bin')
|
30 |
+
index_hnsw = faiss.read_index('for_models/faiss_index_hnsw.bin')
|
31 |
+
return {'L2': index_l2, 'IP': index_ip, 'HNSW': index_hnsw}
|
32 |
+
|
33 |
+
st.title('Рекомендация сериалов')
|
34 |
+
st.markdown(
|
35 |
+
"""
|
36 |
+
<style>
|
37 |
+
.header {
|
38 |
+
font-size: 32px;
|
39 |
+
font-weight: bold;
|
40 |
+
color: #7147e6;
|
41 |
+
margin-bottom: 20px;
|
42 |
+
}
|
43 |
+
|
44 |
+
.subheader {
|
45 |
+
font-size: 24px;
|
46 |
+
font-weight: 600;
|
47 |
+
color: #7147e6;
|
48 |
+
margin-bottom: 15px;
|
49 |
+
}
|
50 |
+
|
51 |
+
.paragraph {
|
52 |
+
font-size: 18px;
|
53 |
+
line-height: 1.6;
|
54 |
+
color: #4799e6;
|
55 |
+
margin-bottom: 20px;
|
56 |
+
}
|
57 |
+
|
58 |
+
.list {
|
59 |
+
font-size: 18px;
|
60 |
+
color: #4799e6;
|
61 |
+
line-height: 1.8;
|
62 |
+
padding-left: 20px;
|
63 |
+
}
|
64 |
+
|
65 |
+
.service {
|
66 |
+
background-color: #ECF0F1;
|
67 |
+
border-radius: 10px;
|
68 |
+
padding: 20px;
|
69 |
+
margin-bottom: 30px;
|
70 |
+
}
|
71 |
+
|
72 |
+
.highlight {
|
73 |
+
color: #E74C3C;
|
74 |
+
font-weight: bold;
|
75 |
+
}
|
76 |
+
</style>
|
77 |
+
""", unsafe_allow_html=True
|
78 |
+
)
|
79 |
+
|
80 |
+
st.markdown('<div class="header">Добро пожаловать на мою страницу!</div>', unsafe_allow_html=True)
|
81 |
+
st.markdown(
|
82 |
+
"""
|
83 |
+
<div class="paragraph">
|
84 |
+
Этот сервис использует передовые технологии машинного обучения и обработки естественного языка для того, чтобы порекомендовать вам сериалы, которые могут вам понравиться. Мы применяем XLM-RoBERTa для поиска и обработки данных, чтобы вывести наиболее релевантные результаты по вашему запросу.
|
85 |
+
</div>
|
86 |
+
""", unsafe_allow_html=True)
|
87 |
+
st.markdown(
|
88 |
+
"""
|
89 |
+
<div class="subheader">Что умеет сервис?</div>
|
90 |
+
<div class="paragraph">
|
91 |
+
Cервис предоставляет следующие возможности:
|
92 |
+
</div>
|
93 |
+
<ul class="list">
|
94 |
+
<li>Поиск сериалов по вашему запросу с использованием различных методов поиска.</li>
|
95 |
+
<li>Перевод информации о сериале в режиме реального времени (если язык - не русский).</li>
|
96 |
+
<li>Вывод информации о сериале, включая название, описание и изображение.</li>
|
97 |
+
<li>Интерактивный поиск с возможностью выбора метода поиска: L2, IP, HNSW.</li>
|
98 |
+
<li>Отображение списка сериалов в удобном формате.</li>
|
99 |
+
</ul>
|
100 |
+
""", unsafe_allow_html=True)
|
101 |
+
|
102 |
+
def calculate_cosine_similarity(query_emb, embeddings):
|
103 |
+
similarity = cosine_similarity(query_emb, embeddings)
|
104 |
+
return similarity.flatten()
|
105 |
+
|
106 |
+
def calculate_l2_similarity(query_emb, embeddings):
|
107 |
+
l2_distances = np.linalg.norm(embeddings - query_emb, axis=1)
|
108 |
+
return l2_distances
|
109 |
+
|
110 |
+
top_k = st.slider('Сколько выдаем рекомендаций?', min_value=1, max_value=20, value=5)
|
111 |
+
|
112 |
+
def search_similar(query, index_type, top_k=5):
|
113 |
+
query_emb = model.encode([query]).astype(np.float32)
|
114 |
+
|
115 |
+
if index_type == 'IP':
|
116 |
+
faiss.normalize_L2(query_emb)
|
117 |
+
|
118 |
+
distances, indices = indexes[index_type].search(query_emb, top_k)
|
119 |
+
# st.write(f"Используемый индекс: {index_type}")
|
120 |
+
# st.write(f"Размер индекса: {indexes[index_type].ntotal}")
|
121 |
+
|
122 |
+
results = df.iloc[indices[0]]
|
123 |
+
|
124 |
+
return results, distances[0]
|
125 |
+
|
126 |
+
translator = Translator()
|
127 |
+
|
128 |
+
def detect_and_translate(text):
|
129 |
+
detected_lang = translator.detect(text).lang
|
130 |
+
if detected_lang != 'ru':
|
131 |
+
translated_text = translator.translate(text, src=detected_lang, dest='ru').text
|
132 |
+
return translated_text
|
133 |
+
return text
|
134 |
+
|
135 |
+
nlp = spacy.load('en_core_web_sm')
|
136 |
+
|
137 |
+
def show_desc(desc, title, max_lines=4):
|
138 |
+
translated_title = detect_and_translate(title)
|
139 |
+
translated_desc = detect_and_translate(desc)
|
140 |
+
doc = nlp(translated_desc)
|
141 |
+
sentence = [sent.text for sent in doc.sents]
|
142 |
+
short_desc = ' '.join(sentence[:max_lines])
|
143 |
+
|
144 |
+
st.markdown(f'### {translated_title}')
|
145 |
+
st.write(short_desc)
|
146 |
+
|
147 |
+
with st.expander('Показать полное описание'):
|
148 |
+
st.write(desc)
|
149 |
+
|
150 |
+
|
151 |
+
# client = openai.OpenAI(api_key='сюда свой APIKEY от ChatGPT')
|
152 |
+
|
153 |
+
def generate_summary(query, title, desc):
|
154 |
+
prompt = f"""Ты – эксперт по кино. Пользователь ищет сериал по запросу: "{query}".
|
155 |
+
Опиши сериал "{title}" коротко и понятно. Объясни, почему он подходит.
|
156 |
+
|
157 |
+
Описание из базы: {desc}
|
158 |
+
|
159 |
+
Ответь в формате:
|
160 |
+
- Краткое описание:
|
161 |
+
- Почему стоит посмотреть:
|
162 |
+
"""
|
163 |
+
|
164 |
+
response = client.chat.completions.create(
|
165 |
+
model="gpt-4",
|
166 |
+
messages=[{"role": "user", "content": prompt}]
|
167 |
+
)
|
168 |
+
|
169 |
+
return response.choices[0].message.content
|
170 |
+
|
171 |
+
model = load_model()
|
172 |
+
df = load_data()
|
173 |
+
embeddings = load_embeddings()
|
174 |
+
indexes = load_faiss_index()
|
175 |
+
|
176 |
+
query = st.text_input('Введите описание сериала', 'Найди мне что-нибудь про автомобили')
|
177 |
+
|
178 |
+
index_type = st.selectbox('Выберите метод поиска:', ['IP', 'L2', 'HNSW'])
|
179 |
+
|
180 |
+
if st.button('Начать поиск'):
|
181 |
+
if query:
|
182 |
+
results, scores = search_similar(query, index_type, top_k)
|
183 |
+
|
184 |
+
st.subheader(f'Результаты c использованием {index_type}:')
|
185 |
+
for _, row in results.iterrows():
|
186 |
+
title = row['title']
|
187 |
+
desc = row['description']
|
188 |
+
image_url = row['image_url']
|
189 |
+
|
190 |
+
# summary = generate_summary(query, title, desc) раскоммитить при работе с ChatGPT
|
191 |
+
|
192 |
+
with st.container():
|
193 |
+
col1, col2 = st.columns([1, 3])
|
194 |
+
with col1:
|
195 |
+
st.image(image_url, width=500)
|
196 |
+
with col2:
|
197 |
+
# st.write(summary) если работает ChatGPT
|
198 |
+
show_desc(desc, title)
|
199 |
+
|
200 |
+
st.markdown('---')
|
201 |
+
|
202 |
+
query_emb = model.encode([query]).astype(np.float32)
|
203 |
+
cosine_scores = calculate_cosine_similarity(query_emb, embeddings)
|
204 |
+
l2_scores = calculate_l2_similarity(query_emb, embeddings)
|
205 |
+
faiss.normalize_L2(query_emb)
|
206 |
+
distances_hnsw, _ = indexes['HNSW'].search(query_emb, len(df))
|
207 |
+
hnsw_scores = distances_hnsw[0]
|
208 |
+
|
209 |
+
df['cosine_similarity'] = cosine_scores
|
210 |
+
df['l2_similarity'] = l2_scores
|
211 |
+
df['hnsw_similarity'] = hnsw_scores
|
212 |
+
|
213 |
+
df_sorted = df[['title', 'cosine_similarity', 'l2_similarity', 'hnsw_similarity']].sort_values(by='cosine_similarity', ascending=False)
|
214 |
+
|
215 |
+
st.subheader('Таблица с метриками')
|
216 |
+
st.markdown(
|
217 |
+
"""
|
218 |
+
<style>
|
219 |
+
.stDataFrame {
|
220 |
+
height: 400px;
|
221 |
+
overflow-y: auto;
|
222 |
+
width: 100%;
|
223 |
+
}
|
224 |
+
</style>
|
225 |
+
""",
|
226 |
+
unsafe_allow_html=True
|
227 |
+
)
|
228 |
+
st.dataframe(df_sorted)
|
pages/page_04.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from transformers import AutoTokenizer, AutoModel
|
7 |
+
import torch
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
df = pd.read_csv('data/cleaned_df.csv')
|
11 |
+
|
12 |
+
labse_model = SentenceTransformer('sentence-transformers/LaBSE')
|
13 |
+
distilbert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-multilingual-cased')
|
14 |
+
distilbert_model = AutoModel.from_pretrained('distilbert-base-multilingual-cased')
|
15 |
+
tiny2_tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2')
|
16 |
+
tiny2_model = AutoModel.from_pretrained('cointegrated/rubert-tiny2')
|
17 |
+
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
distilbert_model.to(device)
|
20 |
+
tiny2_model.to(device)
|
21 |
+
|
22 |
+
with open('models/dashas/labse_index.pkl', 'rb') as f:
|
23 |
+
labse_index = pickle.load(f)
|
24 |
+
with open('models/dashas/distilbert_index.pkl', 'rb') as f:
|
25 |
+
distilbert_index = pickle.load(f)
|
26 |
+
with open('models/dashas/tiny2_index.pkl', 'rb') as f:
|
27 |
+
tiny2_index = pickle.load(f)
|
28 |
+
|
29 |
+
def search_series(query, model, tokenizer=None, index=None, top_k=5):
|
30 |
+
if tokenizer:
|
31 |
+
inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=128)
|
32 |
+
inputs = {key: val.to(device) for key, val in inputs.items()}
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = model(**inputs)
|
35 |
+
query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
36 |
+
else:
|
37 |
+
query_embedding = model.encode([query])
|
38 |
+
distances, indices = index.search(query_embedding, top_k)
|
39 |
+
results = df.iloc[indices[0]]
|
40 |
+
return results
|
41 |
+
|
42 |
+
st.title("Умный поиск сериалов")
|
43 |
+
st.image("images/logo_1.jpeg", width=800) # Add your logo here
|
44 |
+
|
45 |
+
query = st.text_input("Введите запрос:")
|
46 |
+
model_choice = st.selectbox("Выберите модель:", ["LaBSE", "DistilBERT", "tiny2"])
|
47 |
+
top_k = st.slider("Количество результатов:", min_value=1, max_value=20, value=5)
|
48 |
+
|
49 |
+
if st.button("Найти"):
|
50 |
+
if query:
|
51 |
+
if model_choice == "LaBSE":
|
52 |
+
results = search_series(query, labse_model, index=labse_index, top_k=top_k)
|
53 |
+
elif model_choice == "DistilBERT":
|
54 |
+
results = search_series(query, distilbert_model, distilbert_tokenizer, distilbert_index, top_k=top_k)
|
55 |
+
elif model_choice == "tiny2":
|
56 |
+
results = search_series(query, tiny2_model, tiny2_tokenizer, tiny2_index, top_k=top_k)
|
57 |
+
|
58 |
+
st.write("Результаты поиска:")
|
59 |
+
for i, row in results.iterrows():
|
60 |
+
st.write(f"**{row['title']}**")
|
61 |
+
st.write(row['description'])
|
62 |
+
st.image(row['image_url'], width=600)
|
63 |
+
else:
|
64 |
+
st.write("Пожалуйста, введите запрос.")
|
pages/random_10.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
# task 1 sample 10 streamlit
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
df=pd.read_csv('data/full_df_without_nans.csv')
|
8 |
+
n=10
|
9 |
+
rand = df.sample(n, ignore_index=True)
|
10 |
+
st.title('Cлучайные 10 сериалов ')
|
11 |
+
st.divider()
|
12 |
+
st.write('По нажатию на кнопку генерирует случайные 10 сериалов из датасета')
|
13 |
+
st.divider()
|
14 |
+
if st.button('Сгенерировать 10 сериалов'):
|
15 |
+
for i in range(0,n):
|
16 |
+
st.subheader(f"{rand['title'][i]}")
|
17 |
+
st.write(f"{rand['description'][i]}")
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
torch
|
4 |
+
transformers
|
5 |
+
sentence-transformers
|
6 |
+
faiss-cpu
|
7 |
+
numpy
|