File size: 10,034 Bytes
f5b302e
 
 
 
 
deb6f27
 
 
f5b302e
 
 
deb6f27
f5b302e
 
deb6f27
f5b302e
 
 
 
 
deb6f27
f5b302e
 
deb6f27
 
 
f5b302e
 
 
 
deb6f27
f5b302e
 
 
 
 
deb6f27
f5b302e
deb6f27
f5b302e
 
 
 
deb6f27
f5b302e
 
 
 
 
deb6f27
f5b302e
 
 
 
 
deb6f27
f5b302e
 
 
 
deb6f27
 
 
 
f5b302e
 
 
 
deb6f27
 
 
 
 
 
 
f5b302e
deb6f27
f5b302e
 
deb6f27
 
 
 
 
 
 
 
 
 
 
 
f5b302e
 
deb6f27
 
 
 
 
 
 
f5b302e
 
 
deb6f27
f5b302e
 
deb6f27
f5b302e
 
 
deb6f27
f5b302e
 
 
deb6f27
f5b302e
 
 
 
 
 
deb6f27
f5b302e
 
 
 
deb6f27
f5b302e
 
 
deb6f27
f5b302e
 
 
deb6f27
f5b302e
 
deb6f27
f5b302e
 
deb6f27
f5b302e
 
deb6f27
 
 
 
f5b302e
 
deb6f27
f5b302e
 
deb6f27
 
f5b302e
deb6f27
f5b302e
 
 
deb6f27
 
 
f5b302e
deb6f27
 
 
f5b302e
deb6f27
f5b302e
deb6f27
f5b302e
deb6f27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5b302e
deb6f27
f5b302e
 
 
 
deb6f27
 
f5b302e
 
 
 
deb6f27
 
 
 
 
 
f5b302e
 
 
 
 
deb6f27
 
f5b302e
deb6f27
 
f5b302e
deb6f27
 
 
 
 
 
 
 
 
 
f5b302e
deb6f27
 
 
 
f5b302e
 
deb6f27
 
 
 
 
 
 
 
 
 
 
 
 
f5b302e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import json
import random
from typing import Dict, List, Any, Optional, Tuple
from sentence_transformers import SentenceTransformer
import numpy as np

from transformers import pipeline


class SocialGraphManager:
    """Manages the social graph and provides context for the AAC system."""

    def __init__(self, graph_path: str = "social_graph.json"):
        """Initialize the social graph manager.

        Args:
            graph_path: Path to the social graph JSON file
        """
        self.graph_path = graph_path
        self.graph = self._load_graph()

        # Initialize sentence transformer for semantic matching
        try:
            self.sentence_model = SentenceTransformer(
                "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
            )
            self.embeddings_cache = {}
            self._initialize_embeddings()
        except Exception as e:
            self.sentence_model = None

    def _load_graph(self) -> Dict[str, Any]:
        """Load the social graph from the JSON file."""
        try:
            with open(self.graph_path, "r") as f:
                return json.load(f)
        except Exception:
            return {"people": {}, "places": [], "topics": []}

    def _initialize_embeddings(self):
        """Initialize embeddings for topics and phrases in the social graph."""
        if not self.sentence_model:
            return

        # Create embeddings for topics
        topics = self.graph.get("topics", [])
        for topic in topics:
            if topic not in self.embeddings_cache:
                self.embeddings_cache[topic] = self.sentence_model.encode(topic)

        # Create embeddings for common phrases
        for person_id, person_data in self.graph.get("people", {}).items():
            for phrase in person_data.get("common_phrases", []):
                if phrase not in self.embeddings_cache:
                    self.embeddings_cache[phrase] = self.sentence_model.encode(phrase)

        # Create embeddings for common utterances
        for category, utterances in self.graph.get("common_utterances", {}).items():
            for utterance in utterances:
                if utterance not in self.embeddings_cache:
                    self.embeddings_cache[utterance] = self.sentence_model.encode(
                        utterance
                    )

    def get_people_list(self) -> List[Dict[str, str]]:
        """Get a list of people from the social graph with their names and roles."""
        people = []
        for person_id, person_data in self.graph.get("people", {}).items():
            people.append(
                {
                    "id": person_id,
                    "name": person_data.get("name", person_id),
                    "role": person_data.get("role", ""),
                }
            )
        return people

    def get_person_context(self, person_id: str) -> Dict[str, Any]:
        """Get context information for a specific person."""
        # Check if the person_id contains a display name (e.g., "Emma (wife)")
        # and try to extract the actual ID
        if person_id not in self.graph.get("people", {}):
            # Try to find the person by name
            for pid, pdata in self.graph.get("people", {}).items():
                name = pdata.get("name", "")
                role = pdata.get("role", "")
                if f"{name} ({role})" == person_id:
                    person_id = pid
                    break

        # If still not found, return empty dict
        if person_id not in self.graph.get("people", {}):
            return {}

        person_data = self.graph["people"][person_id]
        return person_data

    def get_relevant_phrases(
        self, person_id: str, user_input: Optional[str] = None
    ) -> List[str]:
        """Get relevant phrases for a specific person based on user input."""
        if person_id not in self.graph.get("people", {}):
            return []

        person_data = self.graph["people"][person_id]
        phrases = person_data.get("common_phrases", [])

        # If no user input, return random phrases
        if not user_input or not self.sentence_model:
            return random.sample(phrases, min(3, len(phrases)))

        # Use semantic search to find relevant phrases
        user_embedding = self.sentence_model.encode(user_input)
        phrase_scores = []

        for phrase in phrases:
            if phrase in self.embeddings_cache:
                phrase_embedding = self.embeddings_cache[phrase]
            else:
                phrase_embedding = self.sentence_model.encode(phrase)
                self.embeddings_cache[phrase] = phrase_embedding

            similarity = np.dot(user_embedding, phrase_embedding) / (
                np.linalg.norm(user_embedding) * np.linalg.norm(phrase_embedding)
            )
            phrase_scores.append((phrase, similarity))

        # Sort by similarity score and return top phrases
        phrase_scores.sort(key=lambda x: x[1], reverse=True)
        return [phrase for phrase, _ in phrase_scores[:3]]

    def get_common_utterances(self, category: Optional[str] = None) -> List[str]:
        """Get common utterances from the social graph, optionally filtered by category."""
        utterances = []

        if "common_utterances" not in self.graph:
            return utterances

        if category and category in self.graph["common_utterances"]:
            return self.graph["common_utterances"][category]

        # If no category specified, return a sample from each category
        for category_utterances in self.graph["common_utterances"].values():
            utterances.extend(
                random.sample(category_utterances, min(2, len(category_utterances)))
            )

        return utterances


