From 8ac7917c1508dd068abd26a491b21205894555b9 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Fri, 23 Feb 2024 16:54:47 +0000 Subject: [PATCH] replicate demo --- README.md | 2 +- cog.yaml | 41 +++++++++++ predict.py | 213 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/README.md b/README.md index ae5d7fc..51d5942 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ## Scaling Up to Excellence: Practicing Model Scaling for Photo-Realistic Image Restoration In the Wild -> [[Paper](https://arxiv.org/abs/2401.13627)]   [[Project Page](http://supir.xpixel.group/)]   [Online Demo (Coming soon)]
+> [[Paper](https://arxiv.org/abs/2401.13627)]   [[Project Page](http://supir.xpixel.group/)]   [[Replicate Demo](https://replicate.com/cjwbw/supir)]
> Fanghua, Yu, [Jinjin Gu](https://www.jasongt.com/), Zheyuan Li, Jinfan Hu, Xiangtao Kong, [Xintao Wang](https://xinntao.github.io/), [Jingwen He](https://scholar.google.com.hk/citations?user=GUxrycUAAAAJ), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ)
> Shenzhen Institute of Advanced Technology; Shanghai AI Laboratory; University of Sydney; The Hong Kong Polytechnic University; ARC Lab, Tencent PCG; The Chinese University of Hong Kong
diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..4a4486f --- /dev/null +++ b/cog.yaml @@ -0,0 +1,41 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + gpu: true + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_version: "3.11" + python_packages: + - sentencepiece==0.1.98 + - tokenizers==0.13.3 + - torch>=2.1.0 + - torchvision>=0.16.0 + - uvicorn==0.21.1 + - transformers==4.28.1 + - accelerate==0.18.0 + - scikit-learn==1.2.2 + - sentencepiece==0.1.98 + - einops==0.7.0 + - einops-exts==0.0.4 + - timm==0.9.8 + - openai-clip==1.0.1 + - kornia==0.6.9 + - matplotlib==3.7.1 + - ninja==1.11.1 + - omegaconf==2.3.0 + - open-clip-torch==2.17.1 + - opencv-python==4.7.0.72 + - pandas==2.0.1 + - Pillow==9.4.0 + - pytorch-lightning==2.1.2 + - PyYAML==6.0 + - scipy==1.12.0 + - tqdm==4.65.0 + - triton==2.1.0 + - webdataset==0.2.48 + - xformers>=0.0.20 + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..51e166f --- /dev/null +++ b/predict.py @@ -0,0 +1,213 @@ +# Prediction interface for Cog ⚙️ +# https://github.com/replicate/cog/blob/main/docs/python.md + +import os +import subprocess +import time +from omegaconf import OmegaConf +from PIL import Image +from cog import BasePredictor, Input, Path + +from SUPIR.util import ( + create_SUPIR_model, + PIL2Tensor, + Tensor2PIL, + convert_dtype, +) +from llava.llava_agent import LLavaAgent +import CKPT_PTH + +SUPIR_v0Q_URL = "https://weights.replicate.delivery/default/SUPIR-v0Q.ckpt" +SUPIR_v0F_URL = "https://weights.replicate.delivery/default/SUPIR-v0F.ckpt" +LLAVA_URL = "https://weights.replicate.delivery/default/llava-v1.5-13b.tar" +LLAVA_CLIP_URL = ( + "https://weights.replicate.delivery/default/clip-vit-large-patch14-336.tar" +) +SDXL_URL = "https://weights.replicate.delivery/default/stable-diffusion-xl-base-1.0/sd_xl_base_1.0_0.9vae.safetensors" +SDXL_CLIP1_URL = "https://weights.replicate.delivery/default/clip-vit-large-patch14.tar" +SDXL_CLIP2_URL = ( + "https://weights.replicate.delivery/default/CLIP-ViT-bigG-14-laion2B-39B-b160k.tar" +) + +MODEL_CACHE = "/opt/data/private/AIGC_pretrain/" # Follow the default in CKPT_PTH.py +LLAVA_CLIP_PATH = CKPT_PTH.LLAVA_CLIP_PATH +LLAVA_MODEL_PATH = CKPT_PTH.LLAVA_MODEL_PATH +SDXL_CLIP1_PATH = CKPT_PTH.SDXL_CLIP1_PATH +SDXL_CLIP2_CACHE = f"{MODEL_CACHE}/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k" +SDXL_CKPT = f"{MODEL_CACHE}/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors" +SUPIR_CKPT_F = f"{MODEL_CACHE}/SUPIR_cache/SUsPIR-v0F.ckpt" +SUPIR_CKPT_Q = f"{MODEL_CACHE}/SUPIR_cache/SUPIR-v0Q.ckpt" + + +def download_weights(url, dest, extract=True): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + args = ["pget"] + if extract: + args.append("-x") + subprocess.check_call(args + [url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + for model_dir in [ + MODEL_CACHE, + f"{MODEL_CACHE}/SUPIR_cache", + f"{MODEL_CACHE}/SDXL_cache", + ]: + if not os.path.exists(model_dir): + os.makedirs(model_dir) + if not os.path.exists(SUPIR_CKPT_Q): + download_weights(SUPIR_v0Q_URL, SUPIR_CKPT_Q, extract=False) + if not os.path.exists(SUPIR_CKPT_F): + download_weights(SUPIR_v0F_URL, SUPIR_CKPT_F, extract=False) + if not os.path.exists(LLAVA_MODEL_PATH): + download_weights(LLAVA_URL, LLAVA_MODEL_PATH) + if not os.path.exists(LLAVA_CLIP_PATH): + download_weights(LLAVA_CLIP_URL, LLAVA_CLIP_PATH) + if not os.path.exists(SDXL_CLIP1_PATH): + download_weights(SDXL_CLIP1_URL, SDXL_CLIP1_PATH) + if not os.path.exists(SDXL_CKPT): + download_weights(SDXL_URL, SDXL_CKPT, extract=False) + if not os.path.exists(SDXL_CKPT): + download_weights(SDXL_CLIP2_URL, SDXL_CKPT) + + self.supir_device = "cuda:0" + self.llava_device = "cuda:0" + ae_dtype = "bf16" # Inference data type of AutoEncoder + diff_dtype = "bf16" # Inference data type of Diffusion + + self.models = { + k: create_SUPIR_model("options/SUPIR_v0.yaml", SUPIR_sign=k).to( + self.supir_device + ) + for k in ["Q", "F"][1:] + } + + for k in ["Q", "F"][1:]: + self.models[k].ae_dtype = convert_dtype(ae_dtype) + self.models[k].model.dtype = convert_dtype(diff_dtype) + + # load LLaVA + self.llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=self.llava_device) + + def predict( + self, + model_name: str = Input( + description="Choose a model. SUPIR-v0Q is the default training settings with paper. SUPIR-v0F is high generalization and high image quality in most cases. Training with light degradation settings. Stage1 encoder of SUPIR-v0F remains more details when facing light degradations.", + choices=["SUPIR-v0Q", "SUPIR-v0F"], + default="SUPIR-v0Q", + ), + image: Path = Input(description="Low quality input image."), + upscale: int = Input( + description="Upsampling ratio of given inputs.", default=1 + ), + min_size: float = Input( + description="Minimum resolution of output images.", default=1024 + ), + edm_steps: int = Input( + description="Number of steps for EDM Sampling Schedule.", + ge=1, + le=500, + default=50, + ), + use_llava: bool = Input( + description="Use LLaVA model to get captions.", default=True + ), + a_prompt: str = Input( + description="Additive positive prompt for the inputs.", + default="Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.", + ), + n_prompt: str = Input( + description="Negative prompt for the inputs.", + default="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth", + ), + color_fix_type: str = Input( + description="Color Fixing Type..", + choices=["None", "AdaIn", "Wavelet"], + default="Wavelet", + ), + s_stage1: int = Input( + description="Control Strength of Stage1 (negative means invalid).", + default=-1, + ), + s_churn: float = Input( + description="Original churn hy-param of EDM.", default=5 + ), + s_noise: float = Input( + description="Original noise hy-param of EDM.", default=1.003 + ), + s_cfg: float = Input( + description=" Classifier-free guidance scale for prompts.", + ge=1, + le=20, + default=7.5, + ), + s_stage2: float = Input(description="Control Strength of Stage2.", default=1.0), + linear_CFG: bool = Input( + description="Linearly (with sigma) increase CFG from 'spt_linear_CFG' to s_cfg.", + default=False, + ), + linear_s_stage2: bool = Input( + description="Linearly (with sigma) increase s_stage2 from 'spt_linear_s_stage2' to s_stage2.", + default=False, + ), + spt_linear_CFG: float = Input( + description="Start point of linearly increasing CFG.", default=1.0 + ), + spt_linear_s_stage2: float = Input( + description="Start point of linearly increasing s_stage2.", default=0.0 + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> Path: + """Run a single prediction on the model""" + + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + model = self.models["Q"] # if model_name == "SUPIR-v0Q" else self.models["F"] + + lq_img = Image.open(str(image)) + lq_img, h0, w0 = PIL2Tensor(lq_img, upsacle=upscale, min_size=min_size) + lq_img = lq_img.unsqueeze(0).to(self.supir_device)[:, :3, :, :] + + # step 1: Pre-denoise for LLaVA) + clean_imgs = model.batchify_denoise(lq_img) + clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0) + + # step 2: LLaVA + captions = [""] + if use_llava: + captions = self.llava_agent.gen_image_caption([clean_PIL_img]) + print(f"Captions from LLaVA: {captions}") + + # step 3: Diffusion Process + samples = model.batchify_sample( + lq_img, + captions, + num_steps=edm_steps, + restoration_scale=s_stage1, + s_churn=s_churn, + s_noise=s_noise, + cfg_scale=s_cfg, + control_scale=s_stage2, + seed=seed, + num_samples=1, + p_p=a_prompt, + n_p=n_prompt, + color_fix_type=color_fix_type, + use_linear_CFG=linear_CFG, + use_linear_control_scale=linear_s_stage2, + cfg_scale_start=spt_linear_CFG, + control_scale_start=spt_linear_s_stage2, + ) + + out_path = "/tmp/out.png" + Tensor2PIL(samples[0], h0, w0).save(out_path) + return Path(out_path)