102 lines
No EOL
3.9 KiB
Python
102 lines
No EOL
3.9 KiB
Python
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
|
|
import torch
|
|
from transformers import StoppingCriteria
|
|
from llava.constants import IMAGE_TOKEN_INDEX
|
|
|
|
|
|
def load_image_from_base64(image):
|
|
return Image.open(BytesIO(base64.b64decode(image)))
|
|
|
|
|
|
def expand2square(pil_img, background_color):
|
|
width, height = pil_img.size
|
|
if width == height:
|
|
return pil_img
|
|
elif width > height:
|
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
result.paste(pil_img, (0, (width - height) // 2))
|
|
return result
|
|
else:
|
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
result.paste(pil_img, ((height - width) // 2, 0))
|
|
return result
|
|
|
|
|
|
def process_images(images, image_processor, model_cfg):
|
|
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
|
new_images = []
|
|
if image_aspect_ratio == 'pad':
|
|
for image in images:
|
|
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
|
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
|
new_images.append(image)
|
|
else:
|
|
return image_processor(images, return_tensors='pt')['pixel_values']
|
|
if all(x.shape == new_images[0].shape for x in new_images):
|
|
new_images = torch.stack(new_images, dim=0)
|
|
return new_images
|
|
|
|
|
|
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
|
|
|
def insert_separator(X, sep):
|
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
|
|
|
input_ids = []
|
|
offset = 0
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
offset = 1
|
|
input_ids.append(prompt_chunks[0][0])
|
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
|
input_ids.extend(x[offset:])
|
|
|
|
if return_tensors is not None:
|
|
if return_tensors == 'pt':
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
return input_ids
|
|
|
|
|
|
def get_model_name_from_path(model_path):
|
|
model_path = model_path.strip("/")
|
|
model_paths = model_path.split("/")
|
|
if model_paths[-1].startswith('checkpoint-'):
|
|
return model_paths[-2] + "_" + model_paths[-1]
|
|
else:
|
|
return model_paths[-1]
|
|
|
|
|
|
|
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria):
|
|
def __init__(self, keywords, tokenizer, input_ids):
|
|
self.keywords = keywords
|
|
self.keyword_ids = []
|
|
self.max_keyword_len = 0
|
|
for keyword in keywords:
|
|
cur_keyword_ids = tokenizer(keyword).input_ids
|
|
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
|
cur_keyword_ids = cur_keyword_ids[1:]
|
|
if len(cur_keyword_ids) > self.max_keyword_len:
|
|
self.max_keyword_len = len(cur_keyword_ids)
|
|
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
|
self.tokenizer = tokenizer
|
|
self.start_len = input_ids.shape[1]
|
|
|
|
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
|
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
|
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
|
for keyword_id in self.keyword_ids:
|
|
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
|
return True
|
|
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
|
for keyword in self.keywords:
|
|
if keyword in outputs:
|
|
return True
|
|
return False |