vdgb_project / app.py
DSVmon's picture
1
f1840e7 verified
# === Загрузка библиотек ===
from pypdf import PdfReader, PdfWriter
import gradio as gr
import fitz # PyMuPDF
from PIL import Image
import pandas as pd
import cv2
import numpy as np
import os
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import difflib
# Загрузка TrOCR
processor = TrOCRProcessor.from_pretrained('kazars24/trocr-base-handwritten-ru')
model = VisionEncoderDecoderModel.from_pretrained('kazars24/trocr-base-handwritten-ru')
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# === 1. Функция поиска и группировки линий ===
def group_lines(contours, img_size, y_tolerance=10, is_horizontal=True):
line_groups = []
used = [False] * len(contours)
for i in range(len(contours)):
if used[i]:
continue
group = [contours[i]]
used[i] = True
x, y, w, h = cv2.boundingRect(contours[i])
for j in range(i + 1, len(contours)):
if used[j]:
continue
x2, y2, w2, h2 = cv2.boundingRect(contours[j])
if is_horizontal:
if abs(y2 - y) < y_tolerance:
group.append(contours[j])
used[j] = True
else:
if abs(x2 - x) < y_tolerance:
group.append(contours[j])
used[j] = True
line_groups.append(group)
return line_groups
# === 2. Основная функция отрисовки линий и сохранения координат ===
def detect_table_lines_and_cells(img, min_cell_size=15):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
# Горизонтальные линии
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
detect_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
horizontal_contours = cv2.findContours(detect_horizontal, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
horizontal_contours = horizontal_contours[0] if len(horizontal_contours) == 2 else horizontal_contours[1]
horizontal_line_groups = group_lines(horizontal_contours, img.shape[0], is_horizontal=True)
horizontal_line_groups.sort(key=lambda g: np.mean([cv2.boundingRect(c)[1] for c in g]))
horizontal_coords = [int(np.mean([cv2.boundingRect(c)[1] + cv2.boundingRect(c)[3] / 2 for c in group])) for group in horizontal_line_groups[3:]] # от 4-й линии
# Вертикальные линии
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
detect_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
vertical_contours = cv2.findContours(detect_vertical, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
vertical_contours = vertical_contours[0] if len(vertical_contours) == 2 else vertical_contours[1]
vertical_line_groups = group_lines(vertical_contours, img.shape[1], is_horizontal=False)
vertical_coords = [int(np.mean([cv2.boundingRect(c)[0] + cv2.boundingRect(c)[2] / 2 for c in group])) for group in vertical_line_groups]
# Поиск ячеек
cells = []
horizontal_coords = sorted(horizontal_coords)
vertical_coords = sorted(vertical_coords)
for row_idx in range(len(horizontal_coords) - 1):
y1, y2 = horizontal_coords[row_idx], horizontal_coords[row_idx + 1]
for col_idx in range(len(vertical_coords) - 1):
x1, x2 = vertical_coords[col_idx], vertical_coords[col_idx + 1]
w = x2 - x1
h = y2 - y1
if w > min_cell_size and h > min_cell_size:
cells.append({'row': row_idx, 'col': col_idx, 'box': (x1, y1, w, h)})
return cells
# === 3. Функция распознавания текста в ячейках ===
def recognize_text(image, max_length=10):
if image is None:
return "Не удалось загрузить изображение"
try:
inputs = processor(images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
**inputs,
max_length=max_length,
early_stopping=True,
num_beams=1,
use_cache=True
)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
except Exception as e:
print(f"Ошибка распознавания: {e}")
return "Ошибка распознавания"
# === 4. Обрезка ячеек и OCR для таблицы ===
def crop_and_recognize_cells(image, cells):
allowed_words = ["труба", "врезка", "зкл", "отвод", "арм", "переход", "тройник", "заглушка", "зад-ка", "т-т", "комп-р"]
recognized = {}
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
for cell in cells:
x, y, w, h = cell['box']
cropped = pil_image.crop((x, y, x + w, y + h))
text = recognize_text(cropped, max_length=10).strip().lower()
# Заменяем все запятые на точки
text = text.replace(',', '.')
if any(char.isdigit() for char in text):
# Если есть цифры - оставляем как есть (уже заменили запятые)
final_text = text
else:
# Если текст состоит только из букв
if len(text) <= 2:
# Если длина 1 или 2 символа - заменяем на пустую строку
final_text = ""
else:
# Ищем наиболее похожее слово из словаря
matches = difflib.get_close_matches(text, allowed_words, n=1, cutoff=0.5)
final_text = matches[0] if matches else text
recognized[(cell['row'], cell['col'])] = final_text
return recognized
# === 5. Полный процесс обработки изображения таблицы ===
def process_pdf_table(pdf_path, output_excel='results.xlsx'):
# Извлечение изображений из PDF с помощью PyMuPDF
doc = fitz.open(pdf_path)
images = []
for page_num in range(len(doc)):
page = doc.load_page(page_num)
pix = page.get_pixmap(dpi=300)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
if not images:
print("Ошибка: PDF пустой или не удалось сконвертировать.")
return
image = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR) # Переводим PIL -> OpenCV
cells = detect_table_lines_and_cells(image, min_cell_size=15)
print(f"Найдено ячеек: {len(cells)}")
recognized = crop_and_recognize_cells(image, cells)
# Собираем в DataFrame
data = {}
for (row, col), text in recognized.items():
data.setdefault(row, {})[col] = text
max_cols = max((max(cols.keys()) for cols in data.values()), default=0) + 1
rows = []
for row_idx in range(max(data.keys()) + 1):
row = []
for col_idx in range(max_cols):
row.append(data.get(row_idx, {}).get(col_idx, ""))
rows.append(row)
df = pd.DataFrame(rows)
df.to_excel(output_excel, index=False, header=False)
print(f"Результат сохранён в {output_excel}")
return output_excel
# === Gradio приложение ===
def gradio_process(pdf_file, progress=gr.Progress()):
progress(0, desc="Чтение PDF...")
# Получаем имя без расширения и меняем его на .xlsx
base_name = os.path.splitext(os.path.basename(pdf_file.name))[0]
output_excel = f"{base_name}.xlsx"
progress(0.3, desc="Поиск ячеек таблицы...")
result_file = process_pdf_table(pdf_file.name, output_excel=output_excel)
progress(1.0, desc="Готово! Таблица сохранена.")
return result_file
app = gr.Interface(
fn=gradio_process,
inputs=gr.File(label="Загрузите PDF файл таблицы"),
outputs=gr.File(label="Скачайте Excel с распознанными ячейками"),
title="📄 PDF → Excel распознавание таблиц",
description="Загрузите PDF-файл с таблицей. Программа найдет ячейки, распознает текст и сохранит результат в Excel-файл.",
allow_flagging="never"
)
app.launch()