Spaces:
Runtime error
Runtime error
File size: 3,537 Bytes
b11ac48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import sys
import requests
import json
import pandas as pd
SOCIOFILLMORE_API = "http://127.0.0.1:5000"
AUTH_KEY = "3TrJ397oh#^"
def get_sample(s, dataset, n_samples, frame, construction, role, dependency):
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset})
r_q = s.get(
SOCIOFILLMORE_API + "/sample_frame",
params={
"auth_key": AUTH_KEY,
"frame": frame,
"construction": construction,
"role": role,
"dependency": dependency,
"model": "lome_0shot",
"n": n_samples,
},
)
data = json.loads(r_q.text)
rows_out = []
for sent in data:
for fns in sent["fn_structures"]:
if fns["frame"] == frame:
target_roles = [r for r in fns["roles"] if r[0] == role]
if target_roles:
target_role = target_roles[0]
else:
continue
rows_out.append(
{
"dataset": dataset,
"sentence": " ".join(sent["sentence"]),
"frame": frame,
"target": " ".join(fns["target"]["tokens_str"]),
"role_label": role,
"role_span": " ".join(target_role[1]["tokens_str"]),
"dependency": dependency,
}
)
return rows_out
def get_labels(s, dataset, frame):
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset})
r_q = s.get(
SOCIOFILLMORE_API + "/frame_freq",
params={
"auth_key": AUTH_KEY,
"model": "lome_0shot",
"frames": frame,
"constructions": "",
"group_by_cat": "n",
"group_by_constr": "n",
"group_by_role_expr": 2,
"relative": "y",
"plot_over_days_post": "n",
},
)
data = json.loads(r_q.text)
return {l.split("::")[2] for l in data["relevant_frame_counts"]["x"]}
def main(language):
s = requests.Session()
if language == "it":
print("Finding IT labels...")
labels_it = get_labels(s, "femicides/rai", "Killing")
sample_rows_it = []
for label in sorted(labels_it):
if label == "_UNK_DEP":
continue
print(f"Label (IT): {label}")
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Killer", label))
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Victim", label))
df_samples_it = pd.DataFrame(sample_rows_it)
df_samples_it.to_csv("output/common/query_frame_samples/it_dep_samples.csv")
if language == "nl":
print("Finding NL labels...")
labels_nl = get_labels(s, "crashes/thecrashes", "Cause_harm")
sample_rows_nl = []
for label in sorted(labels_nl):
if label == "_UNK_DEP":
continue
print(f"Label (NL): {label}")
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Agent", label))
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Victim", label))
df_samples_nl = pd.DataFrame(sample_rows_nl)
df_samples_nl.to_csv("output/common/query_frame_samples/nl_dep_samples.csv")
if __name__ == "__main__":
main(language=sys.argv[1])
|