mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
Add files via upload
This commit is contained in:
parent
081fb02424
commit
c70775af99
178
eval_scripts/eval_sciencevqa.py
Normal file
178
eval_scripts/eval_sciencevqa.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import minigpt4.tasks as tasks
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
||||||
|
from minigpt4.common.logger import setup_logger
|
||||||
|
from minigpt4.common.optims import (
|
||||||
|
LinearWarmupCosineLRScheduler,
|
||||||
|
LinearWarmupStepLRScheduler,
|
||||||
|
)
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.common.utils import now
|
||||||
|
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||||
|
|
||||||
|
# imports modules for registration
|
||||||
|
from minigpt4.datasets.builders import *
|
||||||
|
from minigpt4.models import *
|
||||||
|
from minigpt4.processors import *
|
||||||
|
from minigpt4.runners import *
|
||||||
|
from minigpt4.tasks import *
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Demo")
|
||||||
|
parser.add_argument("--cfg-path", required=False, default='scienceqa_eval.yaml', help="path to configuration file.")
|
||||||
|
parser.add_argument("--ckpt_path", required=False, help="path to configuration file.")
|
||||||
|
parser.add_argument("--lora_r", type=int, default=64, help="path to configuration file.")
|
||||||
|
parser.add_argument("--lora_alpha", type=int, default=16, help="path to configuration file.")
|
||||||
|
parser.add_argument("--name", type=str)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--options",
|
||||||
|
nargs="+",
|
||||||
|
help="override some settings in the used config, the key-value pair "
|
||||||
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
|
"change to --cfg-options instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
print('Initializing Chat')
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
|
|
||||||
|
|
||||||
|
ckpt_list = [args.ckpt_path]
|
||||||
|
|
||||||
|
print('evaluating config:', args.cfg_path)
|
||||||
|
print('evaluating checkpoint:', args.ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
|
for ckpt in ckpt_list:
|
||||||
|
cfg.model_cfg.ckpt = ckpt
|
||||||
|
cfg.model_cfg.lora_r=args.lora_r
|
||||||
|
cfg.model_cfg.lora_alpha=args.lora_alpha
|
||||||
|
|
||||||
|
model_config = cfg.model_cfg
|
||||||
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
model = model_cls.from_config(model_config).to('cuda:0')
|
||||||
|
model.eval()
|
||||||
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.coco_vqa.vis_processor.eval
|
||||||
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
|
|
||||||
|
text_processor_cfg = cfg.datasets_cfg.coco_vqa.text_processor.eval
|
||||||
|
text_processor = registry.get_processor_class(text_processor_cfg.name).from_config(text_processor_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
task = tasks.setup_task(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
|
||||||
|
|
||||||
|
class textVQAEvalDataset(VQADataset):
|
||||||
|
def __init__(self, vis_processor, text_processor, vis_root=None, ann_paths=None):
|
||||||
|
# super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
self.annotation = load_dataset("derek-thomas/ScienceQA", split='test')
|
||||||
|
|
||||||
|
## select anns with image
|
||||||
|
self.annotation = [ann for ann in self.annotation if ann['image']]
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.instruction_pool = [
|
||||||
|
'[vqa] Question: {} Answer:',
|
||||||
|
]
|
||||||
|
|
||||||
|
self.alphabet_options = ["A", "B", "C", "D", "E"]
|
||||||
|
self.max_num_choices = len(self.alphabet_options)
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
ann = self.annotation[index]
|
||||||
|
image = ann['image'].convert("RGB")
|
||||||
|
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
question = self.text_processor(ann["question"])
|
||||||
|
lecture = self.text_processor(ann['lecture'])
|
||||||
|
hint = self.text_processor(ann['hint'])
|
||||||
|
options = ann['choices']
|
||||||
|
|
||||||
|
num_choices = len(ann['choices'])
|
||||||
|
if len(ann['choices'])<self.max_num_choices:
|
||||||
|
options = options + (self.max_num_choices - len(ann['choices'])) * [""]
|
||||||
|
else:
|
||||||
|
options = options
|
||||||
|
|
||||||
|
options = [self.text_processor(opt) for opt in options]
|
||||||
|
|
||||||
|
instruction = random.choice(self.instruction_pool).format(question, lecture)
|
||||||
|
|
||||||
|
instruction = "<s>[INST]<Img><ImageHere></Img> {} [/INST]".format(instruction)
|
||||||
|
|
||||||
|
answer = options[ann['answer']]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"question": question,
|
||||||
|
"choices": options,
|
||||||
|
# "choices": self.alphabet_options,
|
||||||
|
"num_choices": num_choices,
|
||||||
|
"answer": answer,
|
||||||
|
"instruction_input": instruction,
|
||||||
|
"question_id": index,
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = textVQAEvalDataset(vis_processor, text_processor)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.run_cfg.batch_size_eval, \
|
||||||
|
num_workers=cfg.run_cfg.num_workers)
|
||||||
|
|
||||||
|
|
||||||
|
print(len(dataset), len(dataloader))
|
||||||
|
result_dir = cfg.run_cfg.output_dir
|
||||||
|
if not os.path.exists(result_dir):
|
||||||
|
os.makedirs(result_dir)
|
||||||
|
|
||||||
|
|
||||||
|
epoch = cfg.model_cfg.ckpt.split('/')[-1].split('.')[0].split('_')[1]
|
||||||
|
exp_id = cfg.model_cfg.ckpt.split('/')[-2] + "_" + epoch
|
||||||
|
|
||||||
|
val_result = task.evaluation(model, dataloader, cuda_enabled=True)
|
||||||
|
task.after_evaluation(val_result, split_name=f'{args.name}_scienceqa_val', result_dir=result_dir)
|
||||||
|
|
||||||
|
from minigpt4.common.vqa_tools.vqa import VQA
|
||||||
|
from minigpt4.common.vqa_tools.vqa_eval import VQAEval
|
||||||
|
import json, logging, os
|
||||||
|
result_file = '{}/{}_scienceqa_val.json'.format(result_dir, args.name)
|
||||||
|
results = json.load(open(result_file, "r"))
|
||||||
|
acc = []
|
||||||
|
vqa_tool = VQAEval()
|
||||||
|
|
||||||
|
for res in results:
|
||||||
|
|
||||||
|
gt_ans = res["gt_ans"]
|
||||||
|
pred = res["pred_ans"]
|
||||||
|
|
||||||
|
pred = vqa_tool.processPunctuation(pred)
|
||||||
|
pred = vqa_tool.processDigitArticle(pred)
|
||||||
|
|
||||||
|
vqa_acc = 1 if pred == gt_ans else 0
|
||||||
|
|
||||||
|
acc.append(vqa_acc)
|
||||||
|
|
||||||
|
accuracy = sum(acc) / len(acc) * 100
|
||||||
|
print('scienceqa val acc: ', accuracy, flush=True)
|
||||||
|
|
185
eval_scripts/eval_textvqa.py
Normal file
185
eval_scripts/eval_textvqa.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import minigpt4.tasks as tasks
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
||||||
|
from minigpt4.common.logger import setup_logger
|
||||||
|
from minigpt4.common.optims import (
|
||||||
|
LinearWarmupCosineLRScheduler,
|
||||||
|
LinearWarmupStepLRScheduler,
|
||||||
|
)
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.common.utils import now
|
||||||
|
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
||||||
|
|
||||||
|
# imports modules for registration
|
||||||
|
from minigpt4.datasets.builders import *
|
||||||
|
from minigpt4.models import *
|
||||||
|
from minigpt4.processors import *
|
||||||
|
from minigpt4.runners import *
|
||||||
|
from minigpt4.tasks import *
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Demo")
|
||||||
|
parser.add_argument("--cfg-path", required=False, default='textvqa_eval.yaml', help="path to configuration file.")
|
||||||
|
parser.add_argument("--ckpt_path", required=False, help="path to configuration file.")
|
||||||
|
parser.add_argument("--lora_r", type=int, default=64, help="path to configuration file.")
|
||||||
|
parser.add_argument("--lora_alpha", type=int, default=16, help="path to configuration file.")
|
||||||
|
parser.add_argument("--name", type=str)
|
||||||
|
parser.add_argument("--img_path", type=str)
|
||||||
|
parser.add_argument("--eval_file_path", type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--options",
|
||||||
|
nargs="+",
|
||||||
|
help="override some settings in the used config, the key-value pair "
|
||||||
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
|
"change to --cfg-options instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
print('Initializing Chat')
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
|
|
||||||
|
ckpt_list = [args.ckpt_path]
|
||||||
|
|
||||||
|
print('evaluating config:', args.cfg_path)
|
||||||
|
print('evaluating checkpoint:', args.ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
f = open(f'{args.eval_file_path}/textocr/TextOCR_0.1_val.json', 'r')
|
||||||
|
textcap_dataset_val = json.load(f)
|
||||||
|
|
||||||
|
all_img_ids = list(textcap_dataset_val['imgs'].keys())
|
||||||
|
ann_ids = list(textcap_dataset_val['anns'].keys())
|
||||||
|
|
||||||
|
|
||||||
|
for ckpt in ckpt_list:
|
||||||
|
cfg.model_cfg.ckpt = ckpt
|
||||||
|
cfg.model_cfg.lora_r=args.lora_r
|
||||||
|
cfg.model_cfg.lora_alpha=args.lora_alpha
|
||||||
|
|
||||||
|
model_config = cfg.model_cfg
|
||||||
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
model = model_cls.from_config(model_config).to('cuda:0')
|
||||||
|
model.eval()
|
||||||
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.coco_vqa.vis_processor.eval
|
||||||
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
|
|
||||||
|
text_processor_cfg = cfg.datasets_cfg.coco_vqa.text_processor.eval
|
||||||
|
text_processor = registry.get_processor_class(text_processor_cfg.name).from_config(text_processor_cfg)
|
||||||
|
task = tasks.setup_task(cfg)
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
def find_most_frequent_item(items):
|
||||||
|
counter = Counter(items)
|
||||||
|
most_common = counter.most_common(1)
|
||||||
|
if most_common:
|
||||||
|
return most_common[0][0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
|
||||||
|
|
||||||
|
class textVQAEvalDataset(VQADataset):
|
||||||
|
def __init__(self, vis_processor, text_processor, vis_root=None, ann_paths=None, use_ocr=False):
|
||||||
|
# super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
self.annotation = load_dataset("textvqa", split="validation",cache_dir=f"{args.eval_file_path}/textvqa_cache")
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.instruction_pool = [
|
||||||
|
'[vqa] Question: {} Short answer:',
|
||||||
|
]
|
||||||
|
|
||||||
|
self.use_ocr = use_ocr
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
ann = self.annotation[index]
|
||||||
|
image_path = ann['flickr_original_url']
|
||||||
|
image = ann["image"].convert("RGB")
|
||||||
|
|
||||||
|
img_id = ann['image_id']
|
||||||
|
sel_anns = [textcap_dataset_val['anns'][ann] for ann in ann_ids if ann.startswith(img_id)]
|
||||||
|
sel_caps = [ann['utf8_string'] for ann in sel_anns]
|
||||||
|
sel_caps = [cap for cap in sel_caps if cap not in ['\n', '.']]
|
||||||
|
captions = " ".join(sel_caps)
|
||||||
|
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
question = self.text_processor(ann["question"])
|
||||||
|
captions = self.text_processor(captions)
|
||||||
|
|
||||||
|
instruction = random.choice(self.instruction_pool).format(question)
|
||||||
|
|
||||||
|
if self.use_ocr and img_id in all_img_ids:
|
||||||
|
instruction = "<s>[INST] <Img><ImageHere></Img> OCR tokens: {}. {} [/INST]".format(captions, instruction)
|
||||||
|
else:
|
||||||
|
instruction = "<s>[INST] <Img><ImageHere></Img> {} [/INST]".format(instruction)
|
||||||
|
|
||||||
|
|
||||||
|
answers = find_most_frequent_item(ann["answers"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"text_input": question,
|
||||||
|
"answer": answers,
|
||||||
|
'image_path': image_path,
|
||||||
|
"instruction_input": instruction,
|
||||||
|
"question_id": ann["question_id"],
|
||||||
|
# "instance_id": ann["instance_id"],
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = textVQAEvalDataset(vis_processor, text_processor, use_ocr=cfg.run_cfg.use_ocr)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.run_cfg.batch_size_eval, num_workers=cfg.run_cfg.num_workers)
|
||||||
|
|
||||||
|
result_dir = cfg.run_cfg.output_dir
|
||||||
|
if not os.path.exists(result_dir):
|
||||||
|
os.makedirs(result_dir)
|
||||||
|
|
||||||
|
epoch = cfg.model_cfg.ckpt.split('/')[-1].split('.')[0].split('_')[1]
|
||||||
|
exp_id = cfg.model_cfg.ckpt.split('/')[-2] + "_" + epoch
|
||||||
|
|
||||||
|
val_result = task.evaluation(model, dataloader, cuda_enabled=True)
|
||||||
|
task.after_evaluation(val_result, split_name=f'{args.name}_textvqa_val', result_dir=result_dir)
|
||||||
|
|
||||||
|
|
||||||
|
anno_files = f'{args.eval_file_path}/TextVQA/gt_ann.json'
|
||||||
|
ques_files = f'{args.eval_file_path}/TextVQA/gt_ques.json'
|
||||||
|
|
||||||
|
from minigpt4.common.vqa_tools.vqa import VQA
|
||||||
|
from minigpt4.common.vqa_tools.vqa_eval import VQAEval
|
||||||
|
import logging, json, os
|
||||||
|
|
||||||
|
result_file = '{}/{}_textvqa_val.json'.format(result_dir, args.name)
|
||||||
|
vqa = VQA(anno_files, ques_files)
|
||||||
|
vqa_result = vqa.loadRes(
|
||||||
|
resFile=result_file, quesFile=ques_files
|
||||||
|
)
|
||||||
|
|
||||||
|
# create vqaEval object by taking vqa and vqaRes
|
||||||
|
# n is precision of accuracy (number of places after decimal), default is 2
|
||||||
|
vqa_scorer = VQAEval(vqa, vqa_result, n=2)
|
||||||
|
logging.info("Start VQA evaluation.")
|
||||||
|
vqa_scorer.evaluate()
|
||||||
|
|
||||||
|
# print accuracies
|
||||||
|
overall_acc = vqa_scorer.accuracy["overall"]
|
||||||
|
print('textvqa val acc: ', overall_acc, flush=True)
|
48
eval_scripts/scienceqa_eval.yaml
Normal file
48
eval_scripts/scienceqa_eval.yaml
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt4
|
||||||
|
model_type: pretrain_llama2
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "</s>"
|
||||||
|
low_resource: True
|
||||||
|
prompt_template: '[INST] {} [/INST] '
|
||||||
|
ckpt: 'please set this value to the path of pretrained checkpoint'
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
coco_vqa: # not used
|
||||||
|
type: eval
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 336
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: scienceqa
|
||||||
|
# optimization-specific
|
||||||
|
batch_size_train: 16
|
||||||
|
batch_size_eval: 32
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# inference-specific
|
||||||
|
num_ans_candidates: 5
|
||||||
|
max_len: 10
|
||||||
|
min_len: 1
|
||||||
|
num_beams: 5
|
||||||
|
inference_method: "generate"
|
||||||
|
prompt: "Question: {} Short answer:"
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "results"
|
||||||
|
|
||||||
|
evaluate: True
|
||||||
|
test_splits: ["val"]
|
||||||
|
|
||||||
|
# distribution-specific
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
||||||
|
|
48
eval_scripts/textvqa_eval.yaml
Normal file
48
eval_scripts/textvqa_eval.yaml
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt4
|
||||||
|
model_type: pretrain_llama2
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "</s>"
|
||||||
|
low_resource: True
|
||||||
|
prompt_template: '[INST] {} [/INST] '
|
||||||
|
ckpt: 'please set this value to the path of pretrained checkpoint'
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
coco_vqa: # not used
|
||||||
|
type: eval
|
||||||
|
vis_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 336
|
||||||
|
text_processor:
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: vqa
|
||||||
|
# optimization-specific
|
||||||
|
batch_size_train: 32
|
||||||
|
batch_size_eval: 128
|
||||||
|
# if do not use OCR token: batch=32, Evaluation Total time: 0:23:47 (9.0954 s / it)
|
||||||
|
use_ocr: True
|
||||||
|
# use OCR token: batch_size=16; Evaluation Total time: 0:36:34 (7.0116 s / it)
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# inference-specific
|
||||||
|
max_len: 10
|
||||||
|
min_len: 1
|
||||||
|
num_beams: 1
|
||||||
|
inference_method: "generate"
|
||||||
|
prompt: "Question: {} Short answer:"
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "results"
|
||||||
|
|
||||||
|
evaluate: True
|
||||||
|
test_splits: ["val"]
|
||||||
|
|
||||||
|
# distribution-specific
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
Loading…
Reference in New Issue
Block a user