diff --git a/MiniGPT4_Train.md b/MiniGPT4_Train.md new file mode 100644 index 0000000..f9e8a5c --- /dev/null +++ b/MiniGPT4_Train.md @@ -0,0 +1,41 @@ +## Training of MiniGPT-4 + +The training of MiniGPT-4 contains two alignment stages. + +**1. First pretraining stage** + +In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets +to align the vision and language model. To download and prepare the datasets, please check +our [first stage dataset preparation instruction](dataset/README_1_STAGE.md). +After the first stage, the visual features are mapped and can be understood by the language +model. +To launch the first stage training, run the following command. In our experiments, we use 4 A100. +You can change the save path in the config file +[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml) + +```bash +torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml +``` + +A MiniGPT-4 checkpoint with only stage one training can be downloaded +[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link). +Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently. + + +**2. Second finetuning stage** + +In the second stage, we use a small high quality image-text pair dataset created by ourselves +and convert it to a conversation format to further align MiniGPT-4. +To download and prepare our second stage dataset, please check our +[second stage dataset preparation instruction](dataset/README_2_STAGE.md). +To launch the second stage alignment, +first specify the path to the checkpoint file trained in stage 1 in +[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml). +You can also specify the output path there. +Then, run the following command. In our experiments, we use 1 A100. + +```bash +torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml +``` + +After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly. diff --git a/MiniGPT_4.pdf b/MiniGPT_4.pdf deleted file mode 100644 index 5450815..0000000 Binary files a/MiniGPT_4.pdf and /dev/null differ diff --git a/MiniGPTv2.pdf b/MiniGPTv2.pdf new file mode 100644 index 0000000..04de5e8 Binary files /dev/null and b/MiniGPTv2.pdf differ diff --git a/PrepareVicuna.md b/PrepareVicuna.md deleted file mode 100644 index 0585e62..0000000 --- a/PrepareVicuna.md +++ /dev/null @@ -1,35 +0,0 @@ -## How to Prepare Vicuna Weight -Vicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT. -We currently use the v0 version of Vicuna-13B. - -To prepare Vicuna’s weight, first download Vicuna’s **delta** weight from [https://huggingface.co/lmsys/vicuna-13b-delta-v0](https://huggingface.co/lmsys/vicuna-13b-delta-v0). -In case you have git-lfs installed (https://git-lfs.com), this can be done by - -``` -git lfs install -git clone https://huggingface.co/lmsys/vicuna-13b-delta-v0 # more powerful, need at least 24G gpu memory -# or -git clone https://huggingface.co/lmsys/vicuna-7b-delta-v0 # smaller, need 12G gpu memory -``` - -Note that this is not directly the working weight, but the difference between the working weight and the original weight of LLAMA-13B. (Due to LLAMA’s rules, we cannot distribute the weight of LLAMA.) - -Then, you need to obtain the original LLAMA-7B or LLAMA-13B weights in the HuggingFace format -either following the instruction provided by HuggingFace -[here](https://huggingface.co/docs/transformers/main/model_doc/llama) or from the Internet. - -When these two weights are ready, we can use tools from Vicuna’s team to create the real working weight. -First, Install their library that is compatible with v0 Vicuna by - -``` -pip install git+https://github.com/lm-sys/FastChat.git@v0.1.10 -``` - -Then, run the following command to create the final working weight - -``` -python -m fastchat.model.apply_delta --base /path/to/llama-13bOR7b-hf/ --target /path/to/save/working/vicuna/weight/ --delta /path/to/vicuna-13bOR7b-delta-v0/ -``` - -Now you are good to go! - diff --git a/README.md b/README.md index 02bc504..8d85ecc 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,48 @@ -# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models -[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution +# MiniGPT-V -**King Abdullah University of Science and Technology** +**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning** + +Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨ + +☨equal last author + + [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=atFCwV2hSY4) + + +**MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models** + +Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny + +*equal contribution [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) +*King Abdullah University of Science and Technology* + ## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE) ## News -We now provide a llama 2 version of MiniGPT-4 +[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2 +[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4 ## Online Demo +Click the image to chat with MiniGPT-v2 around your images +[![demo](figs/minigpt2_demo.png)](https://minigpt-v2.github.io/) + Click the image to chat with MiniGPT-4 around your images [![demo](figs/online_demo.png)](https://minigpt-4.github.io) -## Examples +## MiniGPT-v2 Examples + +![MiniGPT-v2 demos](figs/demo.png) + + + +## MiniGPT-4 Examples | | | :-------------------------:|:-------------------------: ![find wild](figs/examples/wop_2.png) | ![write story](figs/examples/ad_2.png) @@ -28,17 +52,6 @@ More examples can be found in the [project page](https://minigpt-4.github.io). -## Introduction -- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer. -- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavily impacted. -- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset. -- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100. -- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4. - - -![overview](figs/overview.png) - - ## Getting Started ### Installation @@ -56,42 +69,62 @@ conda activate minigpt4 **2. Prepare the pretrained LLM weights** -Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4. +**MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version. Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs. -| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B | +| Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B | :------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: - [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) +[Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) -Then, set the path to the vicuna weight in the model config file -[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18 -and/or the path to the llama2 weight in the model config file +Then, set the variable *llama_model* in the model config file to the LLM weight path. + +* For MiniGPT-v2, set the LLM path +[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14. + +* For MiniGPT-4 (Llama2), set the LLM path [here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15. -**3. Prepare the pretrained MiniGPT-4 checkpoint** +* For MiniGPT-4 (Vicuna), set the LLM path +[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18 -Download the pretrained checkpoints according to the Vicuna model you prepare. +**3. Prepare the pretrained model checkpoints** -| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B | -:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: - [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) +Download the pretrained model checkpoints -Then, set the path to the pretrained checkpoint in the evaluation config file +| MiniGPT-v2 (LLaMA-2 Chat 7B) | +|------------------------------| +| [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) | + +For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file +in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8. + + + +| MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) | +|----------------------------|---------------------------|---------------------------------| +| [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) | + +For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version. ### Launching Demo Locally -Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running +For MiniGPT-v2, run +``` +python demo_v2.py --cfg-path eval_configs/minigpt4v2_eval.yaml --gpu-id 0 +``` + +For MiniGPT-4 (Vicuna version), run ``` python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 ``` -or for Llama 2 version by +For MiniGPT-4 (Llama2 version), run ``` python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0 @@ -101,52 +134,17 @@ python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0 To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1. This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM. For more powerful GPUs, you can run the model -in 16 bit by setting `low_resource` to `False` in the relevant config file -(line 6 of either [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6) if using Vicuna or [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6) if using Llama 2) and use a larger beam search width. +in 16 bit by setting `low_resource` to `False` in the relevant config file: -Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) +* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6) +* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6) +* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6) + +Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) ### Training -The training of MiniGPT-4 contains two alignment stages. - -**1. First pretraining stage** - -In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets -to align the vision and language model. To download and prepare the datasets, please check -our [first stage dataset preparation instruction](dataset/README_1_STAGE.md). -After the first stage, the visual features are mapped and can be understood by the language -model. -To launch the first stage training, run the following command. In our experiments, we use 4 A100. -You can change the save path in the config file -[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml) - -```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml -``` - -A MiniGPT-4 checkpoint with only stage one training can be downloaded -[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link). -Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently. - - -**2. Second finetuning stage** - -In the second stage, we use a small high quality image-text pair dataset created by ourselves -and convert it to a conversation format to further align MiniGPT-4. -To download and prepare our second stage dataset, please check our -[second stage dataset preparation instruction](dataset/README_2_STAGE.md). -To launch the second stage alignment, -first specify the path to the checkpoint file trained in stage 1 in -[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml). -You can also specify the output path there. -Then, run the following command. In our experiments, we use 1 A100. - -```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml -``` - -After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly. +For training details of MiniGPT-4, check [here](MiniGPT4_Train.md). @@ -156,10 +154,19 @@ After the second stage alignment, MiniGPT-4 is able to talk about the image cohe + [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before! + [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis! + [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source! ++ [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model. -If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX: +If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX: ```bibtex + +@article{Chen2023minigpt, + title={MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning}, + author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed}, + journal={github}, + year={2023} +} + @article{zhu2023minigpt, title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models}, author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed}, diff --git a/demo.py b/demo.py index 483b56c..c7646c4 100644 --- a/demo.py +++ b/demo.py @@ -7,10 +7,12 @@ import torch import torch.backends.cudnn as cudnn import gradio as gr +from transformers import StoppingCriteriaList + from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry -from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2 +from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub # imports modules for registration from minigpt4.datasets.builders import * @@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type] vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) -chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) + +stop_words_ids = [[835], [2277, 29937]] +stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids] +stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + +chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria) print('Initialization Finished') @@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state): chat_state = CONV_VISION.copy() img_list = [] llm_message = chat.upload_img(gr_img, chat_state, img_list) + chat.encode_img(img_list) return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list @@ -124,7 +132,7 @@ with gr.Blocks() as demo: gr.Markdown(article) with gr.Row(): - with gr.Column(scale=0.5): + with gr.Column(scale=1): image = gr.Image(type="pil") upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") @@ -147,7 +155,7 @@ with gr.Blocks() as demo: label="Temperature", ) - with gr.Column(): + with gr.Column(scale=2): chat_state = gr.State() img_list = gr.State() chatbot = gr.Chatbot(label='MiniGPT-4') diff --git a/demo_v2.py b/demo_v2.py new file mode 100644 index 0000000..4f66d53 --- /dev/null +++ b/demo_v2.py @@ -0,0 +1,662 @@ +import argparse +import os +import random +from collections import defaultdict + +import cv2 +import re + +import numpy as np +from PIL import Image +import torch +import html +import gradio as gr + +import torchvision.transforms as T +import torch.backends.cudnn as cudnn + +from minigpt4.common.config import Config + +from minigpt4.common.registry import registry +from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml', + help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +random.seed(42) +np.random.seed(42) +torch.manual_seed(42) + +cudnn.benchmark = False +cudnn.deterministic = True + +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +device = 'cuda:{}'.format(args.gpu_id) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to(device) +bounding_box_size = 100 + +vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + +model = model.eval() + +CONV_VISION = Conversation( + system="", + roles=(r"[INST] ", r" [/INST]"), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="", +) + + +def extract_substrings(string): + # first check if there is no-finished bracket + index = string.rfind('}') + if index != -1: + string = string[:index + 1] + + pattern = r'

