duee篇章级触发词模型训练预测

2024-07-11 1110阅读

import ast
import os
import json
import warnings

duee篇章级触发词模型训练预测
(图片来源网络,侵删)

import random
from functools import partial
import numpy as np
import paddle
import paddle.nn.functional as F

from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.transformers import AutoModelForTokenClassification, AutoTokenizer
from paddlenlp.metrics import ChunkEvaluator

def read_by_lines(path):
    result = list()
    with open(path, "r", encoding="utf8") as infile:
        for line in infile:
            result.append(line.strip())
    return result

def write_by_lines(path, data):
    with open(path, "w", encoding="utf8") as outfile:
        [outfile.write(d + "\n") for d in data]

def load_dict(dict_path):
    vocab = {}
    for line in open(dict_path, 'r', encoding='utf-8'):
        value, key = line.strip('\n').split('\t')
        vocab[key] = int(value)
    return vocab

num_epoch=3
learning_rate=5e-5
tag_path='./conf/DuEE-Fin/trigger_tag.dict'
train_data='./datasets/DuEE-Fin/trigger/train.tsv'
dev_data='./datasets/DuEE-Fin/trigger/dev.tsv'
test_data='./datasets/DuEE-Fin/trigger/test.tsv'
predict_data=None
warmup_proportion=0.0
batch_size=10
checkpoints='./checkpoints/Duee_extract/'
init_ckpt=None
predict_save_path=None
seed=1000
device='gpu'
weight_decay=0.0

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)

from paddlenlp.datasets import MapDataset

def load_dataset(datafiles):
    def read(data_path):
        with open(data_path, 'r', encoding='utf-8') as fp:
            next(fp)  # Skip header
            for line in fp.readlines():
                words, labels = line.strip('\n').split('\t')
                words = words.split('\002')
                labels = labels.split('\002')
                yield words, labels
    if isinstance(datafiles, str):
        return MapDataset(list(read(datafiles)))
    elif isinstance(datafiles, list) or isinstance(datafiles, tuple):
        return [MapDataset(list(read(datafile))) for datafile in datafiles]

paddle.set_device(device)

set_seed(seed)

no_entity_label = 'O'
ignore_label = -1

tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-medium-zh")

label_map = load_dict

VPS购买请点击我

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

目录[+]