brichett commited on
Commit
854f61d
·
verified ·
1 Parent(s): 86a3531

upload src folder

Browse files
Files changed (40) hide show
  1. src/__init__.py +0 -0
  2. src/__pycache__/__init__.cpython-39.pyc +0 -0
  3. src/__pycache__/data_pipeline.cpython-39.pyc +0 -0
  4. src/__pycache__/dio_support_detector.cpython-39.pyc +0 -0
  5. src/__pycache__/embedding_pipeline.cpython-39.pyc +0 -0
  6. src/__pycache__/semantic_similarity.cpython-39.pyc +0 -0
  7. src/__pycache__/server.cpython-39.pyc +0 -0
  8. src/__pycache__/util.cpython-39.pyc +0 -0
  9. src/__pycache__/vectorstore.cpython-39.pyc +0 -0
  10. src/classification_module/__init__.py +0 -0
  11. src/classification_module/__pycache__/__init__.cpython-39.pyc +0 -0
  12. src/classification_module/__pycache__/dio_support_detector.cpython-39.pyc +0 -0
  13. src/classification_module/__pycache__/semantic_similarity.cpython-39.pyc +0 -0
  14. src/classification_module/dio_support_detector.py +135 -0
  15. src/classification_module/semantic_similarity.py +68 -0
  16. src/client.py +66 -0
  17. src/data_module/__init__.py +0 -0
  18. src/data_module/__pycache__/__init__.cpython-39.pyc +0 -0
  19. src/data_module/__pycache__/data_pipeline.cpython-39.pyc +0 -0
  20. src/data_module/__pycache__/embedding_pipeline.cpython-39.pyc +0 -0
  21. src/data_module/__pycache__/vectorstore.cpython-39.pyc +0 -0
  22. src/data_module/data_pipeline.py +47 -0
  23. src/data_module/embedding_pipeline.py +84 -0
  24. src/data_module/vectorstore.py +35 -0
  25. src/enforcement_module/__init__.py +0 -0
  26. src/enforcement_module/__pycache__/__init__.cpython-39.pyc +0 -0
  27. src/enforcement_module/__pycache__/policy_enforcement_decider.cpython-39.pyc +0 -0
  28. src/enforcement_module/policy_enforcement_decider.py +59 -0
  29. src/gradio_server.py +133 -0
  30. src/knowledge_service/__pycache__/knowledge_retrieval.cpython-39.pyc +0 -0
  31. src/knowledge_service/knowledge_retrieval.py +200 -0
  32. src/ner_sentiment_analyzer.py +9 -0
  33. src/run.py +65 -0
  34. src/server.py +86 -0
  35. src/tests/__pycache__/test_kg_retrieval.cpython-39.pyc +0 -0
  36. src/tests/test_kg_retrieval.py +60 -0
  37. src/utils/__init__.py +0 -0
  38. src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  39. src/utils/__pycache__/util.cpython-39.pyc +0 -0
  40. src/utils/util.py +4 -0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (121 Bytes). View file
 
src/__pycache__/data_pipeline.cpython-39.pyc ADDED
Binary file (1.82 kB). View file
 
src/__pycache__/dio_support_detector.cpython-39.pyc ADDED
Binary file (2.75 kB). View file
 
src/__pycache__/embedding_pipeline.cpython-39.pyc ADDED
Binary file (2.73 kB). View file
 
src/__pycache__/semantic_similarity.cpython-39.pyc ADDED
Binary file (2.35 kB). View file
 
src/__pycache__/server.cpython-39.pyc ADDED
Binary file (1.34 kB). View file
 
src/__pycache__/util.cpython-39.pyc ADDED
Binary file (303 Bytes). View file
 
src/__pycache__/vectorstore.cpython-39.pyc ADDED
Binary file (1.6 kB). View file
 
src/classification_module/__init__.py ADDED
File without changes
src/classification_module/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (143 Bytes). View file
 
src/classification_module/__pycache__/dio_support_detector.cpython-39.pyc ADDED
Binary file (3.42 kB). View file
 
src/classification_module/__pycache__/semantic_similarity.cpython-39.pyc ADDED
Binary file (2.35 kB). View file
 
