麻花痒 发表于 2025-4-19 07:42:25

JAVA实现从零实现扩散模子stable diffusion系列(一)

概要

在上一个文章咱们已经实现了一个基于llama3.1的大语言模子(LLM模子)。本日咱们继承来利用Omega-AI深度学习引擎从零实现一个stable diffusion模子,并实现文生图场景应用。
Omega-AI深度学习引擎

Omega-AI:基于java打造的深度学习框架,资助你快速搭建神经网络,实现训练或测试模子,支持多卡训练,框架现在支持BP神经网络、卷积神经网络、循环神经网络、vgg16、resnet、yolo、lstm、transformer、diffusion、gpt、llama、llava等模子的构建,现在引擎最新版本支持CUDA和CUDNN两种GPU加速方式,关于GPU加速的环境配置与jcuda版本jar包的对应依靠。
Omega-AI简介
JAVA实现从零大语言模子llama3
结果展示

基于stable diffusion模子实现文生图

文生图演示图

文本1图片1文本2图片2a highly detailed anime landscape,big tree on the water, epic sky,golden grass,detailed.https://i-blog.csdnimg.cn/direct/4349aa305af24cc99b124ab27208ce58.png#pic_center3d art of a golden tree in the river,with intricate flora and flowing water,detailed.https://i-blog.csdnimg.cn/direct/757f0b4e688649048601ebebf249c74a.png#pic_centera vibrant anime mountain landshttps://i-blog.csdnimg.cn/direct/f713e0c4875545d59753f519d0f67274.png#pic_centera dark warrior in epic armor stands among glowing crimson leaves in a mystical forest.https://i-blog.csdnimg.cn/direct/3a8c9954c84b42f1984b4122f8fc2747.png#pic_centercute fluffy panda, anime, ghibli style, pastel colors, soft shadows, detailed fur, vibrant eyes, fantasy setting, digital art, 3d, by kazuo ogahttps://i-blog.csdnimg.cn/direct/8fc80070edd749be891f61e8a818e3c9.png#pic_centera epic city,3d,detailed.https://i-blog.csdnimg.cn/direct/e710ad40c7f44d8896851cb5afef46f7.png#pic_center Quick Start

环境配置



[*]JDK1.8以上
[*]CUDA11.X/12.X
// 检查当前安装的CUDA版本
nvcc --version
安装CUDA与CUDNN
https://developer.nvidia.com/cuda-toolkit-archive
下载与配置Omega-AI深度学习引擎



[*]下载Omega-AI深度学习引擎
git clone https://github.com/dromara/Omega-AI.git
git clone https://gitee.com/dromara/omega-ai.git


[*]根据当前CUDA版本配置JCUDA依靠
打开Omega-AI pom.xml文件,根据当前CUDA版本修改依靠
提示:如您安装的cuda版本为12.x,请利用jcuda12.0.0版本
        <properties>
          <java.version>1.8</java.version>
          <!--当前cuda版本为11.8.x,对应jcuda版本为11.8.0-->
                <jcuda.version>11.8.0</jcuda.version>
                <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
                <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
                <resource.delimiter>@</resource.delimiter>
          <maven.compiler.source>${java.version}</maven.compiler.source>
          <maven.compiler.target>${java.version}</maven.compiler.target>
        </properties>
stable diffusion架构

传统的扩散模子有两大限定:1.输入图片尺寸与计算量巨细限定导致效率低下,2.只能输入随机噪声导致无法控制输出结果。而stable diffsuion引入了latent space概念,使得在其可以在较少的内存占用完成高清的图片天生。在解决只能输入随机噪声的题目上,stable diffusion利用了clip text的text encoder把文本信息作为条件输入到text conditioned lantent unet当中,并利用cross attention把text条件与图像融合。总结以上内容,stable diffsuion总共分为三大组件:VAE(变分自编码器)负责把图片编码成相对较小的latent space数据和解码latent space还原成正常巨细的图片。CLIP TEXT当中的text encoder负责把文本内容编码成77*512的 token embeddings。lantent unet负责结合条件天生latent space,与传统的diffusion模子的unet相比,stable diffusion的unet利用的是cross attention机制,目标就是为了融合条件信息。以下是stable diffusion流程图:
https://i-blog.csdnimg.cn/direct/3868c1ed95bd43fbb9df6b3ccdae97ec.png#pic_center
1 STEP 训练VQ-VAE(变分自编码器)