class SuggestionGenerator:
    """Generates contextual suggestions for the AAC system."""

    def __init__(self, model_name: str = "distilgpt2"):
        """Initialize the suggestion generator.

        Args:
            model_name: Name of the HuggingFace model to use
        """
        self.model_name = model_name
        self.model_loaded = False

        try:
            print(f"Loading model: {model_name}")
            # Use a simpler approach with a pre-built pipeline
            self.generator = pipeline("text-generation", model=model_name)
            self.model_loaded = True
            print(f"Model loaded successfully: {model_name}")
        except Exception as e:
            print(f"Error loading model: {e}")
            self.model_loaded = False

        # Fallback responses if model fails to load or generate
        self.fallback_responses = [
            "I'm not sure how to respond to that.",
            "That's interesting. Tell me more.",
            "I'd like to talk about that further.",
            "I appreciate you sharing that with me.",
        ]

    def test_model(self) -> str:
        """Test if the model is working correctly."""
        if not self.model_loaded:
            return "Model not loaded"

        try:
            test_prompt = "I am Will. My son Billy asked about football. I respond:"
            print(f"Testing model with prompt: {test_prompt}")
            response = self.generator(test_prompt, max_length=30, do_sample=True)
            result = response[0]["generated_text"][len(test_prompt) :]
            print(f"Test response: {result}")
            return f"Model test successful: {result}"
        except Exception as e:
            print(f"Error testing model: {e}")
            return f"Model test failed: {str(e)}"

    def generate_suggestion(
        self,
        person_context: Dict[str, Any],
        user_input: Optional[str] = None,
        max_length: int = 50,
        temperature: float = 0.7,
    ) -> str:
        """Generate a contextually appropriate suggestion.

        Args:
            person_context: Context information about the person
            user_input: Optional user input to consider
            max_length: Maximum length of the generated suggestion
            temperature: Controls randomness in generation (higher = more random)

        Returns:
            A generated suggestion string
        """
        if not self.model_loaded:
            # Use fallback responses if model isn't loaded
            import random

            print("Model not loaded, using fallback responses")
            return random.choice(self.fallback_responses)

        # Extract context information
        name = person_context.get("name", "")
        role = person_context.get("role", "")
        topics = ", ".join(person_context.get("topics", []))
        context = person_context.get("context", "")
        selected_topic = person_context.get("selected_topic", "")

        # Build prompt
        prompt = f"""I am Will, a person with MND (Motor Neuron Disease).
I'm talking to {name}, who is my {role}.
"""

        if context:
            prompt += f"Context: {context}\n"

        if topics:
            prompt += f"Topics of interest: {topics}\n"

        if selected_topic:
            prompt += f"We're currently talking about: {selected_topic}\n"

        if user_input:
            prompt += f'\n{name} just said to me: "{user_input}"\n'

        prompt += "\nMy response:"

        # Generate suggestion
        try:
            print(f"Generating suggestion with prompt: {prompt}")
            response = self.generator(
                prompt,
                max_length=len(prompt.split()) + max_length,
                temperature=temperature,
                do_sample=True,
                top_p=0.92,
                top_k=50,
            )
            # Extract only the generated part, not the prompt
            result = response[0]["generated_text"][len(prompt) :]
            print(f"Generated response: {result}")
            return result.strip()
        except Exception as e:
            print(f"Error generating suggestion: {e}")
            return "Could not generate a suggestion. Please try again."