File size: 806 Bytes
0ccd07a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import pandas as pd
# read dataset
df = pd.read_csv('toutiao_cat_data.txt',
sep='_!_', lineterminator='\n',
encoding='utf8',
names=["id", "type", "type_text", "text", "keywords"])
df = df[["text", "type"]]
df["type"] = df["type"] - 100
# split dataset
df = df.sample(frac=1)
train_df, test_df = df[:-1000], df[-1000:]
# create model
from simpletransformers.classification import ClassificationModel
model = ClassificationModel(
"bert",
"bert-base-chinese",
num_labels=18,
args={"reprocess_input_data": True, "overwrite_output_dir": True},
)
# train
model.train_model(train_df)
# eval
import sklearn
result = model.eval_model(test_df, acc=sklearn.metrics.accuracy_score)
result[0]
# predict
model.predict(["M2处理器IPad mini7值得期待吗?"]) |