2023-05-25 04:34:00 +00:00
|
|
|
import torch
|
2023-05-26 03:44:18 +00:00
|
|
|
import torchaudio
|
2023-05-25 04:34:00 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
def load_image(image, image_processor):
|
|
|
|
if isinstance(image, str): # is a image path
|
|
|
|
raw_image = Image.open(image).convert('RGB')
|
|
|
|
image = image_processor(raw_image).unsqueeze(0)
|
|
|
|
elif isinstance(image, Image.Image):
|
|
|
|
raw_image = image
|
|
|
|
image = image_processor(raw_image).unsqueeze(0)
|
|
|
|
elif isinstance(image, torch.Tensor):
|
|
|
|
if len(image.shape) == 3:
|
|
|
|
image = image.unsqueeze(0)
|
|
|
|
return image
|
2023-05-26 03:44:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def load_audio(audio, audio_processor):
|
|
|
|
if isinstance(audio, str): # is a audio path
|
|
|
|
raw_audio = torchaudio.load(audio)
|
|
|
|
audio = audio_processor(audio)
|
|
|
|
# elif isinstance(audio, )
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|