2023-04-16 22:04:16 +00:00
import argparse
import os
import random
import numpy as np
import torch
import torch . backends . cudnn as cudnn
import gradio as gr
2023-10-13 00:14:35 +00:00
from transformers import StoppingCriteriaList
2023-04-16 22:04:16 +00:00
from minigpt4 . common . config import Config
from minigpt4 . common . dist_utils import get_rank
from minigpt4 . common . registry import registry
2023-10-13 00:14:35 +00:00
from minigpt4 . conversation . conversation import Chat , CONV_VISION_Vicuna0 , CONV_VISION_LLama2 , StoppingCriteriaSub
2023-04-16 22:04:16 +00:00
# imports modules for registration
from minigpt4 . datasets . builders import *
from minigpt4 . models import *
from minigpt4 . processors import *
from minigpt4 . runners import *
from minigpt4 . tasks import *
def parse_args ( ) :
parser = argparse . ArgumentParser ( description = " Demo " )
parser . add_argument ( " --cfg-path " , required = True , help = " path to configuration file. " )
2023-04-19 17:00:25 +00:00
parser . add_argument ( " --gpu-id " , type = int , default = 0 , help = " specify the gpu to load the model. " )
2023-04-16 22:04:16 +00:00
parser . add_argument (
" --options " ,
nargs = " + " ,
help = " override some settings in the used config, the key-value pair "
" in xxx=yyy format will be merged into config file (deprecate), "
" change to --cfg-options instead. " ,
)
args = parser . parse_args ( )
return args
def setup_seeds ( config ) :
seed = config . run_cfg . seed + get_rank ( )
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
cudnn . benchmark = False
cudnn . deterministic = True
# ========================================
# Model Initialization
# ========================================
2023-08-28 18:26:00 +00:00
conv_dict = { ' pretrain_vicuna0 ' : CONV_VISION_Vicuna0 ,
' pretrain_llama2 ' : CONV_VISION_LLama2 }
2023-04-16 22:04:16 +00:00
print ( ' Initializing Chat ' )
2023-04-19 16:49:05 +00:00
args = parse_args ( )
cfg = Config ( args )
2023-04-16 22:04:16 +00:00
model_config = cfg . model_cfg
2023-04-19 16:49:05 +00:00
model_config . device_8bit = args . gpu_id
2023-04-16 22:04:16 +00:00
model_cls = registry . get_model_class ( model_config . arch )
2023-04-19 16:49:05 +00:00
model = model_cls . from_config ( model_config ) . to ( ' cuda: {} ' . format ( args . gpu_id ) )
2023-04-16 22:04:16 +00:00
2023-08-28 18:26:00 +00:00
CONV_VISION = conv_dict [ model_config . model_type ]
2023-04-16 22:04:16 +00:00
vis_processor_cfg = cfg . datasets_cfg . cc_sbu_align . vis_processor . train
vis_processor = registry . get_processor_class ( vis_processor_cfg . name ) . from_config ( vis_processor_cfg )
2023-10-13 00:14:35 +00:00
stop_words_ids = [ [ 835 ] , [ 2277 , 29937 ] ]
stop_words_ids = [ torch . tensor ( ids ) . to ( device = ' cuda: {} ' . format ( args . gpu_id ) ) for ids in stop_words_ids ]
stopping_criteria = StoppingCriteriaList ( [ StoppingCriteriaSub ( stops = stop_words_ids ) ] )
chat = Chat ( model , vis_processor , device = ' cuda: {} ' . format ( args . gpu_id ) , stopping_criteria = stopping_criteria )
2023-04-16 22:04:16 +00:00
print ( ' Initialization Finished ' )
2023-08-28 18:26:00 +00:00
2023-04-16 22:04:16 +00:00
# ========================================
# Gradio Setting
# ========================================
2023-08-28 18:26:00 +00:00
2023-04-16 22:04:16 +00:00
def gradio_reset ( chat_state , img_list ) :
2023-04-17 07:35:02 +00:00
if chat_state is not None :
chat_state . messages = [ ]
if img_list is not None :
img_list = [ ]
2023-04-16 22:04:16 +00:00
return None , gr . update ( value = None , interactive = True ) , gr . update ( placeholder = ' Please upload your image first ' , interactive = False ) , gr . update ( value = " Upload & Start Chat " , interactive = True ) , chat_state , img_list
2023-08-28 18:26:00 +00:00
2023-04-16 22:04:16 +00:00
def upload_img ( gr_img , text_input , chat_state ) :
if gr_img is None :
2023-04-17 07:35:02 +00:00
return None , None , gr . update ( interactive = True ) , chat_state , None
2023-04-16 22:04:16 +00:00
chat_state = CONV_VISION . copy ( )
img_list = [ ]
llm_message = chat . upload_img ( gr_img , chat_state , img_list )
2023-10-13 00:14:35 +00:00
chat . encode_img ( img_list )
2023-04-16 22:04:16 +00:00
return gr . update ( interactive = False ) , gr . update ( interactive = True , placeholder = ' Type and press Enter ' ) , gr . update ( value = " Start Chatting " , interactive = False ) , chat_state , img_list
2023-08-28 18:26:00 +00:00
2023-04-16 22:04:16 +00:00
def gradio_ask ( user_message , chatbot , chat_state ) :
if len ( user_message ) == 0 :
return gr . update ( interactive = True , placeholder = ' Input should not be empty! ' ) , chatbot , chat_state
chat . ask ( user_message , chat_state )
chatbot = chatbot + [ [ user_message , None ] ]
return ' ' , chatbot , chat_state
def gradio_answer ( chatbot , chat_state , img_list , num_beams , temperature ) :
2023-04-18 19:04:50 +00:00
llm_message = chat . answer ( conv = chat_state ,
img_list = img_list ,
num_beams = num_beams ,
temperature = temperature ,
max_new_tokens = 300 ,
max_length = 2000 ) [ 0 ]
2023-04-16 22:04:16 +00:00
chatbot [ - 1 ] [ 1 ] = llm_message
return chatbot , chat_state , img_list
2023-08-28 18:26:00 +00:00
2023-04-16 22:04:16 +00:00
title = """ <h1 align= " center " >Demo of MiniGPT-4</h1> """
description = """ <h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3> """
2023-04-18 09:22:28 +00:00
article = """ <p><a href= ' https://minigpt-4.github.io ' ><img src= ' https://img.shields.io/badge/Project-Page-Green ' ></a></p><p><a href= ' https://github.com/Vision-CAIR/MiniGPT-4 ' ><img src= ' https://img.shields.io/badge/Github-Code-blue ' ></a></p><p><a href= ' https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf ' ><img src= ' https://img.shields.io/badge/Paper-PDF-red ' ></a></p>
2023-04-16 22:04:16 +00:00
"""
#TODO show examples below
with gr . Blocks ( ) as demo :
gr . Markdown ( title )
gr . Markdown ( description )
gr . Markdown ( article )
with gr . Row ( ) :
2023-10-13 00:14:35 +00:00
with gr . Column ( scale = 1 ) :
2023-04-16 22:04:16 +00:00
image = gr . Image ( type = " pil " )
upload_button = gr . Button ( value = " Upload & Start Chat " , interactive = True , variant = " primary " )
clear = gr . Button ( " Restart " )
num_beams = gr . Slider (
minimum = 1 ,
2023-04-17 06:56:18 +00:00
maximum = 10 ,
value = 1 ,
2023-04-16 22:04:16 +00:00
step = 1 ,
interactive = True ,
label = " beam search numbers) " ,
)
temperature = gr . Slider (
minimum = 0.1 ,
maximum = 2.0 ,
value = 1.0 ,
step = 0.1 ,
interactive = True ,
label = " Temperature " ,
)
2023-10-13 00:14:35 +00:00
with gr . Column ( scale = 2 ) :
2023-04-16 22:04:16 +00:00
chat_state = gr . State ( )
img_list = gr . State ( )
chatbot = gr . Chatbot ( label = ' MiniGPT-4 ' )
text_input = gr . Textbox ( label = ' User ' , placeholder = ' Please upload your image first ' , interactive = False )
upload_button . click ( upload_img , [ image , text_input , chat_state ] , [ image , text_input , upload_button , chat_state , img_list ] )
text_input . submit ( gradio_ask , [ text_input , chatbot , chat_state ] , [ text_input , chatbot , chat_state ] ) . then (
gradio_answer , [ chatbot , chat_state , img_list , num_beams , temperature ] , [ chatbot , chat_state , img_list ]
)
clear . click ( gradio_reset , [ chat_state , img_list ] , [ chatbot , image , text_input , upload_button , chat_state , img_list ] , queue = False )
2023-04-18 09:22:28 +00:00
demo . launch ( share = True , enable_queue = True )