maiurilorenzo commited on
Commit
e7e4046
·
verified ·
1 Parent(s): bf865b9

Update tools/analyze_data.py

Browse files
Files changed (1) hide show
  1. tools/analyze_data.py +113 -66
tools/analyze_data.py CHANGED
@@ -1,81 +1,128 @@
1
- from smolagents import Tool
2
  import pandas as pd
3
- import seaborn as sns
4
- import matplotlib.pyplot as plt
5
- from io import BytesIO
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- class DataSummaryTool(Tool):
9
- name = "data_summary"
10
- description = "Provides a summary of the dataset."
11
- inputs = {'df': {'type': 'dataframe', 'description': 'The dataset to analyze.'}}
12
- output_type = "dict"
13
-
14
- def __init__(self, *args, **kwargs):
15
- self.is_initialized = False
16
-
17
- def forward(self, df: pd.DataFrame) -> dict:
18
  return {
19
  "num_rows": df.shape[0],
20
  "num_columns": df.shape[1],
21
  "preview": df.head().to_dict()
22
  }
23
-
24
- class MissingValuesTool(Tool):
25
- name = "missing_values"
26
- description = "Analyzes missing values in the dataset."
27
- inputs = {'df': {'type': 'dataframe', 'description': 'The dataset to analyze.'}}
28
- output_type = "dict"
29
 