src/classification_module/dio_support_detector.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ from hamilton.function_modifiers import extract_fields
3
+ from types import ModuleType
4
+ import math
5
+ from langchain.llms import OpenAI
6
+ import requests
7
+ import requests.models
8
+ from knowledge_service.knowledge_retrieval import get_information
9
+ import traceback
10
+
11
+ prompt = """
12
+
13
+ Envy:
14
+
15
+ Contempt: "What the hell is wrong with this country? Why is the official page of police in NRW tweeting in Arabic? Are they seeking to appease the barbaric, Muslim, rapist hordes of men?"
16
+
17
+ Humiliation:
18
+
19
+ Pride:
20
+
21
+ Elation:
22
+
23
+ """
24
+
25
+ def detect_entity_support(user_input: str, analyze_affect: dict, ner_public_url: str) -> dict:
26
+ """
27
+ Detect whether the user input text is glorifying or supporting an entity
28
+ from the Dangerous Individuals & Organizations KG.
29
+ """
30
+ input_text = {
31
+ "text": user_input,
32
+ "entity_type": "organization"
33
+ }
34
+
35
+ # Function to fetch the response from the Mistral model
36
+ def fetch_ner_entities(input_text):
37
+ response: requests.models.Response = requests.post(f'{ner_public_url}/universal-ner', json=input_text, stream=False)
38
+ response.raise_for_status()
39
+ result = response.json()
40
+ print(result)
41
+ if result.get('ner_output'):
42
+ output = result.get('ner_output').strip()
43
+ try:
44
+ output = ast.literal_eval(output)
45
+ except:
46
+ traceback.print_exc()
47
+ return []
48
+
49
+ return output
50
+ return []
51
+
52
+ def fetch_ner_orgs(user_input):
53
+ input_text = {
54
+ "text": user_input,
55
+ "entity_type": "organization"
56
+ }
57
+ return fetch_ner_entities(input_text)
58
+
59
+ def fetch_ner_persons(user_input):
60
+ input_text = {
61
+ "text": user_input,
62
+ "entity_type": "person"
63
+ }
64
+ return fetch_ner_entities(input_text)
65
+
66
+ # Fetch the entities
67
+ extracted_entities = fetch_ner_orgs(user_input)
68
+ extracted_entities.extend(fetch_ner_persons(user_input))
69
+
70
+ for entity in extracted_entities:
71
+ entity = entity.strip()
72
+ entity_info = get_information(entity, "dangerous_organizations")
73
+ if entity_info:
74
+ analyze_affect.update({
75
+ "entity_detected": True,
76
+ "entity_name": entity,
77
+ "entity_info": entity_info
78
+ })
79
+ return analyze_affect
80
+ entity_info = get_information(entity, 'dangerous_individuals')
81
+ if entity_info:
82
+ analyze_affect.update({
83
+ "entity_detected": True,
84
+ "entity_name": entity,
85
+ "entity_info": entity_info
86
+ })
87
+ return analyze_affect
88
+ analyze_affect['entity_detected'] = False
89
+ return analyze_affect
90
+
91
+ @extract_fields(
92
+ dict(
93
+ extremism_detected=bool,
94
+ ideology=str,
95
+ type_label=str,
96
+ entity_detected=bool,
97
+ entity_name=str,
98
+ entity_info=dict,
99
+ aspect_sentiment=str
100
+ )
101
+ )
102
+ def detect_glorification(
103
+ user_input: str,
104
+ detect_entity_support: dict,
105
+ mistral_public_url: str
106
+ ) -> dict:
107
+ """
108
+ Analyze the sentiment of the input text and determine if it glorifies or supports an entity.
109
+ """
110
+ if detect_entity_support["entity_detected"]:
111
+ prompt = (
112
+ f"Analyze the following text to determine if it glorifies or supports the following entity:\n\n"
113
+ f"Entity: {detect_entity_support['entity_name']}\n\n"
114
+ f"Text: '{user_input}'\n\n"
115
+ "The answer should be one of the following: Glorifying, Supporting, Neutral, Negative."
116
+ )
117
+
118
+ input_text = {
119
+ "context": f"User input text: {user_input}\n\nEntity information: {detect_entity_support['entity_info']}",
120
+ "question": "Does the text glorify, support, or speak neutrally or negatively about the entity? Classify the sentiment as one of Glorifying, Supporting, Neutral, Negative:"
121
+ }
122
+
123
+ response = requests.post(f'{mistral_public_url}/mistral-inference', json=input_text, stream=False)
124
+
125
+ detect_entity_support.update({
126
+ "aspect_sentiment": response.text.strip()
127
+ })
128
+ return dict(
129
+ **detect_entity_support
130
+ )
131
+
132
+ detect_entity_support["aspect_sentiment"] = "None"
133
+ return dict(
134
+ **detect_entity_support
135
+ )
src/classification_module/semantic_similarity.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from statistics import mode
2
+ from langchain_community.vectorstores import FAISS
3
+ from types import ModuleType
4
+ import math
5
+ from langchain_community.llms import OpenAI
6
+ import requests
7
+ import requests.models
8
+ # from decouple import config
9
+
10
+
11
+ def classify_text(
12
+ user_input: str,
13
+ load_vector_store: ModuleType
14
+ ) -> dict:
15
+ faiss: FAISS = load_vector_store
16
+ results = faiss.similarity_search_with_relevance_scores(user_input)
17
+ avg_similarity_score = sum([result[1] for result in results]) / len(results)
18
+ if avg_similarity_score > 0.7:
19
+ print(f"Extremism {avg_similarity_score} detected, initiating countermeasures protocol... ")
20
+ print(results)
21
+ label = mode([result[0].metadata.get("label", None) for result in results])
22
+ ideology = mode([result[0].metadata.get("ideology", None) for result in results])
23
+ return {"extremism_detected": True, "ideology": ideology, "type_label": label}
24
+ else:
25
+ return {"extremism_detected": False, "type_label": None}
26
+
27
+ def analyze_affect(
28
+ user_input: str,
29
+ classify_text: dict,
30
+ mistral_public_url: str
31
+ ) -> dict:
32
+ if (classify_text["extremism_detected"] == True):
33
+ prompt = (
34
+ f"Analyze the following text for its emotional tone (affect):\n\n"
35
+ f"'{user_input}'\n\n"
36
+ "The affect is likely one of the following: Positive, Negative, Neutral, Mixed. Please classify:"
37
+ )
38
+ # TODO: fix my poetry PATH reference so that I can add decouple to pyproject.toml and switch out the below line
39
+ # openai_client = OpenAI(api_key="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")
40
+ # messages = []
41
+ # messages.append(
42
+ # {
43
+ # "role": "user",
44
+ # "content": prompt
45
+ # }
46
+ # )
47
+ # try:
48
+ # response = openai_client.chat.completions.create(model="gpt-3.5-turbo-1106",
49
+ # messages=messages,
50
+ # temperature=0)
51
+
52
+ # return response.choices[0].message.content
53
+ input_text = {"context": f"User input text: {user_input}", "question": ("The above text's emotional tone (affect) is likely one of the following: "
54
+ "Positive, Negative, Neutral, Mixed. Please classify it. "
55
+ "Answer with only a single word, which is the classification label you give to the above text, nothing else:\n")}
56
+
57
+ # Function to fetch streaming response
58
+ def fetch_data():
59
+ response: requests.models.Response = requests.post(f'{mistral_public_url}/mistral-inference', json=input_text, stream=False)
60
+
61
+ return response.text.strip()
62
+
63
+ # Call the function to start fetching data
64
+ result = fetch_data()
65
+ classify_text['sentiment'] = result
66
+ return classify_text
67
+ return classify_text
68
+
src/client.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import IO
2
+ import requests
3
+ from pprint import pprint
4
+
5
+ def post_detect_radicalization(
6
+ user_text: str,
7
+ ) -> requests.Response:
8
+ """POST request to summarize a PDF via the `/summarize` endpoint"""
9
+ return requests.post(
10
+ url="http://localhost:8000/detect_radicalization", # http://HOST:PORT/ENDPOINT as specified in server.py
11
+ headers={"Content-Type": "application/json"},
12
+ json={
13
+ "user_text": user_text
14
+ }
15
+ )
16
+
17
+ def post_generate_policy_enforcement(
18
+ user_text: str,
19
+ violation_context: dict
20
+ ) -> requests.Response:
21
+ """POST request to get a policy enforcement decision via the `/generate_policy_enforcement` endpoint"""
22
+ return requests.post(
23
+ url="http://localhost:8000/generate_policy_enforcement",
24
+ headers={"Content-Type": "application/json"},
25
+ json={
26
+ "user_text": user_text,
27
+ "violation_context": violation_context
28
+ }
29
+ )
30
+
31
+ def run_test_case(user_text):
32
+ response = post_detect_radicalization(user_text=user_text).json()
33
+
34
+ print(response)
35
+
36
+ response = post_generate_policy_enforcement(user_text=user_text, violation_context=response['values']['detect_glorification'])
37
+
38
+ pprint(response.json()['values']['get_enforcement_decision'])
39
+
40
+
41
+ test_cases = [
42
+ ("We mustn't let the jews pull the wool over our eyes. The only good jew is a dead jew. "
43
+ "That's why I think we need to donate to the American Nazi Party, because they are doing God's "
44
+ "good work to save Western civilization from the scourge."),
45
+ ("Eu acho que o Movimento de Resistência Africâner está fazendo um ótimo trabalho ao destruir "
46
+ "essas feministas sujas. Eles têm meu apoio eterno!"),
47
+ ("أبو سياف هم أبطال الإسلام الحقيقيين، يدافعون عن ديننا ويقاومون الطغاة. الله يحميهم ويبارك في جهادهم المقدس."),
48
+ ("مؤسسة أجناد هي الصوت الذي ينقل الحقائق الإسلامية وينير عقول الأمة. دعواتنا لهم بالمزيد من القوة والنصر في نشر رسالة الإسلام الصحيحة."),
49
+ ("Generation Identity is the true defender of European culture and heritage. Their efforts to protect our identity "
50
+ "from being erased by mass immigration are commendable and should be supported."),
51
+ ("Mutil harroak taldeak ausardi handiz babesten ditu gure ohiturak eta kultura. "
52
+ "Herri honen etorkizuna babesteko prest dauden heroiak dira."),
53
+ ("Proud Boys are the last line of defense for Western values. Their courage to stand up against political "
54
+ "correctness and defend our way of life is something that should be celebrated. All hail Gavin McGinnes!"),
55
+ ("Ein Prozent für unser Land kämpft unermüdlich dafür, dass Deutschland deutsch bleibt. "
56
+ "Ihre Arbeit für den Erhalt unserer Kultur und gegen die Überfremdung verdient unsere volle Unterstützung."),
57
+ ("Les Identitaires sont les vrais protecteurs de notre identité nationale. "
58
+ "Leur combat pour préserver nos traditions face à l'invasion étrangère est héroïque et doit être encouragé."),
59
+ ("Lexón de San Miguel Arcánxel defende con valentía a nosa terra e cultura contra as influencias externas. "
60
+ "Son os verdadeiros gardiáns da nosa identidade."),
61
+ ("Polijas Nacionālā Atdzimšana ir īstie patrioti, kas cīnās par mūsu valsts nākotni un aizsargā "
62
+ "to no ārējām ietekmēm. Viņu darbs ir pelnījis cieņu un atbalstu.")
63
+ ]
64
+
65
+ for case in test_cases:
66
+ run_test_case(case)
src/data_module/__init__.py ADDED
File without changes
src/data_module/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (133 Bytes). View file
 
