mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import argparse
|
||
import os
|
||
import random
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.backends.cudnn as cudnn
|
||
import gradio as gr
|
||
|
||
from minigpt4.common.config import Config
|
||
from minigpt4.common.dist_utils import get_rank
|
||
from minigpt4.common.registry import registry
|
||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||
|
||
# imports modules for registration
|
||
from minigpt4.datasets.builders import *
|
||
from minigpt4.models import *
|
||
from minigpt4.processors import *
|
||
from minigpt4.runners import *
|
||
from minigpt4.tasks import *
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="Demo")
|
||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||
parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.")
|
||
parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.")
|
||
parser.add_argument("--english", type=bool, default=True, help="chinese or english")
|
||
parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?")
|
||
parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片?", help="Can you describe the current picture?")
|
||
parser.add_argument(
|
||
"--options",
|
||
nargs="+",
|
||
help="override some settings in the used config, the key-value pair "
|
||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||
"change to --cfg-options instead.",
|
||
)
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
def setup_seeds(config):
|
||
seed = config.run_cfg.seed + get_rank()
|
||
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
|
||
cudnn.benchmark = False
|
||
cudnn.deterministic = True
|
||
|
||
|
||
# ========================================
|
||
# Model Initialization
|
||
# ========================================
|
||
|
||
print('Initializing Chat')
|
||
args = parse_args()
|
||
cfg = Config(args)
|
||
|
||
model_config = cfg.model_cfg
|
||
model_config.device_8bit = args.gpu_id
|
||
model_cls = registry.get_model_class(model_config.arch)
|
||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||
|
||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||
print('Initialization Finished')
|
||
|
||
while True:
|
||
if not args.english:
|
||
image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
|
||
else:
|
||
image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
|
||
|
||
if image_path == 'stop':
|
||
break
|
||
if len(image_path) > 0:
|
||
query = args.prompt_en if args.english else args.prompt_zh
|
||
|
||
|
||
while True:
|
||
if query == "clear":
|
||
break
|
||
if query == "stop":
|
||
sys.exit(0)
|
||
img_list = []
|
||
chat_state = CONV_VISION.copy()
|
||
chat.upload_img(image_path, chat_state, img_list)
|
||
chat.ask(query, chat_state)
|
||
llm_message = chat.answer(
|
||
conv=chat_state,
|
||
img_list=img_list,
|
||
num_beams=args.num_beams,
|
||
temperature=args.temperature,
|
||
max_new_tokens=300,
|
||
max_length=2000
|
||
)[0]
|
||
# chatbot[-1][1] = llm_message
|
||
print("MiniGPT4:"+llm_message)
|
||
query = input("user:")
|