🚀 简介

StarVector 是一个用于生成 SVG(可缩放矢量图形)的多模态视觉语言模型。它可以执行 Image2SVG 和 Text2SVG 生成任务。我们将图像生成视为代码生成任务,利用多模态 VLM(视觉语言模型)的强大能力。

摘要:SVG 在现代图像渲染中至关重要,因为它具备可扩展性和多功能性。以往的 SVG 生成方法主要依赖于基于曲线的矢量化技术,但缺乏语义理解,容易产生伪影,并且难以处理 path 曲线以外的 SVG 基元。为了解决这些问题,我们推出了 StarVector,一个用于 SVG 生成的多模态大语言模型。它通过理解图像语义并使用 SVG 基元进行矢量化,从而生成紧凑且精确的输出。与传统方法不同,StarVector 直接在 SVG 代码空间中操作,利用视觉理解生成准确的 SVG 基元。为了训练 StarVector,我们创建了 SVG-Stack 数据集,包含 200 万个样本,能够在多种矢量化任务中泛化,并精确使用椭圆、多边形和文本等基元。同时,我们解决了 SVG 评估中的难题,发现像 MSE 这种基于像素的度量无法有效衡量矢量图形的独特特性。因此,我们推出了 SVG-Bench 基准,包括 10 个数据集和 3 项任务:图像到 SVG、文本到 SVG 生成,以及图表生成。在这些任务中,StarVector 实现了 SOTA(最先进)性能,生成更紧凑且语义丰富的 SVG。

多模态架构

StarVector 使用多模态架构处理图像和文本。在执行 Image2SVG(图像矢量化)任务时,图像会被映射为视觉标记(visual tokens),并生成对应的 SVG 代码。在 Text2SVG 任务中,模型仅接收文本指令(不提供图像),生成新的 SVG。该模型基于 StarCoder 构建,我们借助其编码能力将其扩展至 SVG 生成任务。

简介 - 图1

安装

  1. 克隆项目仓库并进入 star-vector 文件夹

    1. git clone https://github.com/joanrod/star-vector.git
    2. cd star-vector
  2. 安装依赖包

    1. conda create -n starvector python=3.11.3 -y
    2. conda activate starvector
    3. pip install --upgrade pip # 启用 PEP 660 支持
    4. pip install -e .
  3. 安装训练所需的额外依赖

    1. pip install -e ".[train]"

更新到最新代码版本

  1. git pull
  2. pip install -e .

快速开始 - Image2SVG 生成

  1. from PIL import Image
  2. from starvector.model.starvector_arch import StarVectorForCausalLM
  3. from starvector.data.util import process_and_rasterize_svg
  4. model_name = "starvector/starvector-8b-im2svg"
  5. starvector = StarVectorForCausalLM.from_pretrained(model_name)
  6. starvector.cuda()
  7. starvector.eval()
  8. image_pil = Image.open('assets/examples/sample-0.png')
  9. image = starvector.process_images([image_pil])[0].cuda()
  10. batch = {"image": image}
  11. raw_svg = starvector.generate_im2svg(batch, max_length=1000)[0]
  12. svg, raster_image = process_and_rasterize_svg(raw_svg)

使用 HuggingFace AutoModel

  1. from PIL import Image
  2. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
  3. from starvector.data.util import process_and_rasterize_svg
  4. import torch
  5. model_name = "starvector/starvector-8b-im2svg"
  6. starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True)
  7. processor = starvector.model.processor
  8. tokenizer = starvector.model.svg_transformer.tokenizer
  9. starvector.cuda()
  10. starvector.eval()
  11. image_pil = Image.open('assets/examples/sample-18.png')
  12. image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
  13. if not image.shape[0] == 1:
  14. image = image.squeeze(0)
  15. batch = {"image": image}
  16. raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0]
  17. svg, raster_image = process_and_rasterize_svg(raw_svg)

模型

我们提供了 Hugging Face 🤗 模型检查点,用于 Image2SVG 矢量化任务,包括 💫 StarVector-8B 和 💫 StarVector-1B。以下为它们在 SVG-Bench 上的表现,使用 DinoScore 度量。

方法 SVG-Stack SVG-Fonts SVG-Icons SVG-Emoji SVG-Diagrams
AutoTrace 0.942 0.954 0.946 0.975 0.874
Potrace 0.898 0.967 0.972 0.882 0.875
VTracer 0.954 0.964 0.940 0.981 0.882
Im2Vec 0.692 0.733 0.754 0.732 -
LIVE 0.934 0.956 0.959 0.969 0.870
DiffVG 0.810 0.821 0.952 0.814 0.822
GPT-4-V 0.852 0.842 0.848 0.850 -
💫 StarVector-1B 0.926 0.978 0.975 0.929 0.943
💫 StarVector-8B 0.966 0.982 0.984 0.981 0.959

:StarVector 不适用于自然图像或插图,因为它们未经过相关训练。它在矢量化图标、标志、技术图表、图形和图表方面表现出色。