mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
Merge pull request #12 from ThuanNaN/textvqa-eval
Add eval script for TextVQA dataset
This commit is contained in:
commit
c5796f091b
2
.gitignore
vendored
2
.gitignore
vendored
@ -183,5 +183,7 @@ dataset/Evaluation.md
|
|||||||
jupyter_notebook.slurm
|
jupyter_notebook.slurm
|
||||||
MVTEC/
|
MVTEC/
|
||||||
MVTEC_det/
|
MVTEC_det/
|
||||||
|
TextVQA/
|
||||||
|
TextVQA_tiny/
|
||||||
*.pth
|
*.pth
|
||||||
log.*
|
log.*
|
8
Makefile
8
Makefile
@ -6,3 +6,11 @@ train_mvtec:
|
|||||||
CUDA_VISIBLE_DEVICES=0 \
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
python train.py --cfg-path train_configs/minigptv2_finetune_mvtec.yaml
|
python train.py --cfg-path train_configs/minigptv2_finetune_mvtec.yaml
|
||||||
|
|
||||||
|
|
||||||
|
eval_textvqa:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python eval_textvqa.py --cfg-path ./eval_configs/minigptv2_benchmark_evaluation.yaml
|
||||||
|
|
||||||
|
train_textvqa:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python train.py --cfg-path train_configs/minigptv2_finetune_textvqa.yaml
|
||||||
|
184
data/create_textvqa_dataset.ipynb
Normal file
184
data/create_textvqa_dataset.ipynb
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"from pathlib import Path"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"BASE_DIR = Path(\"./TextVQA\")\n",
|
||||||
|
"NEW_DIR = Path(\"./TextVQA_tiny\")\n",
|
||||||
|
"NEW_DIR.mkdir(exist_ok=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"BASE_IMG_DIR = BASE_DIR / \"train_images\"\n",
|
||||||
|
"NEW_IMG_DIR = NEW_DIR / \"images\"\n",
|
||||||
|
"NEW_IMG_DIR.mkdir(exist_ok=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. READ and FILTER the answer for each question"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import random\n",
|
||||||
|
"from typing import List\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def filter_answers(answers: List[str]) -> str:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Filter out answers that are most common in a list of answers\n",
|
||||||
|
" :param answers: List of answers\n",
|
||||||
|
" :return: Most common answer\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" frequency = Counter(answers)\n",
|
||||||
|
" max_freq = max(frequency.values())\n",
|
||||||
|
" tie_terms = [term for term, count in frequency.items() if count == max_freq]\n",
|
||||||
|
" if len(tie_terms) == 0:\n",
|
||||||
|
" return None\n",
|
||||||
|
" if \"unanswerable\" in tie_terms:\n",
|
||||||
|
" if len(tie_terms) > 1:\n",
|
||||||
|
" tie_terms.remove(\"unanswerable\")\n",
|
||||||
|
" return tie_terms[0]\n",
|
||||||
|
" else:\n",
|
||||||
|
" return None\n",
|
||||||
|
" else:\n",
|
||||||
|
" random_term = random.choice(tie_terms)\n",
|
||||||
|
" return random_term"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(34120, 4922)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n",
|
||||||
|
"\n",
|
||||||
|
"with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n",
|
||||||
|
" train_set = json.load(f)[\"data\"]\n",
|
||||||
|
" train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n",
|
||||||
|
" for d in train_set:\n",
|
||||||
|
" filtered_answer = filter_answers(d[\"answers\"])\n",
|
||||||
|
" if filtered_answer is None:\n",
|
||||||
|
" train_set.remove(d)\n",
|
||||||
|
" else:\n",
|
||||||
|
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
|
||||||
|
"\n",
|
||||||
|
"with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n",
|
||||||
|
" val_set = json.load(f)[\"data\"]\n",
|
||||||
|
" val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n",
|
||||||
|
" for d in val_set:\n",
|
||||||
|
" filtered_answer = filter_answers(d[\"answers\"])\n",
|
||||||
|
" if filtered_answer is None:\n",
|
||||||
|
" val_set.remove(d)\n",
|
||||||
|
" else:\n",
|
||||||
|
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
|
||||||
|
" \n",
|
||||||
|
"\n",
|
||||||
|
"len(train_set), len(val_set)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"random.seed(42)\n",
|
||||||
|
"random.shuffle(train_set)\n",
|
||||||
|
"random.shuffle(val_set)\n",
|
||||||
|
"\n",
|
||||||
|
"train_set = train_set[:4000]\n",
|
||||||
|
"val_set = val_set[:1000]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for d in train_set:\n",
|
||||||
|
" img_id = d[\"image_id\"]\n",
|
||||||
|
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n",
|
||||||
|
"\n",
|
||||||
|
"for d in val_set:\n",
|
||||||
|
" img_id = d[\"image_id\"]\n",
|
||||||
|
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open(NEW_DIR / \"train.json\", \"w\") as f:\n",
|
||||||
|
" json.dump(train_set, f, indent=4)\n",
|
||||||
|
"\n",
|
||||||
|
"with open(NEW_DIR / \"val.json\", \"w\") as f:\n",
|
||||||
|
" json.dump(val_set, f, indent=4)\n",
|
||||||
|
" "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "minigptv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.21"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -28,6 +28,13 @@ evaluation_datasets:
|
|||||||
eval_file_path: ./data/MVTEC_det/val_data.json
|
eval_file_path: ./data/MVTEC_det/val_data.json
|
||||||
max_new_tokens: 40
|
max_new_tokens: 40
|
||||||
|
|
||||||
|
textvqa:
|
||||||
|
batch_size: 4
|
||||||
|
eval_file_path: ./data/TextVQA_tiny/val.json
|
||||||
|
img_path: ./data/TextVQA_tiny/images
|
||||||
|
max_new_tokens: 20
|
||||||
|
|
||||||
|
|
||||||
run:
|
run:
|
||||||
task: image_text_pretrain
|
task: image_text_pretrain
|
||||||
name: minigptv2_evaluation
|
name: minigptv2_evaluation
|
||||||
|
120
eval_textvqa.py
Normal file
120
eval_textvqa.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from collections import Counter
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
from minigpt4.common.eval_utils import (
|
||||||
|
prepare_texts,
|
||||||
|
init_model,
|
||||||
|
eval_parser,
|
||||||
|
)
|
||||||
|
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
def list_of_str(arg):
|
||||||
|
return list(map(str, arg.split(",")))
|
||||||
|
|
||||||
|
|
||||||
|
parser = eval_parser()
|
||||||
|
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
|
||||||
|
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
|
||||||
|
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
|
|
||||||
|
model, vis_processor = init_model(args)
|
||||||
|
conv_temp = CONV_VISION_minigptv2.copy()
|
||||||
|
conv_temp.system = ""
|
||||||
|
model.eval()
|
||||||
|
save_path = cfg.run_cfg.save_path
|
||||||
|
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalTextVQAData(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, loaded_data, image_processor):
|
||||||
|
self.loaded_data = loaded_data
|
||||||
|
self.image_processor = image_processor
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.loaded_data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
data = self.loaded_data[idx]
|
||||||
|
question = data["question"]
|
||||||
|
question = f"[vqa] {question}"
|
||||||
|
image = Image.open(data["image_path"]).convert("RGB")
|
||||||
|
image = self.image_processor(image)
|
||||||
|
return image, question, data["image_id"], data["answers"]
|
||||||
|
|
||||||
|
|
||||||
|
eval_file_path = cfg.evaluation_datasets_cfg["textvqa"]["eval_file_path"]
|
||||||
|
img_path = cfg.evaluation_datasets_cfg["textvqa"]["img_path"]
|
||||||
|
batch_size = cfg.evaluation_datasets_cfg["textvqa"]["batch_size"]
|
||||||
|
max_new_tokens = cfg.evaluation_datasets_cfg["textvqa"]["max_new_tokens"]
|
||||||
|
|
||||||
|
with open(eval_file_path, "r") as f:
|
||||||
|
train_data = json.load(f)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for item in train_data:
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"question": item["question"],
|
||||||
|
"image_id": item["image_id"],
|
||||||
|
"image_path": os.path.join(img_path, item["image_id"] + ".jpg"),
|
||||||
|
"answers": item["answers"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
textvqa = EvalTextVQAData(data, vis_processor)
|
||||||
|
eval_dataloader = DataLoader(textvqa, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
total = 0
|
||||||
|
minigpt4_predict = []
|
||||||
|
print("Evaluating on TextVQA dataset")
|
||||||
|
for images, texts, image_id, labels in tqdm(eval_dataloader):
|
||||||
|
texts = prepare_texts(
|
||||||
|
texts, conv_temp
|
||||||
|
) # warp the texts with conversation template
|
||||||
|
answers = model.generate(images,
|
||||||
|
texts,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=False)
|
||||||
|
|
||||||
|
# Stack the labels to correct order (transpose)
|
||||||
|
labels = [list(x) for x in zip(*labels)]
|
||||||
|
|
||||||
|
for answer, labels in zip(answers, labels):
|
||||||
|
result = dict()
|
||||||
|
result["pred"] = answer.lower().replace("<unk>", "").strip()
|
||||||
|
result["gt"] = Counter(labels).most_common(1)[0][0]
|
||||||
|
|
||||||
|
minigpt4_predict.append(result)
|
||||||
|
if answer.lower() == result["gt"]:
|
||||||
|
count += 1
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
# Calculate BLEU score
|
||||||
|
chencherry = SmoothingFunction()
|
||||||
|
bleu_score = sentence_bleu(
|
||||||
|
labels, answer, smoothing_function=chencherry.method1
|
||||||
|
)
|
||||||
|
result["bleu"] = bleu_score
|
||||||
|
|
||||||
|
|
||||||
|
print("Saving predictions to", save_path)
|
||||||
|
file_save_path = os.path.join(save_path, "textvqa.json")
|
||||||
|
with open(file_save_path, "w") as f:
|
||||||
|
json.dump(minigpt4_predict, f)
|
||||||
|
|
||||||
|
print("Top 1 Accuracy:", count / total * 100, flush=True)
|
||||||
|
print("Average BLEU score: ",
|
||||||
|
np.mean([pred["bleu"] for pred in minigpt4_predict]),
|
||||||
|
flush=True)
|
Loading…
Reference in New Issue
Block a user