PoC first release - no database update procedures included - just the app (+ direct dependencies) which uses the already generated databases - db_faiss and database.db
85eaaaa
verified
# import random | |
# import sqlite3 | |
# import time | |
# from googleapiclient.discovery import build | |
# from google.oauth2 import service_account | |
# from googleapiclient.errors import HttpError | |
# import pandas as pd | |
# import requests | |
# from bs4 import BeautifulSoup | |
# import pickle | |
# import tldextract | |
import os | |
from dotenv import load_dotenv | |
# from langchain.schema import Document | |
# from langchain.vectorstores.utils import DistanceStrategy | |
# from torch import cuda, bfloat16 | |
# import torch | |
# import transformers | |
# from transformers import AutoTokenizer | |
# from langchain.document_loaders import TextLoader | |
# from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.llms import LlamaCpp | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA # RetrievalQAWithSourcesChain | |
# from config import IFCN_LIST_URL | |
IFCN_FILENAME = os.path.join(os.path.dirname(os.path.dirname(__file__)), | |
'ifcn_df.csv') | |
load_dotenv() | |
DB_PATH = os.getenv('DB_PATH') | |
FAISS_DB_PATH = os.getenv('FAISS_DB_PATH') | |
MODEL_PATH = os.getenv('MODEL_PATH') | |
# def get_claims(claims_serv, query_str, lang_code): | |
# """Queries the Google Fact Check API using the search string and returns the results | |
# Args: | |
# claims_serv (build().claims() object): build() creates a service object \ | |
# for the factchecktools API; claims() creates a 'claims' object which \ | |
# can be used to query with the search string | |
# query_str (str): the query string | |
# lang_code (str): BCP-47 language code, used to restrict search results by language | |
# Returns: | |
# list: the list of all search results returned by the API | |
# """ | |
# claims = [] | |
# req = claims_serv.search(query=query_str, languageCode=lang_code) | |
# try: | |
# res = req.execute() | |
# claims = res['claims'] # FIXME: is returning KeyError, perhaps when Google API is unresponsive | |
# except HttpError as e: | |
# print('Error response status code : {0}, reason : {1}'.format(e.status_code, e.error_details)) | |
# # Aggregate all the results pages into one object | |
# while 'nextPageToken' in res.keys(): | |
# req = claims_serv.search_next(req, res) | |
# res = req.execute() | |
# claims.extend(res['claims']) | |
# # TODO: Also return any basic useful metrics based on the results | |
# return claims | |
# def reformat_claims(claims): | |
# """Reformats the list of nested claims / search results into a DataFrame | |
# Args: | |
# claims (list): list of nested claims / search results | |
# Returns: | |
# pd.DataFrame: DataFrame containing search results, one per each row | |
# """ | |
# # Format the results object into a format that is convenient to use | |
# df = pd.DataFrame(claims) | |
# df = df.explode('claimReview').reset_index(drop=True) | |
# claim_review_df = pd.json_normalize(df['claimReview']) | |
# return pd.concat([df.drop('claimReview', axis=1), claim_review_df], axis=1) | |
# def certify_claims(claims_df): | |
# """Certifies all the search results from the API against a list of verified IFCN signatories | |
# Args: | |
# claims_df (pd.DataFrame): DataFrame object containing all search results from the API | |
# Returns: | |
# pd.DataFrame: claims dataframe filtered to include only IFCN-certified claims | |
# """ | |
# ifcn_to_use = get_ifcn_to_use() | |
# claims_df['ifcn_check'] = claims_df['publisher.site'].apply(remove_subdomain).isin(ifcn_to_use) | |
# return claims_df[claims_df['ifcn_check']].drop('ifcn_check', axis=1) | |
# def get_ifcn_data(): | |
# """Standalone function to update the IFCN signatories CSV file that is stored locally""" | |
# r = requests.get(IFCN_LIST_URL) | |
# soup = BeautifulSoup(r.content, 'html.parser') | |
# cats_list = soup.find_all('div', class_='row mb-5') | |
# active = cats_list[0].find_all('div', class_='media') | |
# active = extract_ifcn_df(active, 'active') | |
# under_review = cats_list[1].find_all('div', class_='media') | |
# under_review = extract_ifcn_df(under_review, 'under_review') | |
# expired = cats_list[2].find_all('div', class_='media') | |
# expired = extract_ifcn_df(expired, 'expired') | |
# ifcn_df = pd.concat([active, under_review, expired], axis=0, ignore_index=True) | |
# ifcn_df['country'] = ifcn_df['country'].str.strip('from ') | |
# ifcn_df['verified_date'] = ifcn_df['verified_date'].str.strip('Verified on ') | |
# ifcn_df.to_csv(IFCN_FILENAME, index=False) | |
# def extract_ifcn_df(ifcn_list, status): | |
# """Returns useful info from a list of IFCN signatories | |
# Args: | |
# ifcn_list (list): list of IFCN signatories | |
# status (str): status code to be used for all signatories in this list | |
# Returns: | |
# pd.DataFrame: a dataframe of IFCN signatories' data | |
# """ | |
# ifcn_data = [{ | |
# 'url': x.a['href'], | |
# 'name': x.h5.text, | |
# 'country': x.h6.text, | |
# 'verified_date': x.find_all('span', class_='small')[1].text, | |
# 'ifcn_profile_url': | |
# x.find('a', class_='btn btn-sm btn-outline btn-link mb-0')['href'], | |
# 'status': status | |
# } for x in ifcn_list] | |
# return pd.DataFrame(ifcn_data) | |
# def remove_subdomain(url): | |
# """Removes the subdomain from a URL hostname - useful when comparing two URLs | |
# Args: | |
# url (str): URL hostname | |
# Returns: | |
# str: URL with subdomain removed | |
# """ | |
# extract = tldextract.extract(url) | |
# return extract.domain + '.' + extract.suffix | |
# def get_ifcn_to_use(): | |
# """Returns the IFCN data for non-expired signatories | |
# Returns: | |
# pd.Series: URls of non-expired IFCN signatories | |
# """ | |
# ifcn_df = pd.read_csv(IFCN_FILENAME) | |
# ifcn_url = ifcn_df.loc[ifcn_df.status.isin(['active', 'under_review']), 'url'] | |
# return [remove_subdomain(x) for x in ifcn_url] | |
# def get_gapi_service(): | |
# """Returns a Google Fact-Check API-specific service object used to query the API | |
# Returns: | |
# googleapiclient.discovery.Resource: API-specific service object | |
# """ | |
# load_dotenv() | |
# environment = os.getenv('ENVIRONMENT') | |
# if environment == 'DEVELOPMENT': | |
# api_key = os.getenv('API_KEY') | |
# service = build('factchecktools', 'v1alpha1', developerKey=api_key) | |
# elif environment == 'PRODUCTION': | |
# google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') | |
# # FIXME: The below credentials not working, the HTTP request throws HTTPError 400 | |
# # credentials = service_account.Credentials.from_service_account_file( | |
# # GOOGLE_APPLICATION_CREDENTIALS) | |
# credentials = service_account.Credentials.from_service_account_file( | |
# google_application_credentials, | |
# scopes=['https://www.googleapis.com/auth/userinfo.email', | |
# 'https://www.googleapis.com/auth/cloud-platform']) | |
# service = build('factchecktools', 'v1alpha1', credentials=credentials) | |
# return service | |
# # USED IN update_database.py ---- | |
# def get_claims_by_site(claims_serv, publisher_site, lang_code): | |
# # TODO: Any HTTP or other errors in this function need to be handled better | |
# req = claims_serv.search(reviewPublisherSiteFilter=publisher_site, | |
# languageCode=lang_code) | |
# while True: | |
# try: | |
# res = req.execute() | |
# break | |
# except HttpError as e: | |
# print('Error response status code : {0}, reason : {1}'. | |
# format(e.status_code, e.error_details)) | |
# time.sleep(random.randint(50, 60)) | |
# if 'claims' in res: | |
# claims = res['claims'] # FIXME: is returning KeyError when Google API is unresponsive? | |
# print('first 10') | |
# req_prev, req = req, None | |
# res_prev, res = res, None | |
# else: | |
# print('No data') | |
# return [] | |
# # Aggregate all the results pages into one object | |
# while 'nextPageToken' in res_prev.keys(): | |
# req = claims_serv.search_next(req_prev, res_prev) | |
# try: | |
# res = req.execute() | |
# claims.extend(res['claims']) | |
# req_prev, req = req, None | |
# res_prev, res = res, None | |
# print('another 10') | |
# except HttpError as e: | |
# print('Error in while loop : {0}, \ | |
# reason : {1}'.format(e.status_code, e.error_details)) | |
# time.sleep(random.randint(50, 60)) | |
# return claims | |
# def rename_claim_attrs(df): | |
# return df.rename( | |
# columns={'claimDate': 'claim_date', | |
# 'reviewDate': 'review_date', | |
# 'textualRating': 'textual_rating', | |
# 'languageCode': 'language_code', | |
# 'publisher.name': 'publisher_name', | |
# 'publisher.site': 'publisher_site'} | |
# ) | |
# def clean_claims(df): | |
# pass | |
# def write_claims_to_db(df): | |
# with sqlite3.connect(DB_PATH) as db_con: | |
# df.to_sql('claims', db_con, if_exists='append', index=False) | |
# # FIXME: The id variable is not getting auto-incremented | |
# def generate_and_store_embeddings(df, embed_model, overwrite): | |
# # TODO: Combine "text" & "textual_rating" to generate useful statements | |
# df['fact_check'] = 'The fact-check result for the claim "' + df['text'] \ | |
# + '" is "' + df['textual_rating'] + '"' | |
# # TODO: Are ids required? | |
# df.rename(columns={'text': 'claim'}, inplace=True) | |
# docs = \ | |
# [Document(page_content=row['fact_check'], | |
# metadata=row.drop('fact_check').to_dict()) | |
# for idx, row in df.iterrows()] | |
# if overwrite == True: | |
# db = FAISS.from_documents(docs, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) | |
# # FIXME: MAX_INNER_PRODUCT is not being used currently, only EUCLIDEAN_DISTANCE | |
# db.save_local(FAISS_DB_PATH) | |
# elif overwrite == False: | |
# db = FAISS.load_local(FAISS_DB_PATH, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) | |
# db.add_documents(docs) | |
# db.save_local(FAISS_DB_PATH) | |
def get_rag_chain(): | |
model_name = "sentence-transformers/all-mpnet-base-v2" | |
model_kwargs = {"device": "cpu"} | |
embed_model = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
llm = LlamaCpp(model_path=MODEL_PATH) | |
db_vector = FAISS.load_local(FAISS_DB_PATH, embed_model) | |
retriever = db_vector.as_retriever() | |
return RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
verbose=True | |
) | |