Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
from geopy.geocoders import ArcGIS | |
import folium | |
from streamlit_folium import folium_static | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
session_state = st.session_state | |
if not hasattr(session_state, 'recommended_countries'): | |
session_state.recommended_countries = [] | |
st.set_page_config(layout="wide") | |
def load_model(): | |
model = AutoModel.from_pretrained("cointegrated/rubert-tiny2") | |
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
return model, tokenizer | |
model, tokenizer = load_model() | |
def load_data(): | |
df = pd.read_csv('data/countries.csv') | |
return df | |
df = load_data() | |
def embed_bert_cls(text, model, tokenizer): | |
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = model(**{k: v.to(model.device) for k, v in t.items()}) | |
embeddings = model_output.last_hidden_state[:, 0, :] | |
embeddings = torch.nn.functional.normalize(embeddings) | |
return embeddings[0].cpu().numpy() | |
def get_coordinates(country_name): | |
geolocator = ArcGIS() | |
location = geolocator.geocode(country_name) | |
if location: | |
return location.latitude, location.longitude | |
else: | |
return None | |
st.markdown(""" | |
<style> | |
body { | |
zoom: 67%; /* Установите желаемый масштаб здесь */ | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
col10, col11 = st.columns([2,1.5]) | |
with col10: | |
st.markdown("<style> input {font-size: 20px !important;}</style>", unsafe_allow_html=True) | |
first_input = st.text_input('Введите предпочтения по климату и типу местности') | |
second_input = st.text_input('Введите предпочтения по еде') | |
third_input = st.text_input('Введите предпочтения по активностям') | |
with col11: | |
st.markdown(f"<p style='font-size: 25px;'>Инструкция для пользователя:</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 15px;'>1.Должны быть заполнены все поля!</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 15px;'>2.Чем больше информации в текстовых полях, тем точнее будет обработан ваш запрос</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 15px;'>3.Будьте аккуратны с параметром безопасности страны, если вам не интересен этот параметр, не трогайте его</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 15px;'>4.Если попалась страна, где вы уже были, воспользуйтесь кнопкой 'Следующая рекомендация'</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 15px;'>5.Если понравилось и помогло данное приложение, не забудьте порекомендовать вашим друзьям, будем очень благодарны</p>", unsafe_allow_html=True) | |
col0, col1,_ = st.columns([0.05,0.1, 1]) | |
with col0: | |
option = st.selectbox( | |
'Виза', | |
('Да', 'Нет') | |
) | |
# col1, _ = st.columns([0.1,1]) | |
with col1: | |
sec_option = st.selectbox( | |
'Местоположение', | |
('Африка','Азия','Европа','Океания','Северная Америка','Южная Америка') | |
) | |
col2, _ = st.columns([0.5,1]) | |
with col2: | |
third_option = st.slider('Выберите значение, характеризующее оценку безопасности страны (чем ниже значение, тем более безопасная страна)', 1.0, 3.6, 3.6, -0.1) | |
col3,col4 = st.columns([1,8]) | |
col5,col6, col7 = st.columns([5,5,5]) | |
with col3: | |
button_test = st.button('Получить рекомендацию') | |
if button_test and first_input and second_input and third_input : | |
session_state.recommended_countries = [] | |
filtered_df = df[df['visa'] == option] | |
filtered_df = filtered_df[filtered_df['location'] == sec_option] | |
filtered_df = filtered_df[filtered_df['peace_index'] <= third_option] | |
decode_first = embed_bert_cls(first_input, model, tokenizer) | |
decode_second = embed_bert_cls(second_input, model, tokenizer) | |
decode_third = embed_bert_cls(third_input, model, tokenizer) | |
try: | |
review_embeddings = np.vstack(filtered_df['embeddings_review'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) | |
kitchen_embeddings = np.vstack(filtered_df['embeddings_kitchen'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) | |
activity_embeddings = np.vstack(filtered_df['embeddings_activity'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) | |
similarity_col1 = cosine_similarity(decode_first.reshape(1, -1), review_embeddings) | |
similarity_col2 = cosine_similarity(decode_second.reshape(1, -1), kitchen_embeddings) | |
similarity_col3 = cosine_similarity(decode_third.reshape(1, -1), activity_embeddings) | |
mean_similarity = np.mean([similarity_col1, similarity_col2, similarity_col3], axis=0) | |
max_similarity_row = np.argmax(mean_similarity) | |
max_similarity_value = np.max(mean_similarity) | |
recommended_country = filtered_df.iloc[max_similarity_row]['country'] | |
recommended_review = filtered_df.iloc[max_similarity_row]['short_review'] | |
recommended_flag = filtered_df.iloc[max_similarity_row]['flag'] | |
recommended_photo = filtered_df.iloc[max_similarity_row]['country_photo'] | |
similarity_values = [similarity_col1[:, max_similarity_row], | |
similarity_col2[:, max_similarity_row], | |
similarity_col3[:, max_similarity_row]] | |
session_state.recommended_countries.append(recommended_country) | |
with col5: | |
st.image(recommended_photo, width=795, use_column_width=False) | |
with col6: | |
st.image(recommended_flag, width=200, use_column_width=False) | |
st.markdown(f"<p style='font-size: 25px;'>Рекомендуемая страна: {recommended_country}</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 25px;'> {recommended_review}</p>", unsafe_allow_html=True) | |
scale_html = f'<div style="width: 300px; height: 30px;">' | |
scale_html += f'<progress value="{max_similarity_value}" max="1" style="width: 100%; height: 100%;"></progress>' | |
scale_html += f'<div style="position: relative; top: -22px; text-align: center;">' | |
scale_html += f'<span style="position: absolute; left: 0;">0</span>' | |
scale_html += f'<span style="position: absolute; right: 0;">1</span>' | |
scale_html += f'</div></div>' | |
st.markdown(f"<p style='font-size: 25px;'>Оценка близости вашего запроса и страны</p>", unsafe_allow_html=True) | |
st.markdown(scale_html, unsafe_allow_html=True) | |
with col7: | |
coordinates = get_coordinates(recommended_country) | |
if coordinates: | |
my_map = folium.Map(location=coordinates, zoom_start=5, tiles="Cartodb Positron",max_bounds=True, min_lon=-180, max_lon=180, min_lat=-90, max_lat=90,min_zoom=2,max_zoom=15) | |
folium.Marker(location=coordinates, popup=recommended_country).add_to(my_map) | |
folium_static(my_map) | |
else: | |
st.write(f"Координаты для страны {recommended_country} не найдены.") | |
except ValueError as e: | |
st.markdown(f"<p style='font-size: 25px;'>Проверьте, пожалуйста ваш запрос, по данным параметрам не получается порекомендовать страну</p>", unsafe_allow_html=True) | |
if session_state.recommended_countries: | |
with col4: | |
next_button = st.button("Следующая рекомендация") | |
if next_button and session_state.recommended_countries : | |
filtered_df = df[df['visa'] == option] | |
filtered_df = filtered_df[filtered_df['location'] == sec_option] | |
filtered_df = filtered_df[~filtered_df['country'].isin(session_state.recommended_countries)] | |
decode_first = embed_bert_cls(first_input, model, tokenizer) | |
decode_second = embed_bert_cls(second_input, model, tokenizer) | |
decode_third = embed_bert_cls(third_input, model, tokenizer) | |
review_embeddings = np.vstack(filtered_df['embeddings_review'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) if not filtered_df.empty else None | |
kitchen_embeddings = np.vstack(filtered_df['embeddings_kitchen'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) if not filtered_df.empty else None | |
activity_embeddings = np.vstack(filtered_df['embeddings_activity'].apply(lambda x: np.fromstring(x[1:-1], sep=' '))) if not filtered_df.empty else None | |
if review_embeddings is not None and kitchen_embeddings is not None and activity_embeddings is not None: | |
similarity_col1 = cosine_similarity(decode_first.reshape(1, -1), review_embeddings) | |
similarity_col2 = cosine_similarity(decode_second.reshape(1, -1), kitchen_embeddings) | |
similarity_col3 = cosine_similarity(decode_third.reshape(1, -1), activity_embeddings) | |
mean_similarity = np.mean([similarity_col1, similarity_col2, similarity_col3], axis=0) | |
max_similarity_row = np.argmax(mean_similarity) | |
max_similarity_value = np.max(mean_similarity) | |
if max_similarity_value > 0: | |
recommended_country = filtered_df.iloc[max_similarity_row]['country'] | |
recommended_review = filtered_df.iloc[max_similarity_row]['short_review'] | |
recommended_flag = filtered_df.iloc[max_similarity_row]['flag'] | |
recommended_photo = filtered_df.iloc[max_similarity_row]['country_photo'] | |
similarity_values = [similarity_col1[:, max_similarity_row], | |
similarity_col2[:, max_similarity_row], | |
similarity_col3[:, max_similarity_row]] | |
session_state.recommended_countries.append(recommended_country) | |
with col5: | |
st.image(recommended_photo, width=795, use_column_width=False) | |
with col6: | |
st.image(recommended_flag, width=200, use_column_width=False) | |
st.markdown(f"<p style='font-size: 25px;'>Рекомендуемая страна: {recommended_country}</p>", unsafe_allow_html=True) | |
st.markdown(f"<p style='font-size: 25px;'> {recommended_review}</p>", unsafe_allow_html=True) | |
scale_html = f'<div style="width: 300px; height: 30px;">' | |
scale_html += f'<progress value="{max_similarity_value}" max="1" style="width: 100%; height: 100%;"></progress>' | |
scale_html += f'<div style="position: relative; top: -22px; text-align: center;">' | |
scale_html += f'<span style="position: absolute; left: 0;">0</span>' | |
scale_html += f'<span style="position: absolute; right: 0;">1</span>' | |
scale_html += f'</div></div>' | |
st.markdown(f"<p style='font-size: 25px;'>Оценка близости вашего запроса и страны</p>", unsafe_allow_html=True) | |
st.markdown(scale_html, unsafe_allow_html=True) | |
with col7: | |
coordinates = get_coordinates(recommended_country) | |
if coordinates: | |
my_map = folium.Map(location=coordinates, zoom_start=5, tiles="Cartodb Positron", | |
max_bounds=True, | |
min_lon=-180, max_lon=180, min_lat=-90, max_lat=90, min_zoom=2, | |
max_zoom=15) | |
folium.Marker(location=coordinates, popup=recommended_country).add_to(my_map) | |
folium_static(my_map) | |
else: | |
st.write(f"Координаты для страны {recommended_country} не найдены.") | |
else: | |
st.markdown(f"<p style='font-size: 25px;'>Больше нет рекомендованых стран по вашему запросу</p>", unsafe_allow_html=True) | |