百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

Janus:DeepSeek 在多模态理解与生成的新突破

ztj100 2025-02-11 14:27 16 浏览 0 评论

DeepSeek 的爆发使其在多方面的研究都得到了更多关注;同时,人工智能领域的多模态技术正逐渐成为研究的热点,多模态理解与生成旨在让机器能够同时处理和理解多种类型的数据,如文本、图像和视频等,并生成有意义的输出。

DeepSeek 的 Janus 项目在这一领域崭露头角,其致力于实现统一的多模态理解和生成,推动人工智能技术的进一步发展。



简介

Janus 是 DeepSeek 团队推出的、旨在统一多模态理解和生成的一系列模型,项目地址为
https://github.com/deepseek-ai/Janus
。Janus 系列目前包括 3 个模型,分别是:Janus、JanusFlow 和 Janus-Pro。

Janus 模型 是一个新颖的自回归框架,其将多模态理解和生成统一起来,基于论文 《Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation》。

Janus 通过将视觉编码分解为独立的路径,同时仍然使用单一的统一 Transformer 架构进行处理,解决了先前方法的一些局限性。这种分解不仅缓解了视觉编码器在理解和生成任务之间的冲突,还增强了该框架的灵活性。Janus 超越了先前的统一模型,其与特定任务模型相比起来性能相当甚至更优。Janus 的简单性、高度灵活性和有效性使其成为下一代统一多模态模型的有力候选者。

JanusFlow 模型 引入了一种极简架构,该架构将自回归语言模型与修正流(一种生成式建模中的前沿方法)相结合,基于论文《JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation》

JanusFlow 模型表明修正流可以在大语言模型框架内直接进行训练,无需进行复杂的架构修改。大量实验表明,JanusFlow 在各自领域的性能与专业模型相当或更优,同时在标准基准测试中显著超越了现有的统一方法。这项工作朝着更高效、更通用的视觉 - 语言模型迈出了一步。

Janus-Pro 模型是先前作品 Janus 的进阶版本,基于论文《Janus-Pro: Unified Multimodal Understanding and Generation with Data and Model Scaling》。

Janus-Pro 相比 Janus 融入了以下几点:(1)优化的训练策略;(2)扩充的训练数据;(3)扩大模型规模。通过这些改进,Janus-Pro 在多模态理解和文本到图像的指令遵循能力方面都取得了显著进步,同时还提高了文本到图像生成的稳定性。



使用

Janus 模型可以直接从 Huggingface 下载使用:

  • Janus-1.3B:https://huggingface.co/deepseek-ai/Janus-1.3B
  • JanusFlow-1.3B:https://huggingface.co/deepseek-ai/JanusFlow-1.3B
  • Janus-Pro-1B:https://huggingface.co/deepseek-ai/Janus-Pro-1B
  • Janus-Pro-7B:https://huggingface.co/deepseek-ai/Janus-Pro-7B

此外,开发者们也可以自行安装部署,首先拉取仓库代码:

git clone https://github.com/deepseek-ai/Janus.git

Janus 模型要求 Python >= 3.8,首先安装依赖:

pip install -e .

我们可以简单地写一个从文本生成图片的样例。首先引入依赖:

import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor

然后指定模型:

model_path = "deepseek-ai/Janus-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

提供一个 prompt,描述指定要求生成的图片:

conversation = [
    {
        "role": "User",
        "content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
    },
    {"role": "Assistant", "content": ""},
]
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag

然后注册一个 torch 的生成方法,自行构建模型中的各层:

