In [133]:
import pandas as pd
import os
import tqdm
In [134]:
df = pd.read_csv('/root/.cache/data/train.csv')
In [135]:
df.head()
Out[135]:
In [136]:
df['discourse_type'].unique()
Out[136]:
In [137]:
ids = df['id'].unique()
ids.size
Out[137]:
In [138]:
# bad case
# 2726E31ECDC6
In [143]:
an_df = df[df['id'] == 'FFFD0AF13501']
an_df
Out[143]:
In [140]:
def get_instance_with_row(df, idx: int):
row = df.loc[idx].to_dict()
file_name = os.path.join('/root/.cache/data/train', row['id'] + '.txt')
with open(file_name, 'r') as fp:
text = fp.read()
row['text'] = text
pred_ls = row['predictionstring'].split(' ')
row['start_word'], row['end_word'] = int(pred_ls[0]), int(pred_ls[-1])
row['discourse_words'] = ' '.join(text.split()[row['start_word']:row['end_word'] + 1])
row['discourse_chars'] = text[int(row['discourse_start']):int(row['discourse_end'])]
return row
In [142]:
instance = get_instance_with_row(df, 39597)
instance
Out[142]:
In [20]:
from transformers import BigBirdTokenizerFast
tokenizer = BigBirdTokenizerFast.from_pretrained('allenai/longformer-large-4096')
def mapping_word_to_token(word_ids, word_start, word_end):
token_start, token_end = -1, -1
for idx, word_id in enumerate(word_ids):
if word_id == word_start:
token_start = idx
break
for idx, word_id in enumerate(word_ids):
if word_id == word_end:
token_end = idx
return token_start, token_end
In [ ]:
In [ ]:
In [10]:
text = instance['text']
In [ ]:
ss.split()
In [ ]:
(discourse_start, discourse_end) 与 predictionstring时什么关系?¶
In [ ]:
def stat_relation(df):
cnt = 0
stat = {}
err = []
for idx, item in tqdm.tqdm(df.iterrows()):
id_, discourse_id, start, end, discourse_text, type_, type_num, pred_str = item
file_name = os.path.join('/root/.cache/data/train', id_ + '.txt')
with open(file_name, 'r') as fp:
text = fp.read()
#print(discourse_text)
# text using char.
text_with_char = text[int(start):int(end)]
if discourse_text != text_with_char:
text_with_char = text[int(start):int(end) - 1]
if discourse_text[:-1] == text_with_char:
stat['char'] = stat.get('char', 0) + 1
else:
err.append(idx)
else:
stat['char'] = stat.get('char', 0) + 1
cnt += 1
if cnt == 6000000:
break
return stat
stat_relation(df)
统计predstring正确数¶
In [58]:
def stat_predstring(df):
cnt = 0
stat = {}
err = []
for idx, item in tqdm.tqdm(df.iterrows()):
id_, discourse_id, start_, end_, discourse_text, type_, type_num, pred_str = item
file_name = os.path.join('/root/.cache/data/train', id_ + '.txt')
with open(file_name, 'r') as fp:
text = fp.read()
pred_ls = pred_str.split(' ')
start, end = int(pred_ls[0]), int(pred_ls[-1])
text_pred = text.split()[start:end + 1]
text_pred = ' '.join(text_pred)
if discourse_text == text_pred:
stat['word'] = stat.get('word', 0) + 1
else:
print('='*200)
print(discourse_text)
print()
print(text_pred)
print('='*200)
cnt += 1
if cnt == 50:
break
return stat
Comments