BERT数据处理,模型,预训练_bert怎么处理数值信息-程序员宅基地

技术标签: 深度学习  人工智能  bert  

代码来自李沐老师《动手学pytorch》
在数据处理时,首先执行以下代码
def load_data_wiki(batch_size, max_len):
    """加载WikiText-2数据集"""
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
    以上两句代码,不再说明
    paragraphs = _read_wiki(data_dir)
    train_set = _WikiTextDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                         shuffle=True)
    return train_iter, train_set.vocab

d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

#@save
def _read_wiki(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r',encoding='utf-8') as f:
        lines = f.readlines()
    # 大写字母转换为小写字母 ,每行文本中包含两个句子,才进行处理,否则舍去文本
    paragraphs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

首先读取文本,每个文本必须包含两个以上句子(为了第二个预训练任务:判断两个句子,是否连续)。paragraphs 其中一部分结果如下所示

文本中包含了三个句子,每个’‘里面,代表一个句子
['common starlings are trapped for food in some mediterranean countries'
, 'the meat is tough and of low quality , so it is <unk> or made into <unk>'
, 'one recipe said it should be <unk> " until tender , however long that may be "'
, 'even when correctly prepared , it may still be seen as an acquired taste .']

class _WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self, paragraphs, max_len):
        '''
        每一个paragraph就是上面的包含多个句子的列表,将其进行分词处理。下面是一个分词的例子
        [['common', 'starlings', 'are', 'trapped', 'for', 'food', 'in', 'some', 'mediterranean', 'countries']
        , ['the', 'meat', 'is', 'tough', 'and', 'of', 'low', 'quality', ',', 'so', 'it', 'is', '<unk>', 'or', 'made', 'into', '<unk>'], ['one', 'recipe', 'said', 'it', 'should', 'be', '<unk>', '"', 'until', 'tender', ',', 'however', 'long', 'that', 'may', 'be', '"']
        , ['even', 'when', 'correctly', 'prepared', ',', 'it', 'may', 'still', 'be', 'seen', 'as', 'an', 'acquired', 'taste', '.']]
        '''
        paragraphs = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraphs]
        #将词提取处理,保存
        sentences = [sentence for paragraph in paragraphs
                     for sentence in paragraph]
        #形成一个词典,min_freq为词最少出现的次数,少于5次,则不保存进词典中
        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
            '<pad>', '<mask>', '<cls>', '<sep>'])
        # 获取下一句子预测任务的数据
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(
                paragraph, paragraphs, self.vocab, max_len))
            '''
def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):
    nsp_data_from_paragraph=[]
    for i in range(len(paragraph)-1):
    
_get_next_sentence函数传入的是相邻的句子a,b。函数中b会有一定概率替换为其他的句子

        tokens_a, tokens_b, is_next = _get_next_sentence(
            paragraph[i], paragraph[i + 1], paragraphs)
            
句子长度大于bert限制的长度,则舍去。
        if len(tokens_a)+len(tokens_b)+3>max_len:
            continue
            
        #加上<cls>和<sep>,segments用于区token在哪个句子中
        
        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))
    return nsp_data_from_paragraph
  
  token和segments的例子: True表示两个句子相邻,False表示b被随机替换,a,b不相邻。
            (['<cls>', 'mushrooms', 'grow', '<unk>', 'or', 'in', '"', '<unk>', 'groups', '"', 'in', 'late', 'summer', 'and', 'throughout', 
            'autumn', ',', 'though', 'it', 'is', 'not', 'commonly', 'encountered', 'species', '<sep>', 'it',
             'can', 'be', 'found', 'in', 'europe', ',', 'asia', 'and', 'north', 'america', '.', '<sep>'], 
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1], True),
            '''


        # 获取遮蔽语言模型任务的数据
        '''
     在这里我们会将句子中单词,替换为在词典中的索引。13意思为,句子的第13个词,进行了处理,可能不变,可能替换为其他词,可能替换为mask。在这里这个词没有替换。0与1区分两个句子,False代表两个句子不相邻。
        examples中的结果;
        ([3, 2510, 31, 337, 9, 0, 6, 6891, 8, 11621, 6, 21, 11, 60, 3405, 14, 1542, 9546, 4, 2524,
         21, 185, 4421, 649, 38, 277, 2872, 13233, 4], [13], [60], 
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
         False)
        '''
        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
                      + (segments, is_next))
                     for tokens, segments, is_next in examples]
                     
        #_pad_bert_inputs对数据进行填充,all_mlm_weights中1为需要预测,0为填充
   #    all_mlm_weights= tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights,
         self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
            examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)

上述已经将数据处理完,最后看一下处理后的例子:

将原来的句子列表填充1,一直到到大小为64
tensor([[    3,     5,     0, 18306,    23,    11,  2659,   156,  5779,   382,
          1296,   110,   158,    22,     5,  1771,   496,     0,  3398,     2,
             5,  3496,   110,  5038,   179,     4,    16,    11, 19837,     6,
            58,    13,     5,   685,     7,    66,   156,     0,  3063,    77,
          3842,    19,     4,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1]])
segments用于区分两个句子,0为第一个句子中的词,1为第二个句子中的词,后面的0为填充
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
valid_lens表示句子列表的有效长度
tensor([43.])
pred_positions需要预测的位置,0为填充
tensor([[19,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
mlm_weights需要预测多少个词,0为填充
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
预测位置的真实标签,0为填充
tensor([[22,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
两句话是否相邻
tensor([0])

随后就是把处理好的数据,送入bert中。在 BERTEncoder 中,执行如下代码:

 def forward(self, tokens, segments, valid_lens):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
      #  将token和segment分别进行embedding,
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
      #加入位置编码
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

将编码完后的数据,进行多头注意力和残差化

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

将结果返回到如下代码中:其中encoded_X .shape=torch.Size([1, 64, 128]),1代表批次大小为1,我们设置的每个批次只有行文本,每行文本由64个词组成,bert提取128维的向量来表示每个词。随后进行两个任务,一个是预测被掩盖的单词,另一个为判断两个句子是否为相邻。

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

第一个任务为预测被mask的单词:

'''
例如:batch为1,X为1*64*128,其中num_pred_positions =10,batch_idx 会重复为[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],pred_positions为[ 3,  6, 10, 12, 15, 20,  0,  0,  0,  0],X[batch_idx, pred_positions]会将需要预测的向量取出。然后reshape为1*10*128的矩阵。最后连接一个mlp,经过规范化后接nn.Linear(num_hiddens, vocab_size)),会生成再vocab上的预测

'''
 def forward(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        # Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
        # `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

结束后,会返回到上层的代码中:

def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        判断句子是否连续,将<cls>的向量,放入mlp中,接一个nn.Linear(num_inputs, 2),最后变成一个二分类问题。
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

后面就是计算损失:

将mlm_Y_hat进行reshap,与mlm_Y求loss,最后需要乘mlm_weights_X,将填充的无用数据进行去除。
 mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
 取平均loss
 mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
 nsp_l = loss(nsp_Y_hat, nsp_y)
 l = mlm_l + nsp_l
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/m0_46312382/article/details/132260169

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签