Spaces:
Running
Running
import os | |
import csv | |
import random | |
import shutil | |
import pandas as pd | |
import gradio as gr | |
DATA_DIR = "./data" | |
def list_to_csv(list_of_dicts: list, filename: str): | |
keys = dict(list_of_dicts[0]).keys() | |
# 将列表中的字典写入 CSV 文件 | |
with open(filename, "w", newline="", encoding="utf-8") as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=keys) | |
writer.writeheader() | |
for data in list_of_dicts: | |
writer.writerow(data) | |
return filename | |
def random_allocation(participants: int, ratio: list): | |
total = sum(ratio) | |
splits = [0] | |
for i, r in enumerate(ratio): | |
splits.append(splits[i] + int(1.0 * r / total * participants)) | |
splits[-1] = participants | |
partist = list(range(1, participants + 1)) | |
random.shuffle(partist) | |
allocation = [] | |
groups = len(ratio) | |
for i in range(groups): | |
start = splits[i] | |
end = splits[i + 1] | |
for participant in partist[start:end]: | |
allocation.append({"id": participant, "group": i + 1}) | |
sorted_data = sorted(allocation, key=lambda x: x["id"]) | |
filename = list_to_csv(sorted_data, f"{DATA_DIR}/output.csv") | |
return filename, pd.DataFrame(sorted_data) | |
def inference(participants: float, ratios: str): | |
if os.path.exists(DATA_DIR): | |
shutil.rmtree(DATA_DIR) | |
os.makedirs(DATA_DIR, exist_ok=True) | |
ratio_list = ratios.split(":") | |
ratio = [] | |
try: | |
for r in ratio_list: | |
current_ratio = float(r.strip()) | |
if current_ratio > 0: | |
ratio.append(current_ratio) | |
except Exception: | |
print("Invalid input of ratio!") | |
return random_allocation(int(participants), ratio) | |
if __name__ == "__main__": | |
gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Number( | |
label="输入参与者数量 (Number of participants)", | |
value=10, | |
), | |
gr.Textbox(label="输入分组比率 (Grouping ratio)", value="8:1:1"), | |
], | |
outputs=[ | |
gr.components.File(label="下载随机分组数据 CSV (Download data CSV)"), | |
gr.Dataframe(label="随机分组数据预览 (Data preview)"), | |
], | |
title="随机对照试验随机数生成器<br>Randomized Controlled Trial Generator", | |
description="输入参与者数量和分组比率,格式为用:隔开的数字,生成随机分组数据。<br>Enter the number of participants and the grouping ratio in the format of numbers separated by : to generate randomized grouping data.", | |
allow_flagging=False, | |
).launch() | |