(.*?)\}(?!<)' + matches = re.findall(pattern, string) + substrings = [match for match in matches] + + return substrings + + +def is_overlapping(rect1, rect2): + x1, y1, x2, y2 = rect1 + x3, y3, x4, y4 = rect2 + return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) + + +def computeIoU(bbox1, bbox2): + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + intersection_x1 = max(x1, x3) + intersection_y1 = max(y1, y3) + intersection_x2 = min(x2, x4) + intersection_y2 = min(y2, y4) + intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) + bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) + union_area = bbox1_area + bbox2_area - intersection_area + iou = intersection_area / union_area + return iou + + +def save_tmp_img(visual_img): + file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg" + file_path = "/tmp/" + file_name + visual_img.save(file_path) + return file_path + + +def mask2bbox(mask): + if mask is None: + return '' + mask = mask.resize([100, 100], resample=Image.NEAREST) + mask = np.array(mask)[:, :, 0] + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + if rows.sum(): + # Get the top, bottom, left, and right boundaries + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax) + else: + bbox = '' + + return bbox + + +def escape_markdown(text): + # List of Markdown special characters that need to be escaped + md_chars = ['<', '>'] + + # Escape each special character + for char in md_chars: + text = text.replace(char, '\\' + char) + + return text + + +def reverse_escape(text): + md_chars = ['\\<', '\\>'] + + for char in md_chars: + text = text.replace(char, char[1:]) + + return text + + +colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (210, 210, 0), + (255, 0, 255), + (0, 255, 255), + (114, 128, 250), + (0, 165, 255), + (0, 128, 0), + (144, 238, 144), + (238, 238, 175), + (255, 191, 0), + (0, 128, 0), + (226, 43, 138), + (255, 0, 255), + (0, 215, 255), +] + +color_map = { + f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for + color_id, color in enumerate(colors) +} + +used_colors = colors + + +def visualize_all_bbox_together(image, generation): + if image is None: + return None, '' + + generation = html.unescape(generation) + print('gen begin', generation) + + image_width, image_height = image.size + image = image.resize([500, int(500 / image_width * image_height)]) + image_width, image_height = image.size + + string_list = extract_substrings(generation) + if string_list: # it is grounding or detection + mode = 'all' + entities = defaultdict(list) + i = 0 + j = 0 + for string in string_list: + try: + obj, string = string.split('

') + except ValueError: + print('wrong string: ', string) + continue + bbox_list = string.split('') + flag = False + for bbox_string in bbox_list: + integers = re.findall(r'-?\d+', bbox_string) + if len(integers) == 4: + x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3]) + left = x0 / bounding_box_size * image_width + bottom = y0 / bounding_box_size * image_height + right = x1 / bounding_box_size * image_width + top = y1 / bounding_box_size * image_height + + entities[obj].append([left, bottom, right, top]) + + j += 1 + flag = True + if flag: + i += 1 + else: + integers = re.findall(r'-?\d+', generation) + + if len(integers) == 4: # it is refer + mode = 'single' + + entities = list() + x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3]) + left = x0 / bounding_box_size * image_width + bottom = y0 / bounding_box_size * image_height + right = x1 / bounding_box_size * image_width + top = y1 / bounding_box_size * image_height + entities.append([left, bottom, right, top]) + else: + # don't detect any valid bbox to visualize + return None, '' + + if len(entities) == 0: + return None, '' + + if isinstance(image, Image.Image): + image_h = image.height + image_w = image.width + image = np.array(image) + + elif isinstance(image, str): + if os.path.exists(image): + pil_img = Image.open(image).convert("RGB") + image = np.array(pil_img)[:, :, [2, 1, 0]] + image_h = pil_img.height + image_w = pil_img.width + else: + raise ValueError(f"invaild image path, {image}") + elif isinstance(image, torch.Tensor): + + image_tensor = image.cpu() + reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None] + reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None] + image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean + pil_img = T.ToPILImage()(image_tensor) + image_h = pil_img.height + image_w = pil_img.width + image = np.array(pil_img)[:, :, [2, 1, 0]] + else: + raise ValueError(f"invaild image format, {type(image)} for {image}") + + indices = list(range(len(entities))) + + new_image = image.copy() + + previous_bboxes = [] + # size of text + text_size = 0.5 + # thickness of text + text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) + box_line = 2 + (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line) + base_height = int(text_height * 0.675) + text_offset_original = text_height - base_height + text_spaces = 2 + + # num_bboxes = sum(len(x[-1]) for x in entities) + used_colors = colors # random.sample(colors, k=num_bboxes) + + color_id = -1 + for entity_idx, entity_name in enumerate(entities): + if mode == 'single' or mode == 'identify': + bboxes = entity_name + bboxes = [bboxes] + else: + bboxes = entities[entity_name] + color_id += 1 + for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes): + skip_flag = False + orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm) + + color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist()) + new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line) + + if mode == 'all': + l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1 + + x1 = orig_x1 - l_o + y1 = orig_y1 - l_o + + if y1 < text_height + text_offset_original + 2 * text_spaces: + y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces + x1 = orig_x1 + r_o + + # add text background + (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, + text_line) + text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - ( + text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1 + + for prev_bbox in previous_bboxes: + if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \ + prev_bbox['phrase'] == entity_name: + skip_flag = True + break + while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']): + text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces) + text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces) + y1 += (text_height + text_offset_original + 2 * text_spaces) + + if text_bg_y2 >= image_h: + text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces)) + text_bg_y2 = image_h + y1 = image_h + break + if not skip_flag: + alpha = 0.5 + for i in range(text_bg_y1, text_bg_y2): + for j in range(text_bg_x1, text_bg_x2): + if i < image_h and j < image_w: + if j < text_bg_x1 + 1.35 * c_width: + # original color + bg_color = color + else: + # white + bg_color = [255, 255, 255] + new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype( + np.uint8) + + cv2.putText( + new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), + cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA + ) + + previous_bboxes.append( + {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name}) + + if mode == 'all': + def color_iterator(colors): + while True: + for color in colors: + yield color + + color_gen = color_iterator(colors) + + # Add colors to phrases and remove

