add code for finetune on the ScienceQA dataset

This commit is contained in:
1429904852 2023-08-19 17:49:51 +08:00
parent bbd7883d1c
commit 71e52fcaf0
20 changed files with 575494 additions and 32 deletions

BIN
dataset/.DS_Store vendored Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

418394
dataset/ScienceQA/problems.json Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

BIN
eval_configs/.DS_Store vendored Normal file

Binary file not shown.

71
inference.py Normal file
View File

@ -0,0 +1,71 @@
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
import json
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.response import Chat
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
# print(model_config.output_path)
with open(model_config.output_path, 'r') as json_file:
for line in json_file:
item = json.loads(line)
# print(item["image"])
# print(item["text"])
image_emb = chat.upload_img(item["image"])
# [1, 32, 4096]
# print(image_emb.shape)
embedding = chat.get_context_emb(item["text"], image_emb)
llm_message = chat.answer(embs=embedding, max_new_tokens=300, max_length=2000)[0]
print(llm_message)

BIN
minigpt4/.DS_Store vendored Normal file

Binary file not shown.

BIN
minigpt4/configs/.DS_Store vendored Normal file

Binary file not shown.

BIN
minigpt4/configs/datasets/.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,5 @@
datasets:
ScienceQA:
data_type: images
build_info:
storage: /path/to/MiniGPT-4/dataset/ScienceQA/

View File

@ -0,0 +1,107 @@
import argparse
import time
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
from minigpt4.common.registry import registry
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
self.device = device
self.model = model
self.vis_processor = vis_processor
stop_words_ids = [torch.tensor([835]).to(self.device),
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def answer(self, embs, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
# embs = self.get_context_emb(img_list)
current_max_len = embs.shape[1] + max_new_tokens
if current_max_len - max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
return output_text, output_token.cpu().numpy()
def upload_img(self, image):
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, Image.Image):
raw_image = image
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
return image_emb
def get_context_emb(self, text_list, img_list):
system = "Give the following image: <Img>ImageContent</Img>. You will be able to see the image once I provide it to you. Please answer my questions." + "###"
prompt = "Human" + ": " + "<Img><ImageHere></Img> " + text_list + "###"
prompt = system + prompt
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
# [1, 42, 4096]
# [1, 13, 4096]
# print(seg_embs[:-1].shape)
# print(seg_embs[-1].shape)
mixed_embs = torch.cat([seg_embs[0], img_list, seg_embs[1]], dim=1)
# mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs

BIN
minigpt4/datasets/.DS_Store vendored Normal file

Binary file not shown.

View File

@ -103,3 +103,35 @@ class CCSBUAlignBuilder(BaseDatasetBuilder):
)
return datasets
@registry.register_builder("ScienceQA")
class ScienceQABuilder(BaseDatasetBuilder):
train_dataset_cls = ScienceQADataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/ScienceQA/align.yaml",
}
def build_datasets(self):
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
storage_path = build_info.storage
datasets = dict()
if not os.path.exists(storage_path):
warnings.warn("storage path {} does not exist.".format(storage_path))
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
ann_paths=[os.path.join(storage_path, 'train_QCM-A.json')],
vis_root=os.path.join(storage_path, 'train'),
)
return datasets

View File

@ -36,7 +36,14 @@ class CaptionDataset(BaseDataset, __DisplMixin):
self.img_ids = {}
n = 0
for ann in self.annotation:
if "image_id" in ann:
img_id = ann["image_id"]
if "/" in img_id:
image_id = img_id.split("/")[0]
if image_id not in self.img_ids.keys():
self.img_ids[image_id] = n
n += 1
else:
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
@ -45,19 +52,42 @@ class CaptionDataset(BaseDataset, __DisplMixin):
# TODO this assumes image input, not general enough
ann = self.annotation[index]
if "image_id" in ann:
if "id" in ann:
img_file = ann["image_id"]
input_prompt = self.text_processor(ann["input"])
image_path = os.path.join(self.vis_root, img_file)
# print(image_path)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
# print(image.shape)
caption = self.text_processor(ann["caption"])
return {
"image": image,
"input_prompt": input_prompt,
"text_input": caption,
"image_id": self.img_ids[ann["image_id"].split("/")[0]],
}
else:
img_file = '{:0>12}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
caption = self.text_processor(ann["caption"])
return {
"image": image,
"text_input": caption,
"image_id": self.img_ids[ann["image_id"]],
}
else:
input_prompt = self.text_processor(ann["input"])
caption = self.text_processor(ann["caption"])
return {
"image": torch.zeros(3, 224, 224),
"input_prompt": input_prompt,
"text_input": caption,
"image_id": -100,
}
class CaptionEvalDataset(BaseDataset, __DisplMixin):

