Spaces:
Sleeping
Sleeping
File size: 2,486 Bytes
732a846 01142fd 732a846 01142fd 732a846 01142fd 732a846 ede71a9 732a846 01142fd ede71a9 01142fd 732a846 01142fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import streamlit as st
import torch
import pickle
from PIL import Image
import io
from model_execute import preprocess_images, output_to_names
from summarization import init_model_and_tokenizer, summarize
from wikipedia_api import getWikipedia
from mapbox_map import plot_map
@st.cache_resource
def load_recognition_model():
"""
Loads the translation model pipeline.
"""
filename = "pickle_model.pkl"
with open(filename, 'rb') as file:
model = pickle.load(file)
return model
@st.cache_resource
def load_summarizer():
"""
Loads the summarization model.
"""
summarizer, tokenizer = init_model_and_tokenizer()
return summarizer, tokenizer
def predict_images(images, model):
"""
Predicts each landmark name in `images` list.
"""
images = preprocess_images(images)
with torch.no_grad():
output = model(images)
names = output_to_names(output)
return names
def load_images():
"""
Loads user's images.
"""
uploaded_files = st.file_uploader(
label="Загрузите ваши фотографии.",
type=['png', 'jpg'],
accept_multiple_files=True
)
if uploaded_files is not None:
images = []
for file in uploaded_files:
image_data = file.getvalue()
st.image(image_data)
images.append(image_data)
return [Image.open(io.BytesIO(image_data)) for image_data in images]
else:
return None
# Load models
landmark_model = load_recognition_model()
summarizer, tokenizer = load_summarizer()
st.title("Распознавание достопримечательностей")
# Images input.
images = load_images()
summarize_checkbox = st.checkbox("Короткое описание")
result = st.button('Распознать')
if images and result:
# Get predictions
names = predict_images(images, landmark_model)
st.write(names)
# Request descriptions and coordinates from Wikipedia.
wiki_data = getWikipedia(names)
st.write("Загружены данные с википедии.")
# Summarize descriptions for each landmark.
if summarize_checkbox:
for landmark in wiki_data:
description = landmark['summary']
summarized = summarize(description, summarizer, tokenizer)
landmark['summarized'] = summarized
st.write(wiki_data)
# Draw a map.
plot_map(wiki_data)
|