技术框架概述
本项目的技术方案主要借助 StyleGAN 的图像生成能力与最近OpenAI开源基于对比学习的大规模图文预训练模型 (CLIP) 的力量,以实现直观的基于文本的语义图像操作,既不受限于预定义的操作方向(发型等),也不需要额外的手动开发新控件
- 首先上传一张人像照,以及对应的描述文本;
- 利用 StyleGAN 获取人像照的潜在空间;StyleGAN具有强大的图像表达和生成能力,它使用风格(style)来影响人脸的姿态、身份特征等,用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。
- 通过最小化 CLIP 空间中计算的损失(即生成图像的 CLIP 嵌入表示与输入文本之间的余弦距离),来优化 StyleGAN 中图像的隐藏编码(latent code),从而操纵生成逼真的图像效果;CLIP 模型是基于transformer结构,从网络上收集的 4 亿个图像文本对上进行预训练的,它能够将自然语言所表达更广泛的视觉概念应用在图像的潜在空间。
- 为避免隐藏编码与原始人脸像偏离过大,采取L2范数来进行约束,保留原始图像的特征
- 最终,优化后的StyleGAN 可以响应用户提供的文本提示,生成风格化的人脸像
技术细节
https://github.com/ndb796/StyleCLIP-Tutorial https://analyticsarora.com/how-to-use-styleclip-to-generate-images-from-text/
What is CLIP?
- CLIP 使用大型数据集(4 亿组“图像+文本”)联合训练图像编码器和文本编码器,基于对比学习的大规模图文预训练模型;
- CLIP 的工作原理是创建一个嵌入空间,在嵌入空间中,余弦相似度可用于比较图像和描述。如果图像和文本特征具有相似的语义,则它们之间的余弦相似度会很高。
What is StyleGAN?
- StyleGAN中的“Style”是指数据集中人脸的主要属性,比如表情、人脸朝向、发型、肤色。
- StyleGAN 最显着的区别在于其生成器函数的结构。StyleGAN 用风格(style)来影响人脸的姿态、身份特征等,用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。
- StyleGAN 的网络结构包含两个部分,第一个是Mapping network,由8个全连接层组成,通过一系列仿射变换,由 z 得到 w。它要做的事就是对隐藏空间(latent space)进行解耦,得到的隐藏特征,即latent code(这个latent code输入到网络中去,能够复原指定图像)。它被用于控制生成图像的style,即风格;
- 第二个是Synthesis network,它的作用是生成图像,图像生成其实是学习从一个分布到目标分布的迁移过程
StyleCLIP Methods
1. Latent Optimization
- Google Colab 教程源代码
- 这是一种利用 CLIP 指导图像处理的简单方法,该方案利用基于 CLIP 的损失来修改输入潜在向量以响应用户提供的文本提示。
- 生成器StyleGAN的输出与 CLIP 嵌入空间中的目标文本进行比较。余弦距离用于通知更新步骤,用于更新初始权重 w 并继续循环。
- 给定源latent code
w_s∈W+
,以及自然语言中的指令或文本提示t
,我们迭代地最小化如下的三个损失总和:- CLIP Loss:StyleGAN 生成的图像与文本查询之间的距离(Dclip),并优化 StyleGAN 上的潜在变量 w ∈ W +;
- L2 Loss:正则化损失惩罚源向量 w_s 产生的大偏差;
- Identity Loss:确保生成的人脸的身份与原始人脸的身份相同。这是通过最小化 ArcFace 模型嵌入空间中图像之间的距离来完成的!
- 这种优化方法需要200 - 300 次迭代,花费大概几分钟
2. Latent Mapper
- 根据文本提示进行训练(10 小时)后,Mapper便可以通过一次前向传递来操作属性(推理快)
- 训练三个独立的mapping functions以生成残差(蓝色),这些残差添加到 w 以产生目标代码,通过预训练 StyleGAN(绿色)生成一个图像(右),最后由 CLIP and identity losses 进行评估
- Loss如下:
3. Global Directions
- 对一组图像上从映射器获得的操作方向之间的相似性分析表明,从一个图像到另一个图像的指令/权重更新没有太大差异。简而言之,无论输入如何,每个映射器都会进行类似的操作,这表明全局方向是可能的。
- 文本提示的
CLIP 嵌入空间中的向量 Δt
,图像的CLIP 嵌入空间中的向量 Δi
,将他们映射到 StyleGAN 的样式空间S 中,获得与输入无关的全局方向Δs(也叫广义操纵方向) - 找到全局方向后,我们可以将此全局方向应用于任何潜在向量s,使得 G(s + α∆s)
- alpha(Manipulation strength):正值(↑)沿目标方向 Δs 移动,即更突出文本描述;负值(↓)沿目标反方向 -Δs 移动,即更远离文本描述;
- Beta(Disentanglement threshold):值越小,大量的通道将被操纵,相关属性也会发生变化(如皱纹、肤色、眼镜)
注意事项
- 使用base64格式加密传输,把一张图片数据加密成一串字符,使用该字符串代替图像:
#base64编码
with open("1.png", 'rb') as f:
base64_data = base64.b64encode(f.read())
'''注意编码类型问题,byte->string '''
base64_data = base64_data.decode()
#进行base64解码工作 base64->数组
image_decode = base64.b64decode(image_b64)
#fromstring实现了字符串到Ascii码的转换
nparr = np.fromstring(image_decode, np.uint8)
#从nparr中读取数据,并把数据转换(解码)成图像格式
img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
cv2.imwrite('test.jpg',img_np)
- 在 Google Colab 中运行 FastAPI / Uvicorn?可以使用ngrok将端口导出为外部 url。基本上,ngrok 获取本地主机上可用/托管的内容,并使用临时公共 URL 将其公开给 Internet。
!pip install fastapi nest-asyncio pyngrok uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
@app.get('/')
async def root():
return {'hello': 'world'}
import nest_asyncio
from pyngrok import ngrok
import uvicorn
ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url)
nest_asyncio.apply()
uvicorn.run(app, port=8000)
- 图片流保存至本地,并上传至OSS获取公共可访问URL
class oss(object):
"""对象存储类,将模型传至阿里云端"""
def __init__(self, access_key_id, access_key_secret, endpoint, bucket_name):
self.auth = oss2.Auth(access_key_id, access_key_secret)
self.bucket = oss2.Bucket(self.auth, endpoint, bucket_name) # 连接OSS
def put_file(self, file_path, oss_path):
with open("{}".format(file_path), "rb") as f:
put_result = self.bucket.put_object(oss_path, f.read())
if put_result.status == 200:
# 若此时的status状态为200,则说明上传成功;
print("put success")
ret = self.bucket.sign_url('GET', oss_path, 60*60*24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
def get_file(self, file_path, oss_path):
# param1:oss上bucket中的文件名
# param2:保存在当地的文件路径+文件名
get_result = self.bucket.get_object_to_file(oss_path, file_path)
if get_result.status == 200:
print("get success")
else:
print("get failed")
# oss配置
CFG = read_config(config_file="config.ini")
oss_server = oss(
access_key_id=CFG.get("oss", "AccessKey"),
access_key_secret=CFG.get("oss", "AccessKeySecret"),
endpoint=CFG.get("oss", "EndPoint"),
bucket_name=CFG.get("oss", "Bucket"),
)
print("通过OSS上传model....")
root = "/Users/admin/Pictures/"
img_path = root + "click1.png"
oss_server.put_file(img_path, "hackson-nsx/test.png")