File size: 5,325 Bytes
7a8b33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import argparse
from pathlib import Path
import numpy as np
from pipeline_paths import PIPELINE_PATHS
import json
from zsvision.zs_utils import BlockTimer
from typing import Dict, List
from llm_api_utils import (
call_openai_with_exponetial_backoff,
estimate_cost_of_text_generation_api_call,
init_openai_with_api_key,
)
def generate_search_queries(args, src_path: Path, dest_path: Path):
"""
Generate a search query that can be used to verify a claim.
"""
init_openai_with_api_key(api_key_path=args.api_key_path)
with open(src_path, "r") as f:
claims_and_sources = json.load(f)
# exclude subjective claims
original_num_claims = len(claims_and_sources)
claims_and_sources = [
claim_and_source
for claim_and_source in claims_and_sources
if claim_and_source["label"] == "objective"
]
num_claims = len(claims_and_sources)
print(
f"Filtered from {original_num_claims} claims to {num_claims} objective claims"
)
# we limit the number of claims per api call (otherwise GPT-4 can choke)
num_batches = int(np.ceil(num_claims / args.max_claims_per_api_call))
claims_and_sources_batches = [
batch.tolist() for batch in np.array_split(claims_and_sources, num_batches)
]
queries = []
all_claims_str = "\n".join([claim["claim"] for claim in claims_and_sources])
for idx, claims_and_sources_batch in enumerate(claims_and_sources_batches):
print(
f"Processing batch {idx+1} of {len(claims_and_sources_batches)} (containing {len(claims_and_sources_batch)} claims)"
)
claim_str = "\n".join([claim["claim"] for claim in claims_and_sources_batch])
num_batch_claims = len(claims_and_sources_batch)
# we provide the full list of claims as context (to help resolve ambiguity), but only ask for queries for the current batch
prompt = f"""\
You are working as part of a team and your individual task is to help check a subset of the following claims:\n
{all_claims_str}
Your individual task is as follows. \
For each of the {num_batch_claims} claims made below, provide a suitable Google search query that would enable a human to verify the claim. \
Note that Google can perform calculations and conversions, so you can use it to check numerical claims. \
If you think no Google query will be useful, then write "no suitable query". \
Each proposed Google search query should be on a separate line (do not prefix your queries with bullet points or numbers). \
There should be {num_batch_claims} queries in total.\n \
{claim_str}
"""
persona = "You are a careful research assistant who helps with fact-checking and editing informative articles."
system_message = {"role": "system", "content": persona}
user_message = {"role": "user", "content": prompt}
messages = [system_message, user_message]
with BlockTimer(f"Using OpenAI API to extract claims with {args.model}"):
response = call_openai_with_exponetial_backoff(
model=args.model,
temperature=args.temperature,
messages=messages,
)
cost = estimate_cost_of_text_generation_api_call(
model=args.model, response=response, verbose=True
)
proposed_queries = response.choices[0].message.content
batch_queries = proposed_queries.split("\n")
assert (
len(batch_queries) == num_batch_claims
), f"Expected {num_batch_claims} queries, but got {len(queries)}"
print(f"Generated {len(batch_queries)} queries (cost: {cost:.4f} USD)")
queries.extend(batch_queries)
querysets = []
for claim_and_source, query in zip(claims_and_sources, queries):
queryset = {**claim_and_source, "search_query": query}
querysets.append(queryset)
dest_path.parent.mkdir(exist_ok=True, parents=True)
with open(dest_path, "w") as f:
json.dump(querysets, f, indent=4, sort_keys=True)
def main():
args = parse_args()
src_paths = list(
PIPELINE_PATHS["extracted_claims_with_classifications_dir"].glob("**/*.json")
)
print(
f"Found {len(src_paths)} claim files in {PIPELINE_PATHS['extracted_claims_with_classifications_dir']}"
)
dest_dir = PIPELINE_PATHS["search_queries_for_evidence"]
for src_path in src_paths:
dest_path = dest_dir / src_path.relative_to(
PIPELINE_PATHS["extracted_claims_with_classifications_dir"]
)
if not dest_path.exists() or args.refresh:
generate_search_queries(
args=args,
src_path=src_path,
dest_path=dest_path,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument(
"--model", default="gpt-3.5-turbo", choices=["gpt-4", "gpt-3.5-turbo"]
)
parser.add_argument("--dest_dir", default="data/search_queries", type=Path)
parser.add_argument("--api_key_path", default="OPENAI_API_KEY.txt")
parser.add_argument("--max_claims_per_api_call", type=int, default=10)
parser.add_argument("--refresh", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
main()
|