Filtir / generate_search_queries.py
vladbogo's picture
Upload folder using huggingface_hub
7a8b33f verified
import argparse
from pathlib import Path
import numpy as np
from pipeline_paths import PIPELINE_PATHS
import json
from zsvision.zs_utils import BlockTimer
from typing import Dict, List
from llm_api_utils import (
call_openai_with_exponetial_backoff,
estimate_cost_of_text_generation_api_call,
init_openai_with_api_key,
)
def generate_search_queries(args, src_path: Path, dest_path: Path):
"""
Generate a search query that can be used to verify a claim.
"""
init_openai_with_api_key(api_key_path=args.api_key_path)
with open(src_path, "r") as f:
claims_and_sources = json.load(f)
# exclude subjective claims
original_num_claims = len(claims_and_sources)
claims_and_sources = [
claim_and_source
for claim_and_source in claims_and_sources
if claim_and_source["label"] == "objective"
]
num_claims = len(claims_and_sources)
print(
f"Filtered from {original_num_claims} claims to {num_claims} objective claims"
)
# we limit the number of claims per api call (otherwise GPT-4 can choke)
num_batches = int(np.ceil(num_claims / args.max_claims_per_api_call))
claims_and_sources_batches = [
batch.tolist() for batch in np.array_split(claims_and_sources, num_batches)
]
queries = []
all_claims_str = "\n".join([claim["claim"] for claim in claims_and_sources])
for idx, claims_and_sources_batch in enumerate(claims_and_sources_batches):
print(
f"Processing batch {idx+1} of {len(claims_and_sources_batches)} (containing {len(claims_and_sources_batch)} claims)"
)
claim_str = "\n".join([claim["claim"] for claim in claims_and_sources_batch])
num_batch_claims = len(claims_and_sources_batch)
# we provide the full list of claims as context (to help resolve ambiguity), but only ask for queries for the current batch
prompt = f"""\
You are working as part of a team and your individual task is to help check a subset of the following claims:\n
{all_claims_str}
Your individual task is as follows. \
For each of the {num_batch_claims} claims made below, provide a suitable Google search query that would enable a human to verify the claim. \
Note that Google can perform calculations and conversions, so you can use it to check numerical claims. \
If you think no Google query will be useful, then write "no suitable query". \
Each proposed Google search query should be on a separate line (do not prefix your queries with bullet points or numbers). \
There should be {num_batch_claims} queries in total.\n \
{claim_str}
"""
persona = "You are a careful research assistant who helps with fact-checking and editing informative articles."
system_message = {"role": "system", "content": persona}
user_message = {"role": "user", "content": prompt}
messages = [system_message, user_message]
with BlockTimer(f"Using OpenAI API to extract claims with {args.model}"):
response = call_openai_with_exponetial_backoff(
model=args.model,
temperature=args.temperature,
messages=messages,
)
cost = estimate_cost_of_text_generation_api_call(
model=args.model, response=response, verbose=True
)
proposed_queries = response.choices[0].message.content
batch_queries = proposed_queries.split("\n")
assert (
len(batch_queries) == num_batch_claims
), f"Expected {num_batch_claims} queries, but got {len(queries)}"
print(f"Generated {len(batch_queries)} queries (cost: {cost:.4f} USD)")
queries.extend(batch_queries)
querysets = []
for claim_and_source, query in zip(claims_and_sources, queries):
queryset = {**claim_and_source, "search_query": query}
querysets.append(queryset)
dest_path.parent.mkdir(exist_ok=True, parents=True)
with open(dest_path, "w") as f:
json.dump(querysets, f, indent=4, sort_keys=True)
def main():
args = parse_args()
src_paths = list(
PIPELINE_PATHS["extracted_claims_with_classifications_dir"].glob("**/*.json")
)
print(
f"Found {len(src_paths)} claim files in {PIPELINE_PATHS['extracted_claims_with_classifications_dir']}"
)
dest_dir = PIPELINE_PATHS["search_queries_for_evidence"]
for src_path in src_paths:
dest_path = dest_dir / src_path.relative_to(
PIPELINE_PATHS["extracted_claims_with_classifications_dir"]
)
if not dest_path.exists() or args.refresh:
generate_search_queries(
args=args,
src_path=src_path,
dest_path=dest_path,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument(
"--model", default="gpt-3.5-turbo", choices=["gpt-4", "gpt-3.5-turbo"]
)
parser.add_argument("--dest_dir", default="data/search_queries", type=Path)
parser.add_argument("--api_key_path", default="OPENAI_API_KEY.txt")
parser.add_argument("--max_claims_per_api_call", type=int, default=10)
parser.add_argument("--refresh", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
main()