|
import sys |
|
import requests |
|
import json |
|
|
|
import pandas as pd |
|
|
|
|
|
SOCIOFILLMORE_API = "http://127.0.0.1:5000" |
|
AUTH_KEY = "3TrJ397oh#^" |
|
|
|
|
|
def get_sample(s, dataset, n_samples, frame, construction, role, dependency): |
|
|
|
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset}) |
|
|
|
r_q = s.get( |
|
SOCIOFILLMORE_API + "/sample_frame", |
|
params={ |
|
"auth_key": AUTH_KEY, |
|
"frame": frame, |
|
"construction": construction, |
|
"role": role, |
|
"dependency": dependency, |
|
"model": "lome_0shot", |
|
"n": n_samples, |
|
}, |
|
) |
|
|
|
data = json.loads(r_q.text) |
|
|
|
rows_out = [] |
|
|
|
for sent in data: |
|
for fns in sent["fn_structures"]: |
|
if fns["frame"] == frame: |
|
target_roles = [r for r in fns["roles"] if r[0] == role] |
|
if target_roles: |
|
target_role = target_roles[0] |
|
else: |
|
continue |
|
|
|
rows_out.append( |
|
{ |
|
"dataset": dataset, |
|
"sentence": " ".join(sent["sentence"]), |
|
"frame": frame, |
|
"target": " ".join(fns["target"]["tokens_str"]), |
|
"role_label": role, |
|
"role_span": " ".join(target_role[1]["tokens_str"]), |
|
"dependency": dependency, |
|
} |
|
) |
|
|
|
return rows_out |
|
|
|
|
|
def get_labels(s, dataset, frame): |
|
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset}) |
|
|
|
r_q = s.get( |
|
SOCIOFILLMORE_API + "/frame_freq", |
|
params={ |
|
"auth_key": AUTH_KEY, |
|
"model": "lome_0shot", |
|
"frames": frame, |
|
"constructions": "", |
|
"group_by_cat": "n", |
|
"group_by_constr": "n", |
|
"group_by_role_expr": 2, |
|
"relative": "y", |
|
"plot_over_days_post": "n", |
|
}, |
|
) |
|
|
|
data = json.loads(r_q.text) |
|
return {l.split("::")[2] for l in data["relevant_frame_counts"]["x"]} |
|
|
|
|
|
def main(language): |
|
|
|
s = requests.Session() |
|
|
|
|
|
if language == "it": |
|
print("Finding IT labels...") |
|
labels_it = get_labels(s, "femicides/rai", "Killing") |
|
sample_rows_it = [] |
|
for label in sorted(labels_it): |
|
|
|
if label == "_UNK_DEP": |
|
continue |
|
|
|
print(f"Label (IT): {label}") |
|
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Killer", label)) |
|
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Victim", label)) |
|
|
|
df_samples_it = pd.DataFrame(sample_rows_it) |
|
df_samples_it.to_csv("output/common/query_frame_samples/it_dep_samples.csv") |
|
|
|
if language == "nl": |
|
print("Finding NL labels...") |
|
labels_nl = get_labels(s, "crashes/thecrashes", "Cause_harm") |
|
sample_rows_nl = [] |
|
for label in sorted(labels_nl): |
|
|
|
if label == "_UNK_DEP": |
|
continue |
|
|
|
print(f"Label (NL): {label}") |
|
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Agent", label)) |
|
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Victim", label)) |
|
df_samples_nl = pd.DataFrame(sample_rows_nl) |
|
df_samples_nl.to_csv("output/common/query_frame_samples/nl_dep_samples.csv") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(language=sys.argv[1]) |
|
|