技术框架概述

本项目的技术方案主要借助 StyleGAN 的图像生成能力与最近OpenAI开源基于对比学习的大规模图文预训练模型 (CLIP) 的力量,以实现直观的基于文本的语义图像操作,既不受限于预定义的操作方向(发型等),也不需要额外的手动开发新控件

  1. 首先上传一张人像照,以及对应的描述文本;
  2. 利用 StyleGAN 获取人像照的潜在空间;StyleGAN具有强大的图像表达和生成能力,它使用风格(style)来影响人脸的姿态、身份特征等,用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。
  3. 通过最小化 CLIP 空间中计算的损失(即生成图像的 CLIP 嵌入表示与输入文本之间的余弦距离),来优化 StyleGAN 中图像的隐藏编码(latent code),从而操纵生成逼真的图像效果;CLIP 模型是基于transformer结构,从网络上收集的 4 亿个图像文本对上进行预训练的,它能够将自然语言所表达更广泛的视觉概念应用在图像的潜在空间。
  4. 为避免隐藏编码与原始人脸像偏离过大,采取L2范数来进行约束,保留原始图像的特征
  5. 最终,优化后的StyleGAN 可以响应用户提供的文本提示,生成风格化的人脸像

技术细节

https://github.com/ndb796/StyleCLIP-Tutorial https://analyticsarora.com/how-to-use-styleclip-to-generate-images-from-text/

What is CLIP?

  • CLIP 使用大型数据集(4 亿组“图像+文本”)联合训练图像编码器和文本编码器,基于对比学习的大规模图文预训练模型;
  • CLIP 的工作原理是创建一个嵌入空间,在嵌入空间中,余弦相似度可用于比较图像和描述。如果图像和文本特征具有相似的语义,则它们之间的余弦相似度会很高。

113474963-ed837300-94ad-11eb-98af-632f4ce4feb9.png

What is StyleGAN?

https://zhuanlan.zhihu.com/p/263554045

  • StyleGAN中的“Style”是指数据集中人脸的主要属性,比如表情、人脸朝向、发型、肤色。
  • StyleGAN 最显着的区别在于其生成器函数的结构。StyleGAN 用风格(style)来影响人脸的姿态、身份特征等,用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。
  • StyleGAN 的网络结构包含两个部分,第一个是Mapping network,由8个全连接层组成,通过一系列仿射变换,由 z 得到 w。它要做的事就是对隐藏空间(latent space)进行解耦,得到的隐藏特征,即latent code(这个latent code输入到网络中去,能够复原指定图像)。它被用于控制生成图像的style,即风格;

styleCLIP论文解读 - 图2

  • 第二个是Synthesis network,它的作用是生成图像,图像生成其实是学习从一个分布到目标分布的迁移过程

styleCLIP论文解读 - 图3

StyleCLIP Methods

  • StyleCLIP 利用 CLIP 作为损失函数来提供生成对抗网络反馈。
  • StyleCLIP 提供了基于各种先前研究的三种方法。
  • 教程:讲义
  • 教程:视频:论文解释

1. Latent Optimization

  • Google Colab 教程源代码
  • 这是一种利用 CLIP 指导图像处理的简单方法,该方案利用基于 CLIP 的损失来修改输入潜在向量以响应用户提供的文本提示。
  • 生成器StyleGAN的输出与 CLIP 嵌入空间中的目标文本进行比较。余弦距离用于通知更新步骤,用于更新初始权重 w 并继续循环。

113475055-5e2a8f80-94ae-11eb-8298-2ee1d36e251e.png

  • 给定源latent code w_s∈W+,以及自然语言中的指令或文本提示 t,我们迭代地最小化如下的三个损失总和:
    • CLIP Loss:StyleGAN 生成的图像与文本查询之间的距离(Dclip),并优化 StyleGAN 上的潜在变量 w ∈ W +;
    • L2 Loss:正则化损失惩罚源向量 w_s 产生的大偏差;
    • Identity Loss:确保生成的人脸的身份与原始人脸的身份相同。这是通过最小化 ArcFace 模型嵌入空间中图像之间的距离来完成的!
  • 这种优化方法需要200 - 300 次迭代,花费大概几分钟

styleCLIP论文解读 - 图5

2. Latent Mapper

  • 根据文本提示进行训练(10 小时)后,Mapper便可以通过一次前向传递来操作属性(推理快)
  • 训练三个独立的mapping functions以生成残差(蓝色),这些残差添加到 w 以产生目标代码,通过预训练 StyleGAN(绿色)生成一个图像(右),最后由 CLIP and identity losses 进行评估

