diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..ddbd354
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,184 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+
+wandb/
+jobs/logs/
+*.out
+*ipynb
+.history/
+*.json
+*.sh
+.ipynb_common
+logs/
+results/
+prompts/
+output/
+ckpt/
+divide_vqa.py
+jobs/
+
+*.slurm
+slurm*
+sbatch_generate*
+eval_data/
+dataset/Evaluation.md
+jupyter_notebook.slurm
diff --git a/MiniGPTv2_Train .md b/MiniGPTv2_Train .md
new file mode 100644
index 0000000..254d680
--- /dev/null
+++ b/MiniGPTv2_Train .md
@@ -0,0 +1,24 @@
+## Finetune of MiniGPT-4
+
+
+You firstly need to prepare the dataset. you can follow this step to prepare the dataset.
+our [dataset preparation](dataset/README_MINIGPTv2_FINETUNE.md).
+
+In the train_configs/minigptv2_finetune.yaml, you need to set up the following paths:
+
+llama_model checkpoint path: "/path/to/llama_checkpoint"
+
+ckpt: "/path/to/pretrained_checkpoint"
+
+ckpt save path: "/path/to/save_checkpoint"
+
+For ckpt, you may load from our pretrained model checkpoints:
+| MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo) |
+|------------------------------|------------------------------|------------------------------|
+| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1jAbxUiyl04SFJMN4sF1vvUU69Etuz4qa/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
+
+
+```bash
+torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigptv2_finetune.yaml
+```
+
diff --git a/README.md b/README.md
index e5a4a1b..d24923d 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,8 @@ Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
## News
+[Oct.24 2023] We release the finetuning code of our MiniGPT-v2.
+
[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2
[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4
@@ -63,7 +65,7 @@ Git clone our repository, creating a python environment and activate it via the
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
cd MiniGPT-4
conda env create -f environment.yml
-conda activate minigpt4
+conda activate minigptv
```
@@ -93,9 +95,10 @@ Then, set the variable *llama_model* in the model config file to the LLM weight
Download the pretrained model checkpoints
-| MiniGPT-v2 (LLaMA-2 Chat 7B) |
-|------------------------------|
-| [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
+| MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo)|
+|------------------------------|------------------------------|------------------------------|
+| [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1jAbxUiyl04SFJMN4sF1vvUU69Etuz4qa/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
+
For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8.
@@ -146,6 +149,7 @@ Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run Mini
### Training
For training details of MiniGPT-4, check [here](MiniGPT4_Train.md).
+For finetuning details of MiniGPT-v2, check [here](MiniGPTv2_Train.md)
diff --git a/dataset/README_MINIGPTv2_FINETUNE.md b/dataset/README_MINIGPTv2_FINETUNE.md
new file mode 100644
index 0000000..5a539f4
--- /dev/null
+++ b/dataset/README_MINIGPTv2_FINETUNE.md
@@ -0,0 +1,285 @@
+## Download the dataset for finetuning the MiniGPT-v2
+
+
+Download the dataset
+
+Image source | Download path
+--- | :---:
+COCO 2014 images | images captions
+COCO VQA | vqa train vqa val
+Visual Genome | images part1 images part2 image meta data
+TextCaps | images annotations
+RefCOCO | annotations
+RefCOCO+ | annotations
+RefCOCOg | annotations
+OKVQA | annotations
+AOK-VQA | annotations
+OCR-VQA | annotations
+GQA | images annotations
+Filtered flickr-30k | annotations
+Multi-task conversation | annotations
+Filtered unnatural instruction | annotations
+LLaVA | Compelex reasoning Detailed description Conversation
+
+
+
+### COCO captions
+Download the COCO 2014 images and captions
+
+coco 2014 images path
+
+```
+${MINIGPTv2_DATASET}
+├── coco
+│ ├── images
+...
+```
+
+
+coco caption annotation path
+
+```
+${MINIGPTv2_DATASET}
+├── coco_captions
+│ └── annotations
+│ ├── coco_karpathy_train.json
+...
+```
+
+Set **image_path** to the COCO 2014 image folder.
+Similarly, set **ann_path** to the coco_karpathy_train.json path
+- [minigpt4/configs/datasets/coco/caption.yaml](../minigpt4/configs/datasets/coco/caption.yaml)
+
+### COCO VQA
+Download the vqa v2 train and validation json files
+
+```
+├── ${MINIGPTv2_DATASET}
+│ ├── vqav2
+│ ├── vqa_train.json
+| ├── vqa_val.json
+```
+
+Set **image_path** to the COCO 2014 image folder.
+Similarly, set **ann_path** to the vqa_train.json and vqa_val.json path
+- [minigpt4/configs/datasets/coco/defaults_vqa.yaml](../minigpt4/configs/datasets/coco/defaults_vqa.yaml)
+
+
+### Visual genome
+Download visiual genome images and annotation files
+
+```
+${MINIGPTv2_DATASET}
+├── visual_genome
+│ ├── VG_100K
+│ ├── VG_100K_2
+│ └── region_descriptions.json
+│ └── image_data.json
+...
+```
+
+Set **image_path** to visual_genome folder.
+Similarly, set **ann_path** to the visual_genome folder.
+
+- [minigpt4/configs/datasets/vg/ref.yaml](../minigpt4/configs/datasets/vg/ref.yaml)
+
+
+### TextCaps
+Download the TextCaps images and annotation files
+
+```
+├── ${MINIGPTv2_DATASET}
+│ ├── textcaps
+│ ├── train_images
+│ ├── TextCaps_0.1_train.json
+```
+
+Set **image_path** to TextCaps train_images folder.
+Similarly, set **ann_path** to the TextCaps_0.1_train.json path
+
+- [minigpt4/configs/datasets/textcaps/caption.yaml](../minigpt4/configs/datasets/textcaps/caption.yaml)
+
+### RefCOCO, RefCOCO+, RefCOCOg
+Download the RefCOCO, RefCOCO+, RefCOCOg annotation files
+
+```
+
+${MINIGPTv2_DATASET}
+├── refcoco_annotations
+│ ├── refcoco
+│ │ ├── instances.json
+│ │ ├── refs(google).p
+│ │ └── refs(unc).p
+│ ├── refcoco+
+│ │ ├── instances.json
+│ │ └── refs(unc).p
+│ └── refcocog
+│ ├── instances.json
+│ ├── refs(google).p
+│ └─── refs(und).p
+...
+```
+
+
+Set **image_path** to the COCO 2014 image folder.
+Similarly, set **ann_path** in all the following configs to the above folder *refcoco_annotations* that contains refcoco, refcoco+, and refcocog.
+
+- [minigpt4/configs/datasets/coco_bbox/refcoco.yaml](../minigpt4/configs/datasets/coco_bbox/refcoco.yaml)
+- [minigpt4/configs/datasets/coco_bbox/refcocog.yaml](../minigpt4/configs/datasets/coco_bbox/refcocog.yaml)
+- [minigpt4/configs/datasets/coco_bbox/refcocop.yaml](../minigpt4/configs/datasets/coco_bbox/refcocop.yaml)
+- [minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml)
+- [minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml)
+- [minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml)
+
+
+
+
+### OKVQA
+
+
+```
+Location_you_like
+├── ${MINIGPTv2_DATASET}
+│ ├── okvqa
+│ ├── okvqa_train.json
+```
+
+Set **image_path** to the COCO 2014 image folder.
+Similarly, set **ann_path** to the location of the OKVQA dataset
+- [minigpt4/configs/datasets/okvqa/defaults.yaml](../minigpt4/configs/datasets/okvqa/defaults.yaml)
+
+
+### COCO-VQA
+
+- [OK-VQA Input Questions](https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip)
+- [OK-VQA Annotations](https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip)
+
+
+### AOK-VQA
+Download the AOK-VQA annotation dataset
+
+```
+export AOKVQA_DIR=YOUR_DATASET_PATH
+mkdir -p ${AOKVQA_DIR}
+curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
+```
+
+```
+Location_you_like
+├── ${MINIGPTv2_DATASET}
+│ ├── aokvqa
+│ ├── aokvqa_v1p0_train.json
+```
+
+
+Set **image_path** to the COCO 2014 image folder.
+Similarly, set **ann_path** to the location of the AOKVQA dataset
+- [minigpt4/configs/datasets/aokvqa/defaults.yaml](../minigpt4/configs/datasets/aokvqa/defaults.yaml)
+
+
+
+### OCR-VQA
+Download the OCR-VQA annotation files
+download the images with loadDataset.py script
+
+```
+Location_you_like
+├── ${MINIGPTv2_DATASET}
+│ ├── ocrvqa
+│ ├── images
+│ ├── dataset.json
+```
+
+Set **image_path** as the ocrvqa/images folder.
+Similarly, set **ann_path** to the dataset.json
+- [minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml](../minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml)
+
+### GQA
+Download the GQA annotation files and images
+
+```
+Location_you_like
+├── ${MINIGPTv2_DATASET}
+│ ├── gqa
+│ ├── images
+│ ├── train_balanced_questions.json
+```
+
+Set **image_path** as the gqa/images folder.
+Similarly, set **ann_path** to the train_balanced_questions.json
+- [minigpt4/configs/datasets/gqa/balanced_val.yaml](../minigpt4/configs/datasets/gqa/balanced_val.yaml)
+
+
+
+### filtered Flickr-30k
+Download filtered Flickr-30k images (fill this [form](https://forms.illinois.edu/sec/229675) on official website or from [kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset/download?datasetVersionNumber=1)) and annotation files
+
+```
+${MINIGPTv2_DATASET}
+├── filtered_flickr
+│ ├── images
+│ ├── captiontobbox.json
+│ ├── groundedcaption.json
+│ └── phrasetobbox.json
+...
+```
+
+Set **image_path** as the flickr-30k images foler.
+Similarly, set **ann_path** to the groundedcaption.json, captiontobbox.json and phrasetobbox.json for the
+grounded image caption, caption to bbox, and phrase to bbox datasets.
+
+- [minigpt4/configs/datasets/flickr/default.yaml](../minigpt4/configs/datasets/flickr/default.yaml)
+- [minigpt4/configs/datasets/flickr/caption_to_phrase.yaml](../minigpt4/configs/datasets/flickr/caption_to_phrase.yaml)
+- [minigpt4/configs/datasets/flickr/object_to_phrase.yaml](../minigpt4/configs/datasets/flickr/object_to_phrase.yaml)
+
+
+### Multi-task conversation
+Download the multi-task converstation dataset
+
+```
+Location_you_like
+${MINIGPTv2_DATASET}
+├── multitask_conversation
+│ └── multitask_conversation.json
+...
+```
+
+Set **image_path** as the COCO 2014 images folder.
+Similarly, set **ann_path** to the multitask_conversation.json file path
+
+- [minigpt4/configs/datasets/multitask_conversation/default.yaml](../minigpt4/configs/datasets/multitask_conversation/default.yaml)
+
+### Unnatural instruction
+Download the filtered unnatural instruction annotation files (we remove the very long sentences from the original unnatural instruction dataset)
+
+```
+Location_you_like
+├── ${MINIGPTv2_DATASET}
+│ ├── unnatural_instructions
+│ ├── filtered_unnatural_instruction.json
+```
+
+There is no image path.
+Similarly, set **ann_path** to the filtered_unnatural_instruction.json file path
+
+- [minigpt4/configs/datasets/nlp/unnatural_instruction.yaml](../minigpt4/configs/datasets/nlp/unnatural_instruction.yaml)
+
+### LLaVA
+
+```
+Location_you_like
+├── ${MINIGPTv2_DATASET}
+│ ├── llava
+│ ├── conversation_58k.json
+│ ├── detail_23k.json
+│ ├── complex_reasoning_77k.json
+```
+
+Set **image_path** to the COCO 2014 image folder.
+Similarly, set **ann_path** to the location of the previous downloaded conversation_58k.json,
+detail_23k.json, and complex_reasoning_77k.json in conversation.yaml, detail.yaml, and reason.yaml, respectively.
+
+
+- [minigpt4/configs/datasets/llava/conversation.yaml](../minigpt4/configs/datasets/llava/conversation.yaml)
+- [minigpt4/configs/datasets/llava/detail.yaml](../minigpt4/configs/datasets/llava/detail.yaml)
+- [minigpt4/configs/datasets/llava/reason.yaml](../minigpt4/configs/datasets/llava/reason.yaml)
diff --git a/environment.yml b/environment.yml
index cf90e89..8f94afe 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,4 +1,4 @@
-name: minigpt4
+name: minigptv
channels:
- pytorch
- defaults
diff --git a/minigpt4/configs/datasets/aokvqa/defaults.yaml b/minigpt4/configs/datasets/aokvqa/defaults.yaml
new file mode 100755
index 0000000..767fdd4
--- /dev/null
+++ b/minigpt4/configs/datasets/aokvqa/defaults.yaml
@@ -0,0 +1,20 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+datasets:
+ aok_vqa:
+ # data_dir: ${env.data_dir}/datasets
+ data_type: images # [images|videos|features]
+
+ build_info:
+ # Be careful not to append minus sign (-) before split to avoid itemizing
+ annotations:
+ train:
+ url:
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json
+ storage:
+ - /path/to/aokvqa_v1p0_train.json
+ images:
+ storage: /path/to/coco/images
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco/caption.yaml b/minigpt4/configs/datasets/coco/caption.yaml
new file mode 100644
index 0000000..ac072a4
--- /dev/null
+++ b/minigpt4/configs/datasets/coco/caption.yaml
@@ -0,0 +1,21 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+datasets:
+ coco_caption: # name of the dataset builder
+ # dataset_card: dataset_card/coco_caption.md
+ # data_dir: ${env.data_dir}/datasets
+ data_type: images # [images|videos|features]
+
+ build_info:
+ # Be careful not to append minus sign (-) before split to avoid itemizing
+ annotations:
+ train:
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
+ md5: aa31ac474cf6250ebb81d18348a07ed8
+ storage: /path/to/coco_caption/coco_karpathy_train.json
+ images:
+ storage: /path/to/coco/images
+
diff --git a/minigpt4/configs/datasets/coco/defaults_vqa.yaml b/minigpt4/configs/datasets/coco/defaults_vqa.yaml
new file mode 100755
index 0000000..457e0a3
--- /dev/null
+++ b/minigpt4/configs/datasets/coco/defaults_vqa.yaml
@@ -0,0 +1,24 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+datasets:
+ coco_vqa:
+ # data_dir: ${env.data_dir}/datasets
+ data_type: images # [images|videos|features]
+
+ build_info:
+
+ annotations:
+ train:
+ url:
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json
+ storage:
+ - /path/to/vqav2/vqa_train.json
+ - /path/to/vqav2/vqa_val.json
+ images:
+ storage: /path/to/coco/images
+
+
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml b/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml
new file mode 100755
index 0000000..8325efc
--- /dev/null
+++ b/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml
@@ -0,0 +1,8 @@
+datasets:
+ invrefcoco:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/refcoco_annotations
+ dataset: invrefcoco
+ splitBy: unc
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml b/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml
new file mode 100755
index 0000000..e562240
--- /dev/null
+++ b/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml
@@ -0,0 +1,8 @@
+datasets:
+ invrefcocog:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/refcoco_annotations
+ dataset: invrefcocog
+ splitBy: umd
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml b/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml
new file mode 100755
index 0000000..1c57c8e
--- /dev/null
+++ b/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml
@@ -0,0 +1,8 @@
+datasets:
+ invrefcocop:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/refcoco_annotations
+ dataset: invrefcoco+
+ splitBy: unc
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco_bbox/refcoco.yaml b/minigpt4/configs/datasets/coco_bbox/refcoco.yaml
new file mode 100755
index 0000000..fc96f6d
--- /dev/null
+++ b/minigpt4/configs/datasets/coco_bbox/refcoco.yaml
@@ -0,0 +1,8 @@
+datasets:
+ refcoco:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/refcoco_annotations
+ dataset: refcoco
+ splitBy: unc
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco_bbox/refcocog.yaml b/minigpt4/configs/datasets/coco_bbox/refcocog.yaml
new file mode 100755
index 0000000..bb751cb
--- /dev/null
+++ b/minigpt4/configs/datasets/coco_bbox/refcocog.yaml
@@ -0,0 +1,8 @@
+datasets:
+ refcocog:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/refcoco_annotations
+ dataset: refcocog
+ splitBy: umd
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/coco_bbox/refcocop.yaml b/minigpt4/configs/datasets/coco_bbox/refcocop.yaml
new file mode 100755
index 0000000..36c574e
--- /dev/null
+++ b/minigpt4/configs/datasets/coco_bbox/refcocop.yaml
@@ -0,0 +1,8 @@
+datasets:
+ refcocop:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/refcoco_annotations
+ dataset: refcoco+
+ splitBy: unc
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/flickr/caption_to_phrase.yaml b/minigpt4/configs/datasets/flickr/caption_to_phrase.yaml
new file mode 100755
index 0000000..61243f9
--- /dev/null
+++ b/minigpt4/configs/datasets/flickr/caption_to_phrase.yaml
@@ -0,0 +1,6 @@
+datasets:
+ flickr_CaptionToPhrase:
+ data_type: images
+ build_info:
+ image_path: /path/to/filtered_flikcr/images
+ ann_path: /path/to/filtered_flickr/captiontobbox.json
diff --git a/minigpt4/configs/datasets/flickr/default.yaml b/minigpt4/configs/datasets/flickr/default.yaml
new file mode 100755
index 0000000..25868c0
--- /dev/null
+++ b/minigpt4/configs/datasets/flickr/default.yaml
@@ -0,0 +1,6 @@
+datasets:
+ flickr_grounded_caption:
+ data_type: images
+ build_info:
+ image_path: /path/to/filtered_flikcr/images
+ ann_path: /path/to/filtered_flikcr/groundedcaption.json
diff --git a/minigpt4/configs/datasets/flickr/object_to_phrase.yaml b/minigpt4/configs/datasets/flickr/object_to_phrase.yaml
new file mode 100755
index 0000000..3a317dc
--- /dev/null
+++ b/minigpt4/configs/datasets/flickr/object_to_phrase.yaml
@@ -0,0 +1,6 @@
+datasets:
+ flickr_ObjectToPhrase:
+ data_type: images
+ build_info:
+ image_path: /path/to/filtered_flikcr/images
+ ann_path: /path/to/filtered_flikcr/phrasetobbox.json
diff --git a/minigpt4/configs/datasets/gqa/balanced_val.yaml b/minigpt4/configs/datasets/gqa/balanced_val.yaml
new file mode 100644
index 0000000..f4c8765
--- /dev/null
+++ b/minigpt4/configs/datasets/gqa/balanced_val.yaml
@@ -0,0 +1,21 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+datasets:
+ gqa:
+ # data_dir: ${env.data_dir}/datasets
+ data_type: images # [images|videos|features]
+
+ build_info:
+ # Be careful not to append minus sign (-) before split to avoid itemizing
+ annotations:
+ train:
+ url:
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json
+ storage:
+ - /path/to/gqa/train_balanced_questions.json
+
+ images:
+ storage: /path/to/gqa/images
diff --git a/minigpt4/configs/datasets/llava/conversation.yaml b/minigpt4/configs/datasets/llava/conversation.yaml
new file mode 100755
index 0000000..35c327b
--- /dev/null
+++ b/minigpt4/configs/datasets/llava/conversation.yaml
@@ -0,0 +1,7 @@
+datasets:
+
+ llava_conversation:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/llava/conversation_58k.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/llava/detail.yaml b/minigpt4/configs/datasets/llava/detail.yaml
new file mode 100755
index 0000000..896df39
--- /dev/null
+++ b/minigpt4/configs/datasets/llava/detail.yaml
@@ -0,0 +1,6 @@
+datasets:
+ llava_detail:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/llava/detail_23k.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/llava/reason.yaml b/minigpt4/configs/datasets/llava/reason.yaml
new file mode 100755
index 0000000..f5fb674
--- /dev/null
+++ b/minigpt4/configs/datasets/llava/reason.yaml
@@ -0,0 +1,7 @@
+datasets:
+
+ llava_reason:
+ data_type: images
+ build_info:
+ image_path: /path/to/coco/images
+ ann_path: /path/to/llava/complex_reasoning_77k.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/multitask_conversation/default.yaml b/minigpt4/configs/datasets/multitask_conversation/default.yaml
new file mode 100644
index 0000000..9d5ee72
--- /dev/null
+++ b/minigpt4/configs/datasets/multitask_conversation/default.yaml
@@ -0,0 +1,7 @@
+datasets:
+ multitask_conversation:
+ data_type: images
+ build_info:
+
+ image_path: /path/to/coco/images
+ ann_path: /path/to/multitask_conversation/multi_task_conversation.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml b/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml
new file mode 100644
index 0000000..67464e5
--- /dev/null
+++ b/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml
@@ -0,0 +1,5 @@
+datasets:
+ unnatural_instruction:
+ data_type: text
+ build_info:
+ ann_path: /path/to/unnatural_instructions/filtered_unnatural_instruction.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml b/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml
new file mode 100755
index 0000000..0b651dd
--- /dev/null
+++ b/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml
@@ -0,0 +1,6 @@
+datasets:
+ ocrvqa:
+ data_type: images
+ build_info:
+ image_path: /path/to/ocrvqa/images
+ ann_path: /path/to/ocrvqa/dataset.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/okvqa/defaults.yaml b/minigpt4/configs/datasets/okvqa/defaults.yaml
new file mode 100755
index 0000000..e536366
--- /dev/null
+++ b/minigpt4/configs/datasets/okvqa/defaults.yaml
@@ -0,0 +1,21 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+datasets:
+ ok_vqa:
+ # data_dir: ${env.data_dir}/datasets
+ data_type: images # [images|videos|features]
+
+ build_info:
+ # Be careful not to append minus sign (-) before split to avoid itemizing
+ annotations:
+ train:
+ url:
+ # TODO make this order insensitive
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json
+ storage:
+ - /path/to/okvqa/okvqa_train.json
+ images:
+ storage: /path/to/coco/images
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/textcaps/caption.yaml b/minigpt4/configs/datasets/textcaps/caption.yaml
new file mode 100755
index 0000000..9a732b4
--- /dev/null
+++ b/minigpt4/configs/datasets/textcaps/caption.yaml
@@ -0,0 +1,9 @@
+datasets:
+ textcaps_caption:
+ data_type: images
+
+ build_info:
+ image_path: /path/to/textcaps/train_images
+ ann_path: /path/to/textcaps/TextCaps_0.1_train.json
+
+
diff --git a/minigpt4/configs/datasets/vg/ref.yaml b/minigpt4/configs/datasets/vg/ref.yaml
new file mode 100755
index 0000000..008ae72
--- /dev/null
+++ b/minigpt4/configs/datasets/vg/ref.yaml
@@ -0,0 +1,5 @@
+datasets:
+ refvg:
+ data_type: images
+ build_info:
+ data_dir: /path/to/visual_genome
\ No newline at end of file
diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py
index e5d66b8..fb344f1 100644
--- a/minigpt4/datasets/builders/image_text_pair_builder.py
+++ b/minigpt4/datasets/builders/image_text_pair_builder.py
@@ -6,6 +6,425 @@ from minigpt4.common.registry import registry
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
+from minigpt4.datasets.datasets.text_caps import TextCapDataset
+from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset
+from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset
+from minigpt4.datasets.datasets.multitask_conversation import MultiTaskConversationDataset
+from minigpt4.datasets.datasets.flickr import GroundedDetailDataset,CaptionToObjectDataset,PhraseToObjectDataset
+from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset
+from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset
+from minigpt4.datasets.datasets.gqa_datasets import GQADataset
+from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
+from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
+from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
+from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
+
+
+@registry.register_builder("multitask_conversation")
+class MultitaskConversationBuilder(BaseDatasetBuilder):
+ train_dataset_cls = MultiTaskConversationDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/multitask_conversation/default.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+
+@registry.register_builder("unnatural_instruction")
+class UnnaturalInstructionBuilder(BaseDatasetBuilder):
+ train_dataset_cls = UnnaturalDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/nlp/unnatural_instruction.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ )
+
+ return datasets
+
+
+
+@registry.register_builder("llava_detail")
+class LlavaDetailBuilder(BaseDatasetBuilder):
+ train_dataset_cls = LlavaDetailDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/llava/detail.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+
+
+@registry.register_builder("llava_reason")
+class LlavaReasonBuilder(BaseDatasetBuilder):
+ train_dataset_cls = LlavaReasonDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/llava/reason.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+@registry.register_builder("llava_conversation")
+class LlavaReasonBuilder(BaseDatasetBuilder):
+ train_dataset_cls = LlavaConversationDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/llava/conversation.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+
+class AllRefCOCOBuilder(BaseDatasetBuilder):
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ image_path = build_info.image_path
+ ann_path = build_info.ann_path
+
+ datasets = dict()
+
+ if not os.path.exists(image_path):
+ warnings.warn("image path {} does not exist.".format(image_path))
+ if not os.path.exists(ann_path):
+ warnings.warn("ann path {} does not exist.".format(ann_path))
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=ann_path,
+ vis_root=image_path,
+ dataset=build_info.dataset,
+ splitBy=build_info.splitBy
+ )
+
+ return datasets
+
+
+@registry.register_builder("refcoco")
+class RefCOCOBuilder(AllRefCOCOBuilder):
+ train_dataset_cls = ReferCOCODataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco_bbox/refcoco.yaml",
+ }
+
+@registry.register_builder("refcocop")
+class RefCOCOPBuilder(AllRefCOCOBuilder):
+ train_dataset_cls = ReferCOCODataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco_bbox/refcocop.yaml",
+ }
+
+
+@registry.register_builder("refcocog")
+class RefCOCOGBuilder(AllRefCOCOBuilder):
+ train_dataset_cls = ReferCOCODataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco_bbox/refcocog.yaml",
+ }
+
+@registry.register_builder("invrefcoco")
+class RefCOCOBuilder(AllRefCOCOBuilder):
+ train_dataset_cls = InvReferCOCODataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco_bbox/invrefcoco.yaml",
+ }
+
+
+@registry.register_builder("invrefcocop")
+class RefCOCOPBuilder(AllRefCOCOBuilder):
+ train_dataset_cls = InvReferCOCODataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco_bbox/invrefcocop.yaml",
+ }
+
+
+@registry.register_builder("invrefcocog")
+class RefCOCOGBuilder(AllRefCOCOBuilder):
+ train_dataset_cls = InvReferCOCODataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco_bbox/invrefcocog.yaml",
+ }
+
+@registry.register_builder("refvg")
+class RefVisualGenomeBuilder(BaseDatasetBuilder):
+ train_dataset_cls = ReferVisualGenomeDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/vg/ref.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ data_dir = build_info.data_dir
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ data_dir=data_dir,
+ )
+
+ return datasets
+
+
+@registry.register_builder("textcaps_caption")
+class TextcapCaptionBuilder(BaseDatasetBuilder):
+ train_dataset_cls = TextCapDataset
+
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/caption.yaml"}
+
+ def _download_ann(self):
+ pass
+
+ def _download_vis(self):
+ pass
+
+ def build(self):
+ self.build_processors()
+
+ build_info = self.config.build_info
+
+ datasets = dict()
+ split = "train"
+
+ # create datasets
+ # [NOTE] return inner_datasets (wds.DataPipeline)
+ dataset_cls = self.train_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=self.vis_processors[split],
+ text_processor=self.text_processors[split],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+@registry.register_builder("coco_vqa")
+class COCOVQABuilder(BaseDatasetBuilder):
+ train_dataset_cls = COCOVQADataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco/defaults_vqa.yaml",
+ }
+
+@registry.register_builder("ok_vqa")
+class OKVQABuilder(COCOVQABuilder):
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/okvqa/defaults.yaml",
+ }
+
+
+@registry.register_builder("aok_vqa")
+class AOKVQABuilder(BaseDatasetBuilder):
+ train_dataset_cls = AOKVQADataset
+
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}
+
+
+@registry.register_builder("gqa")
+class GQABuilder(BaseDatasetBuilder):
+ train_dataset_cls = GQADataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/gqa/balanced_val.yaml",
+ }
+
+
+
+
+@registry.register_builder("flickr_grounded_caption")
+class GroundedCaptionBuilder(BaseDatasetBuilder):
+ train_dataset_cls = GroundedDetailDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/flickr/default.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+
+@registry.register_builder("flickr_CaptionToPhrase")
+class CaptionToPhraseBuilder(BaseDatasetBuilder):
+ train_dataset_cls = CaptionToObjectDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/flickr/caption_to_phrase.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+@registry.register_builder("flickr_ObjectToPhrase")
+class CaptionToPhraseBuilder(BaseDatasetBuilder):
+ train_dataset_cls = PhraseToObjectDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/flickr/object_to_phrase.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+
+
+
+class DocumentVQABuilder(BaseDatasetBuilder):
+ def _download_ann(self):
+ pass
+
+ def _download_vis(self):
+ pass
+
+ def build(self):
+ self.build_processors()
+ build_info = self.config.build_info
+
+ datasets = dict()
+ split = "train"
+
+ dataset_cls = self.train_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=self.vis_processors[split],
+ text_processor=self.text_processors[split],
+ vis_root=build_info.image_path,
+ ann_path=build_info.ann_path
+ )
+
+ return datasets
+
+
+@registry.register_builder("ocrvqa")
+class OCRVQABuilder(DocumentVQABuilder):
+ train_dataset_cls = OCRVQADataset
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/ocrvqa/ocrvqa.yaml"}
@registry.register_builder("cc_sbu")
@@ -72,6 +491,17 @@ class LaionBuilder(BaseDatasetBuilder):
return datasets
+
+@registry.register_builder("coco_caption")
+class COCOCapBuilder(BaseDatasetBuilder):
+ train_dataset_cls = COCOCapDataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/coco/caption.yaml",
+ }
+
+
+
@registry.register_builder("cc_sbu_align")
class CCSBUAlignBuilder(BaseDatasetBuilder):
train_dataset_cls = CCSBUAlignDataset
diff --git a/minigpt4/datasets/datasets/aok_vqa_datasets.py b/minigpt4/datasets/datasets/aok_vqa_datasets.py
new file mode 100755
index 0000000..00ed06d
--- /dev/null
+++ b/minigpt4/datasets/datasets/aok_vqa_datasets.py
@@ -0,0 +1,116 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from collections import OrderedDict
+import json
+import os
+import random
+import torch
+
+from PIL import Image
+
+from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset
+
+
+class __DisplMixin:
+ def displ_item(self, index):
+ sample, ann = self.__getitem__(index), self.annotation[index]
+ return OrderedDict(
+ {
+ "file": ann["image"],
+ "question": ann["question"],
+ "question_id": ann["question_id"],
+ "direct_answers": "; ".join(ann["direct_answers"]),
+ "choices": "; ".join(ann["choices"]),
+ "correct_choice": ann["choices"][ann["correct_choice_idx"]],
+ "image": sample["image"],
+ }
+ )
+
+
+class AOKVQADataset(VQADataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ self.instruction_pool =[
+ "[vqa] {}",
+ "[vqa] Based on the image, respond to this question with a short answer: {}"
+ ]
+
+ exist_annotation = []
+ for ann in self.annotation:
+ image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
+ if os.path.exists(image_path):
+ exist_annotation.append(ann)
+ self.annotation = exist_annotation
+
+ def get_data(self, index):
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ question = self.text_processor(ann["question"])
+
+ answer_key = "direct_answers"
+
+ answer_weight = {}
+ for answer in ann[answer_key]:
+ if answer in answer_weight.keys():
+ answer_weight[answer] += 1 / len(ann[answer_key])
+ else:
+ answer_weight[answer] = 1 / len(ann[answer_key])
+
+ answers = list(answer_weight.keys())
+ weights = list(answer_weight.values())
+
+ answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
+
+ return {
+ "image": image,
+ "question": question,
+ "answer": answer,
+ }
+
+ def __getitem__(self, index):
+ data = self.get_data(index)
+ question = self.text_processor(data["question"])
+ instruction = random.choice(self.instruction_pool).format(question)
+
+ instruction = "
{} ".format(instruction)
+ answer = self.text_processor(data['answer'])
+
+ return {
+ "image": data['image'],
+ "instruction_input": instruction,
+ "answer": answer,
+ }
+
+
+class AOKVQGDataset(AOKVQADataset):
+
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+ self.instruction_pool = [
+ 'Given the image, generate a question whose answer is: {}',
+ 'Based on the image, provide a question with the answer: {}',
+ 'Given the visual representation, create a question for which the answer is "{}"',
+ 'From the image provided, craft a question that leads to the reply: {}',
+ 'Considering the picture, come up with a question where the answer is: {}',
+ 'Taking the image into account, generate an question that has the answer: {}'
+ ]
+
+ def __getitem__(self, index):
+ data = self.get_data(index)
+ instruction = random.choice(self.instruction_pool).format(data['answer'])
+
+ return {
+ "image": data['image'],
+ "instruction_input": instruction,
+ "answer": data['question'],
+ }
diff --git a/minigpt4/datasets/datasets/base_dataset.py b/minigpt4/datasets/datasets/base_dataset.py
index ae2a8d0..97aed82 100644
--- a/minigpt4/datasets/datasets/base_dataset.py
+++ b/minigpt4/datasets/datasets/base_dataset.py
@@ -12,6 +12,8 @@ from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.dataloader import default_collate
+
+
class BaseDataset(Dataset):
def __init__(
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
@@ -23,9 +25,16 @@ class BaseDataset(Dataset):
self.vis_root = vis_root
self.annotation = []
+ # print("ann paths", ann_paths)
for ann_path in ann_paths:
- self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
-
+ # print("ann_path", ann_path)
+ ann = json.load(open(ann_path, "r"))
+ if isinstance(ann, dict):
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
+ # self.annotation.extend(json.load(open(ann_path, "r")))
+ else:
+ self.annotation.extend(json.load(open(ann_path, "r")))
+
self.vis_processor = vis_processor
self.text_processor = text_processor
@@ -46,6 +55,7 @@ class BaseDataset(Dataset):
ann[key] = str(idx)
+
class ConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__(datasets)
diff --git a/minigpt4/datasets/datasets/caption_datasets.py b/minigpt4/datasets/datasets/caption_datasets.py
index 78bab66..6a432a7 100644
--- a/minigpt4/datasets/datasets/caption_datasets.py
+++ b/minigpt4/datasets/datasets/caption_datasets.py
@@ -10,6 +10,7 @@ from collections import OrderedDict
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from PIL import Image
+import random
class __DisplMixin:
@@ -60,6 +61,71 @@ class CaptionDataset(BaseDataset, __DisplMixin):
}
+
+class COCOCaptionDataset(BaseDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ self.img_ids = {}
+ n = 0
+
+ self.filter_anntation = []
+
+ for ann in self.annotation:
+ if "train" in ann["image"]:
+ self.filter_anntation.append(ann)
+ self.annotation = self.filter_anntation
+
+ for ann in self.annotation:
+ img_id = ann["image_id"]
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ self.instruction_pool = [
+ 'Briefly describe this image.',
+ 'Provide a concise depiction of this image.',
+ 'Present a short description of this image.',
+ 'Summarize this image in a few words.',
+ 'A short image caption:',
+ 'A short image description:',
+ 'A photo of ',
+ 'An image that shows ',
+ 'Write a short description for the image. ',
+ 'Write a description for the photo.',
+ 'Provide a description of what is presented in the photo.',
+ 'Briefly describe the content of the image.',
+ 'Can you briefly explain what you see in the image?',
+ 'Could you use a few words to describe what you perceive in the photo?',
+ 'Please provide a short depiction of the picture.',
+ 'Using language, provide a short account of the image.',
+ 'Use a few words to illustrate what is happening in the picture.',
+ ]
+ def __getitem__(self, index):
+
+ # TODO this assumes image input, not general enough
+ ann = self.annotation[index]
+
+ img_file = ann["image"].split("/")[-1]
+ image_path = os.path.join(self.vis_root, img_file)
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ caption = self.text_processor(ann["caption"])
+
+ instruction = random.choice(self.instruction_pool)
+ instruction = "
[caption] {} ".format(instruction)
+
+ return {
+ "image": image,
+ "answer": caption,
+ "instruction_input": instruction,
+ }
+
class CaptionEvalDataset(BaseDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
diff --git a/minigpt4/datasets/datasets/coco_caption.py b/minigpt4/datasets/datasets/coco_caption.py
new file mode 100755
index 0000000..76f86e4
--- /dev/null
+++ b/minigpt4/datasets/datasets/coco_caption.py
@@ -0,0 +1,120 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import json
+import torch
+import numpy as np
+
+from PIL import Image
+from PIL import ImageFile
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset
+
+COCOCapDataset = COCOCaptionDataset
+
+
+
+
+
+class COCOCapEvalDataset(CaptionEvalDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ def __getitem__(self, index):
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+
+ img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
+
+ return {
+ "image": image,
+ "image_id": img_id,
+ "instance_id": ann["instance_id"],
+ }
+
+
+class NoCapsEvalDataset(CaptionEvalDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ def __getitem__(self, index):
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+
+ img_id = ann["img_id"]
+
+ return {
+ "image": image,
+ "image_id": img_id,
+ "instance_id": ann["instance_id"],
+ }
+
+
+class RefCOCOEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['img_id']
+ sent = data['sents']
+ image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg')
+ image = Image.open(image_path).convert('RGB')
+ image = self.vis_processor(image)
+ question = f"[refer] tell me the location of {sent}?"
+ return image, question, img_id
+
+class EvalCaptionData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+ ann = dict()
+ for item in self.loaded_data:
+ image_id = item['image_id']
+ ann[image_id] = item['image']
+ self.ann = [{'image_id':image_id, 'image': ann[image_id]} for image_id in ann]
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, idx):
+ data = self.ann[idx]
+ image_id = data['image_id']
+ img_file = data['image'].split('/')[-1]
+ image_path = os.path.join(self.root_path, img_file)
+ image = Image.open(image_path).convert('RGB')
+
+ image = self.vis_processor(image)
+ question = f"[caption] please describe this image?"
+ return image, question, image_id
diff --git a/minigpt4/datasets/datasets/coco_dataset.py b/minigpt4/datasets/datasets/coco_dataset.py
new file mode 100755
index 0000000..16f03f0
--- /dev/null
+++ b/minigpt4/datasets/datasets/coco_dataset.py
@@ -0,0 +1,348 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+
+class ReferCOCODataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path, dataset='refcoco', splitBy='unc'):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.refer = REFER(ann_path, vis_root, dataset, splitBy)
+ self.ref_ids = self.refer.getRefIds(split="train")
+
+ self.instruction_pool = [
+ "[refer] {}",
+ "[refer] give me the location of {}",
+ "[refer] where is {} ?",
+ "[refer] from this image, tell me the location of {}",
+ "[refer] the location of {} is",
+ "[refer] could you tell me the location for {} ?",
+ "[refer] where can I locate the {} ?",
+ ]
+
+
+ def __len__(self):
+ return len(self.ref_ids)
+
+ def preprocess(self, index):
+ ref_id = self.ref_ids[index]
+ ref = self.refer.loadRefs(ref_id)[0]
+
+ image_file = 'COCO_train2014_{:0>12}.jpg'.format(ref["image_id"])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image_orig_size = image.size
+ image = self.vis_processor(image)
+ image_new_size = [image.shape[1], image.shape[2]]
+
+ image_new_size = [100,100]
+
+ sample_sentence = random.choice(ref['sentences'])['raw']
+ refer_sentence = self.text_processor(sample_sentence)
+
+
+ bbox = self.refer.getRefBox(ref['ref_id'])
+ bbox = [
+ bbox[0] / image_orig_size[0] * image_new_size[0],
+ bbox[1] / image_orig_size[1] * image_new_size[1],
+ (bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
+ (bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
+ ]
+ bbox = [int(x) for x in bbox]
+ bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
+ return {
+ "image": image,
+ "refer_sentence": refer_sentence,
+ "bbox": bbox,
+ "image_id": ref['image_id'],
+ }
+
+ def __getitem__(self, index):
+ data = self.preprocess(index)
+ instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
+
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": data['image'],
+ "instruction_input": instruction,
+ "answer": data['bbox'],
+ "image_id": data['image_id'],
+ }
+
+
+class InvReferCOCODataset(ReferCOCODataset):
+ def __init__(self, *args, **kwargs):
+ super(InvReferCOCODataset, self).__init__(*args, **kwargs)
+
+ self.instruction_pool = [
+ "[identify] {}",
+ "[identify] what object is in this location {}",
+ "[identify] identify the object present at this location {}",
+ "[identify] what is it in {}",
+ "[identify] describe this object in {}",
+ "[identify] this {} is",
+ "[identify] the object in {} is",
+ ]
+
+ def __getitem__(self, index):
+ data = self.preprocess(index)
+
+ instruction = random.choice(self.instruction_pool).format(data['bbox'])
+
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": data['image'],
+ "instruction_input": instruction,
+ "answer": self.text_processor(data['refer_sentence']),
+ "image_id": data['image_id'],
+ }
+
+
+class REFER:
+ def __init__(self, data_root, vis_root, dataset='refcoco', splitBy='unc'):
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
+ # also provide dataset name and splitBy information
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
+ dataset = dataset.split('inv')[-1] # inv dataset is stored in the same path as normal dataset
+ print('loading dataset %s into memory...' % dataset)
+ self.ann_dir = os.path.join(data_root, dataset)
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
+ self.vis_root = vis_root
+ elif dataset == 'refclef':
+ raise 'No RefClef image data'
+ else:
+ raise 'No refer dataset is called [%s]' % dataset
+
+ # load refs from data/dataset/refs(dataset).json
+ tic = time.time()
+ ref_file = os.path.join(self.ann_dir, 'refs(' + splitBy + ').p')
+ self.data = {}
+ self.data['dataset'] = dataset
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'))
+
+ # load annotations from data/dataset/instances.json
+ instances_file = os.path.join(self.ann_dir, 'instances.json')
+ instances = json.load(open(instances_file, 'r'))
+ self.data['images'] = instances['images']
+ self.data['annotations'] = instances['annotations']
+ self.data['categories'] = instances['categories']
+
+ # create index
+ self.createIndex()
+ print('DONE (t=%.2fs)' % (time.time() - tic))
+
+ def createIndex(self):
+ # create sets of mapping
+ # 1) Refs: {ref_id: ref}
+ # 2) Anns: {ann_id: ann}
+ # 3) Imgs: {image_id: image}
+ # 4) Cats: {category_id: category_name}
+ # 5) Sents: {sent_id: sent}
+ # 6) imgToRefs: {image_id: refs}
+ # 7) imgToAnns: {image_id: anns}
+ # 8) refToAnn: {ref_id: ann}
+ # 9) annToRef: {ann_id: ref}
+ # 10) catToRefs: {category_id: refs}
+ # 11) sentToRef: {sent_id: ref}
+ # 12) sentToTokens: {sent_id: tokens}
+ print('creating index...')
+ # fetch info from instances
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
+ for ann in self.data['annotations']:
+ Anns[ann['id']] = ann
+ imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
+ for img in self.data['images']:
+ Imgs[img['id']] = img
+ for cat in self.data['categories']:
+ Cats[cat['id']] = cat['name']
+
+ # fetch info from refs
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
+ Sents, sentToRef, sentToTokens = {}, {}, {}
+ for ref in self.data['refs']:
+ # ids
+ ref_id = ref['ref_id']
+ ann_id = ref['ann_id']
+ category_id = ref['category_id']
+ image_id = ref['image_id']
+
+ # add mapping related to ref
+ Refs[ref_id] = ref
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
+ refToAnn[ref_id] = Anns[ann_id]
+ annToRef[ann_id] = ref
+
+ # add mapping of sent
+ for sent in ref['sentences']:
+ Sents[sent['sent_id']] = sent
+ sentToRef[sent['sent_id']] = ref
+ sentToTokens[sent['sent_id']] = sent['tokens']
+
+ # create class members
+ self.Refs = Refs
+ self.Anns = Anns
+ self.Imgs = Imgs
+ self.Cats = Cats
+ self.Sents = Sents
+ self.imgToRefs = imgToRefs
+ self.imgToAnns = imgToAnns
+ self.refToAnn = refToAnn
+ self.annToRef = annToRef
+ self.catToRefs = catToRefs
+ self.sentToRef = sentToRef
+ self.sentToTokens = sentToTokens
+ print('index created.')
+
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
+ refs = self.data['refs']
+ else:
+ if not len(image_ids) == 0:
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
+ else:
+ refs = self.data['refs']
+ if not len(cat_ids) == 0:
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
+ if not len(ref_ids) == 0:
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
+ if not len(split) == 0:
+ if split in ['testA', 'testB', 'testC']:
+ refs = [ref for ref in refs if
+ split[-1] in ref['split']] # we also consider testAB, testBC, ...
+ elif split in ['testAB', 'testBC', 'testAC']:
+ refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
+ elif split == 'test':
+ refs = [ref for ref in refs if 'test' in ref['split']]
+ elif split == 'train' or split == 'val':
+ refs = [ref for ref in refs if ref['split'] == split]
+ else:
+ raise 'No such split [%s]' % split
+ ref_ids = [ref['ref_id'] for ref in refs]
+ return ref_ids
+
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
+ else:
+ if not len(image_ids) == 0:
+ lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
+ anns = list(itertools.chain.from_iterable(lists))
+ else:
+ anns = self.data['annotations']
+ if not len(cat_ids) == 0:
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
+ ann_ids = [ann['id'] for ann in anns]
+ if not len(ref_ids) == 0:
+ ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
+ return ann_ids
+
+ def getImgIds(self, ref_ids=[]):
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if not len(ref_ids) == 0:
+ image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
+ else:
+ image_ids = self.Imgs.keys()
+ return image_ids
+
+ def getCatIds(self):
+ return self.Cats.keys()
+
+ def loadRefs(self, ref_ids=[]):
+ if type(ref_ids) == list:
+ return [self.Refs[ref_id] for ref_id in ref_ids]
+ elif type(ref_ids) == int:
+ return [self.Refs[ref_ids]]
+
+ def loadAnns(self, ann_ids=[]):
+ if type(ann_ids) == list:
+ return [self.Anns[ann_id] for ann_id in ann_ids]
+ elif type(ann_ids) == int:
+ return [self.Anns[ann_ids]]
+
+ def loadImgs(self, image_ids=[]):
+ if type(image_ids) == list:
+ return [self.Imgs[image_id] for image_id in image_ids]
+ elif type(image_ids) == int:
+ return [self.Imgs[image_ids]]
+
+ def loadCats(self, cat_ids=[]):
+ if type(cat_ids) == list:
+ return [self.Cats[cat_id] for cat_id in cat_ids]
+ elif type(cat_ids) == int:
+ return [self.Cats[cat_ids]]
+
+ def getRefBox(self, ref_id):
+ ref = self.Refs[ref_id]
+ ann = self.refToAnn[ref_id]
+ return ann['bbox'] # [x, y, w, h]
+
+ def showRef(self, ref, seg_box='box'):
+ ax = plt.gca()
+ # show image
+ image = self.Imgs[ref['image_id']]
+ I = io.imread(os.path.join(self.vis_root, image['file_name']))
+ ax.imshow(I)
+ # show refer expression
+ for sid, sent in enumerate(ref['sentences']):
+ print('%s. %s' % (sid + 1, sent['sent']))
+ # show segmentations
+ if seg_box == 'seg':
+ ann_id = ref['ann_id']
+ ann = self.Anns[ann_id]
+ polygons = []
+ color = []
+ c = 'none'
+ if type(ann['segmentation'][0]) == list:
+ # polygon used for refcoco*
+ for seg in ann['segmentation']:
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
+ polygons.append(Polygon(poly, True, alpha=0.4))
+ color.append(c)
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 1, 0, 0), linewidths=3, alpha=1)
+ ax.add_collection(p) # thick yellow polygon
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 0, 0, 0), linewidths=1, alpha=1)
+ ax.add_collection(p) # thin red polygon
+ else:
+ # mask used for refclef
+ raise NotImplementedError('RefClef is not downloaded')
+ # show bounding-box
+ elif seg_box == 'box':
+ ann_id = ref['ann_id']
+ ann = self.Anns[ann_id]
+ bbox = self.getRefBox(ref['ref_id'])
+ box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
+ ax.add_patch(box_plot)
diff --git a/minigpt4/datasets/datasets/coco_vqa_datasets.py b/minigpt4/datasets/datasets/coco_vqa_datasets.py
new file mode 100755
index 0000000..2dbe056
--- /dev/null
+++ b/minigpt4/datasets/datasets/coco_vqa_datasets.py
@@ -0,0 +1,145 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import json
+import random
+
+from PIL import Image
+
+from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
+
+from collections import OrderedDict
+
+
+class __DisplMixin:
+ def displ_item(self, index):
+ sample, ann = self.__getitem__(index), self.annotation[index]
+
+ return OrderedDict(
+ {
+ "file": ann["image"],
+ "question": ann["question"],
+ "question_id": ann["question_id"],
+ "answers": "; ".join(ann["answer"]),
+ "image": sample["image"],
+ }
+ )
+
+
+class COCOVQADataset(VQADataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ self.instruction_pool =[
+ "[vqa] {}",
+ "[vqa] Based on the image, respond to this question with a short answer: {}"
+ ]
+
+ exist_annotation = []
+ for ann in self.annotation:
+ image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
+ if os.path.exists(image_path):
+ exist_annotation.append(ann)
+ self.annotation = exist_annotation
+
+
+ def get_data(self, index):
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ question = self.text_processor(ann["question"])
+ question_id = ann["question_id"]
+
+ answer_weight = {}
+ for answer in ann["answer"]:
+ if answer in answer_weight.keys():
+ answer_weight[answer] += 1 / len(ann["answer"])
+ else:
+ answer_weight[answer] = 1 / len(ann["answer"])
+
+ answers = list(answer_weight.keys())
+ weights = list(answer_weight.values())
+
+ answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights
+
+
+ return {
+ "image": image,
+ "question": question,
+ "question_id": question_id,
+ "answer": answer,
+ }
+
+ def __getitem__(self, index):
+ data = self.get_data(index)
+ instruction = random.choice(self.instruction_pool).format(data['question'])
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": data['image'],
+ "question_id": data["question_id"],
+ "instruction_input": instruction,
+ "answer": self.text_processor(data['answer']),
+ }
+
+
+class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+
+ self.instruction_pool = [
+ 'Question: {} Short answer:',
+ ]
+ self.vis_root = vis_root
+
+ self.annotation = json.load(open(ann_paths[0]))
+
+ answer_list_path = ann_paths[1]
+ if os.path.exists(answer_list_path):
+ self.answer_list = json.load(open(answer_list_path))
+ else:
+ self.answer_list = None
+
+ try:
+ self.coco_fmt_qust_file = ann_paths[2]
+ self.coco_fmt_anno_file = ann_paths[3]
+ except IndexError:
+ self.coco_fmt_qust_file = None
+ self.coco_fmt_anno_file = None
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self._add_instance_ids()
+
+ def __getitem__(self, index):
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ question = self.text_processor(ann["question"])
+
+ instruction = random.choice(self.instruction_pool).format(question)
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": image,
+ 'image_path': image_path,
+ "question": question,
+ "question_id": ann["question_id"],
+ "instruction_input": instruction,
+ "instance_id": ann["instance_id"],
+ }
diff --git a/minigpt4/datasets/datasets/flickr.py b/minigpt4/datasets/datasets/flickr.py
new file mode 100755
index 0000000..b6283d3
--- /dev/null
+++ b/minigpt4/datasets/datasets/flickr.py
@@ -0,0 +1,159 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+
+class GroundedDetailDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.instruction_pool = [
+ '[grounding] please describe this image in details',
+ '[grounding] describe this image as detailed as possible',
+ '[grounding] summarize this image in details',
+ '[grounding] give a thorough description of what you see in this image',
+ ]
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ # image_file = 'COCO_train2014_{}.jpg'.format(info['image_id'])
+ image_file = '{}.jpg'.format(info['image_id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ answer = info['grounded_caption']
+ instruction = random.choice(self.instruction_pool)
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['image_id'],
+ }
+
+
+
+
+class CaptionToObjectDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.instruction_pool = [
+ '[detection] {}',
+ ]
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ image_file = '{}.jpg'.format(info['image_id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ input = info["caption"]
+ answer = info["output"]
+
+ instruction = random.choice(self.instruction_pool).format(input)
+
+ instruction = "
{} ".format(instruction)
+
+ print("CaptionToObject instruction", instruction)
+ print("CaptionToObject answer", answer)
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['image_id'],
+ }
+
+
+
+
+class PhraseToObjectDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.instruction_pool = [
+ '[detection] {}',
+ ]
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+ image_file = '{}.jpg'.format(info['image_id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ input = info["phrase"]
+ answer = ""+input+"
"+info["bbox"]
+ instruction = random.choice(self.instruction_pool).format(input)
+
+ instruction = "
{} ".format(instruction)
+
+ print("PhraseToObject instruction", instruction)
+ print("PhraseToObject answer", answer)
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['image_id'],
+ }
diff --git a/minigpt4/datasets/datasets/gqa_datasets.py b/minigpt4/datasets/datasets/gqa_datasets.py
new file mode 100755
index 0000000..b5e835a
--- /dev/null
+++ b/minigpt4/datasets/datasets/gqa_datasets.py
@@ -0,0 +1,60 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import json
+
+from PIL import Image
+
+from minigpt4.datasets.datasets.vqa_datasets import VQADataset
+
+from collections import OrderedDict
+import random
+
+class __DisplMixin:
+ def displ_item(self, index):
+ sample, ann = self.__getitem__(index), self.annotation[index]
+
+ return OrderedDict(
+ {
+ "file": ann["image"],
+ "question": ann["question"],
+ "question_id": ann["question_id"],
+ "answers": "; ".join(ann["answer"]),
+ "image": sample["image"],
+ }
+ )
+
+
+class GQADataset(VQADataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+ self.instruction_pool =[
+ "[vqa] {}",
+ "[vqa] Based on the image, respond to this question with a short answer: {}"
+ ]
+
+ def __getitem__(self, index):
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ question = self.text_processor(ann["question"])
+
+ instruction = random.choice(self.instruction_pool).format(question)
+ instruction = "
{} ".format(instruction)
+
+ answers = self.text_processor(ann["answer"])
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answers,
+ }
+
diff --git a/minigpt4/datasets/datasets/llava_dataset.py b/minigpt4/datasets/datasets/llava_dataset.py
new file mode 100755
index 0000000..2766189
--- /dev/null
+++ b/minigpt4/datasets/datasets/llava_dataset.py
@@ -0,0 +1,150 @@
+import os
+import json
+import pickle
+import random
+import time
+# import iterto
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+class LlavaDetailDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ answer = info['conversations'][1]['value']
+ instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
+
+ instruction = '
{} '.format(self.text_processor(instruction))
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['id'],
+ }
+
+class LlavaReasonDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ answer = info['conversations'][1]['value']
+ instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
+
+ instruction = '
{} '.format(self.text_processor(instruction))
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['id'],
+ }
+
+
+
+
+class LlavaConversationDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.ann=[]
+
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ self.connect_sym = "!@#"
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
+ first_instruction = '
{} '.format(first_instruction)
+
+ questions = [first_instruction]
+ answers = []
+
+ for i, item in enumerate(info["conversations"][1:]):
+ if i % 2 ==0: # assistant
+ assistant_answer = item["value"]
+ answers.append(assistant_answer)
+ else:
+ human_instruction = item["value"]+" "
+ questions.append(human_instruction)
+
+ questions = self.connect_sym.join(questions)
+ answers = self.connect_sym.join(answers)
+
+
+ return {
+ "image": image,
+ "conv_q": questions,
+ 'conv_a': answers,
+ "image_id": info['id'],
+ "connect_sym": self.connect_sym
+ }
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/multitask_conversation.py b/minigpt4/datasets/datasets/multitask_conversation.py
new file mode 100644
index 0000000..3b13e52
--- /dev/null
+++ b/minigpt4/datasets/datasets/multitask_conversation.py
@@ -0,0 +1,75 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+
+
+
+class MultiTaskConversationDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ self.connect_sym = "!@#"
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ image_file = 'COCO_train2014_{}.jpg'.format(info['id'])
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip()
+ first_instruction = '
{} '.format(first_instruction)
+
+ questions = [first_instruction]
+ answers = []
+
+ for i, item in enumerate(info["conversations"][1:]):
+ if i % 2 ==0: # assistant
+ assistant_answer = item["value"]
+ answers.append(assistant_answer)
+ else:
+ human_instruction = item["value"]+" "
+ questions.append(human_instruction)
+
+ questions = self.connect_sym.join(questions)
+ answers = self.connect_sym.join(answers)
+
+
+ return {
+ "image": image,
+ "conv_q": questions,
+ 'conv_a': answers,
+ "image_id": info['id'],
+ "connect_sym": self.connect_sym
+ }
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/ocrvqa_dataset.py b/minigpt4/datasets/datasets/ocrvqa_dataset.py
new file mode 100755
index 0000000..00ce03d
--- /dev/null
+++ b/minigpt4/datasets/datasets/ocrvqa_dataset.py
@@ -0,0 +1,77 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+
+class OCRVQADataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+ self.data = self.create_data(ann_path)
+
+ self.instruction_pool =[
+ "[vqa] {}",
+ "[vqa] Based on the image, respond to this question with a short answer: {}"
+ ]
+
+ def create_data(self, ann_path):
+ processed_data = []
+ with open(ann_path, 'r') as f:
+ data = json.load(f)
+ for k in data.keys():
+ if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test
+ ext = os.path.splitext(data[k]['imageURL'])[1]
+ imageFile = k + ext
+ assert len(data[k]['questions']) == len(data[k]['answers'])
+ for q, a in zip(data[k]['questions'], data[k]['answers']):
+ processed_data.append(
+ {'question': q,
+ 'answer': a,
+ 'image_path': imageFile,
+ 'image_id': k,
+ 'title': data[k]['title'],
+ 'genre': data[k]['genre'],
+ }
+ )
+ return processed_data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ sample = self.data[index]
+ image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
+ image = self.vis_processor(image)
+ question = self.text_processor(sample["question"])
+ answer = self.text_processor(sample["answer"])
+
+ instruction = random.choice(self.instruction_pool).format(question)
+ instruction = "
{} ".format(instruction)
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": sample['image_id']
+ }
+
diff --git a/minigpt4/datasets/datasets/text_caps.py b/minigpt4/datasets/datasets/text_caps.py
new file mode 100755
index 0000000..47a87f1
--- /dev/null
+++ b/minigpt4/datasets/datasets/text_caps.py
@@ -0,0 +1,77 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+
+
+class TextCapDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.instruction_pool = [
+ 'Briefly describe this image.',
+ 'Provide a concise depiction of this image.',
+ 'Present a short description of this image.',
+ 'Summarize this image in a few words.',
+ 'A short image caption:',
+ 'A short image description:',
+ 'A photo of ',
+ 'An image that shows ',
+ 'Write a short description for the image. ',
+ 'Write a description for the photo.',
+ 'Provide a description of what is presented in the photo.',
+ 'Briefly describe the content of the image.',
+ 'Can you briefly explain what you see in the image?',
+ 'Could you use a few words to describe what you perceive in the photo?',
+ 'Please provide a short depiction of the picture.',
+ 'Using language, provide a short account of the image.',
+ 'Use a few words to illustrate what is happening in the picture.',
+ ]
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+
+ def __len__(self):
+ return len(self.ann["data"])
+
+
+ def __getitem__(self, index):
+ info = self.ann["data"][index]
+
+ image_file = '{}.jpg'.format(info['image_id'])
+
+ image_path = os.path.join(self.vis_root, image_file)
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+
+ caption = info["caption_str"]
+ caption = self.text_processor(caption)
+ instruction = "
[caption] {} ".format(random.choice(self.instruction_pool))
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": caption,
+ }
diff --git a/minigpt4/datasets/datasets/unnatural_instruction.py b/minigpt4/datasets/datasets/unnatural_instruction.py
new file mode 100755
index 0000000..3fcf9ac
--- /dev/null
+++ b/minigpt4/datasets/datasets/unnatural_instruction.py
@@ -0,0 +1,46 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from torch.utils.data import Dataset
+import webdataset as wds
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+
+
+class UnnaturalDataset(Dataset):
+ def __init__(self, text_processor, ann_path):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]["instances"][0]
+ instruction = info["instruction_with_input"]
+ constraints = info["constraints"]
+ answer = info["output"]
+ if constraints != None:
+ instruction = instruction+" "+constraints
+
+ return {
+ "instruction_input": self.text_processor(instruction),
+ "answer": self.text_processor(answer),
+ }
diff --git a/minigpt4/datasets/datasets/vg_dataset.py b/minigpt4/datasets/datasets/vg_dataset.py
new file mode 100755
index 0000000..16823c0
--- /dev/null
+++ b/minigpt4/datasets/datasets/vg_dataset.py
@@ -0,0 +1,90 @@
+import os
+import json
+import pickle
+import random
+import time
+import itertools
+
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from visual_genome import local
+
+
+
+
+class ReferVisualGenomeDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, data_dir):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.data_dir = data_dir
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ all_regions = local.get_all_region_descriptions(self.data_dir)
+ all_regions = [region for regions in all_regions for region in regions]
+
+ # follow OFA practice, only regions smaller than 16384 pixels are used for refer
+ self.regions = [region for region in all_regions if region.width * region.height < 16384]
+
+
+ self.instruction_pool = [
+ "[refer] {}",
+ "[refer] give me the location of {}",
+ "[refer] where is {} ?",
+ "[refer] from this image, tell me the location of {}",
+ "[refer] the location of {} is",
+ "[refer] could you tell me the location for {} ?",
+ "[refer] where can I locate the {} ?",
+ ]
+
+
+ def __len__(self):
+ return len(self.regions)
+
+ def preprocess(self, index):
+ region = self.regions[index]
+ image_file = region.image.url.split('/')[-2:]
+ image_path = os.path.join(self.data_dir, *image_file)
+ image = Image.open(image_path).convert("RGB")
+ image_orig_size = image.size
+ image = self.vis_processor(image)
+ image_new_size = [100,100]
+
+ sample_sentence = region.phrase
+ refer_sentence = self.text_processor(sample_sentence)
+
+ bbox = [region.x, region.y, region.width, region.height]
+
+ bbox = [
+ bbox[0] / image_orig_size[0] * image_new_size[0],
+ bbox[1] / image_orig_size[1] * image_new_size[1],
+ (bbox[0] + bbox[2]) / image_orig_size[0] * image_new_size[0],
+ (bbox[1] + bbox[3]) / image_orig_size[1] * image_new_size[1]
+ ]
+ bbox = [int(x) for x in bbox]
+ bbox = "{{<{}><{}><{}><{}>}}".format(*bbox)
+ return {
+ "image": image,
+ "refer_sentence": refer_sentence,
+ "bbox": bbox,
+ "image_id": region.image.id,
+ }
+
+ def __getitem__(self, index):
+ data = self.preprocess(index)
+ instruction = random.choice(self.instruction_pool).format(data['refer_sentence'])
+
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": data['image'],
+ "instruction_input": instruction,
+ "answer": data['bbox'],
+ "image_id": data['image_id'],
+ }
+
+
diff --git a/minigpt4/datasets/datasets/vqa_datasets.py b/minigpt4/datasets/datasets/vqa_datasets.py
new file mode 100755
index 0000000..5cdc0fa
--- /dev/null
+++ b/minigpt4/datasets/datasets/vqa_datasets.py
@@ -0,0 +1,223 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import torch
+from PIL import Image
+import os
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+
+
+class VQADataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ # def collater(self, samples):
+ # image_list, question_list, answer_list, weight_list = [], [], [], []
+
+ # num_answers = []
+
+ # for sample in samples:
+ # image_list.append(sample["image"])
+ # question_list.append(sample["question"])
+
+ # weight_list.extend(sample["weights"])
+
+ # answers = sample["answer"]
+
+ # answer_list.extend(answers)
+ # num_answers.append(len(answers))
+
+ # return {
+ # "image": torch.stack(image_list, dim=0),
+ # "text_input": question_list,
+ # "answer": answer_list,
+ # "weight": torch.Tensor(weight_list),
+ # "n_answers": torch.LongTensor(num_answers),
+ # }
+
+
+class VQAEvalDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+
+class OKVQAEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['image_id']
+ question = data['question']
+ question_id = data['question_id']
+ img_file = '{:0>12}.jpg'.format(img_id)
+ image_path = os.path.join(self.root_path, img_file)
+ image = Image.open(image_path).convert('RGB')
+ image = self.vis_processor(image)
+ question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
+ # question = f"[vqa] {question} "
+ return image, question, question_id, img_id
+
+class VizWizEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['image']
+ question = data['question']
+ answers = data['answers']
+ answers = '_'.join([answer['answer'] for answer in answers])
+ image_path = os.path.join(self.root_path, img_id)
+ image = Image.open(image_path).convert('RGB')
+ image = self.vis_processor(image)
+ # question = f"[vqa] Based on the image, respond to this question with a short answer: {question} "
+ question = f"[vqa] Based on the image, respond to this question with a short answer: {question} and reply 'unanswerable' if you could not answer it"
+ return image, question, answers
+
+class AOKVQADAEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_file = data['image']
+ question = data['question']
+ question_id = data['question_id']
+ image_path = os.path.join(self.root_path, img_file)
+ image = Image.open(image_path).convert('RGB')
+ image = self.vis_processor(image)
+ question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
+ # question = f"[vqa] {question} "
+ return image, question, question_id
+
+class AOKVQAMCEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_file = data['image']
+ question = data['question']
+ question_id = data['question_id']
+ image_path = os.path.join(self.root_path, img_file)
+ image = Image.open(image_path).convert('RGB')
+ image = self.vis_processor(image).half().cuda()
+ candidates=data['choices']
+ # question = f"Given this image, choose one correct answer from {candidates} for this question: {question} "
+ question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
+ # question = f"[vqa] {question} "
+ return image, question, question_id, candidates
+
+class IconQAEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ image_id = data['image_id']
+ question = data['question']
+ image_path = os.path.join(self.root_path, image_id, 'image.png')
+ image = Image.open(image_path).convert('RGB')
+ image = self.vis_processor(image).half().cuda()
+ candidates = '_'.join(data['choices'])
+ answer = data['answer']
+ # question = f"Given this image, choose one correct answer from {candidates} for this question: {question} "
+ question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
+ # question = f"[vqa] {question} "
+ return image, question, candidates, answer
+
+class GQAEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ ann = self.loaded_data[idx]
+ image_id = ann["image"]
+ image_path = os.path.join(self.root_path, f"{image_id}")
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+ question = ann["question"]
+ # question = f'Question: {question} Short answer: '
+ question = f"[vqa] Based on the image, respond to this question with a short answer: {question}"
+ # question = f"[vqa] {question} "
+ labels = ann["answer"]
+
+ return image, question, labels
+
+class HMEvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ ann = self.loaded_data[idx]
+ image_id = ann["img"]
+ image_path = os.path.join(self.root_path, f"{image_id}")
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+ question = ann["text"]
+ question = f"[vqa] This is an image with: \"{question}\" written on it. Is it hateful? Answer:"
+ labels = ann["label"]
+
+ return image, question, labels
+
+class VSREvalData(torch.utils.data.Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ ann = self.loaded_data[idx]
+ image_path = os.path.join(self.root_path, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+ question = ann["caption"]
+ question = f'[vqa] Based on the image, is this statement true or false? {question}'
+ question_id = ann["image"].split('.')[0]
+ labels = 'true' if ann["label"] == 1 else 'false'
+
+ return image, question, labels
\ No newline at end of file
diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py
index 58edb1a..cd051ec 100644
--- a/minigpt4/models/minigpt_base.py
+++ b/minigpt4/models/minigpt_base.py
@@ -172,12 +172,12 @@ class MiniGPTBase(BaseModel):
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
- questions = [self.llama_tokenizer(q,
+ questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
- answers = [self.llama_tokenizer(q,
- return_tensors="pt",
- add_special_tokens=False).to(self.device) for q in answers]
+ answers = [self.llama_tokenizer(a + self.end_sym,
+ return_tensors="pt",
+ add_special_tokens=False).to(self.device) for a in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):
diff --git a/minigpt4/tasks/base_task.py b/minigpt4/tasks/base_task.py
index e5eb32b..1cfa46c 100644
--- a/minigpt4/tasks/base_task.py
+++ b/minigpt4/tasks/base_task.py
@@ -14,7 +14,7 @@ from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process
from minigpt4.common.logger import MetricLogger, SmoothedValue
from minigpt4.common.registry import registry
from minigpt4.datasets.data_utils import prepare_sample
-
+import wandb
class BaseTask:
def __init__(self, **kwargs):
@@ -234,7 +234,9 @@ class BaseTask:
else:
optimizer.step()
optimizer.zero_grad()
-
+ # if self.cfg.wandb_log:
+ if self.cfg.run_cfg.wandb_log:
+ wandb.log({"epoch": inner_epoch, "loss": loss})
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
diff --git a/train.py b/train.py
index a90cb3f..4dead8e 100644
--- a/train.py
+++ b/train.py
@@ -12,6 +12,7 @@ import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
+import wandb
import minigpt4.tasks as tasks
from minigpt4.common.config import Config
@@ -43,10 +44,7 @@ def parse_args():
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
-
args = parser.parse_args()
- # if 'LOCAL_RANK' not in os.environ:
- # os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
@@ -77,22 +75,25 @@ def main():
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
job_id = now()
-
- cfg = Config(parse_args())
+ args = parse_args()
+ cfg = Config(args)
init_distributed_mode(cfg.run_cfg)
-
setup_seeds(cfg)
# set after init_distributed_mode() to only log on master.
setup_logger()
-
cfg.pretty_print()
task = tasks.setup_task(cfg)
datasets = task.build_datasets(cfg)
model = task.build_model(cfg)
+ if cfg.run_cfg.wandb_log:
+ wandb.login()
+ wandb.init(project="minigptv", name=cfg.run_cfg.job_name)
+ wandb.watch(model)
+
runner = get_runner_class(cfg)(
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
)
diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml
index c13d31f..bcc458e 100644
--- a/train_configs/minigpt4_llama2_stage1_pretrain.yaml
+++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml
@@ -52,4 +52,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
- distributed: True
\ No newline at end of file
+ distributed: True
+
+ wandb_log: True
+ job_name: minigpt4_llama2_pretrain
\ No newline at end of file
diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml
index 8c138ae..29b5358 100644
--- a/train_configs/minigpt4_llama2_stage2_finetune.yaml
+++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml
@@ -46,4 +46,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
- distributed: True
\ No newline at end of file
+ distributed: True
+
+ wandb_log: True
+ job_name: minigpt4_llama2_finetune
\ No newline at end of file
diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml
index ce8bc87..bd9a451 100644
--- a/train_configs/minigpt4_stage1_pretrain.yaml
+++ b/train_configs/minigpt4_stage1_pretrain.yaml
@@ -52,4 +52,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
- distributed: True
\ No newline at end of file
+ distributed: True
+
+ wandb_log: True
+ job_name: minigpt4_pretrain
\ No newline at end of file
diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml
index 531a3a0..89d1100 100644
--- a/train_configs/minigpt4_stage2_finetune.yaml
+++ b/train_configs/minigpt4_stage2_finetune.yaml
@@ -46,4 +46,7 @@ run:
device: "cuda"
world_size: 1
dist_url: "env://"
- distributed: True
\ No newline at end of file
+ distributed: True
+
+ wandb_log: True
+ job_name: minigpt4_finetune
\ No newline at end of file
diff --git a/train_configs/minigptv2_finetune.yaml b/train_configs/minigptv2_finetune.yaml
new file mode 100644
index 0000000..114d7e9
--- /dev/null
+++ b/train_configs/minigptv2_finetune.yaml
@@ -0,0 +1,294 @@
+model:
+ arch: minigpt_v2
+ model_type: pretrain
+ max_txt_len: 1024
+ image_size: 448
+ end_sym: ""
+ llama_model: "/path/to/llama_checkpoint"
+ ckpt: "/path/to/pretrained_checkpoint"
+ use_grad_checkpoint: True
+ chat_template: True
+ lora_r: 64
+ lora_alpha: 16
+
+datasets:
+ multitask_conversation:
+ batch_size: 2
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 50
+
+ llava_conversation:
+ batch_size: 2
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 30
+
+ unnatural_instruction:
+ batch_size: 1
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 10
+
+
+ refvg:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 40
+
+ llava_detail:
+ batch_size: 4
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 20
+
+ llava_reason:
+ batch_size: 4
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 80
+
+
+ flickr_grounded_caption:
+ batch_size: 2
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 80
+
+ flickr_CaptionToPhrase:
+ batch_size: 2
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 80
+
+ flickr_ObjectToPhrase:
+ batch_size: 2
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 80
+
+ coco_caption:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 10
+
+
+ textcaps_caption: #
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 30
+
+ refcoco:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 25
+
+
+ refcocop:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 25
+
+ refcocog:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 25
+
+
+
+ invrefcoco:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 10
+
+ invrefcocop:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 10
+
+ invrefcocog:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 10
+
+
+ coco_vqa:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 15
+
+ ok_vqa:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 8
+
+ aok_vqa:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 12
+
+ gqa:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 50
+
+ ocrvqa:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 30
+
+
+run:
+ task: image_text_pretrain
+ # optimizer
+ lr_sched: "linear_warmup_cosine_lr"
+ init_lr: 1e-5
+ min_lr: 8e-5
+ warmup_lr: 1e-6
+
+ weight_decay: 0.05
+ max_epoch: 50
+ num_workers: 6
+ warmup_steps: 1000
+ iters_per_epoch: 1000
+
+ seed: 42
+ output_dir: "/path/to/save_checkpoint"
+
+ amp: True
+ resume_ckpt_path: null
+
+ evaluate: False
+ train_splits: ["train"]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
+
+ wandb_log: True
+ job_name: minigptv2_finetune
\ No newline at end of file