From 377af5b01d4f113ccc22cd832ac824917922689a Mon Sep 17 00:00:00 2001 From: Fanghua-Yu <1901213025@pku.edu.cn> Date: Mon, 5 Feb 2024 19:35:12 +0800 Subject: [PATCH] 20240205 --- gradio_demo.py | 97 +++++++++++++++++++++++++++++------------------- requirements.txt | 5 ++- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/gradio_demo.py b/gradio_demo.py index 0784c26..87b194e 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -1,23 +1,37 @@ import gradio as gr +from gradio_imageslider import ImageSlider import argparse from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype import numpy as np import torch -if torch.cuda.device_count() >= 2: - use_llava = True -else: - use_llava = False from SUPIR.util import create_SUPIR_model, load_QF_ckpt from PIL import Image from llava.llava_agent import LLavaAgent from CKPT_PTH import LLAVA_MODEL_PATH import einops -SUPIR_device = 'cuda:0' -LLaVA_device = 'cuda:1' +parser = argparse.ArgumentParser() +parser.add_argument("--ip", type=str, default='0.0.0.0') +parser.add_argument("--port", type=int, default='6688') +parser.add_argument("--no_llava", action='store_true', default=False) +parser.add_argument("--use_image_slider", action='store_true', default=False) +args = parser.parse_args() +server_ip = args.ip +server_port = args.port +use_llava = not args.no_llava + +if torch.cuda.device_count() >= 2: + SUPIR_device = 'cuda:0' + LLaVA_device = 'cuda:1' +elif torch.cuda.device_count() == 1: + SUPIR_device = 'cuda:0' + LLaVA_device = 'cuda:0' +else: + raise ValueError('Currently support CUDA only.') # load SUPIR -model = create_SUPIR_model('options/SUPIR_v0.yaml').to(SUPIR_device) +model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device) +model.current_model = 'v0-Q' ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml') # load LLaVA if use_llava: @@ -25,15 +39,13 @@ if use_llava: else: llava_agent = None -parser = argparse.ArgumentParser() -parser.add_argument("--ip", type=str, default='0.0.0.0') -parser.add_argument("--port", type=int, default='6688') -args = parser.parse_args() - -server_ip = args.ip -server_port = args.port def stage1_process(input_image, gamma_correction): + # force to v0-Q + if model.current_model != 'v0-Q': + print('load v0-Q') + model.load_state_dict(ckpt_Q, strict=False) + model.current_model = 'v0-Q' LQ = HWC3(input_image) LQ = fix_resize(LQ, 512) # stage1 @@ -59,7 +71,16 @@ def llave_process(input_image, temperature, top_p, qs=None): def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, - linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2): + linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select): + if model_select != model.current_model: + if model_select == 'v0-Q': + print('load v0-Q') + model.load_state_dict(ckpt_Q, strict=False) + model.current_model = 'v0-Q' + elif model_select == 'v0-F': + print('load v0-F') + model.load_state_dict(ckpt_F, strict=False) + model.current_model = 'v0-F' input_image = HWC3(input_image) input_image = upscale_image(input_image, upscale, unit_resolution=32) @@ -85,17 +106,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale results = [x_samples[i] for i in range(num_samples)] return [input_image] + results -def load_and_reset(model_select, model_info): - _model_select = model_info.replace('

', '').replace('

', '').replace('Current Model: ', '').strip() - if model_select != _model_select: - if model_select == 'v0-Q': - print('load v0-Q') - model.load_state_dict(ckpt_Q, strict=False) - elif model_select == 'v0-F': - print('load v0-F') - model.load_state_dict(ckpt_F, strict=False) - model_info = model_info.replace(_model_select, model_select) - +def load_and_reset(param_setting): + edm_steps = 50 s_stage2 = 1.0 s_stage1 = -1.0 s_churn = 5 @@ -109,17 +121,17 @@ def load_and_reset(model_select, model_info): color_fix_type = 'Wavelet' spt_linear_CFG = 1.0 spt_linear_s_stage2 = 0.0 - if model_select == 'v0-Q': + if param_setting == "Quality": s_cfg = 7.5 linear_CFG = False linear_s_stage2 = True - elif model_select == 'v0-F': + elif param_setting == "Fidelity": s_cfg = 4.0 linear_CFG = True linear_s_stage2 = False else: raise NotImplementedError - return model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ + return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2 block = gr.Blocks(title='SUPIR').queue() @@ -131,7 +143,7 @@ with block: with gr.Row(equal_height=True): with gr.Column(): gr.Markdown("
Input
") - input_image = gr.Image(sources='upload', type="numpy", elem_id="image-input") + input_image = gr.Image(type="numpy", elem_id="image-input") with gr.Column(): gr.Markdown("
Stage1 Output
") denoise_image = gr.Image(type="numpy", elem_id="image-s1") @@ -143,7 +155,8 @@ with block: top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1) qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.") with gr.Accordion("Stage2 options", open=False): - num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4, value=1, step=1) + num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1 + , value=1, step=1) upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1) edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1) s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1) @@ -181,10 +194,16 @@ with block: with gr.Column(): color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet", interactive=True) + with gr.Column(): + model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q", + interactive=True) with gr.Column(): gr.Markdown("
Stage2 Output
") - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1", rows=2, columns=1) + if not args.use_image_slider: + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1") + else: + result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1") with gr.Row(): with gr.Column(): denoise_button = gr.Button(value="Stage1 Run") @@ -194,20 +213,20 @@ with block: diffusion_button = gr.Button(value="Stage2 Run") with gr.Row(): with gr.Column(): - model_select = gr.Dropdown(["v0-Q", "v0-F"], interactive=True, label="Model List", - value="v0-Q") + param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting", + value="Quality") with gr.Column(): - restart_button = gr.Button(value="Load & Reset") - model_info = gr.Markdown(f"Current Model: {model_select.value}") + restart_button = gr.Button(value="Reset Param") + llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt]) denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], outputs=[denoise_image]) stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, - linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2] + linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select] diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery]) - restart_button.click(fn=load_and_reset, inputs=[model_select, model_info], - outputs=[model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, + restart_button.click(fn=load_and_reset, inputs=[param_setting], + outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]) block.launch(server_name=server_ip, server_port=server_port) diff --git a/requirements.txt b/requirements.txt index ac21b35..f277bba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ fastapi==0.95.1 -gradio==3.4.0 +gradio==4.16.0 +gradio_imageslider==0.0.17 Markdown==3.4.1 numpy==1.24.2 requests==2.28.2 @@ -35,4 +36,4 @@ tqdm==4.65.0 triton==2.1.0 urllib3==1.26.15 webdataset==0.2.48 -xformers>=0.0.20 \ No newline at end of file +xformers>=0.0.20