@torch.inference_mode()
def generate(mmgpt: MultiModalityCausalLM, vl_chat_processor: VLChatProcessor, prompt: str, temperature: float = 1, parallel_size: int = 16, cfg_weight: float = 5, image_token_num_per_image: int = 576, img_size: int = 384, patch_size: int = 16):
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)

    tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
    for i in range(parallel_size*2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id

    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state
        
        logits = mmgpt.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        
        logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)


    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    os.makedirs('generated_samples', exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
        PIL.Image.fromarray(visual_img[i]).save(save_path)

完成后,调用 generate 进行生成:

generate(vl_gpt, vl_chat_processor, prompt)


对于 JanusFlow,则可以指定对应的模型,其他使用方法基本保持不变:

model_path = "deepseek-ai/JanusFlow-1.3B"

Janus-Pro 也是类似的,一般来说 Janus-Pro 可能会提供更好的模型效果:

model_path = "deepseek-ai/Janus-Pro-7B"


更多的模型样例,可以参考 Huggingface 的模型下载页面,DeepSeek 提供了不少的在线 demo 可供参考使用。


总结

Janus 项目在多模态理解和生成领域具有重要的应用价值,是 DeepSeek 的武器库中重要的一员。其可以应用于多种场景,如根据文字提供设计稿、理解图片生成文字描述、完成多模态的理解任务等。通过统一的框架,Janus 为多模态任务的处理提供了更加便捷和高效的方式,能在未来的人工智能应用中发挥更大的作用,推动多模态技术的进一步发展。

相关推荐

告别手动操作:一键多工作表合并的实用方法

通常情况下,我们需要将同一工作簿内不同工作表中的数据进行合并处理。如何快速有效地完成这些数据的整合呢?这主要取决于需要合并的源数据的结构。...

【MySQL技术专题】「优化技术系列」常用SQL的优化方案和技术思路

概述前面我们介绍了MySQL中怎么样通过索引来优化查询。日常开发中,除了使用查询外,我们还会使用一些其他的常用SQL,比如INSERT、GROUPBY等。对于这些SQL语句,我们该怎么样进行优化呢...

9.7寸视网膜屏原道M9i双系统安装教程

泡泡网平板电脑频道4月17日原道M9i采用Win8安卓双系统,对于喜欢折腾的朋友来说,刷机成了一件难事,那么原道M9i如何刷机呢?下面通过详细地图文,介绍原道M9i的刷机操作过程,在刷机的过程中,要...

如何做好分布式任务调度——Scheduler 的一些探索

作者:张宇轩,章逸,曾丹初识Scheduler找准定位:分布式任务调度平台...

mysqldump备份操作大全及相关参数详解

mysqldump简介mysqldump是用于转储MySQL数据库的实用程序,通常我们用来迁移和备份数据库;它自带的功能参数非常多,文中列举出几乎所有常用的导出操作方法,在文章末尾将所有的参数详细说明...

大厂面试冲刺,Java“实战”问题三连,你碰到了哪个?

推荐学习...

亿级分库分表,如何丝滑扩容、如何双写灰度

以下是基于亿级分库分表丝滑扩容与双写灰度设计方案,结合架构图与核心流程说明:一、总体设计目标...

MYSQL表设计规范(mysql表设计原则)

日常工作总结,不是通用规范一、表设计库名、表名、字段名必须使用小写字母,“_”分割。...

怎么解决MySQL中的Duplicate entry错误?

在使用MySQL数据库时,我们经常会遇到Duplicateentry错误,这是由于插入或更新数据时出现了重复的唯一键值。这种错误可能会导致数据的不一致性和完整性问题。为了解决这个问题,我们可以采取以...

高并发下如何防重?(高并发如何防止重复)

前言最近测试给我提了一个bug,说我之前提供的一个批量复制商品的接口,产生了重复的商品数据。...

性能压测数据告诉你MySQL和MariaDB该怎么选

1.压测环境为了尽可能的客观公正,本次选择同一物理机上的两台虚拟机,一台用作数据库服务器,一台用作运行压测工具mysqlslap,操作系统均为UbuntuServer22.04LTS。...

屠龙之技 --sql注入 不值得浪费超过十天 实战中sqlmap--lv 3通杀全国

MySQL小结发表于2020-09-21分类于知识整理阅读次数:本文字数:67k阅读时长≈1:01...

破防了,谁懂啊家人们:记一次 mysql 问题排查

作者:温粥一、前言谁懂啊家人们,作为一名java开发,原来以为mysql这东西,写写CRUD,不是有手就行吗;你说DDL啊,不就是设计个表结构,搞几个索引吗。...

SpringBoot系列Mybatis之批量插入的几种姿势

...

MySQL 之 Performance Schema(mysql安装及配置超详细教程)

MySQL之PerformanceSchema介绍PerformanceSchema提供了在数据库运行时实时检查MySQL服务器的内部执行情况的方法,通过监视MySQL服务器的事件来实现监视内...

取消回复欢迎 发表评论: