Spaces:
Sleeping
Sleeping
import datetime | |
import random | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import torch | |
import torch.nn as nn | |
from absl import app, flags, logging | |
from loguru import logger | |
from scipy import stats | |
from sklearn import metrics, model_selection | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
from torch.utils.tensorboard import SummaryWriter | |
import config | |
import dataset | |
import engine | |
from model import BERTBaseUncased | |
from utils import categorical_accuracy, label_decoder, label_encoder | |
matplotlib.rcParams['interactive'] == True | |
SEED = 42 | |
random.seed(SEED) | |
np.random.seed(SEED) | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
torch.backends.cudnn.deterministic = True | |
writer = SummaryWriter() | |
logger.add("experiment.log") | |
flags.DEFINE_boolean('features', True, "") | |
flags.DEFINE_string('test_file', None, "") | |
flags.DEFINE_string('model_path', None, "") | |
FLAGS = flags.FLAGS | |
def main(_): | |
test_file = config.DATASET_LOCATION + "eval.prep.test.csv" | |
model_path = config.MODEL_PATH | |
if FLAGS.test_file: | |
test_file = FLAGS.test_file | |
if FLAGS.model_path: | |
model_path = FLAGS.model_path | |
df_test = pd.read_csv(test_file).fillna("none") | |
# Commenting as there are no labels | |
if FLAGS.features: | |
df_test.label = df_test.label.apply(label_encoder) | |
logger.info(f"Bert Model: {config.BERT_PATH}") | |
logger.info( | |
f"Current date and time :{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ") | |
logger.info(f"Test file: {test_file}") | |
logger.info(f"Test size : {len(df_test):.4f}") | |
test_dataset = dataset.BERTDataset( | |
review=df_test.text.values, | |
target=df_test.label.values | |
) | |
test_data_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=config.VALID_BATCH_SIZE, | |
num_workers=3 | |
) | |
device = config.device | |
model = BERTBaseUncased() | |
model.load_state_dict(torch.load( | |
model_path, map_location=torch.device(device))) | |
model.to(device) | |
outputs, extracted_features = engine.predict_fn( | |
test_data_loader, model, device, extract_features=FLAGS.features) | |
df_test["predicted"] = outputs | |
# save file | |
df_test.to_csv(model_path.split( | |
"/")[-2]+'.csv', header=None, index=False) | |
if FLAGS.features: | |
pca = PCA(n_components=50, random_state=7) | |
X1 = pca.fit_transform(extracted_features) | |
tsne = TSNE(n_components=2, perplexity=10, random_state=6, | |
learning_rate=1000, n_iter=1500) | |
X1 = tsne.fit_transform(X1) | |
# if row == 0: print("Shape after t-SNE: ", X1.shape) | |
X = pd.DataFrame(np.concatenate([X1], axis=1), | |
columns=["x1", "y1"]) | |
X = X.astype({"x1": float, "y1": float}) | |
# Plot for layer -1 | |
plt.figure(figsize=(20, 15)) | |
p1 = sns.scatterplot(x=X["x1"], y=X["y1"], palette="coolwarm") | |
# p1.set_title("development-"+str(row+1)+", layer -1") | |
x_texts = [] | |
for output, value in zip(outputs, df_test.label.values): | |
if output == value: | |
x_texts.append("@"+label_decoder(output) | |
[0] + label_decoder(output)) | |
else: | |
x_texts.append(label_decoder(value) + | |
"-" + label_decoder(output)) | |
X["texts"] = x_texts | |
# X["texts"] = ["@G" + label_decoder(output) if output == value else "@R-" + label_decoder(value) + "-" + label_decoder(output) | |
# for output, value in zip(outputs, df_test.label.values)] | |
# df_test.label.astype(str) | |
#([str(output)+"-" + str(value)] for output, value in zip(outputs, df_test.label.values)) | |
# Label each datapoint with the word it corresponds to | |
for line in X.index: | |
text = X.loc[line, "texts"]+"-"+str(line) | |
if "@U" in text: | |
p1.text(X.loc[line, "x1"]+0.2, X.loc[line, "y1"], text[2:], horizontalalignment='left', | |
size='medium', color='blue', weight='semibold') | |
elif "@P" in text: | |
p1.text(X.loc[line, "x1"]+0.2, X.loc[line, "y1"], text[2:], horizontalalignment='left', | |
size='medium', color='green', weight='semibold') | |
elif "@N" in text: | |
p1.text(X.loc[line, "x1"]+0.2, X.loc[line, "y1"], text[2:], horizontalalignment='left', | |
size='medium', color='red', weight='semibold') | |
else: | |
p1.text(X.loc[line, "x1"]+0.2, X.loc[line, "y1"], text, horizontalalignment='left', | |
size='medium', color='black', weight='semibold') | |
plt.show() | |
plt.savefig(model_path.split( | |
"/")[-2]+'-figure.svg', format="svg") | |
# loocv = model_selection.LeaveOneOut() | |
# model = KNeighborsClassifier(n_neighbors=8) | |
# results = model_selection.cross_val_score(model, X, Y, cv=loocv) | |
# for i, j in outputs, extracted_features: | |
# utils.write_embeddings_to_file(extracted_features, outputs) | |
if __name__ == "__main__": | |
app.run(main) | |