|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import json |
|
|
|
|
|
model_name = "google/flan-t5-base" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
rag_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer) |
|
|
|
|
|
with open("social_graph.json", "r") as f: |
|
kg = json.load(f) |
|
|
|
|
|
person = kg["people"]["bob"] |
|
context = f"Bob is the user's son. They talk about football weekly. Last conversation was about coaching changes." |
|
|
|
|
|
query = "What should I say to Bob?" |
|
|
|
|
|
prompt = f"""Context: {context} |
|
User wants to say something appropriate to Bob. Suggest a phrase:""" |
|
|
|
|
|
response = rag_pipeline(prompt, max_length=50) |
|
print(response[0]["generated_text"]) |
|
|