This commit is contained in:
Fanghua-Yu 2024-02-05 20:47:41 +08:00
parent f6a31e9563
commit c1c5728601
2 changed files with 10 additions and 10 deletions

View file

@ -45,8 +45,11 @@ class SUPIRModel(DiffusionEngine):
return z return z
@torch.no_grad() @torch.no_grad()
def encode_first_stage_with_denoise(self, x, use_sample=True): def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
with torch.autocast("cuda", dtype=self.ae_dtype): with torch.autocast("cuda", dtype=self.ae_dtype):
if is_stage1:
h = self.first_stage_model.denoise_encoder_s1(x)
else:
h = self.first_stage_model.denoise_encoder(x) h = self.first_stage_model.denoise_encoder(x)
moments = self.first_stage_model.quant_conv(h) moments = self.first_stage_model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
@ -65,11 +68,11 @@ class SUPIRModel(DiffusionEngine):
return out.float() return out.float()
@torch.no_grad() @torch.no_grad()
def batchify_denoise(self, x): def batchify_denoise(self, x, is_stage1=False):
''' '''
[N, C, H, W], [-1, 1], RGB [N, C, H, W], [-1, 1], RGB
''' '''
x = self.encode_first_stage_with_denoise(x, use_sample=False) x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
return self.decode_first_stage(x) return self.decode_first_stage(x)
@torch.no_grad() @torch.no_grad()

View file

@ -9,6 +9,7 @@ from PIL import Image
from llava.llava_agent import LLavaAgent from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH from CKPT_PTH import LLAVA_MODEL_PATH
import einops import einops
import copy
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='0.0.0.0') parser.add_argument("--ip", type=str, default='0.0.0.0')
@ -31,6 +32,7 @@ else:
# load SUPIR # load SUPIR
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device) model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
model.current_model = 'v0-Q' model.current_model = 'v0-Q'
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml') ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
# load LLaVA # load LLaVA
@ -41,17 +43,12 @@ else:
def stage1_process(input_image, gamma_correction): def stage1_process(input_image, gamma_correction):
# force to v0-Q
if model.current_model != 'v0-Q':
print('load v0-Q')
model.load_state_dict(ckpt_Q, strict=False)
model.current_model = 'v0-Q'
LQ = HWC3(input_image) LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512) LQ = fix_resize(LQ, 512)
# stage1 # stage1
LQ = np.array(LQ) / 255 * 2 - 1 LQ = np.array(LQ) / 255 * 2 - 1
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :] LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
LQ = model.batchify_denoise(LQ) LQ = model.batchify_denoise(LQ, is_stage1=True)
LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8) LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
# gamma correction # gamma correction
LQ = LQ / 255.0 LQ = LQ / 255.0