20240205
This commit is contained in:
parent
f6a31e9563
commit
c1c5728601
2 changed files with 10 additions and 10 deletions
|
@ -45,9 +45,12 @@ class SUPIRModel(DiffusionEngine):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def encode_first_stage_with_denoise(self, x, use_sample=True):
|
def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
|
||||||
with torch.autocast("cuda", dtype=self.ae_dtype):
|
with torch.autocast("cuda", dtype=self.ae_dtype):
|
||||||
h = self.first_stage_model.denoise_encoder(x)
|
if is_stage1:
|
||||||
|
h = self.first_stage_model.denoise_encoder_s1(x)
|
||||||
|
else:
|
||||||
|
h = self.first_stage_model.denoise_encoder(x)
|
||||||
moments = self.first_stage_model.quant_conv(h)
|
moments = self.first_stage_model.quant_conv(h)
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
if use_sample:
|
if use_sample:
|
||||||
|
@ -65,11 +68,11 @@ class SUPIRModel(DiffusionEngine):
|
||||||
return out.float()
|
return out.float()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def batchify_denoise(self, x):
|
def batchify_denoise(self, x, is_stage1=False):
|
||||||
'''
|
'''
|
||||||
[N, C, H, W], [-1, 1], RGB
|
[N, C, H, W], [-1, 1], RGB
|
||||||
'''
|
'''
|
||||||
x = self.encode_first_stage_with_denoise(x, use_sample=False)
|
x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
|
||||||
return self.decode_first_stage(x)
|
return self.decode_first_stage(x)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -9,6 +9,7 @@ from PIL import Image
|
||||||
from llava.llava_agent import LLavaAgent
|
from llava.llava_agent import LLavaAgent
|
||||||
from CKPT_PTH import LLAVA_MODEL_PATH
|
from CKPT_PTH import LLAVA_MODEL_PATH
|
||||||
import einops
|
import einops
|
||||||
|
import copy
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
||||||
|
@ -31,6 +32,7 @@ else:
|
||||||
|
|
||||||
# load SUPIR
|
# load SUPIR
|
||||||
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
|
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
|
||||||
|
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
|
||||||
model.current_model = 'v0-Q'
|
model.current_model = 'v0-Q'
|
||||||
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
|
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
|
||||||
# load LLaVA
|
# load LLaVA
|
||||||
|
@ -41,17 +43,12 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def stage1_process(input_image, gamma_correction):
|
def stage1_process(input_image, gamma_correction):
|
||||||
# force to v0-Q
|
|
||||||
if model.current_model != 'v0-Q':
|
|
||||||
print('load v0-Q')
|
|
||||||
model.load_state_dict(ckpt_Q, strict=False)
|
|
||||||
model.current_model = 'v0-Q'
|
|
||||||
LQ = HWC3(input_image)
|
LQ = HWC3(input_image)
|
||||||
LQ = fix_resize(LQ, 512)
|
LQ = fix_resize(LQ, 512)
|
||||||
# stage1
|
# stage1
|
||||||
LQ = np.array(LQ) / 255 * 2 - 1
|
LQ = np.array(LQ) / 255 * 2 - 1
|
||||||
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
|
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
|
||||||
LQ = model.batchify_denoise(LQ)
|
LQ = model.batchify_denoise(LQ, is_stage1=True)
|
||||||
LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
|
LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
|
||||||
# gamma correction
|
# gamma correction
|
||||||
LQ = LQ / 255.0
|
LQ = LQ / 255.0
|
||||||
|
|
Loading…
Reference in a new issue