diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py index a046b0b..2f6c94c 100644 --- a/minigpt4/models/minigpt_v2.py +++ b/minigpt4/models/minigpt_v2.py @@ -99,6 +99,7 @@ class MiniGPTv2(MiniGPTBase): vit_precision = cfg.get("vit_precision", "fp16") freeze_vit = cfg.get("freeze_vit", True) low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]') max_txt_len = cfg.get("max_txt_len", 300) @@ -122,6 +123,7 @@ class MiniGPTv2(MiniGPTBase): prompt_template=prompt_template, max_txt_len=max_txt_len, low_resource=low_resource, + device_8bit=device_8bit, end_sym=end_sym, lora_r=lora_r, lora_alpha=lora_alpha,