1.1下载与预处理训练数据



[*]本次使命将利用开源动画风格的图文对数据集【rapidata】点击下载
[*]处理图片巨细统一为256 * 256或者512 * 512
[*]制作元数据并存储为json文件,数据格式为:[{“id”: “0”, “en”: “cinematic bokeh: ironcat, sharp focus on the cat’s eyes, blurred background, dramatic chiaroscuro lighting, deep shadows, high contrast, rich textures, high resolution”}]
提示:可下载已经处理好的数据集
点击下载已处理后的数据集
[*]利用数据加载器读取训练数据,代码如下:
        int batchSize = 2;
           int imageSize = 256;
           float[] mean = new float[] {0.5f, 0.5f, 0.5f};
    float[] std = new float[] {0.5f, 0.5f, 0.5f};
           String imgDirPath = "I:\\dataset\\sd-anime\\anime_op\\256\\";
           DiffusionImageDataLoader dataLoader = new DiffusionImageDataLoader(imgDirPath, imageSize, imageSize, batchSize, true, false, mean, std);
1.2 创建VQ-VAE模子

                /**
                  * LossType lossType: 损失函数
                  * UpdaterType updater, 参数更新方法
                  * int z_dims, 输出latent space维度
                  * int latendDim, 输出latent space通道数
                  * latent space形状为
                  * int num_res_blocks, 每个采样层所包含的residual层数
                  * int imageSize, 输入图片大小
                  * int[] ch_mult, unet上下采样层通道倍数
                  * int ch, unet上下采样层通道基数
                  */
                   VQVAE2 network = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imageSize, ch_mult, ch, num_res_blocks);
1.3 创建LPIPS模子

为了增强vae的还原图片的清晰度,在训练vae模子的过程中添加lpips(感知丧失),该模子用于度量两张图片之间的差别。
/**
   * LossType lossType: 损失函数(均方差损失函数)
   * UpdaterType updater, 参数更新方法
   * int imageSize, 输入图片大小
   */
LPIPS lpips = new LPIPS(LossType.MSE, UpdaterType.adamw, imageSize);
完整训练代码如下:

public static void anime_vqvae2_lpips_gandisc_32_nogan() {
        try {
                nt batchSize = 16;
                int imageSize = 256;
                int z_dims = 32;
                int latendDim = 4;
                int num_vq_embeddings = 512;
                int num_res_blocks = 1;
                int[] ch_mult = new int[] {1, 2, 2, 4};
                int ch = 32;
                float[] mean = new float[] {0.5f, 0.5f, 0.5f};
                float[] std = new float[] {0.5f, 0.5f, 0.5f};
               
                String imgDirPath = "I:\\dataset\\sd-anime\\anime_op\\256\\";
                DiffusionImageDataLoader dataLoader = new DiffusionImageDataLoader(imgDirPath, imageSize, imageSize, batchSize, true, false, mean, std);
                /**
               * LossType lossType: 损失函数
               * UpdaterType updater, 参数更新方法
               * int z_dims, 输出latent space维度
               * int latendDim, 输出latent space通道数
               * latent space形状为
               * int num_res_blocks, 每个采样层所包含的residual层数
               * int imageSize, 输入图片大小
               * int[] ch_mult, unet上下采样层通道倍数
               * int ch, unet上下采样层通道基数
               */
                VQVAE2 network = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imageSize, ch_mult, ch, num_res_blocks);
                network.CUDNN = true;
                network.learnRate = 0.001f;
               
                LPIPS lpips = new LPIPS(LossType.MSE, UpdaterType.adamw, imageSize);
                //加载权重
                String lpipsWeight = "H:\\model\\lpips.json";
                LPIPSTest.loadLPIPSWeight(LagJsonReader.readJsonFileSmallWeight(lpipsWeight), lpips, false);
                lpips.CUDNN = true;
               
                MBSGDOptimizer optimizer = new MBSGDOptimizer(network, 200, 0.00001f, batchSize, LearnRateUpdate.CONSTANT, false);
                optimizer.trainVQVAE2_lpips_nogan(dataLoader, lpips);

                String save_model_path = "/omega/models/anime_vqvae2_256.model";
                ModelUtils.saveModel(network, save_model_path);

        } catch (Exception e) {
                // TODO: handle exception
                e.printStackTrace();
        }
}
VQ-VAE演示图

