diff --git a/README.md b/README.md index c271fa5..ae5d7fc 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ --- ## 🔧 Dependencies and Installation - 1. Clone repo ```bash git clone https://github.com/Fanghua-Yu/SUPIR.git @@ -37,11 +36,11 @@ #### Models we provided: -* `SUPIR-v0Q`: (Coming Soon) Google Drive, Baidu Netdisk +* `SUPIR-v0Q`: [Baidu Netdisk](https://pan.baidu.com/s/1lnefCZhBTeDWijqbj1jIyw?pwd=pjq6), Google Drive (Coming Soon) Default training settings with paper. High generalization and high image quality in most cases. -* `SUPIR-v0F`: (Coming Soon) Google Drive, Baidu Netdisk +* `SUPIR-v0F`: [Baidu Netdisk](https://pan.baidu.com/s/1AECN8NjiVuE3hvO8o-Ua6A?pwd=k2uz), Google Drive (Coming Soon) Training with light degradation settings. Stage1 encoder of `SUPIR-v0F` remains more details when facing light degradations. @@ -53,11 +52,11 @@ --- ## ⚡ Quick Inference - +### Val Dataset +RealPhoto60: [Baidu Netdisk](https://pan.baidu.com/s/1CJKsPGtyfs8QEVCQ97voBA?pwd=aocg), Google Drive (Coming Soon) ### Usage of SUPIR - -```console +```Shell Usage: -- python test.py [options] -- python gradio_demo.py [interactive options] @@ -102,13 +101,16 @@ CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir '/opt/data/private/LV_Dataset/ ### Gradio Demo ```Shell -CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 +CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_image_slider --log_history ``` +

+ +

