nuojohnchen commited on
Commit
60f72e5
·
verified ·
1 Parent(s): 3cc5888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -47
app.py CHANGED
@@ -1,64 +1,245 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
 
 
42
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import PyPDF2
6
+ from io import BytesIO
7
+ import torch
8
 
9
+ # 设置环境变量
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+
12
+ DESCRIPTION = '''
13
+ <div>
14
+ <h1 style="text-align: center;">Academic Paper Improver</h1>
15
+ <p>This Space helps you improve sections of your academic paper using the <a href="https://huggingface.co/Xtra-Computing/XtraGPT-7B"><b>XtraGPT-7B</b></a> model.</p>
16
+ <p>Upload your PDF paper, select a section of text you want to improve, and specify your requirements.</p>
17
+ </div>
18
+ '''
19
+
20
+ CITATION = """
21
+ <div style="font-family: monospace; white-space: pre; margin-top: 20px; line-height: 1.2;">
22
+ @misc{XtraGPT,
23
+ title = {XtraGPT},
24
+ url = {https://huggingface.co/Xtra-Computing/XtraGPT-7B},
25
+ author = {Nuo Chen, Andre Lin HuiKai, Junyi Hou, Zining Zhang, Qian Wang, Xidong Wang, Bingsheng He},
26
+ month = {March},
27
+ year = {2025}
28
+ }
29
+ </div>
30
  """
31
+
32
+ LICENSE = """
33
+ <p/>
34
+ ---
35
+ Built with XtraGPT-7B
36
  """
 
37
 
38
+ css = """
39
+ h1 {
40
+ text-align: center;
41
+ display: block;
42
+ }
43
+ #duplicate-button {
44
+ margin: auto;
45
+ color: white;
46
+ background: #1565c0;
47
+ border-radius: 100vh;
48
+ }
49
+ """
50
 
51
+ # 默认论文内容
52
+ default_paper_content = """
53
+ The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
54
+ """
 
 
 
 
 
55
 
56
+ # 直接加载模型和分词器
57
+ tokenizer = AutoTokenizer.from_pretrained("Xtra-Computing/XtraGPT-7B")
58
+ model = AutoModelForCausalLM.from_pretrained("Xtra-Computing/XtraGPT-7B", device_map="auto")
 
 
59
 
60
+ def extract_text_from_pdf(pdf_bytes):
61
+ """从上传的PDF文件中提取文本"""
62
+ if pdf_bytes is None:
63
+ return default_paper_content
64
+
65
+ try:
66
+ # 确保pdf_bytes是字节类型
67
+ if isinstance(pdf_bytes, str):
68
+ return pdf_bytes # 如果已经是字符串,直接返回
69
+
70
+ # 直接使用字节对象
71
+ pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
72
+
73
+ # 从所有页面提取文本
74
+ text = ""
75
+ for page_num in range(len(pdf_reader.pages)):
76
+ page = pdf_reader.pages[page_num]
77
+ text += page.extract_text() + "\n\n"
78
+
79
+ # 限制文本长度,防止超出模型最大长度
80
+ if len(text) > 10000: # 保守估计,留出足够空间给提示和生成
81
+ text = text[:10000] + "...(文本已截断)"
82
+
83
+ return text
84
+ except Exception as e:
85
+ print(f"PDF提取错误: {str(e)}")
86
+ return default_paper_content
87
 
88
+ @spaces.GPU(duration=120)
89
+ def improve_paper_section(paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512):
90
+ """
91
+ 改进学术论文的一个部分 - 使用非流式生成
92
+ """
93
+ # 检查输入
94
+ if not selected_content or not improvement_prompt:
95
+ return "请同时提供要改进的文本和改进要求。"
96
+
97
+ try:
98
+ # 限制paper_content长度,防止超出模型最大长度
99
+ if len(paper_content) > 10000: # 保守估计
100
+ paper_content = paper_content[:10000] + "...(文本已截断)"
101
+
102
+ # 构建提示
103
+ content = f"""
104
+ Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
105
 
106
+ The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
107
+ Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
 
 
 
 
 
 
108
 
109
+ <PAPER_CONTENT>
110
+ {paper_content}
111
+ </PAPER_CONTENT>
112
 
113
+ <SELECTED_CONTENT>
114
+ {selected_content}
115
+ </SELECTED_CONTENT>
116
 
117
+ <QUESTION>
118
+ {improvement_prompt}
119
+ </QUESTION>
120
  """