styleCLIP论文解读 - 图6

  • Loss如下:

styleCLIP论文解读 - 图7

styleCLIP论文解读 - 图8

styleCLIP论文解读 - 图9

3. Global Directions

  • 对一组图像上从映射器获得的操作方向之间的相似性分析表明,从一个图像到另一个图像的指令/权重更新没有太大差异。简而言之,无论输入如何,每个映射器都会进行类似的操作,这表明全局方向是可能的。
  • 文本提示的CLIP 嵌入空间中的向量 Δt ,图像的CLIP 嵌入空间中的向量 Δi ,将他们映射到 StyleGAN 的样式空间S 中,获得与输入无关的全局方向Δs(也叫广义操纵方向)
  • 找到全局方向后,我们可以将此全局方向应用于任何潜在向量s,使得 G(s + α∆s)

styleCLIP论文解读 - 图10

  • alpha(Manipulation strength):正值(↑)沿目标方向 Δs 移动,即更突出文本描述;负值(↓)沿目标反方向 -Δs 移动,即更远离文本描述;
  • Beta(Disentanglement threshold):值越小,大量的通道将被操纵,相关属性也会发生变化(如皱纹、肤色、眼镜)

styleCLIP论文解读 - 图11

注意事项

  1. #base64编码
  2. with open("1.png", 'rb') as f:
  3. base64_data = base64.b64encode(f.read())
  4. '''注意编码类型问题,byte->string '''
  5. base64_data = base64_data.decode()
  6. #进行base64解码工作 base64->数组
  7. image_decode = base64.b64decode(image_b64)
  8. #fromstring实现了字符串到Ascii码的转换
  9. nparr = np.fromstring(image_decode, np.uint8)
  10. #从nparr中读取数据,并把数据转换(解码)成图像格式
  11. img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
  12. cv2.imwrite('test.jpg',img_np)
  1. !pip install fastapi nest-asyncio pyngrok uvicorn
  2. from fastapi import FastAPI
  3. from fastapi.middleware.cors import CORSMiddleware
  4. app = FastAPI()
  5. app.add_middleware(
  6. CORSMiddleware,
  7. allow_origins=['*'],
  8. allow_credentials=True,
  9. allow_methods=['*'],
  10. allow_headers=['*'],
  11. )
  12. @app.get('/')
  13. async def root():
  14. return {'hello': 'world'}
  15. import nest_asyncio
  16. from pyngrok import ngrok
  17. import uvicorn
  18. ngrok_tunnel = ngrok.connect(8000)
  19. print('Public URL:', ngrok_tunnel.public_url)
  20. nest_asyncio.apply()
  21. uvicorn.run(app, port=8000)
  • 图片流保存至本地,并上传至OSS获取公共可访问URL
  1. class oss(object):
  2. """对象存储类,将模型传至阿里云端"""
  3. def __init__(self, access_key_id, access_key_secret, endpoint, bucket_name):
  4. self.auth = oss2.Auth(access_key_id, access_key_secret)
  5. self.bucket = oss2.Bucket(self.auth, endpoint, bucket_name) # 连接OSS
  6. def put_file(self, file_path, oss_path):
  7. with open("{}".format(file_path), "rb") as f:
  8. put_result = self.bucket.put_object(oss_path, f.read())
  9. if put_result.status == 200:
  10. # 若此时的status状态为200,则说明上传成功;
  11. print("put success")
  12. ret = self.bucket.sign_url('GET', oss_path, 60*60*24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
  13. def get_file(self, file_path, oss_path):
  14. # param1:oss上bucket中的文件名
  15. # param2:保存在当地的文件路径+文件名
  16. get_result = self.bucket.get_object_to_file(oss_path, file_path)
  17. if get_result.status == 200:
  18. print("get success")
  19. else:
  20. print("get failed")
  21. # oss配置
  22. CFG = read_config(config_file="config.ini")
  23. oss_server = oss(
  24. access_key_id=CFG.get("oss", "AccessKey"),
  25. access_key_secret=CFG.get("oss", "AccessKeySecret"),
  26. endpoint=CFG.get("oss", "EndPoint"),
  27. bucket_name=CFG.get("oss", "Bucket"),
  28. )
  29. print("通过OSS上传model....")
  30. root = "/Users/admin/Pictures/"
  31. img_path = root + "click1.png"
  32. oss_server.put_file(img_path, "hackson-nsx/test.png")