TTS语音克隆开源模型-Fish、F5、GPT、CosyVoice、MaskGCT

GPT-SoVITS

https://github.com/RVC-Boss/GPT-SoVITS

GPT-SoVITS项目是TTS克隆领域内效果常年霸榜的模型之一,具有以下功能:

  • 零样本文本到语音(TTS): 输入 5 秒的声音样本,即刻体验文本到语音转换。
  • 少样本 TTS:仅需 1 分钟的训练数据即可微调模型,提升声音相似度和真实感。
  • 跨语言支持:支持与训练数据集不同语言的推理,目前支持英语、日语、韩语、粤语和中文。
  • WebUI 工具:集成工具包括声音伴奏分离、自动训练集分割、中文自动语音识别(ASR)和文本标注,协助初学者创建训练数据集和 GPT/SoVITS 模型。

部署和使用教程:【34.8k点赞量!】TTS领域内明星模型GPT-SoVITS实操教程来啦;2秒语音就能克隆,效果过于惊艳,请谨慎使用!

  1. https://github.com/RVC-Boss/GPT-SoVITS
  2. https://huggingface.co/lj1995/GPT-SoVITS

MaskGCT

https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/README.md

文本转语音TTS系统通常被分为自回归和非自回归系统。自回归系统隐式地建模持续时间,但在鲁棒性方面存在一定的缺陷,并且缺乏持续时间的可控性。非自回归系统在训练期间需要显式的文本和语音之间的对齐信息,并预测语言单位(例如音素)的持续时间,这可能会影响其自然性。在10月24日,趣丸科技&香港中文大学提出一种完全非自回归的TTS模型——掩码生成编解码器变换器(MaskGCT),它消除了对文本和语音监督之间显式对齐信息的需求,以及对音素级别持续时间预测的需求。

MaskGCT模型框架如下:

  • 语音语义表示编解码器:这部分将语音转换为semantic tokens,这是将语音信号的声学特征抽象成更高层次的语义信息的过程。
  • 文本到语义模型:这个模型使用文本和提示semantic tokens来预测语义标记。它的作用是理解文本内容并将其映射到相应的语义空间。
  • 语义到声学模型:在得到语义标记后,这个模型会基于这些语义标记来预测声学标记,即将语义信息进一步转换为声学特征,这些声学特征更接近于实际的语音波形。
  • 语音声学编解码器:最后,这个部分负责从声学标记重建语音波形,即将预测的声学特征转换成可以被听到的语音信号。

在训练期间,MaskGCT学习基于给定条件和提示预测掩码的语义或声学标记。在推理期间,模型以并行方式生成指定长度的标记。MaskGCT模型是基于10万小时数据集Emilia训练而来的,精通中英日韩法德6种语言的跨语种合成。数据集Emilia是全球最大且最为多样的高质量多语种语音数据集之一。

MaskGCT模型实验性能

可以看出 MaskGCT模型整体性能超了CosyVoice,XTTS-v2模型性能。

MaskGCT模型运行占用显存(大约10G左右)

部署和使用教程:【又又一款王炸级别TTS模型】趣丸科技&港中大开源MaskGCT语音大模型,性能超过CosyVoice,XTTS-v2!

  1. https://arxiv.org/pdf/2409.00750
  2. https://hf-mirror.com/amphion/MaskGCT
  3. https://maskgct.github.io/
  4. github: https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct
  5. 在线体验: https://huggingface.co/spaces/amphion/maskgct

F5-TTS语音模型

https://github.com/SWivid/F5-TTS

E2-TTS 语音模型介绍

E2-TTS是由微软公司(Microsoft Corporation, USA)的研究团队开发的,具有以下特点:

  1. 简单架构:E2-TTS具有非常简单的架构,仅由填充标记的文本序列和基于流匹配的mel频谱图生成器组成。
  2. 无需额外组件:E2-TTS不需要额外的组件,例如持续时间模型(duration model)、字素到音素转换器(grapheme-to-phoneme converter)或复杂的对齐搜索技术(如单调对齐搜索)。
  3. 高性能:尽管架构简单,E2-TTS在零样本(zero-shot)TTS能力上达到了与之前工作相当或更好的性能,包括Voicebox和NaturalSpeech 3。
  4. 灵活性:E2-TTS在输入表示上具有灵活性,允许在推理期间提高可用性。

F5-TTS 语音模型介绍 :

F5-TTS是一款基于流匹配的全非自回归文本到语音转换模型,由上海交通大学(Shanghai Jiao Tong University)、剑桥大学(University of Cambridge)、以及极氪汽车研究院(Geely Automobile Research Institute (Ningbo) Company Ltd.)的研究团队联合开发的。具有以下特点:

  1. 改进的文本表示:F5-TTS使用ConvNeXt对输入文本进行细化,以改善与语音的对齐,解决了E2-TTS中存在的鲁棒性问题。
  2. Sway Sampling策略:F5-TTS提出了一种新的推理时采样策略,称为Sway Sampling,它显著提高了模型的性能和效率。这种采样策略可以轻松地应用于现有的基于流匹配的模型,而无需重新训练。
  3. 更快的训练与推理:F5-TTS的设计允许更快的训练,并且在推理时实现了0.15的实时因子(Real-Time Factor, RTF),与现有的基于扩散的TTS模型相比,这是一个显著的改进。
  4. 零样本能力:F5-TTS在公共100K小时多语言数据集上训练,展示了高度自然和富有表现力的零样本能力,以及无缝的代码切换能力。
  5. 开源:F5-TTS的代码和检查点被开源,以促进社区发展。

F5-TTS在E2-TTS的基础上进行了改进,特别是在文本表示的细化和推理时采样策略上。这些改进使得F5-TTS在保持简单架构的同时,提供了更好的性能和更快的推理速度。此外,F5-TTS的零样本能力更强,且完全开源。开源协议MIT。

F5-TTS模型性能介绍

这是F5-TTS和E2-TTS在测试集上的结果;

可以看出F5-TTS模型的整体效果是超过CosySense效果的;

部署和使用教程:【克隆TTS领域又更新啦】上海交大开源F5-TTS: 只需要2秒就能克隆语音,可商用,合成语音效果让我震惊不已!

  1. https://github.com/SWivid/F5-TTS
  2. F5-TTS: https://arxiv.org/pdf/2410.06885
  3. E2-TTS:https://arxiv.org/pdf/2406.18009
  4. https://hf-mirror.com/SWivid/F5-TTS
  5. https://hf-mirror.com/SWivid/E2-TTS

FishSpeech1.4模型

https://hf-mirror.com/fishaudio/fish-speech-1.4

https://github.com/fishaudio/fish-speech

fish.audio团队最新开源的FishSpeech1.4;支持中文、英文等8种语音,具有以下特点:

  • 零样本和少样本文本转语音(TTS):输入一个10到30秒的语音样本,即可生成高质量的TTS输出。有关详细指南,请参见语音克隆最佳实践。
  • 多语言和跨语言支持:只需将多语言文本复制粘贴到输入框中——无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。
  • 无需音素依赖:该模型具有强大的泛化能力,不依赖于音素进行TTS。它可以处理任何语言脚本的文本。
  • 高度准确:对于5分钟的英文文本,实现了约2%的低CER(字符错误率)和WER(词错误率)。
  • 快速:借助fish-tech加速技术,在Nvidia RTX 4060笔记本电脑上实时因子约为1:5,在Nvidia RTX 4090上为1:15。
  • WebUI推理:功能强大,基于Gradio的Web UI,兼容Chrome、Firefox、Edge等浏览器。
  • GUI推理:提供与API服务器无缝协作的PyQt6图形界面。支持Linux、Windows和macOS。见GUI。
  • 部署友好:可以轻松设置推理服务器,原生支持Linux、Windows和MacOS,最小化速度损失。目前在huggingface社区下载量高达5.1K!

部署和使用教程: 【又一款王炸级别语音克隆TTS模型】FishSpeech重磅开源1.4版本!语音合成更逼真!跟最近爆火F5-TTS相比如何呢?

  1. https://hf-mirror.com/fishaudio/fish-speech-1.4
  2. https://speech.fish.audio/zh/inference/#_2
  3. https://github.com/fishaudio/fish-speech
  4. https://hf-mirror.com/SWivid/F5-TTS
  5.  https://github.com/SWivid/F5-TTS

CosyVoice模型

https://github.com/FunAudioLLM/CosyVoice

CosyVoice 是一个语音生成模型,能够合成自然声音,适用于多种应用。模型支持五种语言:中文、英语、日语、粤语和韩语。CosyVoice 包含三个开源模型:

  • CosyVoice-base-300M:擅长准确代表说话者身份,无需微调即可适应不同上下文,能够跨语言克隆声音。
  • CosyVoice-instruct-300M:能够生成富有情感表现力的语音,允许通过指令文本进行精细调整。
  • CosyVoice-sft-300M:已针对七位多语言说话者进行了微调,适合立即部署使用。

语音合成模型 CosyVoice 功能特点:

  • 多语言支持:CosyVoice 支持包括中文、英文、日语、粤语和韩语在内的五种语言。
  • 零样本学习:能够无需训练即可适应新说话者(zero-shot in-context learning),能够在不同语言之间复制声音。
  • 情感共鸣:能够创建情感共鸣的声音, CosyVoice-instruct 版本通过情感指令显著提高了情感控制的准确性。
  • 高质量语音合成:生成的样本在词错误率(WER)和说话者相似性方面达到人类水平。
  • 语音定制化:能够根据特定说话者生成多语言语音,适应新说话者而无需训练。
  • 语音克隆与风格迁移:支持在不同语言之间进行语音克隆和情感风格迁移。

部署和使用教程:【语音领域-又双叒更新】阿里开源FunAudioLLM: 2大核心模型、5大亮点功能!效果炸裂!手把手带你理论+实战部署推理!

  1. CosyVoice: https://github.com/FunAudioLLM/CosyVoice
  2. SenseVoice: https://github.com/FunAudioLLM/SenseVoice
  3. FunAudioLLM论文报告: https://fun-audio-llm.github.io/pdf/FunAudioLLM.pdf
  4. CosyVoice论文报告: https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf
  5. https://fun-audio-llm.github.io/
  6. https://www.modelscope.cn/studios/iic/SenseVoice
  7. https://www.modelscope.cn/studios/iic/CosyVoice-300M

参考链接

  1. https://github.com/SWivid/F5-TTS
  2. https://hf-mirror.com/amphion/MaskGCT
  3. https://hf-mirror.com/fishaudio/fish-speech-1.4
  4. https://github.com/RVC-Boss/GPT-SoVITS
  5. https://github.com/FunAudioLLM/CosyVoice

CleanS2S-语音到语音 (S2S) 的原型智能体

https://github.com/opendilab/CleanS2S

CleanS2S 是一个语音到语音 (S2S) 的原型智能体,提供高质量的流式交互,并采用单文件实现。其设计简洁明了,旨在提供类似 GPT-4o 风格的中文交互原型智能体。该项目希望让用户直接体验语言用户界面 (LUI) 的强大功能,并帮助研究人员快速探索和验证 S2S pipeline 的潜力。

功能

📜 单文件实现

每个智能体管道的细节都放在一个独立的文件中。无需额外配置依赖项或理解项目文件结构。这对于那些想快速了解 S2S 管道并直接验证新想法的人来说,是一个很好的参考实现。所有管道实现都易于修改和扩展,用户可以快速更换喜欢的模型(例如 LLM)、添加新组件或自定义管道。

实时流式接口

整个 S2S 管道主要由 ASR(自动语音识别)、LLM(大型语言模型)和 TTS(文本转语音)组成,配合两个 WebSockets 组件接收器(包含 VAD)和发送器。管道设计为实时流模式,用户可以像人与人对话一样实时与智能体互动。所有音频和文本信息通过 WebSocket 流式发送和接收。为此,我们利用多线程和队列机制确保流过程顺畅,避免阻塞问题。所有组件都设计为异步和非阻塞,处理输入队列的数据并将结果输出到另一个队列。

🧫 全双工交互与打断机制

基于 WebSockets 提供的强大机制,管道支持全双工交互,这意味着用户可以同时与智能体对话和听取回复。此外,管道支持中断,用户可以在对话中随时通过新语音输入打断智能体。智能体将停止当前处理,开始处理新输入,并结合之前的对话和中断内容进行处理。此外,我们发现聊天机器人常用的“助理风格”和“轮流式”回应是人类对话的主要缺点之一。我们为智能体添加了更有趣的策略,以使对话更具互动性和吸引力。

🌍 网络搜索和 RAG

通过集成网络搜索功能和检索增强生成(RAG)模型,管道得到了进一步增强。这些功能使智能体不仅能实时处理和响应用户输入,还能从网络中获取和整合外部信息到响应中。这为回答用户提出的各种实际问题提供了扩展和灵活性。

  • WebSearchHelper 类负责根据用户查询进行在线搜索或收集与对话相关的附加信息。这使智能体能够参考最新或外部数据,增强响应的丰富性和准确性。
  • RAG 类实现了检索增强生成方法,首先从数据库中检索相关信息,然后使用这些信息生成响应。这一两步过程确保智能体的回复基于相关的事实数据,使互动更加知情和符合上下文。

快速上手

后端

安装

## clone the repository
git clone https://github.com/opendilab/CleanS2S.git
cd CleanS2S/backend
pip install -r requirements.txt
  • 根据此处的说明安装 funasr 以支持 paraformer-zh
  • 根据此处的说明安装 cosyvoice 以支持 CosyVoice-300M

