File size: 4,961 Bytes
8a5a6d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import datetime
import enum

import streamlit as st

from core.names import find_unique_name
from core.state import Metadata
import mlcroissant as mlc


class RaiEvent(enum.Enum):
    """Event that triggers a Rai change."""

    RAI_DATA_COLLECTION = "RAI_DATA_COLLECTION"
    RAI_DATA_COLLECTION_TYPE = "RAI_DATA_COLLECTION_TYPE"
    RAI_DATA_COLLECTION_MISSING_DATA = "RAI_DATA_COLLECTION_MISSING_DATA"
    RAI_DATA_COLLECTION_RAW = "RAI_DATA_COLLECTION_RAW"
    RAI_DATA_COLLECTION_TIMEFRAME = "RAI_DATA_COLLECTION_TIMEFRAME"
    RAI_DATA_IMPUTATION_PROTOCOL = "RAI_DATA_IMPUTATION_PROTOCOL"
    RAI_DATA_PREPROCESSING_PROTOCOL = " RAI_DATA_PREPROCESSING_PROTOCOL"
    RAI_DATA_MANIPULATION_PROTOCOL = "RAI_DATA_MANIPULATION_PROTOCOL"
    RAI_DATA_ANNOTATION_PROTOCOL = "RAI_DATA_ANNOTATION_PROTOCOL"
    RAI_DATA_ANNOTATION_PLATFORM = "RAI_DATA_ANNOTATION_PLATFORM"
    RAI_DATA_ANNOTATION_ANALYSIS = "RAI_DATA_ANNOTATION_ANALYSIS"
    RAI_DATA_ANNOTATION_PER_ITEM = "RAI_DATA_ANNOTATION_PERI_TEM"
    RAI_DATA_ANNOTATION_DEMOGRAPHICS = "RAI_DATA_ANNOTATION_DEMOGRAPHICS"
    RAI_DATA_ANNOTATION_TOOLS = "RAI_DATA_ANNOTATION_TOOLS"
    RAI_DATA_USE_CASES = "RAI_DATA_USECASES"
    RAI_DATA_BIAS = "RAI_DATA_BIAS"
    RAI_DATA_LIMITATION = "RAI_DATA_LIMITATION"
    RAI_DATA_SOCIAL_IMPACT = "RAI_DATA_SOCIAL_IMPACT"
    RAI_SENSITIVE = "RAI_SENSITIVE"
    RAI_MAINTENANCE = "RAI_MAINTENANCE"


def handle_rai_change(event: RaiEvent, metadata: Metadata, key: str, index: int = 0):
    ## If widget is 1-to-many we first get the index to proper update them
    if event == RaiEvent.RAI_DATA_COLLECTION:
        metadata.data_collection = st.session_state[key]
    if event == RaiEvent.RAI_DATA_COLLECTION_TYPE:
        metadata.data_collection_type = st.session_state[key]
    if event == RaiEvent.RAI_DATA_COLLECTION_MISSING_DATA:
        metadata.data_collection_missing_data = st.session_state[key]
    if event == RaiEvent.RAI_DATA_COLLECTION_RAW:
        metadata.data_collection_raw_data = st.session_state[key]
    if event == RaiEvent.RAI_DATA_COLLECTION_TIMEFRAME:
        # To do
        raise NotImplementedError(
            "Data collectiom timeframe range still not implemented"
        )
        pass
    if event == RaiEvent.RAI_DATA_IMPUTATION_PROTOCOL:
        metadata.data_imputation_protocol = st.session_state[key]
    if event == RaiEvent.RAI_DATA_PREPROCESSING_PROTOCOL:
        if metadata.data_preprocessing_protocol:
            metadata.data_preprocessing_protocol[index] = st.session_state[key]
        else:
            metadata.data_preprocessing_protocol = []
            metadata.data_preprocessing_protocol.append(st.session_state[key])
    if event == RaiEvent.RAI_DATA_MANIPULATION_PROTOCOL:
        metadata.data_manipulation_protocol = st.session_state[key]
    if event == RaiEvent.RAI_DATA_ANNOTATION_PROTOCOL:

        metadata.data_annotation_protocol = st.session_state[key]
    if event == RaiEvent.RAI_DATA_ANNOTATION_PLATFORM:
        metadata.data_annotation_platform = st.session_state[key]
    if event == RaiEvent.RAI_DATA_ANNOTATION_ANALYSIS:
        metadata.data_annotation_analysis = st.session_state[key]
    if event == RaiEvent.RAI_DATA_ANNOTATION_PER_ITEM:
        metadata.annotation_per_item = st.session_state[key]
    if event == RaiEvent.RAI_DATA_ANNOTATION_DEMOGRAPHICS:
        metadata.annotator_demographics = st.session_state[key]
    if event == RaiEvent.RAI_DATA_ANNOTATION_TOOLS:
        metadata.machine_annotation_tools = st.session_state[key]
    if event == RaiEvent.RAI_DATA_USE_CASES:

        if metadata.data_use_cases:
            metadata.data_use_cases[int(index)] = st.session_state[key]
        else:
            metadata.data_use_cases = []
            metadata.data_use_cases.append(st.session_state[key])

    if event == RaiEvent.RAI_DATA_BIAS:

        if metadata.data_biases:
            metadata.data_biases[int(index)] = st.session_state[key]
        else:
            metadata.data_biases = []
            metadata.data_biases.append(st.session_state[key])

    if event == RaiEvent.RAI_DATA_LIMITATION:
        if metadata.data_limitations:
            metadata.data_limitations[int(index)] = st.session_state[key]
        else:
            metadata.data_limitations = []
            metadata.data_limitations.append(st.session_state[key])
    if event == RaiEvent.RAI_DATA_SOCIAL_IMPACT:
        metadata.data_social_impact = st.session_state[key]
    if event == RaiEvent.RAI_SENSITIVE:

        if metadata.personal_sensitive_information:
            metadata.personal_sensitive_information[int(index)] = st.session_state[key]
        else:
            metadata.personal_sensitive_information = []
            metadata.personal_sensitive_information.append(st.session_state[key])
    if event == RaiEvent.RAI_MAINTENANCE:
        metadata.data_release_maintenance_plan = st.session_state[key]


def get_widget_cadinality(key: str):
    return key.split("_")[-1]