SonyaX20 commited on
Commit
27be9b9
·
1 Parent(s): 79fb563
Files changed (1) hide show
  1. app.py +67 -71
app.py CHANGED
@@ -2,6 +2,11 @@ import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from datasets import load_dataset
4
  import torch
 
 
 
 
 
5
 
6
  # 加载预训练模型和分词器
7
  MODEL_NAME = "bert-base-chinese"
@@ -25,84 +30,75 @@ def extract_features(text):
25
  cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
26
  return cls_embedding
27
 
28
- # Gradio接口
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def predict(text):
 
30
  features = extract_features(text)
31
- return f"特征维度: {features.shape}\n特征向量(部分展示): {features[:10]}"
32
 
33
- # 定义界面
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  demo = gr.Interface(
35
  fn=predict,
36
  inputs=gr.Textbox(lines=2, placeholder="输入中文文本..."),
37
- outputs="text",
38
- title="中文特征提取",
39
- description="基于BERT的中文文本特征提取,使用tnews数据集进行微调。",
 
 
 
40
  )
41
 
42
  # 运行Gradio应用
43
  if __name__ == "__main__":
44
- demo.launch()
45
- # import gradio as gr
46
- # from huggingface_hub import InferenceClient
47
-
48
- # """
49
- # For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
50
- # """
51
- # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
52
-
53
-
54
- # def respond(
55
- # message,
56
- # history: list[tuple[str, str]],
57
- # system_message,
58
- # max_tokens,
59
- # temperature,
60
- # top_p,
61
- # ):
62
- # messages = [{"role": "system", "content": system_message}]
63
-
64
- # for val in history:
65
- # if val[0]:
66
- # messages.append({"role": "user", "content": val[0]})
67
- # if val[1]:
68
- # messages.append({"role": "assistant", "content": val[1]})
69
-
70
- # messages.append({"role": "user", "content": message})
71
-
72
- # response = ""
73
-
74
- # for message in client.chat_completion(
75
- # messages,
76
- # max_tokens=max_tokens,
77
- # stream=True,
78
- # temperature=temperature,
79
- # top_p=top_p,
80
- # ):
81
- # token = message.choices[0].delta.content
82
-
83
- # response += token
84
- # yield response
85
-
86
-
87
- # """
88
- # For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
89
- # """
90
- # demo = gr.ChatInterface(
91
- # respond,
92
- # additional_inputs=[
93
- # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
94
- # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
95
- # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
96
- # gr.Slider(
97
- # minimum=0.1,
98
- # maximum=1.0,
99
- # value=0.95,
100
- # step=0.05,
101
- # label="Top-p (nucleus sampling)",
102
- # ),
103
- # ],
104
- # )
105
-
106
-
107
- # if __name__ == "__main__":
108
- # demo.launch()
 
2
  from transformers import AutoModel, AutoTokenizer
3
  from datasets import load_dataset
4
  import torch
5
+ import numpy as np
6
+ from sklearn.manifold import TSNE
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+ import base64
10
 
11
  # 加载预训练模型和分词器
12
  MODEL_NAME = "bert-base-chinese"
 
30
  cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
31
  return cls_embedding
32
 
33
+ # 余弦相似度计算
34
+ def cosine_similarity(vec1, vec2):
35
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
36
+
37
+ # 预定义相似性对比文本
38
+ predefined_texts = [
39
+ "今天的天气很好,我想去散步。",
40
+ "股票市场今天表现不错。",
41
+ "人工智能正在改变我们的生活。"
42
+ ]
43
+ predefined_features = [extract_features(text) for text in predefined_texts]
44
+
45
+ # 绘制降维可视化
46
+ def plot_features(features):
47
+ # 用 t-SNE 进行降维
48
+ tsne = TSNE(n_components=2, random_state=42)
49
+ reduced_features = tsne.fit_transform([features] + predefined_features)
50
+ colors = ['red'] + ['blue'] * len(predefined_texts)
51
+
52
+ # 绘制图像
53
+ plt.figure(figsize=(8, 6))
54
+ for i, point in enumerate(reduced_features):
55
+ label = "Input" if i == 0 else f"Text {i}"
56
+ plt.scatter(point[0], point[1], c=colors[i], label=label)
57
+ plt.legend()
58
+ plt.title("Feature Vector Visualization (t-SNE)")
59
+ plt.xlabel("Dimension 1")
60
+ plt.ylabel("Dimension 2")
61
+ plt.grid()
62
+
63
+ # 保存图像为字符串
64
+ buf = io.BytesIO()
65
+ plt.savefig(buf, format="png")
66
+ buf.seek(0)
67
+ img_str = base64.b64encode(buf.read()).decode("utf-8")
68
+ plt.close()
69
+ return f'<img src="data:image/png;base64,{img_str}" />'
70
+
71
+ # Gradio接口函数
72
  def predict(text):
73
+ # 提取特征
74
  features = extract_features(text)
 
75
 
76
+ # 计算相似性
77
+ similarities = [
78
+ (predefined_texts[i], cosine_similarity(features, predefined_features[i]))
79
+ for i in range(len(predefined_texts))
80
+ ]
81
+
82
+ # 构造相似性结果文本
83
+ similarity_text = "\n".join([f"与 \"{t}\" 的相似度: {s:.2f}" for t, s in similarities])
84
+
85
+ # 降维图
86
+ tsne_plot = plot_features(features)
87
+
88
+ return f"特征维度: {features.shape}\n特征向量(部分展示): {features[:10]}\n\n相似性结果:\n{similarity_text}\n", tsne_plot
89
+
90
+ # 定义Gradio界面
91
  demo = gr.Interface(
92
  fn=predict,
93
  inputs=gr.Textbox(lines=2, placeholder="输入中文文本..."),
94
+ outputs=[
95
+ "text", # 文本输出
96
+ "html", # 图像输出
97
+ ],
98
+ title="中文特征提取与分析",
99
+ description="基于BERT的中文文本特征提取,支持相似性分析与降维可视化。",
100
  )
101
 
102
  # 运行Gradio应用
103
  if __name__ == "__main__":
104
+ demo.launch()