原图VQ-VAE原图VQ-VAEhttps://i-blog.csdnimg.cn/direct/1ec7fdb7d8b64de2b6fd7f29d8e4c589.png#pic_centerhttps://i-blog.csdnimg.cn/direct/5b5ebd062ef9481eb0a85cad605a3b36.png#pic_centerhttps://i-blog.csdnimg.cn/direct/454b213fb23e4f7fbfd9adebbd96b0a7.png#pic_centerhttps://i-blog.csdnimg.cn/direct/9fe0adf2219a444b9577290b074900be.png#pic_centerhttps://i-blog.csdnimg.cn/direct/819716755b0249dbbfaf047a215bd619.png#pic_centerhttps://i-blog.csdnimg.cn/direct/d8de20cea1d445beb432566aebebdbe1.png#pic_centerhttps://i-blog.csdnimg.cn/direct/7b939475880a40739742aa85e3a54d69.png#pic_centerhttps://i-blog.csdnimg.cn/direct/45e822f2476b403baa09a67328274445.png#pic_center 2 STEP 训练diffusion unet cond(条件扩散模子)

2.1 创建与加载Clip Text Encoder

本次使命利用clip-vit-base-patch32的encoder部分作为text encoder。
        /**
       * clipText shape
       */
        int time = maxContextLen;//文本最大token长度
        int maxPositionEmbeddingsSize = 77;//文本最大token长度
        int vocabSize = 49408;//tokenizer词表长度
        int headNum = 8;//多头注意力头数
        int n_layers = 12;//CLIPEncoderLayer编码层层数
        int textEmbedDim = 512;//文本嵌入输出维度
        ClipTextModel clip = new ClipTextModel(LossType.MSE, UpdaterType.adamw, headNum, time, vocabSize, textEmbedDim, maxPositionEmbeddingsSize, n_layers);
        clip.CUDNN = true;
        clip.time = time;
        clip.RUN_MODEL = RunModel.EVAL; //设置推理模式
       
        String clipWeight = "H:\\model\\clip-vit-base-patch32.json";
        ClipModelUtils.loadWeight(LagJsonReader.readJsonFileSmallWeight(clipWeight), clip, true);
2.2 创建与加载VQ-VAE模子

        /**
       * LossType lossType: 损失函数
       * UpdaterType updater, 参数更新方法
       * int z_dims, 输出latent space维度
       * int latendDim, 输出latent space通道数
       * latent space形状为
       * int num_res_blocks, 每个采样层所包含的residual层数
       * int imageSize, 输入图片大小
       * int[] ch_mult, unet上下采样层通道倍数
       * int ch, unet上下采样层通道基数
       */
        VQVAE2 vae = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imageSize, ch_mult, ch, num_res_blocks);
        vae.RUN_MODEL = RunModel.EVAL;//设置推理模式
        //加载已训练好的vae模型权重
        String vaeModel = "anime_vqvae2_256.model";
        ModelUtils.loadModel(vae, vaeModel);
