学习前言
用了好久的Stable Diffusion,但从来没有好好剖析过它内部的结构,写个博客记载一下,嘿嘿。
源码下载所在
https://github.com/bubbliiiing/stable-diffusion
喜欢的可以点个star噢。
网络构建
一、什么是Stable Diffusion(SD)
Stable Diffusion是比力新的一个扩散模子,翻译过来是稳固扩散,固然名字叫稳固扩散,但实际上换个seed天生的结果就完全不一样,非常不稳固哈。
Stable Diffusion最开始的应用应该是文本天生图像,即文生图,随着技能的发展Stable Diffusion不光支持image2image图生图的天生,还支持ControlNet等各种控制方法来定制天生的图像。
Stable Diffusion基于扩散模子,以是难免包罗不停去噪的过程,假如是图生图的话,另有不停加噪的过程,此时离不开DDPM那张老图,如下:
Stable Diffusion相比于DDPM,利用了DDIM采样器,利用了隐空间的扩散,别的利用了非常大的LAION-5B数据集举行预训练。
直接Finetune Stable Diffusion大多数同砚应该是无法cover住本钱的,不外Stable Diffusion有许多轻量Finetune的方案,好比Lora、Textual Inversion等,但这是后话。
本文重要是剖析一下整个SD模子的结构构成,一次扩散,多次扩散的流程。
大模子、AIGC是当前行业的趋势,不会的话容易被镌汰,hh。
txt2img的原理如博文
Diffusion扩散模子学习2——Stable Diffusion结构剖析-以文本天生图像(txt2img)为例
所示。
二、Stable Diffusion的构成
Stable Diffusion由四大部门构成。
1、Sampler采样器。
2、Variational Autoencoder (VAE) 变分自编码器。
3、UNet 主网络,噪声推测器。
4、CLIPEmbedder文本编码器。
每一部门都很紧张,我们以图像天生图像为例举行剖析。既然是图像天生图像,那么我们的输入有两个,一个是文本,别的一个是图片。
三、img2img天生流程
天生流程分为四个部门:
1、对图片举行VAE编码,根据denoise数值举行加噪声。
2、Prompt文本编码。
3、根据denoise数值举行多少次采样。
4、利用VAE举行解码。
相比于文生图,图生图的输入发生了变革,不再以Gaussian noise作为初始化,而是以加噪后的图像特性为初始化,如许便以图像的方式为模子注入了信息。
详细来讲,如上图所示:
- 第一步为对输入的图像利用VAE编码,得到输入图像的Latent特性;然后利用该Latent特性基于DDIM Sampler举行加噪,此时得到输入图片加噪后的特性。假设我们设置denoise数值为0.8,总步数为20步,那么第一步中,我们会对输入图片举行0.8x20次的加噪声,剩下4步不加,可明白为打乱了80%的特性,保存20%的特性。
- 第二步是对输入的文本举行编码,得到文本特性;
- 第三步是根据denoise数值对 第一步中得到的 加噪后的特性 举行多少次采样。还是以第一步中denoise数值为0.8为例,我们只加了0.8x20次噪声,那么我们也只必要举行0.8x20次采样就可以规复出图片了。
- 第四步是将采样后的图片利用VAE的Decoder举行规复。
- with torch.no_grad():
- if seed == -1:
- seed = random.randint(0, 65535)
- seed_everything(seed)
- # ----------------------- #
- # 对输入图片进行编码并加噪
- # ----------------------- #
- if image_path is not None:
- img = HWC3(np.array(img, np.uint8))
- img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
- img = torch.stack([img for _ in range(num_samples)], dim=0)
- img = einops.rearrange(img, 'b h w c -> b c h w').clone()
- ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
- t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
- z = model.get_first_stage_encoding(model.encode_first_stage(img))
- z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
- # ----------------------- #
- # 获得编码后的prompt
- # ----------------------- #
- cond = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
- un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
- H, W = input_shape
- shape = (4, H // 8, W // 8)
- if image_path is not None:
- samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
- else:
- # ----------------------- #
- # 进行采样
- # ----------------------- #
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
- shape, cond, verbose=False, eta=eta,
- unconditional_guidance_scale=scale,
- unconditional_conditioning=un_cond)
- # ----------------------- #
- # 进行解码
- # ----------------------- #
- x_samples = model.decode_first_stage(samples)
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
复制代码 1、输入图片编码
在图生图中,我们起主要指定一张参考的图像,然后在这个参考图像上开始工作:
1、利用VAE编码器对这张参考图像举行编码,使其进入隐空间,只有进入了隐空间,网络才知道这个图像是什么;
2、然后利用该Latent特性基于DDIM Sampler举行加噪,此时得到输入图片加噪后的特性。加噪的逻辑如下:
- denoise可以为是重修的比例,1代表全部重修,0代表不重修;
- 假设我们设置denoise数值为0.8,总步数为20步;我们会对输入图片举行0.8x20次的加噪声,剩下4步不加,可明白为80%的特性,保存20%的特性;不外就算加完20步噪声,原始输入图片的信息还是有一点保存的,不是完全不保存。
此时我们便得到在隐空间加噪后的图像,后续会在这个 隐空间加噪后的图像 的底子上举行采样。
- with torch.no_grad():
- if seed == -1:
- seed = random.randint(0, 65535)
- seed_everything(seed)
- # ----------------------- #
- # 对输入图片进行编码并加噪
- # ----------------------- #
- if image_path is not None:
- img = HWC3(np.array(img, np.uint8))
- img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
- img = torch.stack([img for _ in range(num_samples)], dim=0)
- img = einops.rearrange(img, 'b h w c -> b c h w').clone()
- ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
- t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
- z = model.get_first_stage_encoding(model.encode_first_stage(img))
- z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
复制代码 2、文本编码
文本编码的思绪比力简朴,直接利用CLIP的文本编码器举行编码就可以了,在代码中界说了一个FrozenCLIPEmbedder种别,利用了transformers库的CLIPTokenizer和CLIPTextModel。
在前传过程中,我们对输入进来的文本起首利用CLIPTokenizer举行编码,然后利用CLIPTextModel举行特性提取,通过FrozenCLIPEmbedder,我们可以得到一个[batch_size, 77, 768]的特性向量。
- class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from huggingface)"""
- LAYERS = [
- "last",
- "pooled",
- "hidden"
- ]
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
- freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
- super().__init__()
- assert layer in self.LAYERS
- # 定义文本的tokenizer和transformer
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
- self.transformer = CLIPTextModel.from_pretrained(version)
- self.device = device
- self.max_length = max_length
- # 冻结模型参数
- if freeze:
- self.freeze()
- self.layer = layer
- self.layer_idx = layer_idx
- if layer == "hidden":
- assert layer_idx is not None
- assert 0 <= abs(layer_idx) <= 12
- def freeze(self):
- self.transformer = self.transformer.eval()
- # self.train = disabled_train
- for param in self.parameters():
- param.requires_grad = False
- def forward(self, text):
- # 对输入的图片进行分词并编码,padding直接padding到77的长度。
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- # 拿出input_ids然后传入transformer进行特征提取。
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
- # 取出所有的token
- if self.layer == "last":
- z = outputs.last_hidden_state
- elif self.layer == "pooled":
- z = outputs.pooler_output[:, None, :]
- else:
- z = outputs.hidden_states[self.layer_idx]
- return z
- def encode(self, text):
- return self(text)
复制代码 3、采样流程
a、天生初始噪声
在图生图中,我们的初始噪声获取于参考图片,以是参考第一步就可以得到图生图的噪声
b、对噪声举行N次采样
既然Stable Diffusion是一个不停扩散的过程,那么少不了不停的去噪声,那么怎么去噪声便是一个题目。
在上一步中,我们已经得到了一个图生图的噪声,它是一个符合正态分布的向量,我们便从它开始去噪声。
我们会对ddim_timesteps的时间步取反,由于我们现在是去噪声而非加噪声,然后对其举行一个循环,由于我们此时不再是txt2img中的采样流程,我们利用sampler的别的一个方法decode,循环的代码如下:
- @torch.no_grad()
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- use_original_steps=False):
- # 使用ddim的时间步
- # 这里内容看起来很多,但是其实很少,本质上就是取了self.ddim_timesteps,然后把它reversed一下
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
- timesteps = timesteps[:t_start]
- time_range = np.flip(timesteps)
- total_steps = timesteps.shape[0]
- print(f"Running DDIM Sampling with {total_steps} timesteps")
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
- x_dec = x_latent
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
- # 进行单次采样
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning)
- return x_dec
复制代码 c、单次采样剖析
I、推测噪声
在举行单词采样前,必要起首判断是否有neg prompt,假如有,我们必要同时处置处罚neg prompt,否则仅仅必要处置处罚pos prompt。实际利用的时间一样平常都有neg prompt(结果会好一些),以是默认进入对应的处置处罚过程。
在处置处罚neg prompt时,我们对输入进来的隐向量和步数举行复制,一个属于pos prompt,一个属于neg prompt。torch.cat默认堆叠维度为0,以是是在batch_size维度举行堆叠,二者不会相互影响。然后我们将pos prompt和neg prompt堆叠到一个batch中,也是在batch_size维度堆叠。
- # 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- # 一般都是有neg prompt的,所以进入到这里
- # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
- # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- # 然后我们将pos prompt和neg prompt堆叠到一个batch中
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
复制代码
堆叠完后,我们将隐向量、步数和prompt条件一起传入网络中,将结果在bs维度举行利用chunk举行分割。
由于我们在堆叠时,neg prompt放在了前面。因此分割好后,前半部门e_t_uncond属于利用neg prompt得到的,后半部门e_t属于利用pos prompt得到的,我们本质上应该扩大pos prompt的影响,阔别neg prompt的影响。因此,我们利用e_t-e_t_uncond盘算二者的间隔,利用scale扩大二者的间隔。在e_t_uncond底子上,得到末了的隐向量。
- # 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
复制代码
此时得到的e_t就是通过隐向量和prompt共同得到的推测噪声啦。
II、施加噪声
得到噪声就OK了吗?显然不是的,我们还要将得到的新噪声,按照肯定的比例添加到原来的原始噪声上。
这个地方我们最好团结ddim中的公式来看,我们必要得到 α ˉ t \bar{\alpha}_t αˉt、 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt−1、 σ t \sigma_t σt、 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1−αˉt 。
代码中,我们实在已经预先盘算好了这些参数。我们只必要直接取出即可,下方的a_t也就是公式中括号外的 α ˉ t \bar{\alpha}_t αˉt,a_prev 就是公式中的 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt−1,sigma_t就是公式中的 σ t \sigma_t σt,sqrt_one_minus_at就是公式中的 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1−αˉt 。
- # 根据采样器选择参数
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- # 根据步数选择参数,
- # 这里的index就是上面循环中的total_steps - i - 1
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
复制代码 实在这一步我们只是把公式必要用到的系数全都拿了出来,方便背面的加减乘除。然后我们便在代码中实现上述的公式。
- # current prediction for x_0
- # 公式中的最左边
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- # 公式的中间
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- # 公式最右边
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- # 输出添加完公式的结果
- return x_prev, pred_x0
复制代码
d、推测噪声过程中的网络结构剖析
I、apply_model方法剖析
在3.a的推测噪声过程中,我们利用了model.apply_model方法举行噪声的推测,这个方法详细做了什么被隐掉了,我们看看详细做的工作。
apply_model方法在ldm.models.diffusion.ddpm.py文件中。在apply_model中,我们将x_noisy传入self.model中推测噪声。
- x_recon = self.model(x_noisy, t, **cond)
复制代码
self.model是一个预先构建好的类,界说在ldm.models.diffusion.ddpm.py文件的1416行,内部包罗Stable Diffusion的Unet网络,self.model的功能有点雷同于包装器,根据模子选择的特性融合方式,举行文本与上文天生的噪声的融合。
c_concat代表利用堆叠的方式举行融合,c_crossattn代表利用attention的方式融合。
- class DiffusionWrapper(pl.LightningModule):
- def __init__(self, diff_model_config, conditioning_key):
- super().__init__()
- self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
- # stable diffusion的unet网络
- self.diffusion_model = instantiate_from_config(diff_model_config)
- self.conditioning_key = conditioning_key
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
- if self.conditioning_key is None:
- out = self.diffusion_model(x, t)
- elif self.conditioning_key == 'concat':
- xc = torch.cat([x] + c_concat, dim=1)
- out = self.diffusion_model(xc, t)
- elif self.conditioning_key == 'crossattn':
- if not self.sequential_cross_attn:
- cc = torch.cat(c_crossattn, 1)
- else:
- cc = c_crossattn
- out = self.diffusion_model(x, t, context=cc)
- elif self.conditioning_key == 'hybrid':
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc)
- elif self.conditioning_key == 'hybrid-adm':
- assert c_adm is not None
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc, y=c_adm)
- elif self.conditioning_key == 'crossattn-adm':
- assert c_adm is not None
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(x, t, context=cc, y=c_adm)
- elif self.conditioning_key == 'adm':
- cc = c_crossattn[0]
- out = self.diffusion_model(x, t, y=cc)
- else:
- raise NotImplementedError()
- return out
复制代码
代码中的self.diffusion_model便是Stable Diffusion的Unet网络,网络结构位于ldm.modules.diffusionmodules.openaimodel.py文件中的UNetModel类。
II、UNetModel模子剖析
UNetModel重要做的工作是结适时间步t和文本Embedding盘算这一时间的噪声。只管UNet的思绪非常简朴,但是在StableDiffusion中,UNetModel由ResBlock和Transformer模块构成,团体来讲相比于平凡的UNet复杂一些。
Prompt通过Frozen CLIP Text Encoder得到Text Embedding,Timesteps通过全毗连(MLP)得到Timesteps Embedding;
ResBlock用于结适时间步Timesteps Embedding,Transformer模块用于团结文本Text Embedding。
我在这里放一张大图,同砚们可以看到内部shape的变革。

Unet代码如下所示:
- class UNetModel(nn.Module):
- """
- The full UNet model with attention and timestep embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
- :param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
- """
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
- ):
- super().__init__()
- if use_spatial_transformer:
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
- if context_dim is not None:
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
- from omegaconf.listconfig import ListConfig
- if type(context_dim) == ListConfig:
- context_dim = list(context_dim)
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
- if num_heads == -1:
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
- if num_head_channels == -1:
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
- self.image_size = image_size
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.num_classes = num_classes
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
- self.predict_codebook_ids = n_embed is not None
- # 用于计算当前采样时间t的embedding
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
- if self.num_classes is not None:
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
-
- # 定义输入模块的第一个卷积
- # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
- # [1, 2, 4, 4]
- for level, mult in enumerate(channel_mult):
- # 每个部分循环两次
- # 添加一个ResBlock和一个AttentionBlock
- for _ in range(num_res_blocks):
- # 先添加一个ResBlock
- # 用于对输入的噪声进行通道数的调整,并且融合t的特征
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- # ch便是上述ResBlock的输出通道数
- ch = mult * model_channels
- if ds in attention_resolutions:
- # num_heads=8
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
- if level != len(channel_mult) - 1:
- out_ch = ch
- # 在此处进行下采样
- # 一般直接使用Downsample模块
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- # 为下一阶段定义参数。
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- # 定义中间层
- # ResBlock + SpatialTransformer + ResBlock
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
- # 定义Unet上采样过程
- self.output_blocks = nn.ModuleList([])
- # 循环把channel_mult反了过来
- for level, mult in list(enumerate(channel_mult))[::-1]:
- # 上采样时每个部分循环三次
- for i in range(num_res_blocks + 1):
- ich = input_block_chans.pop()
- # 首先添加ResBlock层
- layers = [
- ResBlock(
- ch + ich,
- time_embed_dim,
- dropout,
- out_channels=model_channels * mult,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = model_channels * mult
- # 然后进行SpatialTransformer自注意力
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads_upsample,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- )
- )
- # 如果不是channel_mult循环的第一个
- # 且
- # 是num_res_blocks循环的最后一次,则进行上采样
- if level and i == num_res_blocks:
- out_ch = ch
- layers.append(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- up=True,
- )
- if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
- )
- ds //= 2
- self.output_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- # 最后在输出部分进行一次卷积
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
- )
- if self.predict_codebook_ids:
- self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
- self.output_blocks.apply(convert_module_to_f16)
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
- self.output_blocks.apply(convert_module_to_f32)
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param context: conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional.
- :return: an [N x C x ...] Tensor of outputs.
- """
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
- hs = []
- # 用于计算当前采样时间t的embedding
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
- # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context)
- hs.append(h)
- # 中间模块的特征提取
- h = self.middle_block(h, emb, context)
- # 上采样模块的特征提取
- for module in self.output_blocks:
- h = th.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context)
- h = h.type(x.dtype)
- # 输出模块
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)
复制代码 4、隐空间解码天生图片
通过上述步调,已经可以多次采样得到结果,然后我们便可以通过隐空间解码天生图片。
隐空间解码天生图片的过程非常简朴,将上文多次采样后的结果,利用decode_first_stage方法即可天生图片。
在decode_first_stage方法中,网络调用VAE对获取到的64x64x3的隐向量举行解码,得到512x512x3的图片。
- @torch.no_grad()
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
- z = 1. / self.scale_factor * z
- # 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
- if hasattr(self, "split_input_params"):
- ......
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
复制代码 图像到图像推测过程代码
团体推测代码如下:
- import os
- import random
- import cv2
- import einops
- import numpy as np
- import torch
- from PIL import Image
- from pytorch_lightning import seed_everything
- from ldm_hacked import *
- # ----------------------- #
- # 使用的参数
- # ----------------------- #
- # config的地址
- config_path = "model_data/sd_v15.yaml"
- # 模型的地址
- model_path = "model_data/v1-5-pruned-emaonly.safetensors"
- # fp16,可以加速与节省显存
- sd_fp16 = True
- vae_fp16 = True
- # ----------------------- #
- # 生成图片的参数
- # ----------------------- #
- # 生成的图像大小为input_shape,对于img2img会进行Centter Crop
- input_shape = [512, 512]
- # 一次生成几张图像
- num_samples = 1
- # 采样的步数
- ddim_steps = 20
- # 采样的种子,为-1的话则随机。
- seed = 12345
- # eta
- eta = 0
- # denoise强度,for img2img
- denoise_strength = 1.0
- # ----------------------- #
- # 提示词相关参数
- # ----------------------- #
- # 提示词
- prompt = "a cute cat, with yellow leaf, trees"
- # 正面提示词
- a_prompt = "best quality, extremely detailed"
- # 负面提示词
- n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
- # 正负扩大倍数
- scale = 9
- # img2img使用,如果不想img2img这设置为None。
- image_path = None
- # ----------------------- #
- # 保存路径
- # ----------------------- #
- save_path = "imgs/outputs_imgs"
- # ----------------------- #
- # 创建模型
- # ----------------------- #
- model = create_model(config_path).cpu()
- model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
- model = model.cuda()
- ddim_sampler = DDIMSampler(model)
- if sd_fp16:
- model = model.half()
- if image_path is not None:
- img = Image.open(image_path)
- img = crop_and_resize(img, input_shape[0], input_shape[1])
- with torch.no_grad():
- if seed == -1:
- seed = random.randint(0, 65535)
- seed_everything(seed)
-
- # ----------------------- #
- # 对输入图片进行编码并加噪
- # ----------------------- #
- if image_path is not None:
- img = HWC3(np.array(img, np.uint8))
- img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
- img = torch.stack([img for _ in range(num_samples)], dim=0)
- img = einops.rearrange(img, 'b h w c -> b c h w').clone()
- if vae_fp16:
- img = img.half()
- model.first_stage_model = model.first_stage_model.half()
- else:
- model.first_stage_model = model.first_stage_model.float()
- ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
- t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
- z = model.get_first_stage_encoding(model.encode_first_stage(img))
- z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
- z_enc = z_enc.half() if sd_fp16 else z_enc.float()
- # ----------------------- #
- # 获得编码后的prompt
- # ----------------------- #
- cond = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
- un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
- H, W = input_shape
- shape = (4, H // 8, W // 8)
- if image_path is not None:
- samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
- else:
- # ----------------------- #
- # 进行采样
- # ----------------------- #
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
- shape, cond, verbose=False, eta=eta,
- unconditional_guidance_scale=scale,
- unconditional_conditioning=un_cond)
- # ----------------------- #
- # 进行解码
- # ----------------------- #
- x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
- # ----------------------- #
- # 保存图片
- # ----------------------- #
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- for index, image in enumerate(x_samples):
- cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金 |