什么是“小说机器人”?
我们需要明确“小说机器人”的具体任务,它不是一个能独立构思完整故事大纲和角色的AI,而更像是一个强大的写作辅助工具,它可以完成以下任务:

- 文本续写: 给定一段开头,机器人可以续写后续内容。
- 风格模仿: 学习某位作家的文风,并生成具有相似风格的新文本。
- 情节生成: 根据关键词或简单的情节提示,生成故事片段。
- 对话生成: 生成符合角色性格的对话。
我们的核心目标是生成连贯、有逻辑、且具有一定文学性的中文文本。
技术选型与模型架构
对于这类生成任务,最强大的模型架构是 Transformer,特别是其中的 Decoder 部分,像 GPT (Generative Pre-trained Transformer) 系列模型就是基于这个原理。
为什么选择 Transformer?
- 长距离依赖: Transformer 的自注意力机制能很好地捕捉文本中相距很远的词语之间的关系,这对于保持故事的连贯性至关重要。
- 并行计算: 相比于传统的 RNN,Transformer 训练速度更快,更容易处理大规模数据。
模型选择方案
对于大多数开发者,直接从头训练一个庞大的 Transformer 模型(如 GPT-3)是不现实的,因为它需要海量的计算资源和数据,我们通常采用以下更可行的方案:
微调预训练模型 (强烈推荐)

这是目前最主流、效果最好的方法,我们可以利用一个已经在大规模文本上预训练好的中文模型,然后在特定的小说数据集上进行“微调”(Fine-tuning)。
- 优点:
- 效果好:预训练模型已经具备了强大的语言理解和生成能力。
- 成本低:只需要相对较少的数据和计算资源就能达到不错的效果。
- 推荐的预训练模型:
- GPT-2 中文版: 有很多开源的中文 GPT-2 模型,如
uer/gpt2-chinese-cluecorpussmall。 - ChatGLM-6B: 一个开源的、性能强大的中英双语对话模型,基于 General Language Model (GLM) 架构,非常适合进行微调,它对中文的理解和生成能力非常出色。
- LLaMA 中文版: Meta 的 LLaMA 模型也有很多优秀的中文版本和微调版本。
- GPT-2 中文版: 有很多开源的中文 GPT-2 模型,如
使用 TensorFlow Hub 或 KerasNLP 的预生成模型
如果你不想进行训练,只想直接使用,可以查找 TensorFlow Hub 上是否有合适的文本生成模型,但这类模型通常针对通用任务,专门用于小说生成的较少。
从头构建一个简单的 Transformer (学习目的)

如果你想深入理解 Transformer 的工作原理,可以尝试用 TensorFlow/Keras 从头搭建一个简化版的 Transformer Decoder,这对于学习非常有帮助,但生成效果可能不如微调预训练模型。
核心步骤:以微调 ChatGLM-6B 为例
我们将重点介绍方案一,因为它最具实用性,假设你已经有一台有 GPU 的机器(可以使用 Google Colab 免费版)。
步骤 1: 环境准备
安装必要的库。transformers 库提供了 Hugging Face 的模型和工具,accelerate 用于简化 GPU 训练。
pip install transformers torch accelerate pip install protobuf==3.20.3 # 有时需要指定版本以避免兼容性问题 pip install datasets # 用于方便地加载数据
步骤 2: 数据准备
这是最关键的一步,你需要一个高质量的中文小说数据集,数据集越大、质量越高,效果越好。
- 数据来源:
- 网络公开的中文小说文本(注意版权问题,仅用于研究)。
- 古籍、现代小说等。
- 数据格式: 通常是一个巨大的
.txt文件,或者多个.txt文件。 - 数据预处理:
- 清洗: 去除无关的 HTML 标签、广告、特殊字符等。
- 分词: 中文需要分词,虽然很多现代模型(如 ChatGLM)已经内置了分词器,但在准备数据时,有时会先进行分词,然后再用模型的分词器处理。
- 格式化: 将文本分割成适合模型训练的样本,将长文本切分成固定长度(如 512 或 1024 tokens)的序列。
假设你有一个 novels.txt 文件,里面包含了多部小说的文本。
步骤 3: 加载模型和分词器
我们将使用 Hugging Face transformers 库来加载 ChatGLM-6B 模型。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
# 指定模型名称
model_name = "THUDM/chatglm-6b" # ChatGLM-6B 的 Hugging Face 模型 ID
# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 # 使用半精度以节省显存
).cuda() # 将模型加载到 GPU
# 打印模型信息,看看是否成功加载
print("Model and tokenizer loaded successfully!")
步骤 4: 数据处理和创建数据集
我们需要将文本文件转换成模型可以训练的格式。
from datasets import load_dataset
# 加载你的文本文件
# 假设文件是 utf-8 编码
dataset = load_dataset('text', data_files={'train': 'novels.txt'})
# 定义一个处理函数,将文本转换为模型的输入格式
def tokenize_function(examples):
# ChatGLM 使用特殊的格式,通常以 "问:\n答:" 开始
# 对于小说生成,我们可以让模型学习续写,所以可以直接用文本本身
# 我们将文本截断到模型的最大长度
return tokenizer(examples["text"], truncation=True, max_length=1024)
# 对数据集进行分词处理
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# 设置数据格式为 PyTorch
tokenized_datasets.set_format("torch")
步骤 5: 配置训练参数并开始微调
TrainingArguments 和 Trainer 是 Hugging Face 提供的高级训练 API,可以极大简化训练过程。
# 定义训练参数
training_args = TrainingArguments(
output_dir="./results", # 输出目录
num_train_epochs=3, # 训练轮数
per_device_train_batch_size=4, # 每个设备的批次大小 (根据你的显存调整)
gradient_accumulation_steps=4, # 梯度累加步数,相当于增大 batch size
optim="adamw_torch", # 优化器
save_steps=500, # 每 500 步保存一次模型
save_total_limit=2, # 最多保存 2 个模型
logging_steps=10, # 每 10 步打印一次日志
learning_rate=5e-5, # 学习率
fp16=True, # 使用混合精度训练
)
# 创建 Trainer 实例
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
'attention_mask': torch.stack([f['attention_mask'] for f in data]),
'labels': torch.stack([f['input_ids'] for f in data])}
)
# 开始训练!
trainer.train()
注意: 训练 ChatGLM-6B 这样的模型需要至少 12GB 的显存,在 Google Colab 的 A100 或 T4 上可以运行,但速度可能较慢。per_device_train_batch_size 和 gradient_accumulation_steps 需要根据你的显存大小进行调整。
步骤 6: 生成文本
训练完成后,我们可以使用模型来生成小说。
# 使用训练好的模型进行推理
prompt = "在一个月黑风高的夜晚,古堡里传来了一声尖叫。"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# 生成文本
# max_length: 生成的最大长度
# num_beams: beam search 的宽度,值越大结果越多样但越慢
# temperature: 温度,控制随机性,越高越随机
# top_p: top-p 核采样,控制随机性
output = model.generate(
**inputs,
max_length=500,
num_beams=5 标签: TensorFlow小说机器人创作原理 TensorFlow小说机器人写作技巧 TensorFlow小说机器人应用案例