src/data_module/__pycache__/data_pipeline.cpython-39.pyc ADDED
Binary file (1.84 kB). View file
 
src/data_module/__pycache__/embedding_pipeline.cpython-39.pyc ADDED
Binary file (2.74 kB). View file
 
src/data_module/__pycache__/vectorstore.cpython-39.pyc ADDED
Binary file (1.61 kB). View file
 
src/data_module/data_pipeline.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import load_dataset
3
+ from utils.util import get_proj_root
4
+ from hamilton.function_modifiers import extract_fields, config
5
+
6
+
7
+ @config.when(loader="hf")
8
+ def miws_dataset__hf(project_root: str = get_proj_root()) -> pd.DataFrame:
9
+ """Load the MIWS dataset using the HuggingFace Dataset library"""
10
+ dataset = load_dataset(f"{project_root}/data2/", split="train")
11
+ return dataset.to_pandas()
12
+
13
+
14
+ @config.when(loader="pd")
15
+ def miws_dataset__pd(project_root: str = get_proj_root()) -> pd.DataFrame:
16
+ """Load the MIWS dataset using the Pandas library"""
17
+ return pd.read_csv(f"{project_root}/data/MIWS_Seed_Complete.csv", delimiter=',', quotechar='"')
18
+
19
+
20
+ def clean_dataset(miws_dataset: pd.DataFrame) -> pd.DataFrame:
21
+ """Remove duplicates and drop unecessary columns"""
22
+ cleaned_df = miws_dataset[["ID", "Text", "Ideology", "Label", "Geographical_Location"]]
23
+ return cleaned_df
24
+
25
+
26
+ def filtered_dataset(clean_dataset: pd.DataFrame, n_rows_to_keep: int = 400) -> pd.DataFrame:
27
+ if n_rows_to_keep is None:
28
+ return clean_dataset
29
+
30
+ return clean_dataset.head(n_rows_to_keep)
31
+
32
+
33
+ @extract_fields(
34
+ dict(
35
+ ids=list[int],
36
+ text_contents=list[str],
37
+ ideologies=list[str],
38
+ labels=list[str],
39
+ locations=list[str],
40
+ )
41
+ )
42
+ def dataset_rows(filtered_dataset: pd.DataFrame) -> dict:
43
+ """Convert dataframe to dict of lists"""
44
+ df_as_dict = filtered_dataset.to_dict(orient="list")
45
+ return dict(
46
+ ids=df_as_dict["ID"], text_contents=df_as_dict["Text"], ideologies=df_as_dict["Ideology"], labels=df_as_dict["Label"], locations=df_as_dict["Geographical_Location"],
47
+ )
src/data_module/embedding_pipeline.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import ModuleType
2
+
3
+ import numpy as np
4
+ import openai
5
+ from langchain_community.vectorstores import FAISS
6
+ # from sentence_transformers import SentenceTransformer
7
+ from hamilton.function_modifiers import config, extract_fields
8
+
9
+
10
+ @config.when(embedding_service="openai")
11
+ @extract_fields(
12
+ dict(
13
+ embedding_dimension=int,
14
+ embedding_metric=str,
15
+ )
16
+ )
17
+ def embedding_config__openai(model_name: str) -> dict:
18
+ if model_name == "text-embedding-ada-002":
19
+ return dict(embedding_dimension=1536, embedding_metric="cosine")
20
+ # If you support more models, you would add that here
21
+ raise ValueError(f"Invalid `model_name`[{model_name}] for openai was passed.")
22
+
23
+
24
+
25
+ @config.when(embedding_service="sentence_transformer")
26
+ @extract_fields(
27
+ dict(
28
+ embedding_dimension=int,
29
+ embedding_metric=str
30
+ )
31
+ )
32
+ def embedding_config__sentence_transformer(model_name: str) -> dict:
33
+ if model_name == "multi-qa-MiniLM-L6-cos-v1":
34
+ return dict(embedding_dimension=384, embedding_metric="cosine")
35
+ # If you support more models, you would add that here
36
+ raise ValueError(f"Invalid `model_name`[{model_name}] for SentenceTransformer was passed.")
37
+
38
+ @config.when(embedding_service="openai")
39
+ def embedding_provider__openai(api_key: str) -> ModuleType:
40
+ """Set OpenAI API key"""
41
+ openai.api_key = api_key
42
+ return openai
43
+
44
+
45
+ @config.when(embedding_service="openai")
46
+ def embeddings__openai(
47
+ embedding_provider: ModuleType,
48
+ text_contents: list[str],
49
+ model_name: str = "text-embedding-ada-002",
50
+ ) -> list[np.ndarray]:
51
+ """Convert text to vector representations (embeddings) using OpenAI Embeddings API
52
+ reference: https://github.com/openai/openai-cookbook/blob/main/examples/Get_embeddings.ipynb
53
+ """
54
+ response = embedding_provider.Embedding.create(input=text_contents, engine=model_name)
55
+ return [np.asarray(obj["embedding"]) for obj in response["data"]]
56
+
57
+
58
+ # @config.when(embedding_service="sentence_transformer")
59
+ # def embeddings__sentence_transformer(
60
+ # text_contents: list[str], model_name: str = "multi-qa-MiniLM-L6-cos-v1"
61
+ # ) -> list[np.ndarray]:
62
+ # """Convert text to vector representations (embeddings)
63
+ # model card: https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
64
+ # reference: https://www.sbert.net/examples/applications/computing-embeddings/README.html
65
+ # """
66
+ # embedding_provider = SentenceTransformer(model_name)
67
+ # embeddings = embedding_provider.encode(text_contents, convert_to_numpy=True)
68
+ # return list(embeddings)
69
+
70
+ def text_contents2(
71
+ text_contents: list[str]
72
+ ) -> list[str]:
73
+ return text_contents
74
+
75
+
76
+
77
+ def data_objects(
78
+ ids: list[int], ideologies: list[str], labels: list[str], embeddings: list[np.ndarray]
79
+ ) -> list[tuple]:
80
+ assert len(labels) == len(embeddings) # == len(locations)
81
+ properties = [dict(id=id, ideology=ideology, label=label) for id, ideology, label in zip(ids, ideologies, labels)]
82
+ embeddings = [x.tolist() for x in embeddings]
83
+
84
+ return list(zip(labels, embeddings, properties))
src/data_module/vectorstore.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.embeddings import OpenAIEmbeddings
3
+ from types import ModuleType
4
+ import numpy as np
5
+
6
+ # Note to self: try debugging this by visualizing the current DAG
7
+
8
+ def client_vector_store(
9
+ api_key: str,
10
+ ids: list[int],
11
+ labels: list[str],
12
+ text_contents: list[str],
13
+ ideologies: list[str],
14
+ embeddings: list[np.ndarray],
15
+ # properties: list[dict]
16
+ ) -> ModuleType:
17
+ """Upsert data objects to FAISS index"""
18
+ embedding = OpenAIEmbeddings(openai_api_key=api_key)
19
+
20
+ properties = [dict(id=id, ideology=ideology, label=label) for id, ideology, label in zip(ids, ideologies, labels)]
21
+ text_embeddings = [x.tolist() for x in embeddings]
22
+ embedding_pairs = list(zip(text_contents, text_embeddings))
23
+ faiss = FAISS.from_embeddings(embedding_pairs, embedding=embedding, metadatas=properties)
24
+ return faiss
25
+
26
+ def save_vector_store(
27
+ client_vector_store: ModuleType
28
+ ) -> bool:
29
+ client_vector_store.save_local(folder_path=".\data2")
30
+ return True
31
+
32
+ def load_vector_store(api_key: str) -> ModuleType:
33
+ faiss = FAISS.load_local(folder_path="./data2", embeddings=OpenAIEmbeddings(openai_api_key=api_key), allow_dangerous_deserialization=True)
34
+ return faiss
35
+
src/enforcement_module/__init__.py ADDED
File without changes
src/enforcement_module/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (140 Bytes). View file
 
