File size: 2,802 Bytes
7cf86e5
 
 
 
35429ce
7cf86e5
 
 
9909b2f
7cf86e5
 
 
 
 
 
 
35429ce
7cf86e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9909b2f
7cf86e5
 
 
 
 
 
9909b2f
7cf86e5
 
 
 
 
 
 
 
9909b2f
 
 
7cf86e5
 
 
 
 
 
 
 
 
 
 
 
b0b394a
9909b2f
b0b394a
 
7cf86e5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import csv
import random
import pandas as pd
import gradio as gr
from utils import clean_dir, TMP_DIR, EN_US

ZH2EN = {
    "输入参与者数量": "Number of participants",
    "输入分组比率 (格式为用:隔开的数字,生成随机分组数据)": "Grouping ratio (numbers separated by : to generate randomized controlled trial)",
    "状态栏": "Status",
    "下载随机分组数据 CSV": "Download data CSV",
    "随机分组数据预览": "Data preview",
}


def _L(zh_txt: str):
    return ZH2EN[zh_txt] if EN_US else zh_txt


def list_to_csv(list_of_dicts: list, filename: str):
    keys = dict(list_of_dicts[0]).keys()
    # 将列表中的字典写入 CSV 文件
    with open(filename, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=keys)
        writer.writeheader()
        for data in list_of_dicts:
            writer.writerow(data)


def random_allocate(participants: int, ratio: list, out_csv: str):
    splits = [0]
    total = sum(ratio)
    for i, r in enumerate(ratio):
        splits.append(splits[i] + int(1.0 * r / total * participants))

    splits[-1] = participants
    partist = list(range(1, participants + 1))
    random.shuffle(partist)
    allocation = []
    groups = len(ratio)
    for i in range(groups):
        start = splits[i]
        end = splits[i + 1]
        for participant in partist[start:end]:
            allocation.append({"id": participant, "group": i + 1})

    sorted_data = sorted(allocation, key=lambda x: x["id"])
    list_to_csv(sorted_data, out_csv)
    return out_csv, pd.DataFrame(sorted_data)


# outer func
def infer(participants: float, ratios: str, cache=f"{TMP_DIR}/rct"):
    ratio = []
    status = "Success"
    out_csv = previews = None
    try:
        ratio_list = ratios.split(":")
        clean_dir(cache)
        for r in ratio_list:
            current_ratio = float(r.strip())
            if current_ratio > 0:
                ratio.append(current_ratio)

        out_csv, previews = random_allocate(
            int(participants), ratio, f"{cache}/output.csv"
        )

    except Exception as e:
        status = f"{e}"

    return status, out_csv, previews


def rct_generator():
    return gr.Interface(
        fn=infer,
        inputs=[
            gr.Number(label=_L("输入参与者数量"), value=10),
            gr.Textbox(
                label=_L("输入分组比率 (格式为用:隔开的数字,生成随机分组数据)"),
                value="8:1:1",
            ),
        ],
        outputs=[
            gr.Textbox(label=_L("状态栏"), show_copy_button=True),
            gr.File(label=_L("下载随机分组数据 CSV")),
            gr.Dataframe(label=_L("随机分组数据预览")),
        ],
        flagging_mode="never",
    )