import os import csv import random import shutil import pandas as pd import gradio as gr DATA_DIR = "./data" def list_to_csv(list_of_dicts: list, filename: str): keys = dict(list_of_dicts[0]).keys() # 将列表中的字典写入 CSV 文件 with open(filename, "w", newline="", encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=keys) writer.writeheader() for data in list_of_dicts: writer.writerow(data) return filename def random_allocation(participants: int, ratio: list): total = sum(ratio) splits = [0] for i, r in enumerate(ratio): splits.append(splits[i] + int(1.0 * r / total * participants)) splits[-1] = participants partist = list(range(1, participants + 1)) random.shuffle(partist) allocation = [] groups = len(ratio) for i in range(groups): start = splits[i] end = splits[i + 1] for participant in partist[start:end]: allocation.append({"id": participant, "group": i + 1}) sorted_data = sorted(allocation, key=lambda x: x["id"]) filename = list_to_csv(sorted_data, f"{DATA_DIR}/output.csv") return filename, pd.DataFrame(sorted_data) def inference(participants: float, ratios: str): if os.path.exists(DATA_DIR): shutil.rmtree(DATA_DIR) os.makedirs(DATA_DIR, exist_ok=True) ratio_list = ratios.split(":") ratio = [] try: for r in ratio_list: current_ratio = float(r.strip()) if current_ratio > 0: ratio.append(current_ratio) except Exception: print("Invalid input of ratio!") return random_allocation(int(participants), ratio) if __name__ == "__main__": gr.Interface( fn=inference, inputs=[ gr.Number( label="输入参与者数量 (Number of participants)", value=10, ), gr.Textbox(label="输入分组比率 (Grouping ratio)", value="8:1:1"), ], outputs=[ gr.components.File(label="下载随机分组数据 CSV (Download data CSV)"), gr.Dataframe(label="随机分组数据预览 (Data preview)"), ], title="随机对照试验随机数生成器
Randomized Controlled Trial Generator", description="输入参与者数量和分组比率,格式为用:隔开的数字,生成随机分组数据。
Enter the number of participants and the grouping ratio in the format of numbers separated by : to generate randomized grouping data.", flagging_mode="never", ).launch()