+ def colored_phrases(match): + phrase = match.group(1) + color = next(color_gen) + return f'{phrase}' + + print('gen before', generation) + generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation) + print('gen after', generation) + generation_colored = re.sub(r'

(.*?)

', colored_phrases, generation) + else: + generation_colored = '' + + pil_image = Image.fromarray(new_image) + return pil_image, generation_colored + + +def gradio_reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat', + interactive=True), chat_state, img_list + + +def image_upload_trigger(upload_flag, replace_flag, img_list): + # set the upload flag to true when receive a new image. + # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. + print('flag', upload_flag, replace_flag) + print("SET UPLOAD FLAG!") + upload_flag = 1 + if img_list: + print("SET REPLACE FLAG!") + replace_flag = 1 + print('flag', upload_flag, replace_flag) + return upload_flag, replace_flag + + +def example_trigger(text_input, image, upload_flag, replace_flag, img_list): + # set the upload flag to true when receive a new image. + # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. + print('flag', upload_flag, replace_flag) + print("SET UPLOAD FLAG!") + upload_flag = 1 + if img_list or replace_flag == 1: + print("SET REPLACE FLAG!") + replace_flag = 1 + + print('flag', upload_flag, replace_flag) + return upload_flag, replace_flag + + +def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag): + if isinstance(gr_img, dict): + gr_img, mask = gr_img['image'], gr_img['mask'] + else: + mask = None + + if '[identify]' in user_message: + # check if user provide bbox in the text input + integers = re.findall(r'-?\d+', user_message) + if len(integers) != 4: # no bbox in text + bbox = mask2bbox(mask) + user_message = user_message + bbox + + if len(user_message) == 0: + return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state + + if chat_state is None: + chat_state = CONV_VISION.copy() + + print('upload flag: {}'.format(upload_flag)) + if upload_flag: + if replace_flag: + print('RESET!!!!!!!') + chat_state = CONV_VISION.copy() # new image, reset everything + replace_flag = 0 + chatbot = [] + print('UPLOAD IMAGE!!') + img_list = [] + llm_message = chat.upload_img(gr_img, chat_state, img_list) + upload_flag = 0 + + chat.ask(user_message, chat_state) + + chatbot = chatbot + [[user_message, None]] + + if '[identify]' in user_message: + visual_img, _ = visualize_all_bbox_together(gr_img, user_message) + if visual_img is not None: + print('Visualizing the input') + file_path = save_tmp_img(visual_img) + chatbot = chatbot + [[(file_path,), None]] + + return '', chatbot, chat_state, img_list, upload_flag, replace_flag + + +def gradio_answer(chatbot, chat_state, img_list, temperature): + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + temperature=temperature, + max_new_tokens=500, + max_length=2000)[0] + chatbot[-1][1] = llm_message + return chatbot, chat_state + + +def gradio_stream_answer(chatbot, chat_state, img_list, temperature): + print('chat state', chat_state.get_prompt()) + if not isinstance(img_list[0], torch.Tensor): + chat.encode_img(img_list) + streamer = chat.stream_answer(conv=chat_state, + img_list=img_list, + temperature=temperature, + max_new_tokens=500, + max_length=2000) + output = '' + for new_output in streamer: + escapped = escape_markdown(new_output) + output += escapped + chatbot[-1][1] = output + yield chatbot, chat_state + # print('message: ', chat_state.messages) + chat_state.messages[-1][1] = '
' + return chatbot, chat_state + + +def gradio_visualize(chatbot, gr_img): + if isinstance(gr_img, dict): + gr_img, mask = gr_img['image'], gr_img['mask'] + + unescaped = reverse_escape(chatbot[-1][1]) + visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped) + if visual_img is not None: + print('Visualizing the output') + if len(generation_color): + chatbot[-1][1] = generation_color + file_path = save_tmp_img(visual_img) + chatbot = chatbot + [[None, (file_path,)]] + + return chatbot + + +def gradio_taskselect(idx): + prompt_list = [ + '', + '[grounding] describe this image in detail', + '[refer] ', + '[detection] ', + '[identify] what is this ', + '[vqa] ' + ] + instruct_list = [ + '**Hint:** Type in whatever you want', + '**Hint:** Send the command to generate a grounded image description', + '**Hint:** Type in a phrase about an object in the image and send the command', + '**Hint:** Type in a caption or phrase, and see object locations in the image', + '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw', + '**Hint:** Send a question to get a short answer', + ] + return prompt_list[idx], instruct_list[idx] + + + + +chat = Chat(model, vis_processor, device=device) + +title = """

MiniGPT-v2 Demo

""" +description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!' +# article = """

""" +article = """

