mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-17 11:10:46 +00:00
75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
import os
|
|
import json
|
|
import pickle
|
|
import random
|
|
import time
|
|
import itertools
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
import skimage.io as io
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.collections import PatchCollection
|
|
from matplotlib.patches import Polygon, Rectangle
|
|
from torch.utils.data import Dataset
|
|
import webdataset as wds
|
|
|
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
|
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
|
|
|
|
|
|
|
|
|
class MultiTaskConversationDataset(Dataset):
|
|
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
self.vis_root = vis_root
|
|
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
|
|
|
|
with open(ann_path, 'r') as f:
|
|
self.ann = json.load(f)
|
|
|
|
self.connect_sym = "!@#"
|
|
|
|
def __len__(self):
|
|
return len(self.ann)
|
|
|
|
def __getitem__(self, index):
|
|
info = self.ann[index]
|
|
|
|
image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
|
|
image_path = os.path.join(self.vis_root, image_file)
|
|
image = Image.open(image_path).convert("RGB")
|
|
image = self.vis_processor(image)
|
|
|
|
first_instruction = info['conversations'][0]['value'].replace('<image>', '').replace('\n', '').strip()
|
|
first_instruction = '<Img><ImageHere></Img> {} '.format(first_instruction)
|
|
|
|
questions = [first_instruction]
|
|
answers = []
|
|
|
|
for i, item in enumerate(info["conversations"][1:]):
|
|
if i % 2 ==0: # assistant
|
|
assistant_answer = item["value"]
|
|
answers.append(assistant_answer)
|
|
else:
|
|
human_instruction = item["value"]+" "
|
|
questions.append(human_instruction)
|
|
|
|
questions = self.connect_sym.join(questions)
|
|
answers = self.connect_sym.join(answers)
|
|
|
|
|
|
return {
|
|
"image": image,
|
|
"conv_q": questions,
|
|
'conv_a': answers,
|
|
"image_id": info['id'],
|
|
"connect_sym": self.connect_sym
|
|
} |