File size: 2,620 Bytes
f507537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b907431
f507537
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import csv
import random
import shutil
import pandas as pd
import gradio as gr

DATA_DIR = "./data"


def list_to_csv(list_of_dicts: list, filename: str):
    keys = dict(list_of_dicts[0]).keys()
    # 将列表中的字典写入 CSV 文件
    with open(filename, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=keys)
        writer.writeheader()
        for data in list_of_dicts:
            writer.writerow(data)

    return filename


def random_allocation(participants: int, ratio: list):
    total = sum(ratio)
    splits = [0]
    for i, r in enumerate(ratio):
        splits.append(splits[i] + int(1.0 * r / total * participants))

    splits[-1] = participants
    partist = list(range(1, participants + 1))
    random.shuffle(partist)
    allocation = []
    groups = len(ratio)
    for i in range(groups):
        start = splits[i]
        end = splits[i + 1]
        for participant in partist[start:end]:
            allocation.append({"id": participant, "group": i + 1})

    sorted_data = sorted(allocation, key=lambda x: x["id"])
    filename = list_to_csv(sorted_data, f"{DATA_DIR}/output.csv")
    return filename, pd.DataFrame(sorted_data)


def inference(participants: float, ratios: str):
    if os.path.exists(DATA_DIR):
        shutil.rmtree(DATA_DIR)

    os.makedirs(DATA_DIR, exist_ok=True)
    ratio_list = ratios.split(":")
    ratio = []
    try:
        for r in ratio_list:
            current_ratio = float(r.strip())
            if current_ratio > 0:
                ratio.append(current_ratio)

    except Exception:
        print("Invalid input of ratio!")

    return random_allocation(int(participants), ratio)


if __name__ == "__main__":
    gr.Interface(
        fn=inference,
        inputs=[
            gr.Number(
                label="输入参与者数量 (Number of participants)",
                value=10,
            ),
            gr.Textbox(label="输入分组比率 (Grouping ratio)", value="8:1:1"),
        ],
        outputs=[
            gr.components.File(label="下载随机分组数据 CSV (Download data CSV)"),
            gr.Dataframe(label="随机分组数据预览 (Data preview)"),
        ],
        title="随机对照试验随机数生成器<br>Randomized Controlled Trial Generator",
        description="输入参与者数量和分组比率,格式为用:隔开的数字,生成随机分组数据。<br>Enter the number of participants and the grouping ratio in the format of numbers separated by : to generate randomized grouping data.",
        flagging_mode="never",
    ).launch()