20240205
This commit is contained in:
parent
12b8fa4c77
commit
377af5b01d
2 changed files with 61 additions and 41 deletions
|
@ -1,23 +1,37 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from gradio_imageslider import ImageSlider
|
||||||
import argparse
|
import argparse
|
||||||
from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype
|
from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.device_count() >= 2:
|
|
||||||
use_llava = True
|
|
||||||
else:
|
|
||||||
use_llava = False
|
|
||||||
from SUPIR.util import create_SUPIR_model, load_QF_ckpt
|
from SUPIR.util import create_SUPIR_model, load_QF_ckpt
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from llava.llava_agent import LLavaAgent
|
from llava.llava_agent import LLavaAgent
|
||||||
from CKPT_PTH import LLAVA_MODEL_PATH
|
from CKPT_PTH import LLAVA_MODEL_PATH
|
||||||
import einops
|
import einops
|
||||||
|
|
||||||
SUPIR_device = 'cuda:0'
|
parser = argparse.ArgumentParser()
|
||||||
LLaVA_device = 'cuda:1'
|
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
||||||
|
parser.add_argument("--port", type=int, default='6688')
|
||||||
|
parser.add_argument("--no_llava", action='store_true', default=False)
|
||||||
|
parser.add_argument("--use_image_slider", action='store_true', default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_ip = args.ip
|
||||||
|
server_port = args.port
|
||||||
|
use_llava = not args.no_llava
|
||||||
|
|
||||||
|
if torch.cuda.device_count() >= 2:
|
||||||
|
SUPIR_device = 'cuda:0'
|
||||||
|
LLaVA_device = 'cuda:1'
|
||||||
|
elif torch.cuda.device_count() == 1:
|
||||||
|
SUPIR_device = 'cuda:0'
|
||||||
|
LLaVA_device = 'cuda:0'
|
||||||
|
else:
|
||||||
|
raise ValueError('Currently support CUDA only.')
|
||||||
|
|
||||||
# load SUPIR
|
# load SUPIR
|
||||||
model = create_SUPIR_model('options/SUPIR_v0.yaml').to(SUPIR_device)
|
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
|
||||||
|
model.current_model = 'v0-Q'
|
||||||
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
|
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
|
||||||
# load LLaVA
|
# load LLaVA
|
||||||
if use_llava:
|
if use_llava:
|
||||||
|
@ -25,15 +39,13 @@ if use_llava:
|
||||||
else:
|
else:
|
||||||
llava_agent = None
|
llava_agent = None
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
|
||||||
parser.add_argument("--port", type=int, default='6688')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
server_ip = args.ip
|
|
||||||
server_port = args.port
|
|
||||||
|
|
||||||
def stage1_process(input_image, gamma_correction):
|
def stage1_process(input_image, gamma_correction):
|
||||||
|
# force to v0-Q
|
||||||
|
if model.current_model != 'v0-Q':
|
||||||
|
print('load v0-Q')
|
||||||
|
model.load_state_dict(ckpt_Q, strict=False)
|
||||||
|
model.current_model = 'v0-Q'
|
||||||
LQ = HWC3(input_image)
|
LQ = HWC3(input_image)
|
||||||
LQ = fix_resize(LQ, 512)
|
LQ = fix_resize(LQ, 512)
|
||||||
# stage1
|
# stage1
|
||||||
|
@ -59,7 +71,16 @@ def llave_process(input_image, temperature, top_p, qs=None):
|
||||||
|
|
||||||
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
||||||
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
||||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2):
|
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
|
||||||
|
if model_select != model.current_model:
|
||||||
|
if model_select == 'v0-Q':
|
||||||
|
print('load v0-Q')
|
||||||
|
model.load_state_dict(ckpt_Q, strict=False)
|
||||||
|
model.current_model = 'v0-Q'
|
||||||
|
elif model_select == 'v0-F':
|
||||||
|
print('load v0-F')
|
||||||
|
model.load_state_dict(ckpt_F, strict=False)
|
||||||
|
model.current_model = 'v0-F'
|
||||||
input_image = HWC3(input_image)
|
input_image = HWC3(input_image)
|
||||||
input_image = upscale_image(input_image, upscale, unit_resolution=32)
|
input_image = upscale_image(input_image, upscale, unit_resolution=32)
|
||||||
|
|
||||||
|
@ -85,17 +106,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
|
||||||
results = [x_samples[i] for i in range(num_samples)]
|
results = [x_samples[i] for i in range(num_samples)]
|
||||||
return [input_image] + results
|
return [input_image] + results
|
||||||
|
|
||||||
def load_and_reset(model_select, model_info):
|
def load_and_reset(param_setting):
|
||||||
_model_select = model_info.replace('<p>', '').replace('</p>', '').replace('Current Model: ', '').strip()
|
edm_steps = 50
|
||||||
if model_select != _model_select:
|
|
||||||
if model_select == 'v0-Q':
|
|
||||||
print('load v0-Q')
|
|
||||||
model.load_state_dict(ckpt_Q, strict=False)
|
|
||||||
elif model_select == 'v0-F':
|
|
||||||
print('load v0-F')
|
|
||||||
model.load_state_dict(ckpt_F, strict=False)
|
|
||||||
model_info = model_info.replace(_model_select, model_select)
|
|
||||||
|
|
||||||
s_stage2 = 1.0
|
s_stage2 = 1.0
|
||||||
s_stage1 = -1.0
|
s_stage1 = -1.0
|
||||||
s_churn = 5
|
s_churn = 5
|
||||||
|
@ -109,17 +121,17 @@ def load_and_reset(model_select, model_info):
|
||||||
color_fix_type = 'Wavelet'
|
color_fix_type = 'Wavelet'
|
||||||
spt_linear_CFG = 1.0
|
spt_linear_CFG = 1.0
|
||||||
spt_linear_s_stage2 = 0.0
|
spt_linear_s_stage2 = 0.0
|
||||||
if model_select == 'v0-Q':
|
if param_setting == "Quality":
|
||||||
s_cfg = 7.5
|
s_cfg = 7.5
|
||||||
linear_CFG = False
|
linear_CFG = False
|
||||||
linear_s_stage2 = True
|
linear_s_stage2 = True
|
||||||
elif model_select == 'v0-F':
|
elif param_setting == "Fidelity":
|
||||||
s_cfg = 4.0
|
s_cfg = 4.0
|
||||||
linear_CFG = True
|
linear_CFG = True
|
||||||
linear_s_stage2 = False
|
linear_s_stage2 = False
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
|
return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
|
||||||
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
|
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
|
||||||
|
|
||||||
block = gr.Blocks(title='SUPIR').queue()
|
block = gr.Blocks(title='SUPIR').queue()
|
||||||
|
@ -131,7 +143,7 @@ with block:
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("<center>Input</center>")
|
gr.Markdown("<center>Input</center>")
|
||||||
input_image = gr.Image(sources='upload', type="numpy", elem_id="image-input")
|
input_image = gr.Image(type="numpy", elem_id="image-input")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("<center>Stage1 Output</center>")
|
gr.Markdown("<center>Stage1 Output</center>")
|
||||||
denoise_image = gr.Image(type="numpy", elem_id="image-s1")
|
denoise_image = gr.Image(type="numpy", elem_id="image-s1")
|
||||||
|
@ -143,7 +155,8 @@ with block:
|
||||||
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
|
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
|
||||||
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.")
|
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.")
|
||||||
with gr.Accordion("Stage2 options", open=False):
|
with gr.Accordion("Stage2 options", open=False):
|
||||||
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4, value=1, step=1)
|
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
|
||||||
|
, value=1, step=1)
|
||||||
upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
|
upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
|
||||||
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
|
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
|
||||||
s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
|
s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
|
||||||
|
@ -181,10 +194,16 @@ with block:
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
|
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
|
||||||
interactive=True)
|
interactive=True)
|
||||||
|
with gr.Column():
|
||||||
|
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
|
||||||
|
interactive=True)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("<center>Stage2 Output</center>")
|
gr.Markdown("<center>Stage2 Output</center>")
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1", rows=2, columns=1)
|
if not args.use_image_slider:
|
||||||
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
|
||||||
|
else:
|
||||||
|
result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
denoise_button = gr.Button(value="Stage1 Run")
|
denoise_button = gr.Button(value="Stage1 Run")
|
||||||
|
@ -194,20 +213,20 @@ with block:
|
||||||
diffusion_button = gr.Button(value="Stage2 Run")
|
diffusion_button = gr.Button(value="Stage2 Run")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
model_select = gr.Dropdown(["v0-Q", "v0-F"], interactive=True, label="Model List",
|
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
|
||||||
value="v0-Q")
|
value="Quality")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
restart_button = gr.Button(value="Load & Reset")
|
restart_button = gr.Button(value="Reset Param")
|
||||||
model_info = gr.Markdown(f"Current Model: {model_select.value}")
|
|
||||||
|
|
||||||
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
|
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
|
||||||
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
|
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
|
||||||
outputs=[denoise_image])
|
outputs=[denoise_image])
|
||||||
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
||||||
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
||||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]
|
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
|
||||||
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery])
|
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery])
|
||||||
restart_button.click(fn=load_and_reset, inputs=[model_select, model_info],
|
restart_button.click(fn=load_and_reset, inputs=[param_setting],
|
||||||
outputs=[model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
|
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
|
||||||
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
|
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
|
||||||
block.launch(server_name=server_ip, server_port=server_port)
|
block.launch(server_name=server_ip, server_port=server_port)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
fastapi==0.95.1
|
fastapi==0.95.1
|
||||||
gradio==3.4.0
|
gradio==4.16.0
|
||||||
|
gradio_imageslider==0.0.17
|
||||||
Markdown==3.4.1
|
Markdown==3.4.1
|
||||||
numpy==1.24.2
|
numpy==1.24.2
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
|
|
Loading…
Reference in a new issue