+ ### Online Demo (Coming Soon) - --- ## BibTeX diff --git a/assets/DemoGuide.png b/assets/DemoGuide.png new file mode 100644 index 0000000..42b2248 Binary files /dev/null and b/assets/DemoGuide.png differ diff --git a/gradio_demo.py b/gradio_demo.py index 578a632..d72056d 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -1,3 +1,5 @@ +import os + import gradio as gr from gradio_imageslider import ImageSlider import argparse @@ -10,12 +12,14 @@ from llava.llava_agent import LLavaAgent from CKPT_PTH import LLAVA_MODEL_PATH import einops import copy +import time parser = argparse.ArgumentParser() -parser.add_argument("--ip", type=str, default='0.0.0.0') +parser.add_argument("--ip", type=str, default='127.0.0.1') parser.add_argument("--port", type=int, default='6688') parser.add_argument("--no_llava", action='store_true', default=False) parser.add_argument("--use_image_slider", action='store_true', default=False) +parser.add_argument("--log_history", action='store_true', default=False) args = parser.parse_args() server_ip = args.ip server_port = args.port @@ -41,7 +45,6 @@ if use_llava: else: llava_agent = None - def stage1_process(input_image, gamma_correction): LQ = HWC3(input_image) LQ = fix_resize(LQ, 512) @@ -69,6 +72,15 @@ def llave_process(input_image, temperature, top_p, qs=None): def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select): + event_id = str(time.time_ns()) + event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt, + 'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps, + 's_stage1': s_stage1, 's_stage2': s_stage2, 's_cfg': s_cfg, 'seed': seed, 's_churn': s_churn, + 's_noise': s_noise, 'color_fix_type': color_fix_type, 'diff_dtype': diff_dtype, 'ae_dtype': ae_dtype, + 'gamma_correction': gamma_correction, 'linear_CFG': linear_CFG, 'linear_s_stage2': linear_s_stage2, + 'spt_linear_CFG': spt_linear_CFG, 'spt_linear_s_stage2': spt_linear_s_stage2, + 'model_select': model_select} + if model_select != model.current_model: if model_select == 'v0-Q': print('load v0-Q') @@ -79,7 +91,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale model.load_state_dict(ckpt_F, strict=False) model.current_model = 'v0-F' input_image = HWC3(input_image) - input_image = upscale_image(input_image, upscale, unit_resolution=32) + input_image = upscale_image(input_image, upscale, unit_resolution=32, + min_size=1024) LQ = np.array(input_image) / 255.0 LQ = np.power(LQ, gamma_correction) @@ -101,7 +114,16 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip( 0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] - return [input_image] + results + + if args.log_history: + os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True) + with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f: + f.write(str(event_dict)) + f.close() + Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png') + for i, result in enumerate(results): + Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png') + return [input_image] + results, event_id, 3, '' def load_and_reset(param_setting): edm_steps = 50 @@ -118,23 +140,56 @@ def load_and_reset(param_setting): color_fix_type = 'Wavelet' spt_linear_CFG = 1.0 spt_linear_s_stage2 = 0.0 + linear_s_stage2 = False if param_setting == "Quality": s_cfg = 7.5 linear_CFG = False - linear_s_stage2 = True elif param_setting == "Fidelity": s_cfg = 4.0 linear_CFG = True - linear_s_stage2 = False else: raise NotImplementedError return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \ linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2 + +def submit_feedback(event_id, fb_score, fb_text): + if args.log_history: + with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'r') as f: + event_dict = eval(f.read()) + f.close() + event_dict['feedback'] = {'score': fb_score, 'text': fb_text} + with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f: + f.write(str(event_dict)) + f.close() + return 'Submit successfully, thank you for your comments!' + else: + return 'Submit failed, the server is not set to log history.' + +title_md = """ +# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration** + +⚠️SUPIR is still a research project under tested and is not yet a stable commercial product. + +[[Paper](https://arxiv.org/abs/2401.13627)]   [[Project Page](http://supir.xpixel.group/)]   [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)] +""" + + +claim_md = """ +## **Terms of use** + +By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. + +## **License** + +The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR. +""" + + block = gr.Blocks(title='SUPIR').queue() with block: with gr.Row(): - gr.Markdown("
SUPIR Playground
") + gr.Markdown(title_md) with gr.Row(): with gr.Column(): with gr.Row(equal_height=True): @@ -150,7 +205,8 @@ with block: with gr.Accordion("LLaVA options", open=False): temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1) top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1) - qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner.") + qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. " + "The image is a realistic photography, not an art painting.") with gr.Accordion("Stage2 options", open=False): num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1 , value=1, step=1) @@ -178,7 +234,7 @@ with block: spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0, maximum=9.0, value=1.0, step=0.5) with gr.Column(): - linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=True) + linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False) spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0., maximum=1., value=0., step=0.05) with gr.Row(): @@ -213,8 +269,15 @@ with block: param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting", value="Quality") with gr.Column(): - restart_button = gr.Button(value="Reset Param") - + restart_button = gr.Button(value="Reset Param", scale=2) + with gr.Accordion("Feedback", open=True): + fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1, + interactive=True) + fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.') + submit_button = gr.Button(value="Submit Feedback") + with gr.Row(): + gr.Markdown(claim_md) + event_id = gr.Textbox(label="Event ID", value="", visible=False) llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt]) denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], @@ -222,8 +285,9 @@ with block: stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select] - diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery]) + diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text]) restart_button.click(fn=load_and_reset, inputs=[param_setting], outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2]) + submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text]) block.launch(server_name=server_ip, server_port=server_port) diff --git a/options/SUPIR_v0.yaml b/options/SUPIR_v0.yaml index c220bb4..ff80312 100644 --- a/options/SUPIR_v0.yaml +++ b/options/SUPIR_v0.yaml @@ -150,7 +150,7 @@ model: jpeg artifacts, deformed, lowres, over-smooth' SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors -SUPIR_CKPT_F: /opt/data/private/code/D1ff91v/experiments/1162-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_3/checkpoints/SUPIR-v0F.ckpt -SUPIR_CKPT_Q: /opt/data/private/code/D1ff91v/experiments/1163-XL_base-ControlGLV-Sharp_RealESR-Mix_P512_B256/lightning_logs/version_4/checkpoints/epoch=0-step=40000_dumped-ema.ckpt +SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt +SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt SUPIR_CKPT: ~ diff --git a/test.py b/test.py index 20b8b44..ac1968a 100644 --- a/test.py +++ b/test.py @@ -6,12 +6,13 @@ from llava.llava_agent import LLavaAgent from CKPT_PTH import LLAVA_MODEL_PATH import os if torch.cuda.device_count() >= 2: - use_llava = True + SUPIR_device = 'cuda:0' + LLaVA_device = 'cuda:1' +elif torch.cuda.device_count() == 1: + SUPIR_device = 'cuda:0' + LLaVA_device = 'cuda:0' else: - use_llava = False - -SUPIR_device = 'cuda:0' -LLaVA_device = 'cuda:1' + raise ValueError('Currently support CUDA only.') # hyparams here parser = argparse.ArgumentParser() @@ -45,8 +46,10 @@ parser.add_argument("--spt_linear_CFG", type=float, default=1.0) parser.add_argument("--spt_linear_s_stage2", type=float, default=0.) parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16']) parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16']) +parser.add_argument("--no_llava", action='store_true', default=False) args = parser.parse_args() print(args) +use_llava = not args.no_llava # load SUPIR model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device)