AjithKSenthil commited on
Commit
c77f143
·
1 Parent(s): 4247f5a

Upload ChatAttachmentAnalysisWithXG.py

Browse files
Files changed (1) hide show
  1. ChatAttachmentAnalysisWithXG.py +27 -0
ChatAttachmentAnalysisWithXG.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.multioutput import MultiOutputRegressor
4
+ import xgboost as xgb
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
7
+
8
+ datafile_path = "data/chat_transcripts_with_embeddings.csv"
9
+
10
+ df = pd.read_csv(datafile_path)
11
+ df["embedding"] = df.embedding.apply(eval).apply(np.array)
12
+
13
+ X = np.array(df.embedding.tolist())
14
+ y = df[["Attachment", "Avoidance"]]
15
+
16
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
17
+
18
+ xg_reg = xgb.XGBRegressor(objective ='reg:squarederror', colsample_bytree = 0.3, learning_rate = 0.1, max_depth = 5, alpha = 10, n_estimators = 10)
19
+
20
+ multioutputregressor = MultiOutputRegressor(xg_reg).fit(X_train, y_train)
21
+
22
+ preds = multioutputregressor.predict(X_test)
23
+
24
+ mse = mean_squared_error(y_test, preds)
25
+ mae = mean_absolute_error(y_test, preds)
26
+
27
+ print(f"ada-002 embedding performance on chat transcripts: mse={mse:.2f}, mae={mae:.2f}")