This commit is contained in:
Fanghua-Yu 2024-02-05 19:35:12 +08:00
parent 12b8fa4c77
commit 377af5b01d
2 changed files with 61 additions and 41 deletions

View file

@ -1,23 +1,37 @@
import gradio as gr import gradio as gr
from gradio_imageslider import ImageSlider
import argparse import argparse
from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype
import numpy as np import numpy as np
import torch import torch
if torch.cuda.device_count() >= 2:
use_llava = True
else:
use_llava = False
from SUPIR.util import create_SUPIR_model, load_QF_ckpt from SUPIR.util import create_SUPIR_model, load_QF_ckpt
from PIL import Image from PIL import Image
from llava.llava_agent import LLavaAgent from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH from CKPT_PTH import LLAVA_MODEL_PATH
import einops import einops
SUPIR_device = 'cuda:0' parser = argparse.ArgumentParser()
LLaVA_device = 'cuda:1' parser.add_argument("--ip", type=str, default='0.0.0.0')
parser.add_argument("--port", type=int, default='6688')
parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--use_image_slider", action='store_true', default=False)
args = parser.parse_args()
server_ip = args.ip
server_port = args.port
use_llava = not args.no_llava
if torch.cuda.device_count() >= 2:
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:1'
elif torch.cuda.device_count() == 1:
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:0'
else:
raise ValueError('Currently support CUDA only.')
# load SUPIR # load SUPIR
model = create_SUPIR_model('options/SUPIR_v0.yaml').to(SUPIR_device) model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
model.current_model = 'v0-Q'
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml') ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
# load LLaVA # load LLaVA
if use_llava: if use_llava:
@ -25,15 +39,13 @@ if use_llava:
else: else:
llava_agent = None llava_agent = None
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='0.0.0.0')
parser.add_argument("--port", type=int, default='6688')
args = parser.parse_args()
server_ip = args.ip
server_port = args.port
def stage1_process(input_image, gamma_correction): def stage1_process(input_image, gamma_correction):
# force to v0-Q
if model.current_model != 'v0-Q':
print('load v0-Q')
model.load_state_dict(ckpt_Q, strict=False)
model.current_model = 'v0-Q'
LQ = HWC3(input_image) LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512) LQ = fix_resize(LQ, 512)
# stage1 # stage1
@ -59,7 +71,16 @@ def llave_process(input_image, temperature, top_p, qs=None):
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2): linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
if model_select != model.current_model:
if model_select == 'v0-Q':
print('load v0-Q')
model.load_state_dict(ckpt_Q, strict=False)
model.current_model = 'v0-Q'
elif model_select == 'v0-F':
print('load v0-F')
model.load_state_dict(ckpt_F, strict=False)
model.current_model = 'v0-F'
input_image = HWC3(input_image) input_image = HWC3(input_image)
input_image = upscale_image(input_image, upscale, unit_resolution=32) input_image = upscale_image(input_image, upscale, unit_resolution=32)
@ -85,17 +106,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
results = [x_samples[i] for i in range(num_samples)] results = [x_samples[i] for i in range(num_samples)]
return [input_image] + results return [input_image] + results
def load_and_reset(model_select, model_info): def load_and_reset(param_setting):
_model_select = model_info.replace('<p>', '').replace('</p>', '').replace('Current Model: ', '').strip() edm_steps = 50
if model_select != _model_select:
if model_select == 'v0-Q':
print('load v0-Q')
model.load_state_dict(ckpt_Q, strict=False)
elif model_select == 'v0-F':
print('load v0-F')
model.load_state_dict(ckpt_F, strict=False)
model_info = model_info.replace(_model_select, model_select)
s_stage2 = 1.0 s_stage2 = 1.0
s_stage1 = -1.0 s_stage1 = -1.0
s_churn = 5 s_churn = 5
@ -109,17 +121,17 @@ def load_and_reset(model_select, model_info):
color_fix_type = 'Wavelet' color_fix_type = 'Wavelet'
spt_linear_CFG = 1.0 spt_linear_CFG = 1.0
spt_linear_s_stage2 = 0.0 spt_linear_s_stage2 = 0.0
if model_select == 'v0-Q': if param_setting == "Quality":
s_cfg = 7.5 s_cfg = 7.5
linear_CFG = False linear_CFG = False
linear_s_stage2 = True linear_s_stage2 = True
elif model_select == 'v0-F': elif param_setting == "Fidelity":
s_cfg = 4.0 s_cfg = 4.0
linear_CFG = True linear_CFG = True
linear_s_stage2 = False linear_s_stage2 = False
else: else:
raise NotImplementedError raise NotImplementedError
return model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2 linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
block = gr.Blocks(title='SUPIR').queue() block = gr.Blocks(title='SUPIR').queue()
@ -131,7 +143,7 @@ with block:
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
with gr.Column(): with gr.Column():
gr.Markdown("<center>Input</center>") gr.Markdown("<center>Input</center>")
input_image = gr.Image(sources='upload', type="numpy", elem_id="image-input") input_image = gr.Image(type="numpy", elem_id="image-input")
with gr.Column(): with gr.Column():
gr.Markdown("<center>Stage1 Output</center>") gr.Markdown("<center>Stage1 Output</center>")
denoise_image = gr.Image(type="numpy", elem_id="image-s1") denoise_image = gr.Image(type="numpy", elem_id="image-s1")
@ -143,7 +155,8 @@ with block:
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1) top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.") qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.")
with gr.Accordion("Stage2 options", open=False): with gr.Accordion("Stage2 options", open=False):
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4, value=1, step=1) num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
, value=1, step=1)
upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1) upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1) edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1) s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
@ -181,10 +194,16 @@ with block:
with gr.Column(): with gr.Column():
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet", color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
interactive=True) interactive=True)
with gr.Column():
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
interactive=True)
with gr.Column(): with gr.Column():
gr.Markdown("<center>Stage2 Output</center>") gr.Markdown("<center>Stage2 Output</center>")
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1", rows=2, columns=1) if not args.use_image_slider:
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
else:
result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
denoise_button = gr.Button(value="Stage1 Run") denoise_button = gr.Button(value="Stage1 Run")
@ -194,20 +213,20 @@ with block:
diffusion_button = gr.Button(value="Stage2 Run") diffusion_button = gr.Button(value="Stage2 Run")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
model_select = gr.Dropdown(["v0-Q", "v0-F"], interactive=True, label="Model List", param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
value="v0-Q") value="Quality")
with gr.Column(): with gr.Column():
restart_button = gr.Button(value="Load & Reset") restart_button = gr.Button(value="Reset Param")
model_info = gr.Markdown(f"Current Model: {model_select.value}")
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt]) llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
outputs=[denoise_image]) outputs=[denoise_image])
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2] linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery]) diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery])
restart_button.click(fn=load_and_reset, inputs=[model_select, model_info], restart_button.click(fn=load_and_reset, inputs=[param_setting],
outputs=[model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]) color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
block.launch(server_name=server_ip, server_port=server_port) block.launch(server_name=server_ip, server_port=server_port)

View file

@ -1,5 +1,6 @@
fastapi==0.95.1 fastapi==0.95.1
gradio==3.4.0 gradio==4.16.0
gradio_imageslider==0.0.17
Markdown==3.4.1 Markdown==3.4.1
numpy==1.24.2 numpy==1.24.2
requests==2.28.2 requests==2.28.2
@ -35,4 +36,4 @@ tqdm==4.65.0
triton==2.1.0 triton==2.1.0
urllib3==1.26.15 urllib3==1.26.15
webdataset==0.2.48 webdataset==0.2.48
xformers>=0.0.20 xformers>=0.0.20