""" + +introduction = ''' +For Abilities Involving Visual Grounding: +1. Grounding: CLICK **Send** to generate a grounded image description. +2. Refer: Input a referring object and CLICK **Send**. +3. Detection: Write a caption or phrase, and CLICK **Send**. +4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time). +5. VQA: Input a visual question and CLICK **Send**. +6. No Tag: Input whatever you want and CLICK **Send** without any tagging + +You can also simply chat in free form! +''' + +text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False, + scale=8) +with gr.Blocks() as demo: + gr.Markdown(title) + # gr.Markdown(description) + gr.Markdown(article) + + with gr.Row(): + with gr.Column(scale=0.5): + image = gr.Image(type="pil", tool='sketch', brush_radius=20) + + temperature = gr.Slider( + minimum=0.1, + maximum=2.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + ) + + clear = gr.Button("Restart") + + gr.Markdown(introduction) + + with gr.Column(): + chat_state = gr.State(value=None) + img_list = gr.State(value=[]) + chatbot = gr.Chatbot(label='MiniGPT-v2') + + dataset = gr.Dataset( + components=[gr.Textbox(visible=False)], + samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']], + type="index", + label='Task Shortcuts', + ) + task_inst = gr.Markdown('**Hint:** Upload your image and chat') + with gr.Row(): + text_input.render() + send = gr.Button("Send", variant='primary', size='sm', scale=1) + + upload_flag = gr.State(value=0) + replace_flag = gr.State(value=0) + image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag]) + + with gr.Row(): + with gr.Column(): + gr.Examples(examples=[ + ["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag, + img_list], + ["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list], + ["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag, + img_list], + ["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag, + replace_flag, img_list], + ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger, + outputs=[upload_flag, replace_flag]) + with gr.Column(): + gr.Examples(examples=[ + ["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek", + upload_flag, replace_flag, img_list], + ["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list], + ["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list], + ["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag, + replace_flag, img_list], + ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger, + outputs=[upload_flag, replace_flag]) + + dataset.click( + gradio_taskselect, + inputs=[dataset], + outputs=[text_input, task_inst], + show_progress="hidden", + postprocess=False, + queue=False, + ) + + text_input.submit( + gradio_ask, + [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], + [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False + ).success( + gradio_stream_answer, + [chatbot, chat_state, img_list, temperature], + [chatbot, chat_state] + ).success( + gradio_visualize, + [chatbot, image], + [chatbot], + queue=False, + ) + + send.click( + gradio_ask, + [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], + [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False + ).success( + gradio_stream_answer, + [chatbot, chat_state, img_list, temperature], + [chatbot, chat_state] + ).success( + gradio_visualize, + [chatbot, image], + [chatbot], + queue=False, + ) + + clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False) + +demo.launch(share=True, enable_queue=True) diff --git a/environment.yml b/environment.yml index d5cfcf8..cf90e89 100644 --- a/environment.yml +++ b/environment.yml @@ -7,57 +7,27 @@ dependencies: - python=3.9 - cudatoolkit - pip - - pytorch=1.12.1 - - pytorch-mutex=1.0=cuda - - torchaudio=0.12.1 - - torchvision=0.13.1 - pip: - - accelerate==0.16.0 - - aiohttp==3.8.4 - - aiosignal==1.3.1 - - async-timeout==4.0.2 - - attrs==22.2.0 - - bitsandbytes==0.37.0 - - cchardet==2.1.7 - - chardet==5.1.0 - - contourpy==1.0.7 - - cycler==0.11.0 - - filelock==3.9.0 - - fonttools==4.38.0 - - frozenlist==1.3.3 - - huggingface-hub==0.13.4 - - importlib-resources==5.12.0 - - kiwisolver==1.4.4 + - torch==2.0.0 + - torchaudio + - torchvision + - huggingface-hub==0.18.0 - matplotlib==3.7.0 - - multidict==6.0.4 - - openai==0.27.0 - - packaging==23.0 - psutil==5.9.4 - - pycocotools==2.0.6 - - pyparsing==3.0.9 - - python-dateutil==2.8.2 + - iopath - pyyaml==6.0 - regex==2022.10.31 - tokenizers==0.13.2 - tqdm==4.64.1 - - transformers==4.28.0 + - transformers==4.30.0 - timm==0.6.13 - - spacy==3.5.1 - webdataset==0.2.48 - - scikit-learn==1.2.2 - - scipy==1.10.1 - - yarl==1.8.2 - - zipp==3.14.0 - omegaconf==2.3.0 - opencv-python==4.7.0.72 - - iopath==0.1.10 - decord==0.6.0 - - tenacity==8.2.2 - - peft - - pycocoevalcap + - peft==0.2.0 - sentence-transformers - - umap-learn - - notebook - - gradio==3.24.1 - - gradio-client==0.0.8 + - gradio==3.47.1 + - accelerate==0.20.3 + - bitsandbytes==0.37.0 - wandb diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index b653eb7..d73ba3a 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -1,11 +1,11 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_vicuna0 max_txt_len: 160 end_sym: "###" low_resource: True prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/path/to/checkpoint/' + ckpt: 'please set this value to the path of pretrained checkpoint' datasets: diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml index eea99d3..93efab1 100644 --- a/eval_configs/minigpt4_llama2_eval.yaml +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -1,11 +1,11 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_llama2 max_txt_len: 160 end_sym: "" low_resource: True prompt_template: '[INST] {} [/INST] ' - ckpt: '/path/to/checkpoint/' + ckpt: 'please set this value to the path of pretrained checkpoint' datasets: diff --git a/eval_configs/minigptv2_eval.yaml b/eval_configs/minigptv2_eval.yaml new file mode 100644 index 0000000..0479f2a --- /dev/null +++ b/eval_configs/minigptv2_eval.yaml @@ -0,0 +1,24 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 160 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST]' + ckpt: 'please set this value to the path of pretrained checkpoint' + lora_r: 64 + lora_alpha: 16 + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/examples_v2/2000x1372_wmkn_0012149409555.jpg b/examples_v2/2000x1372_wmkn_0012149409555.jpg new file mode 100755 index 0000000..1250f7f Binary files /dev/null and b/examples_v2/2000x1372_wmkn_0012149409555.jpg differ diff --git a/examples_v2/KFC-20-for-20-Nuggets.jpg b/examples_v2/KFC-20-for-20-Nuggets.jpg new file mode 100755 index 0000000..0ec641c Binary files /dev/null and b/examples_v2/KFC-20-for-20-Nuggets.jpg differ diff --git a/examples_v2/cockdial.png b/examples_v2/cockdial.png new file mode 100755 index 0000000..935f98e Binary files /dev/null and b/examples_v2/cockdial.png differ diff --git a/examples_v2/float.png b/examples_v2/float.png new file mode 100755 index 0000000..900dcb0 Binary files /dev/null and b/examples_v2/float.png differ diff --git a/examples_v2/glip_test.jpg b/examples_v2/glip_test.jpg new file mode 100755 index 0000000..f9198f2 Binary files /dev/null and b/examples_v2/glip_test.jpg differ diff --git a/examples_v2/office.jpg b/examples_v2/office.jpg new file mode 100755 index 0000000..e35bdc2 Binary files /dev/null and b/examples_v2/office.jpg differ diff --git a/examples_v2/sofa.jpg b/examples_v2/sofa.jpg new file mode 100755 index 0000000..8610591 Binary files /dev/null and b/examples_v2/sofa.jpg differ diff --git a/examples_v2/thief.png b/examples_v2/thief.png new file mode 100755 index 0000000..579ee52 Binary files /dev/null and b/examples_v2/thief.png differ diff --git a/figs/demo.png b/figs/demo.png new file mode 100644 index 0000000..2e573de Binary files /dev/null and b/figs/demo.png differ diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt4_llama2.yaml index c201bdc..fdd25e0 100644 --- a/minigpt4/configs/models/minigpt4_llama2.yaml +++ b/minigpt4/configs/models/minigpt4_llama2.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 # vit encoder image_size: 224 @@ -12,7 +12,7 @@ model: # generation configs prompt: "" - llama_model: "/path/to/llama2/weight" + llama_model: "please set this value to the path of llama2-chat-7b" preprocess: vis_processor: diff --git a/minigpt4/configs/models/minigpt4_vicuna0.yaml b/minigpt4/configs/models/minigpt4_vicuna0.yaml index 34bd2ed..718054c 100644 --- a/minigpt4/configs/models/minigpt4_vicuna0.yaml +++ b/minigpt4/configs/models/minigpt4_vicuna0.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 # vit encoder image_size: 224 @@ -15,7 +15,7 @@ model: # generation configs prompt: "" - llama_model: "/path/to/vicuna/weight" + llama_model: "please set this value to the path of vicuna model" preprocess: vis_processor: diff --git a/minigpt4/configs/models/minigpt_v2.yaml b/minigpt4/configs/models/minigpt_v2.yaml new file mode 100755 index 0000000..1d85d20 --- /dev/null +++ b/minigpt4/configs/models/minigpt_v2.yaml @@ -0,0 +1,31 @@ +model: + arch: minigpt_v2 + + # vit encoder + image_size: 448 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # generation configs + prompt: "" + + llama_model: "please set this value to the path of llama2-chat-7b" + lora_r: 64 + lora_alpha: 16 + + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + eval: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 7678814..9c27c78 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -1,10 +1,11 @@ import argparse import time +from threading import Thread from PIL import Image import torch from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer -from transformers import StoppingCriteria, StoppingCriteriaList +from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import dataclasses from enum import auto, Enum @@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation( class Chat: - def __init__(self, model, vis_processor, device='cuda:0'): + def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None): self.device = device self.model = model self.vis_processor = vis_processor - stop_words_ids = [torch.tensor([835]).to(self.device), - torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. - self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + + if stopping_criteria is not None: + self.stopping_criteria = stopping_criteria + else: + stop_words_ids = [torch.tensor([2]).to(self.device)] + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def ask(self, text, conv): if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ @@ -144,8 +148,8 @@ class Chat: else: conv.append_message(conv.roles[0], text) - def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) @@ -154,10 +158,9 @@ class Chat: print('Warning: The number of tokens in current conversation exceeds the max length. ' 'The model will not see the contexts outside the range.') begin_idx = max(0, current_max_len - max_length) - embs = embs[:, begin_idx:] - outputs = self.model.llama_model.generate( + generation_kwargs = dict( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, @@ -169,18 +172,31 @@ class Chat: length_penalty=length_penalty, temperature=temperature, ) - output_token = outputs[0] - if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it - output_token = output_token[1:] - if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it - output_token = output_token[1:] - output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + return generation_kwargs + + def answer(self, conv, img_list, **kargs): + generation_dict = self.answer_prepare(conv, img_list, **kargs) + + output_token = self.model.llama_model.generate(**generation_dict)[0] + output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() + conv.messages[-1][1] = output_text return output_text, output_token.cpu().numpy() - def upload_img(self, image, conv, img_list): + def stream_answer(self, conv, img_list, **kargs): + generation_kwargs = self.answer_prepare(conv, img_list, **kargs) + streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) + generation_kwargs['streamer'] = streamer + thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + + def encode_img(self, img_list): + image = img_list[0] + img_list.pop(0) if isinstance(image, str): # is a image path raw_image = Image.open(image).convert('RGB') image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) @@ -194,9 +210,12 @@ class Chat: image_emb, _ = self.model.encode_img(image) img_list.append(image_emb) + + def upload_img(self, image, conv, img_list): conv.append_message(conv.roles[0], "") + img_list.append(image) msg = "Received." - # self.conv.append_message(self.conv.roles[1], msg) + return msg def get_context_emb(self, conv, img_list): @@ -209,7 +228,9 @@ class Chat: # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] - seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] + print('debug device: ', self.device) + print('debug model device: ', self.model.device) + seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py index 54acd24..bc01b56 100644 --- a/minigpt4/models/__init__.py +++ b/minigpt4/models/__init__.py @@ -11,16 +11,18 @@ from omegaconf import OmegaConf from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel -from minigpt4.models.blip2 import Blip2Base -from minigpt4.models.mini_gpt4 import MiniGPT4 +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.minigpt4 import MiniGPT4 +from minigpt4.models.minigpt_v2 import MiniGPTv2 from minigpt4.processors.base_processor import BaseProcessor __all__ = [ "load_model", "BaseModel", - "Blip2Base", + "MiniGPTBase", "MiniGPT4", + "MiniGPTv2" ] diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index 2a13393..fd1d636 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -5,15 +5,26 @@ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ -import logging import os +import logging +import contextlib +from omegaconf import OmegaConf import numpy as np import torch import torch.nn as nn +from transformers import BertTokenizer, LlamaTokenizer +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_int8_training, +) + from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized from minigpt4.common.utils import get_abs_path, is_url -from omegaconf import OmegaConf +from minigpt4.models.eva_vit import create_eva_vit_g + class BaseModel(nn.Module): @@ -117,131 +128,121 @@ class BaseModel(nn.Module): else: return tot + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") -class BaseEncoder(nn.Module): - """ - Base class for primitive encoders, such as ViT, TimeSformer, etc. - """ + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() - def __init__(self): - super().__init__() + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze + ): + logging.info('Loading VIT') - def forward_features(self, samples, **kwargs): - raise NotImplementedError + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + if not freeze: + precision = "fp32" # fp16 is not for training - @property - def device(self): - return list(self.parameters())[0].device + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + + if freeze: + for name, param in visual_encoder.named_parameters(): + param.requires_grad = False + visual_encoder = visual_encoder.eval() + visual_encoder.train = disabled_train + for name, param in ln_vision.named_parameters(): + param.requires_grad = False + ln_vision = ln_vision.eval() + ln_vision.train = disabled_train + logging.info("freeze vision encoder") + + logging.info('Loading VIT Done') + return visual_encoder, ln_vision + + def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0, + lora_target_modules=["q_proj","v_proj"], **lora_kargs): + logging.info('Loading LLAMA') + llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) + llama_tokenizer.pad_token = "$$" + + if low_resource: + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={'': low_res_device} + ) + else: + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + ) + + if lora_r > 0: + llama_model = prepare_model_for_int8_training(llama_model) + loraconfig = LoraConfig( + r=lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_target_modules, + **lora_kargs + ) + llama_model = get_peft_model(llama_model, loraconfig) + + llama_model.print_trainable_parameters() + + else: + for name, param in llama_model.named_parameters(): + param.requires_grad = False + logging.info('Loading LLAMA Done') + return llama_model, llama_tokenizer -class SharedQueueMixin: - @torch.no_grad() - def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): - # gather keys before updating queue - image_feats = concat_all_gather(image_feat) - text_feats = concat_all_gather(text_feat) + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") - batch_size = image_feats.shape[0] + state_dict = checkpoint["model"] - ptr = int(self.queue_ptr) - assert self.queue_size % batch_size == 0 # for simplicity + msg = self.load_state_dict(state_dict, strict=False) - # replace the keys at ptr (dequeue and enqueue) - self.image_queue[:, ptr : ptr + batch_size] = image_feats.T - self.text_queue[:, ptr : ptr + batch_size] = text_feats.T + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) - if idxs is not None: - idxs = concat_all_gather(idxs) - self.idx_queue[:, ptr : ptr + batch_size] = idxs.T - - ptr = (ptr + batch_size) % self.queue_size # move pointer - self.queue_ptr[0] = ptr + return msg -class MomentumDistilationMixin: - @torch.no_grad() - def copy_params(self): - for model_pair in self.model_pairs: - for param, param_m in zip( - model_pair[0].parameters(), model_pair[1].parameters() - ): - param_m.data.copy_(param.data) # initialize - param_m.requires_grad = False # not update by gradient - - @torch.no_grad() - def _momentum_update(self): - for model_pair in self.model_pairs: - for param, param_m in zip( - model_pair[0].parameters(), model_pair[1].parameters() - ): - param_m.data = param_m.data * self.momentum + param.data * ( - 1.0 - self.momentum - ) +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self -class GatherLayer(torch.autograd.Function): - """ - Gather tensors from all workers with support for backward propagation: - This implementation does not cut the gradients as torch.distributed.all_gather does. - """ +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" - @staticmethod - def forward(ctx, x): - output = [ - torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(output, x) - return tuple(output) - - @staticmethod - def backward(ctx, *grads): - all_gradients = torch.stack(grads) - torch.distributed.all_reduce(all_gradients) - return all_gradients[torch.distributed.get_rank()] + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) -def all_gather_with_grad(tensors): - """ - Performs all_gather operation on the provided tensors. - Graph remains connected for backward grad computation. - """ - # Queue the gathered tensors - world_size = torch.distributed.get_world_size() - # There is no need for reduction in the single-proc case - if world_size == 1: - return tensors - - # tensor_all = GatherLayer.apply(tensors) - tensor_all = GatherLayer.apply(tensors) - - return torch.cat(tensor_all, dim=0) -@torch.no_grad() -def concat_all_gather(tensor): - """ - Performs all_gather operation on the provided tensors. - *** Warning ***: torch.distributed.all_gather has no gradient. - """ - # if use distributed training - if not is_dist_avail_and_initialized(): - return tensor - tensors_gather = [ - torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(tensors_gather, tensor, async_op=False) - - output = torch.cat(tensors_gather, dim=0) - return output - - -def tile(x, dim, n_tile): - init_dim = x.size(dim) - repeat_idx = [1] * x.dim() - repeat_idx[dim] = n_tile - x = x.repeat(*(repeat_idx)) - order_index = torch.LongTensor( - np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) - ) - return torch.index_select(x, dim, order_index.to(x.device)) diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py deleted file mode 100644 index ee4a9dc..0000000 --- a/minigpt4/models/blip2.py +++ /dev/null @@ -1,221 +0,0 @@ -""" - Copyright (c) 2023, salesforce.com, inc. - All rights reserved. - SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" -import contextlib -import logging -import os -import time -import datetime - -import torch -import torch.nn as nn -import torch.distributed as dist -import torch.nn.functional as F - -import minigpt4.common.dist_utils as dist_utils -from minigpt4.common.dist_utils import download_cached_file -from minigpt4.common.utils import is_url -from minigpt4.common.logger import MetricLogger -from minigpt4.models.base_model import BaseModel -from minigpt4.models.Qformer import BertConfig, BertLMHeadModel -from minigpt4.models.eva_vit import create_eva_vit_g -from transformers import BertTokenizer - - -class Blip2Base(BaseModel): - @classmethod - def init_tokenizer(cls): - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - tokenizer.add_special_tokens({"bos_token": "[DEC]"}) - return tokenizer - - def maybe_autocast(self, dtype=torch.float16): - # if on cpu, don't use autocast - # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 - enable_autocast = self.device != torch.device("cpu") - - if enable_autocast: - return torch.cuda.amp.autocast(dtype=dtype) - else: - return contextlib.nullcontext() - - @classmethod - def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): - encoder_config = BertConfig.from_pretrained("bert-base-uncased") - encoder_config.encoder_width = vision_width - # insert cross-attention layer every other block - encoder_config.add_cross_attention = True - encoder_config.cross_attention_freq = cross_attention_freq - encoder_config.query_length = num_query_token - Qformer = BertLMHeadModel(config=encoder_config) - query_tokens = nn.Parameter( - torch.zeros(1, num_query_token, encoder_config.hidden_size) - ) - query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) - return Qformer, query_tokens - - @classmethod - def init_vision_encoder( - cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision - ): - assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" - visual_encoder = create_eva_vit_g( - img_size, drop_path_rate, use_grad_checkpoint, precision - ) - - ln_vision = LayerNorm(visual_encoder.num_features) - return visual_encoder, ln_vision - - def load_from_pretrained(self, url_or_filename): - if is_url(url_or_filename): - cached_file = download_cached_file( - url_or_filename, check_hash=False, progress=True - ) - checkpoint = torch.load(cached_file, map_location="cpu") - elif os.path.isfile(url_or_filename): - checkpoint = torch.load(url_or_filename, map_location="cpu") - else: - raise RuntimeError("checkpoint url or path is invalid") - - state_dict = checkpoint["model"] - - msg = self.load_state_dict(state_dict, strict=False) - - # logging.info("Missing keys {}".format(msg.missing_keys)) - logging.info("load checkpoint from %s" % url_or_filename) - - return msg - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -def compute_sim_matrix(model, data_loader, **kwargs): - k_test = kwargs.pop("k_test") - - metric_logger = MetricLogger(delimiter=" ") - header = "Evaluation:" - - logging.info("Computing features for evaluation...") - start_time = time.time() - - texts = data_loader.dataset.text - num_text = len(texts) - text_bs = 256 - text_ids = [] - text_embeds = [] - text_atts = [] - for i in range(0, num_text, text_bs): - text = texts[i : min(num_text, i + text_bs)] - text_input = model.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=35, - return_tensors="pt", - ).to(model.device) - text_feat = model.forward_text(text_input) - text_embed = F.normalize(model.text_proj(text_feat)) - text_embeds.append(text_embed) - text_ids.append(text_input.input_ids) - text_atts.append(text_input.attention_mask) - - text_embeds = torch.cat(text_embeds, dim=0) - text_ids = torch.cat(text_ids, dim=0) - text_atts = torch.cat(text_atts, dim=0) - - vit_feats = [] - image_embeds = [] - for samples in data_loader: - image = samples["image"] - - image = image.to(model.device) - image_feat, vit_feat = model.forward_image(image) - image_embed = model.vision_proj(image_feat) - image_embed = F.normalize(image_embed, dim=-1) - - vit_feats.append(vit_feat.cpu()) - image_embeds.append(image_embed) - - vit_feats = torch.cat(vit_feats, dim=0) - image_embeds = torch.cat(image_embeds, dim=0) - - sims_matrix = [] - for image_embed in image_embeds: - sim_q2t = image_embed @ text_embeds.t() - sim_i2t, _ = sim_q2t.max(0) - sims_matrix.append(sim_i2t) - sims_matrix = torch.stack(sims_matrix, dim=0) - - score_matrix_i2t = torch.full( - (len(data_loader.dataset.image), len(texts)), -100.0 - ).to(model.device) - - num_tasks = dist_utils.get_world_size() - rank = dist_utils.get_rank() - step = sims_matrix.size(0) // num_tasks + 1 - start = rank * step - end = min(sims_matrix.size(0), start + step) - - for i, sims in enumerate( - metric_logger.log_every(sims_matrix[start:end], 50, header) - ): - topk_sim, topk_idx = sims.topk(k=k_test, dim=0) - image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) - score = model.compute_itm( - image_inputs=image_inputs, - text_ids=text_ids[topk_idx], - text_atts=text_atts[topk_idx], - ).float() - score_matrix_i2t[start + i, topk_idx] = score + topk_sim - - sims_matrix = sims_matrix.t() - score_matrix_t2i = torch.full( - (len(texts), len(data_loader.dataset.image)), -100.0 - ).to(model.device) - - step = sims_matrix.size(0) // num_tasks + 1 - start = rank * step - end = min(sims_matrix.size(0), start + step) - - for i, sims in enumerate( - metric_logger.log_every(sims_matrix[start:end], 50, header) - ): - topk_sim, topk_idx = sims.topk(k=k_test, dim=0) - image_inputs = vit_feats[topk_idx.cpu()].to(model.device) - score = model.compute_itm( - image_inputs=image_inputs, - text_ids=text_ids[start + i].repeat(k_test, 1), - text_atts=text_atts[start + i].repeat(k_test, 1), - ).float() - score_matrix_t2i[start + i, topk_idx] = score + topk_sim - - if dist_utils.is_dist_avail_and_initialized(): - dist.barrier() - torch.distributed.all_reduce( - score_matrix_i2t, op=torch.distributed.ReduceOp.SUM - ) - torch.distributed.all_reduce( - score_matrix_t2i, op=torch.distributed.ReduceOp.SUM - ) - - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logging.info("Evaluation time {}".format(total_time_str)) - - return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/minigpt4/models/blip2_outputs.py b/minigpt4/models/blip2_outputs.py deleted file mode 100644 index e8722b1..0000000 --- a/minigpt4/models/blip2_outputs.py +++ /dev/null @@ -1,110 +0,0 @@ -""" - Copyright (c) 2022, salesforce.com, inc. - All rights reserved. - SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - -from dataclasses import dataclass -from typing import Optional - -import torch -from transformers.modeling_outputs import ( - ModelOutput, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, -) - - -@dataclass -class BlipSimilarity(ModelOutput): - sim_i2t: torch.FloatTensor = None - sim_t2i: torch.FloatTensor = None - - sim_i2t_m: Optional[torch.FloatTensor] = None - sim_t2i_m: Optional[torch.FloatTensor] = None - - sim_i2t_targets: Optional[torch.FloatTensor] = None - sim_t2i_targets: Optional[torch.FloatTensor] = None - - -@dataclass -class BlipIntermediateOutput(ModelOutput): - """ - Data class for intermediate outputs of BLIP models. - - image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). - text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). - - image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). - text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). - - encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. - encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. - - decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. - decoder_labels (torch.LongTensor): labels for the captioning loss. - - itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). - itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) - - """ - - # uni-modal features - image_embeds: torch.FloatTensor = None - text_embeds: Optional[torch.FloatTensor] = None - - image_embeds_m: Optional[torch.FloatTensor] = None - text_embeds_m: Optional[torch.FloatTensor] = None - - # intermediate outputs of multimodal encoder - encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None - encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None - - itm_logits: Optional[torch.FloatTensor] = None - itm_labels: Optional[torch.LongTensor] = None - - # intermediate outputs of multimodal decoder - decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None - decoder_labels: Optional[torch.LongTensor] = None - - -@dataclass -class BlipOutput(ModelOutput): - # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. - sims: Optional[BlipSimilarity] = None - - intermediate_output: BlipIntermediateOutput = None - - loss: Optional[torch.FloatTensor] = None - - loss_itc: Optional[torch.FloatTensor] = None - - loss_itm: Optional[torch.FloatTensor] = None - - loss_lm: Optional[torch.FloatTensor] = None - - -@dataclass -class BlipOutputFeatures(ModelOutput): - """ - Data class of features from BlipFeatureExtractor. - - Args: - image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional - image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional - text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional - text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional - - The first embedding or feature is for the [CLS] token. - - Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. - """ - - image_embeds: Optional[torch.FloatTensor] = None - image_embeds_proj: Optional[torch.FloatTensor] = None - - text_embeds: Optional[torch.FloatTensor] = None - text_embeds_proj: Optional[torch.FloatTensor] = None - - multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py deleted file mode 100644 index faed3d5..0000000 --- a/minigpt4/models/mini_gpt4.py +++ /dev/null @@ -1,384 +0,0 @@ -import logging -import random - -import torch -from torch.cuda.amp import autocast as autocast -import torch.nn as nn - -from minigpt4.common.registry import registry -from minigpt4.models.blip2 import Blip2Base, disabled_train -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers import LlamaTokenizer - -from peft import ( - LoraConfig, - get_peft_model, - get_peft_model_state_dict, - prepare_model_for_int8_training, - set_peft_model_state_dict, -) - - -@registry.register_model("mini_gpt4") -class MiniGPT4(Blip2Base): - """ - BLIP2 GPT-LLAMA model. - """ - - PRETRAINED_MODEL_CONFIG_DICT = { - "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml", - "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", - } - - def __init__( - self, - vit_model="eva_clip_g", - q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", - img_size=224, - drop_path_rate=0, - use_grad_checkpoint=False, - vit_precision="fp16", - freeze_vit=True, - has_qformer=True, - freeze_qformer=True, - num_query_token=32, - llama_model="", - prompt_path="", - prompt_template="", - max_txt_len=32, - end_sym='\n', - low_resource=False, # use 8 bit and put vit in cpu - device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. - lora_r=0, - lora_target_modules=["q_proj", "v_proj"], - lora_alpha=16, - lora_dropout=0.05, - ): - super().__init__() - - self.tokenizer = self.init_tokenizer() - self.low_resource = low_resource - - print('Loading VIT') - self.visual_encoder, self.ln_vision = self.init_vision_encoder( - vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision - ) - if freeze_vit: - for name, param in self.visual_encoder.named_parameters(): - param.requires_grad = False - self.visual_encoder = self.visual_encoder.eval() - self.visual_encoder.train = disabled_train - for name, param in self.ln_vision.named_parameters(): - param.requires_grad = False - self.ln_vision = self.ln_vision.eval() - self.ln_vision.train = disabled_train - logging.info("freeze vision encoder") - print('Loading VIT Done') - - self.has_qformer = has_qformer - if self.has_qformer: - print('Loading Q-Former') - self.Qformer, self.query_tokens = self.init_Qformer( - num_query_token, self.visual_encoder.num_features - ) - self.Qformer.cls = None - self.Qformer.bert.embeddings.word_embeddings = None - self.Qformer.bert.embeddings.position_embeddings = None - for layer in self.Qformer.bert.encoder.layer: - layer.output = None - layer.intermediate = None - self.load_from_pretrained(url_or_filename=q_former_model) - - if freeze_qformer: - for name, param in self.Qformer.named_parameters(): - param.requires_grad = False - self.Qformer = self.Qformer.eval() - self.Qformer.train = disabled_train - self.query_tokens.requires_grad = False - logging.info("freeze Qformer") - - img_f_dim = self.Qformer.config.hidden_size - print('Loading Q-Former Done') - else: - img_f_dim = self.visual_encoder.num_features * 4 - print('Do not use Q-Former here.') - - print('Loading LLAMA') - self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) - self.llama_tokenizer.pad_token = "$$" - - if self.low_resource: - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - load_in_8bit=True, - device_map={'': device_8bit} - ) - else: - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - ) - - if lora_r > 0: - self.llama_model = prepare_model_for_int8_training(self.llama_model) - loraconfig = LoraConfig( - r=lora_r, - lora_alpha=lora_alpha, - target_modules=lora_target_modules, - lora_dropout=lora_dropout, - bias="none", - task_type="CAUSAL_LM" - ) - self.llama_model = get_peft_model(self.llama_model, loraconfig) - - # if ckpt_path: - # print('load the llm under lora') - # ckpt = torch.load(ckpt_path) - # set_peft_model_state_dict(self.llama_model,ckpt) - self.llama_model.print_trainable_parameters() - - else: - for name, param in self.llama_model.named_parameters(): - param.requires_grad = False - print('Loading LLAMA Done') - - self.llama_proj = nn.Linear( - img_f_dim, self.llama_model.config.hidden_size - ) - self.max_txt_len = max_txt_len - self.end_sym = end_sym - - if prompt_path: - with open(prompt_path, 'r') as f: - raw_prompts = f.read().splitlines() - filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] - self.prompt_list = [prompt_template.format(p) for p in filted_prompts] - print('Load {} training prompts'.format(len(self.prompt_list))) - print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) - else: - self.prompt_list = [] - - def vit_to_cpu(self): - self.ln_vision.to("cpu") - self.ln_vision.float() - self.visual_encoder.to("cpu") - self.visual_encoder.float() - - def encode_img(self, image): - device = image.device - if self.low_resource: - self.vit_to_cpu() - image = image.to("cpu") - - with self.maybe_autocast(): - image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) - if self.has_qformer: - image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_output = self.Qformer.bert( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=True, - ) - - inputs_llama = self.llama_proj(query_output.last_hidden_state) - else: - image_embeds = image_embeds[:, 1:, :] - bs, pn, hs = image_embeds.shape - image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) - - inputs_llama = self.llama_proj(image_embeds) - atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) - return inputs_llama, atts_llama - - def get_context_emb(self, prompt, img_list): - device = img_list[0].device - prompt_segs = prompt.split('') - assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." - seg_tokens = [ - self.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids - # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] - - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - - def prompt_wrap(self, img_embeds, atts_img, prompts): - if prompts: - emb_lists = [] - if isinstance(prompts, str): - prompts = [prompts] * len(img_embeds) - - for each_img_embed, each_prompt in zip(img_embeds, prompts): - p_before, p_after = each_prompt.split('') - - p_before_tokens = self.llama_tokenizer( - p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_after_tokens = self.llama_tokenizer( - p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_before_embed = self.embed_tokens(p_before_tokens.input_ids) - p_after_embed = self.embed_tokens(p_after_tokens.input_ids) - wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1) - emb_lists.append(wrapped_emb) - emb_lens = [emb.shape[1] for emb in emb_lists] - pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) - wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() - wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device) - for i, emb in enumerate(emb_lists): - wrapped_embs[i, :emb_lens[i]] = emb - wrapped_atts[i, :emb_lens[i]] = 1 - return wrapped_embs, wrapped_atts - else: - return img_embeds, atts_img - - def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): - input_lens = [] - cat_embs = [] - cat_atts = [] - for i in range(input_embs.size(0)): - input_len = input_atts[i].sum() - input_lens.append(input_len) - cat_embs.append( - torch.cat([ - input_embs[i][:input_len], - output_embs[i], - input_embs[i][input_len:] - ]) - ) - cat_atts.append( - torch.cat([ - input_atts[i][:input_len], - output_atts[i], - input_atts[i][input_len:] - ]) - ) - cat_embs = torch.stack(cat_embs) - cat_atts = torch.stack(cat_atts) - return cat_embs, cat_atts, input_lens - - def forward(self, samples): - image = samples["image"] - img_embeds, atts_img = self.encode_img(image) - - if self.prompt_list: - instruction = random.choice(self.prompt_list) - else: - instruction = samples["instruction_input"] if "instruction_input" in samples else None - - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction) - - self.llama_tokenizer.padding_side = "right" - text = [t + self.end_sym for t in samples["answer"]] - - to_regress_tokens = self.llama_tokenizer( - text, - return_tensors="pt", - padding="longest", - truncation=True, - max_length=self.max_txt_len, - add_special_tokens=False - ).to(image.device) - - batch_size = img_embeds.shape[0] - bos = torch.ones([batch_size, 1], - dtype=to_regress_tokens.input_ids.dtype, - device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id - bos_embeds = self.embed_tokens(bos) - atts_bos = atts_img[:, :1] - - to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids) - inputs_embeds, attention_mask, input_lens = \ - self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask) - inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) - attention_mask = torch.cat([atts_bos, attention_mask], dim=1) - - part_targets = to_regress_tokens.input_ids.masked_fill( - to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 - ) - targets = ( - torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], - dtype=torch.long).to(image.device).fill_(-100) - ) - - for i, target in enumerate(part_targets): - targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos - - with self.maybe_autocast(): - outputs = self.llama_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - return_dict=True, - labels=targets, - ) - loss = outputs.loss - - return {"loss": loss} - - def embed_tokens(self, token_ids): - if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model - embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) - else: - embeds = self.llama_model.base_model.embed_tokens(token_ids) - return embeds - - @classmethod - def from_config(cls, cfg): - vit_model = cfg.get("vit_model", "eva_clip_g") - q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") - img_size = cfg.get("image_size") - num_query_token = cfg.get("num_query_token") - llama_model = cfg.get("llama_model") - - drop_path_rate = cfg.get("drop_path_rate", 0) - use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) - vit_precision = cfg.get("vit_precision", "fp16") - freeze_vit = cfg.get("freeze_vit", True) - has_qformer = cfg.get("has_qformer", True) - freeze_qformer = cfg.get("freeze_qformer", True) - low_resource = cfg.get("low_resource", False) - device_8bit = cfg.get("device_8bit", 0) - - prompt_path = cfg.get("prompt_path", "") - prompt_template = cfg.get("prompt_template", "") - max_txt_len = cfg.get("max_txt_len", 32) - end_sym = cfg.get("end_sym", '\n') - - lora_r = cfg.get("lora_r", 0) - lora_alpha = cfg.get("lora_alpha", 32) - - model = cls( - vit_model=vit_model, - q_former_model=q_former_model, - img_size=img_size, - drop_path_rate=drop_path_rate, - use_grad_checkpoint=use_grad_checkpoint, - vit_precision=vit_precision, - freeze_vit=freeze_vit, - has_qformer=has_qformer, - freeze_qformer=freeze_qformer, - num_query_token=num_query_token, - llama_model=llama_model, - prompt_path=prompt_path, - prompt_template=prompt_template, - max_txt_len=max_txt_len, - end_sym=end_sym, - low_resource=low_resource, - device_8bit=device_8bit, - lora_r=lora_r, - lora_alpha=lora_alpha, - ) - - ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 - if ckpt_path: - print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) - ckpt = torch.load(ckpt_path, map_location="cpu") - msg = model.load_state_dict(ckpt['model'], strict=False) - - return model diff --git a/minigpt4/models/minigpt4.py b/minigpt4/models/minigpt4.py new file mode 100644 index 0000000..a2e4798 --- /dev/null +++ b/minigpt4/models/minigpt4.py @@ -0,0 +1,195 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + +@registry.register_model("minigpt4") +class MiniGPT4(MiniGPTBase): + """ + MiniGPT-4 model + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml", + "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + has_qformer=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + self.has_qformer = has_qformer + if self.has_qformer: + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features, freeze_qformer + ) + self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here + + img_f_dim = self.Qformer.config.hidden_size + print('Loading Q-Former Done') + else: + img_f_dim = self.visual_encoder.num_features * 4 + print('Do not use Q-Former here.') + + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, freeze): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = 2 + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + if freeze: + for name, param in Qformer.named_parameters(): + param.requires_grad = False + Qformer = Qformer.eval() + Qformer.train = disabled_train + query_tokens.requires_grad = False + logging.info("freeze Qformer") + + return Qformer, query_tokens + + def encode_img(self, image): + device = image.device + + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + if self.has_qformer: + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + else: + image_embeds = image_embeds[:, 1:, :] + bs, pn, hs = image_embeds.shape + image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + has_qformer = cfg.get("has_qformer", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 32) + end_sym = cfg.get("end_sym", '\n') + + model = cls( + vit_model=vit_model, + q_former_model=q_former_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + has_qformer=has_qformer, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py new file mode 100644 index 0000000..77c919a --- /dev/null +++ b/minigpt4/models/minigpt_base.py @@ -0,0 +1,401 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import BaseModel + + + +class MiniGPTBase(BaseModel): + """ + Base class for MiniGPT-4 and MiniGPT-v2 + """ + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + max_txt_len=32, + max_context_len=3800, + prompt_template="", + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + lora_r=0, # lora_r means lora is not used + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + ): + super().__init__() + + self.llama_model, self.llama_tokenizer = self.init_llm( + llama_model_path=llama_model, + low_resource=low_resource, + low_res_device=device_8bit, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit + ) + + self.max_txt_len = max_txt_len + self.max_context_len = max_context_len + self.end_sym = end_sym + + self.prompt_template = prompt_template + self.prompt_list = [] + + def vit_to_cpu(self): + self.ln_vision.to("cpu") + self.ln_vision.float() + self.visual_encoder.to("cpu") + self.visual_encoder.float() + + def get_context_emb(self, prompt, img_list): + device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): + if prompts is None or len(prompts) == 0: + # prompts is not provided, just return the original image embedding + return img_embeds, atts_img + elif img_embeds is None: + # prompt is provided but there is no image embedding. return the prompt embedding in right padding + self.llama_tokenizer.padding_side = "right" + prompt_tokens = self.llama_tokenizer( + prompts, + return_tensors="pt", + padding="longest", + add_special_tokens=False + ).to(self.device) + prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) + atts_prompt = prompt_tokens.attention_mask + return prompt_embeds, atts_prompt + else: + # return the multi-modal embedding in right padding + emb_lists = [] + if isinstance(prompts, str): + prompts = [prompts] * len(img_embeds) + + for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): + pn = each_img_embed.shape[-2] + if lengths is not None: + each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) + each_img_embed = each_img_embed[:lengths[idx] * pn] + p_segs = each_prompt.split('') + interleave_emb = [] + for idx, seg in enumerate(p_segs[:-1]): + p_tokens = self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1)) + wrapped_emb = torch.cat(interleave_emb, dim=1) + p_tokens = self.llama_tokenizer( + p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_embed = self.embed_tokens(p_tokens.input_ids) + wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1) + emb_lists.append(wrapped_emb) + + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + + max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len + wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) + + for i, emb in enumerate(emb_lists): + length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len + wrapped_embs[i, :length] = emb[:, :length] + wrapped_atts[i, :length] = 1 + return wrapped_embs, wrapped_atts + + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + """ + Concatenate the batched input embedding and batched output embedding together. + Both the input and the output embedding should be right padded. + """ + input_lens = [] + cat_embs = [] + cat_atts = [] + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + + def tokenize_conversation(self, conv_q, conv_a): + """concatenate conversation and make sure the model is only trained to regress the answer""" + + to_regress_token_ids_list = [] + targets_list = [] + + batch_size = len(conv_q) + for batch_idx in range(batch_size): + questions, answers = conv_q[batch_idx], conv_a[batch_idx] + questions = [self.llama_tokenizer(q, + return_tensors="pt", + add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it + answers = [self.llama_tokenizer(q, + return_tensors="pt", + add_special_tokens=False).to(self.device) for q in answers] + cur_id = [] + cur_target = [] + for i in range(len(questions)): + cur_id.append(answers[i].input_ids) + cur_target.append(answers[i].input_ids) + cur_id.append(questions[i].input_ids) + cur_target.append(torch.ones_like(questions[i].input_ids) * -100) + + cur_id.append(answers[-1].input_ids) + cur_target.append(answers[-1].input_ids) + + cur_id = torch.cat(cur_id, dim=1) + cur_target = torch.cat(cur_target, dim=1) + to_regress_token_ids_list.append(cur_id) + targets_list.append(cur_target) + + max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) + to_regress_token_ids = torch.ones([batch_size, max_len], + dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id + targets = torch.ones([batch_size, max_len], + dtype=cur_id.dtype, device=self.device) * -100 + for batch_idx in range(batch_size): + cur_len = to_regress_token_ids_list[batch_idx].shape[1] + to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len] + targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] + + to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int) + + return to_regress_token_ids, to_regress_token_attn, targets + + def preparing_embedding(self, samples): + ### prepare input tokens + if 'image' in samples: + img_embeds, img_atts = self.encode_img(samples["image"]) + else: + img_embeds = img_atts = None + + if 'conv_q' in samples: + # handeling conversation datasets + conv_q, conv_a = samples['conv_q'], samples['conv_a'] + + connect_sym = samples['connect_sym'][0] + conv_q = [q.split(connect_sym)for q in conv_q] + conv_a = [a.split(connect_sym) for a in conv_a] + + conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q] + + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) + regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) + + else: + if "instruction_input" in samples: + instruction = samples["instruction_input"] + elif self.prompt_list: + instruction = random.choice(self.prompt_list) + else: + instruction = None + + if self.chat_template: + instruction = [self.prompt_template.format(instruct) for instruct in instruction] + + if 'length' in samples: + # the input is a image train (like videos) + bsz, pn, hs = img_embeds.shape + img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) + else: + cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction) + + ### prepare target tokens + self.llama_tokenizer.padding_side = "right" + text = [t + self.end_sym for t in samples["answer"]] + + regress_tokens = self.llama_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False + ).to(self.device) + + regress_token_ids = regress_tokens.input_ids + regress_atts = regress_tokens.attention_mask + part_targets = regress_token_ids.masked_fill( + regress_token_ids == self.llama_tokenizer.pad_token_id, -100 + ) + + regress_embeds = self.embed_tokens(regress_token_ids) + + return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets + + def forward(self, samples, reduction='mean'): + # prepare the embedding to condition and the embedding to regress + cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \ + self.preparing_embedding(samples) + + # concat the embedding to condition and the embedding to regress + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) + + # get bos token embedding + bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id + bos_embeds = self.embed_tokens(bos) + bos_atts = cond_atts[:, :1] + + # add bos token at the begining + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([bos_atts, attention_mask], dim=1) + + # ensemble the final targets + targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(self.device).fill_(-100) + + for i, target in enumerate(part_targets): + targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos + + with self.maybe_autocast(): + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + reduction=reduction + ) + loss = outputs.loss + + return {"loss": loss} + + def embed_tokens(self, token_ids): + if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + else: + embeds = self.llama_model.base_model.embed_tokens(token_ids) + return embeds + + + @torch.no_grad() + def generate( + self, + images, + texts, + num_beams=1, + max_new_tokens=20, + min_length=1, + top_p=0.9, + repetition_penalty=1, + length_penalty=1, + temperature=1, + do_sample=False, + stop_words_ids=[2], + ): + ''' + function for generate test use + ''' + + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( + stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) + + img_embeds, atts_img = self.encode_img(images.to(self.device)) + image_lists = [[image_emb[None]] for image_emb in img_embeds] + + batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] + + batch_size = len(batch_embs) + max_len = max([emb.shape[1] for emb in batch_embs]) + emb_dim = batch_embs[0].shape[2] + dtype = batch_embs[0].dtype + device = batch_embs[0].device + + embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) + attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) + for i, emb in enumerate(batch_embs): + emb_len = emb.shape[1] + embs[i, -emb_len:] = emb[0] + attn_mask[i, -emb_len:] = 1 + + with self.maybe_autocast(): + outputs = self.llama_model.generate( + inputs_embeds=embs, + attention_mask=attn_mask, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + length_penalty=length_penalty, + temperature=temperature, + do_sample=do_sample, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty + # stopping_criteria=stopping_criteria, + ) + + answers = [] + for output_token in outputs: + if output_token[0] == 0: + output_token = output_token[1:] + output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) + output_texts = output_texts.split('')[0] # remove the stop sign + output_texts = output_texts.replace("", "") + output_texts = output_texts.split(r'[/INST]')[-1].strip() + answers.append(output_texts) + + return answers + + @torch.no_grad() + def multi_select(self, images, texts, answers, num_cand=None): + all_losses = [] + for answer in answers: + choice_samples = { + 'image': images, + 'instruction_input': texts, + 'answer': answer + } + loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) + all_losses.append(loss) + torch.cuda.empty_cache() + all_losses = torch.cat(all_losses, dim=-1) + if num_cand is not None: + for i in range(all_losses.shape[0]): + all_losses[i, num_cand[i]:] = 9999 + output_class_ranks = torch.argsort(all_losses, dim=-1) + return output_class_ranks.tolist() \ No newline at end of file diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py new file mode 100644 index 0000000..a046b0b --- /dev/null +++ b/minigpt4/models/minigpt_v2.py @@ -0,0 +1,139 @@ +import logging +import random + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + +@registry.register_model("minigpt_v2") +class MiniGPTv2(MiniGPTBase): + """ + MiniGPT-v2 model + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain": "configs/models/minigpt_v2.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=448, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + llama_model="", + prompt_template='[INST] {} [/INST]', + max_txt_len=300, + end_sym='\n', + lora_r=64, + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + chat_template=False, + use_grad_checkpoint_llm=False, + max_context_len=3800, + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + max_context_len=max_context_len, + end_sym=end_sym, + prompt_template=prompt_template, + low_resource=low_resource, + device_8bit=device_8bit, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + img_f_dim = self.visual_encoder.num_features * 4 + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + self.chat_template = chat_template + + if use_grad_checkpoint_llm: + self.llama_model.gradient_checkpointing_enable() + + def encode_img(self, image): + device = image.device + + if len(image.shape) > 4: + image = image.reshape(-1, *image.shape[-3:]) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_embeds = image_embeds[:, 1:, :] + bs, pn, hs = image_embeds.shape + image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + + inputs_llama = self.llama_proj(image_embeds) + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama, atts_llama + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + low_resource = cfg.get("low_resource", False) + + prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]') + max_txt_len = cfg.get("max_txt_len", 300) + end_sym = cfg.get("end_sym", '\n') + + lora_r = cfg.get("lora_r", 64) + lora_alpha = cfg.get("lora_alpha", 16) + chat_template = cfg.get("chat_template", False) + + use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False) + max_context_len = cfg.get("max_context_len", 3800) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + low_resource=low_resource, + end_sym=end_sym, + lora_r=lora_r, + lora_alpha=lora_alpha, + chat_template=chat_template, + use_grad_checkpoint_llm=use_grad_checkpoint_llm, + max_context_len=max_context_len, + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + if ckpt_path: + print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(ckpt['model'], strict=False) + + return model diff --git a/minigpt4/models/modeling_llama.py b/minigpt4/models/modeling_llama.py index 12d980e..6d28020 100644 --- a/minigpt4/models/modeling_llama.py +++ b/minigpt4/models/modeling_llama.py @@ -1,628 +1,17 @@ -# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - -""" PyTorch LLaMA model.""" import math from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - query_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if query_embeds is not None: - inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) - batch_size, seq_length, _ = inputs_embeds.shape - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model +class LlamaForCausalLM(LlamaForCausalLMOrig): @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -633,12 +22,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel): position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - query_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -657,13 +46,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel): >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -679,7 +68,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - query_embeds=query_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -687,7 +75,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: @@ -695,12 +89,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = CrossEntropyLoss(reduction=reduction) shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) if not return_dict: output = (logits,) + outputs[1:] @@ -713,43 +109,3 @@ class LlamaForCausalLM(LlamaPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - def prepare_inputs_for_generation( - self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - query_embeds = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "query_embeds": query_embeds, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml index 6920aab..f3981b8 100644 --- a/train_configs/minigpt4_llama2_stage1_pretrain.yaml +++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_llama2 diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml index 9a6ac2d..fa2b578 100644 --- a/train_configs/minigpt4_llama2_stage2_finetune.yaml +++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_llama2 max_txt_len: 160 diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml index 4ec1597..be87b77 100644 --- a/train_configs/minigpt4_stage1_pretrain.yaml +++ b/train_configs/minigpt4_stage1_pretrain.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_vicuna0 diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml index 54cedb4..404dfd6 100644 --- a/train_configs/minigpt4_stage2_finetune.yaml +++ b/train_configs/minigpt4_stage2_finetune.yaml @@ -1,5 +1,5 @@ model: - arch: mini_gpt4 + arch: minigpt4 model_type: pretrain_vicuna0 max_txt_len: 160