|
import pandas as pd |
|
|
|
|
|
df = pd.read_csv('toutiao_cat_data.txt', |
|
sep='_!_', lineterminator='\n', |
|
encoding='utf8', |
|
names=["id", "type", "type_text", "text", "keywords"]) |
|
df = df[["text", "type"]] |
|
df["type"] = df["type"] - 100 |
|
|
|
|
|
df = df.sample(frac=1) |
|
train_df, test_df = df[:-1000], df[-1000:] |
|
|
|
|
|
from simpletransformers.classification import ClassificationModel |
|
model = ClassificationModel( |
|
"bert", |
|
"bert-base-chinese", |
|
num_labels=18, |
|
args={"reprocess_input_data": True, "overwrite_output_dir": True}, |
|
) |
|
|
|
|
|
model.train_model(train_df) |
|
|
|
|
|
import sklearn |
|
result = model.eval_model(test_df, acc=sklearn.metrics.accuracy_score) |
|
result[0] |
|
|
|
|
|
model.predict(["M2处理器IPad mini7值得期待吗?"]) |