MCPcopy
hub / github.com/thu-coai/CDial-GPT

github.com/thu-coai/CDial-GPT @main sqlite

repository ↗ · DeepWiki ↗
60 symbols 206 edges 10 files 9 documented · 15%
README

CDial-GPT

  • 本项目提供了一个大规模中文对话数据集,并提供了在此数据集上的中文对话预训练模型(中文GPT模型),更多信息可参考我们的论文

  • 本项目代码修改自TransferTransfo,使用了HuggingFace Pytorch版的Transformers库, 可用于预训练与微调。

目录

News

  • 2022-06-09: LCCC数据集现在可以通过huggingface的datasets库加载:
from datasets import load_dataset

dataset = load_dataset("lccc", "base")  # or "large"
  • 2022-04-26: 一个新的多模态对话数据集MMChat,欢迎大家使用。
  • 2021-02-28: 一个对话数据清洗框架,欢迎大家提bug和加速优化算法,以及新的清洗功能等等。
  • 2021-01-09: 实验室出版新书《现代自然语言生成》,欢迎大家阅读购买。
  • 2020-11-20: 预训练模型新工作SentiLARE。本工作将词级别的语言学知识(包括词性和词的情感极性)引入预训练语言模型中,提出了一种适用于情感分析任务的语言表示模型SentiLARE,欢迎大家使用。
  • 2020-10-18: 我们的论文《A Large-Scale Chinese Short-Text Conversation Dataset》获得了NLPCC2020 Best Student Paper Award。 🎉 🎉 🎉
  • 2020-09-08: 感谢@xiejiachen所提供的可视化Web界面
  • 2020-09-02: 可用bert4keras加载TF版本的CDial-GPT模型,感谢苏剑林@bojone提供代码。

数据集概况

我们所提供的数据集LCCC(Large-scale Cleaned Chinese Conversation)主要包含两部分: LCCC-base (百度网盘, Google Drive) 和 LCCC-large (百度网盘, Google Drive). 我们设计了一套严格的数据过滤流程来确保该数据集中对话数据的质量。 这一数据过滤流程中包括一系列手工规则以及若干基于机器学习算法所构建的分类器。 我们所过滤掉的噪声包括:脏字脏词、特殊字符、颜表情、语法不通的语句、上下文不相关的对话等。

该数据集的统计信息如下表所示。 其中,我们将仅包含两个语句的对话称为“单轮对话”,我们将包含两个以上语句的对话称为“多轮对话”。 统计词表大小时使用 Jieba 分词。

| LCCC-base

(百度网盘, Google Drive) 单轮对话 多轮对话
总对话轮次 3,354,232 3,466,274
总对话语句 6,708,464 13,365,256
总字符数 68,559,367 163,690,569
词表大小 372,063 666,931
对话语句的平均词数 6.79 8.32
每轮对话的平均语句数 2 3.86

请注意, LCCC-base 数据集的清洗过程比 LCCC-large 更为严格,因此其规模也更小。

| LCCC-large

(百度网盘, Google Drive) 单轮对话 多轮对话
总对话轮次 7,273,804 4,733,955
总对话语句 14,547,608 18,341,167
总字符数 162,301,556 217,776,649
词表大小 662,514 690,027
对话语句的评价词数 7.45 8.14
每轮对话的平均语句数 2 3.87

LCCC-base 数据集中的原始对话数据来自于微博对话,LCCC-large 数据集中的原始对话数据在这些微博对话的基础上融合了其他开源对话数据集:

