mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
include llama2
This commit is contained in:
parent
bbd7883d1c
commit
fb8e2c656a
49
README.md
49
README.md
@ -1,5 +1,5 @@
|
|||||||
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
||||||
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
||||||
|
|
||||||
**King Abdullah University of Science and Technology**
|
**King Abdullah University of Science and Technology**
|
||||||
|
|
||||||
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
|
|
||||||
## News
|
## News
|
||||||
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
|
We now provide a llama 2 version of MiniGPT-4
|
||||||
|
|
||||||
|
|
||||||
## Online Demo
|
## Online Demo
|
||||||
@ -52,49 +52,52 @@ conda activate minigpt4
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
**2. Prepare the pretrained Vicuna weights**
|
**2. Prepare the pretrained LLM weights**
|
||||||
|
|
||||||
The current version of MiniGPT-4 is built on the v0 version of Vicuna-13B.
|
Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
|
||||||
Please refer to our instruction [here](PrepareVicuna.md)
|
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
|
||||||
to prepare the Vicuna weights.
|
|
||||||
The final weights would be in a single folder in a structure similar to the following:
|
| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
|
||||||
|
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||||
|
[Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat/tree/main)
|
||||||
|
|
||||||
```
|
|
||||||
vicuna_weights
|
|
||||||
├── config.json
|
|
||||||
├── generation_config.json
|
|
||||||
├── pytorch_model.bin.index.json
|
|
||||||
├── pytorch_model-00001-of-00003.bin
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, set the path to the vicuna weight in the model config file
|
Then, set the path to the vicuna weight in the model config file
|
||||||
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
|
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
|
||||||
|
and/or the path to the llama2 weight in the model config file
|
||||||
|
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
|
||||||
|
|
||||||
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
||||||
|
|
||||||
Download the pretrained checkpoints according to the Vicuna model you prepare.
|
Download the pretrained checkpoints according to the Vicuna model you prepare.
|
||||||
|
|
||||||
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
|
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
|
||||||
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||||
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
|
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
|
||||||
|
|
||||||
|
|
||||||
Then, set the path to the pretrained checkpoint in the evaluation config file
|
Then, set the path to the pretrained checkpoint in the evaluation config file
|
||||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
|
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Launching Demo Locally
|
### Launching Demo Locally
|
||||||
|
|
||||||
Try out our demo [demo.py](demo.py) on your local machine by running
|
Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
|
||||||
|
|
||||||
```
|
```
|
||||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
||||||
```
|
```
|
||||||
|
|
||||||
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
|
or for Llama 2 version by
|
||||||
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
|
|
||||||
|
```
|
||||||
|
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
|
||||||
|
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
|
||||||
For more powerful GPUs, you can run the model
|
For more powerful GPUs, you can run the model
|
||||||
in 16 bit by setting low_resource to False in the config file
|
in 16 bit by setting low_resource to False in the config file
|
||||||
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
|
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
|
||||||
|
12
demo.py
12
demo.py
@ -10,7 +10,7 @@ import gradio as gr
|
|||||||
from minigpt4.common.config import Config
|
from minigpt4.common.config import Config
|
||||||
from minigpt4.common.dist_utils import get_rank
|
from minigpt4.common.dist_utils import get_rank
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
|
||||||
|
|
||||||
# imports modules for registration
|
# imports modules for registration
|
||||||
from minigpt4.datasets.builders import *
|
from minigpt4.datasets.builders import *
|
||||||
@ -50,6 +50,9 @@ def setup_seeds(config):
|
|||||||
# Model Initialization
|
# Model Initialization
|
||||||
# ========================================
|
# ========================================
|
||||||
|
|
||||||
|
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
|
||||||
|
'pretrain_llama2': CONV_VISION_LLama2}
|
||||||
|
|
||||||
print('Initializing Chat')
|
print('Initializing Chat')
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
cfg = Config(args)
|
cfg = Config(args)
|
||||||
@ -59,15 +62,19 @@ model_config.device_8bit = args.gpu_id
|
|||||||
model_cls = registry.get_model_class(model_config.arch)
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||||
|
|
||||||
|
CONV_VISION = conv_dict[model_config.model_type]
|
||||||
|
|
||||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||||
print('Initialization Finished')
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
|
||||||
# ========================================
|
# ========================================
|
||||||
# Gradio Setting
|
# Gradio Setting
|
||||||
# ========================================
|
# ========================================
|
||||||
|
|
||||||
|
|
||||||
def gradio_reset(chat_state, img_list):
|
def gradio_reset(chat_state, img_list):
|
||||||
if chat_state is not None:
|
if chat_state is not None:
|
||||||
chat_state.messages = []
|
chat_state.messages = []
|
||||||
@ -75,6 +82,7 @@ def gradio_reset(chat_state, img_list):
|
|||||||
img_list = []
|
img_list = []
|
||||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
def upload_img(gr_img, text_input, chat_state):
|
def upload_img(gr_img, text_input, chat_state):
|
||||||
if gr_img is None:
|
if gr_img is None:
|
||||||
return None, None, gr.update(interactive=True), chat_state, None
|
return None, None, gr.update(interactive=True), chat_state, None
|
||||||
@ -83,6 +91,7 @@ def upload_img(gr_img, text_input, chat_state):
|
|||||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
def gradio_ask(user_message, chatbot, chat_state):
|
def gradio_ask(user_message, chatbot, chat_state):
|
||||||
if len(user_message) == 0:
|
if len(user_message) == 0:
|
||||||
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||||
@ -101,6 +110,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
|||||||
chatbot[-1][1] = llm_message
|
chatbot[-1][1] = llm_message
|
||||||
return chatbot, chat_state, img_list
|
return chatbot, chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
||||||
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
||||||
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: mini_gpt4
|
||||||
model_type: pretrain_vicuna
|
model_type: pretrain_vicuna0
|
||||||
freeze_vit: True
|
|
||||||
freeze_qformer: True
|
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "###"
|
end_sym: "###"
|
||||||
low_resource: True
|
low_resource: True
|
||||||
prompt_path: "prompts/alignment.txt"
|
|
||||||
prompt_template: '###Human: {} ###Assistant: '
|
prompt_template: '###Human: {} ###Assistant: '
|
||||||
ckpt: '/path/to/pretrained/ckpt/'
|
ckpt: '/home/zhud/ibex/pretrained_minigpt4.pth'
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
|
22
eval_configs/minigpt4_llama2_eval.yaml
Normal file
22
eval_configs/minigpt4_llama2_eval.yaml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
model:
|
||||||
|
arch: mini_gpt4
|
||||||
|
model_type: pretrain_llama2
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "</s>"
|
||||||
|
low_resource: True
|
||||||
|
prompt_template: '[INST] {} [/INST] '
|
||||||
|
ckpt: '/home/zhud/c2090/zhud/project/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20230826182/checkpoint_4.pth'
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
cc_sbu_align:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
@ -55,7 +55,10 @@ def is_main_process():
|
|||||||
|
|
||||||
|
|
||||||
def init_distributed_mode(args):
|
def init_distributed_mode(args):
|
||||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
if args.distributed is False:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
return
|
||||||
|
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
args.rank = int(os.environ["RANK"])
|
args.rank = int(os.environ["RANK"])
|
||||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
29
minigpt4/configs/models/minigpt4_llama2.yaml
Normal file
29
minigpt4/configs/models/minigpt4_llama2.yaml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
model:
|
||||||
|
arch: mini_gpt4
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 224
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
freeze_vit: True
|
||||||
|
has_qformer: False
|
||||||
|
|
||||||
|
# generation configs
|
||||||
|
prompt: ""
|
||||||
|
|
||||||
|
llama_model: "/path/to/llama2/weight"
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
@ -12,12 +12,11 @@ model:
|
|||||||
# Q-Former
|
# Q-Former
|
||||||
num_query_token: 32
|
num_query_token: 32
|
||||||
|
|
||||||
# Vicuna
|
|
||||||
llama_model: "/path/to/vicuna/weights/"
|
|
||||||
|
|
||||||
# generation configs
|
# generation configs
|
||||||
prompt: ""
|
prompt: ""
|
||||||
|
|
||||||
|
llama_model: "/path/to/vicuna/weight"
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
vis_processor:
|
vis_processor:
|
||||||
train:
|
train:
|
@ -39,18 +39,18 @@ class Conversation:
|
|||||||
ret = self.system + self.sep
|
ret = self.system + self.sep
|
||||||
for role, message in self.messages:
|
for role, message in self.messages:
|
||||||
if message:
|
if message:
|
||||||
ret += role + ": " + message + self.sep
|
ret += role + message + self.sep
|
||||||
else:
|
else:
|
||||||
ret += role + ":"
|
ret += role
|
||||||
return ret
|
return ret
|
||||||
elif self.sep_style == SeparatorStyle.TWO:
|
elif self.sep_style == SeparatorStyle.TWO:
|
||||||
seps = [self.sep, self.sep2]
|
seps = [self.sep, self.sep2]
|
||||||
ret = self.system + seps[0]
|
ret = self.system + seps[0]
|
||||||
for i, (role, message) in enumerate(self.messages):
|
for i, (role, message) in enumerate(self.messages):
|
||||||
if message:
|
if message:
|
||||||
ret += role + ": " + message + seps[i % 2]
|
ret += role + message + seps[i % 2]
|
||||||
else:
|
else:
|
||||||
ret += role + ":"
|
ret += role
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
@ -106,16 +106,26 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
CONV_VISION = Conversation(
|
CONV_VISION_Vicuna0 = Conversation(
|
||||||
system="Give the following image: <Img>ImageContent</Img>. "
|
system="Give the following image: <Img>ImageContent</Img>. "
|
||||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||||
roles=("Human", "Assistant"),
|
roles=("Human: ", "Assistant: "),
|
||||||
messages=[],
|
messages=[],
|
||||||
offset=2,
|
offset=2,
|
||||||
sep_style=SeparatorStyle.SINGLE,
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
sep="###",
|
sep="###",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CONV_VISION_LLama2 = Conversation(
|
||||||
|
system="Give the following image: <Img>ImageContent</Img>. "
|
||||||
|
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||||
|
roles=("<s>[INST] ", " [/INST] "),
|
||||||
|
messages=[],
|
||||||
|
offset=2,
|
||||||
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
|
sep="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
|
@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset):
|
|||||||
def to_dict(self, sample):
|
def to_dict(self, sample):
|
||||||
return {
|
return {
|
||||||
"image": sample[0],
|
"image": sample[0],
|
||||||
"text_input": self.text_processor(sample[1]["caption"]),
|
"answer": self.text_processor(sample[1]["caption"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"image": image,
|
||||||
"text_input": caption,
|
"answer": caption,
|
||||||
"image_id": self.img_ids[ann["image_id"]],
|
"image_id": self.img_ids[ann["image_id"]],
|
||||||
}
|
}
|
@ -26,6 +26,6 @@ class LaionDataset(BaseDataset):
|
|||||||
def to_dict(self, sample):
|
def to_dict(self, sample):
|
||||||
return {
|
return {
|
||||||
"image": sample[0],
|
"image": sample[0],
|
||||||
"text_input": self.text_processor(sample[1]["caption"]),
|
"answer": self.text_processor(sample[1]["caption"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,9 +7,17 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||||
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
get_peft_model_state_dict,
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
set_peft_model_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.register_model("mini_gpt4")
|
@registry.register_model("mini_gpt4")
|
||||||
class MiniGPT4(Blip2Base):
|
class MiniGPT4(Blip2Base):
|
||||||
@ -18,7 +26,8 @@ class MiniGPT4(Blip2Base):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||||
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
|
||||||
|
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -30,6 +39,7 @@ class MiniGPT4(Blip2Base):
|
|||||||
use_grad_checkpoint=False,
|
use_grad_checkpoint=False,
|
||||||
vit_precision="fp16",
|
vit_precision="fp16",
|
||||||
freeze_vit=True,
|
freeze_vit=True,
|
||||||
|
has_qformer=True,
|
||||||
freeze_qformer=True,
|
freeze_qformer=True,
|
||||||
num_query_token=32,
|
num_query_token=32,
|
||||||
llama_model="",
|
llama_model="",
|
||||||
@ -39,6 +49,10 @@ class MiniGPT4(Blip2Base):
|
|||||||
end_sym='\n',
|
end_sym='\n',
|
||||||
low_resource=False, # use 8 bit and put vit in cpu
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
|
lora_r=0,
|
||||||
|
lora_target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -61,30 +75,37 @@ class MiniGPT4(Blip2Base):
|
|||||||
logging.info("freeze vision encoder")
|
logging.info("freeze vision encoder")
|
||||||
print('Loading VIT Done')
|
print('Loading VIT Done')
|
||||||
|
|
||||||
print('Loading Q-Former')
|
self.has_qformer = has_qformer
|
||||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
if self.has_qformer:
|
||||||
num_query_token, self.visual_encoder.num_features
|
print('Loading Q-Former')
|
||||||
)
|
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||||
self.Qformer.cls = None
|
num_query_token, self.visual_encoder.num_features
|
||||||
self.Qformer.bert.embeddings.word_embeddings = None
|
)
|
||||||
self.Qformer.bert.embeddings.position_embeddings = None
|
self.Qformer.cls = None
|
||||||
for layer in self.Qformer.bert.encoder.layer:
|
self.Qformer.bert.embeddings.word_embeddings = None
|
||||||
layer.output = None
|
self.Qformer.bert.embeddings.position_embeddings = None
|
||||||
layer.intermediate = None
|
for layer in self.Qformer.bert.encoder.layer:
|
||||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
layer.output = None
|
||||||
|
layer.intermediate = None
|
||||||
|
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||||
|
|
||||||
if freeze_qformer:
|
if freeze_qformer:
|
||||||
for name, param in self.Qformer.named_parameters():
|
for name, param in self.Qformer.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self.Qformer = self.Qformer.eval()
|
self.Qformer = self.Qformer.eval()
|
||||||
self.Qformer.train = disabled_train
|
self.Qformer.train = disabled_train
|
||||||
self.query_tokens.requires_grad = False
|
self.query_tokens.requires_grad = False
|
||||||
logging.info("freeze Qformer")
|
logging.info("freeze Qformer")
|
||||||
print('Loading Q-Former Done')
|
|
||||||
|
img_f_dim = self.Qformer.config.hidden_size
|
||||||
|
print('Loading Q-Former Done')
|
||||||
|
else:
|
||||||
|
img_f_dim = self.visual_encoder.num_features * 4
|
||||||
|
print('Do not use Q-Former here.')
|
||||||
|
|
||||||
print('Loading LLAMA')
|
print('Loading LLAMA')
|
||||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
self.llama_tokenizer.pad_token = "$$"
|
||||||
|
|
||||||
if self.low_resource:
|
if self.low_resource:
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
@ -99,12 +120,31 @@ class MiniGPT4(Blip2Base):
|
|||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, param in self.llama_model.named_parameters():
|
if lora_r > 0:
|
||||||
param.requires_grad = False
|
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||||
|
loraconfig = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=lora_target_modules,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM"
|
||||||
|
)
|
||||||
|
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||||
|
|
||||||
|
# if ckpt_path:
|
||||||
|
# print('load the llm under lora')
|
||||||
|
# ckpt = torch.load(ckpt_path)
|
||||||
|
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||||
|
self.llama_model.print_trainable_parameters()
|
||||||
|
|
||||||
|
else:
|
||||||
|
for name, param in self.llama_model.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
print('Loading LLAMA Done')
|
print('Loading LLAMA Done')
|
||||||
|
|
||||||
self.llama_proj = nn.Linear(
|
self.llama_proj = nn.Linear(
|
||||||
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
|
img_f_dim, self.llama_model.config.hidden_size
|
||||||
)
|
)
|
||||||
self.max_txt_len = max_txt_len
|
self.max_txt_len = max_txt_len
|
||||||
self.end_sym = end_sym
|
self.end_sym = end_sym
|
||||||
@ -133,50 +173,109 @@ class MiniGPT4(Blip2Base):
|
|||||||
|
|
||||||
with self.maybe_autocast():
|
with self.maybe_autocast():
|
||||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
if self.has_qformer:
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
query_output = self.Qformer.bert(
|
query_output = self.Qformer.bert(
|
||||||
query_embeds=query_tokens,
|
query_embeds=query_tokens,
|
||||||
encoder_hidden_states=image_embeds,
|
encoder_hidden_states=image_embeds,
|
||||||
encoder_attention_mask=image_atts,
|
encoder_attention_mask=image_atts,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||||
|
else:
|
||||||
|
image_embeds = image_embeds[:, 1:, :]
|
||||||
|
bs, pn, hs = image_embeds.shape
|
||||||
|
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(image_embeds)
|
||||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||||
return inputs_llama, atts_llama
|
return inputs_llama, atts_llama
|
||||||
|
|
||||||
def prompt_wrap(self, img_embeds, atts_img, prompt):
|
def get_context_emb(self, prompt, img_list):
|
||||||
if prompt:
|
device = img_list[0].device
|
||||||
batch_size = img_embeds.shape[0]
|
prompt_segs = prompt.split('<ImageHere>')
|
||||||
p_before, p_after = prompt.split('<ImageHere>')
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||||
p_before_tokens = self.llama_tokenizer(
|
seg_tokens = [
|
||||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
self.llama_tokenizer(
|
||||||
p_after_tokens = self.llama_tokenizer(
|
seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
|
||||||
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
# only add bos to the first seg
|
||||||
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
for i, seg in enumerate(prompt_segs)
|
||||||
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
]
|
||||||
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
|
||||||
return wrapped_img_embeds, wrapped_atts_img
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||||
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
|
return mixed_embs
|
||||||
|
|
||||||
|
def prompt_wrap(self, img_embeds, atts_img, prompts):
|
||||||
|
if prompts:
|
||||||
|
emb_lists = []
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
prompts = [prompts] * len(img_embeds)
|
||||||
|
|
||||||
|
for each_img_embed, each_prompt in zip(img_embeds, prompts):
|
||||||
|
p_before, p_after = each_prompt.split('<ImageHere>')
|
||||||
|
|
||||||
|
p_before_tokens = self.llama_tokenizer(
|
||||||
|
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_after_tokens = self.llama_tokenizer(
|
||||||
|
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
|
||||||
|
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
|
||||||
|
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
|
||||||
|
emb_lists.append(wrapped_emb)
|
||||||
|
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||||
|
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||||
|
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
|
||||||
|
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
|
||||||
|
for i, emb in enumerate(emb_lists):
|
||||||
|
wrapped_embs[i, :emb_lens[i]] = emb
|
||||||
|
wrapped_atts[i, :emb_lens[i]] = 1
|
||||||
|
return wrapped_embs, wrapped_atts
|
||||||
else:
|
else:
|
||||||
return img_embeds, atts_img
|
return img_embeds, atts_img
|
||||||
|
|
||||||
|
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||||
|
input_lens = []
|
||||||
|
cat_embs = []
|
||||||
|
cat_atts = []
|
||||||
|
for i in range(input_embs.size(0)):
|
||||||
|
input_len = input_atts[i].sum()
|
||||||
|
input_lens.append(input_len)
|
||||||
|
cat_embs.append(
|
||||||
|
torch.cat([
|
||||||
|
input_embs[i][:input_len],
|
||||||
|
output_embs[i],
|
||||||
|
input_embs[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
cat_atts.append(
|
||||||
|
torch.cat([
|
||||||
|
input_atts[i][:input_len],
|
||||||
|
output_atts[i],
|
||||||
|
input_atts[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
cat_embs = torch.stack(cat_embs)
|
||||||
|
cat_atts = torch.stack(cat_atts)
|
||||||
|
return cat_embs, cat_atts, input_lens
|
||||||
|
|
||||||
def forward(self, samples):
|
def forward(self, samples):
|
||||||
image = samples["image"]
|
image = samples["image"]
|
||||||
img_embeds, atts_img = self.encode_img(image)
|
img_embeds, atts_img = self.encode_img(image)
|
||||||
if hasattr(samples, 'question_split'): # VQA dataset
|
|
||||||
print('VQA Batch')
|
if self.prompt_list:
|
||||||
vqa_prompt = '###Human: <Img><ImageHere></Img> '
|
instruction = random.choice(self.prompt_list)
|
||||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
|
else:
|
||||||
elif self.prompt_list:
|
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||||
prompt = random.choice(self.prompt_list)
|
|
||||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
|
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
|
||||||
|
|
||||||
self.llama_tokenizer.padding_side = "right"
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
text = [t + self.end_sym for t in samples["answer"]]
|
||||||
text = [t + self.end_sym for t in samples["text_input"]]
|
|
||||||
|
|
||||||
to_regress_tokens = self.llama_tokenizer(
|
to_regress_tokens = self.llama_tokenizer(
|
||||||
text,
|
text,
|
||||||
@ -187,26 +286,29 @@ class MiniGPT4(Blip2Base):
|
|||||||
add_special_tokens=False
|
add_special_tokens=False
|
||||||
).to(image.device)
|
).to(image.device)
|
||||||
|
|
||||||
targets = to_regress_tokens.input_ids.masked_fill(
|
|
||||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
|
||||||
)
|
|
||||||
|
|
||||||
empty_targets = (
|
|
||||||
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
|
||||||
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
|
|
||||||
)
|
|
||||||
targets = torch.cat([empty_targets, targets], dim=1)
|
|
||||||
|
|
||||||
batch_size = img_embeds.shape[0]
|
batch_size = img_embeds.shape[0]
|
||||||
bos = torch.ones([batch_size, 1],
|
bos = torch.ones([batch_size, 1],
|
||||||
dtype=to_regress_tokens.input_ids.dtype,
|
dtype=to_regress_tokens.input_ids.dtype,
|
||||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
||||||
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
bos_embeds = self.embed_tokens(bos)
|
||||||
atts_bos = atts_img[:, :1]
|
atts_bos = atts_img[:, :1]
|
||||||
|
|
||||||
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
|
||||||
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
inputs_embeds, attention_mask, input_lens = \
|
||||||
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
|
||||||
|
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||||
|
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
|
||||||
|
|
||||||
|
part_targets = to_regress_tokens.input_ids.masked_fill(
|
||||||
|
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||||
|
)
|
||||||
|
targets = (
|
||||||
|
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||||
|
dtype=torch.long).to(image.device).fill_(-100)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, target in enumerate(part_targets):
|
||||||
|
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
|
||||||
|
|
||||||
with self.maybe_autocast():
|
with self.maybe_autocast():
|
||||||
outputs = self.llama_model(
|
outputs = self.llama_model(
|
||||||
@ -219,6 +321,13 @@ class MiniGPT4(Blip2Base):
|
|||||||
|
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def embed_tokens(self, token_ids):
|
||||||
|
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
|
||||||
|
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||||
|
else:
|
||||||
|
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||||
|
return embeds
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, cfg):
|
def from_config(cls, cfg):
|
||||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||||
@ -231,6 +340,7 @@ class MiniGPT4(Blip2Base):
|
|||||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||||
vit_precision = cfg.get("vit_precision", "fp16")
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
freeze_vit = cfg.get("freeze_vit", True)
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
|
has_qformer = cfg.get("has_qformer", True)
|
||||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||||
low_resource = cfg.get("low_resource", False)
|
low_resource = cfg.get("low_resource", False)
|
||||||
device_8bit = cfg.get("device_8bit", 0)
|
device_8bit = cfg.get("device_8bit", 0)
|
||||||
@ -240,6 +350,9 @@ class MiniGPT4(Blip2Base):
|
|||||||
max_txt_len = cfg.get("max_txt_len", 32)
|
max_txt_len = cfg.get("max_txt_len", 32)
|
||||||
end_sym = cfg.get("end_sym", '\n')
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
|
lora_r = cfg.get("lora_r", 0)
|
||||||
|
lora_alpha = cfg.get("lora_alpha", 32)
|
||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
vit_model=vit_model,
|
vit_model=vit_model,
|
||||||
q_former_model=q_former_model,
|
q_former_model=q_former_model,
|
||||||
@ -248,6 +361,7 @@ class MiniGPT4(Blip2Base):
|
|||||||
use_grad_checkpoint=use_grad_checkpoint,
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
vit_precision=vit_precision,
|
vit_precision=vit_precision,
|
||||||
freeze_vit=freeze_vit,
|
freeze_vit=freeze_vit,
|
||||||
|
has_qformer=has_qformer,
|
||||||
freeze_qformer=freeze_qformer,
|
freeze_qformer=freeze_qformer,
|
||||||
num_query_token=num_query_token,
|
num_query_token=num_query_token,
|
||||||
llama_model=llama_model,
|
llama_model=llama_model,
|
||||||
@ -257,6 +371,8 @@ class MiniGPT4(Blip2Base):
|
|||||||
end_sym=end_sym,
|
end_sym=end_sym,
|
||||||
low_resource=low_resource,
|
low_resource=low_resource,
|
||||||
device_8bit=device_8bit,
|
device_8bit=device_8bit,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
55
train_configs/minigpt4_llama2_stage1_pretrain.yaml
Normal file
55
train_configs/minigpt4_llama2_stage1_pretrain.yaml
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
model:
|
||||||
|
arch: mini_gpt4
|
||||||
|
model_type: pretrain_llama2
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
laion:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 115
|
||||||
|
cc_sbu:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
sample_ratio: 14
|
||||||
|
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 1e-4
|
||||||
|
min_lr: 8e-5
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 4
|
||||||
|
batch_size_train: 64
|
||||||
|
batch_size_eval: 64
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 5000
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "output/minigpt4_stage1_pretrain"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
50
train_configs/minigpt4_llama2_stage2_finetune.yaml
Normal file
50
train_configs/minigpt4_llama2_stage2_finetune.yaml
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
model:
|
||||||
|
arch: mini_gpt4
|
||||||
|
model_type: pretrain_llama2
|
||||||
|
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "</s>"
|
||||||
|
prompt_path: "prompts/alignment.txt"
|
||||||
|
prompt_template: '[INST] {} [/INST] '
|
||||||
|
ckpt: '/path/to/stage1/checkpoint/'
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
cc_sbu_align:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 3e-5
|
||||||
|
min_lr: 1e-5
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 5
|
||||||
|
iters_per_epoch: 200
|
||||||
|
batch_size_train: 12
|
||||||
|
batch_size_eval: 12
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 200
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "output/minigpt4_stage2_finetune"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
@ -1,8 +1,6 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: mini_gpt4
|
||||||
model_type: pretrain_vicuna
|
model_type: pretrain_vicuna0
|
||||||
freeze_vit: True
|
|
||||||
freeze_qformer: True
|
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: mini_gpt4
|
||||||
model_type: pretrain_vicuna
|
model_type: pretrain_vicuna0
|
||||||
freeze_vit: True
|
|
||||||
freeze_qformer: True
|
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "###"
|
end_sym: "###"
|
||||||
prompt_path: "prompts/alignment.txt"
|
prompt_path: "prompts/alignment.txt"
|
||||||
|
Loading…
Reference in New Issue
Block a user