下载模型

您需要下载以下四个必要的模型(3个 ASR 模型 + 1个 TTS 模型),可以通过以下链接下载,并放置在合适的目录中。

对于 LLM,我们默认使用 LLM API,您也可以按照下方的说明定制自己的本地 LLM(如 DeepSeek-V2.5、Qwen2.5 等)。

删除 --enable_llm_api 和 --lm_model_url 参数,修改 --lm_model_name 参数为您的本地 LLM 模型路径(例如 --lm_model_name /home/users/deepseek-v2.5)。

您还需要准备一个参考音频目录,其中包含用于韵律和音色转换的参考音频。我们在此仓库中准备了一个示例参考音频目录

如果您想使用自己的参考音频,需要保持与示例参考音频目录相同的格式。音频应为 10~20 秒长,发音清晰。

运行服务器

以下是使用默认设置运行服务器的示例:

export LLM_API_KEY=<your-deepseek-api-key>
python3 -u s2s_server_pipeline.py \
        --recv_host 0.0.0.0 \
        --send_host 0.0.0.0 \
        --stt_model_name <your-asr-path> \
        --enable_llm_api \
        --lm_model_name "deepseek-chat" \
        --lm_model_url "https://api.deepseek.com" \
        --tts_model_name <your-tts-path> \
        --ref_dir <ref-audio-path> \
        --enable_interruption

ℹ️ 支持自定义LLM:在这里,我们使用 deepseek-chat 作为默认 LLM API ,您也可以根据 OpenAI 接口更改为其他 LLM API。(修改--lm_model_name--lm_model_url,设置您自己的 API 密钥)

ℹ️ 支持其他自定义:您可以参考后端管道文件(例如s2s_server_pipeline.py)中由argparse库实现的参数列表,根据自己的需求进行自定义。所有参数在其帮助属性中都有详细文档,易于理解。

使用 Websearch+RAG 运行服务器

您首先需要安装 Websearch 和 RAG 所需的依赖。

pip install -r backend/requirements-rag.txt

其次,为 RAG 中嵌入 Websearch 结果选择一个嵌入模型,例如以下嵌入模型:

git lfs install
git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

然后,为 Websearch 和 RAG 模块提供令牌,在s2s_server_pipeline_rag.py中,我们使用Serper作为 Websearch 工具,使用Deepseek进行 RAG 。

export LLM_API_KEY=''
export SERPER_API_KEY=''

最后,在运行服务器的示例代码中,将s2s_server_pipeline.py替换为s2s_server_pipeline_rag.py,并添加参数--embedding_model_name

这是使用默认设置和 Websearch+RAG 运行服务器的示例:

python3 -u s2s_server_pipeline_rag.py \
        --recv_host 0.0.0.0 \
        --send_host 0.0.0.0 \
        --stt_model_name <your-asr-path> \
        --enable_llm_api \
        --lm_model_name "deepseek-chat" \
        --lm_model_url "https://api.deepseek.com" \
        --tts_model_name <your-tts-path> \
        --embedding_model_name <embedding-model-path> \
        --ref_dir <ref-audio-path> \
        --enable_interruption

前端

我们建议使用Docker镜像来安装和运行客户端。以下是具体步骤:

## 运行基本的Docker镜像
docker run -it -p 3001:3001 amazonlinux:2023.2.20231011.0 sh
## 安装必要的包
dnf install vim git nodejs -y
npm install -g pnpm
git clone https://github.com/opendilab/CleanS2S.git
cd CleanS2S/frontend_nextjs
pnpm install

frontend_nextjs目录中准备适当的.env.local文件,您可以参考.env.example文件以获取所需的环境变量。

## 运行客户端
pnpm dev --port 3001

然后您可以在浏览器中访问客户端,地址为http://localhost:3001(推荐使用 Chrome 浏览器)。

附注:如果您想在本地运行客户端,请首先安装 node.js 和 pnpm ,然后使用 pnpm 安装必要的包并运行客户端。

MooER (摩尔): 基于8万小时训练数据的开源音频理解大模型

MooER: LLM-based Speech Recognition and Translation Models from Moore Threads

Github: https://github.com/MooreThreads/MooER
ModelScope: https://modelscope.cn/models/MooreThreadsSpeech/MooER-MTL-5K
Huggingface: https://huggingface.co/mtspeech/MooER-MTL-5K

paper:https://arxiv.org/abs/2408.05101

🎉🎉🎉我们发布了支持普通话输入的新 Omni (MooER-omni-v1) 和语音转语音翻译 (MooER-S2ST-v1) 模型。Omni 模型可以听到、思考和与您交谈!请在此处查看我们的演示

 Omni (MooER-omni-v1)

在本工作中,我们推出了摩耳大模型(英文名:MooER)—— 一个由摩尔线程开发的、基于大语言模型(Large Language Model,LLM)的语音识别和语音翻译系统。通过摩尔框架,您可以基于大语言模型,以端到端的方式,将输入语音自动转录为文本(即语音识别),并将其翻译为其它语言(即语音翻译)。关于MooER的具体效果,您可以查阅下文中有关评测结果的部分。在我们公布的技术报告中,我们提供了更详细的实验结果,并分享了我们对模型配置、训练策略等方面的理解。

MooER是业界首个基于国产全功能GPU进行训练和推理的大型开源语音模型。依托摩尔线程夸娥(KUAE)智算平台,MooER大模型仅用38小时便完成了5000小时音频数据和伪标签的训练,这一成就得益于自研的创新算法和高效计算资源的结合。

MooER不仅支持中文和英文的语音识别,还具备中译英的语音翻译能力。在多个语音识别领域的测试集中,MooER展现出领先或至少持平的优异表现。特别值得一提的是,在Covost2中译英测试集中,MooER-5K取得了25.2的BLEU分数,接近工业级效果。摩尔线程AI团队在该工作中开源了推理代码和5000小时数据训练的模型,并计划进一步开源训练代码和基于8万小时数据训练的模型,希望该工作能够在语音大模型的方法演进和技术落地方面为社区做出贡献。

MooER主要功能:

  • 语音识别:支持中文和英文的语音到文本的转换
  • 语音翻译:具备中文语音翻译成英文文本的能力
  • 高效率训练:在摩尔线程的智算平台上,快速完成大量数据的训练
  • 开源模型:推理代码和部分训练模型已经开源,便于社区使用和进一步研究。

MooER 模型、实验:

  • 深度学习架构:MoOER采用了深度学习技术,特别是神经网络来处理和理解语音信号端到端训练:模型从原始语音信号直接到文本输出,无需传统语音识别系统中的多个独立模块。
  • Encoder-Adapter-Decoder结构:
    • Encoder:负责将输入的语音信号转换成一系列高级特征表示。
    • Adapter:用于调整和优化模型对特定任务的适应性,提高型的泛化能力。
    • Decoder(Large Language Model,LLM):基于这些特征生成最终的文本输出。
  • LoRA技术:使用LoRA(Low-Rank Adaptation)技术,一种参数高效的模型微调方法,通过只更新模型中一小部分参数来提高训练效率和效果。
  • 伪标签训练:在训练过程中使用伪标签技术,即用模型自身的预测作为训练数据,以增强模型的学习能力。
  • 多语言支持:MOOER支持中文和英文的语音识别,以及中译英的语音翻译,显示出其多语言处理能

MooER的模型结构

包括Encoder、Adapter和Decoder(Large Language Model,LLM)三个部分。其中,由Encoder对输入的原始音频进行建模,提取特征并获取表征向量。Encoder的输出会送到Adapter进一步下采样,使得每120ms音频输出一组音频Embedding。音频Embedding和文本的Prompt Embedding拼接后,再送进LLM进行对应的下游任务,如语音识别(Automatic Speech Recognition,ASR)、语音翻译(Automatic Speech Translation,AST)等。在模型训练阶段,融合了语音模态和文本模态的数据会按以下形式输入到LLM:

训练数据格式

MooER的训练

我们使用开源的Paraformer语音编码器、Qwen2-7B-instruct大语言模型来初始化Encoder和LLM模块,并随机初始化Adapter模块。训练过程中,Encoder始终固定参数,Adapter和LLM会参与训练和梯度更新。利用自研的夸娥智算平台,我们使用DeepSpeed框架和Zero2策略,基于BF16精度进行训练和推理。经实验发现,训练过程中更新LLM参数能够提升最终音频理解任务的效果。为了提升训练效率,我们采用了LoRA技术,仅更新2%的LLM参数。具体的模型参数规模如下:

MooER 数据集:

该模型的训练数据MT5K(MT 5000h)由部分开源数据和内部数据构成,内部数据的语音识别标签均是由第三方云服务得到的伪标签。语音识别的伪标签经过一个文本翻译模型后,得到语音翻译的伪标签。我们没有对这些伪标签数据做任何的人工筛选。具体数据来源和对应的规模如下:

MooER实验结果:

我们将MooER与多个开源的音频理解大模型进行了对比,包括Paraformer、SenseVoice、Qwen-audio、Whisper-large-v3和SeamlessM4T-v2等。这些模型的训练规模从几万小时到上百万小时不等。对比结果显示,我们的开源模型MooER-5K在六个中文测试集上的CER(字错误率)达到4.21%,在六个英文测试集的WER(词错误率)为17.98%,与其它开源模型相比,MooER-5K的效果更优或几乎持平。特别是在Covost2 zh2en中译英测试集上,MooER的BLEU分数达到了25.2,显著优于其他开源模型,取得了可与工业水平相媲美的效果。基于内部8万小时数据训练的MooER-80k模型,在上述中文测试集上的CER达到了3.50%,在英文测试集上的WER到达了12.66%。

• Paraformer-large: 60,000 hours ASR data
• SenseVoice small: 300,000 hours ASR data
• Qwen-audio: 53,000 hours ASR data + 3700 hours S2TT data + …
• WhisperV3: 1000,000 hours weakly labels, 4000,000 hours pseudo labels
• SeamlessM4T2: 351,000 hours S2TT data, 145,000 hours S2ST data
• MooER-5K: 5,000 hours pseudo labels【伪标签】
• MooER-80K: 80,000 hours pseudo labels【伪标签】

建议

与此同时,我们还得到一些有趣的结论,可以为数据资源和计算资源有限的开发者提供一些建议:

▼Encoder的选择。我们分别对比了无监督(Self-Supervised Learning)训练的W2v-bert 2.0、半监督(Semi-Supervised Learning)训练的Whisper v3和有监督(Supervised Learning)训练的Paraformer。我们发现,采用无监督训练得到的Encoder必须参与到训练过程中,否则模型很难收敛。综合考虑模型效果、参数量以及训练和推理的效率,我们选择Paraformer作为Encoder。

▼音频建模粒度很关键。我们尝试使用240ms、180ms和120ms的粒度进行建模,并发现这一参数对音频与文本的融合效果具有重要影响,同时会影响模型的最终效果和训练的收敛速度。经过评估,我们最终选择每120ms输出一个音频Embedding

▼快速适应到目标垂类。我们仅使用了140h~150h的英文数据进行训练,可以在6个不同来源的英文的测试集上取得一定效果。同时我们尝试将任务迁移到语音翻译(AST)领域,取得了很好的效果。我们相信这个方法同样也适用于小语种、方言或其它低资源的音频理解任务。

▼LLM对音频理解任务的影响。我们发现,在模型训练过程中采用LoRA技术对LLM参数进行更新,可以使训练更快收敛,并且最终取得更好的效果。同时,音频理解任务上的效果也会随着基础LLM效果提升而提升。【LLM模型越大,效果越好。训练参数越多,效果越好】

是否冻结LLM,以及LLM模型的选择

加速训练:

优化了数据加载器部分,在相同配置下可以将训练速度提高4到5倍。同时,我们基于5000小时的训练优化了DeepSpeed的训练策略,并将其重新用于我们8wh内部数据的训练。对于需要解冻编码器的训练,我们使用梯度检查点技术以减少内存使用。我们使用基于Moore Threads的KUAE平台加速大型模型的训练。

训练参数:

应用场景:

  • 实时语音转写:在会议、讲座、课堂等场合,MOOER可以实时将语音转换为文字,便于记录和回顾。
  • 多语言翻译:支持中英文之间的语音翻译,适用于跨国会议、国际交流等场景。
  • 智能客服:在客户服务领域,MOOER可以通过语音识别和翻译功能,提高客服的响应效率和服务质量。
  • 语音助手:集成到智能手机、智能音箱等设备中,提供语音交互服务。
  • 教育辅助:在语言学习中,MOOER可以帮助学习者进行发音校正和语言翻译,

📝 路线图

  •  Technical report 技术报告
  •  Inference code and pretrained ASR/AST models using 5k hours of data
    使用 5k 小时数据的推理代码和预训练的 ASR/AST 模型
  •  Pretrained ASR model using 80k hours of data
    使用 80k 小时数据的预训练 ASR 模型
  •  Traning code for MooER MooER 的训练代码
  •  LLM-based speech-to-speech translation (S2ST, Mandrin Chinese to English)
    LLM 基于语音的语音转语音翻译(S2ST,Mandrin 中文到英文)
  •  GPT-4o-like audio-LLM supporting chat using speech
    类似 GPT-4o 的音频LLM 支持使用语音聊天
  •  Training code and technical report about our new Omni model
    有关我们新 Omni 模型的培训代码和技术报告
  •  Omni audio-LLM that supports multi-turn conversation
    Omni audio-LLM,支持多轮次对话
  •  Pretrained AST and multi-task models using 80k hours of data
    使用 80k 小时数据的预训练 AST 和多任务模型
  •  LLM-based timbre-preserving Speech-to-speech translation
    LLM 基于音色保留的语音到语音翻译

