|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
import jsonlines |
|
import json |
|
import copy |
|
import glob |
|
import random |
|
import fire |
|
|
|
|
|
class Converter(ABC): |
|
|
|
def __init__(self, filepath) -> None: |
|
super().__init__() |
|
|
|
self.filepath = filepath |
|
|
|
def convert(self): |
|
""" |
|
Implement your convert logics in this function |
|
""" |
|
self.start() |
|
self.process() |
|
self.end() |
|
pass |
|
|
|
def start(self): |
|
print(f'Start processing {self.__class__.__name__} at {self.filepath}') |
|
|
|
def end(self): |
|
print( |
|
f'Finish processing {self.__class__.__name__} at {self.filepath}') |
|
|
|
@abstractmethod |
|
def process(self): |
|
""" |
|
Implement your convert logics in this function |
|
""" |
|
|
|
|
|
class DSTC7Converter(Converter): |
|
|
|
''' |
|
Converter class for DSTC7 Grounded response generation |
|
''' |
|
|
|
def process(self): |
|
|
|
convs = open(self.filepath) |
|
examples = [] |
|
for conv in convs: |
|
_, c_id, score, facts, context, response = conv.split('\t') |
|
example = {} |
|
if context.strip() == 'START': |
|
continue |
|
context = context.replace('START EOS TIL ', '') |
|
example['Context'] = context.strip() |
|
example['Knowledge'] = facts.replace( |
|
' < p > ', '').replace(' < /p > ', '').strip() |
|
example['Response'] = response.strip() |
|
examples.append(copy.deepcopy(example)) |
|
|
|
with jsonlines.open('../data/dstc7.jsonl', mode='w') as writer: |
|
for i in examples: |
|
writer.write(i) |
|
|
|
return |
|
|
|
|
|
class MSMARCOConverter(Converter): |
|
|
|
''' |
|
Converter class for MS MARCO |
|
''' |
|
|
|
def process(self): |
|
|
|
train_data = json.load(open(self.filepath)) |
|
examples = [] |
|
for ids in train_data['query'].keys(): |
|
query, answer, passage = train_data['query'][ids], train_data['answers'][ids], train_data['passages'][ids] |
|
knowledge = [i['passage_text'] |
|
for i in passage if i['is_selected']] |
|
example = {} |
|
example['Context'] = query.strip() |
|
example['Knowledge'] = ' '.join(knowledge) |
|
example['Response'] = ' '.join(answer).strip() |
|
examples.append(copy.deepcopy(example)) |
|
|
|
with jsonlines.open('../data/msmarco.jsonl', mode='w') as writer: |
|
for i in examples: |
|
writer.write(i) |
|
|
|
return |
|
|
|
|
|
class UnifiedQAConverter(Converter): |
|
|
|
def process(self): |
|
|
|
examples = [] |
|
for fname in glob.glob(f'{self.filepath}/*/*'): |
|
if 'train.tsv' in fname or 'test.tsv' in fname: |
|
data = open(fname) |
|
for line in data: |
|
line = line.strip() |
|
try: |
|
question, answer = line.split('\t') |
|
question, story = question.split('\\n') |
|
example = {} |
|
example['Context'] = question |
|
example['Response'] = answer |
|
example['Knowledge'] = story |
|
examples.append(copy.deepcopy(example)) |
|
k += 1 |
|
except: |
|
pass |
|
|
|
train_writer = jsonlines.open('../data/unifiedqa.jsonl', mode='w') |
|
for i in examples: |
|
train_writer.write(i) |
|
|
|
return |
|
|
|
|
|
class SGDConverter(Converter): |
|
|
|
''' |
|
Converter class for SGD dataset |
|
''' |
|
|
|
def process(self): |
|
|
|
examples = [] |
|
for split in ['train', 'dev', 'test']: |
|
schema_info = json.load( |
|
open(f'{self.filepath}/{split}/schema.json')) |
|
schema_info = dict([(i['service_name'], i) for i in schema_info]) |
|
for file in glob.glob(f'{self.filepath}/{split}/dialogues_*.json'): |
|
data = json.load(open(file)) |
|
for dialogue in data: |
|
dialogue_id = dialogue['dialogue_id'] |
|
services = dialogue['services'][0] |
|
schema = schema_info[services] |
|
description = schema['description'] |
|
task_slots = [s['name'] for s in schema['slots']] |
|
task_intents = [s['name'] for s in schema['intents']] |
|
task_intents_description = [ |
|
s['description'] for s in schema['intents']] |
|
turns = dialogue['turns'] |
|
history = [] |
|
example = {} |
|
for idx, turn in enumerate(turns): |
|
if idx == 0: |
|
assert turn['speaker'] == 'USER' |
|
frame = turn['frames'][0] |
|
service = turn['frames'][0]['service'].split('_')[ |
|
0].lower() |
|
if turn['speaker'] == 'USER': |
|
user_utter = turn['utterance'] |
|
history.append(f'{user_utter}') |
|
belief_slot_values = frame['state']['slot_values'] |
|
slot_values_list = [] |
|
for slot_value in belief_slot_values.items(): |
|
slot, values = slot_value |
|
value = values[0] |
|
slot_values_list.append(f'{slot} = {value}') |
|
slot_values_str = ' ; '.join(slot_values_list) |
|
|
|
else: |
|
sys_utter = copy.copy(turn['utterance']) |
|
slot_values_str = f'belief : {service} {slot_values_str}' |
|
|
|
slots = frame['slots'] |
|
offset = 0 |
|
len_ = len(sys_utter) |
|
candidates = [] |
|
for idx, slot_info in enumerate(slots): |
|
start, end, slot_name = slot_info['start'], slot_info['exclusive_end'], slot_info['slot'] |
|
sys_utter = sys_utter[:start+offset] + str( |
|
idx) * (end - start) + sys_utter[end+offset:] |
|
candidates.append( |
|
(slot_name, str(idx) * (end - start))) |
|
for idx, info in enumerate(candidates): |
|
slotname, target = info |
|
sys_utter = sys_utter.replace( |
|
target, f'[{slotname}]') |
|
|
|
reply = f'{sys_utter}' |
|
example['Context'] = ' EOS '.join(history) |
|
example['Knowledge'] = slot_values_str |
|
example['Response'] = reply |
|
examples.append(copy.deepcopy(example)) |
|
history.append(reply) |
|
|
|
train_writer = jsonlines.open('../data/sgd.jsonl', mode='w') |
|
for i in examples: |
|
train_writer.write(i) |
|
|
|
return |
|
|
|
|
|
def merge_and_split(): |
|
|
|
examples = [] |
|
filepath = '../data/dstc7.jsonl' |
|
with open(filepath, "r", encoding="utf-8") as reader: |
|
for item in jsonlines.Reader(reader): |
|
examples.append(item) |
|
|
|
filepath = '../data/msmarco.jsonl' |
|
with open(filepath, "r", encoding="utf-8") as reader: |
|
for item in jsonlines.Reader(reader): |
|
examples.append(item) |
|
|
|
filepath = '../data/sgd.jsonl' |
|
with open(filepath, "r", encoding="utf-8") as reader: |
|
for item in jsonlines.Reader(reader): |
|
examples.append(item) |
|
|
|
filepath = '../data/unifiedqa.jsonl' |
|
with open(filepath, "r", encoding="utf-8") as reader: |
|
for item in jsonlines.Reader(reader): |
|
examples.append(item) |
|
|
|
random.seed(2021) |
|
train_writer = jsonlines.open( |
|
'../data/grounded_data_train.jsonl', mode='w') |
|
valid_writer = jsonlines.open( |
|
'../data/grounded_data_valid.jsonl', mode='w') |
|
for i in examples: |
|
if random.random() < 0.01: |
|
valid_writer.write(i) |
|
else: |
|
train_writer.write(i) |
|
|
|
print('Done!') |
|
|
|
|
|
def process( |
|
msmarco_path, |
|
sgd_path, |
|
dstc7_path, |
|
unified_qa_path |
|
): |
|
MSMARCOConverter(f'{msmarco_path}/train_v2.1.json').convert() |
|
SGDConverter(f'{sgd_path}').convert() |
|
DSTC7Converter(f'{dstc7_path}').convert() |
|
UnifiedQAConverter(unified_qa_path).convert() |
|
|
|
|
|
def main(): |
|
fire.Fire(process) |
|
|
|
merge_and_split() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|