Spaces:
Runtime error
Runtime error
import pandas as pd | |
import torch | |
import random | |
import numpy as np | |
import os | |
import json | |
import matplotlib.pyplot as plt | |
import wandb | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig | |
from datasets import Dataset | |
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback | |
from sklearn.model_selection import train_test_split | |
from sklearn.utils.class_weight import compute_class_weight | |
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix | |
import seaborn as sns | |
from datetime import datetime | |
import torch.nn.functional as F | |
# ✅ 設定隨機種子,確保結果可重現 | |
seed = 42 | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# ✅ 初始化wandb(可選,如果不需要可以禁用) | |
# 選項1: 使用API密鑰 | |
# WANDB_API_KEY = "YOUR_API_KEY_HERE" | |
# os.environ["WANDB_API_KEY"] = WANDB_API_KEY | |
# 選項2: 完全禁用wandb | |
os.environ["WANDB_DISABLED"] = "true" | |
# 初始化wandb (如果未禁用) | |
if os.environ.get("WANDB_DISABLED") != "true": | |
run = wandb.init( | |
project="chinese-topic-classifier", | |
name=f"roberta-topic-classifier-{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
config={ | |
"model_name": "hfl/chinese-roberta-wwm-ext", | |
"epochs": 12, | |
"batch_size": 8, | |
"learning_rate": 1e-5, | |
"weight_decay": 0.01, | |
"max_length": 128, | |
"seed": seed | |
} | |
) | |
# ✅ 創建輸出目錄 | |
base_output_dir = "./roberta_output" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_dir = f"{base_output_dir}_{timestamp}" | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) | |
os.makedirs(f"{output_dir}/results", exist_ok=True) | |
os.makedirs(f"{output_dir}/logs", exist_ok=True) | |
os.makedirs(f"{output_dir}/api", exist_ok=True) | |
# ✅ 配置日誌記錄 | |
import logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(f"{output_dir}/logs/training.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
logger.info(f"開始訓練,輸出目錄: {output_dir}") | |
# ✅ 讀取 CSV 檔案 (假設CSV已上傳到Hugging Face空間) | |
file_path = "ragproject7.csv" | |
logger.info(f"讀取資料集: {file_path}") | |
df = pd.read_csv(file_path) | |
logger.info(f"資料集大小: {df.shape}") | |
# ✅ 處理數據 | |
df = df[['text', 'topic']].dropna() | |
df = df.drop_duplicates(subset=['text']) # 刪除重複文本 | |
unique_topics = df["topic"].unique() | |
logger.info(f"類別數量: {len(unique_topics)}") | |
logger.info(f"類別分布: \n{df['topic'].value_counts()}") | |
# 創建類別映射字典 | |
topic_dict = {topic: i for i, topic in enumerate(unique_topics)} | |
inv_topic_dict = {i: topic for topic, i in topic_dict.items()} | |
# 更新wandb配置 (如果啟用) | |
if os.environ.get("WANDB_DISABLED") != "true": | |
wandb.config.update({ | |
"num_classes": len(unique_topics), | |
"class_distribution": df['topic'].value_counts().to_dict(), | |
"topic_dict": topic_dict | |
}) | |
# 保存類別對照表,便於未來使用 | |
with open(f"{output_dir}/topic_dict.json", "w", encoding="utf-8") as f: | |
json.dump(topic_dict, f, ensure_ascii=False, indent=2) | |
logger.info(f"保存類別對照表,共 {len(unique_topics)} 個類別") | |
# 將類別轉換為數字 | |
df["numeric_topic"] = df["topic"].map(topic_dict) | |
# ✅ 計算類別權重以處理不平衡問題 | |
class_counts = df['numeric_topic'].value_counts().sort_index() | |
total_samples = len(df) | |
class_weights = torch.FloatTensor([total_samples / (len(class_counts) * count) for count in class_counts]) | |
logger.info(f"類別權重: {class_weights}") | |
if os.environ.get("WANDB_DISABLED") != "true": | |
wandb.config.update({"class_weights": class_weights.tolist()}) | |
# ✅ 載入分詞器 (在Hugging Face上應該可以順利載入) | |
model_name = "hfl/chinese-roberta-wwm-ext" | |
logger.info(f"正在載入分詞器: {model_name}") | |
# 使用AutoTokenizer替代特定的RobertaTokenizer,增加兼容性 | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logger.info("成功載入分詞器") | |
except Exception as e: | |
logger.error(f"載入分詞器時發生錯誤: {e}") | |
# 在Hugging Face平台嘗試備用模型 | |
backup_model_names = ["hfl/chinese-macbert-base", "bert-base-chinese"] | |
for backup_name in backup_model_names: | |
try: | |
logger.info(f"嘗試載入備用分詞器: {backup_name}") | |
tokenizer = AutoTokenizer.from_pretrained(backup_name) | |
model_name = backup_name # 更新模型名稱 | |
logger.info(f"成功載入備用分詞器: {backup_name}") | |
break | |
except Exception as e2: | |
logger.error(f"載入備用分詞器 {backup_name} 失敗: {e2}") | |
else: | |
raise Exception("無法載入任何分詞器,請檢查環境設定") | |
# ✅ 定義評估指標計算函數 | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
predictions = np.argmax(logits, axis=1) | |
# 計算基本指標 | |
acc = accuracy_score(labels, predictions) | |
f1 = f1_score(labels, predictions, average='weighted') | |
precision = precision_score(labels, predictions, average='weighted') | |
recall = recall_score(labels, predictions, average='weighted') | |
# 計算每個類別的F1分數 | |
f1_per_class = f1_score(labels, predictions, average=None) | |
f1_per_class_dict = {inv_topic_dict[i]: score for i, score in enumerate(f1_per_class)} | |
# 返回結果 | |
result = { | |
'accuracy': acc, | |
'f1': f1, | |
'precision': precision, | |
'recall': recall, | |
} | |
# 添加每個類別的F1分數 | |
for class_name, score in f1_per_class_dict.items(): | |
result[f'f1_{class_name}'] = score | |
return result | |
# ✅ 定義 tokenization 方法 | |
max_length = 128 | |
def tokenize_function(examples): | |
return tokenizer( | |
examples["text"], | |
padding="max_length", | |
truncation=True, | |
max_length=max_length, | |
return_tensors="np" | |
) | |
# ✅ 準備數據集 | |
logger.info("正在處理數據集...") | |
# 添加數據增強(針對少數類) | |
# 對類別樣本數據進行統計 | |
class_samples = df['numeric_topic'].value_counts().sort_index() | |
max_samples = class_samples.max() | |
augmented_texts = [] | |
augmented_topics = [] | |
# 對少數類進行簡單的數據增強(這裡可以根據需求改進增強方法) | |
for idx, count in enumerate(class_samples): | |
if count < max_samples * 0.5: # 如果樣本數小於最多類的一半 | |
# 找出這個類別的所有樣本 | |
class_texts = df[df['numeric_topic'] == idx]['text'].tolist() | |
# 計算需要增加的樣本數 | |
augment_count = int(max_samples * 0.7) - count | |
if augment_count > 0 and len(class_texts) > 0: | |
# 從現有樣本中隨機抽樣進行輕微修改 | |
for _ in range(augment_count): | |
text = random.choice(class_texts) | |
# 簡單增強:隨機刪除一些字符或重複一些字符 | |
if len(text) > 20: # 確保文本足夠長 | |
if random.random() < 0.5: | |
# 隨機刪除一些字符 | |
remove_pos = random.randint(0, len(text) - 10) | |
remove_len = random.randint(1, 3) | |
text = text[:remove_pos] + text[remove_pos + remove_len:] | |
else: | |
# 隨機重複一些字符 | |
repeat_pos = random.randint(0, len(text) - 5) | |
repeat_len = random.randint(1, 3) | |
repeat_text = text[repeat_pos:repeat_pos + repeat_len] | |
text = text[:repeat_pos] + repeat_text + text[repeat_pos:] | |
augmented_texts.append(text) | |
augmented_topics.append(idx) | |
# 添加增強的樣本到原始數據 | |
if augmented_texts: | |
aug_df = pd.DataFrame({ | |
'text': augmented_texts, | |
'numeric_topic': augmented_topics | |
}) | |
df = pd.concat([df, aug_df], ignore_index=True) | |
logger.info(f"添加了 {len(augmented_texts)} 個增強樣本,新的數據集大小: {df.shape}") | |
logger.info(f"增強後的類別分布: \n{df['numeric_topic'].value_counts().sort_index()}") | |
# 轉換為Dataset格式 | |
dataset = Dataset.from_pandas(df[['text', 'numeric_topic']].rename(columns={'numeric_topic': 'labels'})) | |
# 進行分詞處理 | |
tokenized_dataset = dataset.map( | |
lambda x: tokenizer(x['text'], padding="max_length", truncation=True, max_length=max_length), | |
batched=True | |
) | |
# 拆分訓練集和測試集 | |
train_test_split_ratio = 0.2 | |
train_test = tokenized_dataset.train_test_split(test_size=train_test_split_ratio, seed=seed, stratify_by_column="labels") | |
train_dataset = train_test["train"] | |
eval_dataset = train_test["test"] | |
logger.info(f"訓練集大小: {len(train_dataset)},測試集大小: {len(eval_dataset)}") | |
# ✅ 載入並配置模型 | |
config = AutoConfig.from_pretrained( | |
model_name, | |
num_labels=len(unique_topics), | |
hidden_dropout_prob=0.2, # 設置較低的dropout | |
attention_probs_dropout_prob=0.2, | |
classifier_dropout=0.3, # 分類器dropout率較高,可以減少過擬合 | |
) | |
# 加載模型 | |
logger.info(f"正在載入模型: {model_name}") | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_name, | |
config=config, | |
ignore_mismatched_sizes=True # 允許分類層大小不匹配 | |
) | |
# ✅ 凍結前10層,只微調最後2層 (適合小數據集) | |
logger.info("正在凍結前10層,只微調最後2層...") | |
# 獲取所有層 | |
all_layers = list(model.named_parameters()) | |
# 計算總層數 | |
total_layers = sum(1 for name, _ in all_layers if "layer" in name) | |
# 凍結前80%的層 | |
freeze_layers = int(0.8 * total_layers) | |
for i, (name, param) in enumerate(all_layers): | |
# 保留最後幾層和分類層 | |
if "layer" in name and int(name.split(".")[1]) < freeze_layers: | |
param.requires_grad = False # 凍結前面大部分層 | |
logger.info(f"凍結層: {name}") | |
else: | |
param.requires_grad = True # 訓練最後幾層和分類層 | |
logger.info(f"訓練層: {name}") | |
# 確認可訓練參數 | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
total_params = sum(p.numel() for p in model.parameters()) | |
logger.info(f"可訓練參數: {trainable_params:,} / 總參數: {total_params:,} ({trainable_params/total_params:.2%})") | |
# 添加類別權重到損失函數 | |
model.class_weights = class_weights # 保存類別權重以供後續使用 | |
# ✅ 設定訓練參數(小數據集最佳參數) | |
logger.info("配置小數據集最佳訓練參數...") | |
# 更新wandb配置 (如果啟用) | |
if os.environ.get("WANDB_DISABLED") != "true": | |
wandb.config.update({ | |
"epochs": 12, # 小數據需要多輪學習 | |
"batch_size": 8, # 小批次提高穩定性 | |
"learning_rate": 1e-5, # 降低學習率避免過擬合 | |
"weight_decay": 0.01, # 正則化防止模型記死 | |
"gradient_accumulation_steps": 2, # 小批次補償,讓梯度更穩 | |
"frozen_layers": "前80%層", # 凍結前80%層 | |
"early_stopping_patience": 3 # 3輪無進步就停止 | |
}) | |
# 添加報告選項 | |
report_to_list = ["tensorboard"] | |
if os.environ.get("WANDB_DISABLED") != "true": | |
report_to_list.append("wandb") | |
training_args = TrainingArguments( | |
output_dir=f"{output_dir}/checkpoints", | |
num_train_epochs=12, # ✅ 訓練12輪(小數據需要多輪學習) | |
per_device_train_batch_size=8, # ✅ 小批次(提高穩定性) | |
per_device_eval_batch_size=8, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
logging_dir=f"{output_dir}/logs/tensorboard", | |
logging_strategy="steps", | |
logging_steps=50, | |
load_best_model_at_end=True, | |
metric_for_best_model="f1", | |
greater_is_better=True, | |
learning_rate=1e-5, # ✅ 小數據降低學習率,避免過擬合 | |
weight_decay=0.01, # ✅ 正則化,防止模型記死 | |
warmup_ratio=0.1, # ✅ 設置warm-up,讓學習率慢慢上升 | |
gradient_accumulation_steps=2, # ✅ 小批次補償,讓梯度更穩 | |
fp16=True, | |
remove_unused_columns=False, | |
report_to=report_to_list, # 根據設置決定是否啟用wandb | |
save_total_limit=3, # 只保存最近的3個檢查點 | |
push_to_hub=False, # 不推送到HuggingFace Hub | |
dataloader_num_workers=2, # 使用較少線程,避免小數據過度並行處理 | |
group_by_length=True, # 分組相似長度的序列,提高效率 | |
) | |
# 自定義損失函數的訓練器 | |
class CustomTrainer(Trainer): | |
def compute_loss(self, model, inputs, return_outputs=False): | |
labels = inputs.pop("labels") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# 獲取類別權重 | |
device = logits.device | |
class_weights = model.class_weights.to(device) | |
# 計算帶權重的交叉熵損失 | |
loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights) | |
loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1)) | |
return (loss, outputs) if return_outputs else loss | |
# 定義早停回調,避免過擬合 | |
early_stopping = EarlyStoppingCallback( | |
early_stopping_patience=3, # ✅ 監控驗證集表現,3輪無進步就停 | |
early_stopping_threshold=0.001 | |
) | |
logger.info("配置訓練器,啟用Early Stopping...") | |
trainer = CustomTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
compute_metrics=compute_metrics, | |
callbacks=[early_stopping], # 使用Early Stopping避免過擬合 | |
) | |
# ✅ 開始訓練 | |
logger.info("開始訓練...") | |
trainer.train() | |
# ✅ 評估模型 | |
logger.info("評估最終模型...") | |
eval_results = trainer.evaluate() | |
logger.info(f"評估結果: {eval_results}") | |
if os.environ.get("WANDB_DISABLED") != "true": | |
wandb.log({"final_results": eval_results}) | |
# 在測試集上進行詳細評估 | |
logger.info("生成詳細測試報告...") | |
predictions = trainer.predict(eval_dataset) | |
preds = np.argmax(predictions.predictions, axis=1) | |
labels = predictions.label_ids | |
# 生成分類報告 | |
class_names = [inv_topic_dict[i] for i in range(len(unique_topics))] | |
classification_rep = classification_report(labels, preds, target_names=class_names, output_dict=True) | |
with open(f"{output_dir}/results/classification_report.json", "w", encoding="utf-8") as f: | |
json.dump(classification_rep, f, ensure_ascii=False, indent=2) | |
# 繪制混淆矩陣 | |
plt.figure(figsize=(10, 8)) | |
cm = confusion_matrix(labels, preds) | |
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) | |
plt.title('Confusion Matrix') | |
plt.xlabel('Predicted') | |
plt.ylabel('True') | |
plt.xticks(rotation=45, ha='right') | |
plt.tight_layout() | |
plt.savefig(f"{output_dir}/results/confusion_matrix.png") | |
if os.environ.get("WANDB_DISABLED") != "true": | |
wandb.log({"confusion_matrix": wandb.Image(f"{output_dir}/results/confusion_matrix.png")}) | |
logger.info(f"混淆矩陣已保存至 {output_dir}/results/confusion_matrix.png") | |
# ✅ 保存最終模型和分詞器 | |
final_model_path = f"{output_dir}/final_model" | |
model.save_pretrained(final_model_path) | |
tokenizer.save_pretrained(final_model_path) | |
logger.info(f"最終模型和分詞器已保存到 {final_model_path}") | |
# 將模型上傳到Hugging Face Hub (如需要) | |
push_to_hub = False # 設置為True如果要上傳到Hugging Face Hub | |
if push_to_hub: | |
from huggingface_hub import HfFolder, Repository | |
# 設置您的Hugging Face憑證 | |
# HfFolder.save_token("YOUR_HF_TOKEN") | |
# 推送到Hub | |
repo_name = f"chinese-topic-classifier-{timestamp}" | |
model.push_to_hub(repo_name) | |
tokenizer.push_to_hub(repo_name) | |
logger.info(f"模型已上傳至Hugging Face Hub: {repo_name}") | |
# 將模型上傳到wandb(如果wandb已啟用) | |
if os.environ.get("WANDB_DISABLED") != "true": | |
try: | |
model_artifact = wandb.Artifact('roberta-topic-model', type='model') | |
model_artifact.add_dir(final_model_path) | |
wandb.log_artifact(model_artifact) | |
logger.info("已將模型上傳到wandb") | |
except Exception as e: | |
logger.warning(f"上傳模型到wandb時發生錯誤: {e}") | |
else: | |
logger.info("wandb已禁用,跳過模型上傳") | |
# ✅ 定義預測函數 | |
def predict(text, return_probs=False): | |
inputs = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=1)[0] | |
prediction = torch.argmax(logits, dim=1).item() | |
if return_probs: | |
probs_dict = {inv_topic_dict[i]: float(probs[i]) for i in range(len(unique_topics))} | |
return inv_topic_dict[prediction], probs_dict | |
return inv_topic_dict[prediction] | |
# ✅ 創建Gradio界面 (Hugging Face平台特別適用) | |
try: | |
import gradio as gr | |
def predict_for_gradio(text): | |
topic, probs = predict(text, return_probs=True) | |
# 格式化機率為百分比 | |
formatted_probs = {k: f"{v:.2%}" for k, v in probs.items()} | |
# 排序並格式化結果 | |
sorted_probs = sorted(formatted_probs.items(), key=lambda x: float(x[1].strip('%'))/100, reverse=True) | |
result_text = f"預測主題: {topic}\n\n各類別機率:\n" | |
for class_name, prob in sorted_probs: | |
result_text += f"- {class_name}: {prob}\n" | |
return result_text | |
# 創建Gradio界面 | |
demo = gr.Interface( | |
fn=predict_for_gradio, | |
inputs=gr.Textbox(lines=5, placeholder="請輸入要分類的中文文本..."), | |
outputs="text", | |
title="中文主題分類器", | |
description=f"此模型可將文本分類為以下主題: {', '.join(unique_topics)}", | |
examples=[ | |
["這篇文章探討了太陽能電池的最新研究進展。"], | |
["碳捕捉技術可以減少溫室氣體排放。"], | |
["社區參與對環保項目的成功至關重要。"] | |
] | |
) | |
# 啟動Gradio應用 | |
demo.launch(share=True) | |
logger.info("Gradio界面已啟動") | |
except ImportError: | |
logger.info("未安裝Gradio,跳過界面創建") | |
# ✅ 創建API服務代碼 | |
api_code = ''' | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import torch | |
import json | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
import os | |
app = Flask(__name__) | |
CORS(app) | |
# 全局變量 | |
model = None | |
tokenizer = None | |
topic_dict = None | |
inv_topic_dict = None | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
global model, tokenizer, topic_dict, inv_topic_dict | |
model_path = "./final_model" | |
if not os.path.exists(model_path): | |
return {"error": f"模型路徑 {model_path} 不存在"} | |
topic_dict_path = "./topic_dict.json" | |
if not os.path.exists(topic_dict_path): | |
return {"error": f"類別映射文件 {topic_dict_path} 不存在"} | |
try: | |
# 載入模型和分詞器 | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
model.to(device) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# 載入類別映射 | |
with open(topic_dict_path, "r", encoding="utf-8") as f: | |
topic_dict = json.load(f) | |
inv_topic_dict = {v: k for k, v in topic_dict.items()} | |
return {"success": "模型載入成功"} | |
except Exception as e: | |
return {"error": f"載入模型時發生錯誤: {str(e)}"} | |
@app.route("/", methods=["GET"]) | |
def index(): | |
return jsonify({"status": "API服務運行中", "endpoints": {"/predict": "文本分類預測"}}) | |
@app.route("/predict", methods=["POST"]) | |
def predict_topic(): | |
# 確保模型已載入 | |
global model, tokenizer, topic_dict, inv_topic_dict | |
if model is None: | |
result = load_model() | |
if "error" in result: | |
return jsonify(result), 500 | |
# 獲取請求數據 | |
data = request.json | |
if not data or "text" not in data: | |
return jsonify({"error": "請求必須包含'text'字段"}), 400 | |
text = data["text"] | |
return_probs = data.get("return_probs", False) | |
try: | |
# 進行預測 | |
inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=1)[0] | |
prediction = torch.argmax(logits, dim=1).item() | |
result = {"topic": inv_topic_dict[prediction]} | |
if return_probs: | |
result["probabilities"] = {inv_topic_dict[i]: float(probs[i]) for i in range(len(inv_topic_dict))} | |
return jsonify(result) | |
except Exception as e: | |
return jsonify({"error": f"預測過程中發生錯誤: {str(e)}"}), 500 | |
if __name__ == "__main__": | |
# 預先載入模型 | |
load_model() | |
app.run(host="0.0.0.0", port=5000, debug=False) | |
''' | |
with open(f"{output_dir}/api/app.py", "w", encoding="utf-8") as f: | |
f.write(api_code) | |
# 創建啟動腳本 | |
startup_script = ''' | |
#!/bin/bash | |
cd "$(dirname "$0")" | |
export PYTHONIOENCODING=utf-8 | |
export FLASK_APP=app.py | |
flask run --host=0.0.0.0 --port=5000 | |
''' | |
with open(f"{output_dir}/api/start_api.sh", "w", encoding="utf-8") as f: | |
f.write(startup_script) | |
os.chmod(f"{output_dir}/api/start_api.sh", 0o755) | |
# 創建README文件 | |
readme = f''' | |
# 中文主題分類模型 API 服務 | |
## 概述 | |
這是一個使用預訓練語言模型訓練的中文主題分類API服務,可以將文本分類到以下類別: | |
{json.dumps({v: k for k, v in topic_dict.items()}, ensure_ascii=False, indent=2)} | |
## 使用方法 | |
### 啟動API服務 | |
1. 確保已安裝所需套件: `pip install flask flask-cors transformers torch` | |
2. 運行啟動腳本: `./start_api.sh` | |
### API端點 | |
- `GET /`: 檢查API狀態 | |
- `POST /predict`: 進行文本分類 | |
### 預測請求示例 | |
```bash | |
curl -X POST http://localhost:5000/predict \\ | |
-H "Content-Type: application/json" \\ | |
-d '{{"text": "您的文本內容", "return_probs": true}}' | |
``` | |
### 返回格式 | |
```json | |
{{ | |
"topic": "預測的類別", | |
"probabilities": {{ | |
"類別1": 0.8, | |
"類別2": 0.1, | |
"類別3": 0.05, | |
"類別4": 0.03, | |
"類別5": 0.02 | |
}} | |
}} | |
``` | |
## 在其他應用中使用 | |
### Python | |
```python | |
import requests | |
def predict_topic(text, return_probs=False): | |
response = requests.post('http://localhost:5000/predict', | |
json={{'text': text, 'return_probs': return_probs}}) | |
return response.json() | |
# 使用示例 | |
result = predict_topic("您的文本", return_probs=True) | |
print(f"預測類別: {{result['topic']}}") | |
if 'probabilities' in result: | |
for topic, prob in result['probabilities'].items(): | |
print(f"{{topic}}: {{prob:.2f}}") | |
``` | |
### JavaScript | |
```javascript | |
async function predictTopic(text, returnProbs = false) {{ | |
const response = await fetch('http://localhost:5000/predict', {{ | |
method: 'POST', | |
headers: {{ | |
'Content-Type': 'application/json', | |
}}, | |
body: JSON.stringify({{ text, return_probs: returnProbs }}), | |
}}); | |
return await response.json(); | |
}} | |
// 使用示例 | |
predictTopic("您的文本", true).then(result => {{ | |
console.log(`預測類別: ${{result.topic}}`); | |
if (result.probabilities) {{ | |
Object.entries(result.probabilities).forEach(([topic, prob]) => {{ | |
console.log(`${{topic}}: ${{prob.toFixed(2)}}`); | |
}}); | |
}} | |
}}); | |
``` | |
## 訓練詳情 | |
- 模型: {model_name} | |
- 訓練時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
- 訓練集大小: {len(train_dataset)} | |
- 測試集大小: {len(eval_dataset)} | |
- 最終測試集F1分數: {eval_results.get('eval_f1', 'N/A')} | |
''' | |
with open(f"{output_dir}/api/README.md", "w", encoding="utf-8") as f: | |
f.write(readme) | |
# 將API所需文件打包 | |
import shutil | |
os.makedirs(f"{output_dir}/api/final_model", exist_ok=True) | |
shutil.copytree(final_model_path, f"{output_dir}/api/final_model", dirs_exist_ok=True) | |
shutil.copy(f"{output_dir}/topic_dict.json", f"{output_dir}/api/topic_dict.json") | |
logger.info(f"API服務代碼和文件已準備完成,位於 {output_dir}/api/") | |
# 嘗試打包API文件夾 (在Hugging Face環境中可能不需要) | |
try: | |
import zipfile | |
def zip_directory(directory_path, zip_path): | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, _, files in os.walk(directory_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
arcname = os.path.relpath(file_path, os.path.dirname(directory_path)) | |
zipf.write(file_path, arcname) | |
zip_directory(f"{output_dir}/api", f"{output_dir}/api.zip") | |
logger.info(f"API服務文件已打包為 {output_dir}/api.zip") | |
except Exception as e: | |
logger.warning(f"打包API文件時發生錯誤: {e},但這不影響模型和API功能") | |
# ✅ 手動輸入文本並分類 | |
print("\n" + "="*50) | |
print("模型訓練完成!現在可以進行文本分類測試") | |
print("="*50 + "\n") | |
# 在Hugging Face平台上,我們可以提供幾個示例文本自動展示結果 | |
sample_texts = [ | |
"這篇文章探討了太陽能電池的最新研究進展。", | |
"碳捕捉技術可以減少溫室氣體排放。", | |
"社區參與對環保項目的成功至關重要。" | |
] | |
print("示例文本分類結果:") | |
for i, text in enumerate(sample_texts, 1): | |
topic, probs = predict(text, return_probs=True) | |
print(f"\n示例 {i}: {text}") | |
print(f"預測主題:{topic}") | |
print("各類別機率:") | |
for topic, prob in sorted(probs.items(), key=lambda x: x[1], reverse=True): | |
print(f"- {topic}: {prob:.4f}") | |
# 如果在交互式環境,仍然提供輸入選項 | |
if 'ipykernel' in sys.modules: | |
text_input = input("\n請輸入要分類的文本:") | |
if text_input: | |
topic, probs = predict(text_input, return_probs=True) | |
print(f"\n預測主題:{topic}") | |
print("各類別機率:") | |
for topic, prob in sorted(probs.items(), key=lambda x: x[1], reverse=True): | |
print(f"- {topic}: {prob:.4f}") | |
# ✅ 保存模型評估結果 | |
with open(f"{output_dir}/results/model_evaluation.txt", "w", encoding="utf-8") as f: | |
f.write(f"模型評估結果:\n") | |
f.write(f"準確率: {eval_results.get('eval_accuracy', 'N/A')}\n") | |
f.write(f"F1分數: {eval_results.get('eval_f1', 'N/A')}\n") | |
f.write(f"精確率: {eval_results.get('eval_precision', 'N/A')}\n") | |
f.write(f"召回率: {eval_results.get('eval_recall', 'N/A')}\n\n") | |
f.write("各類別F1分數:\n") | |
for class_name in class_names: | |
f.write(f"{class_name}: {eval_results.get(f'eval_f1_{class_name}', 'N/A')}\n") | |
# ✅ 完成wandb運行(如果已啟用) | |
if os.environ.get("WANDB_DISABLED") != "true": | |
try: | |
wandb.finish() | |
logger.info("wandb運行已完成") | |
except Exception as e: | |
logger.warning(f"結束wandb運行時發生錯誤: {e}") | |
# 保存到Hugging Face Hub的說明 | |
hub_instructions = f''' | |
# 將模型保存到Hugging Face Hub | |
如果您想將訓練好的模型分享到Hugging Face Hub,請按照以下步驟操作: | |
1. 確保您已登入Hugging Face: | |
```python | |
from huggingface_hub import login | |
login() # 會提示您輸入token | |
``` | |
2. 將模型上傳到Hub: | |
```python | |
model_id = "your-username/chinese-topic-classifier" # 替換為您的用戶名 | |
# 上傳模型 | |
model.push_to_hub(model_id) | |
# 上傳分詞器 | |
tokenizer.push_to_hub(model_id) | |
# 上傳配置文件 | |
with open("config.json", "w") as f: | |
json.dump({{"model_name": "{model_name}", | |
"num_classes": {len(unique_topics)}, | |
"classes": {list(topic_dict.keys())} | |
}}, f) | |
from huggingface_hub import upload_file | |
upload_file( | |
path_or_fileobj="config.json", | |
path_in_repo="config.json", | |
repo_id=model_id | |
) | |
``` | |
3. 創建一個基於Gradio的Demo並上傳: | |
```python | |
%%writefile app.py | |
import gradio as gr | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch.nn.functional as F | |
import torch | |
# 定義模型ID (您上面使用的) | |
model_id = "your-username/chinese-topic-classifier" | |
# 載入模型和分詞器 | |
model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# 類別映射 | |
topic_names = {list(topic_dict.keys())} | |
# 預測函數 | |
def predict(text): | |
inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=1)[0] | |
prediction = torch.argmax(logits, dim=1).item() | |
# 格式化結果 | |
result = f"預測類別: {{topic_names[prediction]}}\\n\\n機率分布:\\n" | |
for i, prob in enumerate(probs): | |
result += f"- {{topic_names[i]}}: {{prob:.4f}}\\n" | |
return result | |
# 創建Gradio界面 | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=5, placeholder="請輸入中文文本..."), | |
outputs="text", | |
title="中文主題分類器", | |
description="輸入中文文本,預測其所屬主題類別。", | |
examples=[ | |
"這篇文章探討了太陽能電池的最新研究進展。", | |
"碳捕捉技術可以減少溫室氣體排放。", | |
"社區參與對環保項目的成功至關重要。" | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |
# 然後上傳此應用到Hugging Face Spaces | |
``` | |
''' | |
with open(f"{output_dir}/hub_instructions.md", "w", encoding="utf-8") as f: | |
f.write(hub_instructions) | |
print(f""" | |
======================================== | |
訓練和API服務準備完成! | |
======================================== | |
訓練結果摘要: | |
- 模型保存在: {final_model_path} | |
- API服務代碼在: {output_dir}/api/ | |
將模型部署到Hugging Face Hub: | |
- 說明文件: {output_dir}/hub_instructions.md | |
API使用示例: | |
curl -X POST http://localhost:5000/predict \\ | |
-H "Content-Type: application/json" \\ | |
-d '{{"text": "您的文本", "return_probs": true}}' | |
更多詳情請參閱: {output_dir}/api/README.md | |
""") | |
# 特別為Hugging Face環境添加的Spaces部署說明 | |
print(""" | |
在Hugging Face平台上部署模型: | |
1. 創建一個新的Space: | |
- 前往 huggingface.co/new-space | |
- 選擇Gradio作為SDK | |
- 填寫名稱和描述 | |
2. 上傳模型和app.py: | |
- 將訓練好的模型上傳到你的Hugging Face賬戶 | |
- 根據hub_instructions.md中的說明創建app.py | |
- 上傳到你的Space | |
3. 配置Space: | |
- 在Space設置中添加依賴項: transformers, torch, gradio | |
完成這些步驟後,你將有一個公開可訪問的模型推理界面! | |
""") |