|
import os |
|
|
|
import pandas as pd |
|
import argparse |
|
|
|
|
|
def main(input_file): |
|
|
|
df = pd.read_csv(input_file) |
|
df['acc1'] = (df['acc1'] * 100).round(2) |
|
df['dataset'] = df['dataset'].str.replace(r'^(wds/vtab/|wds/)', '', regex=True) |
|
|
|
|
|
columns_to_pivot = ["dataset"] |
|
|
|
|
|
index_columns = ["model", "pretrained", "attack", "eps", "iterations_adv"] |
|
|
|
|
|
df_pivot = df.pivot_table(values="acc1", index=index_columns, columns=columns_to_pivot).reset_index() |
|
del df |
|
|
|
|
|
output_file = "pivoted.csv" |
|
df_pivot.to_csv(output_file, index=False) |
|
print(df_pivot, "\n") |
|
print(df_pivot.to_csv(index=False)) |
|
print(f"Pivoted CSV saved as {output_file}") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Pivot a CSV file.") |
|
parser.add_argument("input_file", type=str, default=None, help="The input CSV file to be pivoted.") |
|
args = parser.parse_args() |
|
|
|
if not args.input_file: |
|
input_file = input("enter input file: ") |
|
|
|
else: |
|
input_file = args.input_file |
|
|
|
main(input_file) |
|
|