File size: 10,827 Bytes
85eaaaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# import random
# import sqlite3
# import time

# from googleapiclient.discovery import build
# from google.oauth2 import service_account
# from googleapiclient.errors import HttpError
# import pandas as pd
# import requests
# from bs4 import BeautifulSoup
# import pickle
# import tldextract

import os
from dotenv import load_dotenv

# from langchain.schema import Document
# from langchain.vectorstores.utils import DistanceStrategy
# from torch import cuda, bfloat16
# import torch
# import transformers
# from transformers import AutoTokenizer
# from langchain.document_loaders import TextLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import LlamaCpp
from langchain.vectorstores import  FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA  # RetrievalQAWithSourcesChain

# from config import IFCN_LIST_URL

IFCN_FILENAME = os.path.join(os.path.dirname(os.path.dirname(__file__)),
                             'ifcn_df.csv')

load_dotenv()
DB_PATH = os.getenv('DB_PATH')
FAISS_DB_PATH = os.getenv('FAISS_DB_PATH')
MODEL_PATH = os.getenv('MODEL_PATH')


# def get_claims(claims_serv, query_str, lang_code):
#     """Queries the Google Fact Check API using the search string and returns the results

#     Args:
#         claims_serv (build().claims() object): build() creates a service object \
#             for the factchecktools API; claims() creates a 'claims' object which \
#                 can be used to query with the search string
#         query_str (str): the query string
#         lang_code (str): BCP-47 language code, used to restrict search results by language

#     Returns:
#         list: the list of all search results returned by the API
#     """    
#     claims = []
#     req = claims_serv.search(query=query_str, languageCode=lang_code)
#     try:
#         res = req.execute()
#         claims = res['claims']  # FIXME: is returning KeyError, perhaps when Google API is unresponsive
#     except HttpError as e:
#         print('Error response status code : {0}, reason : {1}'.format(e.status_code, e.error_details))

#     # Aggregate all the results pages into one object
#     while 'nextPageToken' in res.keys():        
#         req = claims_serv.search_next(req, res)
#         res = req.execute()
#         claims.extend(res['claims'])
    
#     # TODO: Also return any basic useful metrics based on the results

#     return claims


# def reformat_claims(claims):
#     """Reformats the list of nested claims / search results into a DataFrame

#     Args:
#         claims (list): list of nested claims / search results

#     Returns:
#         pd.DataFrame: DataFrame containing search results, one per each row
#     """
#     # Format the results object into a format that is convenient to use
#     df = pd.DataFrame(claims)
#     df = df.explode('claimReview').reset_index(drop=True)
#     claim_review_df = pd.json_normalize(df['claimReview'])
#     return pd.concat([df.drop('claimReview', axis=1), claim_review_df], axis=1)


# def certify_claims(claims_df):
#     """Certifies all the search results from the API against a list of verified IFCN signatories

#     Args:
#         claims_df (pd.DataFrame): DataFrame object containing all search results from the API

#     Returns:
#         pd.DataFrame: claims dataframe filtered to include only IFCN-certified claims
#     """
#     ifcn_to_use = get_ifcn_to_use()
#     claims_df['ifcn_check'] = claims_df['publisher.site'].apply(remove_subdomain).isin(ifcn_to_use)
#     return claims_df[claims_df['ifcn_check']].drop('ifcn_check', axis=1)


# def get_ifcn_data():
#     """Standalone function to update the IFCN signatories CSV file that is stored locally"""
#     r = requests.get(IFCN_LIST_URL)
#     soup = BeautifulSoup(r.content, 'html.parser')
#     cats_list = soup.find_all('div', class_='row mb-5')
    
#     active = cats_list[0].find_all('div', class_='media')
#     active = extract_ifcn_df(active, 'active')
    
#     under_review = cats_list[1].find_all('div', class_='media')
#     under_review = extract_ifcn_df(under_review, 'under_review')
    
#     expired = cats_list[2].find_all('div', class_='media')
#     expired = extract_ifcn_df(expired, 'expired')
    
#     ifcn_df = pd.concat([active, under_review, expired], axis=0, ignore_index=True)
#     ifcn_df['country'] = ifcn_df['country'].str.strip('from ')
#     ifcn_df['verified_date'] = ifcn_df['verified_date'].str.strip('Verified on ')

#     ifcn_df.to_csv(IFCN_FILENAME, index=False)


# def extract_ifcn_df(ifcn_list, status):
#     """Returns useful info from a list of IFCN signatories

#     Args:
#         ifcn_list (list): list of IFCN signatories
#         status (str): status code to be used for all signatories in this list

#     Returns:
#         pd.DataFrame: a dataframe of IFCN signatories' data
#     """
#     ifcn_data = [{
#         'url': x.a['href'], 
#         'name': x.h5.text, 
#         'country': x.h6.text, 
#         'verified_date': x.find_all('span', class_='small')[1].text, 
#         'ifcn_profile_url': 
#             x.find('a', class_='btn btn-sm btn-outline btn-link mb-0')['href'], 
#         'status': status
#         } for x in ifcn_list]
#     return pd.DataFrame(ifcn_data)