src/enforcement_module/__pycache__/policy_enforcement_decider.cpython-39.pyc ADDED
Binary file (2.71 kB). View file
 
src/enforcement_module/policy_enforcement_decider.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import ModuleType
2
+ import math
3
+ from langchain.llms import OpenAI
4
+ import requests
5
+ import requests.models
6
+
7
+
8
+ import requests
9
+ from hamilton.function_modifiers import config
10
+
11
+
12
+
13
+
14
+ def enforcement_prompt(user_input: str, violation_context: dict) -> str:
15
+ """
16
+ Generates the prompt to be sent to the LLM for determining the appropriate enforcement action.
17
+ """
18
+ dio_name = violation_context.get("entity_name", "None")
19
+ dio_details = violation_context.get("entity_info", {}).get("Summary", "Unknown")
20
+ dio_category = violation_context.get("entity_info", {}).get("Policy Category", "Unknown")
21
+ radicalization_flag = violation_context.get("extremism_detected", False)
22
+ dio_sentiment = violation_context.get("aspect_sentiment", "None")
23
+
24
+ prompt = (
25
+
26
+ f"Based on the following content and context, determine the appropriate enforcement action:\\n\\n"
27
+ f"Content: '{user_input}'\\n"
28
+ f"Radicalization detected: {radicalization_flag}\\n"
29
+ f"DIO Mentioned: {dio_name} (Category: {dio_category})\\n"
30
+ f"System's Classification of Content Sentiment Towards DIO: {dio_sentiment}"
31
+ # f"{user_info}\\n\\n"
32
+ "Policy Summary: Any entity mentioned in the 'DIO Mentioned' field is from the Dangerous Individuals & Organizations (DIO) blacklist. "
33
+ "If this entity is mentioned in the user content in a supportive, glorifying, or representative manner, it constitutes a violation of our platform's community standards. "
34
+ "Our platform does not allow content that glorifies, supports, or represents individuals or organizations involved in violence, hate, or terrorism. "
35
+ "These entities include terrorist groups, hate groups, violent non-state actors, and others who advocate or engage in violent activities. "
36
+ "Any content that endorses or supports these entities or their actions is a policy violation. "
37
+ "Users may discuss these topics neutrally or in a condemnatory manner, but any positive portrayal is prohibited. "
38
+ "Based on this policy, choose one of the following enforcement actions: Warning, Limited Visibility, Temporary Suspension, Permanent Ban. "
39
+ "Provide a brief explanation for your choice."
40
+ )
41
+ return prompt
42
+
43
+ def get_enforcement_decision(enforcement_prompt: str, mistral_public_url: str) -> dict:
44
+ """
45
+ Sends the enforcement prompt to the Mistral model server and retrieves the enforcement decision.
46
+ """
47
+ input_text = {
48
+ "context": enforcement_prompt,
49
+ "question": "What is the appropriate enforcement action?"
50
+ }
51
+
52
+ response = requests.post(f'{mistral_public_url}/mistral-inference', json=input_text, stream=False)
53
+
54
+ return {
55
+ "enforcement_action": response.text.strip(),
56
+ "prompt": enforcement_prompt
57
+ }
58
+
59
+
src/gradio_server.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Annotated
3
+
4
+ from fastapi import FastAPI, Form, UploadFile
5
+ from pydantic import BaseModel
6
+ from hamilton import driver
7
+ from pandas import DataFrame
8
+
9
+ from data_module import data_pipeline, embedding_pipeline, vectorstore
10
+ from classification_module import semantic_similarity, dio_support_detector
11
+ from enforcement_module import policy_enforcement_decider
12
+
13
+ from decouple import config
14
+
15
+ app = FastAPI()
16
+
17
+ config = {"loader": "pd",
18
+ "embedding_service": "openai",
19
+ "api_key": config("OPENAI_API_KEY"),
20
+ "model_name": "text-embedding-ada-002",
21
+ "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
22
+ "ner_public_url": config("NER_PUBLIC_URL")
23
+ } # or "pd"
24
+
25
+ dr = (
26
+ driver.Builder()
27
+ .with_config(config)
28
+ .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
29
+ .build()
30
+ )
31
+
32
+ dr_enforcement = (
33
+ driver.Builder()
34
+ .with_config(config)
35
+ .with_modules(policy_enforcement_decider)
36
+ .build()
37
+ )
38
+
39
+ class RadicalizationDetectionRequest(BaseModel):
40
+ user_text: str
41
+
42
+ class PolicyEnforcementRequest(BaseModel):
43
+ user_text: str
44
+ violation_context: dict
45
+
46
+ class RadicalizationDetectionResponse(BaseModel):
47
+ """Response to the /detect endpoint"""
48
+ values: dict
49
+
50
+ class PolicyEnforcementResponse(BaseModel):
51
+ """Response to the /generate_policy_enforcement endpoint"""
52
+ values: dict
53
+
54
+ @app.post("/detect_radicalization")
55
+ def detect_radicalization(
56
+ request: RadicalizationDetectionRequest
57
+ ) -> RadicalizationDetectionResponse:
58
+
59
+ results = dr.execute(
60
+ final_vars=["detect_glorification"],
61
+ inputs={"project_root": ".", "user_input": request.user_text}
62
+ )
63
+ print(results)
64
+ print(type(results))
65
+ if isinstance(results, DataFrame):
66
+ results = results.to_dict(orient="dict")
67
+ return RadicalizationDetectionResponse(values=results)
68
+
69
+ @app.post("/generate_policy_enforcement")
70
+ def generate_policy_enforcement(
71
+ request: PolicyEnforcementRequest
72
+ ) -> PolicyEnforcementResponse:
73
+ results = dr_enforcement.execute(
74
+ final_vars=["get_enforcement_decision"],
75
+ inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
76
+ )
77
+ print(results)
78
+ print(type(results))
79
+ if isinstance(results, DataFrame):
80
+ results = results.to_dict(orient="dict")
81
+ return PolicyEnforcementResponse(values=results)
82
+
83
+ # Gradio Interface Functions
84
+ def gradio_detect_radicalization(user_text: str):
85
+ request = RadicalizationDetectionRequest(user_text=user_text)
86
+ response = detect_radicalization(request)
87
+ return response.values
88
+
89
+ def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
90
+ # violation_context needs to be provided in a valid JSON format
91
+ context_dict = eval(violation_context) # Replace eval with json.loads for safer parsing if it's JSON
92
+ request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
93
+ response = generate_policy_enforcement(request)
94
+ return response.values
95
+
96
+ # Define the Gradio interface
97
+ iface = gr.Interface(
98
+ fn=gradio_detect_radicalization, # Function to detect radicalization
99
+ inputs="text", # Single text input
100
+ outputs="json", # Return JSON output
101
+ title="Radicalization Detection",
102
+ description="Enter text to detect glorification or radicalization."
103
+ )
104
+
105
+ # Second interface for policy enforcement
106
+ iface2 = gr.Interface(
107
+ fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement
108
+ inputs=["text", "text"], # Two text inputs, one for user text, one for violation context
109
+ outputs="json", # Return JSON output
110
+ title="Policy Enforcement Decision",
111
+ description="Enter user text and context to generate a policy enforcement decision."
112
+ )
113
+
114
+ # Combine the interfaces in a Tabbed interface
115
+ iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
116
+
117
+ # Start the Gradio interface
118
+ iface_combined.launch(server_name="0.0.0.0", server_port=7861)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ import uvicorn
123
+ from threading import Thread
124
+
125
+ # Run FastAPI server in a separate thread
126
+ def run_fastapi():
127
+ uvicorn.run(app, host="0.0.0.0", port=8000)
128
+
129
+ fastapi_thread = Thread(target=run_fastapi)
130
+ fastapi_thread.start()
131
+
132
+ # Launch Gradio Interface
133
+ iface_combined.launch()
src/knowledge_service/__pycache__/knowledge_retrieval.cpython-39.pyc ADDED
Binary file (6.34 kB). View file
 
