Merge pull request #12 from ThuanNaN/textvqa-eval

Add eval script for TextVQA dataset
This commit is contained in:
Nguyen Thuan Duong 2025-01-20 06:32:18 +07:00 committed by GitHub
commit c5796f091b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 321 additions and 0 deletions

2
.gitignore vendored
View File

@ -183,5 +183,7 @@ dataset/Evaluation.md
jupyter_notebook.slurm
MVTEC/
MVTEC_det/
TextVQA/
TextVQA_tiny/
*.pth
log.*

View File

@ -6,3 +6,11 @@ train_mvtec:
CUDA_VISIBLE_DEVICES=0 \
python train.py --cfg-path train_configs/minigptv2_finetune_mvtec.yaml
eval_textvqa:
CUDA_VISIBLE_DEVICES=0 \
python eval_textvqa.py --cfg-path ./eval_configs/minigptv2_benchmark_evaluation.yaml
train_textvqa:
CUDA_VISIBLE_DEVICES=0 \
python train.py --cfg-path train_configs/minigptv2_finetune_textvqa.yaml

View File

@ -0,0 +1,184 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"BASE_DIR = Path(\"./TextVQA\")\n",
"NEW_DIR = Path(\"./TextVQA_tiny\")\n",
"NEW_DIR.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"BASE_IMG_DIR = BASE_DIR / \"train_images\"\n",
"NEW_IMG_DIR = NEW_DIR / \"images\"\n",
"NEW_IMG_DIR.mkdir(exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. READ and FILTER the answer for each question"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from typing import List\n",
"from collections import Counter\n",
"\n",
"\n",
"def filter_answers(answers: List[str]) -> str:\n",
" \"\"\"\n",
" Filter out answers that are most common in a list of answers\n",
" :param answers: List of answers\n",
" :return: Most common answer\n",
" \"\"\"\n",
" frequency = Counter(answers)\n",
" max_freq = max(frequency.values())\n",
" tie_terms = [term for term, count in frequency.items() if count == max_freq]\n",
" if len(tie_terms) == 0:\n",
" return None\n",
" if \"unanswerable\" in tie_terms:\n",
" if len(tie_terms) > 1:\n",
" tie_terms.remove(\"unanswerable\")\n",
" return tie_terms[0]\n",
" else:\n",
" return None\n",
" else:\n",
" random_term = random.choice(tie_terms)\n",
" return random_term"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(34120, 4922)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n",
"\n",
"with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n",
" train_set = json.load(f)[\"data\"]\n",
" train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n",
" for d in train_set:\n",
" filtered_answer = filter_answers(d[\"answers\"])\n",
" if filtered_answer is None:\n",
" train_set.remove(d)\n",
" else:\n",
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
"\n",
"with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n",
" val_set = json.load(f)[\"data\"]\n",
" val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n",
" for d in val_set:\n",
" filtered_answer = filter_answers(d[\"answers\"])\n",
" if filtered_answer is None:\n",
" val_set.remove(d)\n",
" else:\n",
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
" \n",
"\n",
"len(train_set), len(val_set)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"random.seed(42)\n",
"random.shuffle(train_set)\n",
"random.shuffle(val_set)\n",
"\n",
"train_set = train_set[:4000]\n",
"val_set = val_set[:1000]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"for d in train_set:\n",
" img_id = d[\"image_id\"]\n",
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n",
"\n",
"for d in val_set:\n",
" img_id = d[\"image_id\"]\n",
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"with open(NEW_DIR / \"train.json\", \"w\") as f:\n",
" json.dump(train_set, f, indent=4)\n",
"\n",
"with open(NEW_DIR / \"val.json\", \"w\") as f:\n",
" json.dump(val_set, f, indent=4)\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "minigptv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -28,6 +28,13 @@ evaluation_datasets:
eval_file_path: ./data/MVTEC_det/val_data.json
max_new_tokens: 40
textvqa:
batch_size: 4
eval_file_path: ./data/TextVQA_tiny/val.json
img_path: ./data/TextVQA_tiny/images
max_new_tokens: 20
run:
task: image_text_pretrain
name: minigptv2_evaluation

120
eval_textvqa.py Normal file
View File

@ -0,0 +1,120 @@
import os
import json
from collections import Counter
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from minigpt4.common.eval_utils import (
prepare_texts,
init_model,
eval_parser,
)
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.common.config import Config
def list_of_str(arg):
return list(map(str, arg.split(",")))
parser = eval_parser()
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
args = parser.parse_args()
cfg = Config(args)
model, vis_processor = init_model(args)
conv_temp = CONV_VISION_minigptv2.copy()
conv_temp.system = ""
model.eval()
save_path = cfg.run_cfg.save_path
os.makedirs(save_path, exist_ok=True)
class EvalTextVQAData(torch.utils.data.Dataset):
def __init__(self, loaded_data, image_processor):
self.loaded_data = loaded_data
self.image_processor = image_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
data = self.loaded_data[idx]
question = data["question"]
question = f"[vqa] {question}"
image = Image.open(data["image_path"]).convert("RGB")
image = self.image_processor(image)
return image, question, data["image_id"], data["answers"]
eval_file_path = cfg.evaluation_datasets_cfg["textvqa"]["eval_file_path"]
img_path = cfg.evaluation_datasets_cfg["textvqa"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["textvqa"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["textvqa"]["max_new_tokens"]
with open(eval_file_path, "r") as f:
train_data = json.load(f)
data = []
for item in train_data:
data.append(
{
"question": item["question"],
"image_id": item["image_id"],
"image_path": os.path.join(img_path, item["image_id"] + ".jpg"),
"answers": item["answers"],
}
)
textvqa = EvalTextVQAData(data, vis_processor)
eval_dataloader = DataLoader(textvqa, batch_size=batch_size, shuffle=False)
count = 0
total = 0
minigpt4_predict = []
print("Evaluating on TextVQA dataset")
for images, texts, image_id, labels in tqdm(eval_dataloader):
texts = prepare_texts(
texts, conv_temp
) # warp the texts with conversation template
answers = model.generate(images,
texts,
max_new_tokens=max_new_tokens,
do_sample=False)
# Stack the labels to correct order (transpose)
labels = [list(x) for x in zip(*labels)]
for answer, labels in zip(answers, labels):
result = dict()
result["pred"] = answer.lower().replace("<unk>", "").strip()
result["gt"] = Counter(labels).most_common(1)[0][0]
minigpt4_predict.append(result)
if answer.lower() == result["gt"]:
count += 1
total += 1
# Calculate BLEU score
chencherry = SmoothingFunction()
bleu_score = sentence_bleu(
labels, answer, smoothing_function=chencherry.method1
)
result["bleu"] = bleu_score
print("Saving predictions to", save_path)
file_save_path = os.path.join(save_path, "textvqa.json")
with open(file_save_path, "w") as f:
json.dump(minigpt4_predict, f)
print("Top 1 Accuracy:", count / total * 100, flush=True)
print("Average BLEU score: ",
np.mean([pred["bleu"] for pred in minigpt4_predict]),
flush=True)