121
+
122
+ # 准备输入
123
+ messages = [
124
+ {"role": "user", "content": content}
125
+ ]
126
+
127
+ text = tokenizer.apply_chat_template(
128
+ messages,
129
+ tokenize=False,
130
+ add_generation_prompt=True
131
+ )
132
+
133
+ # 检查输入长度并截断
134
+ input_tokens = tokenizer.encode(text)
135
+ if len(input_tokens) > 15000: # 为生成留出空间
136
+ input_tokens = input_tokens[:15000]
137
+ text = tokenizer.decode(input_tokens)
138
+ print(f"输入已截断至15000个token")
139
+
140
+ # 使用非流式方式生成
141
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
142
+
143
+ with torch.no_grad():
144
+ output_ids = model.generate(
145
+ input_ids,
146
+ max_new_tokens=max_new_tokens,
147
+ do_sample=(temperature > 0),
148
+ temperature=temperature if temperature > 0 else 1.0,
149
+ pad_token_id=tokenizer.eos_token_id
150
+ )
151
+
152
+ # 只保留新生成的部分
153
+ generated_ids = output_ids[0, len(input_ids[0]):]
154
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True)
155
+
156
+ return response
157
+
158
+ except Exception as e:
159
+ import traceback
160
+ error_details = traceback.format_exc()
161
+ print(f"生成错误: {str(e)}\n{error_details}")
162
+ return f"生成文本时出错: {str(e)}\n\n请尝试使用不同的参数或输入。"
163
 
164
+ # 创建Gradio界面
165
+ with gr.Blocks(fill_height=True, css=css) as demo:
166
+ # 存储提取的PDF文本
167
+ extracted_pdf_text = gr.State(default_paper_content)
168
+
169
+ gr.Markdown(DESCRIPTION)
170
+ # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ # 步骤1:上传PDF
175
+ with gr.Group():
176
+ gr.Markdown("### Step 1: Upload your academic paper")
177
+ pdf_file = gr.File(
178
+ label="Upload PDF",
179
+ file_types=[".pdf"],
180
+ type="binary" # 直接获取二进制数据
181
+ )
182
+
183
+ # 步骤2:提取并选择文本
184
+ with gr.Group():
185
+ gr.Markdown("### Step 2: Enter the text section to improve")
186
+ selected_content = gr.Textbox(
187
+ label="Text to improve",
188
+ placeholder="Paste the section of text you want to improve...",
189
+ lines=5,
190
+ value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
191
+ )
192
+
193
+ # 步骤3:指定改进要求
194
+ with gr.Group():
195
+ gr.Markdown("### Step 3: Specify your improvement requirements")
196
+ improvement_prompt = gr.Textbox(
197
+ label="Improvement requirements",
198
+ placeholder="e.g., 'Make this more concise', 'Add more technical details', 'Redefine this concept'...",
199
+ lines=3,
200
+ value="help me make it more concise."
201
+ )
202
+
203
+ with gr.Accordion("⚙️ Parameters", open=False):
204
+ temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature")
205
+ max_tokens = gr.Slider(minimum=128, maximum=1024, step=32, value=512, label="Max Tokens")
206
+
207
+ submit_btn = gr.Button("Improve Text")
208
+
209
+ with gr.Column():
210
+ # 输出
211
+ output = gr.Textbox(label="Improved Text", lines=20)
212
+
213
+ # 显示提取的PDF文本(可折叠)
214
+ with gr.Accordion("Extracted PDF Content (for reference)", open=False):
215
+ pdf_content_display = gr.Textbox(
216
+ label="Paper Content",
217
+ lines=10,
218
+ value=default_paper_content
219
+ )
220
+
221
+ # 当PDF上传时自动提取文本
222
+ def update_pdf_content(pdf_bytes):
223
+ if pdf_bytes is not None:
224
+ content = extract_text_from_pdf(pdf_bytes)
225
+ return content, content
226
+ return default_paper_content, default_paper_content
227
+
228
+ pdf_file.change(
229
+ fn=update_pdf_content,
230
+ inputs=[pdf_file],
231
+ outputs=[extracted_pdf_text, pdf_content_display]
232
+ )
233
+
234
+ # 处理文本改进
235
+ submit_btn.click(
236
+ fn=improve_paper_section,
237
+ inputs=[extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens],
238
+ outputs=[output]
239
+ )
240
+
241
+ # gr.Markdown(LICENSE)
242
+ gr.Markdown(CITATION)
243
 
244
  if __name__ == "__main__":
245
+ demo.launch()