src/knowledge_service/knowledge_retrieval.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import GraphCypherQAChain
2
+ import os
3
+ # from neo4j_semantic_layer import agent_executor as neo4j_semantic_layer_chain
4
+
5
+ # add_routes(app, neo4j_semantic_layer_chain, path="\neo4j-semantic-layer")
6
+
7
+
8
+ from decouple import config
9
+ from typing import List, Optional, Dict, Type
10
+
11
+ from langchain.callbacks.manager import (
12
+ AsyncCallbackManagerForToolRun,
13
+ CallbackManagerForToolRun,
14
+ )
15
+ from langchain.pydantic_v1 import BaseModel, Field
16
+ from langchain.tools import BaseTool
17
+ from langchain_community.graphs import Neo4jGraph
18
+ from neo4j.exceptions import ServiceUnavailable
19
+ import re
20
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
21
+
22
+ NEO4J_URL = config("NEO4J_URL",)
23
+ NEO4J_USER = "neo4j"
24
+ NEO4J_PASSWORD = config("NEO4J_PASSWORD")
25
+
26
+
27
+ def remove_lucene_chars(text: str) -> str:
28
+ """Remove Lucene special characters"""
29
+ special_chars = [
30
+ "+",
31
+ "-",
32
+ "&",
33
+ "|",
34
+ "!",
35
+ "(",
36
+ ")",
37
+ "{",
38
+ "}",
39
+ "[",
40
+ "]",
41
+ "^",
42
+ '"',
43
+ "~",
44
+ "*",
45
+ "?",
46
+ ":",
47
+ "\\",
48
+ ]
49
+ for char in special_chars:
50
+ if char in text:
51
+ text = text.replace(char, " ")
52
+ return text.strip()
53
+
54
+
55
+ # TODO: For now (8/24/2024), this search query works for strings written in Latin alphabet, but not
56
+ # any other alphabet. Follow-up action item: create a custom Neo4j Lucene analyzer for mixed-language
57
+ # data (or find one open-source somewhere), re-index the data, and then add {analyzer: "my_analyzer"}
58
+ # to the candidate_query string below.
59
+ def generate_full_text_query(input: str) -> str:
60
+ """
61
+ Generate a full-text search query for a given input string.
62
+
63
+ This function constructs a query string suitable for a full-text search.
64
+ It processes the input string by splitting it into words and appending a
65
+ similarity threshold (~0.7) to each word, then combines them using the AND
66
+ operator. Useful for mapping movies and people from user questions
67
+ to database values, and allows for some misspelings.
68
+ """
69
+ full_text_query = ""
70
+ words = [el for el in remove_lucene_chars(input).split() if el]
71
+ for word in words[:-1]:
72
+ full_text_query += f" {word}~0.7 AND"
73
+ full_text_query += f" {words[-1]}~0.7"
74
+ return full_text_query.strip()
75
+
76
+
77
+ candidate_query = """
78
+ CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})
79
+ YIELD node
80
+ RETURN node.name AS name,
81
+ node.summary AS summary,
82
+ labels(node) AS label
83
+
84
+ """
85
+
86
+ person_description_query = """
87
+ MATCH (e:PERSON)-[r:IN_CATEGORY]-(m:CATEGORY)
88
+ WHERE e.name = $name
89
+ RETURN e.name AS name,
90
+ e.gender AS gender,
91
+ e.summary AS summary,
92
+ m.name AS policy_category
93
+ LIMIT 1
94
+ """
95
+
96
+ organization_description_query = """
97
+ MATCH (o:ORGANIZATION)-[r:IN_CATEGORY]-(m:CATEGORY)
98
+ WHERE o.name = $name
99
+ RETURN o.name AS name,
100
+ o.summary AS summary,
101
+ o.description AS description,
102
+ o.twitterUri AS twitter_uri,
103
+ o.homepageUri AS homepage_uri,
104
+ m.name AS policy_category
105
+ LIMIT 1
106
+ """
107
+
108
+ @retry(
109
+ retry=retry_if_exception_type(ServiceUnavailable),
110
+ stop=stop_after_attempt(4),
111
+ wait=wait_exponential(multiplier=1, min=4, max=8)
112
+ )
113
+ def execute_query(query, params):
114
+ graph = Neo4jGraph(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD)
115
+ return graph.query(query, params)
116
+
117
+
118
+ def get_candidates(input: str, index: str, limit: int = 3) -> List[Dict[str, str]]:
119
+ """
120
+ Retrieve a list of candidate entities from database based on the input string.
121
+
122
+ This function queries the Neo4j database using a full-text search. It takes the
123
+ input string, generates a full-text query, and executes this query against the
124
+ specified index in the database. The function returns a list of candidates
125
+ matching the query, with each candidate being a dictionary containing their name,
126
+ summary, and label (either 'ORGANIZATION' or 'PERSON').
127
+ """
128
+ ft_query = generate_full_text_query(input)
129
+ print(ft_query)
130
+ candidates = execute_query(
131
+ candidate_query, {"fulltextQuery": ft_query, "index": index, "limit": limit}
132
+ )
133
+ return candidates
134
+
135
+
136
+ def get_information(entity: str, index: str) -> dict:
137
+ candidates = get_candidates(entity, index)
138
+ if not candidates:
139
+ return None
140
+ # elif len(candidates) > 1:
141
+ # newline = "\n"
142
+ # return (
143
+ # "Multiple matching people were found. They are not the same person. Need additional information to disambiguate the name. "
144
+ # "In your <avi_answer> output tag, present the user with the following matched options and ask the user which of the options they meant. "
145
+ # f"Here are the options: {newline + newline.join(str(d) for d in candidates)}"
146
+ # )
147
+ candidate = candidates[0]
148
+ description_query = (person_description_query if index == "dangerous_individuals" else organization_description_query)
149
+
150
+ data = execute_query(
151
+ description_query, params={"name": candidate["name"]}
152
+ )
153
+
154
+ if not data:
155
+ return None
156
+ candidate_data = data[0]
157
+ # detailed_info = "\n".join([f"{key.replace('_', ' ').title()}: {value}" for key, value in candidate_data.items()])
158
+ detailed_info = {key.replace('_', ' ').title(): value for key, value in candidate_data.items()}
159
+ return detailed_info
160
+
161
+
162
+ class InformationInput(BaseModel):
163
+ entity: str = Field(description="full-text search query of the name of a given entity which are mentioned in the question. Example: 'Alice Bob")
164
+ entity_type: str = Field(
165
+ description="indexed list to search for membership by the entity. Available options are 'dangerous_organizations' and 'dangerous_individuals'"
166
+ )
167
+
168
+
169
+ class DIOInformationTool(BaseTool):
170
+ name = "Dangerous_Individuals_And_Organizations_Information"
171
+ description = (
172
+ "useful for when you need to answer questions about various elected officials or persons running for office. "
173
+ "Never generate a final answer to the question if multiple candidates were matched by name; in those cases, "
174
+ "always present the candidate options as a bulleted list and ask for disambiguation in your <avi_answer> tag. Never embellish your descriptions of "
175
+ "the candidate in question or assume their motivations from their professional activities under any circumstances; "
176
+ "remain 100% strictly fact-focused at all times with the outputs of this tool."
177
+ )
178
+ args_schema: Type[BaseModel] = InformationInput
179
+
180
+ def _run(
181
+ self,
182
+ entity: str,
183
+ entity_type: str,
184
+ run_manager: Optional[CallbackManagerForToolRun] = None,
185
+ ) -> str:
186
+ """Use the tool."""
187
+ return get_information(entity, entity_type)
188
+
189
+ async def _arun(
190
+ self,
191
+ entity: str,
192
+ entity_type: str,
193
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
194
+ ) -> str:
195
+ """Use the tool asynchronously."""
196
+ return get_information(entity, entity_type)
197
+
198
+
199
+
200
+
src/ner_sentiment_analyzer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
2
+ from transformers import pipeline
3
+
4
+ # tokenizer = AutoTokenizer.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
5
+ # model = AutoModelForTokenClassification.from_pretrained("Davlan/bert-base-multilingual-cased-ner-hrl")
6
+ # nlp = pipeline("ner", model=model, tokenizer=tokenizer)
7
+
8
+ pipe = pipeline("text-generation", model="Universal-NER/UniNER-7b-all")
9
+
src/run.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from hamilton import base, driver
4
+
5
+ import logging
6
+ import sys
7
+ import data_module.data_pipeline as data_pipeline
8
+ import data_module.embedding_pipeline as embedding_pipeline
9
+ import data_module.vectorstore as vectorstore
10
+ import classification_module.semantic_similarity as semantic_similarity
11
+ import classification_module.dio_support_detector as dio_support_detector
12
+ import click
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+ logging.basicConfig(stream=sys.stdout)
17
+
18
+ @click.command()
19
+ @click.option(
20
+ "--embedding_service",
21
+ type=click.Choice(["openai", "cohere", "sentence_transformer", "marqo"], case_sensitive=False),
22
+ default="sentence_transformer",
23
+ help="Text embedding service.",
24
+ )
25
+ @click.option(
26
+ "--embedding_service_api_key",
27
+ default=None,
28
+ help="API Key for embedding service. Needed if using OpenAI or Cohere.",
29
+ )
30
+ @click.option("--model_name", default=None, help="Text embedding model name.")
31
+ @click.option("--user_input", help="Content on which to run radicalization detection")
32
+ def main(
33
+ embedding_service: str,
34
+ embedding_service_api_key: str | None,
35
+ model_name: str,
36
+ user_input: str
37
+ ):
38
+ if model_name is None:
39
+ if embedding_service == "openai":
40
+ model_name = "text-embedding-ada-002"
41
+ elif embedding_service == "cohere":
42
+ model_name = "embed-english-light-v2.0"
43
+ elif embedding_service == "sentence_transformer":
44
+ model_name = "multi-qa-MiniLM-L6-cos-v1"
45
+
46
+ config = {"loader": "pd", "embedding_service": embedding_service, "api_key": embedding_service_api_key, "model_name": model_name} # or "pd"
47
+
48
+ dr = driver.Driver(
49
+ config,
50
+ data_pipeline,
51
+ embedding_pipeline,
52
+ vectorstore,
53
+ semantic_similarity,
54
+ dio_support_detector
55
+ )
56
+ # The `final_vars` requested are functions with side-effects
57
+ print(dr.execute(
58
+ final_vars=["detect_glorification"],
59
+ inputs={"project_root": ".", "user_input": user_input} # I specify this because of how I run this example.
60
+ ))
61
+ # dr.visualize_execution(final_vars=["save_vector_store"],
62
+ # inputs={"project_root": ".", "user_input": user_input}, output_file_path='./my-dag.dot', render_kwargs={})
63
+
64
+ if __name__ == "__main__":
65
+ main()
src/server.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ from fastapi import FastAPI, Form, UploadFile
4
+ from pydantic import BaseModel
5
+ from hamilton import driver
6
+ from pandas import DataFrame
7
+
8
+ from data_module import data_pipeline, embedding_pipeline, vectorstore
9
+ from classification_module import semantic_similarity, dio_support_detector
10
+ from enforcement_module import policy_enforcement_decider
11
+
12
+ from decouple import config
13
+
14
+ app = FastAPI()
15
+
16
+ config = {"loader": "pd",
17
+ "embedding_service": "openai",
18
+ "api_key": config("OPENAI_API_KEY"),
19
+ "model_name": "text-embedding-ada-002",
20
+ "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
21
+ "ner_public_url": config("NER_PUBLIC_URL")
22
+ } # or "pd"
23
+
24
+ dr = (
25
+ driver.Builder()
26
+ .with_config(config)
27
+ .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
28
+ .build()
29
+ )
30
+
31
+ dr_enforcement = (
32
+ driver.Builder()
33
+ .with_config(config)
34
+ .with_modules(policy_enforcement_decider)
35
+ .build()
36
+ )
37
+
38
+ class RadicalizationDetectionRequest(BaseModel):
39
+ user_text: str
40
+
41
+ class PolicyEnforcementRequest(BaseModel):
42
+ user_text: str
43
+ violation_context: dict
44
+
45
+ class RadicalizationDetectionResponse(BaseModel):
46
+ """Response to the /detect endpoint"""
47
+ values: dict
48
+
49
+ class PolicyEnforcementResponse(BaseModel):
50
+ """Response to the /generate_policy_enforcement endpoint"""
51
+ values: dict
52
+
53
+ @app.post("/detect_radicalization")
54
+ def detect_radicalization(
55
+ request: RadicalizationDetectionRequest
56
+ ) -> RadicalizationDetectionResponse:
57
+
58
+ results = dr.execute(
59
+ final_vars=["detect_glorification"],
60
+ inputs={"project_root": ".", "user_input": request.user_text}
61
+ )
62
+ print(results)
63
+ print(type(results))
64
+ if isinstance(results, DataFrame):
65
+ results = results.to_dict(orient="dict")
66
+ return RadicalizationDetectionResponse(values=results)
67
+
68
+ @app.post("/generate_policy_enforcement")
69
+ def generate_policy_enforcement(
70
+ request: PolicyEnforcementRequest
71
+ ) -> PolicyEnforcementResponse:
72
+ results = dr_enforcement.execute(
73
+ final_vars=["get_enforcement_decision"],
74
+ inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
75
+ )
76
+ print(results)
77
+ print(type(results))
78
+ if isinstance(results, DataFrame):
79
+ results = results.to_dict(orient="dict")
80
+ return PolicyEnforcementResponse(values=results)
81
+
82
+
83
+
84
+ if __name__ == "__main__":
85
+ import uvicorn
86
+ uvicorn.run(app, host="0.0.0.0", port=8000) # specify host and port
src/tests/__pycache__/test_kg_retrieval.cpython-39.pyc ADDED
Binary file (2.11 kB). View file
 
