release v0 ckpt

This commit is contained in:
Fanghua-Yu 2024-02-22 23:33:17 +08:00
parent c1c5728601
commit 9e8f6f8f76
5 changed files with 96 additions and 27 deletions

View file

@ -12,7 +12,6 @@
---
## 🔧 Dependencies and Installation
1. Clone repo
```bash
git clone https://github.com/Fanghua-Yu/SUPIR.git
@ -37,11 +36,11 @@
#### Models we provided:
* `SUPIR-v0Q`: (Coming Soon) Google Drive, Baidu Netdisk
* `SUPIR-v0Q`: [Baidu Netdisk](https://pan.baidu.com/s/1lnefCZhBTeDWijqbj1jIyw?pwd=pjq6), Google Drive (Coming Soon)
Default training settings with paper. High generalization and high image quality in most cases.
* `SUPIR-v0F`: (Coming Soon) Google Drive, Baidu Netdisk
* `SUPIR-v0F`: [Baidu Netdisk](https://pan.baidu.com/s/1AECN8NjiVuE3hvO8o-Ua6A?pwd=k2uz), Google Drive (Coming Soon)
Training with light degradation settings. Stage1 encoder of `SUPIR-v0F` remains more details when facing light degradations.
@ -53,11 +52,11 @@
---
## ⚡ Quick Inference
### Val Dataset
RealPhoto60: [Baidu Netdisk](https://pan.baidu.com/s/1CJKsPGtyfs8QEVCQ97voBA?pwd=aocg), Google Drive (Coming Soon)
### Usage of SUPIR
```console
```Shell
Usage:
-- python test.py [options]
-- python gradio_demo.py [interactive options]
@ -102,13 +101,16 @@ CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir '/opt/data/private/LV_Dataset/
### Gradio Demo
```Shell
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_image_slider --log_history
```
<p align="center">
<img src="assets/DemoGuide.png">
</p>
### Online Demo (Coming Soon)
---
## BibTeX

BIN
assets/DemoGuide.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

View file

@ -1,3 +1,5 @@
import os
import gradio as gr
from gradio_imageslider import ImageSlider
import argparse
@ -10,12 +12,14 @@ from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH
import einops
import copy
import time
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='0.0.0.0')
parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--port", type=int, default='6688')
parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
args = parser.parse_args()
server_ip = args.ip
server_port = args.port
@ -41,7 +45,6 @@ if use_llava:
else:
llava_agent = None
def stage1_process(input_image, gamma_correction):
LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512)
@ -69,6 +72,15 @@ def llave_process(input_image, temperature, top_p, qs=None):
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
event_id = str(time.time_ns())
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
's_stage1': s_stage1, 's_stage2': s_stage2, 's_cfg': s_cfg, 'seed': seed, 's_churn': s_churn,
's_noise': s_noise, 'color_fix_type': color_fix_type, 'diff_dtype': diff_dtype, 'ae_dtype': ae_dtype,
'gamma_correction': gamma_correction, 'linear_CFG': linear_CFG, 'linear_s_stage2': linear_s_stage2,
'spt_linear_CFG': spt_linear_CFG, 'spt_linear_s_stage2': spt_linear_s_stage2,
'model_select': model_select}
if model_select != model.current_model:
if model_select == 'v0-Q':
print('load v0-Q')
@ -79,7 +91,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.load_state_dict(ckpt_F, strict=False)
model.current_model = 'v0-F'
input_image = HWC3(input_image)
input_image = upscale_image(input_image, upscale, unit_resolution=32)
input_image = upscale_image(input_image, upscale, unit_resolution=32,
min_size=1024)
LQ = np.array(input_image) / 255.0
LQ = np.power(LQ, gamma_correction)
@ -101,7 +114,16 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
return [input_image] + results
if args.log_history:
os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
f.write(str(event_dict))
f.close()
Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
for i, result in enumerate(results):
Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
return [input_image] + results, event_id, 3, ''
def load_and_reset(param_setting):
edm_steps = 50
@ -118,23 +140,56 @@ def load_and_reset(param_setting):
color_fix_type = 'Wavelet'
spt_linear_CFG = 1.0
spt_linear_s_stage2 = 0.0
linear_s_stage2 = False
if param_setting == "Quality":
s_cfg = 7.5
linear_CFG = False
linear_s_stage2 = True
elif param_setting == "Fidelity":
s_cfg = 4.0
linear_CFG = True
linear_s_stage2 = False
else:
raise NotImplementedError
return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
def submit_feedback(event_id, fb_score, fb_text):
if args.log_history:
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'r') as f:
event_dict = eval(f.read())
f.close()
event_dict['feedback'] = {'score': fb_score, 'text': fb_text}
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
f.write(str(event_dict))
f.close()
return 'Submit successfully, thank you for your comments!'
else:
return 'Submit failed, the server is not set to log history.'
title_md = """
# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
SUPIR is still a research project under tested and is not yet a stable commercial product.
[[Paper](https://arxiv.org/abs/2401.13627)] &emsp; [[Project Page](http://supir.xpixel.group/)] &emsp; [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
"""
claim_md = """
## **Terms of use**
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
## **License**
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
"""
block = gr.Blocks(title='SUPIR').queue()
with block:
with gr.Row():
gr.Markdown("<center><font size=5>SUPIR Playground</font></center>")
gr.Markdown(title_md)
with gr.Row():
with gr.Column():
with gr.Row(equal_height=True):
@ -150,7 +205,8 @@ with block:
with gr.Accordion("LLaVA options", open=False):
temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.")
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
"The image is a realistic photography, not an art painting.")
with gr.Accordion("Stage2 options", open=False):
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
, value=1, step=1)
@ -178,7 +234,7 @@ with block:
spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
maximum=9.0, value=1.0, step=0.5)
with gr.Column():
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=True)
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
maximum=1., value=0., step=0.05)
with gr.Row():
@ -213,8 +269,15 @@ with block:
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
value="Quality")
with gr.Column():
restart_button = gr.Button(value="Reset Param")
restart_button = gr.Button(value="Reset Param", scale=2)
with gr.Accordion("Feedback", open=True):
fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
interactive=True)
fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
submit_button = gr.Button(value="Submit Feedback")
with gr.Row():
gr.Markdown(claim_md)
event_id = gr.Textbox(label="Event ID", value="", visible=False)
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
@ -222,8 +285,9 @@ with block:
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery])
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text])
restart_button.click(fn=load_and_reset, inputs=[param_setting],
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
block.launch(server_name=server_ip, server_port=server_port)

View file

@ -150,7 +150,7 @@ model:
jpeg artifacts, deformed, lowres, over-smooth'
SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors
SUPIR_CKPT_F: /opt/data/private/code/D1ff91v/experiments/1162-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_3/checkpoints/SUPIR-v0F.ckpt
SUPIR_CKPT_Q: /opt/data/private/code/D1ff91v/experiments/1163-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_4/checkpoints/epoch=0-step=40000_dumped-ema.ckpt
SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt
SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt
SUPIR_CKPT: ~

13
test.py
View file

@ -6,12 +6,13 @@ from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH
import os
if torch.cuda.device_count() >= 2:
use_llava = True
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:1'
elif torch.cuda.device_count() == 1:
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:0'
else:
use_llava = False
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:1'
raise ValueError('Currently support CUDA only.')
# hyparams here
parser = argparse.ArgumentParser()
@ -45,8 +46,10 @@ parser.add_argument("--spt_linear_CFG", type=float, default=1.0)
parser.add_argument("--spt_linear_s_stage2", type=float, default=0.)
parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16'])
parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16'])
parser.add_argument("--no_llava", action='store_true', default=False)
args = parser.parse_args()
print(args)
use_llava = not args.no_llava
# load SUPIR
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device)