diff --git a/README.md b/README.md
index ae5d7fc..51d5942 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
## Scaling Up to Excellence: Practicing Model Scaling for Photo-Realistic Image Restoration In the Wild
-> [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [Online Demo (Coming soon)]
+> [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[Replicate Demo](https://replicate.com/cjwbw/supir)]
> Fanghua, Yu, [Jinjin Gu](https://www.jasongt.com/), Zheyuan Li, Jinfan Hu, Xiangtao Kong, [Xintao Wang](https://xinntao.github.io/), [Jingwen He](https://scholar.google.com.hk/citations?user=GUxrycUAAAAJ), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ)
> Shenzhen Institute of Advanced Technology; Shanghai AI Laboratory; University of Sydney; The Hong Kong Polytechnic University; ARC Lab, Tencent PCG; The Chinese University of Hong Kong
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..4a4486f
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,41 @@
+# Configuration for Cog ⚙️
+# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
+
+build:
+ gpu: true
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_version: "3.11"
+ python_packages:
+ - sentencepiece==0.1.98
+ - tokenizers==0.13.3
+ - torch>=2.1.0
+ - torchvision>=0.16.0
+ - uvicorn==0.21.1
+ - transformers==4.28.1
+ - accelerate==0.18.0
+ - scikit-learn==1.2.2
+ - sentencepiece==0.1.98
+ - einops==0.7.0
+ - einops-exts==0.0.4
+ - timm==0.9.8
+ - openai-clip==1.0.1
+ - kornia==0.6.9
+ - matplotlib==3.7.1
+ - ninja==1.11.1
+ - omegaconf==2.3.0
+ - open-clip-torch==2.17.1
+ - opencv-python==4.7.0.72
+ - pandas==2.0.1
+ - Pillow==9.4.0
+ - pytorch-lightning==2.1.2
+ - PyYAML==6.0
+ - scipy==1.12.0
+ - tqdm==4.65.0
+ - triton==2.1.0
+ - webdataset==0.2.48
+ - xformers>=0.0.20
+ run:
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
+predict: "predict.py:Predictor"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..51e166f
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,213 @@
+# Prediction interface for Cog ⚙️
+# https://github.com/replicate/cog/blob/main/docs/python.md
+
+import os
+import subprocess
+import time
+from omegaconf import OmegaConf
+from PIL import Image
+from cog import BasePredictor, Input, Path
+
+from SUPIR.util import (
+ create_SUPIR_model,
+ PIL2Tensor,
+ Tensor2PIL,
+ convert_dtype,
+)
+from llava.llava_agent import LLavaAgent
+import CKPT_PTH
+
+SUPIR_v0Q_URL = "https://weights.replicate.delivery/default/SUPIR-v0Q.ckpt"
+SUPIR_v0F_URL = "https://weights.replicate.delivery/default/SUPIR-v0F.ckpt"
+LLAVA_URL = "https://weights.replicate.delivery/default/llava-v1.5-13b.tar"
+LLAVA_CLIP_URL = (
+ "https://weights.replicate.delivery/default/clip-vit-large-patch14-336.tar"
+)
+SDXL_URL = "https://weights.replicate.delivery/default/stable-diffusion-xl-base-1.0/sd_xl_base_1.0_0.9vae.safetensors"
+SDXL_CLIP1_URL = "https://weights.replicate.delivery/default/clip-vit-large-patch14.tar"
+SDXL_CLIP2_URL = (
+ "https://weights.replicate.delivery/default/CLIP-ViT-bigG-14-laion2B-39B-b160k.tar"
+)
+
+MODEL_CACHE = "/opt/data/private/AIGC_pretrain/" # Follow the default in CKPT_PTH.py
+LLAVA_CLIP_PATH = CKPT_PTH.LLAVA_CLIP_PATH
+LLAVA_MODEL_PATH = CKPT_PTH.LLAVA_MODEL_PATH
+SDXL_CLIP1_PATH = CKPT_PTH.SDXL_CLIP1_PATH
+SDXL_CLIP2_CACHE = f"{MODEL_CACHE}/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k"
+SDXL_CKPT = f"{MODEL_CACHE}/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors"
+SUPIR_CKPT_F = f"{MODEL_CACHE}/SUPIR_cache/SUsPIR-v0F.ckpt"
+SUPIR_CKPT_Q = f"{MODEL_CACHE}/SUPIR_cache/SUPIR-v0Q.ckpt"
+
+
+def download_weights(url, dest, extract=True):
+ start = time.time()
+ print("downloading url: ", url)
+ print("downloading to: ", dest)
+ args = ["pget"]
+ if extract:
+ args.append("-x")
+ subprocess.check_call(args + [url, dest], close_fds=False)
+ print("downloading took: ", time.time() - start)
+
+
+class Predictor(BasePredictor):
+ def setup(self) -> None:
+ """Load the model into memory to make running multiple predictions efficient"""
+ for model_dir in [
+ MODEL_CACHE,
+ f"{MODEL_CACHE}/SUPIR_cache",
+ f"{MODEL_CACHE}/SDXL_cache",
+ ]:
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ if not os.path.exists(SUPIR_CKPT_Q):
+ download_weights(SUPIR_v0Q_URL, SUPIR_CKPT_Q, extract=False)
+ if not os.path.exists(SUPIR_CKPT_F):
+ download_weights(SUPIR_v0F_URL, SUPIR_CKPT_F, extract=False)
+ if not os.path.exists(LLAVA_MODEL_PATH):
+ download_weights(LLAVA_URL, LLAVA_MODEL_PATH)
+ if not os.path.exists(LLAVA_CLIP_PATH):
+ download_weights(LLAVA_CLIP_URL, LLAVA_CLIP_PATH)
+ if not os.path.exists(SDXL_CLIP1_PATH):
+ download_weights(SDXL_CLIP1_URL, SDXL_CLIP1_PATH)
+ if not os.path.exists(SDXL_CKPT):
+ download_weights(SDXL_URL, SDXL_CKPT, extract=False)
+ if not os.path.exists(SDXL_CKPT):
+ download_weights(SDXL_CLIP2_URL, SDXL_CKPT)
+
+ self.supir_device = "cuda:0"
+ self.llava_device = "cuda:0"
+ ae_dtype = "bf16" # Inference data type of AutoEncoder
+ diff_dtype = "bf16" # Inference data type of Diffusion
+
+ self.models = {
+ k: create_SUPIR_model("options/SUPIR_v0.yaml", SUPIR_sign=k).to(
+ self.supir_device
+ )
+ for k in ["Q", "F"][1:]
+ }
+
+ for k in ["Q", "F"][1:]:
+ self.models[k].ae_dtype = convert_dtype(ae_dtype)
+ self.models[k].model.dtype = convert_dtype(diff_dtype)
+
+ # load LLaVA
+ self.llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=self.llava_device)
+
+ def predict(
+ self,
+ model_name: str = Input(
+ description="Choose a model. SUPIR-v0Q is the default training settings with paper. SUPIR-v0F is high generalization and high image quality in most cases. Training with light degradation settings. Stage1 encoder of SUPIR-v0F remains more details when facing light degradations.",
+ choices=["SUPIR-v0Q", "SUPIR-v0F"],
+ default="SUPIR-v0Q",
+ ),
+ image: Path = Input(description="Low quality input image."),
+ upscale: int = Input(
+ description="Upsampling ratio of given inputs.", default=1
+ ),
+ min_size: float = Input(
+ description="Minimum resolution of output images.", default=1024
+ ),
+ edm_steps: int = Input(
+ description="Number of steps for EDM Sampling Schedule.",
+ ge=1,
+ le=500,
+ default=50,
+ ),
+ use_llava: bool = Input(
+ description="Use LLaVA model to get captions.", default=True
+ ),
+ a_prompt: str = Input(
+ description="Additive positive prompt for the inputs.",
+ default="Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
+ ),
+ n_prompt: str = Input(
+ description="Negative prompt for the inputs.",
+ default="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
+ ),
+ color_fix_type: str = Input(
+ description="Color Fixing Type..",
+ choices=["None", "AdaIn", "Wavelet"],
+ default="Wavelet",
+ ),
+ s_stage1: int = Input(
+ description="Control Strength of Stage1 (negative means invalid).",
+ default=-1,
+ ),
+ s_churn: float = Input(
+ description="Original churn hy-param of EDM.", default=5
+ ),
+ s_noise: float = Input(
+ description="Original noise hy-param of EDM.", default=1.003
+ ),
+ s_cfg: float = Input(
+ description=" Classifier-free guidance scale for prompts.",
+ ge=1,
+ le=20,
+ default=7.5,
+ ),
+ s_stage2: float = Input(description="Control Strength of Stage2.", default=1.0),
+ linear_CFG: bool = Input(
+ description="Linearly (with sigma) increase CFG from 'spt_linear_CFG' to s_cfg.",
+ default=False,
+ ),
+ linear_s_stage2: bool = Input(
+ description="Linearly (with sigma) increase s_stage2 from 'spt_linear_s_stage2' to s_stage2.",
+ default=False,
+ ),
+ spt_linear_CFG: float = Input(
+ description="Start point of linearly increasing CFG.", default=1.0
+ ),
+ spt_linear_s_stage2: float = Input(
+ description="Start point of linearly increasing s_stage2.", default=0.0
+ ),
+ seed: int = Input(
+ description="Random seed. Leave blank to randomize the seed", default=None
+ ),
+ ) -> Path:
+ """Run a single prediction on the model"""
+
+ if seed is None:
+ seed = int.from_bytes(os.urandom(2), "big")
+ print(f"Using seed: {seed}")
+
+ model = self.models["Q"] # if model_name == "SUPIR-v0Q" else self.models["F"]
+
+ lq_img = Image.open(str(image))
+ lq_img, h0, w0 = PIL2Tensor(lq_img, upsacle=upscale, min_size=min_size)
+ lq_img = lq_img.unsqueeze(0).to(self.supir_device)[:, :3, :, :]
+
+ # step 1: Pre-denoise for LLaVA)
+ clean_imgs = model.batchify_denoise(lq_img)
+ clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0)
+
+ # step 2: LLaVA
+ captions = [""]
+ if use_llava:
+ captions = self.llava_agent.gen_image_caption([clean_PIL_img])
+ print(f"Captions from LLaVA: {captions}")
+
+ # step 3: Diffusion Process
+ samples = model.batchify_sample(
+ lq_img,
+ captions,
+ num_steps=edm_steps,
+ restoration_scale=s_stage1,
+ s_churn=s_churn,
+ s_noise=s_noise,
+ cfg_scale=s_cfg,
+ control_scale=s_stage2,
+ seed=seed,
+ num_samples=1,
+ p_p=a_prompt,
+ n_p=n_prompt,
+ color_fix_type=color_fix_type,
+ use_linear_CFG=linear_CFG,
+ use_linear_control_scale=linear_s_stage2,
+ cfg_scale_start=spt_linear_CFG,
+ control_scale_start=spt_linear_s_stage2,
+ )
+
+ out_path = "/tmp/out.png"
+ Tensor2PIL(samples[0], h0, w0).save(out_path)
+ return Path(out_path)