6 – 模型训练

基于上文构建的模型,本篇使用开放数据进行预训练。

我们曾在“初识神经网络”篇章给出深度神经网络的训练框架,回顾一下:

optimizer = ...  # 初始化优化器
model = ...      # 定义模型

for epoch in range(训练轮数):  # 训练轮数
    for batch in 训练数据集:   # `训练数据集` 是一个数据加载器或其他类似的数据提供方式
        输入, 目标 = batch  # 每个批次返回一个输入和目标的元组
        
        # 清零参数梯度
        optimizer.zero_grad()
        # 前向传播 + 反向传播 + 更新权重
        输出 = model(输入)
        损失 = 损失函数(输出, 目标)  # 损失函数应该被定义
        损失.backward()  # 计算损失相对于模型参数的梯度
        optimizer.step()  # 根据计算出的梯度更新权重

上述框架涉及训练数据、优化器、模型、损失函数、训练轮数等参数。

  1. 训练数据准备

我们将使用 The Verdict 这部作品作文训练集数据。数据大小为9kB,包括 20479个字符。

伊迪丝·华顿的短篇小说《判决》首次发表于1908年。故事围绕着艺术家杰克·吉斯本的生活展开。杰克在事业巅峰时期放弃了绘画生涯,与一位富有的寡妇结婚,并搬到了里维埃拉。故事是从吉斯本的熟人瑞克汉姆先生的角度叙述的,瑞克汉姆拜访这对夫妇时,对吉斯本突然退出艺术界的原因产生了浓厚的兴趣。

伊迪丝·华顿是美国文学史上著名的作家,以其对社会阶层差异和道德约束的深刻洞察而闻名。她的作品常常探讨上流社会的规范及其对个人生活的影响。《判决》虽然不是她最知名的作品之一,但它依然体现了华顿对人性和社会动态的敏锐观察。

在Pytorch 快速入门章节,我们简单介绍过如何使用Dataset以及DataLoader 模块处理和加载数据。代码段6-1展示了二者配合使用处理和加载数据的过程。

import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler

class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 返回数据集中的第 idx 个元素
        print(f"Fetching item with index: {idx}")
        return self.data[idx]

# 示例数据集
data = [i for i in range(100)]
dataset = SimpleDataset(data)

# 初始化 DataLoader
batch_size = 16
sampler = RandomSampler(dataset)  # 使用随机采样器
# sampler = SequentialSampler(dataset)  # 使用顺序采样器
data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=2,
    collate_fn=lambda x: x  # 自定义批处理函数(这里使用默认行为)
)
# 遍历 DataLoader
for batch_idx, batch in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}: {batch}")

SimpleDataset 类实现了基本的数据集接口,提供了 lengetitem 方法。

DataLoader 通过传入的数据集实例和其他参数配置来初始化,并通过 enumerate 函数遍历 DataLoader,每次迭代返回一个批次的数据。

在这个例子中,DataLoader 按照指定的批次大小(batch_size=16)从数据集中读取数据,并且在每个 epoch 开始时都会重新混洗数据。通过设置 num_workers,可以启用多线程加载,从而加速数据读取过程。

代码段6-2给出了处理数据集 The Verdict的过程。

import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler
import tiktoken

class LargeTextDataset(Dataset):
    
    def __init__(self, file_path, tokenizer, max_length=128):
        super(LargeTextDataset, self).__init__()
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data=[]
        self.load_data()
    
    def load_data(self):
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                tokens = self.tokenizer.encode(line.strip())[:self.max_length]
                if len(tokens)>1:
                    self.data.append(tokens)
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return torch.tensor(self.data[index], dtype=torch.long)

def custom_collate_fn(batch):
    print(batch)
    max_len= max([t.shape[0] for t in batch])
    padded_batch = []
    for t in batch:
        padding = torch.zeros(max_len - t.shape[0], dtype = torch.long)
        padded_t = torch.cat([t, padding])
        padded_batch.append(padded_t)

    return  torch.stack(padded_batch)

file_path = 'the-verdict.txt'  # 假设这是你的文本文件路径
tokenizer = tiktoken.get_encoding("gpt2")
dataset = LargeTextDataset(file_path, tokenizer)
sampler = RandomSampler(dataset)
data_loader = DataLoader(dataset,sampler=sampler,batch_size=16, collate_fn = custom_collate_fn)