2.3 创建Diffusion UNet Cond模子(条件扩散模子)

        int unetHeadNum = 8;//unet多头注意力头数
        int[] downChannels = new int[] {128, 256, 512, 768};//下采样通道数
        int numLayer = 2;//每层采样层的ResidualBlock个数
        int timeSteps = 1000;//扩散时间步数
        int tEmbDim = 512;//时序嵌入维度
        int latentSize = 32;//latent space维度
        int groupNum = 32;//group norm分组数
               
        DiffusionUNetCond2 unet = new DiffusionUNetCond2(LossType.MSE, UpdaterType.adamw, latendDim, latentSize, latentSize, downChannels, unetHeadNum, numLayer, timeSteps, tEmbDim, maxContextLen, textEmbedDim, groupNum);
        unet.CUDNN = true;
        unet.learnRate = 0.0001f;
完整训练代码如下:

        public static void tiny_sd_train_anime_32() throws Exception {
                String labelPath = "I:\\dataset\\sd-anime\\anime_op\\data.json";
                String imgDirPath = "I:\\dataset\\sd-anime\\anime_op\\256\\";
                boolean horizontalFilp = true;
                int imgSize = 256;
                int maxContextLen = 77;
                int batchSize = 8;
                float[] mean = new float[] {0.5f, 0.5f,0.5f};
                float[] std = new float[] {0.5f, 0.5f,0.5f};
                //加载bpe tokenizer分词器
                String vocabPath = "H:\\model\\bpe_tokenizer\\vocab.json";
                String mergesPath = "H:\\model\\bpe_tokenizer\\merges.txt";
                BPETokenizerEN bpe = new BPETokenizerEN(vocabPath, mergesPath, 49406, 49407);
               
                SDImageDataLoaderEN dataLoader = new SDImageDataLoaderEN(bpe, labelPath, imgDirPath, imgSize, imgSize, maxContextLen, batchSize, horizontalFilp, mean, std);
               
                /**
               * clipText shape
               */
                int time = maxContextLen;//文本最大token长度
                int maxPositionEmbeddingsSize = 77;//文本最大token长度
                int vocabSize = 49408;//tokenizer词表长度
                int headNum = 8;//多头注意力头数
                int n_layers = 12;//CLIPEncoderLayer编码层层数
                int textEmbedDim = 512;//文本嵌入输出维度
                ClipTextModel clip = new ClipTextModel(LossType.MSE, UpdaterType.adamw, headNum, time, vocabSize, textEmbedDim, maxPositionEmbeddingsSize, n_layers);
                clip.CUDNN = true;
                clip.time = time;
                clip.RUN_MODEL = RunModel.EVAL;
               
                String clipWeight = "H:\\model\\clip-vit-base-patch32.json";
                ClipModelUtils.loadWeight(LagJsonReader.readJsonFileSmallWeight(clipWeight), clip, true);
               
                int z_dims = 128;
                int latendDim = 4;
                int num_vq_embeddings = 512;
                int num_res_blocks = 2;
                int[] ch_mult = new int[] {1, 2, 2, 4};
                int ch = 128;
               
                VQVAE2 vae = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imgSize, ch_mult, ch, num_res_blocks);
                vae.CUDNN = true;
                vae.learnRate = 0.001f;
                vae.RUN_MODEL = RunModel.EVAL;
                String vaeModel = "anime_vqvae2_256.model";
                ModelUtils.loadModel(vae, vaeModel);
               
                int unetHeadNum = 8;//unet多头注意力头数
                int[] downChannels = new int[] {128, 256, 512, 768};//下采样通道数
                int numLayer = 2;//每层采样层的ResidualBlock个数
                int timeSteps = 1000;//扩散时间步数
                int tEmbDim = 512;//时序嵌入维度
                int latentSize = 32;//latent space维度
                int groupNum = 32;//group norm分组数
               
                DiffusionUNetCond2 unet = new DiffusionUNetCond2(LossType.MSE, UpdaterType.adamw, latendDim, latentSize, latentSize, downChannels, unetHeadNum, numLayer, timeSteps, tEmbDim, maxContextLen, textEmbedDim, groupNum);
                unet.CUDNN = true;
                unet.learnRate = 0.0001f;
               
                MBSGDOptimizer optimizer = new MBSGDOptimizer(unet, 500, 0.00001f, batchSize, LearnRateUpdate.CONSTANT, false);
                optimizer.trainTinySD_Anime(dataLoader, vae, clip);
                //保存训练完成的权重文件
                String save_model_path = "/omega/models/sd_anime256.model";
                ModelUtils.saveModel(unet, save_model_path);
        }
