Hasan Iqbal
Added OpenFactCheck library
8360ec7 unverified
raw
history blame
2.6 kB
"""Utils for running question generation."""
import os
import time
from typing import List
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
def parse_api_response(api_response: str) -> List[str]:
"""Extract questions from the GPT-3 API response.
Our prompt returns questions as a string with the format of an ordered list.
This function parses this response in a list of questions.
Args:
api_response: Question generation response from GPT-3.
Returns:
questions: A list of questions.
"""
search_string = "I googled:"
questions = []
for question in api_response.split("\n"):
# Remove the search string from each question
if search_string not in question:
continue
question = question.split(search_string)[1].strip()
questions.append(question)
return questions
def run_rarr_question_generation(
claim: str,
model: str,
prompt: str,
temperature: float,
num_rounds: int,
context: str = None,
num_retries: int = 5,
) -> List[str]:
"""Generates questions that interrogate the information in a claim.
Given a piece of text (claim), we use GPT-3 to generate questions that question the
information in the claim. We run num_rounds of sampling to get a diverse set of questions.
Args:
claim: Text to generate questions off of.
model: Name of the OpenAI GPT-3 model to use.
prompt: The prompt template to query GPT-3 with.
temperature: Temperature to use for sampling questions. 0 represents greedy deconding.
num_rounds: Number of times to sample questions.
Returns:
questions: A list of questions.
"""
if context:
gpt3_input = prompt.format(context=context, claim=claim).strip()
else:
gpt3_input = prompt.format(claim=claim).strip()
questions = set()
for _ in range(num_rounds):
for _ in range(num_retries):
try:
response = openai.completions.create(
model=model,
prompt=gpt3_input,
temperature=temperature,
max_tokens=256,
)
cur_round_questions = parse_api_response(
response.choices[0].text.strip()
)
questions.update(cur_round_questions)
break
except openai.OpenAIError as exception:
print(f"{exception}. Retrying...")
time.sleep(1)
questions = list(sorted(questions))
return questions