for batch_idx, batch in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}: {batch}")

上述代码在工程实践需要注意以下问题:

  • 一次性加载数据集到内存,如果数据集超大这是不可行的;
  • DataLoader 参数的设定,包括采样器、批大小、批处理函数、并行线程数;

我们会在工程优化部分详细介绍。

由于 The Verdict数据集比较小,此处使用7:1.5:1.5的比例将数据划分为训练集、交叉验证集和测试集。

代码段6-3展示了数据集的划分过程:使用 sklearn.model_selection.train_test_split 来逐步划分数据。首先,将数据划分为训练集和剩余集(剩余集包含交叉验证集和测试集),然后再将剩余集进一步划分为交叉验证集和测试集。

from sklearn.model_selection import train_test_split

# 初始化数据集
file_path = 'the-data.txt'  # 假设这是你的文本文件路径
tokenizer = tiktoken.get_encoding("gpt2")
dataset = LargeTextDataset(file_path, tokenizer)

# 将数据集划分为训练集和剩余集(交叉验证集 + 测试集)
train_size = 0.7
remaining_size = 0.3
train_dataset, remaining_dataset = train_test_split(dataset, train_size=train_size, random_state=42)

# 将剩余集进一步划分为交叉验证集和测试集
val_test_ratio = 0.5  # 交叉验证集和测试集的比例相等
val_dataset, test_dataset = train_test_split(remaining_dataset, train_size=val_test_ratio, random_state=42)

# 初始化数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, sampler=RandomSampler(train_dataset), collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, sampler=SequentialSampler(val_dataset), collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, sampler=SequentialSampler(test_dataset), collate_fn=custom_collate_fn)

# 遍历数据加载器
for batch_idx, batch in enumerate(train_loader):
    print(f"Training Batch {batch_idx + 1}: {batch.shape}")

for batch_idx, batch in enumerate(val_loader):
    print(f"Validation Batch {batch_idx + 1}: {batch.shape}")

for batch_idx, batch in enumerate(test_loader):
    print(f"Test Batch {batch_idx + 1}: {batch.shape}")
  1. 训练过程评估

模型在训练的时候,需要通过损失函数计算训练集和交叉验证集合之间的误差。模型训练按批处理数据,也需要按批计算损失函数。此处我们选择使用交叉熵损失函数,它衡量的是模型预测的概率分布与实际标签的概率分布之间的差异,其梯度形式简单,便于优化算法收敛。

在PyTorch中,交叉熵损失函数可以通过

torch.nn.functional.cross_entropy 或者

torch.nn.CrossEntropyLoss 类来实现。这两个方法在功能上基本相同,但 torch.nn.CrossEntropyLoss 是一个类,可以作为模型的一部分保存状态,而 torch.nn.functional.cross_entropy 是一个函数,适合临时使用。cross_entropy函数内部会自动应用 softmax 函数,因此传入的 logits 不需要预先进行 softmax 处理。

示例见代码段6-4:

import torch
import torch.nn.functional as F

# 假设 logits 是模型的输出,target 是真实的标签
logits = torch.randn(3, 5)  # 3个样本,5个类别
target = torch.tensor([1, 0, 4])  # 真实标签

# 计算交叉熵损失
loss = F.cross_entropy(logits, target)
print(loss)

对于训练批次和交叉验证批次,代码段6-5:

def calc_loss_batch(feature_batch, target_batch, model):
    logits = model(feature_batch)
    loss = torch.nn.functional.cross_entropy(
        logits.flatten(0, 1), target_batch.flatten()
    )
    return loss
    
with torch.no_grad(): 
    for i, (input_batch, target_batch) in enumerate(train_loader):
        train_loss = calc_loss_batch(input_batch,target_batch, model)
    for i, (input_batch, target_batch) in enumerate(val_loader):
        val_loss = calc_loss_batch(input_batch,target_batch, model)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)
  1. 开启训练

大语言模型的训练流程和现有的深度神经网络没有本质区别:

代码段6-6:

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    for batch_idx, input_batch in enumerate(train_loader):
        # 假设输入和目标是连续的,这里简单地将输入作为目标
        target_batch = input_batch[:, 1:]  # 目标序列(左移一位)
        input_batch = input_batch[:, :-1]  # 输入序列

        # 计算损失
        loss = calc_loss_batch(input_batch, target_batch, model, device)

        # 前向传播 + 反向传播 + 更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch+1} [{batch_idx*len(input_batch)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    avg_train_loss = total_train_loss / len(train_loader)
    print(f'Average Train Loss: {avg_train_loss:.6f}')

    # 验证模型
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch_idx, input_batch in enumerate(val_loader):
            # 假设输入和目标是连续的,这里简单地将输入作为目标
            target_batch = input_batch[:, 1:]  # 目标序列(左移一位)
            input_batch = input_batch[:, :-1]  # 输入序列

            # 计算损失
            loss = calc_loss_batch(input_batch, target_batch, model, device)

            total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f'Average Validation Loss: {avg_val_loss:.6f}')

当调用 model.train() 时,模型进入训练模式。这意味着模型中的所有组件都会准备好进行训练。

  • Batch Normalization:在训练模式下,BatchNorm 层会根据当前批次的数据更新其统计信息(均值和方差),并在前向传播中使用这些统计信息进行标准化。
  • Dropout:在训练模式下,Dropout 层会按照一定的概率随机关闭一部分神经元,从而起到正则化的作用。

调用 model.eval() 时,模型进入评估模式。这意味着模型中的所有组件都将处于评估状态,这对于验证和测试非常有用,因为此时不需要模型的行为发生改变。

  • Batch Normalization:在评估模式下,BatchNorm 层会使用整个训练过程中积累的统计信息来进行标准化,而不是当前批次的数据。
  • Dropout:在评估模式下,Dropout 层不会随机关闭神经元,而是保持所有神经元开启,并且权重乘以 dropout 概率,以此来模拟训练时的效果。

一切都是那么的丝滑。

  1. 训练暂停和重启

暂停和重启训练一个深度学习模型通常涉及到两个关键步骤:

保存模型的状态(checkpoint):在训练过程中保存模型的权重、优化器的状态以及其他相关信息(如训练的轮次、损失等)。

恢复模型的状态:在继续训练之前,加载之前保存的checkpoint,使模型回到之前的状态继续训练。

保存模型的状态

在训练过程中,你可以选择在某个时刻保存模型的状态。这通常是在每个epoch结束时执行,或者在训练过程中达到某个特定条件时执行(例如,每训练一定数量的epochs保存一次)

def save_checkpoint(model, optimizer, epoch, path):
    """
    保存模型的状态,包括模型的权重、优化器的状态以及当前epoch。
    
    参数:
    - model: 模型对象
    - optimizer: 优化器对象
    - epoch: 当前epoch
    - path: checkpoint保存的路径
    """
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_train_loss  # 如果需要的话,也可以保存当前epoch的损失等信息
    }, path)
    print(f'Model checkpoint saved to {path}')

# 在每个epoch结束时保存模型
for epoch in range(num_epochs):
    ...
    # 每个epoch结束时保存模型
    if (epoch + 1) % 5 == 0:  # 每5个epoch保存一次
        checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth'
        save_checkpoint(model, optimizer, epoch + 1, checkpoint_path)

恢复模型状态

要继续训练,你需要加载之前保存的checkpoint。这可以通过以下方式完成:

