🚀 简介
StarVector 是一个用于生成 SVG(可缩放矢量图形)的多模态视觉语言模型。它可以执行 Image2SVG 和 Text2SVG 生成任务。我们将图像生成视为代码生成任务,利用多模态 VLM(视觉语言模型)的强大能力。
摘要:SVG 在现代图像渲染中至关重要,因为它具备可扩展性和多功能性。以往的 SVG 生成方法主要依赖于基于曲线的矢量化技术,但缺乏语义理解,容易产生伪影,并且难以处理
path
曲线以外的 SVG 基元。为了解决这些问题,我们推出了 StarVector,一个用于 SVG 生成的多模态大语言模型。它通过理解图像语义并使用 SVG 基元进行矢量化,从而生成紧凑且精确的输出。与传统方法不同,StarVector 直接在 SVG 代码空间中操作,利用视觉理解生成准确的 SVG 基元。为了训练 StarVector,我们创建了 SVG-Stack 数据集,包含 200 万个样本,能够在多种矢量化任务中泛化,并精确使用椭圆、多边形和文本等基元。同时,我们解决了 SVG 评估中的难题,发现像 MSE 这种基于像素的度量无法有效衡量矢量图形的独特特性。因此,我们推出了 SVG-Bench 基准,包括 10 个数据集和 3 项任务:图像到 SVG、文本到 SVG 生成,以及图表生成。在这些任务中,StarVector 实现了 SOTA(最先进)性能,生成更紧凑且语义丰富的 SVG。
多模态架构
StarVector 使用多模态架构处理图像和文本。在执行 Image2SVG(图像矢量化)任务时,图像会被映射为视觉标记(visual tokens),并生成对应的 SVG 代码。在 Text2SVG 任务中,模型仅接收文本指令(不提供图像),生成新的 SVG。该模型基于 StarCoder 构建,我们借助其编码能力将其扩展至 SVG 生成任务。
安装
克隆项目仓库并进入 star-vector 文件夹
git clone https://github.com/joanrod/star-vector.git
cd star-vector
安装依赖包
conda create -n starvector python=3.11.3 -y
conda activate starvector
pip install --upgrade pip # 启用 PEP 660 支持
pip install -e .
安装训练所需的额外依赖
pip install -e ".[train]"
更新到最新代码版本
git pull
pip install -e .
快速开始 - Image2SVG 生成
from PIL import Image
from starvector.model.starvector_arch import StarVectorForCausalLM
from starvector.data.util import process_and_rasterize_svg
model_name = "starvector/starvector-8b-im2svg"
starvector = StarVectorForCausalLM.from_pretrained(model_name)
starvector.cuda()
starvector.eval()
image_pil = Image.open('assets/examples/sample-0.png')
image = starvector.process_images([image_pil])[0].cuda()
batch = {"image": image}
raw_svg = starvector.generate_im2svg(batch, max_length=1000)[0]
svg, raster_image = process_and_rasterize_svg(raw_svg)
使用 HuggingFace AutoModel
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from starvector.data.util import process_and_rasterize_svg
import torch
model_name = "starvector/starvector-8b-im2svg"
starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True)
processor = starvector.model.processor
tokenizer = starvector.model.svg_transformer.tokenizer
starvector.cuda()
starvector.eval()
image_pil = Image.open('assets/examples/sample-18.png')
image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
if not image.shape[0] == 1:
image = image.squeeze(0)
batch = {"image": image}
raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0]
svg, raster_image = process_and_rasterize_svg(raw_svg)
模型
我们提供了 Hugging Face 🤗 模型检查点,用于 Image2SVG 矢量化任务,包括 💫 StarVector-8B 和 💫 StarVector-1B。以下为它们在 SVG-Bench 上的表现,使用 DinoScore 度量。
方法 | SVG-Stack | SVG-Fonts | SVG-Icons | SVG-Emoji | SVG-Diagrams |
---|---|---|---|---|---|
AutoTrace | 0.942 | 0.954 | 0.946 | 0.975 | 0.874 |
Potrace | 0.898 | 0.967 | 0.972 | 0.882 | 0.875 |
VTracer | 0.954 | 0.964 | 0.940 | 0.981 | 0.882 |
Im2Vec | 0.692 | 0.733 | 0.754 | 0.732 | - |
LIVE | 0.934 | 0.956 | 0.959 | 0.969 | 0.870 |
DiffVG | 0.810 | 0.821 | 0.952 | 0.814 | 0.822 |
GPT-4-V | 0.852 | 0.842 | 0.848 | 0.850 | - |
💫 StarVector-1B | 0.926 | 0.978 | 0.975 | 0.929 | 0.943 |
💫 StarVector-8B | 0.966 | 0.982 | 0.984 | 0.981 | 0.959 |
注:StarVector 不适用于自然图像或插图,因为它们未经过相关训练。它在矢量化图标、标志、技术图表、图形和图表方面表现出色。