Spaces:
Running
Running
import csv | |
import random | |
import pandas as pd | |
import gradio as gr | |
from utils import clean_dir, TMP_DIR, EN_US | |
ZH2EN = { | |
"输入参与者数量": "Number of participants", | |
"输入分组比率 (格式为用:隔开的数字,生成随机分组数据)": "Grouping ratio (numbers separated by : to generate randomized controlled trial)", | |
"状态栏": "Status", | |
"下载随机分组数据 CSV": "Download data CSV", | |
"随机分组数据预览": "Data preview", | |
} | |
def _L(zh_txt: str): | |
return ZH2EN[zh_txt] if EN_US else zh_txt | |
def list_to_csv(list_of_dicts: list, filename: str): | |
keys = dict(list_of_dicts[0]).keys() | |
# 将列表中的字典写入 CSV 文件 | |
with open(filename, "w", newline="", encoding="utf-8") as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=keys) | |
writer.writeheader() | |
for data in list_of_dicts: | |
writer.writerow(data) | |
def random_allocate(participants: int, ratio: list, out_csv: str): | |
splits = [0] | |
total = sum(ratio) | |
for i, r in enumerate(ratio): | |
splits.append(splits[i] + int(1.0 * r / total * participants)) | |
splits[-1] = participants | |
partist = list(range(1, participants + 1)) | |
random.shuffle(partist) | |
allocation = [] | |
groups = len(ratio) | |
for i in range(groups): | |
start = splits[i] | |
end = splits[i + 1] | |
for participant in partist[start:end]: | |
allocation.append({"id": participant, "group": i + 1}) | |
sorted_data = sorted(allocation, key=lambda x: x["id"]) | |
list_to_csv(sorted_data, out_csv) | |
return out_csv, pd.DataFrame(sorted_data) | |
# outer func | |
def infer(participants: float, ratios: str, cache=f"{TMP_DIR}/rct"): | |
ratio = [] | |
status = "Success" | |
out_csv = previews = None | |
try: | |
ratio_list = ratios.split(":") | |
clean_dir(cache) | |
for r in ratio_list: | |
current_ratio = float(r.strip()) | |
if current_ratio > 0: | |
ratio.append(current_ratio) | |
out_csv, previews = random_allocate( | |
int(participants), ratio, f"{cache}/output.csv" | |
) | |
except Exception as e: | |
status = f"{e}" | |
return status, out_csv, previews | |
def rct_generator(): | |
return gr.Interface( | |
fn=infer, | |
inputs=[ | |
gr.Number(label=_L("输入参与者数量"), value=10), | |
gr.Textbox( | |
label=_L("输入分组比率 (格式为用:隔开的数字,生成随机分组数据)"), | |
value="8:1:1", | |
), | |
], | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.File(label=_L("下载随机分组数据 CSV")), | |
gr.Dataframe(label=_L("随机分组数据预览")), | |
], | |
flagging_mode="never", | |
) | |