mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
add code for finetune on the ScienceQA dataset
This commit is contained in:
parent
bbd7883d1c
commit
71e52fcaf0
BIN
dataset/.DS_Store
vendored
Normal file
BIN
dataset/.DS_Store
vendored
Normal file
Binary file not shown.
40311
dataset/ScienceQA/pid_splits.json
Normal file
40311
dataset/ScienceQA/pid_splits.json
Normal file
File diff suppressed because it is too large
Load Diff
418394
dataset/ScienceQA/problems.json
Normal file
418394
dataset/ScienceQA/problems.json
Normal file
File diff suppressed because it is too large
Load Diff
23226
dataset/ScienceQA/test_QCM-A.json
Normal file
23226
dataset/ScienceQA/test_QCM-A.json
Normal file
File diff suppressed because it is too large
Load Diff
69852
dataset/ScienceQA/train_QCM-A.json
Normal file
69852
dataset/ScienceQA/train_QCM-A.json
Normal file
File diff suppressed because it is too large
Load Diff
23306
dataset/ScienceQA/val_QCM-A.json
Normal file
23306
dataset/ScienceQA/val_QCM-A.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
eval_configs/.DS_Store
vendored
Normal file
BIN
eval_configs/.DS_Store
vendored
Normal file
Binary file not shown.
71
inference.py
Normal file
71
inference.py
Normal file
@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import gradio as gr
|
||||
import json
|
||||
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.response import Chat
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
from minigpt4.models import *
|
||||
from minigpt4.processors import *
|
||||
from minigpt4.runners import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def setup_seeds(config):
|
||||
seed = config.run_cfg.seed + get_rank()
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
|
||||
|
||||
args = parse_args()
|
||||
cfg = Config(args)
|
||||
|
||||
model_config = cfg.model_cfg
|
||||
model_config.device_8bit = args.gpu_id
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||
|
||||
# print(model_config.output_path)
|
||||
with open(model_config.output_path, 'r') as json_file:
|
||||
for line in json_file:
|
||||
item = json.loads(line)
|
||||
# print(item["image"])
|
||||
# print(item["text"])
|
||||
image_emb = chat.upload_img(item["image"])
|
||||
# [1, 32, 4096]
|
||||
# print(image_emb.shape)
|
||||
embedding = chat.get_context_emb(item["text"], image_emb)
|
||||
llm_message = chat.answer(embs=embedding, max_new_tokens=300, max_length=2000)[0]
|
||||
print(llm_message)
|
BIN
minigpt4/.DS_Store
vendored
Normal file
BIN
minigpt4/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
minigpt4/configs/.DS_Store
vendored
Normal file
BIN
minigpt4/configs/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
minigpt4/configs/datasets/.DS_Store
vendored
Normal file
BIN
minigpt4/configs/datasets/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
minigpt4/configs/datasets/ScienceQA/.DS_Store
vendored
Normal file
BIN
minigpt4/configs/datasets/ScienceQA/.DS_Store
vendored
Normal file
Binary file not shown.
5
minigpt4/configs/datasets/ScienceQA/align.yaml
Normal file
5
minigpt4/configs/datasets/ScienceQA/align.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
ScienceQA:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/MiniGPT-4/dataset/ScienceQA/
|
107
minigpt4/conversation/response.py
Normal file
107
minigpt4/conversation/response.py
Normal file
@ -0,0 +1,107 @@
|
||||
import argparse
|
||||
import time
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.vis_processor = vis_processor
|
||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def answer(self, embs, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
||||
|
||||
# embs = self.get_context_emb(img_list)
|
||||
|
||||
current_max_len = embs.shape[1] + max_new_tokens
|
||||
if current_max_len - max_length > 0:
|
||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, current_max_len - max_length)
|
||||
embs = embs[:, begin_idx:]
|
||||
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
do_sample=True,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def upload_img(self, image):
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, Image.Image):
|
||||
raw_image = image
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(self.device)
|
||||
image_emb, _ = self.model.encode_img(image)
|
||||
return image_emb
|
||||
|
||||
def get_context_emb(self, text_list, img_list):
|
||||
system = "Give the following image: <Img>ImageContent</Img>. You will be able to see the image once I provide it to you. Please answer my questions." + "###"
|
||||
prompt = "Human" + ": " + "<Img><ImageHere></Img> " + text_list + "###"
|
||||
prompt = system + prompt
|
||||
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
# [1, 42, 4096]
|
||||
# [1, 13, 4096]
|
||||
# print(seg_embs[:-1].shape)
|
||||
# print(seg_embs[-1].shape)
|
||||
mixed_embs = torch.cat([seg_embs[0], img_list, seg_embs[1]], dim=1)
|
||||
# mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
BIN
minigpt4/datasets/.DS_Store
vendored
Normal file
BIN
minigpt4/datasets/.DS_Store
vendored
Normal file
Binary file not shown.
@ -103,3 +103,35 @@ class CCSBUAlignBuilder(BaseDatasetBuilder):
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
@registry.register_builder("ScienceQA")
|
||||
class ScienceQABuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = ScienceQADataset
|
||||
|
||||
DATASET_CONFIG_DICT = {
|
||||
"default": "configs/datasets/ScienceQA/align.yaml",
|
||||
}
|
||||
|
||||
def build_datasets(self):
|
||||
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||
logging.info("Building datasets...")
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
storage_path = build_info.storage
|
||||
|
||||
datasets = dict()
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
warnings.warn("storage path {} does not exist.".format(storage_path))
|
||||
|
||||
# create datasets
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets['train'] = dataset_cls(
|
||||
vis_processor=self.vis_processors["train"],
|
||||
text_processor=self.text_processors["train"],
|
||||
ann_paths=[os.path.join(storage_path, 'train_QCM-A.json')],
|
||||
vis_root=os.path.join(storage_path, 'train'),
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
@ -36,7 +36,14 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
self.img_ids = {}
|
||||
n = 0
|
||||
for ann in self.annotation:
|
||||
if "image_id" in ann:
|
||||
img_id = ann["image_id"]
|
||||
if "/" in img_id:
|
||||
image_id = img_id.split("/")[0]
|
||||
if image_id not in self.img_ids.keys():
|
||||
self.img_ids[image_id] = n
|
||||
n += 1
|
||||
else:
|
||||
if img_id not in self.img_ids.keys():
|
||||
self.img_ids[img_id] = n
|
||||
n += 1
|
||||
@ -45,19 +52,42 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
|
||||
# TODO this assumes image input, not general enough
|
||||
ann = self.annotation[index]
|
||||
|
||||
if "image_id" in ann:
|
||||
if "id" in ann:
|
||||
img_file = ann["image_id"]
|
||||
input_prompt = self.text_processor(ann["input"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
# print(image_path)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = self.vis_processor(image)
|
||||
# print(image.shape)
|
||||
caption = self.text_processor(ann["caption"])
|
||||
return {
|
||||
"image": image,
|
||||
"input_prompt": input_prompt,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"].split("/")[0]],
|
||||
}
|
||||
else:
|
||||
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
caption = self.text_processor(ann["caption"])
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
||||
else:
|
||||
input_prompt = self.text_processor(ann["input"])
|
||||
caption = self.text_processor(ann["caption"])
|
||||
return {
|
||||
"image": torch.zeros(3, 224, 224),
|
||||
"input_prompt": input_prompt,
|
||||
"text_input": caption,
|
||||
"image_id": -100,
|
||||
}
|
||||
|
||||
|
||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
||||
|
42
minigpt4/datasets/datasets/science_qa_dataset.py
Normal file
42
minigpt4/datasets/datasets/science_qa_dataset.py
Normal file
@ -0,0 +1,42 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
||||
import torch
|
||||
|
||||
class ScienceQADataset(CaptionDataset):
|
||||
def __getitem__(self, index):
|
||||
# TODO this assumes image input, not general enough
|
||||
ann = self.annotation[index]
|
||||
if "image_id" in ann:
|
||||
# if "id" in ann:
|
||||
# img_file = ann["image_id"]
|
||||
# input_prompt = ann["input"]
|
||||
# else:
|
||||
# img_file = '{}.jpg'.format(ann["image_id"])
|
||||
img_file = ann["image_id"]
|
||||
input_prompt = ann["input"]
|
||||
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
# print(image.shape)
|
||||
caption = ann["caption"]
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"input_prompt": input_prompt,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"].split("/")[0]]
|
||||
}
|
||||
else:
|
||||
input_prompt = ann["input"]
|
||||
caption = ann["caption"]
|
||||
return {
|
||||
"image": torch.zeros(3, 224, 224),
|
||||
"input_prompt": input_prompt,
|
||||
"text_input": caption,
|
||||
"image_id": -100,
|
||||
}
|
@ -16,7 +16,6 @@ class MiniGPT4(Blip2Base):
|
||||
"""
|
||||
BLIP2 GPT-LLAMA model.
|
||||
"""
|
||||
|
||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||
}
|
||||
@ -163,14 +162,50 @@ class MiniGPT4(Blip2Base):
|
||||
else:
|
||||
return img_embeds, atts_img
|
||||
|
||||
def prompt_wrap_image(self, img_embeds, atts_img, prompt):
|
||||
p_before, p_after = [], []
|
||||
|
||||
for prompt_item in prompt:
|
||||
before, after = prompt_item.split('<ImageHere>')
|
||||
p_before.append(before)
|
||||
p_after.append(after)
|
||||
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids)
|
||||
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
||||
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
||||
return wrapped_img_embeds, wrapped_atts_img
|
||||
|
||||
|
||||
def prompt_wrap_no_image(self, prompt, device_gpu):
|
||||
p_tokens = self.llama_tokenizer(
|
||||
prompt, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to(device_gpu)
|
||||
p_embeds = self.llama_model.model.embed_tokens(p_tokens.input_ids)
|
||||
p_atts = torch.ones(p_embeds.size()[:-1], dtype=torch.long)
|
||||
wrapped_p_atts = p_atts[:, :1].expand(-1, p_embeds.shape[1])
|
||||
return p_embeds, wrapped_p_atts
|
||||
|
||||
def forward(self, samples):
|
||||
while True:
|
||||
if "image" in samples:
|
||||
device_gpu = samples["image"].device
|
||||
break
|
||||
if "image" in samples:
|
||||
image = samples["image"]
|
||||
img_embeds, atts_img = self.encode_img(image)
|
||||
if hasattr(samples, 'question_split'): # VQA dataset
|
||||
print('VQA Batch')
|
||||
vqa_prompt = '###Human: <Img><ImageHere></Img> '
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
|
||||
elif self.prompt_list:
|
||||
if "input_prompt" in samples:
|
||||
prefix = '###Human: <Img><ImageHere></Img> '
|
||||
science_prompt = [prefix + item for item in samples["input_prompt"]]
|
||||
img_embeds, atts_img = self.prompt_wrap_image(img_embeds, atts_img, science_prompt)
|
||||
else:
|
||||
science_prompt = samples["input_prompt"]
|
||||
img_embeds, atts_img = self.prompt_wrap_no_image(science_prompt, device_gpu)
|
||||
|
||||
if len(self.prompt_list) > 0:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
|
||||
|
||||
@ -185,7 +220,7 @@ class MiniGPT4(Blip2Base):
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(image.device)
|
||||
).to(img_embeds.device)
|
||||
|
||||
targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
@ -193,7 +228,7 @@ class MiniGPT4(Blip2Base):
|
||||
|
||||
empty_targets = (
|
||||
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
||||
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
|
||||
dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
|
||||
)
|
||||
targets = torch.cat([empty_targets, targets], dim=1)
|
||||
|
||||
|
51
train_configs/minigpt4_stage2_finetune_science.yaml
Normal file
51
train_configs/minigpt4_stage2_finetune_science.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
prompt_path: ""
|
||||
prompt_template: ''
|
||||
ckpt: '/path/to/stage1/checkpoint/'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 3e-5
|
||||
min_lr: 1e-5
|
||||
warmup_lr: 1e-6
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
iters_per_epoch: 200
|
||||
batch_size_train: 3
|
||||
batch_size_eval: 3
|
||||
num_workers: 4
|
||||
warmup_steps: 200
|
||||
|
||||
seed: 42
|
||||
output_dir: "output/minigpt4_stage2_finetune"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
Loading…
Reference in New Issue
Block a user