def load_checkpoint(model, optimizer, checkpoint_path, device):
    """
    加载模型的状态,包括模型的权重、优化器的状态以及当前epoch。
    
    参数:
    - model: 模型对象
    - optimizer: 优化器对象
    - checkpoint_path: checkpoint的路径
    - device: 设备(CPU或GPU)
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f'Model loaded from checkpoint {checkpoint_path}, starting from epoch {start_epoch}')
    return start_epoch

# 加载checkpoint并继续训练
checkpoint_path = 'checkpoint_epoch_10.pth'  # 假设这是你保存的checkpoint路径
start_epoch = load_checkpoint(model, optimizer, checkpoint_path, device)

# 继续训练
for epoch in range(start_epoch, num_epochs):
    model.train()
    ...
    # 每个epoch结束时保存模型
    if (epoch + 1) % 5 == 0:  # 每5个epoch保存一次
        checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth'
        save_checkpoint(model, optimizer, epoch + 1, checkpoint_path)
  1. 模型保存和模型加载

完整模型的保存和加载,和checkpoint类似。

在 PyTorch 中存储模型主要有几种不同的文件格式,其中包括 pkl、pt 和 pth。这些格式主要用于保存和加载模型的状态字典(state dict),即模型的权重和其他可学习的参数。下面详细介绍这几种格式的区别和使用方法。

  • .pth

.pth 是最常用的格式,它实际上是一个 Python 的 pickle 文件,用于保存模型的状态字典(state dict)。主要用于保存模型的权重和优化器的状态等信息。如果模型包含大量的元数据或其他非标准数据类型,可能会导致文件体积较大。

  • .pt

.pt 本质上也是 Python 的 pickle 文件,与 .pth 文件没有本质区别,只是扩展名不同。

  • .pkl

.pkl 也是一种 Python 的 pickle 文件,可以保存任何 Python 对象,不仅仅是模型的状态字典。除了保存模型的状态字典外,还可以保存更多的信息,如模型的架构定义、优化器的状态等。

import torch
model = Model()

# 保存模型的状态字典到.pth文件
torch.save(model.state_dict(), 'model.pth')

# 保存整个模型到.pth文件(包括模型架构和状态字典)
torch.save(model, 'model.pth')

# 保存到.pt文件
torch.save(model.state_dict(), 'model.pt')

# 保存到.pkl文件
torch.save(model.state_dict(), 'model.pkl')
  1. 附录

完整训练代码,本代码仅供参考,用于建立LLM的框架全局观,后续文章会在工程实现上进行优化和改进。

import torch
import torch.nn as nn

GPT_CONFIG = {
    "vocab_size": 50257,  # 模型词汇表的大小
    "context_length": 1024,      # 上下文长度
    "emb_dim": 768,       # 嵌入层的维度
    "n_heads": 12,        # 多头注意力机制中的头数量
    "n_layers": 12,       # transformer 堆叠层数
    "drop_rate": 0.1,     # dropout比率
    "bias": False     # 计算查询(Query)、键(Key)和值(Value)时是否使用偏置项
}
 
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super(GPTModel, self).__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg, device) for _ in range(cfg["n_layers"])])
        self.layer_norm = LayerNorm(cfg["emb_dim"], device)
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.layer_norm(x)
        logits = self.out_head(x)
        return logits

class TransformerBlock(nn.Module):
    def __init__(self, cfg, device):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(cfg, device)
        self.ff = FeedForward(cfg)
        self.norm = LayerNorm(cfg['emb_dim'], device)
        self.dropout = nn.Dropout(cfg['drop_rate'])

    def forward(self, x):
        shortcut = x
        x = self.norm(x)
        x = self.attention(x)
        x = self.dropout(x)
        x = x + shortcut

        shortcut = x
        x = self.norm(x)
        x = self.ff(x)
        x = self.dropout(x)
        x = x + shortcut
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, device):
        super(MultiHeadAttention, self).__init__()
        self.input_dim = cfg["emb_dim"]
        self.output_dim = cfg["emb_dim"]
        self.num_heads = cfg["n_heads"]
        self.head_dim = self.input_dim // self.num_heads
        self.W_query = nn.Linear(self.input_dim, self.output_dim, bias=cfg["bias"])
        self.W_key = nn.Linear(self.input_dim, self.output_dim, bias=cfg["bias"])
        self.W_value = nn.Linear(self.input_dim, self.output_dim, bias=cfg["bias"])
        self.dropout = nn.Dropout(cfg["drop_rate"])
        self.device = device

    def forward(self, inputs):
        batch_size, seq_len, _ = inputs.size()
        queries = self.W_query(inputs)
        keys = self.W_key(inputs)
        values = self.W_value(inputs)

        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        mask = torch.tril(torch.ones(seq_len, seq_len, device=self.device)).view(1, 1, seq_len, seq_len)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = torch.matmul(attn_weights, values)
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, seq_len, self.input_dim)
        return context_vec


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super(FeedForward, self).__init__()
        self.linear_layer_1 = nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"])
        self.dropout = nn.Dropout(cfg["drop_rate"])
        self.linear_layer_2 = nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear_layer_1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear_layer_2(x)
        return x

class LayerNorm(nn.Module):
    def __init__(self, emb_dim, device):
        super(LayerNorm, self).__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim, device=device))
        self.shift = nn.Parameter(torch.zeros(emb_dim, device=device))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift


import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler
import tiktoken
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

# 定义一个简单的数据集类,用于加载和处理数据
class LargeTextDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        super(LargeTextDataset, self).__init__()
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data=[]
        self.load_data()
    
    def load_data(self):
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                tokens = self.tokenizer.encode(line.strip())[:self.max_length]
                if len(tokens)>1:
                    self.data.append(tokens)
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return torch.tensor(self.data[index], dtype=torch.long)

# 自定义批处理函数
def custom_collate_fn(batch):
    max_len = max([t.shape[0] for t in batch])
    padded_batch = []
    for t in batch:
        padding = torch.zeros(max_len - t.shape[0], dtype=torch.long)
        padded_t = torch.cat([t, padding])
        padded_batch.append(padded_t)
    return torch.stack(padded_batch)

# 计算一个批次的损失
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(
        logits.flatten(0, 1), target_batch.flatten()
    )
    return loss

# 初始化数据集
file_path = 'the-verdict.txt'  # 假设这是你的文本文件路径
tokenizer = tiktoken.get_encoding("gpt2")
dataset = LargeTextDataset(file_path, tokenizer)

# 将数据集划分为训练集、交叉验证集和测试集
train_size = 0.7
remaining_size = 0.3
train_dataset, remaining_dataset = train_test_split(dataset, train_size=train_size, random_state=42)

val_test_ratio = 0.5  # 交叉验证集和测试集的比例相等
val_dataset, test_dataset = train_test_split(remaining_dataset, train_size=val_test_ratio, random_state=42)

# 初始化数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, sampler=RandomSampler(train_dataset), collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, sampler=SequentialSampler(val_dataset), collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, sampler=SequentialSampler(test_dataset), collate_fn=custom_collate_fn)

# 设定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型、损失函数和优化器
model = GPTModel(GPT_CONFIG)  # 初始化你的模型
model = model.to(device)  # 确保模型在正确的设备上

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def generate_text (model, input_sequence):
    model.eval()
    encoded = tokenizer.encode(input_sequence)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0).to(device)
    generation_length = 10
    generated_text = encoded_tensor
    generation_length = 10
    generated_text = encoded_tensor
    for _ in range(generation_length - generated_text.shape[1]):
        with torch.no_grad():
            output_logits = model(generated_text)
        last_position_probabilities = F.softmax(output_logits[:, -1, :], dim=-1)
        next_word = torch.multinomial(last_position_probabilities, 1)
        generated_text = torch.cat([generated_text, next_word], dim=1) 
    decoded_text = tokenizer.decode(generated_text.squeeze(0).tolist())
    print(decoded_text)
    model.train()

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    for batch_idx, input_batch in enumerate(train_loader):
        # 假设输入和目标是连续的,这里简单地将输入作为目标
        target_batch = input_batch[:, 1:]  # 目标序列(左移一位)
        input_batch = input_batch[:, :-1]  # 输入序列

        # 计算损失
        loss = calc_loss_batch(input_batch, target_batch, model, device)

        # 前向传播 + 反向传播 + 更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch+1} [{batch_idx*len(input_batch)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    avg_train_loss = total_train_loss / len(train_loader)
    print(f'Average Train Loss: {avg_train_loss:.6f}')

    # 验证模型
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch_idx, input_batch in enumerate(val_loader):
            # 假设输入和目标是连续的,这里简单地将输入作为目标
            target_batch = input_batch[:, 1:]  # 目标序列(左移一位)
            input_batch = input_batch[:, :-1]  # 输入序列
            # 计算损失
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f'Average Validation Loss: {avg_val_loss:.6f}')
    generate_text(model, "hello she said,")
print('Finished Training')