From 377af5b01d4f113ccc22cd832ac824917922689a Mon Sep 17 00:00:00 2001 From: Fanghua-Yu <1901213025@pku.edu.cn> Date: Mon, 5 Feb 2024 19:35:12 +0800 Subject: [PATCH] 20240205 --- gradio_demo.py | 97 +++++++++++++++++++++++++++++------------------- requirements.txt | 5 ++- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/gradio_demo.py b/gradio_demo.py index 0784c26..87b194e 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -1,23 +1,37 @@ import gradio as gr +from gradio_imageslider import ImageSlider import argparse from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype import numpy as np import torch -if torch.cuda.device_count() >= 2: - use_llava = True -else: - use_llava = False from SUPIR.util import create_SUPIR_model, load_QF_ckpt from PIL import Image from llava.llava_agent import LLavaAgent from CKPT_PTH import LLAVA_MODEL_PATH import einops -SUPIR_device = 'cuda:0' -LLaVA_device = 'cuda:1' +parser = argparse.ArgumentParser() +parser.add_argument("--ip", type=str, default='0.0.0.0') +parser.add_argument("--port", type=int, default='6688') +parser.add_argument("--no_llava", action='store_true', default=False) +parser.add_argument("--use_image_slider", action='store_true', default=False) +args = parser.parse_args() +server_ip = args.ip +server_port = args.port +use_llava = not args.no_llava + +if torch.cuda.device_count() >= 2: + SUPIR_device = 'cuda:0' + LLaVA_device = 'cuda:1' +elif torch.cuda.device_count() == 1: + SUPIR_device = 'cuda:0' + LLaVA_device = 'cuda:0' +else: + raise ValueError('Currently support CUDA only.') # load SUPIR -model = create_SUPIR_model('options/SUPIR_v0.yaml').to(SUPIR_device) +model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device) +model.current_model = 'v0-Q' ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml') # load LLaVA if use_llava: @@ -25,15 +39,13 @@ if use_llava: else: llava_agent = None -parser = argparse.ArgumentParser() -parser.add_argument("--ip", type=str, default='0.0.0.0') -parser.add_argument("--port", type=int, default='6688') -args = parser.parse_args() - -server_ip = args.ip -server_port = args.port def stage1_process(input_image, gamma_correction): + # force to v0-Q + if model.current_model != 'v0-Q': + print('load v0-Q') + model.load_state_dict(ckpt_Q, strict=False) + model.current_model = 'v0-Q' LQ = HWC3(input_image) LQ = fix_resize(LQ, 512) # stage1 @@ -59,7 +71,16 @@ def llave_process(input_image, temperature, top_p, qs=None): def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, - linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2): + linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select): + if model_select != model.current_model: + if model_select == 'v0-Q': + print('load v0-Q') + model.load_state_dict(ckpt_Q, strict=False) + model.current_model = 'v0-Q' + elif model_select == 'v0-F': + print('load v0-F') + model.load_state_dict(ckpt_F, strict=False) + model.current_model = 'v0-F' input_image = HWC3(input_image) input_image = upscale_image(input_image, upscale, unit_resolution=32) @@ -85,17 +106,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale results = [x_samples[i] for i in range(num_samples)] return [input_image] + results -def load_and_reset(model_select, model_info): - _model_select = model_info.replace('
', '').replace('
', '').replace('Current Model: ', '').strip() - if model_select != _model_select: - if model_select == 'v0-Q': - print('load v0-Q') - model.load_state_dict(ckpt_Q, strict=False) - elif model_select == 'v0-F': - print('load v0-F') - model.load_state_dict(ckpt_F, strict=False) - model_info = model_info.replace(_model_select, model_select) - +def load_and_reset(param_setting): + edm_steps = 50 s_stage2 = 1.0 s_stage1 = -1.0 s_churn = 5 @@ -109,17 +121,17 @@ def load_and_reset(model_select, model_info): color_fix_type = 'Wavelet' spt_linear_CFG = 1.0 spt_linear_s_stage2 = 0.0 - if model_select == 'v0-Q': + if param_setting == "Quality": s_cfg = 7.5 linear_CFG = False linear_s_stage2 = True - elif model_select == 'v0-F': + elif param_setting == "Fidelity": s_cfg = 4.0 linear_CFG = True linear_s_stage2 = False else: raise NotImplementedError - return model_info, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ + return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2 block = gr.Blocks(title='SUPIR').queue() @@ -131,7 +143,7 @@ with block: with gr.Row(equal_height=True): with gr.Column(): gr.Markdown("