Spaces:
Runtime error
Runtime error
import os | |
import logging | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import nltk | |
def logging_handler(verbose, save_name, idx=0): | |
logger = logging.getLogger(str(idx)) | |
logger.setLevel(logging.INFO) | |
stream_logger = logging.StreamHandler() | |
stream_logger.setFormatter(logging.Formatter("%(message)s")) | |
logger.addHandler(stream_logger) | |
if save_name is not None: | |
savepath = f"results/{save_name}" | |
if not os.path.exists(savepath): | |
os.makedirs(savepath) | |
file_logger = logging.FileHandler(f"{savepath}/{idx}.log") | |
file_logger.setFormatter(logging.Formatter("%(message)s")) | |
logger.addHandler(file_logger) | |
return logger | |
def image_saver(images, save_name, idx=0, interactive=True): | |
fig, a = plt.subplots(2,5) | |
fig.set_size_inches(30, 15) | |
for i in range(10): | |
a[i//5][i%5].imshow(images[i]) | |
a[i//5][i%5].axis('off') | |
a[i//5][i%5].set_aspect('equal') | |
plt.tight_layout() | |
plt.subplots_adjust(wspace=0, hspace=0) | |
if not interactive: | |
plt.savefig(f"results/{save_name}/{idx}.png") | |
else: | |
plt.savefig(f"{save_name}.png") | |
def assert_checks(args): | |
if args.question_strategy=="gpt3": | |
assert args.include_what | |
def extract_nouns(sents): | |
noun_list = [] | |
for idx, s in enumerate(sents): | |
curr = [] | |
sent = (nltk.pos_tag(s.split())) | |
for word in sent: | |
if word[1] not in ["NN", "NNS"]: continue | |
currword = word[0].replace('.','') | |
curr.append(currword.lower()) | |
noun_list.append(curr) | |
return noun_list |