XtraGPT-7B / app.py
nuojohnchen's picture
Update app.py
044d0d9 verified
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from io import BytesIO
from PIL import Image
import fitz # PyMuPDF
import numpy as np
from transformers import NougatProcessor, VisionEncoderDecoderModel
import nltk
import ssl
# 初始化NLTK
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
# 下载NLTK必要的数据
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Set environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Academic Paper Improver</h1>
<p>This Space helps you improve a selected content of your academic paper using the <a href="https://huggingface.co/Xtra-Computing/XtraGPT-7B"><b>XtraGPT</b></a> model series, ensuring controllability on criteria following and in-context ability.</p>
<p>Upload your PDF paper, select a section of text you want to improve, and specify your requirements.</p>
</div>
'''
CITATION = """
<div style="font-family: monospace; white-space: pre; margin-top: 20px; line-height: 1.2;">
@misc{XtraGPT,
title = {XtraGPT},
url = {https://huggingface.co/Xtra-Computing/XtraGPT-7B},
author = {Nuo Chen, Andre Lin HuiKai, Junyi Hou, Zining Zhang, Qian Wang, Xidong Wang, Bingsheng He},
month = {March},
year = {2025}
}
</div>
"""
LICENSE = """
<p/>
---
Built with XtraGPT models
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Default paper content
default_paper_content = """
The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
"""
# Available models
AVAILABLE_MODELS = {
"XtraGPT-1.5B": "Xtra-Computing/XtraGPT-1.5B",
"XtraGPT-3B": "Xtra-Computing/XtraGPT-3B",
"XtraGPT-7B": "Xtra-Computing/XtraGPT-7B",
"XtraGPT-14B": "Xtra-Computing/XtraGPT-14B"
}
# Global variables for model and tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
nougat_model = None
nougat_processor = None
@spaces.GPU(duration=200)
def load_nougat_model():
"""Load Nougat model for PDF processing"""
global nougat_model, nougat_processor
if nougat_model is None or nougat_processor is None:
nougat_processor = NougatProcessor.from_pretrained("facebook/nougat-base")
nougat_model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base")
nougat_model.to("cuda" if torch.cuda.is_available() else "cpu")
return nougat_processor, nougat_model
@spaces.GPU(duration=200)
def extract_text_from_pdf(pdf_bytes):
"""Extract text from uploaded PDF file using Nougat"""
if pdf_bytes is None:
return default_paper_content
try:
# 确保NLTK已安装
try:
import nltk
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
except ImportError:
print("Installing NLTK...")
import subprocess
subprocess.check_call(["pip", "install", "nltk", "python-Levenshtein"])
import nltk
nltk.download('punkt')
# Load Nougat model
processor, model = load_nougat_model()
# Convert PDF to images using PyMuPDF
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
full_text = ""
for page_num in range(len(doc)):
page = doc.load_page(page_num)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality
# Convert to PIL Image
img_data = pix.samples
img = Image.frombytes("RGB", [pix.width, pix.height], img_data)
# Process with Nougat
pixel_values = processor(img, return_tensors="pt").pixel_values.to(model.device)
# Generate text
outputs = model.generate(
pixel_values,
min_length=1,
max_new_tokens=1024,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
)
# Decode and post-process
page_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
page_text = processor.post_process_generation(page_text, fix_markdown=True)
full_text += page_text + "\n\n"
# Print progress
print(f"Processed page {page_num+1}/{len(doc)}")
# 检查是否已经达到15000个token的限制
if len(full_text.split()) > 15000:
print("Reached 15000 token limit, stopping extraction")
break
# Clear GPU memory
del pixel_values, outputs
torch.cuda.empty_cache()
# 确保不超过15000个token
words = full_text.split()
if len(words) > 15000:
full_text = " ".join(words[:15000])
print(f"Truncated paper content to 15000 tokens")
return full_text
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"PDF extraction error: {str(e)}\n{error_details}")
return default_paper_content
finally:
# Clear GPU memory
torch.cuda.empty_cache()
def load_model(model_name):
"""Load model and tokenizer on demand"""
global current_model, current_tokenizer, current_model_name
# If the requested model is already loaded, return it
if current_model_name == model_name and current_model is not None and current_tokenizer is not None:
return current_tokenizer, current_model
# Clear GPU memory if a model is already loaded
if current_model is not None:
del current_model
del current_tokenizer
torch.cuda.empty_cache()
# Load the requested model
model_path = AVAILABLE_MODELS[model_name]
current_tokenizer = AutoTokenizer.from_pretrained(model_path)
current_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
current_model_name = model_name
return current_tokenizer, current_model
@spaces.GPU(duration=200)
def improve_paper_section(model_name, paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512, progress=gr.Progress()):
"""
Improve a section of an academic paper - non-streaming generation
"""
# Check inputs
if not selected_content or not improvement_prompt:
return "Please provide both text to improve and improvement requirements."
try:
progress(0.1, desc="Loading model...")
# Load the selected model
tokenizer, model = load_model(model_name)
progress(0.3, desc="Processing input...")
# Build prompt
content = f"""
Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
<PAPER_CONTENT>
{paper_content}
</PAPER_CONTENT>
<SELECTED_CONTENT>
{selected_content}
</SELECTED_CONTENT>
<QUESTION>
{improvement_prompt}
</QUESTION>
"""
# Prepare input
messages = [
{"role": "user", "content": content}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Check input length and truncate to 16384 tokens before encoding
input_tokens = tokenizer.encode(text)
if len(input_tokens) > 16384: # 模型的最大上下文长度
input_tokens = input_tokens[:16384]
text = tokenizer.decode(input_tokens)
print(f"Input truncated to 16384 tokens")
progress(0.5, desc="Generating improved text...")
# Generate non-streaming
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
# Create attention mask
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
output_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=(temperature > 0),
temperature=temperature if temperature > 0 else 1.0,
pad_token_id=tokenizer.eos_token_id
)
# Only keep the newly generated part
generated_ids = output_ids[0, len(input_ids[0]):]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
progress(1.0, desc="Complete!")
return response
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Generation error: {str(e)}\n{error_details}")
return f"Error generating text: {str(e)}\n\nPlease try with different parameters or input."
# Create Gradio interface
with gr.Blocks(fill_height=True, css=css) as demo:
# Store extracted PDF text
extracted_pdf_text = gr.State(default_paper_content)
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
# Step 1: Upload PDF
with gr.Group():
gr.Markdown("### Step 1: Upload your academic paper")
pdf_file = gr.File(
label="Upload PDF",
file_types=[".pdf"],
type="binary" # Get binary data directly
)
# Display extracted PDF text (moved here)
with gr.Accordion("Extracted PDF Content (by Nougat)", open=False):
pdf_content_display = gr.Textbox(
label="Paper Content",
lines=10,
value=default_paper_content
)
# Model selection
with gr.Group():
gr.Markdown("### Select Model")
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="XtraGPT-7B", # Default selection
label="Select XtraGPT Model"
)
# Step 2: Extract and select text
with gr.Group():
gr.Markdown("### Step 2: Enter the text section to improve")
selected_content = gr.Textbox(
label="Text to improve",
placeholder="Paste the section of text you want to improve...",
lines=5,
value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
)
# Step 3: Specify improvement requirements
with gr.Group():
gr.Markdown("### Step 3: Specify your improvement requirements")
improvement_prompt = gr.Textbox(
label="Improvement requirements",
placeholder="e.g., 'Make this more concise', 'Add more technical details', 'Redefine this concept'...",
lines=3,
value="help me make it more concise."
)
with gr.Accordion("⚙️ Parameters", open=False):
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature")
max_tokens = gr.Slider(minimum=128, maximum=1024, step=32, value=512, label="Max Tokens")
submit_btn = gr.Button("Improve Text")
with gr.Column():
# Output
output = gr.Textbox(label="Improved Text", lines=20)
# Removed the PDF content display from here
# Automatically extract text when PDF is uploaded
def update_pdf_content(pdf_bytes):
if pdf_bytes is not None:
content = extract_text_from_pdf(pdf_bytes)
return content, content
return default_paper_content, default_paper_content
pdf_file.change(
fn=update_pdf_content,
inputs=[pdf_file],
outputs=[extracted_pdf_text, pdf_content_display]
)
# Process text improvement
submit_btn.click(
fn=improve_paper_section,
inputs=[model_dropdown, extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens],
outputs=[output]
)
gr.HTML(CITATION)
if __name__ == "__main__":
demo.launch()