File size: 10,034 Bytes
f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e deb6f27 f5b302e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import json
import random
from typing import Dict, List, Any, Optional, Tuple
from sentence_transformers import SentenceTransformer
import numpy as np
from transformers import pipeline
class SocialGraphManager:
"""Manages the social graph and provides context for the AAC system."""
def __init__(self, graph_path: str = "social_graph.json"):
"""Initialize the social graph manager.
Args:
graph_path: Path to the social graph JSON file
"""
self.graph_path = graph_path
self.graph = self._load_graph()
# Initialize sentence transformer for semantic matching
try:
self.sentence_model = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
self.embeddings_cache = {}
self._initialize_embeddings()
except Exception as e:
self.sentence_model = None
def _load_graph(self) -> Dict[str, Any]:
"""Load the social graph from the JSON file."""
try:
with open(self.graph_path, "r") as f:
return json.load(f)
except Exception:
return {"people": {}, "places": [], "topics": []}
def _initialize_embeddings(self):
"""Initialize embeddings for topics and phrases in the social graph."""
if not self.sentence_model:
return
# Create embeddings for topics
topics = self.graph.get("topics", [])
for topic in topics:
if topic not in self.embeddings_cache:
self.embeddings_cache[topic] = self.sentence_model.encode(topic)
# Create embeddings for common phrases
for person_id, person_data in self.graph.get("people", {}).items():
for phrase in person_data.get("common_phrases", []):
if phrase not in self.embeddings_cache:
self.embeddings_cache[phrase] = self.sentence_model.encode(phrase)
# Create embeddings for common utterances
for category, utterances in self.graph.get("common_utterances", {}).items():
for utterance in utterances:
if utterance not in self.embeddings_cache:
self.embeddings_cache[utterance] = self.sentence_model.encode(
utterance
)
def get_people_list(self) -> List[Dict[str, str]]:
"""Get a list of people from the social graph with their names and roles."""
people = []
for person_id, person_data in self.graph.get("people", {}).items():
people.append(
{
"id": person_id,
"name": person_data.get("name", person_id),
"role": person_data.get("role", ""),
}
)
return people
def get_person_context(self, person_id: str) -> Dict[str, Any]:
"""Get context information for a specific person."""
# Check if the person_id contains a display name (e.g., "Emma (wife)")
# and try to extract the actual ID
if person_id not in self.graph.get("people", {}):
# Try to find the person by name
for pid, pdata in self.graph.get("people", {}).items():
name = pdata.get("name", "")
role = pdata.get("role", "")
if f"{name} ({role})" == person_id:
person_id = pid
break
# If still not found, return empty dict
if person_id not in self.graph.get("people", {}):
return {}
person_data = self.graph["people"][person_id]
return person_data
def get_relevant_phrases(
self, person_id: str, user_input: Optional[str] = None
) -> List[str]:
"""Get relevant phrases for a specific person based on user input."""
if person_id not in self.graph.get("people", {}):
return []
person_data = self.graph["people"][person_id]
phrases = person_data.get("common_phrases", [])
# If no user input, return random phrases
if not user_input or not self.sentence_model:
return random.sample(phrases, min(3, len(phrases)))
# Use semantic search to find relevant phrases
user_embedding = self.sentence_model.encode(user_input)
phrase_scores = []
for phrase in phrases:
if phrase in self.embeddings_cache:
phrase_embedding = self.embeddings_cache[phrase]
else:
phrase_embedding = self.sentence_model.encode(phrase)
self.embeddings_cache[phrase] = phrase_embedding
similarity = np.dot(user_embedding, phrase_embedding) / (
np.linalg.norm(user_embedding) * np.linalg.norm(phrase_embedding)
)
phrase_scores.append((phrase, similarity))
# Sort by similarity score and return top phrases
phrase_scores.sort(key=lambda x: x[1], reverse=True)
return [phrase for phrase, _ in phrase_scores[:3]]
def get_common_utterances(self, category: Optional[str] = None) -> List[str]:
"""Get common utterances from the social graph, optionally filtered by category."""
utterances = []
if "common_utterances" not in self.graph:
return utterances
if category and category in self.graph["common_utterances"]:
return self.graph["common_utterances"][category]
# If no category specified, return a sample from each category
for category_utterances in self.graph["common_utterances"].values():
utterances.extend(
random.sample(category_utterances, min(2, len(category_utterances)))
)
return utterances
class SuggestionGenerator:
"""Generates contextual suggestions for the AAC system."""
def __init__(self, model_name: str = "distilgpt2"):
"""Initialize the suggestion generator.
Args:
model_name: Name of the HuggingFace model to use
"""
self.model_name = model_name
self.model_loaded = False
try:
print(f"Loading model: {model_name}")
# Use a simpler approach with a pre-built pipeline
self.generator = pipeline("text-generation", model=model_name)
self.model_loaded = True
print(f"Model loaded successfully: {model_name}")
except Exception as e:
print(f"Error loading model: {e}")
self.model_loaded = False
# Fallback responses if model fails to load or generate
self.fallback_responses = [
"I'm not sure how to respond to that.",
"That's interesting. Tell me more.",
"I'd like to talk about that further.",
"I appreciate you sharing that with me.",
]
def test_model(self) -> str:
"""Test if the model is working correctly."""
if not self.model_loaded:
return "Model not loaded"
try:
test_prompt = "I am Will. My son Billy asked about football. I respond:"
print(f"Testing model with prompt: {test_prompt}")
response = self.generator(test_prompt, max_length=30, do_sample=True)
result = response[0]["generated_text"][len(test_prompt) :]
print(f"Test response: {result}")
return f"Model test successful: {result}"
except Exception as e:
print(f"Error testing model: {e}")
return f"Model test failed: {str(e)}"
def generate_suggestion(
self,
person_context: Dict[str, Any],
user_input: Optional[str] = None,
max_length: int = 50,
temperature: float = 0.7,
) -> str:
"""Generate a contextually appropriate suggestion.
Args:
person_context: Context information about the person
user_input: Optional user input to consider
max_length: Maximum length of the generated suggestion
temperature: Controls randomness in generation (higher = more random)
Returns:
A generated suggestion string
"""
if not self.model_loaded:
# Use fallback responses if model isn't loaded
import random
print("Model not loaded, using fallback responses")
return random.choice(self.fallback_responses)
# Extract context information
name = person_context.get("name", "")
role = person_context.get("role", "")
topics = ", ".join(person_context.get("topics", []))
context = person_context.get("context", "")
selected_topic = person_context.get("selected_topic", "")
# Build prompt
prompt = f"""I am Will, a person with MND (Motor Neuron Disease).
I'm talking to {name}, who is my {role}.
"""
if context:
prompt += f"Context: {context}\n"
if topics:
prompt += f"Topics of interest: {topics}\n"
if selected_topic:
prompt += f"We're currently talking about: {selected_topic}\n"
if user_input:
prompt += f'\n{name} just said to me: "{user_input}"\n'
prompt += "\nMy response:"
# Generate suggestion
try:
print(f"Generating suggestion with prompt: {prompt}")
response = self.generator(
prompt,
max_length=len(prompt.split()) + max_length,
temperature=temperature,
do_sample=True,
top_p=0.92,
top_k=50,
)
# Extract only the generated part, not the prompt
result = response[0]["generated_text"][len(prompt) :]
print(f"Generated response: {result}")
return result.strip()
except Exception as e:
print(f"Error generating suggestion: {e}")
return "Could not generate a suggestion. Please try again."
|