推理代码如下:

        public static void tiny_sd_predict_anime_32() throws Exception {
               
                int imgSize = 256;
                int maxContextLen = 77;
                String vocabPath = "H:\\model\\bpe_tokenizer\\vocab.json";
                String mergesPath = "H:\\model\\bpe_tokenizer\\merges.txt";
                BPETokenizerEN tokenizer = new BPETokenizerEN(vocabPath, mergesPath, 49406, 49407);
               
                int time = maxContextLen;
                int maxPositionEmbeddingsSize = 77;
                int vocabSize = 49408;
                int headNum = 8;
                int n_layers = 12;
                int textEmbedDim = 512;
               
                ClipTextModel clip = new ClipTextModel(LossType.MSE, UpdaterType.adamw, headNum, time, vocabSize, textEmbedDim, maxPositionEmbeddingsSize, n_layers);
                clip.CUDNN = true;
                clip.time = time;
                clip.RUN_MODEL = RunModel.EVAL;
               
                String clipWeight = "H:\\model\\clip-vit-base-patch32.json";
                ClipModelUtils.loadWeight(LagJsonReader.readJsonFileSmallWeight(clipWeight), clip, true);
               
                int z_dims = 128;
                int latendDim = 4;
                int num_vq_embeddings = 512;
                int num_res_blocks = 2;
                int[] ch_mult = new int[] {1, 2, 2, 4};
                int ch = 128;
               
                VQVAE2 vae = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imgSize, ch_mult, ch, num_res_blocks);
                vae.CUDNN = true;
                vae.learnRate = 0.001f;
                vae.RUN_MODEL = RunModel.EVAL;
               
                String vaeModel = "H:\\model\\anime_vqvae2_256.model";
                ModelUtils.loadModel(vae, vaeModel);
               
                int unetHeadNum = 8;
                int[] downChannels = new int[] {64, 128, 256, 512};
                int numLayer = 2;
                int timeSteps = 1000;
                int tEmbDim = 512;
                int latendSize = 32;
                int groupNum = 32;

                int batchSize = 1;
               
                DiffusionUNetCond2 unet = new DiffusionUNetCond2(LossType.MSE, UpdaterType.adamw, latendDim, latendSize, latendSize, downChannels, unetHeadNum, numLayer, timeSteps, tEmbDim, maxContextLen, textEmbedDim, groupNum);
                unet.CUDNN = true;
                unet.learnRate = 0.0001f;
                unet.RUN_MODEL = RunModel.TEST;
               
                String model_path = "H:\\model\\sd_anime256.model";
                ModelUtils.loadModel(unet, model_path);
               
                Scanner scanner = new Scanner(System.in);
               
//                Tensor latent = new Tensor(batchSize, latendDim, latendSize, latendSize, true);
                Tensor t = new Tensor(batchSize, 1, 1, 1, true);
                Tensor label = new Tensor(batchSize * unet.maxContextLen, 1, 1, 1, true);
               
                Tensor input = new Tensor(batchSize, 3, imgSize, imgSize, true);
                Tensor latent = vae.encode(input);
               
                while (true) {
                        System.out.println("请输入英文:");
                        String input_txt = scanner.nextLine();
                        if(input_txt.equals("exit")){
                                break;
                        }
                        input_txt = input_txt.toLowerCase();
                       
                        loadLabels(input_txt, label, tokenizer, unet.maxContextLen);

                        Tensor condInput = clip.forward(label);

                        String[] labels = new String[] {input_txt, input_txt};
                        MBSGDOptimizer.testSD(input_txt, latent, t, condInput, unet, vae, labels, "H:\\vae_dataset\\anime_test256\\");
                }
                scanner.close();
        }
以上代码所需的文件请移步到百度云盘下载 点击下载

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: JAVA实现从零实现扩散模子stable diffusion系列(一)