import gradio as gr from transformers import AutoModel, AutoTokenizer from datasets import load_dataset import torch import numpy as np from sklearn.manifold import TSNE import matplotlib.pyplot as plt # 加载预训练模型和分词器 MODEL_NAME = "bert-base-chinese" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) # 加载tnews数据集 dataset = load_dataset("clue", "tnews") # 数据预处理函数 def preprocess_text(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) return inputs # 特征提取函数 def extract_features(text): inputs = preprocess_text(text) with torch.no_grad(): outputs = model(**inputs) # 使用[CLS] token的表示作为特征 cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy() return cls_embedding # 余弦相似度计算 def cosine_similarity(vec1, vec2): return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) # 预定义相似性对比文本 predefined_texts = [ "今天的天气很好,我想去散步。", "股票市场今天表现不错。", "人工智能正在改变我们的生活。" ] predefined_features = [extract_features(text) for text in predefined_texts] # 绘制降维可视化 def plot_features(features): # 用 t-SNE 进行降维 tsne = TSNE(n_components=2, random_state=42) reduced_features = tsne.fit_transform([features] + predefined_features) colors = ['red'] + ['blue'] * len(predefined_texts) # 绘制图像 plt.figure(figsize=(8, 6)) for i, point in enumerate(reduced_features): label = "Input" if i == 0 else f"Text {i}" plt.scatter(point[0], point[1], c=colors[i], label=label) plt.legend() plt.title("Feature Vector Visualization (t-SNE)") plt.xlabel("Dimension 1") plt.ylabel("Dimension 2") plt.grid() # 保存图像为字符串 buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) img_str = base64.b64encode(buf.read()).decode("utf-8") plt.close() return f'' # Gradio接口函数 def predict(text): # 提取特征 features = extract_features(text) # 计算相似性 similarities = [ (predefined_texts[i], cosine_similarity(features, predefined_features[i])) for i in range(len(predefined_texts)) ] # 构造相似性结果文本 similarity_text = "\n".join([f"与 \"{t}\" 的相似度: {s:.2f}" for t, s in similarities]) # 降维图 tsne_plot = plot_features(features) return f"特征维度: {features.shape}\n特征向量(部分展示): {features[:10]}\n\n相似性结果:\n{similarity_text}\n", tsne_plot # 定义Gradio界面 demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=2, placeholder="输入中文文本..."), outputs=[ "text", # 文本输出 "html", # 图像输出 ], title="中文特征提取与分析", description="基于BERT的中文文本特征提取,支持相似性分析与降维可视化。", ) # 运行Gradio应用 if __name__ == "__main__": demo.launch()