数据集 总对话轮次 对话示例
Weibo Corpus 79M Q:火锅我在重庆成都吃了七八顿火锅 A: 哈哈哈哈!那我的嘴巴 可能要烂掉!
PTT Gossiping Corpus 0.4M Q:为什么乡民总是欺负国高中生呢QQ A:如果以为选好科系就会变成比尔盖兹那不如退学吧
Subtitle Corpus 2.74M Q:京戏里头的人都是不自由的 A:他们让人拿笼子给套起来了了
Xiaohuangji Corpus 0.45M Q:你谈过恋爱么 A:谈过,哎,别提了,伤心..
Tieba Corpus 2.32M Q:前排,鲁迷们都起床了吧 A:标题说助攻,但是看了那球,真是活生生的讽刺了
Qingyun Corpus 0.1M Q:看来你很爱钱 A:噢是吗?那么你也差不多了
Douban Conversation Corpus 0.5M Q:看原版英文电影学纯正英语 A:大爱老友记反复看了好多次 了 Q:一样光盘都快被我看花了 A:那你现在的英语应该不错了
E-commerical Conversation Corpus 0.5M Q:这个会不会聚划算 A:暂时没有哦 Q:后期会不会有 A:不一定哦亲多多关注我们哦
Chinese Chat Corpus 0.5M Q: 我今天腿都废了,你们过节,我搬砖 A: 辛苦啊,圣诞节还去赚大钱了加油 Q: 毕竟是没男朋友的人,什么节都是一样的

预训练模型概况

模型

我们同时提供了一系列中文预训练模型(中文GPT模型),这些模型的预训练过程分为两步,首先在一个中文小说数据上预训练,然后在LCCC数据集上预训练。

我们沿用了 TransferTransfo 中的数据预处理设定,既将所有的对话历史拼接为一个句子,然后使用这个句子作为模型的输入,预测对话回复。我们模型的输入除了各个词的向量表示外,还包括发话人向量表示和位置向量表示。

模型输入

预训练模型 参数数量 预训练所使用数据 描述
GPTNovel 95.5M 中文小说数据 基于中文小说数据所构建中文预训练GPT模型 (该小说数据中共包括1.3B个字)
CDial-GPTLCCC-base 95.5M LCCC-base 在GPTNovel的基础上,使用 LCCC-base 训练得到的中文预训练GPT模型
CDial-GPT2LCCC-base 95.5M LCCC-base 在GPTNovel的基础上,使用 LCCC-base 训练得到的中文预训练GPT2模型
CDial-GPTLCCC-large 95.5M LCCC-large 在GPTNovel的基础上,使用 LCCC-large 训练得到的中文预训练GPT模型

安装

从源代码直接安装:

git clone https://github.com/thu-coai/CDial-GPT.git
cd CDial-GPT
pip install -r requirements.txt

快速开始

Step 1: 准备预训练模型和 fine-tuning 所需使用的数据集(如 STC dataset 或该项目目录中的toy数据 "data/toy_data.json", 请注意如数据中包含英文需按字母分割如:h e l l o)

# 下载 STC 数据集 中的训练集和验证集 并将其解压至 "data_path" 目录 (如果微调所使用的数据集为 STC)
git lfs install
git clone https://huggingface.co/thu-coai/CDial-GPT_LCCC-large  # 您可自行下载模型或者OpenAIGPTLMHeadModel.from_pretrained("thu-coai/CDial-GPT_LCCC-large")

ps:可以使用如下链接下载STC的训练集和验证集 (百度网盘, Google Drive)

Step 2: 训练模型

python train.py --pretrained --model_checkpoint thu-coai/CDial-GPT_LCCC-large --data_path data/STC.json --scheduler linear  # 使用单个GPU进行训练

或者

python -m torch.distributed.launch --nproc_per_node=8 train.py --pretrained --model_checkpoint thu-coai/CDial-GPT_LCCC-large --data_path data/STC.json --scheduler linear  # 以分布式的方式在8块GPU上训练

我们的训练脚本中还提供了 train_path 参数,用户可使用该参数以切片地形式读取纯文本文件。如果您所使用的的系统中内存有限,可以考虑使用该参数读入训练数据。 如果您使用 train_path 则需要将 data_path 置空。

Step 3: 生成文本

