File size: 14,166 Bytes
5a97013
60f72e5
 
 
 
001b7f2
 
3ff06ea
 
001b7f2
7895988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97013
ca06ed0
60f72e5
 
 
 
 
5bd31ed
60f72e5
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97013
60f72e5
 
 
 
ca06ed0
5a97013
 
60f72e5
 
 
 
 
 
 
 
 
 
 
 
5a97013
ca06ed0
60f72e5
 
 
5a97013
ca06ed0
 
 
 
 
 
 
 
 
 
 
 
001b7f2
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97013
001b7f2
60f72e5
001b7f2
60f72e5
 
 
 
7895988
 
 
 
 
 
 
 
 
 
 
 
 
 
001b7f2
 
 
3ff06ea
 
001b7f2
 
3ff06ea
 
 
 
 
 
 
 
001b7f2
 
 
 
 
 
 
3ff06ea
001b7f2
 
 
 
 
 
60f72e5
001b7f2
3ff06ea
 
 
044d0d9
 
 
 
 
60f72e5
001b7f2
 
 
60f72e5
044d0d9
 
 
 
 
 
001b7f2
60f72e5
3ff06ea
 
 
60f72e5
001b7f2
 
 
5a97013
ca06ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e53ff8e
ca06ed0
60f72e5
ca06ed0
60f72e5
ca06ed0
60f72e5
ca06ed0
60f72e5
 
ca06ed0
 
 
60f72e5
ca06ed0
 
60f72e5
 
 
 
 
 
 
 
 
 
 
 
 
5a97013
60f72e5
ca06ed0
60f72e5
 
 
 
 
 
 
 
 
 
044d0d9
60f72e5
044d0d9
 
60f72e5
044d0d9
60f72e5
ca06ed0
 
60f72e5
 
7ede7c6
 
 
60f72e5
 
 
044d0d9
60f72e5
 
 
 
 
 
ca06ed0
60f72e5
 
 
ca06ed0
60f72e5
 
 
 
 
ca06ed0
 
5a97013
ca06ed0
60f72e5
ca06ed0
60f72e5
 
 
 
 
 
ca06ed0
60f72e5
 
 
 
 
ca06ed0
 
 
5fd919a
 
 
 
 
 
 
 
ca06ed0
 
 
 
 
 
 
60f72e5
 
ca06ed0
60f72e5
 
 
 
 
 
 
 
 
ca06ed0
60f72e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca06ed0
60f72e5
 
5fd919a
60f72e5
ca06ed0
60f72e5
 
 
 
 
 
 
 
 
 
 
 
ca06ed0
60f72e5
 
ca06ed0
60f72e5
 
 
ca06ed0
5a97013
 
60f72e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from io import BytesIO
from PIL import Image
import fitz  # PyMuPDF
import numpy as np
from transformers import NougatProcessor, VisionEncoderDecoderModel
import nltk
import ssl

# 初始化NLTK
try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