View File

@ -0,0 +1,42 @@
import os
from PIL import Image
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
import torch
class ScienceQADataset(CaptionDataset):
def __getitem__(self, index):
# TODO this assumes image input, not general enough
ann = self.annotation[index]
if "image_id" in ann:
# if "id" in ann:
# img_file = ann["image_id"]
# input_prompt = ann["input"]
# else:
# img_file = '{}.jpg'.format(ann["image_id"])
img_file = ann["image_id"]
input_prompt = ann["input"]
image_path = os.path.join(self.vis_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
# print(image.shape)
caption = ann["caption"]
return {
"image": image,
"input_prompt": input_prompt,
"text_input": caption,
"image_id": self.img_ids[ann["image_id"].split("/")[0]]
}
else:
input_prompt = ann["input"]
caption = ann["caption"]
return {
"image": torch.zeros(3, 224, 224),
"input_prompt": input_prompt,
"text_input": caption,
"image_id": -100,
}

View File

@ -16,7 +16,6 @@ class MiniGPT4(Blip2Base):
"""
BLIP2 GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml",
}
@ -163,14 +162,50 @@ class MiniGPT4(Blip2Base):
else:
return img_embeds, atts_img
def prompt_wrap_image(self, img_embeds, atts_img, prompt):
p_before, p_after = [], []
for prompt_item in prompt:
before, after = prompt_item.split('<ImageHere>')
p_before.append(before)
p_after.append(after)
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to(img_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids)
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids)
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
return wrapped_img_embeds, wrapped_atts_img
def prompt_wrap_no_image(self, prompt, device_gpu):
p_tokens = self.llama_tokenizer(
prompt, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to(device_gpu)
p_embeds = self.llama_model.model.embed_tokens(p_tokens.input_ids)
p_atts = torch.ones(p_embeds.size()[:-1], dtype=torch.long)
wrapped_p_atts = p_atts[:, :1].expand(-1, p_embeds.shape[1])
return p_embeds, wrapped_p_atts
def forward(self, samples):
while True:
if "image" in samples:
device_gpu = samples["image"].device
break
if "image" in samples:
image = samples["image"]
img_embeds, atts_img = self.encode_img(image)
if hasattr(samples, 'question_split'): # VQA dataset
print('VQA Batch')
vqa_prompt = '###Human: <Img><ImageHere></Img> '
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
elif self.prompt_list:
if "input_prompt" in samples:
prefix = '###Human: <Img><ImageHere></Img> '
science_prompt = [prefix + item for item in samples["input_prompt"]]
img_embeds, atts_img = self.prompt_wrap_image(img_embeds, atts_img, science_prompt)
else:
science_prompt = samples["input_prompt"]
img_embeds, atts_img = self.prompt_wrap_no_image(science_prompt, device_gpu)
if len(self.prompt_list) > 0:
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
@ -185,7 +220,7 @@ class MiniGPT4(Blip2Base):
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(image.device)
).to(img_embeds.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
@ -193,7 +228,7 @@ class MiniGPT4(Blip2Base):
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
)
targets = torch.cat([empty_targets, targets], dim=1)

View File

@ -0,0 +1,51 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
prompt_path: ""
prompt_template: ''
ckpt: '/path/to/stage1/checkpoint/'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 3e-5
min_lr: 1e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
batch_size_train: 3
batch_size_eval: 3
num_workers: 4
warmup_steps: 200
seed: 42
output_dir: "output/minigpt4_stage2_finetune"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True