MaskGCT-国产最强TTS语音大模型

近期,港中大(深圳)联手趣丸科技联合推出了新一代大规模声音克隆 TTS 模型 ——MaskGCT。该模型在包含 10 万小时多语言数据的 Emilia 数据集上进行训练,展现出超自然的语音克隆、风格迁移以及跨语言生成能力,同时保持了较强的稳定性。MaskGCT 已在香港中文大学(深圳)与上海人工智能实验室联合开发的开源系统 Amphion 发布。

本文介绍了一种名为 Masked Generative Codec Transformer(MaskGCT)的全非自回归 TTS 模型。

现有大规模文本到语音(TTS)系统通常分为自回归和非自回归系统。自回归系统隐式地建模持续时间,但在鲁棒性和持续时间可控性方面存在一定缺陷。非自回归系统在训练过程中需要显式的文本与语音对齐信息,并预测语言单元(如音素)的持续时间,这可能会影响其自然度。

该模型消除了文本与语音监督之间的显式对齐需求,以及音素级持续时间预测。MaskGCT 是一个两阶段模型:在第一阶段,模型使用文本预测从语音自监督学习(SSL)模型中提取的语义标记;在第二阶段,模型基于这些语义标记预测声学标记。MaskGCT 遵循掩码预测学习范式。在训练过程中,MaskGCT 学习根据给定的条件和提示预测掩码的语义或声学标记。在推理过程中,模型以并行方式生成指定长度的标记。通过对 10 万小时的自然语音进行实验,结果表明 MaskGCT 在质量、相似度和可理解性方面优于当前最先进的零样本 TTS 系统。

一、方法

MaskGCT 模型由四个主要组件组成:

1. 语音语义表示编解码器:将语音转换为语义标记。

2. 语音声学编解码器:从声学标记重建波形。

3. 文本到语义模型【 非自回归Tranformer 】:使用文本和提示语义标记预测语义标记。

4. 语义到声学模型【非自回归Tranformer】:基于语义标记预测声学标记。

所提出的两阶段 MaskGCT 框架的概述。它由四个主要部分组成:(1)语音语义表示编解码器将语音转换为语义标记; (2)文本到语义模型用文本和提示语义标记来预测语义标记; (3) 语义到声学模型预测以语义标记为条件的声学标记; (4) 语音声学编解码器根据声学标记重建波形

语音语义表示编解码器用于将语音转换为离散的语义标记,这些标记通常通过离散化来自语音自监督学习(SSL)模型的特征获得。与以往使用 k-means 方法离散化语义特征相比,这种方法可能导致信息损失,从而影响高质量语音的重建或声学标记的精确预测,尤其是在音调丰富的语言中。为了最小化信息损失,本文训练了一个 VQ-VAE 模型来学习一个向量量化码本,该码本能够从语音 SSL 模型中重建语音语义表示。具体来说,使用 W2v-BERT 2.0 模型的第 17 层隐藏状态作为语音编码器的语义特征,编码器和解码器由多个 ConvNext 块组成。通过改进的 VQ-GAN 和 DAC 方法,使用因子分解码将编码器输出投影到低维潜在变量空间。

图 5:语义编解码器(左)和声学编解码器(右)概述。语义编解码器被训练为使用单个码本量化语义特征并重建语义特征声学编解码器经过训练,使用 RVQ 量化和重建语音波形,并使用时间和频谱鉴别器进一步提高重建质量

语音声学编解码器旨在将语音波形量化为多层离散标记,同时尽可能保留语音的所有信息。本文采用残差向量量化(Residual Vector Quantization, RVQ)方法,将 24K 采样率的语音波形压缩为 12 层的离散标记。此外,模型使用 Vocos 架构作为解码器,以提高训练和推理效率。

文本到语义模型采用非自回归掩码生成 Transformer而不使用自回归模型或任何文本到语音的对齐信息。在训练过程中,我们随机提取语义标记序列的前缀部分作为提示,以利用语言模型的上下文学习能力。我们使用 Llama 风格的 Transformer 作为模型的主干,结合门控线性单元(GLU)和 GELU 激活函数、旋转位置编码等,但将因果注意力替换为双向注意力。还使用了接受时间步 t 作为条件的自适应 RMSNorm。在推理过程中,我们生成任意指定长度的目标语义标记序列,条件是文本和提示语义标记序列。本文还训练了一个基于流匹配的持续时间预测模型,以预测基于文本和提示语音持续时间的总持续时间,利用上下文学习。

语义到声学模型同样采用非自回归掩码生成 Transformer,【基于 SoundStorm】,该模型以语义标记为条件,生成多层声学标记序列以重建高质量语音波形。对于 S2A 模型的输入,由于语义令牌序列中的帧数等于提示声学序列和目标声学序列中帧数的总和,我们简单地将语义令牌的嵌入和从层 1 到层 j的声学令牌的嵌入相加。在推理过程中,我们从粗到细为每层生成令牌,在每层内使用迭代并行解码。

图 2:T2S(左)和 S2A(右)模型的训练图概述。 T2S 模型经过训练,可以预测以文本和提示语义标记为前缀的屏蔽语义标记。 S2A 模型经过训练,可以根据提示声学标记、语义标记和前一层的声学标记来预测随机层的屏蔽声学标记
SoundStorm 架构

二、支持的功能

MaskGCT 能超自然地模拟参考音频音色与风格,并跨语言生成音频

Zero-shot In-context Learning 根据提示音频自动生成下文

MaskGCT 可以模仿名人或动画节目中角色的声音。

MaskGCT 可以学习提示语音的韵律、风格和情感。

MaskGCT 可以从提示语音中学习如何说话,包括情感和口音等风格。

MaskGCT 具有控制生成音频的总持续时间的能力,从而使我们能够将生成的语音的速度调节在合理的范围内。

与 AR 模型相比,MaskGCT 表现出更高的稳健性(更低的 WER),在一些具有挑战性的情况下(例如绕口令和 AR 模型容易产生幻觉的其他样本)表现出增强的稳定性。

Speech Editing 语音编辑。

基于掩码和预测机制,我们的文本到语义模型支持在文本-语音对齐器的帮助下进行零镜头语音内容编辑。通过使用对齐器,我们可以识别原始语义标记序列的编辑边界,屏蔽需要编辑的部分,然后使用编辑后的文本和未屏蔽的语义标记来预测被屏蔽的语义标记。

语音对话。MaskGCT 通过使用改进的训练策略微调 S2A (语义到声学)模型来支持零镜头语音转换。我们仍在努力提高语音转换的有效性。源和提示示例来自 Seed-TTS 的 demo 页面。

跨语言视频翻译。

三、实验结果

SOTA 的语音合成效果:MaskGCT 在三个 TTS 基准数据集上都达到了 SOTA 效果,在某些指标上甚至超过了人类水平。

此外,MaskGCT 在风格迁移(口音、情感)也达到了 SOTA 的水准:

我们还研究了 MaskGCT 在中、英外其它语言的能力:

四、应用场景

目前,MaskGCT 在短剧出海、智能助手、有声读物、辅助教育等领域拥有丰富的应用场景。为了加快落地应用,在安全合规下,趣丸科技打造了多语种速译智能视听平台 “趣丸千音”。一键上传视频即可快速翻译成多语种版本,并实现音话同步、口型同步、去字幕等功能。该产品进一步革新视频翻译制作流程,大幅降低过往昂贵的人工翻译成本和冗长的制作周期,成为影视、游戏、短剧等内容出海的理想选择平台。
《2024 年短剧出海白皮书》显示,短剧出海成为蓝海新赛道,2023 年海外市场规模高达 650 亿美元,约为国内市场的 12 倍,短剧出海成为蓝海新赛道。以 “趣丸千音” 为代表的产品的出现,将加速国产短剧 “走出去”,进一步推动中华文化在全球不同语境下的传播。

五、总结

MaskGCT 是一个大规模的零样本 TTS 系统,利用全非自回归掩码生成编解码器 Transformer,无需文本与语音的对齐监督和音素级持续时间预测。MaskGCT 通过文本预测从语音自监督学习(SSL)模型中提取的语义标记,然后基于这些语义标记预测声学标记,实现了高质量的文本到语音合成。实验表明,MaskGCT 在语音质量、相似度和可理解性方面优于最先进的 TTS 系统,并且在模型规模和训练数据量增加时表现更佳,同时能够控制生成语音的总时长。此外,我们还探索了 MaskGCT 在语音翻译、语音转换、情感控制和语音内容编辑等任务中的可扩展性,展示了 MaskGCT 作为语音生成基础模型的潜力。

Emilia:用于大规模语音生成的广泛、多语言和多样化的10wh+语音数据集

ArXiv: https://arxiv.org/abs/2407.05361
GitHub: https://github.com/open-mmlab/Amphion/tree/main/preprocessors/Emilia
Homepage: https://emilia-dataset.github.io/Emilia-Demo-Page/
HuggingFace: https://huggingface.co/datasets/amphion/Emilia

Emilia 数据集是一个全面的多语言数据集,具有以下功能:

  • 包含超过 101K 小时的语音数据;
  • 涵盖六种不同的语言:英语 (En)、中文 (Zh)、德语 (De)、法语 (Fr)、日语 (Ja) 和韩语 (Ko);
  • 包含来自互联网上各种视频平台和播客的各种说话风格的多样化语音数据,涵盖脱口秀、访谈、辩论、体育评论和有声读物等各种内容类型。
Language 语言Duration (hours) 持续时间 (小时)
English 英语46,828
Chinese 中文49,922
German 德语1,590
French 法语1,381
Japanese 日语1,715
Korean 朝鲜语217

Emilia 数据集结构:

