20240205
This commit is contained in:
parent
12b8fa4c77
commit
377af5b01d
2 changed files with 61 additions and 41 deletions
|
@ -1,23 +1,37 @@
|
|||
import gradio as gr
|
||||
from gradio_imageslider import ImageSlider
|
||||
import argparse
|
||||
from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype
|
||||
import numpy as np
|
||||
import torch
|
||||
if torch.cuda.device_count() >= 2:
|
||||
use_llava = True
|
||||
else:
|
||||
use_llava = False
|
||||
from SUPIR.util import create_SUPIR_model, load_QF_ckpt
|
||||
from PIL import Image
|
||||
from llava.llava_agent import LLavaAgent
|
||||
from CKPT_PTH import LLAVA_MODEL_PATH
|
||||
import einops
|
||||
|
||||
SUPIR_device = 'cuda:0'
|
||||
LLaVA_device = 'cuda:1'
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
||||
parser.add_argument("--port", type=int, default='6688')
|
||||
parser.add_argument("--no_llava", action='store_true', default=False)
|
||||
parser.add_argument("--use_image_slider", action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
server_ip = args.ip
|
||||
server_port = args.port
|
||||
use_llava = not args.no_llava
|
||||
|
||||
if torch.cuda.device_count() >= 2:
|
||||
SUPIR_device = 'cuda:0'
|
||||
LLaVA_device = 'cuda:1'
|
||||
elif torch.cuda.device_count() == 1:
|
||||
SUPIR_device = 'cuda:0'
|
||||
LLaVA_device = 'cuda:0'
|
||||
else:
|
||||
raise ValueError('Currently support CUDA only.')
|
||||
|
||||
# load SUPIR
|
||||
model = create_SUPIR_model('options/SUPIR_v0.yaml').to(SUPIR_device)
|
||||
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
|
||||
model.current_model = 'v0-Q'
|
||||
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
|
||||
# load LLaVA
|
||||
if use_llava:
|
||||
|
@ -25,15 +39,13 @@ if use_llava:
|
|||
else:
|
||||
llava_agent = None
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
||||
parser.add_argument("--port", type=int, default='6688')
|
||||
args = parser.parse_args()
|
||||
|
||||
server_ip = args.ip
|
||||
server_port = args.port
|
||||
|
||||
def stage1_process(input_image, gamma_correction):
|
||||
# force to v0-Q
|
||||
if model.current_model != 'v0-Q':
|
||||
print('load v0-Q')
|
||||
model.load_state_dict(ckpt_Q, strict=False)
|
||||
model.current_model = 'v0-Q'
|
||||
LQ = HWC3(input_image)
|
||||
LQ = fix_resize(LQ, 512)
|
||||
# stage1
|
||||
|
@ -59,7 +71,16 @@ def llave_process(input_image, temperature, top_p, qs=None):
|
|||
|
||||
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
||||
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2):
|
||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
|
||||
if model_select != model.current_model:
|
||||
if model_select == 'v0-Q':
|
||||
print('load v0-Q')
|
||||
model.load_state_dict(ckpt_Q, strict=False)
|
||||
model.current_model = 'v0-Q'
|
||||
elif model_select == 'v0-F':
|
||||
print('load v0-F')
|
||||
model.load_state_dict(ckpt_F, strict=False)
|
||||
model.current_model = 'v0-F'
|
||||
input_image = HWC3(input_image)
|
||||
input_image = upscale_image(input_image, upscale, unit_resolution=32)
|
||||
|
||||
|
@ -85,17 +106,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
|
|||
results = [x_samples[i] for i in range(num_samples)]
|
||||
return [input_image] + results
|
||||
|
||||
def load_and_reset(model_select, model_info):
|
||||
_model_select = model_info.replace('<p>', '').replace('</p>', '').replace('Current Model: ', '').strip()
|
||||
if model_select != _model_select:
|
||||
if model_select == 'v0-Q':
|
||||
print('load v0-Q')
|
||||
model.load_state_dict(ckpt_Q, strict=False)
|
||||
elif model_select == 'v0-F':
|
||||
print('load v0-F')
|
||||
model.load_state_dict(ckpt_F, strict=False)
|
||||
model_info = model_info.replace(_model_select, model_select)
|
||||
|
||||
def load_and_reset(param_setting):
|
||||
edm_steps = 50
|
||||
s_stage2 = 1.0
|
||||
s_stage1 = -1.0
|
||||
s_churn = 5
|
||||
|
@ -109,17 +121,17 @@ def load_and_reset(model_select, model_info):
|
|||
color_fix_type = 'Wavelet'
|
||||
spt_linear_CFG = 1.0
|
||||
spt_linear_s_stage2 = 0.0
|
||||
if model_select == 'v0-Q':
|
||||
if param_setting == "Quality":
|
||||
s_cfg = 7.5
|
||||
linear_CFG = False
|
||||
linear_s_stage2 = True
|
||||
elif model_select == 'v0-F':
|
||||
elif param_setting == "Fidelity":
|
||||
s_cfg = 4.0
|
||||
linear_CFG = True
|
||||
linear_s_stage2 = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
|
||||
return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
|
||||
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2
|
||||
|
||||
block = gr.Blocks(title='SUPIR').queue()
|
||||
|
@ -131,7 +143,7 @@ with block:
|
|||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
gr.Markdown("<center>Input</center>")
|
||||
input_image = gr.Image(sources='upload', type="numpy", elem_id="image-input")
|
||||
input_image = gr.Image(type="numpy", elem_id="image-input")
|
||||
with gr.Column():
|
||||
gr.Markdown("<center>Stage1 Output</center>")
|
||||
denoise_image = gr.Image(type="numpy", elem_id="image-s1")
|
||||
|
@ -143,7 +155,8 @@ with block:
|
|||
top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
|
||||
qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.")
|
||||
with gr.Accordion("Stage2 options", open=False):
|
||||
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4, value=1, step=1)
|
||||
num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
|
||||
, value=1, step=1)
|
||||
upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
|
||||
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
|
||||
s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
|
||||
|
@ -181,10 +194,16 @@ with block:
|
|||
with gr.Column():
|
||||
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
|
||||
interactive=True)
|
||||
with gr.Column():
|
||||
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
|
||||
interactive=True)
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("<center>Stage2 Output</center>")
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1", rows=2, columns=1)
|
||||
if not args.use_image_slider:
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
|
||||
else:
|
||||
result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
denoise_button = gr.Button(value="Stage1 Run")
|
||||
|
@ -194,20 +213,20 @@ with block:
|
|||
diffusion_button = gr.Button(value="Stage2 Run")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_select = gr.Dropdown(["v0-Q", "v0-F"], interactive=True, label="Model List",
|
||||
value="v0-Q")
|
||||
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
|
||||
value="Quality")
|
||||
with gr.Column():
|
||||
restart_button = gr.Button(value="Load & Reset")
|
||||
model_info = gr.Markdown(f"Current Model: {model_select.value}")
|
||||
restart_button = gr.Button(value="Reset Param")
|
||||
|
||||
|
||||
llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
|
||||
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
|
||||
outputs=[denoise_image])
|
||||
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
||||
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]
|
||||
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
|
||||
diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery])
|
||||
restart_button.click(fn=load_and_reset, inputs=[model_select, model_info],
|
||||
outputs=[model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
|
||||
restart_button.click(fn=load_and_reset, inputs=[param_setting],
|
||||
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
|
||||
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
|
||||
block.launch(server_name=server_ip, server_port=server_port)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
fastapi==0.95.1
|
||||
gradio==3.4.0
|
||||
gradio==4.16.0
|
||||
gradio_imageslider==0.0.17
|
||||
Markdown==3.4.1
|
||||
numpy==1.24.2
|
||||
requests==2.28.2
|
||||
|
@ -35,4 +36,4 @@ tqdm==4.65.0
|
|||
triton==2.1.0
|
||||
urllib3==1.26.15
|
||||
webdataset==0.2.48
|
||||
xformers>=0.0.20
|
||||
xformers>=0.0.20
|
||||
|
|
Loading…
Add table
Reference in a new issue