mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
Merge 1d92f757ff
into d94738a762
This commit is contained in:
commit
a7d2926618
103
cli_demo.py
Normal file
103
cli_demo.py
Normal file
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import gradio as gr
|
||||
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
from minigpt4.models import *
|
||||
from minigpt4.processors import *
|
||||
from minigpt4.runners import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||
parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.")
|
||||
parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.")
|
||||
parser.add_argument("--english", type=bool, default=True, help="chinese or english")
|
||||
parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?")
|
||||
parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片?", help="Can you describe the current picture?")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def setup_seeds(config):
|
||||
seed = config.run_cfg.seed + get_rank()
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
|
||||
|
||||
# ========================================
|
||||
# Model Initialization
|
||||
# ========================================
|
||||
|
||||
print('Initializing Chat')
|
||||
args = parse_args()
|
||||
cfg = Config(args)
|
||||
|
||||
model_config = cfg.model_cfg
|
||||
model_config.device_8bit = args.gpu_id
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||
print('Initialization Finished')
|
||||
|
||||
while True:
|
||||
if not args.english:
|
||||
image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
|
||||
else:
|
||||
image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
|
||||
|
||||
if image_path == 'stop':
|
||||
break
|
||||
if len(image_path) > 0:
|
||||
query = args.prompt_en if args.english else args.prompt_zh
|
||||
|
||||
|
||||
while True:
|
||||
if query == "clear":
|
||||
break
|
||||
if query == "stop":
|
||||
sys.exit(0)
|
||||
img_list = []
|
||||
chat_state = CONV_VISION.copy()
|
||||
chat.upload_img(image_path, chat_state, img_list)
|
||||
chat.ask(query, chat_state)
|
||||
llm_message = chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
num_beams=args.num_beams,
|
||||
temperature=args.temperature,
|
||||
max_new_tokens=300,
|
||||
max_length=2000
|
||||
)[0]
|
||||
# chatbot[-1][1] = llm_message
|
||||
print("MiniGPT4:"+llm_message)
|
||||
query = input("user:")
|
Loading…
Reference in New Issue
Block a user