Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from io import BytesIO | |
from PIL import Image | |
import fitz # PyMuPDF | |
import numpy as np | |
from transformers import NougatProcessor, VisionEncoderDecoderModel | |
import nltk | |
import ssl | |
# 初始化NLTK | |
try: | |
_create_unverified_https_context = ssl._create_unverified_context | |
except AttributeError: | |
pass | |
else: | |
ssl._create_default_https_context = _create_unverified_https_context | |
# 下载NLTK必要的数据 | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
# Set environment variables | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
DESCRIPTION = ''' | |
<div> | |
<h1 style="text-align: center;">Academic Paper Improver</h1> | |
<p>This Space helps you improve a selected content of your academic paper using the <a href="https://huggingface.co/Xtra-Computing/XtraGPT-7B"><b>XtraGPT</b></a> model series, ensuring controllability on criteria following and in-context ability.</p> | |
<p>Upload your PDF paper, select a section of text you want to improve, and specify your requirements.</p> | |
</div> | |
''' | |
CITATION = """ | |
<div style="font-family: monospace; white-space: pre; margin-top: 20px; line-height: 1.2;"> | |
@misc{XtraGPT, | |
title = {XtraGPT}, | |
url = {https://huggingface.co/Xtra-Computing/XtraGPT-7B}, | |
author = {Nuo Chen, Andre Lin HuiKai, Junyi Hou, Zining Zhang, Qian Wang, Xidong Wang, Bingsheng He}, | |
month = {March}, | |
year = {2025} | |
} | |
</div> | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
Built with XtraGPT models | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: white; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
""" | |
# Default paper content | |
default_paper_content = """ | |
The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data. | |
""" | |
# Available models | |
AVAILABLE_MODELS = { | |
"XtraGPT-1.5B": "Xtra-Computing/XtraGPT-1.5B", | |
"XtraGPT-3B": "Xtra-Computing/XtraGPT-3B", | |
"XtraGPT-7B": "Xtra-Computing/XtraGPT-7B", | |
"XtraGPT-14B": "Xtra-Computing/XtraGPT-14B" | |
} | |
# Global variables for model and tokenizer | |
current_model = None | |
current_tokenizer = None | |
current_model_name = None | |
nougat_model = None | |
nougat_processor = None | |
def load_nougat_model(): | |
"""Load Nougat model for PDF processing""" | |
global nougat_model, nougat_processor | |
if nougat_model is None or nougat_processor is None: | |
nougat_processor = NougatProcessor.from_pretrained("facebook/nougat-base") | |
nougat_model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base") | |
nougat_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
return nougat_processor, nougat_model | |
def extract_text_from_pdf(pdf_bytes): | |
"""Extract text from uploaded PDF file using Nougat""" | |
if pdf_bytes is None: | |
return default_paper_content | |
try: | |
# 确保NLTK已安装 | |
try: | |
import nltk | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
except ImportError: | |
print("Installing NLTK...") | |
import subprocess | |
subprocess.check_call(["pip", "install", "nltk", "python-Levenshtein"]) | |
import nltk | |
nltk.download('punkt') | |
# Load Nougat model | |
processor, model = load_nougat_model() | |
# Convert PDF to images using PyMuPDF | |
doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
full_text = "" | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality | |
# Convert to PIL Image | |
img_data = pix.samples | |
img = Image.frombytes("RGB", [pix.width, pix.height], img_data) | |
# Process with Nougat | |
pixel_values = processor(img, return_tensors="pt").pixel_values.to(model.device) | |
# Generate text | |
outputs = model.generate( | |
pixel_values, | |
min_length=1, | |
max_new_tokens=1024, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
) | |
# Decode and post-process | |
page_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
page_text = processor.post_process_generation(page_text, fix_markdown=True) | |
full_text += page_text + "\n\n" | |
# Print progress | |
print(f"Processed page {page_num+1}/{len(doc)}") | |
# 检查是否已经达到15000个token的限制 | |
if len(full_text.split()) > 15000: | |
print("Reached 15000 token limit, stopping extraction") | |
break | |
# Clear GPU memory | |
del pixel_values, outputs | |
torch.cuda.empty_cache() | |
# 确保不超过15000个token | |
words = full_text.split() | |
if len(words) > 15000: | |
full_text = " ".join(words[:15000]) | |
print(f"Truncated paper content to 15000 tokens") | |
return full_text | |
except Exception as e: | |
import traceback | |
error_details = traceback.format_exc() | |
print(f"PDF extraction error: {str(e)}\n{error_details}") | |
return default_paper_content | |
finally: | |
# Clear GPU memory | |
torch.cuda.empty_cache() | |
def load_model(model_name): | |
"""Load model and tokenizer on demand""" | |
global current_model, current_tokenizer, current_model_name | |
# If the requested model is already loaded, return it | |
if current_model_name == model_name and current_model is not None and current_tokenizer is not None: | |
return current_tokenizer, current_model | |
# Clear GPU memory if a model is already loaded | |
if current_model is not None: | |
del current_model | |
del current_tokenizer | |
torch.cuda.empty_cache() | |
# Load the requested model | |
model_path = AVAILABLE_MODELS[model_name] | |
current_tokenizer = AutoTokenizer.from_pretrained(model_path) | |
current_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") | |
current_model_name = model_name | |
return current_tokenizer, current_model | |
def improve_paper_section(model_name, paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512, progress=gr.Progress()): | |
""" | |
Improve a section of an academic paper - non-streaming generation | |
""" | |
# Check inputs | |
if not selected_content or not improvement_prompt: | |
return "Please provide both text to improve and improvement requirements." | |
try: | |
progress(0.1, desc="Loading model...") | |
# Load the selected model | |
tokenizer, model = load_model(model_name) | |
progress(0.3, desc="Processing input...") | |
# Build prompt | |
content = f""" | |
Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**. | |
The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements. | |
Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper. | |
<PAPER_CONTENT> | |
{paper_content} | |
</PAPER_CONTENT> | |
<SELECTED_CONTENT> | |
{selected_content} | |
</SELECTED_CONTENT> | |
<QUESTION> | |
{improvement_prompt} | |
</QUESTION> | |
""" | |
# Prepare input | |
messages = [ | |
{"role": "user", "content": content} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Check input length and truncate to 16384 tokens before encoding | |
input_tokens = tokenizer.encode(text) | |
if len(input_tokens) > 16384: # 模型的最大上下文长度 | |
input_tokens = input_tokens[:16384] | |
text = tokenizer.decode(input_tokens) | |
print(f"Input truncated to 16384 tokens") | |
progress(0.5, desc="Generating improved text...") | |
# Generate non-streaming | |
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device) | |
# Create attention mask | |
attention_mask = torch.ones_like(input_ids) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
do_sample=(temperature > 0), | |
temperature=temperature if temperature > 0 else 1.0, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Only keep the newly generated part | |
generated_ids = output_ids[0, len(input_ids[0]):] | |
response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
progress(1.0, desc="Complete!") | |
return response | |
except Exception as e: | |
import traceback | |
error_details = traceback.format_exc() | |
print(f"Generation error: {str(e)}\n{error_details}") | |
return f"Error generating text: {str(e)}\n\nPlease try with different parameters or input." | |
# Create Gradio interface | |
with gr.Blocks(fill_height=True, css=css) as demo: | |
# Store extracted PDF text | |
extracted_pdf_text = gr.State(default_paper_content) | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
# Step 1: Upload PDF | |
with gr.Group(): | |
gr.Markdown("### Step 1: Upload your academic paper") | |
pdf_file = gr.File( | |
label="Upload PDF", | |
file_types=[".pdf"], | |
type="binary" # Get binary data directly | |
) | |
# Display extracted PDF text (moved here) | |
with gr.Accordion("Extracted PDF Content (by Nougat)", open=False): | |
pdf_content_display = gr.Textbox( | |
label="Paper Content", | |
lines=10, | |
value=default_paper_content | |
) | |
# Model selection | |
with gr.Group(): | |
gr.Markdown("### Select Model") | |
model_dropdown = gr.Dropdown( | |
choices=list(AVAILABLE_MODELS.keys()), | |
value="XtraGPT-7B", # Default selection | |
label="Select XtraGPT Model" | |
) | |
# Step 2: Extract and select text | |
with gr.Group(): | |
gr.Markdown("### Step 2: Enter the text section to improve") | |
selected_content = gr.Textbox( | |
label="Text to improve", | |
placeholder="Paste the section of text you want to improve...", | |
lines=5, | |
value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration." | |
) | |
# Step 3: Specify improvement requirements | |
with gr.Group(): | |
gr.Markdown("### Step 3: Specify your improvement requirements") | |
improvement_prompt = gr.Textbox( | |
label="Improvement requirements", | |
placeholder="e.g., 'Make this more concise', 'Add more technical details', 'Redefine this concept'...", | |
lines=3, | |
value="help me make it more concise." | |
) | |
with gr.Accordion("⚙️ Parameters", open=False): | |
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature") | |
max_tokens = gr.Slider(minimum=128, maximum=1024, step=32, value=512, label="Max Tokens") | |
submit_btn = gr.Button("Improve Text") | |
with gr.Column(): | |
# Output | |
output = gr.Textbox(label="Improved Text", lines=20) | |
# Removed the PDF content display from here | |
# Automatically extract text when PDF is uploaded | |
def update_pdf_content(pdf_bytes): | |
if pdf_bytes is not None: | |
content = extract_text_from_pdf(pdf_bytes) | |
return content, content | |
return default_paper_content, default_paper_content | |
pdf_file.change( | |
fn=update_pdf_content, | |
inputs=[pdf_file], | |
outputs=[extracted_pdf_text, pdf_content_display] | |
) | |
# Process text improvement | |
submit_btn.click( | |
fn=improve_paper_section, | |
inputs=[model_dropdown, extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens], | |
outputs=[output] | |
) | |
gr.HTML(CITATION) | |
if __name__ == "__main__": | |
demo.launch() |