30
- def forward(self, df: pd.DataFrame) -> dict:
31
- missing_values = df.isnull().sum()
32
- missing_percentage = (missing_values / len(df)) * 100
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return {
35
- "missing_values": missing_values.to_dict(),
36
- "missing_percentage": missing_percentage.to_dict()
37
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- class DuplicatesDetectorTool(Tool):
40
- name = "detect_duplicates"
41
- description = "Detects duplicate rows in the dataset."
42
- inputs = {'df': {'type': 'dataframe', 'description': 'The dataset to analyze.'}}
43
- output_type = "dict"
44
-
45
- def forward(self, df: pd.DataFrame) -> dict:
46
- duplicate_count = df.duplicated().sum()
47
- return {"duplicate_count": duplicate_count}
48
-
49
- class DataStatisticsTool(Tool):
50
- name = "data_statistics"
51
- description = "Provides basic statistics for numerical columns and lists column data types."
52
- inputs = {'df': {'type': 'dataframe', 'description': 'The dataset to analyze.'}}
53
- output_type = "dict"
54
-
55
- def forward(self, df: pd.DataFrame) -> dict:
56
  return {
57
- "data_types": df.dtypes.astype(str).to_dict(),
58
- "statistics": df.describe().to_dict()
59
  }
 
 
60
 
61
- class CorrelationMatrixTool(Tool):
62
- name = "correlation_matrix"
63
- description = "Generates a correlation matrix heatmap for numerical columns."
64
- inputs = {'df': {'type': 'dataframe', 'description': 'The dataset to analyze.'}}
65
- output_type = "bytes"
66
-
67
- def forward(self, df: pd.DataFrame) -> BytesIO:
68
- numeric_df = df.select_dtypes(include=["number"])
69
-
70
- if numeric_df.shape[1] < 2:
71
- raise ValueError("Not enough numerical columns for correlation analysis.")
72
-
73
- plt.figure(figsize=(10, 6))
74
- sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', fmt='.2f')
75
- plt.title("Correlation Matrix")
76
-
77
- img_bytes = BytesIO()
78
- plt.savefig(img_bytes, format='png')
79
- plt.close()
80
- img_bytes.seek(0)
81
- return img_bytes
 
1
+ from smolagents import tool
2
  import pandas as pd
 
 
 
3
 
4
+ @tool
5
+ def read_data(file_path: str) -> pd.DataFrame:
6
+ """A tool that reads an Excel or CSV file from a given path and returns a pandas DataFrame.
7
+ Args:
8
+ file_path: The path to the Excel (.xlsx) or CSV (.csv) file.
9
+ Returns:
10
+ A pandas DataFrame containing the data from the file.
11
+ """
12
+ try:
13
+ if file_path.endswith('.csv'):
14
+ df = pd.read_csv(file_path)
15
+ elif file_path.endswith('.xls'):
16
+ df = pd.read_excel(file_path)
17
+ else:
18
+ raise f"Unsupported file extension: {file_path}"
19
+
20
+ return df
21
+ except Exception as e:
22
+ raise Exception(f"Error reading the file: {str(e)}")
23
 
24
+ @tool
25
+ def get_data_summary(df: pd.DataFrame) -> dict:
26
+ """A tool that gives a summary of the data.
27
+ Args:
28
+ df: A pandas DataFrame.
29
+ Returns: A dictionary containing the number of rows and columns in the DataFrame, and a preview of the first few rows.
30
+ """
31
+ try:
 
 
32
  return {
33
  "num_rows": df.shape[0],
34
  "num_columns": df.shape[1],
35
  "preview": df.head().to_dict()
36
  }
37
+ except Exception as e:
38
+ raise Exception(f"Error in analyzing the dataset: {str(e)}")
 
 
 
 
39
 
40
+ import pandas as pd
41
+
42
+ @tool
43
+ def get_dataframe_statistics(data: dict) -> dict:
44
+ """A tool that calculates statistical summaries of a pandas DataFrame.
45
+
46
+ Args:
47
+ data: A dictionary where keys are column names and values are lists of column values.
48
+
49
+ Returns:
50
+ A dictionary containing summary statistics such as mean, median, standard deviation,
51
+ and count for numerical columns.
52
+ """
53
+ try:
54
+ # Convert input dictionary to DataFrame
55
+ df = pd.DataFrame(data)
56
+
57
+ # Generate summary statistics
58
+ stats = df.describe().to_dict()
59
+
60
+ # Convert NaN values to None for JSON compatibility
61
+ for col, col_stats in stats.items():
62
+ stats[col] = {key: (None if pd.isna(value) else value) for key, value in col_stats.items()}
63
+
64
+ return stats
65
+ except Exception as e:
66
+ raise Exception(f"error: {str(e)}")
67
+
68
+ @tool
69
+ def get_missing_values(data: dict) -> dict:
70
+ """A tool that calculates the number and percentage of missing values in a pandas DataFrame.
71
+
72
+ Args:
73
+ data: A dictionary where keys are column names and values are lists of column values.
74
+
75
+ Returns:
76
+ A dictionary with column names as keys and missing value statistics (count and percentage).
77
+ """
78
+ try:
79
+ df = pd.DataFrame(data)
80
+ missing_count = df.isnull().sum()
81
+ missing_percentage = (missing_count / len(df)) * 100
82
+
83
  return {
84
+ col: {"missing_count": int(missing_count[col]), "missing_percentage": missing_percentage[col]}
85
+ for col in df.columns
86
  }
87
+ except Exception as e:
88
+ return {"error": str(e)}
89
+
90
+ @tool
91
+ def get_duplicate_rows(data: dict) -> dict:
92
+ """A tool that finds duplicate rows in a pandas DataFrame.
93
+
94
+ Args:
95
+ data: A dictionary where keys are column names and values are lists of column values.
96
+
97
+ Returns:
98
+ A dictionary with the number of duplicate rows and sample duplicate rows.
99
+ """
100
+ try:
101
+ df = pd.DataFrame(data)
102
+ duplicates = df[df.duplicated(keep=False)]
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return {
105
+ "duplicate_count": int(df.duplicated().sum()),
106
+ "duplicate_rows": duplicates.to_dict(orient="records"),
107
  }
108
+ except Exception as e:
109
+ return {"error": str(e)}
110
 
111
+ @tool
112
+ def get_correlation_matrix(data: dict) -> dict:
113
+ """A tool that calculates the correlation matrix for numerical columns in a pandas DataFrame.
114
+
115
+ Args:
116
+ data: A dictionary where keys are column names and values are lists of column values.
117
+
118
+ Returns:
119
+ A dictionary representing the correlation matrix.
120
+ """
121
+ try:
122
+ df = pd.DataFrame(data)
123
+ correlation_matrix = df.corr().to_dict()
124
+
125
+ return correlation_matrix
126
+ except Exception as e:
127
+ return {"error": str(e)}
128
+