diff --git a/eval_textvqa.py b/eval_textvqa.py new file mode 100644 index 0000000..eb39150 --- /dev/null +++ b/eval_textvqa.py @@ -0,0 +1,120 @@ +import os +import json +from collections import Counter +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +from minigpt4.common.eval_utils import ( + prepare_texts, + init_model, + eval_parser, +) +from minigpt4.conversation.conversation import CONV_VISION_minigptv2 +from minigpt4.common.config import Config + + +def list_of_str(arg): + return list(map(str, arg.split(","))) + + +parser = eval_parser() +parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate") +parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco") +parser.add_argument("--resample", action='store_true', help="resolution used in refcoco") +args = parser.parse_args() +cfg = Config(args) + +model, vis_processor = init_model(args) +conv_temp = CONV_VISION_minigptv2.copy() +conv_temp.system = "" +model.eval() +save_path = cfg.run_cfg.save_path + +os.makedirs(save_path, exist_ok=True) + + +class EvalTextVQAData(torch.utils.data.Dataset): + def __init__(self, loaded_data, image_processor): + self.loaded_data = loaded_data + self.image_processor = image_processor + + def __len__(self): + return len(self.loaded_data) + + def __getitem__(self, idx): + data = self.loaded_data[idx] + question = data["question"] + question = f"[vqa] {question}" + image = Image.open(data["image_path"]).convert("RGB") + image = self.image_processor(image) + return image, question, data["image_id"], data["answers"] + + +eval_file_path = cfg.evaluation_datasets_cfg["textvqa"]["eval_file_path"] +img_path = cfg.evaluation_datasets_cfg["textvqa"]["img_path"] +batch_size = cfg.evaluation_datasets_cfg["textvqa"]["batch_size"] +max_new_tokens = cfg.evaluation_datasets_cfg["textvqa"]["max_new_tokens"] + +with open(eval_file_path, "r") as f: + train_data = json.load(f) + +data = [] +for item in train_data: + data.append( + { + "question": item["question"], + "image_id": item["image_id"], + "image_path": os.path.join(img_path, item["image_id"] + ".jpg"), + "answers": item["answers"], + } + ) + +textvqa = EvalTextVQAData(data, vis_processor) +eval_dataloader = DataLoader(textvqa, batch_size=batch_size, shuffle=False) + +count = 0 +total = 0 +minigpt4_predict = [] +print("Evaluating on TextVQA dataset") +for images, texts, image_id, labels in tqdm(eval_dataloader): + texts = prepare_texts( + texts, conv_temp + ) # warp the texts with conversation template + answers = model.generate(images, + texts, + max_new_tokens=max_new_tokens, + do_sample=False) + + # Stack the labels to correct order (transpose) + labels = [list(x) for x in zip(*labels)] + + for answer, labels in zip(answers, labels): + result = dict() + result["pred"] = answer.lower().replace("", "").strip() + result["gt"] = Counter(labels).most_common(1)[0][0] + + minigpt4_predict.append(result) + if answer.lower() == result["gt"]: + count += 1 + total += 1 + + # Calculate BLEU score + chencherry = SmoothingFunction() + bleu_score = sentence_bleu( + labels, answer, smoothing_function=chencherry.method1 + ) + result["bleu"] = bleu_score + + +print("Saving predictions to", save_path) +file_save_path = os.path.join(save_path, "textvqa.json") +with open(file_save_path, "w") as f: + json.dump(minigpt4_predict, f) + +print("Top 1 Accuracy:", count / total * 100, flush=True) +print("Average BLEU score: ", + np.mean([pred["bleu"] for pred in minigpt4_predict]), + flush=True)