import gradio as gr import os import spaces from transformers import AutoModelForCausalLM, AutoTokenizer import torch from io import BytesIO from PIL import Image import fitz # PyMuPDF import numpy as np from transformers import NougatProcessor, VisionEncoderDecoderModel import nltk import ssl # 初始化NLTK try: _create_unverified_https_context = ssl._create_unverified_context except AttributeError: pass else: ssl._create_default_https_context = _create_unverified_https_context # 下载NLTK必要的数据 try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') # Set environment variables HF_TOKEN = os.environ.get("HF_TOKEN", None) DESCRIPTION = '''

Academic Paper Improver

This Space helps you improve a selected content of your academic paper using the XtraGPT model series, ensuring controllability on criteria following and in-context ability.

Upload your PDF paper, select a section of text you want to improve, and specify your requirements.

''' CITATION = """
@misc{XtraGPT, title = {XtraGPT}, url = {https://huggingface.co/Xtra-Computing/XtraGPT-7B}, author = {Nuo Chen, Andre Lin HuiKai, Junyi Hou, Zining Zhang, Qian Wang, Xidong Wang, Bingsheng He}, month = {March}, year = {2025} }
""" LICENSE = """

--- Built with XtraGPT models """ css = """ h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: white; background: #1565c0; border-radius: 100vh; } """ # Default paper content default_paper_content = """ The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data. """ # Available models AVAILABLE_MODELS = { "XtraGPT-1.5B": "Xtra-Computing/XtraGPT-1.5B", "XtraGPT-3B": "Xtra-Computing/XtraGPT-3B", "XtraGPT-7B": "Xtra-Computing/XtraGPT-7B", "XtraGPT-14B": "Xtra-Computing/XtraGPT-14B" } # Global variables for model and tokenizer current_model = None current_tokenizer = None current_model_name = None nougat_model = None nougat_processor = None @spaces.GPU(duration=200) def load_nougat_model(): """Load Nougat model for PDF processing""" global nougat_model, nougat_processor if nougat_model is None or nougat_processor is None: nougat_processor = NougatProcessor.from_pretrained("facebook/nougat-base") nougat_model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base") nougat_model.to("cuda" if torch.cuda.is_available() else "cpu") return nougat_processor, nougat_model @spaces.GPU(duration=200) def extract_text_from_pdf(pdf_bytes): """Extract text from uploaded PDF file using Nougat""" if pdf_bytes is None: return default_paper_content try: # 确保NLTK已安装 try: import nltk try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') except ImportError: print("Installing NLTK...") import subprocess subprocess.check_call(["pip", "install", "nltk", "python-Levenshtein"]) import nltk nltk.download('punkt') # Load Nougat model processor, model = load_nougat_model() # Convert PDF to images using PyMuPDF doc = fitz.open(stream=pdf_bytes, filetype="pdf") full_text = "" for page_num in range(len(doc)): page = doc.load_page(page_num) pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality # Convert to PIL Image img_data = pix.samples img = Image.frombytes("RGB", [pix.width, pix.height], img_data) # Process with Nougat pixel_values = processor(img, return_tensors="pt").pixel_values.to(model.device) # Generate text outputs = model.generate( pixel_values, min_length=1, max_new_tokens=1024, bad_words_ids=[[processor.tokenizer.unk_token_id]], ) # Decode and post-process page_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] page_text = processor.post_process_generation(page_text, fix_markdown=True) full_text += page_text + "\n\n" # Print progress print(f"Processed page {page_num+1}/{len(doc)}") # 检查是否已经达到15000个token的限制 if len(full_text.split()) > 15000: print("Reached 15000 token limit, stopping extraction") break # Clear GPU memory del pixel_values, outputs torch.cuda.empty_cache() # 确保不超过15000个token words = full_text.split() if len(words) > 15000: full_text = " ".join(words[:15000]) print(f"Truncated paper content to 15000 tokens") return full_text except Exception as e: import traceback error_details = traceback.format_exc() print(f"PDF extraction error: {str(e)}\n{error_details}") return default_paper_content finally: # Clear GPU memory torch.cuda.empty_cache() def load_model(model_name): """Load model and tokenizer on demand""" global current_model, current_tokenizer, current_model_name # If the requested model is already loaded, return it if current_model_name == model_name and current_model is not None and current_tokenizer is not None: return current_tokenizer, current_model # Clear GPU memory if a model is already loaded if current_model is not None: del current_model del current_tokenizer torch.cuda.empty_cache() # Load the requested model model_path = AVAILABLE_MODELS[model_name] current_tokenizer = AutoTokenizer.from_pretrained(model_path) current_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") current_model_name = model_name return current_tokenizer, current_model @spaces.GPU(duration=200) def improve_paper_section(model_name, paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512, progress=gr.Progress()): """ Improve a section of an academic paper - non-streaming generation """ # Check inputs if not selected_content or not improvement_prompt: return "Please provide both text to improve and improvement requirements." try: progress(0.1, desc="Loading model...") # Load the selected model tokenizer, model = load_model(model_name) progress(0.3, desc="Processing input...") # Build prompt content = f""" Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**. The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements. Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper. {paper_content} {selected_content} {improvement_prompt} """ # Prepare input messages = [ {"role": "user", "content": content} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Check input length and truncate to 16384 tokens before encoding input_tokens = tokenizer.encode(text) if len(input_tokens) > 16384: # 模型的最大上下文长度 input_tokens = input_tokens[:16384] text = tokenizer.decode(input_tokens) print(f"Input truncated to 16384 tokens") progress(0.5, desc="Generating improved text...") # Generate non-streaming input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device) # Create attention mask attention_mask = torch.ones_like(input_ids) with torch.no_grad(): output_ids = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature if temperature > 0 else 1.0, pad_token_id=tokenizer.eos_token_id ) # Only keep the newly generated part generated_ids = output_ids[0, len(input_ids[0]):] response = tokenizer.decode(generated_ids, skip_special_tokens=True) progress(1.0, desc="Complete!") return response except Exception as e: import traceback error_details = traceback.format_exc() print(f"Generation error: {str(e)}\n{error_details}") return f"Error generating text: {str(e)}\n\nPlease try with different parameters or input." # Create Gradio interface with gr.Blocks(fill_height=True, css=css) as demo: # Store extracted PDF text extracted_pdf_text = gr.State(default_paper_content) gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): # Step 1: Upload PDF with gr.Group(): gr.Markdown("### Step 1: Upload your academic paper") pdf_file = gr.File( label="Upload PDF", file_types=[".pdf"], type="binary" # Get binary data directly ) # Display extracted PDF text (moved here) with gr.Accordion("Extracted PDF Content (by Nougat)", open=False): pdf_content_display = gr.Textbox( label="Paper Content", lines=10, value=default_paper_content ) # Model selection with gr.Group(): gr.Markdown("### Select Model") model_dropdown = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="XtraGPT-7B", # Default selection label="Select XtraGPT Model" ) # Step 2: Extract and select text with gr.Group(): gr.Markdown("### Step 2: Enter the text section to improve") selected_content = gr.Textbox( label="Text to improve", placeholder="Paste the section of text you want to improve...", lines=5, value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration." ) # Step 3: Specify improvement requirements with gr.Group(): gr.Markdown("### Step 3: Specify your improvement requirements") improvement_prompt = gr.Textbox( label="Improvement requirements", placeholder="e.g., 'Make this more concise', 'Add more technical details', 'Redefine this concept'...", lines=3, value="help me make it more concise." ) with gr.Accordion("⚙️ Parameters", open=False): temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature") max_tokens = gr.Slider(minimum=128, maximum=1024, step=32, value=512, label="Max Tokens") submit_btn = gr.Button("Improve Text") with gr.Column(): # Output output = gr.Textbox(label="Improved Text", lines=20) # Removed the PDF content display from here # Automatically extract text when PDF is uploaded def update_pdf_content(pdf_bytes): if pdf_bytes is not None: content = extract_text_from_pdf(pdf_bytes) return content, content return default_paper_content, default_paper_content pdf_file.change( fn=update_pdf_content, inputs=[pdf_file], outputs=[extracted_pdf_text, pdf_content_display] ) # Process text improvement submit_btn.click( fn=improve_paper_section, inputs=[model_dropdown, extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens], outputs=[output] ) gr.HTML(CITATION) if __name__ == "__main__": demo.launch()