release v0 ckpt

This commit is contained in:
Fanghua-Yu 2024-02-22 23:33:17 +08:00
parent c1c5728601
commit 9e8f6f8f76
5 changed files with 96 additions and 27 deletions

View file

@ -12,7 +12,6 @@
--- ---
## 🔧 Dependencies and Installation ## 🔧 Dependencies and Installation
1. Clone repo 1. Clone repo
```bash ```bash
git clone https://github.com/Fanghua-Yu/SUPIR.git git clone https://github.com/Fanghua-Yu/SUPIR.git
@ -37,11 +36,11 @@
#### Models we provided: #### Models we provided:
* `SUPIR-v0Q`: (Coming Soon) Google Drive, Baidu Netdisk * `SUPIR-v0Q`: [Baidu Netdisk](https://pan.baidu.com/s/1lnefCZhBTeDWijqbj1jIyw?pwd=pjq6), Google Drive (Coming Soon)
Default training settings with paper. High generalization and high image quality in most cases. Default training settings with paper. High generalization and high image quality in most cases.
* `SUPIR-v0F`: (Coming Soon) Google Drive, Baidu Netdisk * `SUPIR-v0F`: [Baidu Netdisk](https://pan.baidu.com/s/1AECN8NjiVuE3hvO8o-Ua6A?pwd=k2uz), Google Drive (Coming Soon)
Training with light degradation settings. Stage1 encoder of `SUPIR-v0F` remains more details when facing light degradations. Training with light degradation settings. Stage1 encoder of `SUPIR-v0F` remains more details when facing light degradations.
@ -53,11 +52,11 @@
--- ---
## ⚡ Quick Inference ## ⚡ Quick Inference
### Val Dataset
RealPhoto60: [Baidu Netdisk](https://pan.baidu.com/s/1CJKsPGtyfs8QEVCQ97voBA?pwd=aocg), Google Drive (Coming Soon)
### Usage of SUPIR ### Usage of SUPIR
```Shell
```console
Usage: Usage:
-- python test.py [options] -- python test.py [options]
-- python gradio_demo.py [interactive options] -- python gradio_demo.py [interactive options]
@ -102,13 +101,16 @@ CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir '/opt/data/private/LV_Dataset/
### Gradio Demo ### Gradio Demo
```Shell ```Shell
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_image_slider --log_history
``` ```
<p align="center">
<img src="assets/DemoGuide.png">
</p>
### Online Demo (Coming Soon) ### Online Demo (Coming Soon)
--- ---
## BibTeX ## BibTeX

BIN
assets/DemoGuide.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

View file

@ -1,3 +1,5 @@
import os
import gradio as gr import gradio as gr
from gradio_imageslider import ImageSlider from gradio_imageslider import ImageSlider
import argparse import argparse
@ -10,12 +12,14 @@ from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH from CKPT_PTH import LLAVA_MODEL_PATH
import einops import einops
import copy import copy
import time
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='0.0.0.0') parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--port", type=int, default='6688') parser.add_argument("--port", type=int, default='6688')
parser.add_argument("--no_llava", action='store_true', default=False) parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--use_image_slider", action='store_true', default=False) parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
args = parser.parse_args() args = parser.parse_args()
server_ip = args.ip server_ip = args.ip
server_port = args.port server_port = args.port
@ -41,7 +45,6 @@ if use_llava:
else: else:
llava_agent = None llava_agent = None
def stage1_process(input_image, gamma_correction): def stage1_process(input_image, gamma_correction):
LQ = HWC3(input_image) LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512) LQ = fix_resize(LQ, 512)
@ -69,6 +72,15 @@ def llave_process(input_image, temperature, top_p, qs=None):
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select): linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
event_id = str(time.time_ns())
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
's_stage1': s_stage1, 's_stage2': s_stage2, 's_cfg': s_cfg, 'seed': seed, 's_churn': s_churn,
's_noise': s_noise, 'color_fix_type': color_fix_type, 'diff_dtype': diff_dtype, 'ae_dtype': ae_dtype,
'gamma_correction': gamma_correction, 'linear_CFG': linear_CFG, 'linear_s_stage2': linear_s_stage2,
'spt_linear_CFG': spt_linear_CFG, 'spt_linear_s_stage2': spt_linear_s_stage2,
'model_select': model_select}
if model_select != model.current_model: if model_select != model.current_model:
if model_select == 'v0-Q': if model_select == 'v0-Q':
print('load v0-Q') print('load v0-Q')
@ -79,7 +91,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.load_state_dict(ckpt_F, strict=False) model.load_state_dict(ckpt_F, strict=False)
model.current_model = 'v0-F' model.current_model = 'v0-F'
input_image = HWC3(input_image) input_image = HWC3(input_image)
input_image = upscale_image(input_image, upscale, unit_resolution=32) input_image = upscale_image(input_image, upscale, unit_resolution=32,
min_size=1024)
LQ = np.array(input_image) / 255.0 LQ = np.array(input_image) / 255.0
LQ = np.power(LQ, gamma_correction) LQ = np.power(LQ, gamma_correction)
@ -101,7 +114,16 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip( x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
0, 255).astype(np.uint8) 0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)] results = [x_samples[i] for i in range(num_samples)]
return [input_image] + results
if args.log_history:
os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
f.write(str(event_dict))
f.close()
Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
for i, result in enumerate(results):
Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
return [input_image] + results, event_id, 3, ''
def load_and_reset(param_setting): def load_and_reset(param_setting):
edm_steps = 50 edm_steps = 50
@ -118,23 +140,56 @@ def load_and_reset(param_setting):
color_fix_type = 'Wavelet' color_fix_type = 'Wavelet'
spt_linear_CFG = 1.0 spt_linear_CFG = 1.0
spt_linear_s_stage2 = 0.0 spt_linear_s_stage2 = 0.0
linear_s_stage2 = False
if param_setting == "Quality": if param_setting == "Quality":
s_cfg = 7.5 s_cfg = 7.5
linear_CFG = False linear_CFG = False
linear_s_stage2 = True
elif param_setting == "Fidelity": elif param_setting == "Fidelity":
s_cfg = 4.0 s_cfg = 4.0
linear_CFG = True linear_CFG = True
linear_s_stage2 = False
else: else:
raise NotImplementedError raise NotImplementedError
return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2 linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
def submit_feedback(event_id, fb_score, fb_text):
if args.log_history:
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'r') as f:
event_dict = eval(f.read())
f.close()
event_dict['feedback'] = {'score': fb_score, 'text': fb_text}
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
f.write(str(event_dict))
f.close()
return 'Submit successfully, thank you for your comments!'
else:
return 'Submit failed, the server is not set to log history.'
title_md = """
# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
SUPIR is still a research project under tested and is not yet a stable commercial product.
[[Paper](https://arxiv.org/abs/2401.13627)] &emsp; [[Project Page](http://supir.xpixel.group/)] &emsp; [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
"""
claim_md = """
## **Terms of use**
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
## **License**
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
"""
block = gr.Blocks(title='SUPIR').queue() block = gr.Blocks(title='SUPIR').queue()
with block: with block:
with gr.Row(): with gr.Row():
gr.Markdown("<center><font size=5>SUPIR Playground</font></center>") gr.Markdown(title_md)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
@ -150,7 +205,8 @@ with block:
with gr.Accordion("LLaVA options", open=False): with gr.Accordion("LLaVA options", open=False):
temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1) temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1) top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.") qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
"The image is a realistic photography, not an art painting.")
with gr.Accordion("Stage2 options", open=False): with gr.Accordion("Stage2 options", open=False):
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1 num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
, value=1, step=1) , value=1, step=1)
@ -178,7 +234,7 @@ with block:
spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0, spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
maximum=9.0, value=1.0, step=0.5) maximum=9.0, value=1.0, step=0.5)
with gr.Column(): with gr.Column():
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=True) linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0., spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
maximum=1., value=0., step=0.05) maximum=1., value=0., step=0.05)
with gr.Row(): with gr.Row():
@ -213,8 +269,15 @@ with block:
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting", param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
value="Quality") value="Quality")
with gr.Column(): with gr.Column():
restart_button = gr.Button(value="Reset Param") restart_button = gr.Button(value="Reset Param", scale=2)
with gr.Accordion("Feedback", open=True):
fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
interactive=True)
fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
submit_button = gr.Button(value="Submit Feedback")
with gr.Row():
gr.Markdown(claim_md)
event_id = gr.Textbox(label="Event ID", value="", visible=False)
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt]) llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
@ -222,8 +285,9 @@ with block:
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select] linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery]) diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text])
restart_button.click(fn=load_and_reset, inputs=[param_setting], restart_button.click(fn=load_and_reset, inputs=[param_setting],
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]) color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
block.launch(server_name=server_ip, server_port=server_port) block.launch(server_name=server_ip, server_port=server_port)

