SonyaX20
commited on
Commit
·
27be9b9
1
Parent(s):
79fb563
new
Browse files
app.py
CHANGED
@@ -2,6 +2,11 @@ import gradio as gr
|
|
2 |
from transformers import AutoModel, AutoTokenizer
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# 加载预训练模型和分词器
|
7 |
MODEL_NAME = "bert-base-chinese"
|
@@ -25,84 +30,75 @@ def extract_features(text):
|
|
25 |
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
|
26 |
return cls_embedding
|
27 |
|
28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def predict(text):
|
|
|
30 |
features = extract_features(text)
|
31 |
-
return f"特征维度: {features.shape}\n特征向量(部分展示): {features[:10]}"
|
32 |
|
33 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
demo = gr.Interface(
|
35 |
fn=predict,
|
36 |
inputs=gr.Textbox(lines=2, placeholder="输入中文文本..."),
|
37 |
-
outputs=
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
)
|
41 |
|
42 |
# 运行Gradio应用
|
43 |
if __name__ == "__main__":
|
44 |
-
demo.launch()
|
45 |
-
# import gradio as gr
|
46 |
-
# from huggingface_hub import InferenceClient
|
47 |
-
|
48 |
-
# """
|
49 |
-
# For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
50 |
-
# """
|
51 |
-
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
52 |
-
|
53 |
-
|
54 |
-
# def respond(
|
55 |
-
# message,
|
56 |
-
# history: list[tuple[str, str]],
|
57 |
-
# system_message,
|
58 |
-
# max_tokens,
|
59 |
-
# temperature,
|
60 |
-
# top_p,
|
61 |
-
# ):
|
62 |
-
# messages = [{"role": "system", "content": system_message}]
|
63 |
-
|
64 |
-
# for val in history:
|
65 |
-
# if val[0]:
|
66 |
-
# messages.append({"role": "user", "content": val[0]})
|
67 |
-
# if val[1]:
|
68 |
-
# messages.append({"role": "assistant", "content": val[1]})
|
69 |
-
|
70 |
-
# messages.append({"role": "user", "content": message})
|
71 |
-
|
72 |
-
# response = ""
|
73 |
-
|
74 |
-
# for message in client.chat_completion(
|
75 |
-
# messages,
|
76 |
-
# max_tokens=max_tokens,
|
77 |
-
# stream=True,
|
78 |
-
# temperature=temperature,
|
79 |
-
# top_p=top_p,
|
80 |
-
# ):
|
81 |
-
# token = message.choices[0].delta.content
|
82 |
-
|
83 |
-
# response += token
|
84 |
-
# yield response
|
85 |
-
|
86 |
-
|
87 |
-
# """
|
88 |
-
# For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
89 |
-
# """
|
90 |
-
# demo = gr.ChatInterface(
|
91 |
-
# respond,
|
92 |
-
# additional_inputs=[
|
93 |
-
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
94 |
-
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
95 |
-
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
96 |
-
# gr.Slider(
|
97 |
-
# minimum=0.1,
|
98 |
-
# maximum=1.0,
|
99 |
-
# value=0.95,
|
100 |
-
# step=0.05,
|
101 |
-
# label="Top-p (nucleus sampling)",
|
102 |
-
# ),
|
103 |
-
# ],
|
104 |
-
# )
|
105 |
-
|
106 |
-
|
107 |
-
# if __name__ == "__main__":
|
108 |
-
# demo.launch()
|
|
|
2 |
from transformers import AutoModel, AutoTokenizer
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.manifold import TSNE
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import io
|
9 |
+
import base64
|
10 |
|
11 |
# 加载预训练模型和分词器
|
12 |
MODEL_NAME = "bert-base-chinese"
|
|
|
30 |
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
|
31 |
return cls_embedding
|
32 |
|
33 |
+
# 余弦相似度计算
|
34 |
+
def cosine_similarity(vec1, vec2):
|
35 |
+
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
36 |
+
|
37 |
+
# 预定义相似性对比文本
|
38 |
+
predefined_texts = [
|
39 |
+
"今天的天气很好,我想去散步。",
|
40 |
+
"股票市场今天表现不错。",
|
41 |
+
"人工智能正在改变我们的生活。"
|
42 |
+
]
|
43 |
+
predefined_features = [extract_features(text) for text in predefined_texts]
|
44 |
+
|
45 |
+
# 绘制降维可视化
|
46 |
+
def plot_features(features):
|
47 |
+
# 用 t-SNE 进行降维
|
48 |
+
tsne = TSNE(n_components=2, random_state=42)
|
49 |
+
reduced_features = tsne.fit_transform([features] + predefined_features)
|
50 |
+
colors = ['red'] + ['blue'] * len(predefined_texts)
|
51 |
+
|
52 |
+
# 绘制图像
|
53 |
+
plt.figure(figsize=(8, 6))
|
54 |
+
for i, point in enumerate(reduced_features):
|
55 |
+
label = "Input" if i == 0 else f"Text {i}"
|
56 |
+
plt.scatter(point[0], point[1], c=colors[i], label=label)
|
57 |
+
plt.legend()
|
58 |
+
plt.title("Feature Vector Visualization (t-SNE)")
|
59 |
+
plt.xlabel("Dimension 1")
|
60 |
+
plt.ylabel("Dimension 2")
|
61 |
+
plt.grid()
|
62 |
+
|
63 |
+
# 保存图像为字符串
|
64 |
+
buf = io.BytesIO()
|
65 |
+
plt.savefig(buf, format="png")
|
66 |
+
buf.seek(0)
|
67 |
+
img_str = base64.b64encode(buf.read()).decode("utf-8")
|
68 |
+
plt.close()
|
69 |
+
return f'<img src="data:image/png;base64,{img_str}" />'
|
70 |
+
|
71 |
+
# Gradio接口函数
|
72 |
def predict(text):
|
73 |
+
# 提取特征
|
74 |
features = extract_features(text)
|
|
|
75 |
|
76 |
+
# 计算相似性
|
77 |
+
similarities = [
|
78 |
+
(predefined_texts[i], cosine_similarity(features, predefined_features[i]))
|
79 |
+
for i in range(len(predefined_texts))
|
80 |
+
]
|
81 |
+
|
82 |
+
# 构造相似性结果文本
|
83 |
+
similarity_text = "\n".join([f"与 \"{t}\" 的相似度: {s:.2f}" for t, s in similarities])
|
84 |
+
|
85 |
+
# 降维图
|
86 |
+
tsne_plot = plot_features(features)
|
87 |
+
|
88 |
+
return f"特征维度: {features.shape}\n特征向量(部分展示): {features[:10]}\n\n相似性结果:\n{similarity_text}\n", tsne_plot
|
89 |
+
|
90 |
+
# 定义Gradio界面
|
91 |
demo = gr.Interface(
|
92 |
fn=predict,
|
93 |
inputs=gr.Textbox(lines=2, placeholder="输入中文文本..."),
|
94 |
+
outputs=[
|
95 |
+
"text", # 文本输出
|
96 |
+
"html", # 图像输出
|
97 |
+
],
|
98 |
+
title="中文特征提取与分析",
|
99 |
+
description="基于BERT的中文文本特征提取,支持相似性分析与降维可视化。",
|
100 |
)
|
101 |
|
102 |
# 运行Gradio应用
|
103 |
if __name__ == "__main__":
|
104 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|