2024-02-23 17:54:47 +01:00
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import subprocess
import time
from omegaconf import OmegaConf
from PIL import Image
from cog import BasePredictor , Input , Path
from SUPIR . util import (
create_SUPIR_model ,
PIL2Tensor ,
Tensor2PIL ,
convert_dtype ,
)
from llava . llava_agent import LLavaAgent
import CKPT_PTH
SUPIR_v0Q_URL = " https://weights.replicate.delivery/default/SUPIR-v0Q.ckpt "
SUPIR_v0F_URL = " https://weights.replicate.delivery/default/SUPIR-v0F.ckpt "
LLAVA_URL = " https://weights.replicate.delivery/default/llava-v1.5-13b.tar "
LLAVA_CLIP_URL = (
" https://weights.replicate.delivery/default/clip-vit-large-patch14-336.tar "
)
SDXL_URL = " https://weights.replicate.delivery/default/stable-diffusion-xl-base-1.0/sd_xl_base_1.0_0.9vae.safetensors "
SDXL_CLIP1_URL = " https://weights.replicate.delivery/default/clip-vit-large-patch14.tar "
SDXL_CLIP2_URL = (
" https://weights.replicate.delivery/default/CLIP-ViT-bigG-14-laion2B-39B-b160k.tar "
)
MODEL_CACHE = " /opt/data/private/AIGC_pretrain/ " # Follow the default in CKPT_PTH.py
LLAVA_CLIP_PATH = CKPT_PTH . LLAVA_CLIP_PATH
LLAVA_MODEL_PATH = CKPT_PTH . LLAVA_MODEL_PATH
SDXL_CLIP1_PATH = CKPT_PTH . SDXL_CLIP1_PATH
SDXL_CLIP2_CACHE = f " { MODEL_CACHE } /models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k "
SDXL_CKPT = f " { MODEL_CACHE } /SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors "
2024-02-23 17:58:05 +01:00
SUPIR_CKPT_F = f " { MODEL_CACHE } /SUPIR_cache/SUPIR-v0F.ckpt "
2024-02-23 17:54:47 +01:00
SUPIR_CKPT_Q = f " { MODEL_CACHE } /SUPIR_cache/SUPIR-v0Q.ckpt "
def download_weights ( url , dest , extract = True ) :
start = time . time ( )
print ( " downloading url: " , url )
print ( " downloading to: " , dest )
args = [ " pget " ]
if extract :
args . append ( " -x " )
subprocess . check_call ( args + [ url , dest ] , close_fds = False )
print ( " downloading took: " , time . time ( ) - start )
class Predictor ( BasePredictor ) :
def setup ( self ) - > None :
""" Load the model into memory to make running multiple predictions efficient """
for model_dir in [
MODEL_CACHE ,
f " { MODEL_CACHE } /SUPIR_cache " ,
f " { MODEL_CACHE } /SDXL_cache " ,
] :
if not os . path . exists ( model_dir ) :
os . makedirs ( model_dir )
if not os . path . exists ( SUPIR_CKPT_Q ) :
download_weights ( SUPIR_v0Q_URL , SUPIR_CKPT_Q , extract = False )
if not os . path . exists ( SUPIR_CKPT_F ) :
download_weights ( SUPIR_v0F_URL , SUPIR_CKPT_F , extract = False )
if not os . path . exists ( LLAVA_MODEL_PATH ) :
download_weights ( LLAVA_URL , LLAVA_MODEL_PATH )
if not os . path . exists ( LLAVA_CLIP_PATH ) :
download_weights ( LLAVA_CLIP_URL , LLAVA_CLIP_PATH )
if not os . path . exists ( SDXL_CLIP1_PATH ) :
download_weights ( SDXL_CLIP1_URL , SDXL_CLIP1_PATH )
if not os . path . exists ( SDXL_CKPT ) :
download_weights ( SDXL_URL , SDXL_CKPT , extract = False )
if not os . path . exists ( SDXL_CKPT ) :
download_weights ( SDXL_CLIP2_URL , SDXL_CKPT )
self . supir_device = " cuda:0 "
self . llava_device = " cuda:0 "
ae_dtype = " bf16 " # Inference data type of AutoEncoder
diff_dtype = " bf16 " # Inference data type of Diffusion
self . models = {
k : create_SUPIR_model ( " options/SUPIR_v0.yaml " , SUPIR_sign = k ) . to (
self . supir_device
)
for k in [ " Q " , " F " ] [ 1 : ]
}
for k in [ " Q " , " F " ] [ 1 : ] :
self . models [ k ] . ae_dtype = convert_dtype ( ae_dtype )
self . models [ k ] . model . dtype = convert_dtype ( diff_dtype )
# load LLaVA
self . llava_agent = LLavaAgent ( LLAVA_MODEL_PATH , device = self . llava_device )
def predict (
self ,
model_name : str = Input (
description = " Choose a model. SUPIR-v0Q is the default training settings with paper. SUPIR-v0F is high generalization and high image quality in most cases. Training with light degradation settings. Stage1 encoder of SUPIR-v0F remains more details when facing light degradations. " ,
choices = [ " SUPIR-v0Q " , " SUPIR-v0F " ] ,
default = " SUPIR-v0Q " ,
) ,
image : Path = Input ( description = " Low quality input image. " ) ,
upscale : int = Input (
description = " Upsampling ratio of given inputs. " , default = 1
) ,
min_size : float = Input (
description = " Minimum resolution of output images. " , default = 1024
) ,
edm_steps : int = Input (
description = " Number of steps for EDM Sampling Schedule. " ,
ge = 1 ,
le = 500 ,
default = 50 ,
) ,
use_llava : bool = Input (
description = " Use LLaVA model to get captions. " , default = True
) ,
a_prompt : str = Input (
description = " Additive positive prompt for the inputs. " ,
default = " Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations. " ,
) ,
n_prompt : str = Input (
description = " Negative prompt for the inputs. " ,
default = " painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth " ,
) ,
color_fix_type : str = Input (
description = " Color Fixing Type.. " ,
choices = [ " None " , " AdaIn " , " Wavelet " ] ,
default = " Wavelet " ,
) ,
s_stage1 : int = Input (
description = " Control Strength of Stage1 (negative means invalid). " ,
default = - 1 ,
) ,
s_churn : float = Input (
description = " Original churn hy-param of EDM. " , default = 5
) ,
s_noise : float = Input (
description = " Original noise hy-param of EDM. " , default = 1.003
) ,
s_cfg : float = Input (
description = " Classifier-free guidance scale for prompts. " ,
ge = 1 ,
le = 20 ,
default = 7.5 ,
) ,
s_stage2 : float = Input ( description = " Control Strength of Stage2. " , default = 1.0 ) ,
linear_CFG : bool = Input (
description = " Linearly (with sigma) increase CFG from ' spt_linear_CFG ' to s_cfg. " ,
default = False ,
) ,
linear_s_stage2 : bool = Input (
description = " Linearly (with sigma) increase s_stage2 from ' spt_linear_s_stage2 ' to s_stage2. " ,
default = False ,
) ,
spt_linear_CFG : float = Input (
description = " Start point of linearly increasing CFG. " , default = 1.0
) ,
spt_linear_s_stage2 : float = Input (
description = " Start point of linearly increasing s_stage2. " , default = 0.0
) ,
seed : int = Input (
description = " Random seed. Leave blank to randomize the seed " , default = None
) ,
) - > Path :
""" Run a single prediction on the model """
if seed is None :
seed = int . from_bytes ( os . urandom ( 2 ) , " big " )
print ( f " Using seed: { seed } " )
2024-02-23 17:57:15 +01:00
model = self . models [ " Q " ] if model_name == " SUPIR-v0Q " else self . models [ " F " ]
2024-02-23 17:54:47 +01:00
lq_img = Image . open ( str ( image ) )
lq_img , h0 , w0 = PIL2Tensor ( lq_img , upsacle = upscale , min_size = min_size )
lq_img = lq_img . unsqueeze ( 0 ) . to ( self . supir_device ) [ : , : 3 , : , : ]
# step 1: Pre-denoise for LLaVA)
clean_imgs = model . batchify_denoise ( lq_img )
clean_PIL_img = Tensor2PIL ( clean_imgs [ 0 ] , h0 , w0 )
# step 2: LLaVA
captions = [ " " ]
if use_llava :
captions = self . llava_agent . gen_image_caption ( [ clean_PIL_img ] )
print ( f " Captions from LLaVA: { captions } " )
# step 3: Diffusion Process
samples = model . batchify_sample (
lq_img ,
captions ,
num_steps = edm_steps ,
restoration_scale = s_stage1 ,
s_churn = s_churn ,
s_noise = s_noise ,
cfg_scale = s_cfg ,
control_scale = s_stage2 ,
seed = seed ,
num_samples = 1 ,
p_p = a_prompt ,
n_p = n_prompt ,
color_fix_type = color_fix_type ,
use_linear_CFG = linear_CFG ,
use_linear_control_scale = linear_s_stage2 ,
cfg_scale_start = spt_linear_CFG ,
control_scale_start = spt_linear_s_stage2 ,
)
out_path = " /tmp/out.png "
Tensor2PIL ( samples [ 0 ] , h0 , w0 ) . save ( out_path )
return Path ( out_path )