|
|
|
|
|
|
|
|
|
|
|
import jsonlines |
|
import fire |
|
|
|
|
|
def _norm_text(text): |
|
w, *toks = text.strip().split() |
|
try: |
|
w = float(w) |
|
except Exception: |
|
toks = [w] + toks |
|
w = 1.0 |
|
return w, ' '.join(toks) |
|
|
|
|
|
def _get_inputs_from_text(text): |
|
srcs, tgt = text.strip().split('\t') |
|
weights = [] |
|
inputs = [] |
|
for src in srcs.split(' EOS '): |
|
src_weight, src = _norm_text(src) |
|
weights.append(src_weight) |
|
inputs.append(src) |
|
tgt_weight, tgt = _norm_text(tgt) |
|
if tgt_weight != 0: |
|
weights.append(tgt_weight) |
|
inputs.append(tgt) |
|
return weights, inputs |
|
|
|
|
|
def process(reddit_path): |
|
|
|
idx = 0 |
|
writer = jsonlines.open('../data/reddit_session_level.jsonl', 'w') |
|
with open(reddit_path, "r", encoding="utf-8") as reader: |
|
for line in reader: |
|
idx += 1 |
|
if idx % 10000 == 0: |
|
print(idx) |
|
weights, inputs = _get_inputs_from_text(line) |
|
if 0.0 in weights: |
|
continue |
|
else: |
|
writer.write({'text': ' EOS '.join(inputs)}) |
|
|
|
idx = 0 |
|
with open('../data/reddit_session_level.jsonl', "r", encoding="utf-8") as reader: |
|
writer = jsonlines.open('../data/reddit.jsonl', mode='w') |
|
for item in jsonlines.Reader(reader): |
|
idx += 1 |
|
if idx % 10000 == 0: |
|
print(idx) |
|
context = item['text'].split('EOS') |
|
|
|
for idx in range(0, len(context)-1): |
|
|
|
history = 'EOS'.join(context[:idx+1]) |
|
response = context[idx+1] |
|
|
|
if len(history) == 0: |
|
continue |
|
|
|
example = {} |
|
example['Context'] = history |
|
example['Knowledge'] = '' |
|
example['Response'] = response.strip() |
|
|
|
writer.write(example) |
|
|
|
|
|
def main(): |
|
fire.Fire(process) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|