# def remove_subdomain(url):
#     """Removes the subdomain from a URL hostname - useful when comparing two URLs

#     Args:
#         url (str): URL hostname

#     Returns:
#         str: URL with subdomain removed
#     """
#     extract = tldextract.extract(url)
#     return extract.domain + '.' + extract.suffix


# def get_ifcn_to_use():
#     """Returns the IFCN data for non-expired signatories

#     Returns:
#         pd.Series: URls of non-expired IFCN signatories
#     """
#     ifcn_df = pd.read_csv(IFCN_FILENAME)
#     ifcn_url = ifcn_df.loc[ifcn_df.status.isin(['active', 'under_review']), 'url']
#     return [remove_subdomain(x) for x in ifcn_url]


# def get_gapi_service():
#     """Returns a Google Fact-Check API-specific service object used to query the API

#     Returns:
#         googleapiclient.discovery.Resource: API-specific service object
#     """
#     load_dotenv()
#     environment = os.getenv('ENVIRONMENT')
#     if environment == 'DEVELOPMENT':
#         api_key = os.getenv('API_KEY')
#         service = build('factchecktools', 'v1alpha1', developerKey=api_key)
#     elif environment == 'PRODUCTION':
#         google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
#         # FIXME: The below credentials not working, the HTTP request throws HTTPError 400
#         # credentials = service_account.Credentials.from_service_account_file(
#         #     GOOGLE_APPLICATION_CREDENTIALS)
#         credentials = service_account.Credentials.from_service_account_file(
#             google_application_credentials,
#             scopes=['https://www.googleapis.com/auth/userinfo.email',
#                     'https://www.googleapis.com/auth/cloud-platform'])
#         service = build('factchecktools', 'v1alpha1', credentials=credentials)
#     return service


# # USED IN update_database.py ----
# def get_claims_by_site(claims_serv, publisher_site, lang_code):
#     # TODO: Any HTTP or other errors in this function need to be handled better
#     req = claims_serv.search(reviewPublisherSiteFilter=publisher_site,
#                              languageCode=lang_code)
#     while True:
#         try:
#             res = req.execute()
#             break
#         except HttpError as e:
#             print('Error response status code : {0}, reason : {1}'.
#                   format(e.status_code, e.error_details))
#             time.sleep(random.randint(50, 60))
#     if 'claims' in res:
#         claims = res['claims']  # FIXME: is returning KeyError when Google API is unresponsive?
#         print('first 10')
#         req_prev, req = req, None
#         res_prev, res = res, None
#     else:
#         print('No data')
#         return []

#     # Aggregate all the results pages into one object
#     while 'nextPageToken' in res_prev.keys():
#         req = claims_serv.search_next(req_prev, res_prev)
#         try:
#             res = req.execute()
#             claims.extend(res['claims'])
#             req_prev, req = req, None
#             res_prev, res = res, None
#             print('another 10')
#         except HttpError as e:
#             print('Error in while loop : {0}, \
#                     reason : {1}'.format(e.status_code, e.error_details))
#             time.sleep(random.randint(50, 60))

#     return claims


# def rename_claim_attrs(df):
#     return df.rename(
#         columns={'claimDate': 'claim_date',
#                  'reviewDate': 'review_date',
#                  'textualRating': 'textual_rating',
#                  'languageCode': 'language_code',
#                  'publisher.name': 'publisher_name',
#                  'publisher.site': 'publisher_site'}
#     )


# def clean_claims(df):
#     pass


# def write_claims_to_db(df):
#     with sqlite3.connect(DB_PATH) as db_con:
#         df.to_sql('claims', db_con, if_exists='append', index=False)
#     # FIXME: The id variable is not getting auto-incremented


# def generate_and_store_embeddings(df, embed_model, overwrite):
#     # TODO: Combine "text" & "textual_rating" to generate useful statements
#     df['fact_check'] = 'The fact-check result for the claim "' + df['text'] \
#         + '" is "' + df['textual_rating'] + '"'
#     # TODO: Are ids required?

#     df.rename(columns={'text': 'claim'}, inplace=True)
#     docs = \
#         [Document(page_content=row['fact_check'],
#                   metadata=row.drop('fact_check').to_dict())
#          for idx, row in df.iterrows()]

#     if overwrite == True:        
#         db = FAISS.from_documents(docs, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
#         # FIXME: MAX_INNER_PRODUCT is not being used currently, only EUCLIDEAN_DISTANCE
#         db.save_local(FAISS_DB_PATH)
#     elif overwrite == False:
#         db = FAISS.load_local(FAISS_DB_PATH, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
#         db.add_documents(docs)
#         db.save_local(FAISS_DB_PATH)


def get_rag_chain():
    model_name = "sentence-transformers/all-mpnet-base-v2"
    model_kwargs = {"device": "cpu"}
    embed_model = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
    llm = LlamaCpp(model_path=MODEL_PATH)

    db_vector = FAISS.load_local(FAISS_DB_PATH, embed_model)
    retriever = db_vector.as_retriever()

    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
        verbose=True
    )