mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
add mvtec dataset
This commit is contained in:
parent
e66c4ee3e3
commit
59eafed6ca
6
minigpt4/configs/datasets/mvtec/default.yaml
Executable file
6
minigpt4/configs/datasets/mvtec/default.yaml
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
datasets:
|
||||||
|
mvtec_ad:
|
||||||
|
data_type: images
|
||||||
|
build_info:
|
||||||
|
image_path: /mnt/Repository/MiniGPT-4/MVTEC_det/images
|
||||||
|
ann_path: /mnt/Repository/MiniGPT-4/MVTEC_det/mvtech_ad_data_for_regression.json
|
@ -18,7 +18,7 @@ from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
|
|||||||
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
|
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
|
||||||
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
||||||
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
|
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
|
||||||
|
from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset
|
||||||
|
|
||||||
@registry.register_builder("multitask_conversation")
|
@registry.register_builder("multitask_conversation")
|
||||||
class MultitaskConversationBuilder(BaseDatasetBuilder):
|
class MultitaskConversationBuilder(BaseDatasetBuilder):
|
||||||
@ -394,6 +394,29 @@ class CaptionToPhraseBuilder(BaseDatasetBuilder):
|
|||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("mvtec_ad")
|
||||||
|
class MVTECADBuilder(BaseDatasetBuilder):
|
||||||
|
train_dataset_cls = MVTecDataset
|
||||||
|
DATASET_CONFIG_DICT = {
|
||||||
|
"default": "configs/datasets/mvtec/default.yaml",
|
||||||
|
}
|
||||||
|
def build_datasets(self):
|
||||||
|
logging.info("Building datasets...")
|
||||||
|
self.build_processors()
|
||||||
|
build_info = self.config.build_info
|
||||||
|
datasets = dict()
|
||||||
|
|
||||||
|
# create datasets
|
||||||
|
dataset_cls = self.train_dataset_cls
|
||||||
|
datasets['train'] = dataset_cls(
|
||||||
|
vis_processor=self.vis_processors["train"],
|
||||||
|
text_processor=self.text_processors["train"],
|
||||||
|
ann_path=build_info.ann_path,
|
||||||
|
vis_root=build_info.image_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentVQABuilder(BaseDatasetBuilder):
|
class DocumentVQABuilder(BaseDatasetBuilder):
|
||||||
|
51
minigpt4/datasets/datasets/mvtec_dataset.py
Normal file
51
minigpt4/datasets/datasets/mvtec_dataset.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MVTecDataset(Dataset):
|
||||||
|
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
|
||||||
|
self.instruction_pool = [
|
||||||
|
'[detection] {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(ann_path, 'r') as f:
|
||||||
|
self.ann = json.load(f)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.ann)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
info = self.ann[index]
|
||||||
|
gt_bbox = info["bbox"]
|
||||||
|
|
||||||
|
image_path = os.path.join(self.vis_root, info['image_path'])
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
|
||||||
|
input = "detect defect or non-defect and return the bounding box"
|
||||||
|
|
||||||
|
ans_cls = "defect" if info["is_broken"] == True else "non-defect"
|
||||||
|
answer = f"{ans_cls}<{gt_bbox[0]}><{gt_bbox[1]}><{gt_bbox[2]}><{gt_bbox[3]}>"
|
||||||
|
|
||||||
|
instruction = random.choice(self.instruction_pool).format(input)
|
||||||
|
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"instruction_input": instruction,
|
||||||
|
"answer": answer,
|
||||||
|
"image_id": info['image_path'],
|
||||||
|
}
|
||||||
|
|
55
train_configs/minigptv2_finetune_mvtec.yaml
Normal file
55
train_configs/minigptv2_finetune_mvtec.yaml
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt_v2
|
||||||
|
model_type: pretrain
|
||||||
|
max_txt_len: 1024
|
||||||
|
image_size: 448
|
||||||
|
end_sym: "</s>"
|
||||||
|
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
ckpt: "./ckpt/checkpoint_stage3.pth"
|
||||||
|
use_grad_checkpoint: True
|
||||||
|
chat_template: True
|
||||||
|
lora_r: 64
|
||||||
|
lora_alpha: 16
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
mvtec_ad:
|
||||||
|
batch_size: 2
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 448
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 100
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 1e-5
|
||||||
|
min_lr: 1e-6
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 50
|
||||||
|
num_workers: 6
|
||||||
|
warmup_steps: 1000
|
||||||
|
iters_per_epoch: 1000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "mvtec_outputs"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
||||||
|
|
||||||
|
wandb_log: True
|
||||||
|
job_name: minigptv2_finetune
|
Loading…
Reference in New Issue
Block a user