|
from transformers import pipeline |
|
import json |
|
|
|
|
|
|
|
|
|
rag_pipeline = pipeline("text-generation", model="distilgpt2") |
|
|
|
|
|
with open("social_graph.json", "r") as f: |
|
kg = json.load(f) |
|
|
|
|
|
person = kg["people"]["billy"] |
|
context = person["context"] |
|
|
|
|
|
query = "What should I say to Billy?" |
|
|
|
|
|
prompt = """I am Will, a 38-year-old father with MND (Motor Neuron Disease). I have a 7-year-old son named Billy who loves Manchester United football. |
|
|
|
Billy just asked me: "Dad, did you see the United match last night?" |
|
|
|
My response to Billy:""" |
|
|
|
|
|
response = rag_pipeline( |
|
prompt, |
|
max_length=100, |
|
temperature=0.9, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
top_p=0.92, |
|
top_k=50, |
|
) |
|
print("Generated response:") |
|
|
|
generated_text = response[0]["generated_text"][len(prompt) :] |
|
print(generated_text) |
|
|