diff --git a/SUPIR/models/SUPIR_model.py b/SUPIR/models/SUPIR_model.py index 25367b2..ca97b03 100644 --- a/SUPIR/models/SUPIR_model.py +++ b/SUPIR/models/SUPIR_model.py @@ -45,9 +45,12 @@ class SUPIRModel(DiffusionEngine): return z @torch.no_grad() - def encode_first_stage_with_denoise(self, x, use_sample=True): + def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False): with torch.autocast("cuda", dtype=self.ae_dtype): - h = self.first_stage_model.denoise_encoder(x) + if is_stage1: + h = self.first_stage_model.denoise_encoder_s1(x) + else: + h = self.first_stage_model.denoise_encoder(x) moments = self.first_stage_model.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) if use_sample: @@ -65,11 +68,11 @@ class SUPIRModel(DiffusionEngine): return out.float() @torch.no_grad() - def batchify_denoise(self, x): + def batchify_denoise(self, x, is_stage1=False): ''' [N, C, H, W], [-1, 1], RGB ''' - x = self.encode_first_stage_with_denoise(x, use_sample=False) + x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1) return self.decode_first_stage(x) @torch.no_grad() diff --git a/gradio_demo.py b/gradio_demo.py index 3046d96..578a632 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -9,6 +9,7 @@ from PIL import Image from llava.llava_agent import LLavaAgent from CKPT_PTH import LLAVA_MODEL_PATH import einops +import copy parser = argparse.ArgumentParser() parser.add_argument("--ip", type=str, default='0.0.0.0') @@ -31,6 +32,7 @@ else: # load SUPIR model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device) +model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder) model.current_model = 'v0-Q' ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml') # load LLaVA @@ -41,17 +43,12 @@ else: def stage1_process(input_image, gamma_correction): - # force to v0-Q - if model.current_model != 'v0-Q': - print('load v0-Q') - model.load_state_dict(ckpt_Q, strict=False) - model.current_model = 'v0-Q' LQ = HWC3(input_image) LQ = fix_resize(LQ, 512) # stage1 LQ = np.array(LQ) / 255 * 2 - 1 LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :] - LQ = model.batchify_denoise(LQ) + LQ = model.batchify_denoise(LQ, is_stage1=True) LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8) # gamma correction LQ = LQ / 255.0