# 下载NLTK必要的数据
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Set environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", None)

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Academic Paper Improver</h1>
<p>This Space helps you improve a selected content of your academic paper using the <a href="https://huggingface.co/Xtra-Computing/XtraGPT-7B"><b>XtraGPT</b></a> model series, ensuring controllability on criteria following and in-context ability.</p>
<p>Upload your PDF paper, select a section of text you want to improve, and specify your requirements.</p>
</div>
'''

CITATION = """
<div style="font-family: monospace; white-space: pre; margin-top: 20px; line-height: 1.2;">
@misc{XtraGPT,
    title = {XtraGPT},
    url = {https://huggingface.co/Xtra-Computing/XtraGPT-7B},
    author = {Nuo Chen, Andre Lin HuiKai, Junyi Hou, Zining Zhang, Qian Wang, Xidong Wang, Bingsheng He},
    month = {March},
    year = {2025}
}
</div>
"""

LICENSE = """
<p/>
---
Built with XtraGPT models
"""

css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Default paper content
default_paper_content = """
The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
"""

# Available models
AVAILABLE_MODELS = {
    "XtraGPT-1.5B": "Xtra-Computing/XtraGPT-1.5B",
    "XtraGPT-3B": "Xtra-Computing/XtraGPT-3B",
    "XtraGPT-7B": "Xtra-Computing/XtraGPT-7B",
    "XtraGPT-14B": "Xtra-Computing/XtraGPT-14B"
}

# Global variables for model and tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
nougat_model = None
nougat_processor = None

@spaces.GPU(duration=200)
def load_nougat_model():
    """Load Nougat model for PDF processing"""
    global nougat_model, nougat_processor
    
    if nougat_model is None or nougat_processor is None:
        nougat_processor = NougatProcessor.from_pretrained("facebook/nougat-base")
        nougat_model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base")
        nougat_model.to("cuda" if torch.cuda.is_available() else "cpu")
    
    return nougat_processor, nougat_model

@spaces.GPU(duration=200)
def extract_text_from_pdf(pdf_bytes):
    """Extract text from uploaded PDF file using Nougat"""
    if pdf_bytes is None:
        return default_paper_content
    
    try:
        # 确保NLTK已安装
        try:
            import nltk
            try:
                nltk.data.find('tokenizers/punkt')
            except LookupError:
                nltk.download('punkt')
        except ImportError:
            print("Installing NLTK...")
            import subprocess
            subprocess.check_call(["pip", "install", "nltk", "python-Levenshtein"])
            import nltk
            nltk.download('punkt')
        
        # Load Nougat model
        processor, model = load_nougat_model()
        
        # Convert PDF to images using PyMuPDF
        doc = fitz.open(stream=pdf_bytes, filetype="pdf")
        full_text = ""
        
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)
            pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))  # 2x zoom for better quality
            
            # Convert to PIL Image
            img_data = pix.samples
            img = Image.frombytes("RGB", [pix.width, pix.height], img_data)
            
            # Process with Nougat
            pixel_values = processor(img, return_tensors="pt").pixel_values.to(model.device)
            
            # Generate text
            outputs = model.generate(
                pixel_values,
                min_length=1,
                max_new_tokens=1024,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
            )
            
            # Decode and post-process
            page_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
            page_text = processor.post_process_generation(page_text, fix_markdown=True)
            
            full_text += page_text + "\n\n"
            
            # Print progress
            print(f"Processed page {page_num+1}/{len(doc)}")
            
            # 检查是否已经达到15000个token的限制
            if len(full_text.split()) > 15000:
                print("Reached 15000 token limit, stopping extraction")
                break
        
        # Clear GPU memory
        del pixel_values, outputs
        torch.cuda.empty_cache()
        
        # 确保不超过15000个token
        words = full_text.split()
        if len(words) > 15000:
            full_text = " ".join(words[:15000])
            print(f"Truncated paper content to 15000 tokens")
        
        return full_text
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        print(f"PDF extraction error: {str(e)}\n{error_details}")
        return default_paper_content
    finally:
        # Clear GPU memory
        torch.cuda.empty_cache()

def load_model(model_name):
    """Load model and tokenizer on demand"""
    global current_model, current_tokenizer, current_model_name
    
    # If the requested model is already loaded, return it
    if current_model_name == model_name and current_model is not None and current_tokenizer is not None:
        return current_tokenizer, current_model
    
    # Clear GPU memory if a model is already loaded
    if current_model is not None:
        del current_model
        del current_tokenizer
        torch.cuda.empty_cache()
    
    # Load the requested model
    model_path = AVAILABLE_MODELS[model_name]
    current_tokenizer = AutoTokenizer.from_pretrained(model_path)
    current_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    current_model_name = model_name
    
    return current_tokenizer, current_model

@spaces.GPU(duration=200)
def improve_paper_section(model_name, paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512, progress=gr.Progress()):
    """
    Improve a section of an academic paper - non-streaming generation
    """
    # Check inputs
    if not selected_content or not improvement_prompt:
        return "Please provide both text to improve and improvement requirements."
    
    try:
        progress(0.1, desc="Loading model...")
        # Load the selected model
        tokenizer, model = load_model(model_name)
        
        progress(0.3, desc="Processing input...")
        # Build prompt
        content = f"""
Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
<PAPER_CONTENT>
{paper_content}
</PAPER_CONTENT>
<SELECTED_CONTENT>
{selected_content}
</SELECTED_CONTENT>
<QUESTION>
{improvement_prompt}
</QUESTION>
"""
        
        # Prepare input
        messages = [
            {"role": "user", "content": content}
        ]
        
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Check input length and truncate to 16384 tokens before encoding
        input_tokens = tokenizer.encode(text)
        if len(input_tokens) > 16384:  # 模型的最大上下文长度
            input_tokens = input_tokens[:16384]
            text = tokenizer.decode(input_tokens)
            print(f"Input truncated to 16384 tokens")
        
        progress(0.5, desc="Generating improved text...")
        # Generate non-streaming
        input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
        
        # Create attention mask
        attention_mask = torch.ones_like(input_ids)
        
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=(temperature > 0),
                temperature=temperature if temperature > 0 else 1.0,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Only keep the newly generated part
        generated_ids = output_ids[0, len(input_ids[0]):]
        response = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        progress(1.0, desc="Complete!")
        return response
        
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        print(f"Generation error: {str(e)}\n{error_details}")
        return f"Error generating text: {str(e)}\n\nPlease try with different parameters or input."

# Create Gradio interface
with gr.Blocks(fill_height=True, css=css) as demo:
    # Store extracted PDF text
    extracted_pdf_text = gr.State(default_paper_content)
    
    gr.Markdown(DESCRIPTION)
    
    with gr.Row():
        with gr.Column():
            # Step 1: Upload PDF
            with gr.Group():
                gr.Markdown("### Step 1: Upload your academic paper")
                pdf_file = gr.File(
                    label="Upload PDF",
                    file_types=[".pdf"],
                    type="binary"  # Get binary data directly
                )
            
            # Display extracted PDF text (moved here)
            with gr.Accordion("Extracted PDF Content (by Nougat)", open=False):
                pdf_content_display = gr.Textbox(
                    label="Paper Content", 
                    lines=10,
                    value=default_paper_content
                )
            
            # Model selection
            with gr.Group():
                gr.Markdown("### Select Model")
                model_dropdown = gr.Dropdown(
                    choices=list(AVAILABLE_MODELS.keys()),
                    value="XtraGPT-7B",  # Default selection
                    label="Select XtraGPT Model"
                )
            
            # Step 2: Extract and select text
            with gr.Group():
                gr.Markdown("### Step 2: Enter the text section to improve")
                selected_content = gr.Textbox(
                    label="Text to improve",
                    placeholder="Paste the section of text you want to improve...",
                    lines=5,
                    value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
                )
            
            # Step 3: Specify improvement requirements
            with gr.Group():
                gr.Markdown("### Step 3: Specify your improvement requirements")
                improvement_prompt = gr.Textbox(
                    label="Improvement requirements",
                    placeholder="e.g., 'Make this more concise', 'Add more technical details', 'Redefine this concept'...",
                    lines=3,
                    value="help me make it more concise."
                )
            
            with gr.Accordion("⚙️ Parameters", open=False):
                temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature")
                max_tokens = gr.Slider(minimum=128, maximum=1024, step=32, value=512, label="Max Tokens")
            
            submit_btn = gr.Button("Improve Text")
        
        with gr.Column():
            # Output
            output = gr.Textbox(label="Improved Text", lines=20)
            
            # Removed the PDF content display from here
    
    # Automatically extract text when PDF is uploaded
    def update_pdf_content(pdf_bytes):
        if pdf_bytes is not None:
            content = extract_text_from_pdf(pdf_bytes)
            return content, content
        return default_paper_content, default_paper_content
    
    pdf_file.change(
        fn=update_pdf_content,
        inputs=[pdf_file],
        outputs=[extracted_pdf_text, pdf_content_display]
    )
    
    # Process text improvement
    submit_btn.click(
        fn=improve_paper_section,
        inputs=[model_dropdown, extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens],
        outputs=[output]
    )
    
    gr.HTML(CITATION)

if __name__ == "__main__":
    demo.launch()