# YOUR_MODEL_PATH: 你要使用的模型的路径,每次微调后的模型目录保存在./runs/中
python infer.py --model_checkpoint YOUR_MODEL_PATH --datapath data/STC_test.json --out_path STC_result.txt  # 在测试数据上生成回复
python interact.py --model_checkpoint YOUR_MODEL_PATH  # 在命令行中与模型进行交互

ps:可以使用如下链接下载STC的测试集 (百度网盘, Google Drive)

训练脚本参数

参数 类型 默认值 描述
model_checkpoint str "" Path or URL of model files (Directory of pre-training model and config/vocab files)
pretrained bool False If False, then train the model from scratch
data_path str "" Path of the dataset
dataset_cache str default="dataset_cache" Path or url of the dataset cache
train_path str "" Path of the training set for distributed dataset
valid_path str "" Path of the validation set for distributed dataset
log_file str "" Output logs to a file under this path
num_workers int 1 Number of subprocesses for data loading
n_epochs int 70 Number of training epochs
train_batch_size int 8 Batch size for training
valid_batch_size int 8 Batch size for validation
max_history int 15 Number of previous exchanges to keep in history
scheduler str "noam" Method of optimizer
n_emd int 768 Number of n_emd in config file (for noam)
eval_before_start bool False If true, start evaluation before training
warmup_steps int 5000 Warm up steps
valid_steps int 0 Perform validation every X steps, if is not 0
gradient_accumulation_steps int 64 Accumulate gradients on several steps
max_norm float 1.0 Clipping gradient norm
device str "cuda" if torch.cuda.is_available() else "cpu" Device (cuda or cpu)
fp16 str "" Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)
local_rank int -1 Local rank for distributed training (-1: not distributed)

评测结果

我们评测了使用 STC数据集 (训练集/验证集 (百度网盘, Google Drive), 测试集 (百度网盘, Google Drive)) 微调后的对话预训练模型。 所有的回复均使用 Nucleus Sampling 的方法采样得到 (p=0.9, temperature=0.7)。

自动评价指标

模型 模型大小 PPL BLEU-2 BLEU-4 Dist-1 Dist-2 Greedy Matching Embedding Average
Attn-Seq2seq 73M 34.20 3.93 0.90 8.5 11.91 65.84 83.38
Transformer 113M 22.10 6.72 3.14 8.8 13.97 66.06 83.55
GPT2-chitchat 88M - 2.28 0.54 10.3 16.25 61.54 78.94
GPTNovel 95.5M 21.27 5.96 2.71 8.0 11.72 66.12 83.34
GPTLCCC-base 95.5M 18.38 6.48 3.08 8.3 12.68 66.21 83.54
GPT2LCCC-base 95.5M 22.76 5.69 2.50 7.7 10.87 66.24 83.46
GPTLCCC-large 95.5M 18.23 6.63 3.20 8.3 12.71 66.25 83.63

人工评价指标

我们为每个模型采样了200个回复,并在外包平台上邀请了3位标注员对这些回复进行人工评测,评测维度为对话流畅性、上下文相关性和回复多样性。每个维度的打分范围为 2/1/0。更多细节请参考我们的 论文

模型 +2 +1 +0 Kappa
Attn-Seq2Seq 27.1% 21.4% 51.4% 0.4544
Transformer 42.4% 23.6% 34.0% 0.4429
GPT2-chitchat 24.3% 19,1% 56.6% 0.3941
GPTLCCC-base 46.7% 21.8% 31.5% 0.3954
GPTLCCC-large 48.6% 24.5% 27.0% 0.4203

不良回复统计

我们随机采样了1000个对话回复,并统计了这些回复中的不良回复:

包含重复语句 包含脏字 其他不良回复
23 9 2

GPTLCCC-large模型的生成示例

我们在STC数据集上微调了GPTLCCC-large模型的参数,并使用这一微调后的模型生成了如下对话示例:

