toutiao / main.py
myml
First model version
0ccd07a unverified
raw
history blame contribute delete
806 Bytes
import pandas as pd
# read dataset
df = pd.read_csv('toutiao_cat_data.txt',
sep='_!_', lineterminator='\n',
encoding='utf8',
names=["id", "type", "type_text", "text", "keywords"])
df = df[["text", "type"]]
df["type"] = df["type"] - 100
# split dataset
df = df.sample(frac=1)
train_df, test_df = df[:-1000], df[-1000:]
# create model
from simpletransformers.classification import ClassificationModel
model = ClassificationModel(
"bert",
"bert-base-chinese",
num_labels=18,
args={"reprocess_input_data": True, "overwrite_output_dir": True},
)
# train
model.train_model(train_df)
# eval
import sklearn
result = model.eval_model(test_df, acc=sklearn.metrics.accuracy_score)
result[0]
# predict
model.predict(["M2处理器IPad mini7值得期待吗?"])