release v0 ckpt
This commit is contained in:
parent
c1c5728601
commit
9e8f6f8f76
5 changed files with 96 additions and 27 deletions
18
README.md
18
README.md
|
@ -12,7 +12,6 @@
|
|||
---
|
||||
## 🔧 Dependencies and Installation
|
||||
|
||||
|
||||
1. Clone repo
|
||||
```bash
|
||||
git clone https://github.com/Fanghua-Yu/SUPIR.git
|
||||
|
@ -37,11 +36,11 @@
|
|||
|
||||
|
||||
#### Models we provided:
|
||||
* `SUPIR-v0Q`: (Coming Soon) Google Drive, Baidu Netdisk
|
||||
* `SUPIR-v0Q`: [Baidu Netdisk](https://pan.baidu.com/s/1lnefCZhBTeDWijqbj1jIyw?pwd=pjq6), Google Drive (Coming Soon)
|
||||
|
||||
Default training settings with paper. High generalization and high image quality in most cases.
|
||||
|
||||
* `SUPIR-v0F`: (Coming Soon) Google Drive, Baidu Netdisk
|
||||
* `SUPIR-v0F`: [Baidu Netdisk](https://pan.baidu.com/s/1AECN8NjiVuE3hvO8o-Ua6A?pwd=k2uz), Google Drive (Coming Soon)
|
||||
|
||||
Training with light degradation settings. Stage1 encoder of `SUPIR-v0F` remains more details when facing light degradations.
|
||||
|
||||
|
@ -53,11 +52,11 @@
|
|||
---
|
||||
|
||||
## ⚡ Quick Inference
|
||||
|
||||
### Val Dataset
|
||||
RealPhoto60: [Baidu Netdisk](https://pan.baidu.com/s/1CJKsPGtyfs8QEVCQ97voBA?pwd=aocg), Google Drive (Coming Soon)
|
||||
|
||||
### Usage of SUPIR
|
||||
|
||||
```console
|
||||
```Shell
|
||||
Usage:
|
||||
-- python test.py [options]
|
||||
-- python gradio_demo.py [interactive options]
|
||||
|
@ -102,13 +101,16 @@ CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir '/opt/data/private/LV_Dataset/
|
|||
|
||||
### Gradio Demo
|
||||
```Shell
|
||||
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688
|
||||
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_image_slider --log_history
|
||||
```
|
||||
<p align="center">
|
||||
<img src="assets/DemoGuide.png">
|
||||
</p>
|
||||
|
||||
|
||||
### Online Demo (Coming Soon)
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
## BibTeX
|
||||
|
|
BIN
assets/DemoGuide.png
Normal file
BIN
assets/DemoGuide.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 161 KiB |
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import gradio as gr
|
||||
from gradio_imageslider import ImageSlider
|
||||
import argparse
|
||||
|
@ -10,12 +12,14 @@ from llava.llava_agent import LLavaAgent
|
|||
from CKPT_PTH import LLAVA_MODEL_PATH
|
||||
import einops
|
||||
import copy
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
||||
parser.add_argument("--ip", type=str, default='127.0.0.1')
|
||||
parser.add_argument("--port", type=int, default='6688')
|
||||
parser.add_argument("--no_llava", action='store_true', default=False)
|
||||
parser.add_argument("--use_image_slider", action='store_true', default=False)
|
||||
parser.add_argument("--log_history", action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
server_ip = args.ip
|
||||
server_port = args.port
|
||||
|
@ -41,7 +45,6 @@ if use_llava:
|
|||
else:
|
||||
llava_agent = None
|
||||
|
||||
|
||||
def stage1_process(input_image, gamma_correction):
|
||||
LQ = HWC3(input_image)
|
||||
LQ = fix_resize(LQ, 512)
|
||||
|
@ -69,6 +72,15 @@ def llave_process(input_image, temperature, top_p, qs=None):
|
|||
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
||||
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
|
||||
event_id = str(time.time_ns())
|
||||
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
|
||||
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
|
||||
's_stage1': s_stage1, 's_stage2': s_stage2, 's_cfg': s_cfg, 'seed': seed, 's_churn': s_churn,
|
||||
's_noise': s_noise, 'color_fix_type': color_fix_type, 'diff_dtype': diff_dtype, 'ae_dtype': ae_dtype,
|
||||
'gamma_correction': gamma_correction, 'linear_CFG': linear_CFG, 'linear_s_stage2': linear_s_stage2,
|
||||
'spt_linear_CFG': spt_linear_CFG, 'spt_linear_s_stage2': spt_linear_s_stage2,
|
||||
'model_select': model_select}
|
||||
|
||||
if model_select != model.current_model:
|
||||
if model_select == 'v0-Q':
|
||||
print('load v0-Q')
|
||||
|
@ -79,7 +91,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
|
|||
model.load_state_dict(ckpt_F, strict=False)
|
||||
model.current_model = 'v0-F'
|
||||
input_image = HWC3(input_image)
|
||||
input_image = upscale_image(input_image, upscale, unit_resolution=32)
|
||||
input_image = upscale_image(input_image, upscale, unit_resolution=32,
|
||||
min_size=1024)
|
||||
|
||||
LQ = np.array(input_image) / 255.0
|
||||
LQ = np.power(LQ, gamma_correction)
|
||||
|
@ -101,7 +114,16 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
|
|||
x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
|
||||
0, 255).astype(np.uint8)
|
||||
results = [x_samples[i] for i in range(num_samples)]
|
||||
return [input_image] + results
|
||||
|
||||
if args.log_history:
|
||||
os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
|
||||
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
|
||||
f.write(str(event_dict))
|
||||
f.close()
|
||||
Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
|
||||
for i, result in enumerate(results):
|
||||
Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
|
||||
return [input_image] + results, event_id, 3, ''
|
||||
|
||||
def load_and_reset(param_setting):
|
||||
edm_steps = 50
|
||||
|
@ -118,23 +140,56 @@ def load_and_reset(param_setting):
|
|||
color_fix_type = 'Wavelet'
|
||||
spt_linear_CFG = 1.0
|
||||
spt_linear_s_stage2 = 0.0
|
||||
linear_s_stage2 = False
|
||||
if param_setting == "Quality":
|
||||
s_cfg = 7.5
|
||||
linear_CFG = False
|
||||
linear_s_stage2 = True
|
||||
elif param_setting == "Fidelity":
|
||||
s_cfg = 4.0
|
||||
linear_CFG = True
|
||||
linear_s_stage2 = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
|
||||
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
|
||||
|
||||
|
||||
def submit_feedback(event_id, fb_score, fb_text):
|
||||
if args.log_history:
|
||||
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'r') as f:
|
||||
event_dict = eval(f.read())
|
||||
f.close()
|
||||
event_dict['feedback'] = {'score': fb_score, 'text': fb_text}
|
||||
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
|
||||
f.write(str(event_dict))
|
||||
f.close()
|
||||
return 'Submit successfully, thank you for your comments!'
|
||||
else:
|
||||
return 'Submit failed, the server is not set to log history.'
|
||||
|
||||
title_md = """
|
||||
# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
|
||||
|
||||
⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
|
||||
|
||||
[[Paper](https://arxiv.org/abs/2401.13627)]   [[Project Page](http://supir.xpixel.group/)]   [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
|
||||
"""
|
||||
|
||||
|
||||
claim_md = """
|
||||
## **Terms of use**
|
||||
|
||||
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
||||
|
||||
## **License**
|
||||
|
||||
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
|
||||
"""
|
||||
|
||||
|
||||
block = gr.Blocks(title='SUPIR').queue()
|
||||
with block:
|
||||
with gr.Row():
|
||||
gr.Markdown("<center><font size=5>SUPIR Playground</font></center>")
|
||||
gr.Markdown(title_md)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row(equal_height=True):
|
||||
|
@ -150,7 +205,8 @@ with block:
|
|||
with gr.Accordion("LLaVA options", open=False):
|
||||
temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
|
||||
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
|
||||
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.")
|
||||
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
|
||||
"The image is a realistic photography, not an art painting.")
|
||||
with gr.Accordion("Stage2 options", open=False):
|
||||
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
|
||||
, value=1, step=1)
|
||||
|
@ -178,7 +234,7 @@ with block:
|
|||
spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
|
||||
maximum=9.0, value=1.0, step=0.5)
|
||||
with gr.Column():
|
||||
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=True)
|
||||
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
|
||||
spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
|
||||
maximum=1., value=0., step=0.05)
|
||||
with gr.Row():
|
||||
|
@ -213,8 +269,15 @@ with block:
|
|||
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
|
||||
value="Quality")
|
||||
with gr.Column():
|
||||
restart_button = gr.Button(value="Reset Param")
|
||||
|
||||
restart_button = gr.Button(value="Reset Param", scale=2)
|
||||
with gr.Accordion("Feedback", open=True):
|
||||
fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
|
||||
interactive=True)
|
||||
fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
|
||||
submit_button = gr.Button(value="Submit Feedback")
|
||||
with gr.Row():
|
||||
gr.Markdown(claim_md)
|
||||
event_id = gr.Textbox(label="Event ID", value="", visible=False)
|
||||
|
||||
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
|
||||
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
|
||||
|
@ -222,8 +285,9 @@ with block:
|
|||
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
||||
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
|
||||
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery])
|
||||
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text])
|
||||
restart_button.click(fn=load_and_reset, inputs=[param_setting],
|
||||
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
|
||||
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
|
||||
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
|
||||
block.launch(server_name=server_ip, server_port=server_port)
|
||||
|
|
|
@ -150,7 +150,7 @@ model:
|
|||
jpeg artifacts, deformed, lowres, over-smooth'
|
||||
|
||||
SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors
|
||||
SUPIR_CKPT_F: /opt/data/private/code/D1ff91v/experiments/1162-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_3/checkpoints/SUPIR-v0F.ckpt
|
||||
SUPIR_CKPT_Q: /opt/data/private/code/D1ff91v/experiments/1163-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_4/checkpoints/epoch=0-step=40000_dumped-ema.ckpt
|
||||
SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt
|
||||
SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt
|
||||
SUPIR_CKPT: ~
|
||||
|
||||
|
|
11
test.py
11
test.py
|
@ -6,12 +6,13 @@ from llava.llava_agent import LLavaAgent
|
|||
from CKPT_PTH import LLAVA_MODEL_PATH
|
||||
import os
|
||||
if torch.cuda.device_count() >= 2:
|
||||
use_llava = True
|
||||
else:
|
||||
use_llava = False
|
||||
|
||||
SUPIR_device = 'cuda:0'
|
||||
LLaVA_device = 'cuda:1'
|
||||
elif torch.cuda.device_count() == 1:
|
||||
SUPIR_device = 'cuda:0'
|
||||
LLaVA_device = 'cuda:0'
|
||||
else:
|
||||
raise ValueError('Currently support CUDA only.')
|
||||
|
||||
# hyparams here
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -45,8 +46,10 @@ parser.add_argument("--spt_linear_CFG", type=float, default=1.0)
|
|||
parser.add_argument("--spt_linear_s_stage2", type=float, default=0.)
|
||||
parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16'])
|
||||
parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16'])
|
||||
parser.add_argument("--no_llava", action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
use_llava = not args.no_llava
|
||||
|
||||
# load SUPIR
|
||||
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device)
|
||||
|
|
Loading…
Reference in a new issue