Post Response
这 居 然 是 北 京 不 是 哈 尔 滨 。 哈 尔 滨 的 天 气 好 像 比 北 京 好 点 , 我 这 里 就 是 小 雨 , 白 天 都 是 大 风
运 营 商 校 园 广 告 大 P K , 太 欢 乐 了 ! 哈 哈 哈 。 我 以 为 是 马 云 的 广 告 。
D e l s a S p o s a 糖 果 色 系 列 婚 纱 , 粉 蓝 红 紫 , 俏 皮 又 清 新 , 你 喜 欢 吗 ? 最 后 一 件 太 美 了 , 美 的 不 像 话 !

我们还使用上述微调后的模型得到了如下交互结果:

Roel Response
User 别爱我没结果
Bot 没 结 果 的 爱 是 多 么 的 痛 。
User 但是爱过就可以了
Bot 有 时 候 ,爱 情 就 是 一 场 游 戏 。
User 输的总是我
Bot 我 爱 你 , 就 像 我 爱 你 一 样 。
User 你可拉到吧
Bot 你 可 拉 到 吧 。
User 别学我说话
Bot 我 说 话 就 是 你 不 行 。
User 逼我动手是吧
Bot 不 , 是 逼 你 动 口 是 吧

作者

Yida Wang, Pei Ke, Yinhe Zheng, Kaili Huang, Yong Jiang, Xiaoyan Zhu, Minlie Huang

致谢

Zuoxian Ye, Yao Wang, Yifan Pan

免责声明

本项目所提供的 LCCC 数据集和预训练对话模型仅限科研用途。LCCC数据集中的对话收集自不同的来源,虽然我们设计了一套严格的数据清洗流程,但是我们并不保证所有不当内容均已被过滤。该数据中所包含的所有内容和意见与本项目作者无关。 本项目所提供的模型和代码仅为完整对话系统的一个组成部分,我们所提供的解码脚本仅限科研用途,使用本项目中的模型和脚本所生成的一切对话内容与本项目作者无关。

引用

如果您觉得我们的项目对您有帮助,请引用我们的论文

@inproceedings{wang2020chinese,
  title={A Large-Scale Chinese Short-Text Conversation Dataset},
  author={Wang, Yida and Ke, Pei and Zheng, Yinhe and Huang, Kaili and Jiang, Yong and Zhu, Xiaoyan and Huang, Minlie},
  booktitle={NLPCC},
  year={2020},
  url={https://arxiv.org/abs/2008.03946}
}

CDial-GPT

  • This project provides a large-scale cleaned Chinese conversation dataset and a Chinese GPT model pre-trained on this dataset. Please refer to our paper for more details.

  • Our code used for the pre-training is adapted from the TransferTransfo model based on the Transformers library. The codes used for both pre-tra

Core symbols most depended-on inside this repo

tokenize
called by 5
od/inputters/dataset_wb.py
tokenize
called by 4
od/inputters/inputter.py
sample_sequence
called by 2
contrib/dash_app/interact.py
setup_seed
called by 1
train.py
train
called by 1
train.py
top_filtering
called by 1
infer.py
build_input_from_segments
called by 1
infer.py
sample_sequence
called by 1
infer.py

Shape

Function 39
Method 14
Route 4
Class 3

Languages

Python100%

Modules by API surface

od/inputters/dataset_wb.py18 symbols
contrib/dash_app/app.py8 symbols
train.py6 symbols
infer.py6 symbols
od/utils/data_utils.py5 symbols
interact.py5 symbols
contrib/dash_app/interact.py5 symbols
od/inputters/inputter.py4 symbols
contrib/dash_app/chat_res.py3 symbols

Dependencies from manifests, versioned

pytorch-ignite0.2.1 · 1×
tensorboardX1.8 · 1×
torch1.4.0 · 1×
transformers2.1.1 · 1×

For agents

$ claude mcp add CDial-GPT \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact