This commit is contained in:
John Vandivier 2023-05-01 07:08:32 -04:00 committed by GitHub
commit 3749e8bbb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 15 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
**.pyc

View File

@ -57,11 +57,11 @@ cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
model = model_cls.from_config(model_config).to('cpu')
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
chat = Chat(model, vis_processor, device='cpu')
print('Initialization Finished')
# ========================================

View File

@ -3,12 +3,11 @@ channels:
- pytorch
- defaults
- anaconda
- conda-forge
dependencies:
- python=3.9
- cudatoolkit
- pip
- pytorch=1.12.1
- pytorch-mutex=1.0=cuda
- torchaudio=0.12.1
- torchvision=0.13.1
- pip:
@ -30,6 +29,7 @@ dependencies:
- kiwisolver==1.4.4
- matplotlib==3.7.0
- multidict==6.0.4
- numpy
- openai==0.27.0
- packaging==23.0
- psutil==5.9.4
@ -51,7 +51,6 @@ dependencies:
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- iopath==0.1.10
- decord==0.6.0
- tenacity==8.2.2
- peft
- pycocoevalcap

View File

@ -5,20 +5,12 @@
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import gzip
import logging
import os
import random as rnd
import tarfile
import zipfile
import random
from typing import List
from tqdm import tqdm
import decord
from decord import VideoReader
# import decord
import webdataset as wds
import numpy as np
import torch
from torch.utils.data.dataset import IterableDataset
@ -26,7 +18,7 @@ from minigpt4.common.registry import registry
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
decord.bridge.set_bridge("torch")
# decord.bridge.set_bridge("torch")
MAX_INT = registry.get("MAX_INT")