View file

@ -150,7 +150,7 @@ model:
jpeg artifacts, deformed, lowres, over-smooth' jpeg artifacts, deformed, lowres, over-smooth'
SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors
SUPIR_CKPT_F: /opt/data/private/code/D1ff91v/experiments/1162-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_3/checkpoints/SUPIR-v0F.ckpt SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt
SUPIR_CKPT_Q: /opt/data/private/code/D1ff91v/experiments/1163-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_4/checkpoints/epoch=0-step=40000_dumped-ema.ckpt SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt
SUPIR_CKPT: ~ SUPIR_CKPT: ~

11
test.py
View file

@ -6,12 +6,13 @@ from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH from CKPT_PTH import LLAVA_MODEL_PATH
import os import os
if torch.cuda.device_count() >= 2: if torch.cuda.device_count() >= 2:
use_llava = True
else:
use_llava = False
SUPIR_device = 'cuda:0' SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:1' LLaVA_device = 'cuda:1'
elif torch.cuda.device_count() == 1:
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:0'
else:
raise ValueError('Currently support CUDA only.')
# hyparams here # hyparams here
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -45,8 +46,10 @@ parser.add_argument("--spt_linear_CFG", type=float, default=1.0)
parser.add_argument("--spt_linear_s_stage2", type=float, default=0.) parser.add_argument("--spt_linear_s_stage2", type=float, default=0.)
parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16']) parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16'])
parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16']) parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16'])
parser.add_argument("--no_llava", action='store_true', default=False)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
use_llava = not args.no_llava
# load SUPIR # load SUPIR
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device) model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device)