|-- openemilia_all.tar.gz (all .JSONL files are gzipped with directory structure in this file)
|-- EN (114 batches)
|   |-- EN_B00000.jsonl
|   |-- EN_B00000 (= EN_B00000.tar.gz)
|   |   |-- EN_B00000_S00000
|   |   |   `-- mp3
|   |   |       |-- EN_B00000_S00000_W000000.mp3
|   |   |       `-- EN_B00000_S00000_W000001.mp3
|   |   |-- ...
|   |-- ...
|   |-- EN_B00113.jsonl
|   `-- EN_B00113
|-- ZH (92 batches)
|-- DE (9 batches)
|-- FR (10 batches)
|-- JA (7 batches)
|-- KO (4 batches)

JSONL 文件示例:

{"id": "EN_B00000_S00000_W000000", "wav": "EN_B00000/EN_B00000_S00000/mp3/EN_B00000_S00000_W000000.mp3", "text": " You can help my mother and you- No. You didn't leave a bad situation back home to get caught up in another one here. What happened to you, Los Angeles?", "duration": 6.264, "speaker": "EN_B00000_S00000", "language": "en", "dnsmos": 3.2927}
{"id": "EN_B00000_S00000_W000001", "wav": "EN_B00000/EN_B00000_S00000/mp3/EN_B00000_S00000_W000001.mp3", "text": " Honda's gone, 20 squads done. X is gonna split us up and put us on different squads. The team's come and go, but 20 squad, can't believe it's ending.", "duration": 8.031, "speaker": "EN_B00000_S00000", "language": "en", "dnsmos": 3.0442}

Emilia-Pipe 概述 👀

Emilia-Pipe 是第一个开源预处理管道,旨在将原始的野生语音数据转换为高质量的训练数据,并带有用于语音生成的注释。此管道可以在几分钟内将一小时的原始音频处理为模型就绪数据,只需要原始语音数据。
Emilia 和 Emilia-Pipe 的详细说明可以在我们的论文中找到。

The Emilia-Pipe includes the following major steps:
Emilia-Pipe 包括以下主要步骤:

  1. Standardization:Audio normalization
    标准化:音频标准化
  2. Source Separation: Long audio -> Long audio without BGM
    源分离: 长音频 -> 无 BGM 的长音频
  3. Speaker Diarization: Get medium-length single-speaker speech data
    说话人分类:获取中等长度的单个说话人语音数据
  4. Fine-grained Segmentation by VAD: Get 3-30s single-speaker speech segments
    按 VAD 进行精细分割:获取 3-30 秒的单说话人语音片段
  5. ASR: Get transcriptions of the speech segments
    ASR:获取语音段的听录
  6. Filtering: Obtain the final processed dataset
    筛选:获取最终处理后的数据集

具体使用的模型:

关于ASR:

缺乏文本转录限制了 Emilia 数据集在 TTS 任务中的直接使用。为了解决这个问题,我们应用 ASR 技术对分段的语音数据进行转录。为了平衡速度和准确性,我们使用了最先进的多语言 ASR 模型 Whisper-Medium。为了进一步提高效率,我们使用了 WhisperX ,它基于更快的 Whisper 后端和 CTranslate2 8 推理引擎。该设置的速度是官方 Whisper 实现的四倍,同时几乎保持相同的准确性。为了避免重复处理,我们绕过了 WhisperX 的 VAD 组件,直接使用VAD的结果。此外,我们还实现了基于更快 Whisper 后端的批处理推理,以并行转录语音数据。这些优化显著提高了整个流程的效率。

关于筛选:

在实际场景中,一些噪声可能无法通过源分离完全处理,Whisper 模型可能会出现幻觉,某些原始语音数据可能质量较低。为了确保生成数据集的质量,我们应用以下过滤标准。首先,我们使用 Whisper 模型的语言识别结果,丢弃任何未被预测为我们目标语言(英语、法语、德语、中文、日语、韩语)或模型语言置信度低于 80% 的语音数据。其次,我们使用 DNSMOS P.835 OVRL 评分来评估整体语音质量,仅保留评分高于 3.0 的语音数据。最后,对于每个原始语音样本,我们计算其对应片段的平均字符持续时间。平均音素持续时间超出第三四分位数上方 1.5 倍四分位距(IQR)或低于第一四分位数的片段被视为异常值,相关的语音片段将被丢弃。经过过滤后,生成的数据集将用于训练语音生成模型。

DNSMOS:评估噪声抑制效果的非侵入性客观语音质量指标。人类主观评估语音质量时音频质量的“黄金标准”。传统音频评估指标需要参考干净的语音信号,这无法对真实的应用环境中的音频作出准确评估。因为真实应用中干净的语音信号难以获得。然而,传统的无参考方法与人类主观评估质量相关性很差,没有被广泛采用。本文介绍了来自微软的深度学习多阶段噪声抑制评估方法

GLM-4-Voice:智谱新一代端到端语音大模型

GLM-4-Voice | 端到端中英语音对话模型

代码仓库:https://github.com/THUDM/GLM-4-Voice

继语言模型、图像理解、视频理解、图像生成、视频生成等模型之后,今天,智谱的多模态大模型家族再次加入新成员——GLM-4-Voice(端到端语音模型)。这一成果使得大模型具备了完整的感官系统,实现了机器与人交互的自然与流畅。
GLM-4-Voice 模型具备直接理解和生成中英文语音的能力,能够根据用户指令灵活调整语音的情感、语调、语速及方言等特征,且具有更低的延时,支持实时打断,进一步提升交互体验。
具体来说,GLM-4-Voice具备:

  1. 情感表达和情感共鸣:模拟不同的情感和语调,如高兴、悲伤、生气、害怕等情绪,用合适的情绪语气进行回复。传统 TTS 通常在情感表达上比较僵硬,声音缺少起伏和细腻的变化。
  2. 调节语速:在同一轮对话中,可以要求 TA 快点说 or 慢点说。
  3. 随时打断,灵活输入指令:根据实时的用户指令,调整语音输出的内容、风格和情感,支持更灵活的对话互动。例如,你可以随时打断 TA,让 TA 输出新的内容,更加符合日常对话情境。
  4. 多语言、多方言支持:目前 GLM-4-Voice 支持中英文语音以及中国各地方言,尤其擅长粤语、重庆话、北京话等。

技术细节

与传统的 ASR + LLM + TTS 的级联方案相比,端到端模型以音频 token 的形式直接建模语音,在一个模型里面同时完成语音的理解和生成,避免了级联方案“语音转文字再转语音” 的中间过程中带来的信息损失,也解锁了更高的能力上限。

GLM-4-Voice 由三个部分组成:

  • GLM-4-Voice-Tokenizer: 通过在 Whisper 的 Encoder 部分增加 Vector Quantization [单层量化]训练,通过在 ASR 数据上有监督训练的方式得到,将连续的语音输入转化为离散的 token,每秒音频转化为 12.5 个离散 token。
  • GLM-4-Voice-9B: 在 GLM-4-9B 的基础上进行语音模态的预训练和对齐,从而能够理解和生成离散化的语音。
  • GLM-4-Voice-Decoder: 基于 CosyVoice 的 Flow Matching 模型结构训练的支持流式推理的语音解码器,将离散化的语音 token 转化为连续的语音输出。最少只需要 10 个音频 token 即可开始生成,降低端到端对话延迟。
 CosyVoice  模型架构

具体来说,GLM-4-Voice 以离散 token 的方式表示音频,实现了音频的输入和输出的端到端建模。具体来说,我们基于语音识别(ASR)模型以有监督方式训练了音频 Tokenizer,能够在 12.5Hz(12.5 个音频 token)单码表的超低码率下准确保留语义信息,并包含语速,情感等副语言信息。

语音合成方面,我们采用 Flow Matching 模型流式从音频 token 合成音频,最低只需要 10 个 token 合成语音,最大限度降低对话延迟。

预训练方面,为了攻克模型在语音模态下的智商和合成表现力两个难关,我们将 Speech2Speech 任务解耦合为 Speech2Text(根据用户音频做出文本回复) 和 Text2Speech(根据文本回复和用户语音合成回复语音)两个任务,并设计两种预训练目标适配这两种任务形式:

  • Speech2Text:从文本数据中,随机选取文本句子转换为音频 token;
  • Text2Speech:从音频数据中,随机选取音频句子加入文本 transcription。

分别基于文本预训练数据和无监督音频数据合成语音-文本交错数据以适配这两种任务形式。

GLM-4-Voice 在 GLM-4-9B 的基座模型基础之上,经过了数百万小时音频和数千亿 token 的音频文本交错数据预训练,拥有很强的音频理解和建模能力。对齐方面,为了支持高质量的语音对话,我们设计了一套流式思考架构:输入用户语音,GLM-4-Voice 可以流式交替输出文本和语音两个模态的内容,其中语音模态以文本模态作为参照保证回复内容的高质量,并根据用户的语音指令变化感情需求,在保证智商的情况下仍然具有端到端建模的能力,同时保持低延迟性(最低只需要输出 20 个 token 便可以合成语音)。

Model List

ModelTypeDownload
GLM-4-Voice-TokenizerSpeech Tokenizer🤗 Huggingface 🤖 ModelScope
GLM-4-Voice-9BChat Model🤗 Huggingface 🤖 ModelScope
GLM-4-Voice-DecoderSpeech Decoder🤗 Huggingface 🤖 ModelScope

效果和其他说明

支持语音输入/文本输入,以及语音+文本交替输出

音频实时生成的质量较差,Gradio 的流式音频播放效果不稳定。在生成完成后点击对话框中的音频质量会更高。目前仅支持女声输出,指令遵循能力较强。

关于实时打断功能

作者目前还没给出实现方法和demo。有关打断问题,可以考虑参考开源项目 CleanS2S,虽然是级联式 pipeline【ASR+LLM+TTS】,但是相关代码逻辑应该可以结合 GLM-4-Voice 这样的 end-to-end 模型。目前支持实时输入语音打断和输入文字打断两种方式,后续还会设计更多有趣的打断模式(例如 agent 视角的主动打断)。

https://github.com/opendilab/CleanS2S/blob/main/README.zh.md

OpenAI o1复现项目进展报告

摘自:机器之心

O1 Replication Journey: A Strategic Progress Report

Report | Walnut Plan | Citation

团队介绍:本项目的核心开发团队主要由上海交通大学 GAIR 研究组的本科三年级、四年级学生以及直博一年级研究生组成。项目得到了来自 NYU 等一线大型语言模型领域顶尖研究科学家的指导。详细作者介绍见:https://github.com/GAIR-NLP/O1-Journey#about-the-team。


人工智能领域掀起巨浪的 OpenAI o1 模型发布三周后,一支由高校年轻研究者组成的团队今天发布了题为 “o1 Replication Journey: A Strategic Progress Report (o1 探索之旅:战略进展报告)” 的研究进展报告。这份报告的独特之处在于 (1)不仅提出并验证了 “旅程学习” 的技术的巨大潜力(研究者也认为是 o1 取得成功的关键技术):通过 327 条训练样本,鼓励模型学会反思、纠错、回溯,其在复杂数学题目上表现 绝对性能就超过了传统监督学习 8% 以上,相对性能提升超过 20%;(2)并且,其前所未有的透明度和即时性,不仅详细记录了团队在复现过程中的发现、挑战、试错和创新方法,更重要的是,它倡导了一种全新的 AI 研究范式。研究团队负责人表示:” 我们的主要目标不是达到与 OpenAI 的 o1 相当的性能 —— 考虑到可用资源有限,这是一个极具挑战性的任务。相反,我们的使命是透明地记录和分享我们的探索过程,聚焦于我们遇到的根本问题,发现新的科学问题,并识别导致 o1 的成功的关键因素,并与更广泛的 AI 社区分享我们的试错经验。o1 技术无疑会成为全球各大 AI 科技公司争相复现的目标。如果我们能够及早分享一些复现过程中的经验教训,就能帮助其他公司减少不必要的试错,从而降低全球范围内 o1 技术复现的总体成本和时间。这不仅有利于推动技术的快速发展,也能促进整个 AI 行业的共同进步。

团队提出的模型在同一道数学题上,与 OpenAI 的 o1-preview (答对)及 GPT-4o(答错)的比较实例,证明旅程学习不断试错、反思、自我纠正的能力在复杂推理任务场景上非常关键。
  • 技术报告链接:https://github.com/GAIR-NLP/O1-Journey/blob/main/resource/report.pdf
  • Github 链接:https://github.com/GAIR-NLP/O1-Journey
  • o1 讨论资源:https://github.com/GAIR-NLP/O1-Journey/tree/main/resource’

该报告发现了什么?从 “”捷径学习”” 到 “旅程学习”,从 “浮光掠影” 到 “深耕细作”

图:从 “捷径学习” 到 “旅程学习” 的范式转变。这是一个用于推理任务的搜索树。对于数学问题解决任务,根节点代表初始问题,而叶节点则是最终结论。绿色节点表示正确答案,红色节点表示错误答案。传统上,学习主要集中在对直接从根到叶的捷径路径进行监督训练。然而,本研究探索了对整个探索路径进行监督学习,这包括了试错和纠正的过程。

团队认为,大多数现有的机器学习或大模型训练方法(如监督式微调)都可以被归类为 “捷径学习” (Shortcut Learning),即模型学习到达正确答案的直接路径。这种传统范式虽然在特定、明确定义的任务中可能有效,但在面对复杂、动态和开放性问题时显示出明显的局限性。捷径学习具有以下几个关键特征:(1) 注重快速结果:强调在短时间内达到特定的性能指标或完成特定任务。(2) 高度依赖数据:性能改进通常依赖于增加训练数据量,而非改进学习算法本身。(3) 泛化能力有限:在训练数据分布之外的场景中,性能可能会急剧下降。(4) 缺乏自我纠正能力:这些系统通常缺乏识别和纠正自身错误的能力。尽管捷径学习推动了人工智能的许多进步,但它难以产生真正智能和可靠的人工智能系统,无法应对现实世界挑战的复杂性。随着我们追求更高级形式的人工智能甚至超级智能,这种方法的局限性变得越来越明显。
认识到这些缺点,本文提出了一种名为 “旅程学习”(Journey Learning) 的新范式。旅程学习旨在使人工智能系统能够通过学习、反思、回溯和适应不断进步,就像人类一样,从而展现出更高水平的智能。

如图所示,团队提出了 “旅程学习” 范式,它鼓励模型不仅学习捷径,还要学习完整的探索过程,包括试错、反思和回溯。仅使用 327 个训练样本,不借助任何额外训练技巧,旅程学习在 MATH 数据集上的表现就超过了传统监督学习 8% 以上,展示了其极其强大的潜力。作者也认为这是 o1 技术中最关键的组成部分

表:捷径学习和旅程学习的多维度比较

模型生成的例子:

技术细节是什么?o1 技术探索之旅

如图所示,从 OpenAI o1 9 月 12 日发布的过去三周内,该团队对 o1 技术已经完成了系统化、多阶段的探索。这个过程始于使用 OlympicArena 数据集对 o1 进行初步评估(如下表格),旨在全面了解其在多个学科领域的认知能力。研究的核心集中在 o1 思维结构的分析上,特别关注 “长思维” 这一关键概念整个探索技术涉及多个复杂的步骤,包括奖励模型的开发、在策略推理树的构建,以及将这些元素整合为连贯的长思维过程。整个研究过程采用了迭代和并行的方法。进行了多次尝试,不断调整和完善技术和方法。评估过程包括定量和定性分析,结合人工检查和专门的分析工具,以确保研究的准确性和有效性。

团队强调了探索过程的重要性,而不仅仅关注最终结果。这种重视科研探索过程的思路与团推提出的 “旅程学习” 范式相一致,强调了在复杂、动态环境中不断试错、纠错的持续学习和适应的重要性。通过这个过程,不仅获得了关于 o1 技术的深入理解,还开发了一套探索未知 AI 技术的系统方法。研究过程涉及决策分析、挑战识别以及创新解决方案的开发。最终,这项研究不仅仅是对 o1 技术的探索,更是对先进 AI 系统研究方法的一次实践和验证。通过分享研究过程,包括成功和失败的经验,旨在为 AI 研究社区提供有价值的见解,促进该领域的集体进步。
这个探索过程展示了开放、协作的 AI 研究在推动技术边界方面的重要性,为未来更复杂的 AI 系统研究提供了有益的参考和指导。
具体地,团队凝炼了复现 o1 过程中的几个关键问题,并做了非常细致的探索分享:

  • Q1: o1 的思维链是什么样子的?
  • Q2: 长思维 (Long thought) 是如何工作的?
  • Q3: 如何构建长思维?
  • Q4: 如何构建奖励模型?
  • Q5: 如何构建 on-policy 推理树?
  • Q6: 如何从推理树中推导出长思维?
  • Q7: 如何评估我们的尝试方法?
  • Q8: 如何训练我们的模型?
  • Q9: 什么是人类和 AI 协同标注的有效策略?

  • Q1: o1 的思维链是什么样子的?
  • 经过探索,团队确定需要构建的长思维数据应具有以下特征:

    • 迭代式问题解决:模型首先定义函数,然后逐步探索相关表达式,将复杂方程分解为更简单的组成部分,反映了一种结构化和有条理的方法。 
    • 关键思维指标:使用 “Therefore” 表示结论,”Alternatively” 探索不同路径,”Wait” 表示反思,以及 “Let me compute” 过渡到计算,突出了模型的推理阶段。 
    • 递归和反思方法:模型经常重新评估和验证中间结果,使用递归结构确保一致性,这在严谨的数学推理中很典型。 
    • 假设探索:模型测试不同的假设,随着获得更多信息而调整其方法,展示了推理过程中的灵活性
    • 结论和验证:最后,模型解方程并验证结果,强调在完成之前验证结论的重要性。 

    Q2: 长思维 (Long thought) 是如何工作的?
    这是团队认为重要的问题。然而,在当前的研究阶段,该团队仅仅提出了猜想。团队认为还没有足够的经验证据来验证它们的准确性,这也是未来需要重点展开的工作。
    o1 长思维方法的显著成功可以归因于在上述中介绍的旅程学习 (Journey Learning)。与传统的捷径学习 (Shortcut Learning) 不同,旅程学习允许模型探索整个决策轨迹,模仿人类的问题解决过程。这种全面的探索使 o1 能够考虑多种解决方案路径,从错误中学习,并理解完整的问题解决过程。通过经历正确和错误的路径,模型发展出强大的错误处理和自我纠正能力,增强了其适应新挑战的能力。这种方法培养了对问题领域更深入的理解,不仅仅是知道正确答案,而是理解为什么以及如何得出答案。旅程学习过程密切模拟人类的认知过程,包含试错、反思和调整。这大大增加了模型输出内容的可解释性,因为 o1 可以提供详细的解决步骤并解释其推理过程,包括如何从错误中恢复。因此,基于旅程学习的 o1 长思维过程不仅仅是计算时间的扩展,还代表了一种彻底的、人类般的推理探索。这种方法使 o1 能够处理更复杂的问题,提供更可靠和可解释的答案,并在面对新挑战时表现出更大的适应性,从而解释了它在各种任务中的卓越表现。


    Q3: 如何构建长思维?
    尝试 1:基于 LLM 和奖励的树搜索 根据在 Q1 中对长思维的观察,其最显著的特征是在推理产生错误时或遇到冗余的推理步骤时尝试反思和回溯。这类似于在推理树上搜索问题的解决方案,在错误节点处回溯,直到找到正确的解决路径。为实现这一点,需要构建一棵推理树,其中根节点代表问题,其他每个节点代表一个推理步骤。从根到任何节点的路径代表从问题到该结论的推理过程。此外,回溯和反思必须基于错误的推理步骤,这需要一个更细粒度的奖励模型(即过程级)来指示树中每个节点的正确性。通过在具有过程级奖励的推理树上执行搜索算法,可以将错误步骤整合到思维链中,从而构建包含回溯和反思等行为的长思维。
    尝试 2:提议 – 批评循环 尝试 1 通过基于预定义规则在树上执行搜索来构建长思维,但这限制了回溯和反思等行为的自由度。因此,团队尝试让模型选择自己当前的行为。团队构建了一个提议 – 批评循环,其中为模型预定义了一些可能的行为(即继续、回溯、反思、终止),并让模型自身选择行为来构建推理树。如果树没有达到最终答案,可以将这个负面信号告知模型,引导它反思和纠正其方法。
    尝试 3:多智能体方法 基于推理树构建长思维存在几个挑战,包括存在许多冗余的无效节点,以及存在不依赖于反思行为的推理步骤,从而引起构建的长思维逻辑不一致。为解决这个问题,团队设计了一个利用多智能体辩论的算法,其中一个智能体充当策略模型,持续推理,而另一个智能体充当评论模型,指示策略模型是否应该继续当前推理或执行回溯等行为。两个智能体进行持续对话,在找到正确答案时自然构建长思维数据集。
    尝试 4:完整的人类思维过程注释 当人类处理推理问题时,他们通常不会不断地向前推理直到解决问题或失败;相反,他们在无法继续时会反思、回溯和重写推理。这种行为与长思维的特征高度一致。因此,可以忠实且全面地记录人类解决推理任务的过程,从而产生高质量的长思维。


    Q4: 如何构建奖励模型?
    使用奖励模型的第一步是定义粒度。团队的目标不仅仅是关注最终结果,而是专门提高 LLMs 在反思、回溯和相关认知过程方面的能力。因此,团队将评估粒度定义在步骤层面。具体来说,团队使用来自 Abel 的微调数据,通过行号使解决方案变得清晰可辨。
    实现奖励模型的过程可以使用开源模型或是调用闭源模型的 api。团队比较了不同奖励模型在 PRM800K 和 MR-GSM8K 子集上的元评估表现。如下表格展示了结果,其中,o1-mini 在不同数据集上表现最佳,证明其是一个良好的奖励模型。

    Q5: 如何构建 on-policy 推理树?

    构建推理树需要一个能够执行单步推理的策略模型。给定一个问题及其相应的最终答案,策略模型从问题作为根节点开始,不断向树中添加新节点。它首先生成 w 个可能的第一步推理步骤作为根节点的子节点。然后,它迭代地进行前向推理,为每个当前节点(如第一步推理)生成 w 个可能的后续推理步骤作为该节点的子节点。这个过程重复进行,直到达到预设的最大深度或所有叶节点达到最终答案。

    策略模型和步骤分段 构建推理树需要清晰定义推理步骤。为此,团队采用 Abel 提出的数据格式,将数学问题解决方案转化为具有清晰步骤的形式,将答案分成多行,每行以行号开始,并包含该行内的推理。因此,使用 Abel 数据集对 DeepSeekMath-7B-Base 进行微调,得到 Abel-DSMath,作为策略模型。在这种特定格式数据上微调的模型可以方便地控制单个推理步骤的生成。

    奖励模型和剪枝 上述提出的树生成算法计算成本高昂。当设置后续推理步骤数目为 3 和深度为 10 时,最后一次迭代需要生成 3 的 10 次方个推理步骤。因此,使用奖励模型来剪除错误的推理步骤,提高操作效率。具体来说,团队采用束搜索,在每次迭代中只选择少量候选项保留到下一轮。根据使用的奖励模型,剪枝实现的细节有所不同。团队尝试了两个奖励模型:math-shepherd 和 o1-mini。

    Math-shepherd 为每个步骤提供一个介于 0 和 1 之间的实数,表示当前步骤正确的概率。在树生成的每次迭代中,对所有推理步骤进行评分,并选择得分最高的前 K 个进入下一次迭代。这将总生成次数进行剪枝。然而,math-shepherd 在评估困难问题的推理步骤时存在困难,需要一个更强大的奖励模型,能够为每个步骤提供高准确度的正确性指示。因此,最终使用 o1-mini 为每个步骤提供奖励,直接指示每个推理步骤是否正确。此时,在树生成的每次迭代中,利用来自 o1-mini 的奖励,选择最多 K 个正确的推理步骤进入下一次迭代。

    Q6: 如何从推理树中推导出长思维?
    一旦构建了推理树,目标就变为探索如何从推理树转换为包含试错过程的长思维。在该团队的框架中,推理树的每个节点都被奖励模型标注,指示该步骤是否正确或错误。具体的合成步骤如下:

    • 从推理树构建捷径 首先从推理树构建捷径,其中只包括正确答案和有效的中间步骤。从代表问题的根节点开始,找出通向正确答案叶节点的路径。如果有多个正确答案节点,则建立多条正确路径。
    • 遍历推理树 为了得到长思维,采用深度优先搜索(DFS)遍历树。这种遍历按 DFS 顺序构建路径,记录从根问题节点到正确答案叶节点的每一步,同时包括任何被标记为错误的节点的推理。DFS 的挑战在于它探索了庞大的搜索空间,产生了大量可能无法得到正确解决方案的试错路径。为了简化这一初始探索,团队还引入了具体的约束来缓解由于遍历路径过长导致的合成数据的复杂性。首先,根据节点是否位于正确路径(即捷径)上来标记树中的所有节点。遍历遵循以下规则: 
    • 正确路径上的节点:DFS 遇到正确路径上的节点时,它可能会探索导致错误结果的子节点,从而模拟试错的过程。一旦这个节点到达叶节点并被确定为错误,算法就会回溯并切换到正确的路径继续遍历。 
    • 不在正确路径上的节点:随机选择一个子节点进行探索,并不产生试错的分支。

     为进一步简化过程,应用了一个额外的约束:正确路径上的每个节点最多允许 K 次试错 —— 一次在错误路径上的试错和一次在正确路径上的探索。 这些约束确保 DFS 遍历专注有意义的试错探索,同时避免过度探索错误路径。在未来的实验中,计划移除或调整这些约束,以研究试错路径长度与最终模型性能之间的关系。

    • 从遍历路径得到长思维 生成遍历路径并将推理附加到错误节点后,通过连接路径中的所有步骤来构建长思维,其中还包含了每个错误步骤的推理。然而,初步实验表明,使用这个形式的长思维数据来训练模型的性能不佳。为解决这个问题,团队尝试使用 GPT-4o 来修改草稿。GPT-4o 在保留所有推理步骤(包括错误步骤、反思和修正)的同时,增强了思维过程的连贯性和流畅性。这种方法确保最终的长思维不仅准确,而且自然流畅,模拟了包含正确和错误步骤的人类问题解决过程。

    Q7: 如何评估我们的尝试方法?

    除了使用特定评估指标在基准测试上测试准确率分数外,人工审查实际案例(输入输出)是评估数据和模型的关键步骤。因此,为了提供一种更直观的方式来评估模型在特定问题上的表现,团队构建了一个可视化数据分析平台。
    具体来说,可视化平台包括合成树及其对应长思维的可视化,以及训练模型的输出。此外,在可视化结果时,支持详细的条件过滤,例如过滤正确或错误回答的问题,或输出是否包含表示反思或犹豫的关键词(如 “wait”)。另外,可视化平台支持不同迭代轮次的合成数据和模型输出之间的比较,这使得团队可以非常直观地验证新一轮的数据或模型是否有效。

    Q8: 如何训练我们的模型?
    团队实验使用预训练语言模型 deepseek-math-7b-base(更多其他模型已经在等待列表中)。训练过程分为两个主要阶段:监督微调(SFT)和直接偏好学习(DPO)。
    第一阶段:监督微调(SFT): 
    SFT 过程包括两个阶段:

    • 初始阶段:在这个初始阶段,团队专注于使用只包含正确中间步骤和最终正确答案的响应来微调模型。在 Abel 数据集和 PRM800K 数据集上微调 Deepseek-math-7b-base。对于 PRM800K 中的每个问题,使用单个正确的逐步解决方案,丢弃不导向最终答案的回复。在这个阶段,对每个数据集进行一个 epoch 的微调,主要目的是让模型熟悉所需的响应格式。
    • 旅程学习:在第二阶段,使用构建的长思维(包含 327 个示例)进一步微调初始阶段的 SFT 模型。这个阶段旨在增强模型发现错误、自我反思、自我修正和执行回溯的能力。通过在合成的包含试错、反思的长思维数据上训练,模型对更长推理链中涉及的复杂性有更深入的理解。为了比较,团队还在从同一推理树生成的相应捷径上 (Shortcut Learning) 微调模型(同样是 327 个),从而更直观的比较旅程学习相比捷径学习所带来的增益。

    第二阶段:直接偏好学习(DPO)
    在这个阶段,使用核采样(top_p = 0.95 和温度 T = 0.7)从 MATH Train 数据集为每个问题生成 20 个回复。这 20 个回复根据最终答案的正确性分类为正面和负面响应。从中,随机选择 5 个正面响应和 5 个负面响应来创建 5 对偏好对。然后,使用这些偏好对和 DPO 损失来训练模型,使其能够从正确和错误答案的比较中学习。


    Q9: 什么是人类和 AI 协同标注的有效策略?
    团队开发了一种人类和 AI 协作的数据标注流程,用于生成基于 MATH 数据集的高质量、长文本推理数据。通过这个流程,我们将短短几行人类标注的解题方案扩展为包含数千个 token 的、符合 “旅程学习” 范式的详细推理过程。在构建流程的过程中,我们发现了下面几种有效的标注技巧:

    • 完整的思维过程:标注者不必详细记录每一个想到的词语,但必须记录每一个尝试、反思、联想和修正的过程。这些发散的认知路径在日常思考中可能并未被表达成文字,甚至没有被显式认知。然而,捕捉这些思维转变以及背后的原因是至关重要的。这种规划和理解认知转换的能力是大语言模型从我们的数据中必须学习的核心技能。
    • 补充解释常识:人类在用语中经常省略一些可以从上下文中推断的信息,比如对前述公式的引用,或是对广为人知的理论的应用。然而,当大语言模型尝试解读人类标注时,这种省略可能导致幻觉。因此,高质量的数据必须包括对常识性知识的明确解释,以防止大模型的误解。

    遵循以上两个关键要素,人类专家即可完成数据标注,这些数据精简但准确,非常利于大模型做进一步增强。下一阶段,通过设计复杂的提示词,我们通过大语言模型实现了数据扩展和增强。我们的提示词包含以下关键点:

    • 数据颗粒度的增强:提示词强调将问题解决过程分解为更细小的步骤。通过将过程拆解成细粒度且易于理解的步骤块,大语言模型能更好地掌握和内化每个概念,确保在每个阶段都有深入的理解。
    • 逐步推理:提示词控制大语言模型需频繁暂停,反思已知信息或提出下一步的操作。这种停顿模仿了学生在思考问题时的自然过程,帮助他们保持参与感和对推理过程的连接感,而不仅仅是被动地遵循指令。
    • 探索者视角:与直接呈现答案不同,大语言模型被鼓励以探索的语气进行推理,即假设自己是第一次思考这个问题。这种方式可以激发某种程度的 “好奇心”,鼓励模型批判性思考,使他们感觉自己是学习过程的一部分,而不是简单地接收信息。

    为什么科学进展报告很重要?
    研究团队表示:传统发论文方无法适应新的科研范式,人工智能技术的快速发展开创了一个新的研究范式时代,其特点是长期的、基于团队的努力,通常持续六个月或更长时间。这种转变虽然有利于突破性创新,但无意中给科学过程带来了新的挑战。长期团队合作的内向性经常导致向更广泛科学界信息流动的减少。此外,这些项目的长期性质往往导致研究人员满足感的延迟,可能在整个研究过程中培养焦虑和动力减弱。另外,大规模团队项目的复杂性使得认可个人贡献变得复杂,可能侵蚀传统的学术激励结构。团队的进展报告方法旨在通过增强透明度、促进实时反馈和认可,以及鼓励对长期研究计划的持续承诺来解决这些新出现的挑战。在这样的背景下,团队认为 ”Scientific Progress Report“ (科研进展报告)是一种比 现在”Scentific Paper“ (科研论文)更有价值的科研产出和成果分享的组织形式。团队科学探索过程的细致记录,尤其在 AI 能力快速发展的背景下,具有深远意义。通过全面记录探索过程,包括成功和失败,团队正在培育一个独特而宝贵的数据集。这份全面的记录对于训练真正理解科学方法的 AI 模型至关重要。o1 的成功强调了 AI 系统不仅要学习结果,还要学习完整的科学探索过程,包括试错的重要性。通过科研进展报告,不仅可以捕捉技术细节,还包括决策理由、灵感来源和思维过程。这些 “人类因素” 对于训练能够进行真实科学发现的 AI 模型至关重要。
    下一步探索
    团队根据的研究时间线和取得的进展,确定了几个未来探索和发展的关键方向:

    • 扩展长思维的合成: 基于在长思维合成方面的成功迭代,团队计划进行第三轮的数据集成。这将涉及处理更复杂和多样的思维模式,可能揭示 o1 能力的新维度。
    • 长思维扩展定律实验: 这个研究流程旨在理解模型的性能和能力如何随着数据、模型大小和计算资源的增加而扩展。对这个规律的掌握对优化方法和挖掘超级 AI 系统背后的基本原理至关重要。
    • 细粒度、以思考为中心的评估: 计划开发和实施更复杂的评估方法,专注于细粒度、以思考为中心的评估。这种方法将让我们更准确地衡量生成的长思维的质量和连贯性,为模型推理能力提供更深入的洞察。
    • 人机协作以提高思考质量: 未来计划的一个关键部分是探索和增强人机协作,以产生更贴近人类思维的高质量思考数据。这涉及开发利用人类智能和 AI 能力的共同优势,促进 AI 能力的突破。
    • 持续改进奖励和批评模型: 基于过程级奖励模型和评论模型设置,旨在进一步完善这些系统。这个持续的过程将涉及迭代改进,以更好地提供细粒度的监督信号。
    • 推理树的合成优化: 计划探索从推理树中推导和集成长思维更复杂、有效的方法。这将涉及探索更加先进高效的算法来遍历并利用复杂结构中的信息。
    • 扩展训练方法: 未来计划包括进一步实验和完善训练流程。这包括增加预训练阶段、迭代训练、强化学习、偏好学习和 DPO(直接偏好优化)。
    • 持续的透明度和资源共享: 将继续分享在整个科研旅程中开发的资源、观察到的结论和工具。这种持续的做法旨在促进更广泛的 AI 研究社区的协作和加速进展。
    • 探索多代理方法: 基于在多代理系统方面的初步尝试,计划深入研究这一领域,发现建模复杂推理和决策过程潜在的新方法。
    • 完善分析工具: 旨在进一步开发和增强分析工具。这些工具对解释模型输出、跟踪进展和指导未来研究方向至关重要。

    通过追求这些途径,不仅推进我们对 o1 能力的理解和复制,还要推动 AI 研究方法的边界。

    核桃计划:

    团队借本项目正式引出 “核桃计划” (https://gair-nlp.github.io/walnut-plan),团队成员表示:“对 o1 技术路线的探索及复现工作,仅仅是我们核桃计划的一部分。核桃计划旨在成为人工智能复杂推理和深度思考能力研究的开放先锋,致力于推动 AI 从简单的信息处理工具演变为具备 “牛顿” 和 “爱因斯坦” 级别深度思考能力的智能系统。我们将着眼于更长远的研究,最终的伟大愿景是让未来可以呈现 AI 驱动的科研范式,即 AI 完全具备参与人类科研的水准,从而更好地服务人类、改变世界。”

    OpenMusic:音乐生成更高质量,更有乐感

    中科大&科大讯飞重磅开源OpenMusic

    文章链接:https://arxiv.org/pdf/2405.15863
    代码链接:https://github.com/ivcylc/qa-mdt
    Huggingface链接:https://huggingface.co/spaces/jadechoghari/OpenMusic
    Demo链接:https://qa-mdt.github.io/  (chatgpt * 30, musiccaps * 30)

    • 提出了一种质量感知训练范式,使模型在训练过程中能够感知数据集的质量,从而在音乐性(美学角度)和音频质量方面实现卓越的音乐生成效果。
    • 创新性地将masked扩散Transformer引入到音乐信号中,展示了其在建模音乐潜在空间上的独特效果,以及其在质量控制感知方面的卓越能力,从而进一步提升了生成音乐的质量和音乐性。
    • 解决了大型音乐数据集中文本与音频低相关性的问题,有效提高了文本对齐度和生成的多样性。

    背景

    近年来,基于扩散的文本到音乐(TTM)生成方法逐渐受到重视,提供了一种创新的方法,将文本描述合成音乐内容。要在这一生成过程中实现高准确性和多样性,必须依赖大量高质量的数据,包括高保真音频波形和详细的文本描述,但这些通常仅占现有数据集中的一小部分。在开源数据集中,低质量音乐波形、标签错误、弱标签和无标签数据等问题显著阻碍了音乐生成模型的发展。为了解决这些挑战,今天和大家分享一种全新的高质量音乐生成范式,该范式结合了质量感知训练策略,使生成模型能够在训练过程中辨别输入音乐波形的质量。利用音乐信号的独特特性,首先针对TTM任务调整并实现了一个掩码扩散Transformer(MDT)模型,展现出其在质量控制和音乐性增强方面的独特能力。此外,还通过字幕优化数据处理方法解决了TTM中低质量字幕的问题。实验结果表明,在MusicCaps和Song-Describer数据集上取得了当前最先进的(SOTA)性能。

    当前音乐生成(音效生成)领域的问题为质量低,具体来说分为三个方面:

    • 大部分的开源数据集音质低(FMA,AudioSet,MSD),旋律杂乱
    • 音乐性(美学角度)差
    • 文本对齐度低,大多数的音频处于少标签,弱标签,错标签。其中, 第1点可以由下图蓝色分布CLAP分数表征,2,3点可以由数据集的平均MOS分布表征(颜色由 μ +α * σ 分割)
    图 1:大规模开源音乐数据库 AudioSet 和 FMA 的 CLAP 相似性和伪 MOS 的分布曲线,其中较暗的区域代表较高的文本音频对齐或音频质量。

    创新方法及思路

    质量信息注入

    解决: 引入质量感知训练策略。采用主观数据集中的MOS分训练出的质量评分模型,在训练过程中注入(伪MOS分)音频质量信息。

    两种注入方法:

    • 利用 text encoder 对分级后的 low quality, medium quality, high quality 质量文本进行cross attn嵌入 【粗粒度,适配unet架构和transformer类架构】
    • 参考U-ViT内 时间信息和label信息的融入方式,以量化(阈值由 决定)后转换为quality embedding, 以token 形式进行控制注入,【细粒度,并且只适配transformer类架构】

    结论:质量感知策略允许了在推理阶段以高质量文本和质量token进行引导,从而生成显著高于训练集平均质量的音频。

    以类似解耦的方式在训练中感知音频的质量(类似TTS中分离出音色训练),从而更好地促进了模型的训练(大幅降低FAD,KL,并提升IS,REL,CLAP等指标)

    我们还发现,粗粒度文本控制和细粒度token控制相结合,更有助于模型训练中解耦,感知,并控制更高质量音频的生成,从而解决训练数据集影响的问题

    质量感知型 masked扩散Transformer

    解决:从音乐性建模角度,我们发现 U-ViT/DiT 类架构对频谱隐空间建模也具有图像上表达的scale ability,并能更好建模谐波,音色等方面(反应在主观评分)

    优化

    • 对频谱切片而言,此类结构的收敛速度慢。消融数据集中,20w步时依然不能很好控制收敛,推测来源于时域/频域相关性弱。故在预训练阶段加入掩码,加速训练速度和频谱关联性。微调阶段以高质量数据进一步强化模型(5W步就有收敛迹象)。
    • 相比于U-Net,transformer based架构对text encoder的质量信息感知能力增强,并且U-ViT 式 token 质量融入策略显著有效进一步提升质量并降低客观指标
    • 图像中切块未考虑 overlap,探究了overlap策略在合成中的作用(大幅降低FAD,但在主观听感上有trade off)

    优化音乐标注描述

    解决:首次在音乐生成领域使用预训练标注模型(LP-Musiccaps)进行大规模标注优化

    • 考虑到标注模型的不充分训练导致错标,以CLAP文本-音频分数+阈值筛选低分数据
    • 考虑到原始标注中有些词(例如说American,R&B等标注器不一定能标注出的词)。使用CLAP分数过滤出生成的与原始的文本相似度低低数据,利用语言模型 融合原始标注中有用信息

    实验

    总体对比与,对比U-net架构和transformer based架构

    对比overlap策略和patch size:

    质量感知消融

    此图证明了相比于无质量感知,大幅提升了生成质量和客观指标。并且,MDT(我们的架构)比 U-Net 在文本质量控制感知上的独特优势(生成质量更高,总体客观指标更好)

    左图展示了 token as control 的准确感知控制生成能力,生成的高质量数据(黄色区域)显著高于训练集MOS分。

    右图展示了文本质量控制和token质量控制的结合效果与单纯token和文本控制的对比。

    主观评测结果

    • PO:产品运营
    • PMP:专业音乐制作人
    • VE:视频编辑人
    • BEGINNERS:不懂音乐的小白

    各个人的评分下,均有优势。

    结论与展望

    本研究识别出大规模音频质量不均和文本标注未对齐所带来的挑战,这些挑战阻碍了基于扩散的文本到音乐(TTM)生成的发展。通过采用基于p-MOS的新型质量感知学习方法,以及以masked扩散Transformer作为扩散过程的主干,在音乐生成中实现了更高的生成质量和音乐性。

    为什么BERT输入的最大长度要限制为512?

    来自 知乎

    总的来说是为了计算效率,所以把长度限制在了512。

    To speed up pretraing in our experiments, we pre-train the model with sequence length of 128 for 90% of the steps. Then, we train the rest 10% of the steps of sequence of 512 to learn the positional embeddings.

    当然,最最本质的原因还是在于BERT中的Positional Embedding和Transformer中的Positional Embedding实现方式不一样,后者是通过公式计算得到的,而前者本质上则是一个可学习的参数,也就是说每个位置所对应的向量就类似于Token Embedding中每个词对应的词向量。

    class PositionalEmbedding(nn.Module):
        """
        位置编码。
          *** 注意: Bert中的位置编码完全不同于Transformer中的位置编码,
                    前者本质上也是一个普通的Embedding层,而后者是通过公式计算得到,
                    而这也是为什么Bert只能接受长度为512字符的原因,因为位置编码的最大size为512 ***
          # Since the position embedding table is a learned variable, we create it
          # using a (long) sequence length `max_position_embeddings`. The actual
          # sequence length might be shorter than this, for faster training of
          # tasks that do not have long sequences.
                                                     ————————  GoogleResearch
        https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/modeling.py
        """
    
        def __init__(self, hidden_size, max_position_embeddings=512, initializer_range=0.02):
            super(PositionalEmbedding, self).__init__()
            self.embedding = nn.Embedding(max_position_embeddings, hidden_size)
            self._reset_parameters(initializer_range)

    因此,除非是你自己从零开始训练一个模型,否则如果你使用的是谷歌开源的预训练模型,那么这个词表的大小将会被限制在512。 当然,你依旧可以突破这个限制,那就是自己在从新初始化Positional Embedding中的向量,然后对前512个进行替换,后面的就自己在对于语料上训练微调。

    Triton Inference Server推理引擎

    https://github.com/triton-inference-server/server

    Triton 从入门到精通

    https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/index.html

    深度学习部署神器——triton-inference-server入门教程指北

    triton-inference-server的backend(一)——关于推理框架的一些讨论

    Triton Inference Server是Nvida开源的机器学习推理引擎(可以理解为同TF Serving对等的产品),其提供了多种开箱即用的功能帮助我们快速落地AI模型到生产环境以提供业务使用。当我们团队人手资源受限或开发时间不足的情况下,没有能力自研一套机器学习推理引擎,Triton是一个非常好的选择。

    从Triton的不断完善的进程来,其已经逐渐的具有了一个深度学习推理引擎应该所具有的的全部功能(当然还是有一些不够完善的地方)。

    Triton Inference Server核心的几个功能:

    • 多框架支持:Triton支持了几乎所有主流的机器学习框架,例如Tensorflow、TensorRT、Pytorch、Python、ONNX等;同时也可以custom backend的方式来扩展解码引擎。
    • 高性能:Triton提供了dynamic batching、concurrent execution、optimal model configuration、model ensemble、dali model 等策略来提升在线推理的性能;同时也提供了perf analyze和model analyze工具来辅助我们进行性能调优。
    • MLOps:Triton提供了Prometheus exporter、模型在线更新、http server 、grpc server等多种在线服务策略以满足用户生产场景多样化的部署和运维需求

    triton可以充当服务框架去部署你的深度学习模型,其他用户可以通过http或者grpc去请求你搭建的服务,相当于你用flask搭了个服务供别人请求,当然相比flask的性能高很多了。triton也可以摘出C-API充当多线程推理服务框架,去除http和grpc部分,适合本地部署多模型,比如你有很多模型要部署,然后分时段调用,或者有pipeline,有了triton就省去你处理显存、内存和线程的麻烦。

    借用官方的图,triton的使用场景结构如下:

    涉及到运维部分,我也不是很懂,抛去K8S后,结构清爽了些:

    triton的一些优点

    通过上述的两个结构图,可以大概知道triton的一些功能和特点

    • 支持HTTP/GRPC
    • 支持多backend,TensorRT、libtorch、onnx、paddle、tvm啥的都支持,也可以自己custom,所以理论上所有backend都可以支持
    • 单GPU、多GPU都可以支持,CPU也支持
    • 模型可以在CPU层面并行执行
    • 很多基本的服务框架的功能都有,模型管理比如热加载、模型版本切换、动态batch,类似于之前的tensorflow server
    • 开源,可以自定义修改,很多问题可以直接issue,官方回复及时
    • NVIDIA官方出品,对NVIDIA系列GPU比较友好,也是大厂购买NVIDIA云服务器推荐使用的框架
    • 很多公司都在用triton,真的很多,不管是互联网大厂还是NVIDIA的竞品都在用,用户多代表啥不用我多说了吧
    大厂都在用triton

    如何学习triton

    两年前开始学习的时候,官方资料比较匮乏, 只能通过看源码来熟悉triton的使用方式,所幸知乎上有个关于TensorRT serving不错的教程[2],跟着看了几篇大致了解了triton的框架结构。那会triton叫做TensorRT serving,专门针对TensorRT设计的服务器框架,后来才变为triton,支持其他推理后端的。

    现在triton的教程比较多了,官方的docs写着比较详细,还有issue中各种用例可以参考,B站上也有视频教程[3],比两年前的生态要好了不少。

    当然,最重要的,还是上手使用,然后看源码, 然后客制化。

    源码学习

    从triton的源码中可以学到:

    • C++各种高级语法
    • 设计模式
    • 不同backend(libtorch、TensorRT、onnxruntime等)如何正确创建推理端,如何多线程推理
    • C++多线程编程/互斥/队列
    • API接口暴露/SDK设计
    • CMAKE高级用法

    等等等等,不列举了,对于程序员来说,好的源码就是好的学习资料。当然,也可以看老潘的文章哈。

    triton系列教程计划

    triton相关系列也会写一些文章,目前大概规划是这些:

    • 什么是triton以及triton入门、triton编译、triton运行
    • triton管理模型、调度模型的方式
    • triton的backend介绍、自定义backend
    • 自定义客户端,python和c++
    • 高级特性、优先级、rate limiter等等

    编译和安装

    一般来说,如果想快速使用triton,直接使用官方的镜像[4]最快。但是官方镜像有个尴尬点,那就是编译好的镜像需要的环境一般都是最新的,和你的不一定一致

    比如22.09版本的镜像需要的显卡驱动为520及以上,如果想满足自己的显卡驱动,就需要自行编译了。官方也提供了使用镜像的快速使用方法

    # 第一步,创建 model repository 
    git clone -b r22.09 https://github.com/triton-inference-server/server.git
    cd server/docs/examples
    ./fetch_models.sh

    # 第二步,从 NGC Triton container 中拉取最新的镜像并启动
    docker run --gpus=1 --rm --net=host -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:22.09-py3 tritonserver --model-repository=/models

    # 第三步,发送
    # In a separate console, launch the image_client example from the NGC Triton SDK container
    docker run -it --rm --net=host nvcr.io/nvidia/tritonserver:22.09-py3-sdk
    /workspace/install/bin/image_client -m densenet_onnx -c 3 -s INCEPTION /workspace/images/mug.jpg

    # Inference should return the following
    Image '/workspace/images/mug.jpg':
        15.346230 (504) = COFFEE MUG
        13.224326 (968) = CUP
        10.422965 (505) = COFFEEPOT

    triton官方仓库

    两年前的triton只有一个大仓库,tensorrt_backend也默认在triton主仓库中,但是现在tensorrt_backend被拆分出来了,很显然triton除了支持tensorrt还支持很多其他的后端,我们可以自定义使用很多后端。

    现在是目前的triton包含的一些仓库:

    • [server[5]] triton服务外层框架,包含了http收发请求,服务内存分配等一些功能代码
    • [core[6]] triton主框架,如果处理请求、后端管理、模型调度啥的全在这里
    • [common[7]] 通用工具,没啥好说的,打日志的代码在这里
    • [backend[8]] backend后端框架代码,存放了一些后端通用父类,自定义后端可以集成这些类仿写新的后端
    • [third_party[9]] triton使用的第三方库的汇总,主要是cmake里头会包含
    • [tensorrt_backend[10]] tensorrt后端代码
    • [pytorch_backend[11]] libtorch后端代码

    最开始的时候,server、core、common、backend这些代码仓库都是合在一起的,后来都拆分出来了,增加了triton的灵活性。

    比如,上述的core仓库可以单独暴露出cAPI作为动态链接库供其他程序调用,去掉http、grpc的外层请求接口,直接一步到位调用。

    一般来说,我们都是从最主要的server开始编,编译的时候会链接core、common、backend中的代码,其他自定义backend(比如tensorrt_backend)在编译的时候也需要带上common、core、backend这三个仓库,这些关系我们可以从相应的CMakeList中找到。

    自行编译

    如果想要研究源码,修改源码实现客制化,那么自行编译是必须的。

    triton的编译和安装其实很简单,唯一的难点就是需要加速,因为triton在编译的时候会clone很多第三方库,第三方库也会克隆它们需要的第三方库,这些库当然都是国外的,所以有个好的网络环境很重要。

    比如在编译triton的时候需要下载grpc这个库,grpc又依赖很多第三方其他库,网络不好的话,会经常遇到下面的问题:

    Failed to recurse into submodule path 'third_party/bloaty'
    CMake Error at /tmp/tritonbuild/tritonserver/build/_deps/repo-third-party-build/grpc-repo/tmp/grpc-repo-gitclone.cmake:52 (message):
      Failed to update submodules in:
      '/tmp/tritonbuild/tritonserver/build/_deps/repo-third-party-build/grpc-repo/src/grpc'


    make[3]: *** [_deps/repo-third-party-build/CMakeFiles/grpc-repo.dir/build.make:99: _deps/repo-third-party-build/grpc-repo/src/grpc-repo-stamp/grpc-repo-download] Error 1
    make[3]: Leaving directory '/tmp/tritonbuild/tritonserver/build'
    make[2]: *** [CMakeFiles/Makefile2:590: _deps/repo-third-party-build/CMakeFiles/grpc-repo.dir/all] Error 2
    make[2]: Leaving directory '/tmp/tritonbuild/tritonserver/build'
    make[1]: *** [CMakeFiles/Makefile2:145: CMakeFiles/server.dir/rule] Error 2
    make[1]: Leaving directory '/tmp/tritonbuild/tritonserver/build'

    开加速是最好的办法,不管是UI还是命令行,都有相应的软件可以用,比如clash。

    如果你的服务器实在是开不了加速,也有其他办法,那就是将triton库中大部分重量级库的git地址全换为国内的

    怎么替换,我是在gitee中,同步github上的仓库,比如triton的core仓库,同步过来,就可以使用国内的地址了。

    当然也需要将这些库的submodule中的库也修改为国内源,比如grpc这个库依赖很多第三方库,克隆的时候,这是要一个一个下载的:

    改起来稍微麻烦,还需要注意,要改特定commit分支的git地址:

    ver/build/_deps/repo-third-party-build/grpc-repo/src/grpc/third_part目录然后手动git clone xxx,然后执行一下git submodule init / git submodule update下就可以带进去。

    示例:

    root@64da25af2629:/tmp/tritonbuild/tritonserver/build/_deps/repo-third-party-build/grpc-repo/src/grpc# git submodule init
    root@64da25af2629:/tmp/tritonbuild/tritonserver/build/_deps/repo-third-party-build/grpc-repo/src/grpc# git submodule update
    Submodule path 'third_party/googletest': checked out 'c9ccac7cb7345901884aabf5d1a786cfa6e2f397'

    太麻烦了,不过确实为一种办法呃呃。

    还有一点,triton每次build都会clone,是因为其用了cmake中的ExternalProject_Add指令,假如我们已经有下载好的grpc,那么直接替换到server/build/_deps/repo-third-party-build/grpc-repo/src中然后将/data/oldpan/software/server/build/_deps/repo-third-party-src/CMakeLists.txt

    注释掉git下载部分,修改自己本地的就行,就不需要每次再clone一遍了。

    #
    # Get the protobuf and grpc source used for the GRPC endpoint. We must
    # use v1.25.0 because later GRPC has significant performance
    # regressions (e.g. resnet50 bs128).
    #
    ExternalProject_Add(grpc-repo
      PREFIX grpc-repo
      # GIT_REPOSITORY "https://gitee.com/Oldpann/grpc.git"
      # GIT_TAG "v1.25.x"
      SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/grpc-repo/src/grpc"
      CONFIGURE_COMMAND ""
      BUILD_COMMAND ""
      INSTALL_COMMAND ""
      TEST_COMMAND ""
      PATCH_COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/tools/install_src.py --src <SOURCE_DIR> ${INSTALL_SRC_DEST_ARG} --dest-basename=grpc_1.25.0
    )

    说了这么多,总之,最好的办法当然还是开科学,全局一下就OK,省去那么多麻烦事儿。

    搞定好网络问题,编译triton就很简单了!

    git clone --recursive https://github.com/triton-inference-server/server.git
    cd server
    python build.py  --enable-logging --enable-stats --enable-tracing --enable-gpu  --endpoint=http --repo-tag=common:r22.06 --repo-tag=core:r22.06 --repo-tag=backend:r22.06 --repo-tag=thirdparty:r22.06 --backend=ensemble --backend=tensorrt

    在克隆好的server的目录下执行以上命令(下面是我的设置,我们可以个根据自己的需求进行修改)就可以了。

    执行这个命令后triton就会构建docker在docker中编译,最终会创建3个镜像:

    • tritonserver:latest
    • tritonserver_buildbase:latest
    • tritonserver_cibase:latest

    最终编译好的tritonserver_buildbase:latest镜像,我们可以在其中开发,因为环境都帮忙配好了,只需要再执行编译命令,就可以编译了,我们也可以自定义源码进行个性功能的开发。

    在镜像中开发

    需要注意,在编译的时候需要pull官方默认的镜像,而这个镜像是有显卡驱动限制的,比如r22.06需要显卡驱动版本为470。

    同志们看看自己的显卡驱动,别下了不能用hhh

    可以通过triton镜像历史[12]查看镜像版本要求:

    接上,我们不是编译好了triton镜像,直接进去就可以开发了:

    docker run -v/home/oldpan/code:/code -v/home/oldpan/software:/software  -d tritonserver_buildbase:latest /usr/bin/sh -c "while true; do echo hello world; sleep 20;done"

    在docker中修改triton的源码,继续执行以下命令就可以编译,和之前的区别就是加了 --no-container-build参数。

    python build.py  --enable-logging --enable-stats --enable-tracing --enable-gpu  --endpoint=http --repo-tag=common:r22.06 --repo-tag=core:r22.06 --repo-tag=backend:r22.06 --repo-tag=thirdparty:r22.06 --backend=ensemble --no-container-build --build-dir=./build

    我们如果想编译debug版本的triton,可以在命令中添加:--build-type=Debug

    另外,原始triton镜像中已经有tensorrt,如果想换版本,可以删除原始docker中的旧的tensorrt,自行安装新的tensorrt即可:

    • https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html[13]

    说下运行流程吧!

    讲了这么多铺垫,接下来简单说下运行流程。

    这里通过代码简单梳理下triton运行的整体流程,之后的具体细节,放到接下来的篇章讲解

    首先一开始,main函数在servers/main.cc下,triton在启动的时候会执行以下函数:

    // src/servers/main.cc 经过简化
    int
    main(int argc, char** argv)
    {
      // 解析参数
      TRITONSERVER_ServerOptions* server_options = nullptr;
      if (!Parse(&server_options, argc, argv)) {
        exit(1);
      }
      ...
      // 这里创建server
      TRITONSERVER_Server* server_ptr = nullptr;
      FAIL_IF_ERR(
          TRITONSERVER_ServerNew(&server_ptr, server_options), "creating server"); // 这里创建server
      FAIL_IF_ERR(
          TRITONSERVER_ServerOptionsDelete(server_options),
          "deleting server options");
      std::shared_ptr<TRITONSERVER_Server> server(
          server_ptr, TRITONSERVER_ServerDelete);
      ...
      // 启动HTTP, GRPC, 以及性能统计的端口
      if (!StartEndpoints(server, trace_manager, shm_manager)) {
        exit(1);
      }
      // Trap SIGINT and SIGTERM to allow server to exit gracefully
      signal(SIGINT, SignalHandler);
      signal(SIGTERM, SignalHandler);
      // 等待kill信号区关闭triton 
      while (!exiting_) {
       ...
          // 做一些监控模型仓库是否变动的操作
      }
      // 优雅地关闭triton
      TRITONSERVER_Error* stop_err = TRITONSERVER_ServerStop(server_ptr);
      // 如果无法优雅地关掉,旧直接exit即可
      if (stop_err != nullptr) {
        LOG_TRITONSERVER_ERROR(stop_err, "failed to stop server");
        exit(1);
      }
      // 停止监控http、grpc
      StopEndpoints();
      ...
      return 0;
    }

    TRITONSERVER_ServerNew这个函数中,会:

    • new一个triton类InferenceServer对象
    • 根据参数设置配置一下,执行一堆Set函数
    • 配置好参数后,Init服务,这里初始化服务的状态,校验参数
    • 创建各种模块,经常使用的有后端管理TritonBackendManager以及模型仓库管理ModelRepositoryManager
    • 再进行一些检查、配置一些状态

    在启动过程中最重要的是模型仓库,运行triton当然你要有模型,要不然你开它干嘛?

    这里我使用的模型仓库目录结构如下(是一个识别姿态的hrnet,hrnet官方有很多预训练模型,转tensorrt也很简单):

    debug目录下有一个模型文件夹叫做hrnet-pose-estimate-debug的模型文件夹,这个文件夹地址(/path/to/hrnet-pose-estimate-debug)需要传给triton启动命令行,文件夹内的四个子模型文件夹,会被triton检测到并且一一加载。

    需要注意的是,除了hrnet_pose_estimate这个其余三个在目录的1子目录下有个so或者model.plan,这代表hrnet-trt-staticimage_preprocess还有pose_postprocess都属于model,使用了backend,backend会在各自的config中指明:

    name: "hrnet-trt-static"
    backend: "tensorrt"

    因为hrnet-trt-static是tensorrt的模型,所以backend设置为tensorrt,model.plan就是tensorrt的engine。其backend的so文件我放到了其他位置(放到和model.plan同目录也是可以的),而另外两个预处理和后处理的backend就放到了模型仓库中,也就是libtorch_image_preprocess.solibtriton_pose_postprocess,包含了你的backend代码,封装成so供triton调用

    关于backend、model以及modelinstanc的关系,说实话稍微复杂点,各自有完整的生命周期,这个嘛,之后文章说,感兴趣的也可以提前看官方文档的介绍:

    • https://github.com/triton-inference-server/backend[14]

    然后我们就启动triton吧!

    # 执行以下函数,模型目录通过 --model-repository 指定     tensorrt的backend通过  --backend-directory 指定
    ./tritonserver --model-repository=/path/to/hrnet-pose-estimate-debug --backend-directory=/workspace/backends/tensorrt_backend/ 

    模型加载成功之后会输出:

    ...
    I1016 08:25:37.952055 51771 server.cc:587] 
    +------------------+----------------------------------------------------------------+----------------------------------------------------------------+
    | Backend          | Path                                                           | Config                                                         |
    +------------------+----------------------------------------------------------------+----------------------------------------------------------------+
    | image_preprocess | /workspace/triton-models/debug/hrnet-pose-estimate-debug/image | {"cmdline":{"auto-complete-config":"false","min-compute-capabi |
    |                  | _preprocess/1/libtriton_image_preprocess.so                    | lity":"6.000000","backend-directory":"/workspace/backends/tens |
    |                  |                                                                | orrt_backend/","default-max-batch-size":"4"}}     |
    |                  |                                                                |                                                                |
    | pose_postprocess | /workspace/triton-models/debug/hrnet-pose-estimate-debug/pose_ | {"cmdline":{"auto-complete-config":"false","min-compute-capabi |
    |                  | postprocess/1/libtriton_pose_postprocess.so                    | lity":"6.000000","backend-directory":"/workspace/backends/tens |
    |                  |                                                                | orrt_backend/","default-max-batch-size":"4"}}     |
    |                  |                                                                |                                                                |
    | tensorrt         | /workspace/backends/tensorrt_backend/li              | {"cmdline":{"auto-complete-config":"false","min-compute-capabi |
    |                  | btriton_tensorrt.so                                            | lity":"6.000000","backend-directory":"/workspace/backends/tens |
    |                  |                                                                | orrt_backend/","default-max-batch-size":"4"}}     |
    |                  |                                                                |                                                                |
    +------------------+----------------------------------------------------------------+----------------------------------------------------------------+

    I1016 08:25:37.952252 51771 server.cc:630] 
    +---------------------+---------+--------+
    | Model               | Version | Status |
    +---------------------+---------+--------+
    | hrnet-trt-static    | 1       | READY  |
    | hrnet_pose_estimate | 1       | READY  |
    | image_preprocess    | 1       | READY  |
    | pose_postprocess    | 1       | READY  |
    +---------------------+---------+--------+

    I1016 08:25:38.051742 51771 metrics.cc:650] Collecting metrics for GPU 0: NVIDIA GeForce RTX 3080
    I1016 08:25:38.055197 51771 tritonserver.cc:2159] 
    +----------------------------------+------------------------------------------------------------------------------------------------------------------+
    | Option                           | Value                                                                                                            |
    +----------------------------------+------------------------------------------------------------------------------------------------------------------+
    | server_id                        | triton                                                                                                           |
    | server_version                   | 2.23.0                                                                                                           |
    | server_extensions                | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration |
    |                                  |  system_shared_memory cuda_shared_memory binary_tensor_data statistics trace                                     |
    | model_repository_path[0]         | /workspace/triton-models/debug/hrnet-pose-estimate-debug                                                         |
    | model_control_mode               | MODE_NONE                                                                                                        |
    | strict_model_config              | 1                                                                                                                |
    | rate_limit                       | OFF                                                                                                              |
    | pinned_memory_pool_byte_size     | 268435456                                                                                                        |
    | cuda_memory_pool_byte_size{0}    | 300021772                                                                                                        |
    | response_cache_byte_size         | 0                                                                                                                |
    | min_supported_compute_capability | 6.0                                                                                                              |
    | strict_readiness                 | 1                                                                                                                |
    | exit_timeout                     | 30                                                                                                               |
    +----------------------------------+------------------------------------------------------------------------------------------------------------------+

    I1016 08:25:38.055627 51771 http_server.cc:3303] Started HTTPService at 0.0.0.0:8000
    I1016 08:25:38.097213 51771 http_server.cc:178] Started Metrics Service at 0.0.0.0:8001

    加载好之后,我们开启了http端口,端口号为8000,另一个是metric接口,端口号8001

    此时可以使用http请求试一下。

    简单请求

    请求的话有http和grpc协议,我对http协议熟悉些,所以就搞http吧。

    官方也提供了客户端,C++和python的都可以有,可以直接使用官方的,也可以根据官方提供的http协议构造自己的客户端,只要会构造body,一切都很简单。

    请求协议可以参考官方:

    • https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md[15]

    这里我们用python简单构造一个body

    # 构造triton的输入body
    json_buf = b'{\"inputs\":[{\"name\":\"INPUT\",\"datatype\":\"BYTES\",\"shape\":[1],\"parameters\":{\"binary_data_size\":' + \
            bytes(str(len(data)), encoding = "utf8") + b'}}],\"outputs\":[{\"name\":\"RESULT\",\"parameters\":{\"binary_data\":true}}]}'
    push_data = json_buf + data

    print("Inference-Header-Content-Length ",str(len(json_buf)), " Content-Length ",str(len(data) + len(json_buf)))
    # 构造triton-header
    header = {"Content-Type": "application/octet-stream", "Accept": "*/*",
              "Inference-Header-Content-Length":str(len(json_buf)),
              "Content-Length":str(len(data) + len(json_buf))}

    server_url = "127.0.0.1:8000"
    model_name = "hrnet_pose_estimate"

    # 请求
    response = post('http://' + server_url + '/v2/models/' + model_name + '/infer', data=push_data, headers=header)

    就可以发送请求,结果也会传回response里。我们也可以使用curl命令,直接传递构造好的body(这个body将上述的push_data写到本地即可):

    [oldpan@have-fun client]$ curl -v --max-time 1 --request POST 'http://192.168.1.102:9006/v2/models/hrnet_pose_estimate/infer' --header 'Inference-Header-Content-Length: 230' --header 'Content-Type: application/octet-stream' --data-binary '@data.txt' --output temp_res
    Note: Unnecessary use of -X or --request, POST is already inferred.
    *   Trying 192.168.1.102:9006...
      % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                     Dload  Upload   Total   Spent    Left  Speed
      0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0* Connected to 172.29.210.105 (172.29.210.105) port 9006 (#0)
    > POST /v2/models/aocr_cnprint_trt8p/infer HTTP/1.1
    > Host: 192.168.1.102:9006
    > User-Agent: curl/7.71.1
    > Accept: */*
    > Inference-Header-Content-Length: 230
    > Content-Type: application/octet-stream
    > Content-Length: 1573102
    > Expect: 100-continue

    * Mark bundle as not supporting multiuse
    < HTTP/1.1 100 Continue
    } [56480 bytes data]
    * We are completely uploaded and fine
    * Mark bundle as not supporting multiuse
    < HTTP/1.1 200 OK
    < Content-Type: application/json
    < Inference-Header-Content-Length: 394
    < Content-Length: 6794

    { [6794 bytes data]
    100 1542k  100  6794  100 1536k   127k  28.8M --:--:-- --:--:-- --:--:-- 28.9M

    结果就不发了,验证没啥问题。

    关于如何使用curl直接请求triton,有一些相关链接可以参考:

    • https://github.com/triton-inference-server/server/issues/2563[16]
    • https://github.com/triton-inference-server/server/issues/1822[17]
    参考资料
    https://www.bilibili.com/video/BV1KS4y1v7zd/?spm_id_from=333.788&vd_source=eec038509607175d58cdfe2e824e8ba2[18]
    https://github.com/triton-inference-server/server/releases[19]
    https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/index.html[20]
    https://zhuanlan.zhihu.com/p/462817679[21]
    参考资料
    [1]
    triton: https://github.com/openai/triton
    
    [2]
    教程: https://www.zhihu.com/people/pwlazy/posts
    
    [3]
    视频教程: https://www.bilibili.com/video/BV1KS4y1v7zd/?spm_id_from=333.337.search-card.all.click&vd_source=eec038509607175d58cdfe2e824e8ba2
    
    [4]
    官方的镜像: https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/index.html
    
    [5]
    [server: https://github.com/triton-inference-server/server
    
    [6]
    [core: https://github.com/triton-inference-server/core
    
    [7]
    [common: https://github.com/triton-inference-server/common
    
    [8]
    [backend: https://github.com/triton-inference-server/backend
    
    [9]
    [third_party: https://github.com/triton-inference-server/third_party.git
    
    [10]
    [tensorrt_backend: https://github.com/triton-inference-server/tensorrt_backend
    
    [11]
    [pytorch_backend: https://github.com/triton-inference-server/pytorch_backend
    
    [12]
    triton镜像历史: https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/rel_22-06.html
    
    [13]
    https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html: https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html
    
    [14]
    https://github.com/triton-inference-server/backend: https://github.com/triton-inference-server/backend
    
    [15]
    https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md: https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md
    
    [16]
    https://github.com/triton-inference-server/server/issues/2563: https://github.com/triton-inference-server/server/issues/2563
    
    [17]
    https://github.com/triton-inference-server/server/issues/1822: https://github.com/triton-inference-server/server/issues/1822
    
    [18]
    https://www.bilibili.com/video/BV1KS4y1v7zd/?spm_id_from=333.788&vd_source=eec038509607175d58cdfe2e824e8ba2: https://www.bilibili.com/video/BV1KS4y1v7zd/?spm_id_from=333.788&vd_source=eec038509607175d58cdfe2e824e8ba2
    
    [19]
    https://github.com/triton-inference-server/server/releases: https://github.com/triton-inference-server/server/releases
    
    [20]
    https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/index.html: https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/index.html
    
    [21]
    https://zhuanlan.zhihu.com/p/462817679: https://zhuanlan.zhihu.com/p/462817679