llama_output / check.py
rookiemango's picture
Upload folder using huggingface_hub
3e37441 verified
import json
import os
from collections import Counter
base_dir = "generate_result/zero_shot/bd_math/generation/llama3.1/1"
def has_repetition(text, threshold=3):
"""
Check if the given text contains repetitive substrings and return the repetitive phrases.
:param text: The text to check for repetition
:param threshold: The number of repetitions to consider as repetitive
:return: A list of repetitive phrases if found, otherwise an empty list
"""
words = text.split()
repetitive_phrases = []
for n in range(15, 20):
phrases = [" ".join(words[i : i + n]) for i in range(len(words) - n + 1)]
phrase_counts = Counter(phrases)
repetitive_phrases.extend(
[phrase for phrase, count in phrase_counts.items() if count >= threshold]
)
break
return repetitive_phrases
total_items = 0
items_with_repetition = 0
repetition_data = []
for i in range(8):
file_path = os.path.join(base_dir, f"{i}.json")
if not os.path.exists(file_path):
print(f"File {file_path} does not exist. Skipping.")
continue
with open(file_path, "r") as file:
for line_number, line in enumerate(file, 1):
try:
data = json.loads(line)
model_output = data.get("total output", "")[0]
total_items += 1
repetitive_phrases = has_repetition(model_output)
if len(repetitive_phrases):
items_with_repetition += 1
repetition_data.append(
{
"file": f"{i}.json",
"line": line_number,
"prompt": data.get("prompt", ""),
"repetitive_phrases": repetitive_phrases,
}
)
# print(repetitive_phrases[0])
except json.JSONDecodeError:
print(f"Error decoding JSON in file {i}.json, line {line_number}")
except Exception as e:
print(f"Error processing file {i}.json, line {line_number}: {str(e)}")
# Calculate the ratio
ratio = items_with_repetition / total_items if total_items > 0 else 0
print(
f"Ratio of items with repetition: {ratio:.2f} ({items_with_repetition}/{total_items})"
)
# Save repetition data to a file
output_file = "repetition_analysis.json"
with open(os.path.join(base_dir, output_file), "w") as f:
json.dump(repetition_data, f, indent=2)
print(f"Repetition analysis completed. Results saved to {output_file}")