File size: 1,321 Bytes
e1aaaac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os

import pandas as pd
import argparse


def main(input_file):
    # Read the original CSV
    df = pd.read_csv(input_file)
    df['acc1'] = (df['acc1'] * 100).round(2)
    df['dataset'] = df['dataset'].str.replace(r'^(wds/vtab/|wds/)', '', regex=True)

    # Define the columns to pivot
    columns_to_pivot = ["dataset"]

    # Define the columns that will become the index of the new table
    index_columns = ["model", "pretrained", "attack", "eps", "iterations_adv"]

    # Pivot the DataFrame to the desired format
    df_pivot = df.pivot_table(values="acc1", index=index_columns, columns=columns_to_pivot).reset_index()
    del df

    # Save the pivoted DataFrame as a new CSV
    output_file = "pivoted.csv"
    df_pivot.to_csv(output_file, index=False)
    print(df_pivot, "\n")
    print(df_pivot.to_csv(index=False))
    print(f"Pivoted CSV saved as {output_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pivot a CSV file.")
    parser.add_argument("input_file", type=str, default=None, help="The input CSV file to be pivoted.")
    args = parser.parse_args()

    if not args.input_file:
        input_file = input("enter input file: ")
        # input_file = os.path.join("out", input_file)
    else:
        input_file = args.input_file

    main(input_file)