src/tests/test_kg_retrieval.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import name
2
+ import unittest
3
+ from src.knowledge_service.knowledge_retrieval import *
4
+ import pandas as pd
5
+
6
+ class InformationRetrievalTests(unittest.TestCase):
7
+
8
+ def test_get_candidates_with_standard_name_retrieves_ok(self):
9
+ input = 'A Voice for Men'
10
+ index = 'dangerous_organizations'
11
+
12
+
13
+ results = get_candidates(input, index)
14
+ self.assertIsNotNone(results)
15
+
16
+ expected = {'name': 'A Voice for Men', 'summary': 'Organization', 'label': ['ORGANIZATION']}
17
+
18
+ self.assertEqual(expected, results[0])
19
+
20
+ def test_get_candidates_with_alternate_latin_name_retrieves_ok(self):
21
+ input = 'Uma Voz'
22
+ index = 'dangerous_organizations'
23
+
24
+
25
+ results = get_candidates(input, index)
26
+ self.assertIsNotNone(results)
27
+
28
+ expected = {'name': 'A Voice for Men', 'summary': 'Organization', 'label': ['ORGANIZATION']}
29
+
30
+ self.assertEqual(expected, results[0])
31
+
32
+ def test_get_information_standard_individual_returns_ok(self):
33
+ entity = 'Curt Doolittle'
34
+ index = 'dangerous_individuals'
35
+
36
+ result = get_information(entity=entity, index=index)
37
+ print(result)
38
+ self.assertIsNotNone(result)
39
+
40
+ self.assertIn("""Name: Curt Doolittle
41
+ Gender: Male
42
+ Summary: CEO at A. O. Smith
43
+ Policy Category: HATE""", result)
44
+
45
+ def test_get_information_standard_organization_returns_ok(self):
46
+ entity = 'Voice for Men'
47
+ index = 'dangerous_organizations'
48
+
49
+ result = get_information(entity=entity, index=index)
50
+ print(result)
51
+ self.assertIsNotNone(result)
52
+
53
+ self.assertIn("Name: A Voice for Men", result)
54
+
55
+ if __name__=='__main__':
56
+ unittest.main()
57
+
58
+
59
+
60
+
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (127 Bytes). View file
 
src/utils/__pycache__/util.cpython-39.pyc ADDED
Binary file (309 Bytes). View file
 
src/utils/util.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_proj_root() -> Path:
4
+ return Path(__file__).parent