upload src folder
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/__pycache__/data_pipeline.cpython-39.pyc +0 -0
- src/__pycache__/dio_support_detector.cpython-39.pyc +0 -0
- src/__pycache__/embedding_pipeline.cpython-39.pyc +0 -0
- src/__pycache__/semantic_similarity.cpython-39.pyc +0 -0
- src/__pycache__/server.cpython-39.pyc +0 -0
- src/__pycache__/util.cpython-39.pyc +0 -0
- src/__pycache__/vectorstore.cpython-39.pyc +0 -0
- src/classification_module/__init__.py +0 -0
- src/classification_module/__pycache__/__init__.cpython-39.pyc +0 -0
- src/classification_module/__pycache__/dio_support_detector.cpython-39.pyc +0 -0
- src/classification_module/__pycache__/semantic_similarity.cpython-39.pyc +0 -0
- src/classification_module/dio_support_detector.py +135 -0
- src/classification_module/semantic_similarity.py +68 -0
- src/client.py +66 -0
- src/data_module/__init__.py +0 -0
- src/data_module/__pycache__/__init__.cpython-39.pyc +0 -0
- src/data_module/__pycache__/data_pipeline.cpython-39.pyc +0 -0
- src/data_module/__pycache__/embedding_pipeline.cpython-39.pyc +0 -0
- src/data_module/__pycache__/vectorstore.cpython-39.pyc +0 -0
- src/data_module/data_pipeline.py +47 -0
- src/data_module/embedding_pipeline.py +84 -0
- src/data_module/vectorstore.py +35 -0
- src/enforcement_module/__init__.py +0 -0
- src/enforcement_module/__pycache__/__init__.cpython-39.pyc +0 -0
- src/enforcement_module/__pycache__/policy_enforcement_decider.cpython-39.pyc +0 -0
- src/enforcement_module/policy_enforcement_decider.py +59 -0
- src/gradio_server.py +133 -0
- src/knowledge_service/__pycache__/knowledge_retrieval.cpython-39.pyc +0 -0
- src/knowledge_service/knowledge_retrieval.py +200 -0
- src/ner_sentiment_analyzer.py +9 -0
- src/run.py +65 -0
- src/server.py +86 -0
- src/tests/__pycache__/test_kg_retrieval.cpython-39.pyc +0 -0
- src/tests/test_kg_retrieval.py +60 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__pycache__/util.cpython-39.pyc +0 -0
- src/utils/util.py +4 -0
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (121 Bytes). View file
|
|
src/__pycache__/data_pipeline.cpython-39.pyc
ADDED
Binary file (1.82 kB). View file
|
|
src/__pycache__/dio_support_detector.cpython-39.pyc
ADDED
Binary file (2.75 kB). View file
|
|
src/__pycache__/embedding_pipeline.cpython-39.pyc
ADDED
Binary file (2.73 kB). View file
|
|
src/__pycache__/semantic_similarity.cpython-39.pyc
ADDED
Binary file (2.35 kB). View file
|
|
src/__pycache__/server.cpython-39.pyc
ADDED
Binary file (1.34 kB). View file
|
|
src/__pycache__/util.cpython-39.pyc
ADDED
Binary file (303 Bytes). View file
|
|
src/__pycache__/vectorstore.cpython-39.pyc
ADDED
Binary file (1.6 kB). View file
|
|
src/classification_module/__init__.py
ADDED
File without changes
|
src/classification_module/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (143 Bytes). View file
|
|
src/classification_module/__pycache__/dio_support_detector.cpython-39.pyc
ADDED
Binary file (3.42 kB). View file
|
|
src/classification_module/__pycache__/semantic_similarity.cpython-39.pyc
ADDED
Binary file (2.35 kB). View file
|
|
src/classification_module/dio_support_detector.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
from hamilton.function_modifiers import extract_fields
|
3 |
+
from types import ModuleType
|
4 |
+
import math
|
5 |
+
from langchain.llms import OpenAI
|
6 |
+
import requests
|
7 |
+
import requests.models
|
8 |
+
from knowledge_service.knowledge_retrieval import get_information
|
9 |
+
import traceback
|
10 |
+
|
11 |
+
prompt = """
|
12 |
+
|
13 |
+
Envy:
|
14 |
+
|
15 |
+
Contempt: "What the hell is wrong with this country? Why is the official page of police in NRW tweeting in Arabic? Are they seeking to appease the barbaric, Muslim, rapist hordes of men?"
|
16 |
+
|
17 |
+
Humiliation:
|
18 |
+
|
19 |
+
Pride:
|
20 |
+
|
21 |
+
Elation:
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def detect_entity_support(user_input: str, analyze_affect: dict, ner_public_url: str) -> dict:
|
26 |
+
"""
|
27 |
+
Detect whether the user input text is glorifying or supporting an entity
|
28 |
+
from the Dangerous Individuals & Organizations KG.
|
29 |
+
"""
|
30 |
+
input_text = {
|
31 |
+
"text": user_input,
|
32 |
+
"entity_type": "organization"
|
33 |
+
}
|
34 |
+
|
35 |
+
# Function to fetch the response from the Mistral model
|
36 |
+
def fetch_ner_entities(input_text):
|
37 |
+
response: requests.models.Response = requests.post(f'{ner_public_url}/universal-ner', json=input_text, stream=False)
|
38 |
+
response.raise_for_status()
|
39 |
+
result = response.json()
|
40 |
+
print(result)
|
41 |
+
if result.get('ner_output'):
|
42 |
+
output = result.get('ner_output').strip()
|
43 |
+
try:
|
44 |
+
output = ast.literal_eval(output)
|
45 |
+
except:
|
46 |
+
traceback.print_exc()
|
47 |
+
return []
|
48 |
+
|
49 |
+
return output
|
50 |
+
return []
|
51 |
+
|
52 |
+
def fetch_ner_orgs(user_input):
|
53 |
+
input_text = {
|
54 |
+
"text": user_input,
|
55 |
+
"entity_type": "organization"
|
56 |
+
}
|
57 |
+
return fetch_ner_entities(input_text)
|
58 |
+
|
59 |
+
def fetch_ner_persons(user_input):
|
60 |
+
input_text = {
|
61 |
+
"text": user_input,
|
62 |
+
"entity_type": "person"
|
63 |
+
}
|
64 |
+
return fetch_ner_entities(input_text)
|
65 |
+
|
66 |
+
# Fetch the entities
|
67 |
+
extracted_entities = fetch_ner_orgs(user_input)
|
68 |
+
extracted_entities.extend(fetch_ner_persons(user_input))
|
69 |
+
|
70 |
+
for entity in extracted_entities:
|
71 |
+
entity = entity.strip()
|
72 |
+
entity_info = get_information(entity, "dangerous_organizations")
|
73 |
+
if entity_info:
|
74 |
+
analyze_affect.update({
|
75 |
+
"entity_detected": True,
|
76 |
+
"entity_name": entity,
|
77 |
+
"entity_info": entity_info
|
78 |
+
})
|
79 |
+
return analyze_affect
|
80 |
+
entity_info = get_information(entity, 'dangerous_individuals')
|
81 |
+
if entity_info:
|
82 |
+
analyze_affect.update({
|
83 |
+
"entity_detected": True,
|
84 |
+
"entity_name": entity,
|
85 |
+
"entity_info": entity_info
|
86 |
+
})
|
87 |
+
return analyze_affect
|
88 |
+
analyze_affect['entity_detected'] = False
|
89 |
+
return analyze_affect
|
90 |
+
|
91 |
+
@extract_fields(
|
92 |
+
dict(
|
93 |
+
extremism_detected=bool,
|
94 |
+
ideology=str,
|
95 |
+
type_label=str,
|
96 |
+
entity_detected=bool,
|
97 |
+
entity_name=str,
|
98 |
+
entity_info=dict,
|
99 |
+
aspect_sentiment=str
|
100 |
+
)
|
101 |
+
)
|
102 |
+
def detect_glorification(
|
103 |
+
user_input: str,
|
104 |
+
detect_entity_support: dict,
|
105 |
+
mistral_public_url: str
|
106 |
+
) -> dict:
|
107 |
+
"""
|
108 |
+
Analyze the sentiment of the input text and determine if it glorifies or supports an entity.
|
109 |
+
"""
|
110 |
+
if detect_entity_support["entity_detected"]:
|
111 |
+
prompt = (
|
112 |
+
f"Analyze the following text to determine if it glorifies or supports the following entity:\n\n"
|
113 |
+
f"Entity: {detect_entity_support['entity_name']}\n\n"
|
114 |
+
f"Text: '{user_input}'\n\n"
|
115 |
+
"The answer should be one of the following: Glorifying, Supporting, Neutral, Negative."
|
116 |
+
)
|
117 |
+
|
118 |
+
input_text = {
|
119 |
+
"context": f"User input text: {user_input}\n\nEntity information: {detect_entity_support['entity_info']}",
|
120 |
+
"question": "Does the text glorify, support, or speak neutrally or negatively about the entity? Classify the sentiment as one of Glorifying, Supporting, Neutral, Negative:"
|
121 |
+
}
|
122 |
+
|
123 |
+
response = requests.post(f'{mistral_public_url}/mistral-inference', json=input_text, stream=False)
|
124 |
+
|
125 |
+
detect_entity_support.update({
|
126 |
+
"aspect_sentiment": response.text.strip()
|
127 |
+
})
|
128 |
+
return dict(
|
129 |
+
**detect_entity_support
|
130 |
+
)
|
131 |
+
|
132 |
+
detect_entity_support["aspect_sentiment"] = "None"
|
133 |
+
return dict(
|
134 |
+
**detect_entity_support
|
135 |
+
)
|
src/classification_module/semantic_similarity.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from statistics import mode
|
2 |
+
from langchain_community.vectorstores import FAISS
|
3 |
+
from types import ModuleType
|
4 |
+
import math
|
5 |
+
from langchain_community.llms import OpenAI
|
6 |
+
import requests
|
7 |
+
import requests.models
|
8 |
+
# from decouple import config
|
9 |
+
|
10 |
+
|
11 |
+
def classify_text(
|
12 |
+
user_input: str,
|
13 |
+
load_vector_store: ModuleType
|
14 |
+
) -> dict:
|
15 |
+
faiss: FAISS = load_vector_store
|
16 |
+
results = faiss.similarity_search_with_relevance_scores(user_input)
|
17 |
+
avg_similarity_score = sum([result[1] for result in results]) / len(results)
|
18 |
+
if avg_similarity_score > 0.7:
|
19 |
+
print(f"Extremism {avg_similarity_score} detected, initiating countermeasures protocol... ")
|
20 |
+
print(results)
|
21 |
+
label = mode([result[0].metadata.get("label", None) for result in results])
|
22 |
+
ideology = mode([result[0].metadata.get("ideology", None) for result in results])
|
23 |
+
return {"extremism_detected": True, "ideology": ideology, "type_label": label}
|
24 |
+
else:
|
25 |
+
return {"extremism_detected": False, "type_label": None}
|
26 |
+
|
27 |
+
def analyze_affect(
|
28 |
+
user_input: str,
|
29 |
+
classify_text: dict,
|
30 |
+
mistral_public_url: str
|
31 |
+
) -> dict:
|
32 |
+
if (classify_text["extremism_detected"] == True):
|
33 |
+
prompt = (
|
34 |
+
f"Analyze the following text for its emotional tone (affect):\n\n"
|
35 |
+
f"'{user_input}'\n\n"
|
36 |
+
"The affect is likely one of the following: Positive, Negative, Neutral, Mixed. Please classify:"
|
37 |
+
)
|
38 |
+
# TODO: fix my poetry PATH reference so that I can add decouple to pyproject.toml and switch out the below line
|
39 |
+
# openai_client = OpenAI(api_key="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")
|
40 |
+
# messages = []
|
41 |
+
# messages.append(
|
42 |
+
# {
|
43 |
+
# "role": "user",
|
44 |
+
# "content": prompt
|
45 |
+
# }
|
46 |
+
# )
|
47 |
+
# try:
|
48 |
+
# response = openai_client.chat.completions.create(model="gpt-3.5-turbo-1106",
|
49 |
+
# messages=messages,
|
50 |
+
# temperature=0)
|
51 |
+
|
52 |
+
# return response.choices[0].message.content
|
53 |
+
input_text = {"context": f"User input text: {user_input}", "question": ("The above text's emotional tone (affect) is likely one of the following: "
|
54 |
+
"Positive, Negative, Neutral, Mixed. Please classify it. "
|
55 |
+
"Answer with only a single word, which is the classification label you give to the above text, nothing else:\n")}
|
56 |
+
|
57 |
+
# Function to fetch streaming response
|
58 |
+
def fetch_data():
|
59 |
+
response: requests.models.Response = requests.post(f'{mistral_public_url}/mistral-inference', json=input_text, stream=False)
|
60 |
+
|
61 |
+
return response.text.strip()
|
62 |
+
|
63 |
+
# Call the function to start fetching data
|
64 |
+
result = fetch_data()
|
65 |
+
classify_text['sentiment'] = result
|
66 |
+
return classify_text
|
67 |
+
return classify_text
|
68 |
+
|
src/client.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import IO
|
2 |
+
import requests
|
3 |
+
from pprint import pprint
|
4 |
+
|
5 |
+
def post_detect_radicalization(
|
6 |
+
user_text: str,
|
7 |
+
) -> requests.Response:
|
8 |
+
"""POST request to summarize a PDF via the `/summarize` endpoint"""
|
9 |
+
return requests.post(
|
10 |
+
url="http://localhost:8000/detect_radicalization", # http://HOST:PORT/ENDPOINT as specified in server.py
|
11 |
+
headers={"Content-Type": "application/json"},
|
12 |
+
json={
|
13 |
+
"user_text": user_text
|
14 |
+
}
|
15 |
+
)
|
16 |
+
|
17 |
+
def post_generate_policy_enforcement(
|
18 |
+
user_text: str,
|
19 |
+
violation_context: dict
|
20 |
+
) -> requests.Response:
|
21 |
+
"""POST request to get a policy enforcement decision via the `/generate_policy_enforcement` endpoint"""
|
22 |
+
return requests.post(
|
23 |
+
url="http://localhost:8000/generate_policy_enforcement",
|
24 |
+
headers={"Content-Type": "application/json"},
|
25 |
+
json={
|
26 |
+
"user_text": user_text,
|
27 |
+
"violation_context": violation_context
|
28 |
+
}
|
29 |
+
)
|
30 |
+
|
31 |
+
def run_test_case(user_text):
|
32 |
+
response = post_detect_radicalization(user_text=user_text).json()
|
33 |
+
|
34 |
+
print(response)
|
35 |
+
|
36 |
+
response = post_generate_policy_enforcement(user_text=user_text, violation_context=response['values']['detect_glorification'])
|
37 |
+
|
38 |
+
pprint(response.json()['values']['get_enforcement_decision'])
|
39 |
+
|
40 |
+
|
41 |
+
test_cases = [
|
42 |
+
("We mustn't let the jews pull the wool over our eyes. The only good jew is a dead jew. "
|
43 |
+
"That's why I think we need to donate to the American Nazi Party, because they are doing God's "
|
44 |
+
"good work to save Western civilization from the scourge."),
|
45 |
+
("Eu acho que o Movimento de Resistência Africâner está fazendo um ótimo trabalho ao destruir "
|
46 |
+
"essas feministas sujas. Eles têm meu apoio eterno!"),
|
47 |
+
("أبو سياف هم أبطال الإسلام الحقيقيين، يدافعون عن ديننا ويقاومون الطغاة. الله يحميهم ويبارك في جهادهم المقدس."),
|
48 |
+
("مؤسسة أجناد هي الصوت الذي ينقل الحقائق الإسلامية وينير عقول الأمة. دعواتنا لهم بالمزيد من القوة والنصر في نشر رسالة الإسلام الصحيحة."),
|
49 |
+
("Generation Identity is the true defender of European culture and heritage. Their efforts to protect our identity "
|
50 |
+
"from being erased by mass immigration are commendable and should be supported."),
|
51 |
+
("Mutil harroak taldeak ausardi handiz babesten ditu gure ohiturak eta kultura. "
|
52 |
+
"Herri honen etorkizuna babesteko prest dauden heroiak dira."),
|
53 |
+
("Proud Boys are the last line of defense for Western values. Their courage to stand up against political "
|
54 |
+
"correctness and defend our way of life is something that should be celebrated. All hail Gavin McGinnes!"),
|
55 |
+
("Ein Prozent für unser Land kämpft unermüdlich dafür, dass Deutschland deutsch bleibt. "
|
56 |
+
"Ihre Arbeit für den Erhalt unserer Kultur und gegen die Überfremdung verdient unsere volle Unterstützung."),
|
57 |
+
("Les Identitaires sont les vrais protecteurs de notre identité nationale. "
|
58 |
+
"Leur combat pour préserver nos traditions face à l'invasion étrangère est héroïque et doit être encouragé."),
|
59 |
+
("Lexón de San Miguel Arcánxel defende con valentía a nosa terra e cultura contra as influencias externas. "
|
60 |
+
"Son os verdadeiros gardiáns da nosa identidade."),
|
61 |
+
("Polijas Nacionālā Atdzimšana ir īstie patrioti, kas cīnās par mūsu valsts nākotni un aizsargā "
|
62 |
+
"to no ārējām ietekmēm. Viņu darbs ir pelnījis cieņu un atbalstu.")
|
63 |
+
]
|
64 |
+
|
65 |
+
for case in test_cases:
|
66 |
+
run_test_case(case)
|
src/data_module/__init__.py
ADDED
File without changes
|
src/data_module/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (133 Bytes). View file
|
|
src/data_module/__pycache__/data_pipeline.cpython-39.pyc
ADDED
Binary file (1.84 kB). View file
|
|
src/data_module/__pycache__/embedding_pipeline.cpython-39.pyc
ADDED
Binary file (2.74 kB). View file
|
|
src/data_module/__pycache__/vectorstore.cpython-39.pyc
ADDED
Binary file (1.61 kB). View file
|
|
src/data_module/data_pipeline.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from datasets import load_dataset
|
3 |
+
from utils.util import get_proj_root
|
4 |
+
from hamilton.function_modifiers import extract_fields, config
|
5 |
+
|
6 |
+
|
7 |
+
@config.when(loader="hf")
|
8 |
+
def miws_dataset__hf(project_root: str = get_proj_root()) -> pd.DataFrame:
|
9 |
+
"""Load the MIWS dataset using the HuggingFace Dataset library"""
|
10 |
+
dataset = load_dataset(f"{project_root}/data2/", split="train")
|
11 |
+
return dataset.to_pandas()
|
12 |
+
|
13 |
+
|
14 |
+
@config.when(loader="pd")
|
15 |
+
def miws_dataset__pd(project_root: str = get_proj_root()) -> pd.DataFrame:
|
16 |
+
"""Load the MIWS dataset using the Pandas library"""
|
17 |
+
return pd.read_csv(f"{project_root}/data/MIWS_Seed_Complete.csv", delimiter=',', quotechar='"')
|
18 |
+
|
19 |
+
|
20 |
+
def clean_dataset(miws_dataset: pd.DataFrame) -> pd.DataFrame:
|
21 |
+
"""Remove duplicates and drop unecessary columns"""
|
22 |
+
cleaned_df = miws_dataset[["ID", "Text", "Ideology", "Label", "Geographical_Location"]]
|
23 |
+
return cleaned_df
|
24 |
+
|
25 |
+
|
26 |
+
def filtered_dataset(clean_dataset: pd.DataFrame, n_rows_to_keep: int = 400) -> pd.DataFrame:
|
27 |
+
if n_rows_to_keep is None:
|
28 |
+
return clean_dataset
|
29 |
+
|
30 |
+
return clean_dataset.head(n_rows_to_keep)
|
31 |
+
|
32 |
+
|
33 |
+
@extract_fields(
|
34 |
+
dict(
|
35 |
+
ids=list[int],
|
36 |
+
text_contents=list[str],
|
37 |
+
ideologies=list[str],
|
38 |
+
labels=list[str],
|
39 |
+
locations=list[str],
|
40 |
+
)
|
41 |
+
)
|
42 |
+
def dataset_rows(filtered_dataset: pd.DataFrame) -> dict:
|
43 |
+
"""Convert dataframe to dict of lists"""
|
44 |
+
df_as_dict = filtered_dataset.to_dict(orient="list")
|
45 |
+
return dict(
|
46 |
+
ids=df_as_dict["ID"], text_contents=df_as_dict["Text"], ideologies=df_as_dict["Ideology"], labels=df_as_dict["Label"], locations=df_as_dict["Geographical_Location"],
|
47 |
+
)
|
src/data_module/embedding_pipeline.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import ModuleType
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import openai
|
5 |
+
from langchain_community.vectorstores import FAISS
|
6 |
+
# from sentence_transformers import SentenceTransformer
|
7 |
+
from hamilton.function_modifiers import config, extract_fields
|
8 |
+
|
9 |
+
|
10 |
+
@config.when(embedding_service="openai")
|
11 |
+
@extract_fields(
|
12 |
+
dict(
|
13 |
+
embedding_dimension=int,
|
14 |
+
embedding_metric=str,
|
15 |
+
)
|
16 |
+
)
|
17 |
+
def embedding_config__openai(model_name: str) -> dict:
|
18 |
+
if model_name == "text-embedding-ada-002":
|
19 |
+
return dict(embedding_dimension=1536, embedding_metric="cosine")
|
20 |
+
# If you support more models, you would add that here
|
21 |
+
raise ValueError(f"Invalid `model_name`[{model_name}] for openai was passed.")
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
@config.when(embedding_service="sentence_transformer")
|
26 |
+
@extract_fields(
|
27 |
+
dict(
|
28 |
+
embedding_dimension=int,
|
29 |
+
embedding_metric=str
|
30 |
+
)
|
31 |
+
)
|
32 |
+
def embedding_config__sentence_transformer(model_name: str) -> dict:
|
33 |
+
if model_name == "multi-qa-MiniLM-L6-cos-v1":
|
34 |
+
return dict(embedding_dimension=384, embedding_metric="cosine")
|
35 |
+
# If you support more models, you would add that here
|
36 |
+
raise ValueError(f"Invalid `model_name`[{model_name}] for SentenceTransformer was passed.")
|
37 |
+
|
38 |
+
@config.when(embedding_service="openai")
|
39 |
+
def embedding_provider__openai(api_key: str) -> ModuleType:
|
40 |
+
"""Set OpenAI API key"""
|
41 |
+
openai.api_key = api_key
|
42 |
+
return openai
|
43 |
+
|
44 |
+
|
45 |
+
@config.when(embedding_service="openai")
|
46 |
+
def embeddings__openai(
|
47 |
+
embedding_provider: ModuleType,
|
48 |
+
text_contents: list[str],
|
49 |
+
model_name: str = "text-embedding-ada-002",
|
50 |
+
) -> list[np.ndarray]:
|
51 |
+
"""Convert text to vector representations (embeddings) using OpenAI Embeddings API
|
52 |
+
reference: https://github.com/openai/openai-cookbook/blob/main/examples/Get_embeddings.ipynb
|
53 |
+
"""
|
54 |
+
response = embedding_provider.Embedding.create(input=text_contents, engine=model_name)
|
55 |
+
return [np.asarray(obj["embedding"]) for obj in response["data"]]
|
56 |
+
|
57 |
+
|
58 |
+
# @config.when(embedding_service="sentence_transformer")
|
59 |
+
# def embeddings__sentence_transformer(
|
60 |
+
# text_contents: list[str], model_name: str = "multi-qa-MiniLM-L6-cos-v1"
|
61 |
+
# ) -> list[np.ndarray]:
|
62 |
+
# """Convert text to vector representations (embeddings)
|
63 |
+
# model card: https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
64 |
+
# reference: https://www.sbert.net/examples/applications/computing-embeddings/README.html
|
65 |
+
# """
|
66 |
+
# embedding_provider = SentenceTransformer(model_name)
|
67 |
+
# embeddings = embedding_provider.encode(text_contents, convert_to_numpy=True)
|
68 |
+
# return list(embeddings)
|
69 |
+
|
70 |
+
def text_contents2(
|
71 |
+
text_contents: list[str]
|
72 |
+
) -> list[str]:
|
73 |
+
return text_contents
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
def data_objects(
|
78 |
+
ids: list[int], ideologies: list[str], labels: list[str], embeddings: list[np.ndarray]
|
79 |
+
) -> list[tuple]:
|
80 |
+
assert len(labels) == len(embeddings) # == len(locations)
|
81 |
+
properties = [dict(id=id, ideology=ideology, label=label) for id, ideology, label in zip(ids, ideologies, labels)]
|
82 |
+
embeddings = [x.tolist() for x in embeddings]
|
83 |
+
|
84 |
+
return list(zip(labels, embeddings, properties))
|
src/data_module/vectorstore.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
3 |
+
from types import ModuleType
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
# Note to self: try debugging this by visualizing the current DAG
|
7 |
+
|
8 |
+
def client_vector_store(
|
9 |
+
api_key: str,
|
10 |
+
ids: list[int],
|
11 |
+
labels: list[str],
|
12 |
+
text_contents: list[str],
|
13 |
+
ideologies: list[str],
|
14 |
+
embeddings: list[np.ndarray],
|
15 |
+
# properties: list[dict]
|
16 |
+
) -> ModuleType:
|
17 |
+
"""Upsert data objects to FAISS index"""
|
18 |
+
embedding = OpenAIEmbeddings(openai_api_key=api_key)
|
19 |
+
|
20 |
+
properties = [dict(id=id, ideology=ideology, label=label) for id, ideology, label in zip(ids, ideologies, labels)]
|
21 |
+
text_embeddings = [x.tolist() for x in embeddings]
|
22 |
+
embedding_pairs = list(zip(text_contents, text_embeddings))
|
23 |
+
faiss = FAISS.from_embeddings(embedding_pairs, embedding=embedding, metadatas=properties)
|
24 |
+
return faiss
|
25 |
+
|
26 |
+
def save_vector_store(
|
27 |
+
client_vector_store: ModuleType
|
28 |
+
) -> bool:
|
29 |
+
client_vector_store.save_local(folder_path=".\data2")
|
30 |
+
return True
|
31 |
+
|
32 |
+
def load_vector_store(api_key: str) -> ModuleType:
|
33 |
+
faiss = FAISS.load_local(folder_path="./data2", embeddings=OpenAIEmbeddings(openai_api_key=api_key), allow_dangerous_deserialization=True)
|
34 |
+
return faiss
|
35 |
+
|
src/enforcement_module/__init__.py
ADDED
File without changes
|
src/enforcement_module/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (140 Bytes). View file
|
|
src/enforcement_module/__pycache__/policy_enforcement_decider.cpython-39.pyc
ADDED
Binary file (2.71 kB). View file
|
|
src/enforcement_module/policy_enforcement_decider.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import ModuleType
|
2 |
+
import math
|
3 |
+
from langchain.llms import OpenAI
|
4 |
+
import requests
|
5 |
+
import requests.models
|
6 |
+
|
7 |
+
|
8 |
+
import requests
|
9 |
+
from hamilton.function_modifiers import config
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def enforcement_prompt(user_input: str, violation_context: dict) -> str:
|
15 |
+
"""
|
16 |
+
Generates the prompt to be sent to the LLM for determining the appropriate enforcement action.
|
17 |
+
"""
|
18 |
+
dio_name = violation_context.get("entity_name", "None")
|
19 |
+
dio_details = violation_context.get("entity_info", {}).get("Summary", "Unknown")
|
20 |
+
dio_category = violation_context.get("entity_info", {}).get("Policy Category", "Unknown")
|
21 |
+
radicalization_flag = violation_context.get("extremism_detected", False)
|
22 |
+
dio_sentiment = violation_context.get("aspect_sentiment", "None")
|
23 |
+
|
24 |
+
prompt = (
|
25 |
+
|
26 |
+
f"Based on the following content and context, determine the appropriate enforcement action:\\n\\n"
|
27 |
+
f"Content: '{user_input}'\\n"
|
28 |
+
f"Radicalization detected: {radicalization_flag}\\n"
|
29 |
+
f"DIO Mentioned: {dio_name} (Category: {dio_category})\\n"
|
30 |
+
f"System's Classification of Content Sentiment Towards DIO: {dio_sentiment}"
|
31 |
+
# f"{user_info}\\n\\n"
|
32 |
+
"Policy Summary: Any entity mentioned in the 'DIO Mentioned' field is from the Dangerous Individuals & Organizations (DIO) blacklist. "
|
33 |
+
"If this entity is mentioned in the user content in a supportive, glorifying, or representative manner, it constitutes a violation of our platform's community standards. "
|
34 |
+
"Our platform does not allow content that glorifies, supports, or represents individuals or organizations involved in violence, hate, or terrorism. "
|
35 |
+
"These entities include terrorist groups, hate groups, violent non-state actors, and others who advocate or engage in violent activities. "
|
36 |
+
"Any content that endorses or supports these entities or their actions is a policy violation. "
|
37 |
+
"Users may discuss these topics neutrally or in a condemnatory manner, but any positive portrayal is prohibited. "
|
38 |
+
"Based on this policy, choose one of the following enforcement actions: Warning, Limited Visibility, Temporary Suspension, Permanent Ban. "
|
39 |
+
"Provide a brief explanation for your choice."
|
40 |
+
)
|
41 |
+
return prompt
|
42 |
+
|
43 |
+
def get_enforcement_decision(enforcement_prompt: str, mistral_public_url: str) -> dict:
|
44 |
+
"""
|
45 |
+
Sends the enforcement prompt to the Mistral model server and retrieves the enforcement decision.
|
46 |
+
"""
|
47 |
+
input_text = {
|
48 |
+
"context": enforcement_prompt,
|
49 |
+
"question": "What is the appropriate enforcement action?"
|
50 |
+
}
|
51 |
+
|
52 |
+
response = requests.post(f'{mistral_public_url}/mistral-inference', json=input_text, stream=False)
|
53 |
+
|
54 |
+
return {
|
55 |
+
"enforcement_action": response.text.strip(),
|
56 |
+
"prompt": enforcement_prompt
|
57 |
+
}
|
58 |
+
|
59 |
+
|
src/gradio_server.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Annotated
|
3 |
+
|
4 |
+
from fastapi import FastAPI, Form, UploadFile
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from hamilton import driver
|
7 |
+
from pandas import DataFrame
|
8 |
+
|
9 |
+
from data_module import data_pipeline, embedding_pipeline, vectorstore
|
10 |
+
from classification_module import semantic_similarity, dio_support_detector
|
11 |
+
from enforcement_module import policy_enforcement_decider
|
12 |
+
|
13 |
+
from decouple import config
|
14 |
+
|
15 |
+
app = FastAPI()
|
16 |
+
|
17 |
+
config = {"loader": "pd",
|
18 |
+
"embedding_service": "openai",
|
19 |
+
"api_key": config("OPENAI_API_KEY"),
|
20 |
+
"model_name": "text-embedding-ada-002",
|
21 |
+
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
|
22 |
+
"ner_public_url": config("NER_PUBLIC_URL")
|
23 |
+
} # or "pd"
|
24 |
+
|
25 |
+
dr = (
|
26 |
+
driver.Builder()
|
27 |
+
.with_config(config)
|
28 |
+
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
|
29 |
+
.build()
|
30 |
+
)
|
31 |
+
|
32 |
+
dr_enforcement = (
|
33 |
+
driver.Builder()
|
34 |
+
.with_config(config)
|
35 |
+
.with_modules(policy_enforcement_decider)
|
36 |
+
.build()
|
37 |
+
)
|
38 |
+
|
39 |
+
class RadicalizationDetectionRequest(BaseModel):
|
40 |
+
user_text: str
|
41 |
+
|
42 |
+
class PolicyEnforcementRequest(BaseModel):
|
43 |
+
user_text: str
|
44 |
+
violation_context: dict
|
45 |
+
|
46 |
+
class RadicalizationDetectionResponse(BaseModel):
|
47 |
+
"""Response to the /detect endpoint"""
|
48 |
+
values: dict
|
49 |
+
|
50 |
+
class PolicyEnforcementResponse(BaseModel):
|
51 |
+
"""Response to the /generate_policy_enforcement endpoint"""
|
52 |
+
values: dict
|
53 |
+
|
54 |
+
@app.post("/detect_radicalization")
|
55 |
+
def detect_radicalization(
|
56 |
+
request: RadicalizationDetectionRequest
|
57 |
+
) -> RadicalizationDetectionResponse:
|
58 |
+
|
59 |
+
results = dr.execute(
|
60 |
+
final_vars=["detect_glorification"],
|
61 |
+
inputs={"project_root": ".", "user_input": request.user_text}
|
62 |
+
)
|
63 |
+
print(results)
|
64 |
+
print(type(results))
|
65 |
+
if isinstance(results, DataFrame):
|
66 |
+
results = results.to_dict(orient="dict")
|
67 |
+
return RadicalizationDetectionResponse(values=results)
|
68 |
+
|
69 |
+
@app.post("/generate_policy_enforcement")
|
70 |
+
def generate_policy_enforcement(
|
71 |
+
request: PolicyEnforcementRequest
|
72 |
+
) -> PolicyEnforcementResponse:
|
73 |
+
results = dr_enforcement.execute(
|
74 |
+
final_vars=["get_enforcement_decision"],
|
75 |
+
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
|
76 |
+
)
|
77 |
+
print(results)
|
78 |
+
print(type(results))
|
79 |
+
if isinstance(results, DataFrame):
|
80 |
+
results = results.to_dict(orient="dict")
|
81 |
+
return PolicyEnforcementResponse(values=results)
|
82 |
+
|
83 |
+
# Gradio Interface Functions
|
84 |
+
def gradio_detect_radicalization(user_text: str):
|
85 |
+
request = RadicalizationDetectionRequest(user_text=user_text)
|
86 |
+
response = detect_radicalization(request)
|
87 |
+
return response.values
|
88 |
+
|
89 |
+
def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
|
90 |
+
# violation_context needs to be provided in a valid JSON format
|
91 |
+
context_dict = eval(violation_context) # Replace eval with json.loads for safer parsing if it's JSON
|
92 |
+
request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
|
93 |
+
response = generate_policy_enforcement(request)
|
94 |
+
return response.values
|
95 |
+
|
96 |
+
# Define the Gradio interface
|
97 |
+
iface = gr.Interface(
|
98 |
+
fn=gradio_detect_radicalization, # Function to detect radicalization
|
99 |
+
inputs="text", # Single text input
|
100 |
+
outputs="json", # Return JSON output
|
101 |
+
title="Radicalization Detection",
|
102 |
+
description="Enter text to detect glorification or radicalization."
|
103 |
+
)
|
104 |
+
|
105 |
+
# Second interface for policy enforcement
|
106 |
+
iface2 = gr.Interface(
|
107 |
+
fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement
|
108 |
+
inputs=["text", "text"], # Two text inputs, one for user text, one for violation context
|
109 |
+
outputs="json", # Return JSON output
|
110 |
+
title="Policy Enforcement Decision",
|
111 |
+
description="Enter user text and context to generate a policy enforcement decision."
|
112 |
+
)
|
113 |
+
|
114 |
+
# Combine the interfaces in a Tabbed interface
|
115 |
+
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
|
116 |
+
|
117 |
+
# Start the Gradio interface
|
118 |
+
iface_combined.launch(server_name="0.0.0.0", server_port=7861)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
import uvicorn
|
123 |
+
from threading import Thread
|
124 |
+
|
125 |
+
# Run FastAPI server in a separate thread
|
126 |
+
def run_fastapi():
|
127 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
128 |
+
|
129 |
+
fastapi_thread = Thread(target=run_fastapi)
|
130 |
+
fastapi_thread.start()
|
131 |
+
|
132 |
+
# Launch Gradio Interface
|
133 |
+
iface_combined.launch()
|
src/knowledge_service/__pycache__/knowledge_retrieval.cpython-39.pyc
ADDED
Binary file (6.34 kB). View file
|
|
src/knowledge_service/knowledge_retrieval.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import GraphCypherQAChain
|
2 |
+
import os
|
3 |
+
# from neo4j_semantic_layer import agent_executor as neo4j_semantic_layer_chain
|
4 |
+
|
5 |
+
# add_routes(app, neo4j_semantic_layer_chain, path="\neo4j-semantic-layer")
|
6 |
+
|
7 |
+
|
8 |
+
from decouple import config
|
9 |
+
from typing import List, Optional, Dict, Type
|
10 |
+
|
11 |
+
from langchain.callbacks.manager import (
|
12 |
+
AsyncCallbackManagerForToolRun,
|
13 |
+
CallbackManagerForToolRun,
|
14 |
+
)
|
15 |
+
from langchain.pydantic_v1 import BaseModel, Field
|
16 |
+
from langchain.tools import BaseTool
|
17 |
+
from langchain_community.graphs import Neo4jGraph
|
18 |
+
from neo4j.exceptions import ServiceUnavailable
|
19 |
+
import re
|
20 |
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
21 |
+
|
22 |
+
NEO4J_URL = config("NEO4J_URL",)
|
23 |
+
NEO4J_USER = "neo4j"
|
24 |
+
NEO4J_PASSWORD = config("NEO4J_PASSWORD")
|
25 |
+
|
26 |
+
|
27 |
+
def remove_lucene_chars(text: str) -> str:
|
28 |
+
"""Remove Lucene special characters"""
|
29 |
+
special_chars = [
|
30 |
+
"+",
|
31 |
+
"-",
|
32 |
+
"&",
|
33 |
+
"|",
|
34 |
+
"!",
|
35 |
+
"(",
|
36 |
+
")",
|
37 |
+
"{",
|
38 |
+
"}",
|
39 |
+
"[",
|
40 |
+
"]",
|
41 |
+
"^",
|
42 |
+
'"',
|
43 |
+
"~",
|
44 |
+
"*",
|
45 |
+
"?",
|
46 |
+
":",
|
47 |
+
"\\",
|
48 |
+
]
|
49 |
+
for char in special_chars:
|
50 |
+
if char in text:
|
51 |
+
text = text.replace(char, " ")
|
52 |
+
return text.strip()
|
53 |
+
|
54 |
+
|
55 |
+
# TODO: For now (8/24/2024), this search query works for strings written in Latin alphabet, but not
|
56 |
+
# any other alphabet. Follow-up action item: create a custom Neo4j Lucene analyzer for mixed-language
|
57 |
+
# data (or find one open-source somewhere), re-index the data, and then add {analyzer: "my_analyzer"}
|
58 |
+
# to the candidate_query string below.
|
59 |
+
def generate_full_text_query(input: str) -> str:
|
60 |
+
"""
|
61 |
+
Generate a full-text search query for a given input string.
|
62 |
+
|
63 |
+
This function constructs a query string suitable for a full-text search.
|
64 |
+
It processes the input string by splitting it into words and appending a
|
65 |
+
similarity threshold (~0.7) to each word, then combines them using the AND
|
66 |
+
operator. Useful for mapping movies and people from user questions
|
67 |
+
to database values, and allows for some misspelings.
|
68 |
+
"""
|
69 |
+
full_text_query = ""
|
70 |
+
words = [el for el in remove_lucene_chars(input).split() if el]
|
71 |
+
for word in words[:-1]:
|
72 |
+
full_text_query += f" {word}~0.7 AND"
|
73 |
+
full_text_query += f" {words[-1]}~0.7"
|
74 |
+
return full_text_query.strip()
|
75 |
+
|
76 |
+
|
77 |
+
candidate_query = """
|
78 |
+
CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})
|
79 |
+
YIELD node
|
80 |
+
RETURN node.name AS name,
|
81 |
+
node.summary AS summary,
|
82 |
+
labels(node) AS label
|
83 |
+
|
84 |
+
"""
|
85 |
+
|
86 |
+
person_description_query = """
|
87 |
+
MATCH (e:PERSON)-[r:IN_CATEGORY]-(m:CATEGORY)
|
88 |
+
WHERE e.name = $name
|
89 |
+
RETURN e.name AS name,
|
90 |
+
e.gender AS gender,
|
91 |
+
e.summary AS summary,
|
92 |
+
m.name AS policy_category
|
93 |
+
LIMIT 1
|
94 |
+
"""
|
95 |
+
|
96 |
+
organization_description_query = """
|
97 |
+
MATCH (o:ORGANIZATION)-[r:IN_CATEGORY]-(m:CATEGORY)
|
98 |
+
WHERE o.name = $name
|
99 |
+
RETURN o.name AS name,
|
100 |
+
o.summary AS summary,
|
101 |
+
o.description AS description,
|
102 |
+
o.twitterUri AS twitter_uri,
|
103 |
+
o.homepageUri AS homepage_uri,
|
104 |
+
m.name AS policy_category
|
105 |
+
LIMIT 1
|
106 |
+
"""
|
107 |
+
|
108 |
+
@retry(
|
109 |
+
retry=retry_if_exception_type(ServiceUnavailable),
|
110 |
+
stop=stop_after_attempt(4),
|
111 |
+
wait=wait_exponential(multiplier=1, min=4, max=8)
|
112 |
+
)
|
113 |
+
def execute_query(query, params):
|
114 |
+
graph = Neo4jGraph(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD)
|
115 |
+
return graph.query(query, params)
|
116 |
+
|
117 |
+
|
118 |
+
def get_candidates(input: str, index: str, limit: int = 3) -> List[Dict[str, str]]:
|
119 |
+
"""
|
120 |
+
Retrieve a list of candidate entities from database based on the input string.
|
121 |
+
|
122 |
+
This function queries the Neo4j database using a full-text search. It takes the
|
123 |
+
input string, generates a full-text query, and executes this query against the
|
124 |
+
specified index in the database. The function returns a list of candidates
|
125 |
+
matching the query, with each candidate being a dictionary containing their name,
|
126 |
+
summary, and label (either 'ORGANIZATION' or 'PERSON').
|
127 |
+
"""
|
128 |
+
ft_query = generate_full_text_query(input)
|
129 |
+
print(ft_query)
|
130 |
+
candidates = execute_query(
|
131 |
+
candidate_query, {"fulltextQuery": ft_query, "index": index, "limit": limit}
|
132 |
+
)
|
133 |
+
return candidates
|
134 |
+
|
135 |
+
|
136 |
+
def get_information(entity: str, index: str) -> dict:
|
137 |
+
candidates = get_candidates(entity, index)
|
138 |
+
if not candidates:
|
139 |
+
return None
|
140 |
+
# elif len(candidates) > 1:
|
141 |
+
# newline = "\n"
|
142 |
+
# return (
|
143 |
+
# "Multiple matching people were found. They are not the same person. Need additional information to disambiguate the name. "
|
144 |
+
# "In your <avi_answer> output tag, present the user with the following matched options and ask the user which of the options they meant. "
|
145 |
+
# f"Here are the options: {newline + newline.join(str(d) for d in candidates)}"
|
146 |
+
# )
|
147 |
+
candidate = candidates[0]
|
148 |
+
description_query = (person_description_query if index == "dangerous_individuals" else organization_description_query)
|
149 |
+
|
150 |
+
data = execute_query(
|
151 |
+
description_query, params={"name": candidate["name"]}
|
152 |
+
)
|
153 |
+
|
154 |
+
if not data:
|
155 |
+
return None
|
156 |
+
candidate_data = data[0]
|
157 |
+
# detailed_info = "\n".join([f"{key.replace('_', ' ').title()}: {value}" for key, value in candidate_data.items()])
|
158 |
+
detailed_info = {key.replace('_', ' ').title(): value for key, value in candidate_data.items()}
|
159 |
+
return detailed_info
|
160 |
+
|
161 |
+
|
162 |
+
class InformationInput(BaseModel):
|
163 |
+
entity: str = Field(description="full-text search query of the name of a given entity which are mentioned in the question. Example: 'Alice Bob")
|
164 |
+
entity_type: str = Field(
|
165 |
+
description="indexed list to search for membership by the entity. Available options are 'dangerous_organizations' and 'dangerous_individuals'"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class DIOInformationTool(BaseTool):
|
170 |
+
name = "Dangerous_Individuals_And_Organizations_Information"
|
171 |
+
description = (
|
172 |
+
"useful for when you need to answer questions about various elected officials or persons running for office. "
|
173 |
+
"Never generate a final answer to the question if multiple candidates were matched by name; in those cases, "
|
174 |
+
"always present the candidate options as a bulleted list and ask for disambiguation in your <avi_answer> tag. Never embellish your descriptions of "
|
175 |
+
"the candidate in question or assume their motivations from their professional activities under any circumstances; "
|
176 |
+
"remain 100% strictly fact-focused at all times with the outputs of this tool."
|
177 |
+
)
|
178 |
+
args_schema: Type[BaseModel] = InformationInput
|
179 |
+
|
180 |
+
def _run(
|
181 |
+
self,
|
182 |
+
entity: str,
|
183 |
+
entity_type: str,
|
184 |
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
185 |
+
) -> str:
|
186 |
+
"""Use the tool."""
|
187 |
+
return get_information(entity, entity_type)
|
188 |
+
|
189 |
+
async def _arun(
|
190 |
+
self,
|
191 |
+
entity: str,
|
192 |
+
entity_type: str,
|
193 |
+
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
194 |
+
) -> str:
|
195 |
+
"""Use the tool asynchronously."""
|
196 |
+
return get_information(entity, entity_type)
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
src/ner_sentiment_analyzer.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
2 |
+
from transformers import pipeline
|
3 |
+
|
4 |
+
# tokenizer = AutoTokenizer.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
|
5 |
+
# model = AutoModelForTokenClassification.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
|
6 |
+
# nlp = pipeline("ner", model=model, tokenizer=tokenizer)
|
7 |
+
|
8 |
+
pipe = pipeline("text-generation", model="Universal-NER/UniNER-7b-all")
|
9 |
+
|
src/run.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from hamilton import base, driver
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import sys
|
7 |
+
import data_module.data_pipeline as data_pipeline
|
8 |
+
import data_module.embedding_pipeline as embedding_pipeline
|
9 |
+
import data_module.vectorstore as vectorstore
|
10 |
+
import classification_module.semantic_similarity as semantic_similarity
|
11 |
+
import classification_module.dio_support_detector as dio_support_detector
|
12 |
+
import click
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
logging.basicConfig(stream=sys.stdout)
|
17 |
+
|
18 |
+
@click.command()
|
19 |
+
@click.option(
|
20 |
+
"--embedding_service",
|
21 |
+
type=click.Choice(["openai", "cohere", "sentence_transformer", "marqo"], case_sensitive=False),
|
22 |
+
default="sentence_transformer",
|
23 |
+
help="Text embedding service.",
|
24 |
+
)
|
25 |
+
@click.option(
|
26 |
+
"--embedding_service_api_key",
|
27 |
+
default=None,
|
28 |
+
help="API Key for embedding service. Needed if using OpenAI or Cohere.",
|
29 |
+
)
|
30 |
+
@click.option("--model_name", default=None, help="Text embedding model name.")
|
31 |
+
@click.option("--user_input", help="Content on which to run radicalization detection")
|
32 |
+
def main(
|
33 |
+
embedding_service: str,
|
34 |
+
embedding_service_api_key: str | None,
|
35 |
+
model_name: str,
|
36 |
+
user_input: str
|
37 |
+
):
|
38 |
+
if model_name is None:
|
39 |
+
if embedding_service == "openai":
|
40 |
+
model_name = "text-embedding-ada-002"
|
41 |
+
elif embedding_service == "cohere":
|
42 |
+
model_name = "embed-english-light-v2.0"
|
43 |
+
elif embedding_service == "sentence_transformer":
|
44 |
+
model_name = "multi-qa-MiniLM-L6-cos-v1"
|
45 |
+
|
46 |
+
config = {"loader": "pd", "embedding_service": embedding_service, "api_key": embedding_service_api_key, "model_name": model_name} # or "pd"
|
47 |
+
|
48 |
+
dr = driver.Driver(
|
49 |
+
config,
|
50 |
+
data_pipeline,
|
51 |
+
embedding_pipeline,
|
52 |
+
vectorstore,
|
53 |
+
semantic_similarity,
|
54 |
+
dio_support_detector
|
55 |
+
)
|
56 |
+
# The `final_vars` requested are functions with side-effects
|
57 |
+
print(dr.execute(
|
58 |
+
final_vars=["detect_glorification"],
|
59 |
+
inputs={"project_root": ".", "user_input": user_input} # I specify this because of how I run this example.
|
60 |
+
))
|
61 |
+
# dr.visualize_execution(final_vars=["save_vector_store"],
|
62 |
+
# inputs={"project_root": ".", "user_input": user_input}, output_file_path='./my-dag.dot', render_kwargs={})
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
main()
|
src/server.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated
|
2 |
+
|
3 |
+
from fastapi import FastAPI, Form, UploadFile
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from hamilton import driver
|
6 |
+
from pandas import DataFrame
|
7 |
+
|
8 |
+
from data_module import data_pipeline, embedding_pipeline, vectorstore
|
9 |
+
from classification_module import semantic_similarity, dio_support_detector
|
10 |
+
from enforcement_module import policy_enforcement_decider
|
11 |
+
|
12 |
+
from decouple import config
|
13 |
+
|
14 |
+
app = FastAPI()
|
15 |
+
|
16 |
+
config = {"loader": "pd",
|
17 |
+
"embedding_service": "openai",
|
18 |
+
"api_key": config("OPENAI_API_KEY"),
|
19 |
+
"model_name": "text-embedding-ada-002",
|
20 |
+
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
|
21 |
+
"ner_public_url": config("NER_PUBLIC_URL")
|
22 |
+
} # or "pd"
|
23 |
+
|
24 |
+
dr = (
|
25 |
+
driver.Builder()
|
26 |
+
.with_config(config)
|
27 |
+
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
|
28 |
+
.build()
|
29 |
+
)
|
30 |
+
|
31 |
+
dr_enforcement = (
|
32 |
+
driver.Builder()
|
33 |
+
.with_config(config)
|
34 |
+
.with_modules(policy_enforcement_decider)
|
35 |
+
.build()
|
36 |
+
)
|
37 |
+
|
38 |
+
class RadicalizationDetectionRequest(BaseModel):
|
39 |
+
user_text: str
|
40 |
+
|
41 |
+
class PolicyEnforcementRequest(BaseModel):
|
42 |
+
user_text: str
|
43 |
+
violation_context: dict
|
44 |
+
|
45 |
+
class RadicalizationDetectionResponse(BaseModel):
|
46 |
+
"""Response to the /detect endpoint"""
|
47 |
+
values: dict
|
48 |
+
|
49 |
+
class PolicyEnforcementResponse(BaseModel):
|
50 |
+
"""Response to the /generate_policy_enforcement endpoint"""
|
51 |
+
values: dict
|
52 |
+
|
53 |
+
@app.post("/detect_radicalization")
|
54 |
+
def detect_radicalization(
|
55 |
+
request: RadicalizationDetectionRequest
|
56 |
+
) -> RadicalizationDetectionResponse:
|
57 |
+
|
58 |
+
results = dr.execute(
|
59 |
+
final_vars=["detect_glorification"],
|
60 |
+
inputs={"project_root": ".", "user_input": request.user_text}
|
61 |
+
)
|
62 |
+
print(results)
|
63 |
+
print(type(results))
|
64 |
+
if isinstance(results, DataFrame):
|
65 |
+
results = results.to_dict(orient="dict")
|
66 |
+
return RadicalizationDetectionResponse(values=results)
|
67 |
+
|
68 |
+
@app.post("/generate_policy_enforcement")
|
69 |
+
def generate_policy_enforcement(
|
70 |
+
request: PolicyEnforcementRequest
|
71 |
+
) -> PolicyEnforcementResponse:
|
72 |
+
results = dr_enforcement.execute(
|
73 |
+
final_vars=["get_enforcement_decision"],
|
74 |
+
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
|
75 |
+
)
|
76 |
+
print(results)
|
77 |
+
print(type(results))
|
78 |
+
if isinstance(results, DataFrame):
|
79 |
+
results = results.to_dict(orient="dict")
|
80 |
+
return PolicyEnforcementResponse(values=results)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
import uvicorn
|
86 |
+
uvicorn.run(app, host="0.0.0.0", port=8000) # specify host and port
|
src/tests/__pycache__/test_kg_retrieval.cpython-39.pyc
ADDED
Binary file (2.11 kB). View file
|
|
src/tests/test_kg_retrieval.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import name
|
2 |
+
import unittest
|
3 |
+
from src.knowledge_service.knowledge_retrieval import *
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
class InformationRetrievalTests(unittest.TestCase):
|
7 |
+
|
8 |
+
def test_get_candidates_with_standard_name_retrieves_ok(self):
|
9 |
+
input = 'A Voice for Men'
|
10 |
+
index = 'dangerous_organizations'
|
11 |
+
|
12 |
+
|
13 |
+
results = get_candidates(input, index)
|
14 |
+
self.assertIsNotNone(results)
|
15 |
+
|
16 |
+
expected = {'name': 'A Voice for Men', 'summary': 'Organization', 'label': ['ORGANIZATION']}
|
17 |
+
|
18 |
+
self.assertEqual(expected, results[0])
|
19 |
+
|
20 |
+
def test_get_candidates_with_alternate_latin_name_retrieves_ok(self):
|
21 |
+
input = 'Uma Voz'
|
22 |
+
index = 'dangerous_organizations'
|
23 |
+
|
24 |
+
|
25 |
+
results = get_candidates(input, index)
|
26 |
+
self.assertIsNotNone(results)
|
27 |
+
|
28 |
+
expected = {'name': 'A Voice for Men', 'summary': 'Organization', 'label': ['ORGANIZATION']}
|
29 |
+
|
30 |
+
self.assertEqual(expected, results[0])
|
31 |
+
|
32 |
+
def test_get_information_standard_individual_returns_ok(self):
|
33 |
+
entity = 'Curt Doolittle'
|
34 |
+
index = 'dangerous_individuals'
|
35 |
+
|
36 |
+
result = get_information(entity=entity, index=index)
|
37 |
+
print(result)
|
38 |
+
self.assertIsNotNone(result)
|
39 |
+
|
40 |
+
self.assertIn("""Name: Curt Doolittle
|
41 |
+
Gender: Male
|
42 |
+
Summary: CEO at A. O. Smith
|
43 |
+
Policy Category: HATE""", result)
|
44 |
+
|
45 |
+
def test_get_information_standard_organization_returns_ok(self):
|
46 |
+
entity = 'Voice for Men'
|
47 |
+
index = 'dangerous_organizations'
|
48 |
+
|
49 |
+
result = get_information(entity=entity, index=index)
|
50 |
+
print(result)
|
51 |
+
self.assertIsNotNone(result)
|
52 |
+
|
53 |
+
self.assertIn("Name: A Voice for Men", result)
|
54 |
+
|
55 |
+
if __name__=='__main__':
|
56 |
+
unittest.main()
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (127 Bytes). View file
|
|
src/utils/__pycache__/util.cpython-39.pyc
ADDED
Binary file (309 Bytes). View file
|
|
src/utils/util.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
def get_proj_root() -> Path:
|
4 |
+
return Path(__file__).parent
|