PoC / src /utils.py
deekshith-rj's picture
Fixing a Runtime Error
269a93e verified
# import random
# import sqlite3
# import time
# from googleapiclient.discovery import build
# from google.oauth2 import service_account
# from googleapiclient.errors import HttpError
# import pandas as pd
# import requests
# from bs4 import BeautifulSoup
# import pickle
# import tldextract
import os
from dotenv import load_dotenv
# from langchain.schema import Document
# from langchain.vectorstores.utils import DistanceStrategy
# from torch import cuda, bfloat16
# import torch
# import transformers
# from transformers import AutoTokenizer
# from langchain.document_loaders import TextLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import LlamaCpp
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA # RetrievalQAWithSourcesChain
# from config import IFCN_LIST_URL
IFCN_FILENAME = os.path.join(os.path.dirname(os.path.dirname(__file__)),
'ifcn_df.csv')
load_dotenv()
DB_PATH = os.getenv('DB_PATH')
FAISS_DB_PATH = os.getenv('FAISS_DB_PATH')
MODEL_PATH = os.getenv('MODEL_PATH')
# def get_claims(claims_serv, query_str, lang_code):
# """Queries the Google Fact Check API using the search string and returns the results
# Args:
# claims_serv (build().claims() object): build() creates a service object \
# for the factchecktools API; claims() creates a 'claims' object which \
# can be used to query with the search string
# query_str (str): the query string
# lang_code (str): BCP-47 language code, used to restrict search results by language
# Returns:
# list: the list of all search results returned by the API
# """
# claims = []
# req = claims_serv.search(query=query_str, languageCode=lang_code)
# try:
# res = req.execute()
# claims = res['claims'] # FIXME: is returning KeyError, perhaps when Google API is unresponsive
# except HttpError as e:
# print('Error response status code : {0}, reason : {1}'.format(e.status_code, e.error_details))
# # Aggregate all the results pages into one object
# while 'nextPageToken' in res.keys():
# req = claims_serv.search_next(req, res)
# res = req.execute()
# claims.extend(res['claims'])
# # TODO: Also return any basic useful metrics based on the results
# return claims
# def reformat_claims(claims):
# """Reformats the list of nested claims / search results into a DataFrame
# Args:
# claims (list): list of nested claims / search results
# Returns:
# pd.DataFrame: DataFrame containing search results, one per each row
# """
# # Format the results object into a format that is convenient to use
# df = pd.DataFrame(claims)
# df = df.explode('claimReview').reset_index(drop=True)
# claim_review_df = pd.json_normalize(df['claimReview'])
# return pd.concat([df.drop('claimReview', axis=1), claim_review_df], axis=1)
# def certify_claims(claims_df):
# """Certifies all the search results from the API against a list of verified IFCN signatories
# Args:
# claims_df (pd.DataFrame): DataFrame object containing all search results from the API
# Returns:
# pd.DataFrame: claims dataframe filtered to include only IFCN-certified claims
# """
# ifcn_to_use = get_ifcn_to_use()
# claims_df['ifcn_check'] = claims_df['publisher.site'].apply(remove_subdomain).isin(ifcn_to_use)
# return claims_df[claims_df['ifcn_check']].drop('ifcn_check', axis=1)
# def get_ifcn_data():
# """Standalone function to update the IFCN signatories CSV file that is stored locally"""
# r = requests.get(IFCN_LIST_URL)
# soup = BeautifulSoup(r.content, 'html.parser')
# cats_list = soup.find_all('div', class_='row mb-5')
# active = cats_list[0].find_all('div', class_='media')
# active = extract_ifcn_df(active, 'active')
# under_review = cats_list[1].find_all('div', class_='media')
# under_review = extract_ifcn_df(under_review, 'under_review')
# expired = cats_list[2].find_all('div', class_='media')
# expired = extract_ifcn_df(expired, 'expired')
# ifcn_df = pd.concat([active, under_review, expired], axis=0, ignore_index=True)
# ifcn_df['country'] = ifcn_df['country'].str.strip('from ')
# ifcn_df['verified_date'] = ifcn_df['verified_date'].str.strip('Verified on ')
# ifcn_df.to_csv(IFCN_FILENAME, index=False)
# def extract_ifcn_df(ifcn_list, status):
# """Returns useful info from a list of IFCN signatories
# Args:
# ifcn_list (list): list of IFCN signatories
# status (str): status code to be used for all signatories in this list
# Returns:
# pd.DataFrame: a dataframe of IFCN signatories' data
# """
# ifcn_data = [{
# 'url': x.a['href'],
# 'name': x.h5.text,
# 'country': x.h6.text,
# 'verified_date': x.find_all('span', class_='small')[1].text,
# 'ifcn_profile_url':
# x.find('a', class_='btn btn-sm btn-outline btn-link mb-0')['href'],
# 'status': status
# } for x in ifcn_list]
# return pd.DataFrame(ifcn_data)
# def remove_subdomain(url):
# """Removes the subdomain from a URL hostname - useful when comparing two URLs
# Args:
# url (str): URL hostname
# Returns:
# str: URL with subdomain removed
# """
# extract = tldextract.extract(url)
# return extract.domain + '.' + extract.suffix
# def get_ifcn_to_use():
# """Returns the IFCN data for non-expired signatories
# Returns:
# pd.Series: URls of non-expired IFCN signatories
# """
# ifcn_df = pd.read_csv(IFCN_FILENAME)
# ifcn_url = ifcn_df.loc[ifcn_df.status.isin(['active', 'under_review']), 'url']
# return [remove_subdomain(x) for x in ifcn_url]
# def get_gapi_service():
# """Returns a Google Fact-Check API-specific service object used to query the API
# Returns:
# googleapiclient.discovery.Resource: API-specific service object
# """
# load_dotenv()
# environment = os.getenv('ENVIRONMENT')
# if environment == 'DEVELOPMENT':
# api_key = os.getenv('API_KEY')
# service = build('factchecktools', 'v1alpha1', developerKey=api_key)
# elif environment == 'PRODUCTION':
# google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
# # FIXME: The below credentials not working, the HTTP request throws HTTPError 400
# # credentials = service_account.Credentials.from_service_account_file(
# # GOOGLE_APPLICATION_CREDENTIALS)
# credentials = service_account.Credentials.from_service_account_file(
# google_application_credentials,
# scopes=['https://www.googleapis.com/auth/userinfo.email',
# 'https://www.googleapis.com/auth/cloud-platform'])
# service = build('factchecktools', 'v1alpha1', credentials=credentials)
# return service
# # USED IN update_database.py ----
# def get_claims_by_site(claims_serv, publisher_site, lang_code):
# # TODO: Any HTTP or other errors in this function need to be handled better
# req = claims_serv.search(reviewPublisherSiteFilter=publisher_site,
# languageCode=lang_code)
# while True:
# try:
# res = req.execute()
# break
# except HttpError as e:
# print('Error response status code : {0}, reason : {1}'.
# format(e.status_code, e.error_details))
# time.sleep(random.randint(50, 60))
# if 'claims' in res:
# claims = res['claims'] # FIXME: is returning KeyError when Google API is unresponsive?
# print('first 10')
# req_prev, req = req, None
# res_prev, res = res, None
# else:
# print('No data')
# return []
# # Aggregate all the results pages into one object
# while 'nextPageToken' in res_prev.keys():
# req = claims_serv.search_next(req_prev, res_prev)
# try:
# res = req.execute()
# claims.extend(res['claims'])
# req_prev, req = req, None
# res_prev, res = res, None
# print('another 10')
# except HttpError as e:
# print('Error in while loop : {0}, \
# reason : {1}'.format(e.status_code, e.error_details))
# time.sleep(random.randint(50, 60))
# return claims
# def rename_claim_attrs(df):
# return df.rename(
# columns={'claimDate': 'claim_date',
# 'reviewDate': 'review_date',
# 'textualRating': 'textual_rating',
# 'languageCode': 'language_code',
# 'publisher.name': 'publisher_name',
# 'publisher.site': 'publisher_site'}
# )
# def clean_claims(df):
# pass
# def write_claims_to_db(df):
# with sqlite3.connect(DB_PATH) as db_con:
# df.to_sql('claims', db_con, if_exists='append', index=False)
# # FIXME: The id variable is not getting auto-incremented
# def generate_and_store_embeddings(df, embed_model, overwrite):
# # TODO: Combine "text" & "textual_rating" to generate useful statements
# df['fact_check'] = 'The fact-check result for the claim "' + df['text'] \
# + '" is "' + df['textual_rating'] + '"'
# # TODO: Are ids required?
# df.rename(columns={'text': 'claim'}, inplace=True)
# docs = \
# [Document(page_content=row['fact_check'],
# metadata=row.drop('fact_check').to_dict())
# for idx, row in df.iterrows()]
# if overwrite == True:
# db = FAISS.from_documents(docs, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# # FIXME: MAX_INNER_PRODUCT is not being used currently, only EUCLIDEAN_DISTANCE
# db.save_local(FAISS_DB_PATH)
# elif overwrite == False:
# db = FAISS.load_local(FAISS_DB_PATH, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# db.add_documents(docs)
# db.save_local(FAISS_DB_PATH)
def get_rag_chain():
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cpu"}
embed_model = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
llm = LlamaCpp(model_path=MODEL_PATH)
db_vector = FAISS.load_local(FAISS_DB_PATH, embed_model, allow_dangerous_deserialization=True)
retriever = db_vector.as_retriever()
return RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
verbose=True
)