mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
add argument to switch 8bit
This commit is contained in:
parent
3e03c8327f
commit
b76d5c5a50
@ -5,6 +5,7 @@ model:
|
|||||||
freeze_qformer: True
|
freeze_qformer: True
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "###"
|
end_sym: "###"
|
||||||
|
low_resource: True
|
||||||
prompt_path: "prompts/alignment.txt"
|
prompt_path: "prompts/alignment.txt"
|
||||||
prompt_template: '###Human: {} ###Assistant: '
|
prompt_template: '###Human: {} ###Assistant: '
|
||||||
ckpt: '/path/to/pretrained/ckpt/'
|
ckpt: '/path/to/pretrained/ckpt/'
|
||||||
|
@ -36,11 +36,13 @@ class MiniGPT4(Blip2Base):
|
|||||||
prompt_path="",
|
prompt_path="",
|
||||||
prompt_template="",
|
prompt_template="",
|
||||||
max_txt_len=32,
|
max_txt_len=32,
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
end_sym='\n',
|
end_sym='\n',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.tokenizer = self.init_tokenizer()
|
self.tokenizer = self.init_tokenizer()
|
||||||
|
self.low_resource = low_resource
|
||||||
|
|
||||||
print('Loading VIT')
|
print('Loading VIT')
|
||||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
@ -83,10 +85,19 @@ class MiniGPT4(Blip2Base):
|
|||||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||||
|
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
if self.low_resource:
|
||||||
llama_model, torch_dtype=torch.float16,
|
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
load_in_8bit=True, device_map="auto"
|
llama_model,
|
||||||
)
|
torch_dtype=torch.float16,
|
||||||
|
load_in_8bit=True,
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
llama_model,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
for name, param in self.llama_model.named_parameters():
|
for name, param in self.llama_model.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
print('Loading LLAMA Done')
|
print('Loading LLAMA Done')
|
||||||
@ -107,18 +118,22 @@ class MiniGPT4(Blip2Base):
|
|||||||
else:
|
else:
|
||||||
self.prompt_list = []
|
self.prompt_list = []
|
||||||
|
|
||||||
def encode_img(self, image):
|
def vit_to_cpu(self):
|
||||||
device = image.device
|
|
||||||
self.ln_vision.to("cpu")
|
self.ln_vision.to("cpu")
|
||||||
self.ln_vision.float()
|
self.ln_vision.float()
|
||||||
self.visual_encoder.to("cpu")
|
self.visual_encoder.to("cpu")
|
||||||
self.visual_encoder.float()
|
self.visual_encoder.float()
|
||||||
image = image.to("cpu")
|
|
||||||
|
|
||||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
def encode_img(self, image):
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
device = image.device
|
||||||
|
if self.low_resource:
|
||||||
|
self.vit_to_cpu()
|
||||||
|
image = image.to("cpu")
|
||||||
|
|
||||||
with self.maybe_autocast():
|
with self.maybe_autocast():
|
||||||
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
query_output = self.Qformer.bert(
|
query_output = self.Qformer.bert(
|
||||||
query_embeds=query_tokens,
|
query_embeds=query_tokens,
|
||||||
@ -216,6 +231,7 @@ class MiniGPT4(Blip2Base):
|
|||||||
vit_precision = cfg.get("vit_precision", "fp16")
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
freeze_vit = cfg.get("freeze_vit", True)
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||||
|
low_resource = cfg.get("low_resource", False)
|
||||||
|
|
||||||
prompt_path = cfg.get("prompt_path", "")
|
prompt_path = cfg.get("prompt_path", "")
|
||||||
prompt_template = cfg.get("prompt_template", "")
|
prompt_template = cfg.get("prompt_template", "")
|
||||||
@ -236,6 +252,7 @@ class MiniGPT4(Blip2Base):
|
|||||||
prompt_path=prompt_path,
|
prompt_path=prompt_path,
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
max_txt_len=max_txt_len,
|
max_txt_len=max_txt_len,
|
||||||
|
low_resource=low_resource,
|
||||||
end_sym=end_sym
|
end_sym=end_sym
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user