willwade commited on
Commit
b929813
·
1 Parent(s): deb6f27

fixing demo

Browse files
Files changed (1) hide show
  1. demo.py +706 -28
demo.py CHANGED
@@ -1,40 +1,718 @@
1
  from transformers import pipeline
2
  import json
 
 
 
 
 
 
 
3
 
4
- # Load model
5
 
6
- # Use a simpler approach with a pre-built pipeline
7
- rag_pipeline = pipeline("text-generation", model="distilgpt2")
 
 
8
 
9
- # Load KG
10
- with open("social_graph.json", "r") as f:
11
- kg = json.load(f)
 
 
 
 
 
 
 
12
 
13
- # Build context
14
- person = kg["people"]["billy"] # Using Billy instead of Bob
15
- context = person["context"]
16
 
17
- # User input
18
- query = "What should I say to Billy?"
19
 
20
- # RAG-style prompt
21
- prompt = """I am Will, a 38-year-old father with MND (Motor Neuron Disease). I have a 7-year-old son named Billy who loves Manchester United football.
 
 
22
 
23
- Billy just asked me: "Dad, did you see the United match last night?"
24
 
25
- My response to Billy:"""
 
 
 
 
 
 
 
 
26
 
27
- # Generate
28
- response = rag_pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  prompt,
30
- max_length=100, # Longer output
31
- temperature=0.9, # More creative
32
- do_sample=True,
33
- num_return_sequences=1,
34
- top_p=0.92, # More focused sampling
35
- top_k=50, # Limit vocabulary
36
- )
37
- print("Generated response:")
38
- # For text-generation models, we need to extract just the generated part (not the prompt)
39
- generated_text = response[0]["generated_text"][len(prompt) :]
40
- print(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
  import json
3
+ import argparse
4
+ import os
5
+ import sys
6
+ import subprocess
7
+ import requests
8
+ from typing import List, Dict, Any, Optional, Union
9
+ import time
10
 
 
11
 
12
+ # Check for Hugging Face token
13
+ def check_hf_token():
14
+ """Check if a Hugging Face token is properly set up."""
15
+ token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
16
 
17
+ if not token:
18
+ print("\nWarning: No Hugging Face token found in environment variables.")
19
+ print(
20
+ "To use gated models like Gemma, you need to set up a token with the right permissions."
21
+ )
22
+ print("1. Create a token at https://huggingface.co/settings/tokens")
23
+ print("2. Make sure to enable 'Access to public gated repositories'")
24
+ print("3. Set it as an environment variable:")
25
+ print(" export HUGGING_FACE_HUB_TOKEN=your_token_here")
26
+ return False
27
 
28
+ return True
 
 
29
 
 
 
30
 
31
+ def load_social_graph(file_path="social_graph.json"):
32
+ """Load the social graph from a JSON file."""
33
+ with open(file_path, "r") as f:
34
+ return json.load(f)
35
 
 
36
 
37
+ def get_person_info(social_graph, person_id):
38
+ """Get information about a person from the social graph."""
39
+ if person_id in social_graph["people"]:
40
+ return social_graph["people"][person_id]
41
+ else:
42
+ available_people = ", ".join(social_graph["people"].keys())
43
+ raise ValueError(
44
+ f"Person '{person_id}' not found in social graph. Available people: {available_people}"
45
+ )
46
 
47
+
48
+ def build_enhanced_prompt(social_graph, person_id, topic=None, user_message=None):
49
+ """Build an enhanced prompt using social graph information."""
50
+ # Get AAC user information
51
+ aac_user = social_graph["aac_user"]
52
+
53
+ # Get conversation partner information
54
+ person = get_person_info(social_graph, person_id)
55
+
56
+ # Start building the prompt with AAC user information
57
+ prompt = f"""I am {aac_user['name']}, a {aac_user['age']}-year-old with MND (Motor Neuron Disease) from {aac_user['location']}.
58
+ {aac_user['background']}
59
+
60
+ My communication needs: {aac_user['communication_needs']}
61
+
62
+ I am talking to {person['name']}, who is my {person['role']}.
63
+ About {person['name']}: {person['context']}
64
+ We typically talk about: {', '.join(person['topics'])}
65
+ We communicate {person['frequency']}.
66
+ """
67
+
68
+ # Add places information if available
69
+ if "places" in social_graph:
70
+ relevant_places = social_graph["places"][
71
+ :3
72
+ ] # Just use a few places for context
73
+ prompt += f"\nPlaces important to me: {', '.join(relevant_places)}\n"
74
+
75
+ # Add communication style based on relationship
76
+ if person["role"] in ["wife", "son", "daughter", "mother", "father"]:
77
+ prompt += "I communicate with my family in a warm, loving way, sometimes using inside jokes.\n"
78
+ elif person["role"] in ["doctor", "therapist", "nurse"]:
79
+ prompt += (
80
+ "I communicate with healthcare providers in a direct, informative way.\n"
81
+ )
82
+ elif person["role"] in ["best mate", "friend"]:
83
+ prompt += "I communicate with friends casually, often with humor and sometimes swearing.\n"
84
+ elif person["role"] in ["work colleague", "boss"]:
85
+ prompt += "I communicate with colleagues professionally but still friendly.\n"
86
+
87
+ # Add common utterances by category if available
88
+ if "common_utterances" in social_graph:
89
+ # Try to find relevant utterance category based on topic
90
+ utterance_category = None
91
+ if topic == "football" or topic == "sports":
92
+ utterance_category = "sports_talk"
93
+ elif topic == "programming" or topic == "tech news":
94
+ utterance_category = "tech_talk"
95
+ elif topic in ["family plans", "children's activities"]:
96
+ utterance_category = "family_talk"
97
+
98
+ # Add relevant utterances if category exists
99
+ if (
100
+ utterance_category
101
+ and utterance_category in social_graph["common_utterances"]
102
+ ):
103
+ utterances = social_graph["common_utterances"][utterance_category][:2]
104
+ prompt += f"\nI might say things like: {' or '.join(utterances)}\n"
105
+
106
+ # Add topic information if provided
107
+ if topic and topic in person["topics"]:
108
+ prompt += f"\nWe are currently discussing {topic}.\n"
109
+
110
+ # Add specific context about this topic with this person
111
+ if topic == "football" and "Manchester United" in person["context"]:
112
+ prompt += (
113
+ "We both support Manchester United and often discuss recent matches.\n"
114
+ )
115
+ elif topic == "programming" and "software developer" in person["context"]:
116
+ prompt += (
117
+ "We both work in software development and share technical interests.\n"
118
+ )
119
+ elif topic == "family plans" and person["role"] in ["wife", "husband"]:
120
+ prompt += "We make family decisions together, considering my condition.\n"
121
+ elif topic == "old scout adventures" and person["role"] == "best mate":
122
+ prompt += "We often reminisce about our Scout camping trips in South East London.\n"
123
+ elif topic == "cycling" and "cycling" in person["context"]:
124
+ prompt += "I miss being able to cycle but enjoy talking about past cycling adventures.\n"
125
+
126
+ # Add shared experiences based on relationship and topic
127
+ if person["role"] == "best mate" and topic in ["football", "pub quizzes"]:
128
+ prompt += (
129
+ "We've watched many matches together and done countless pub quizzes.\n"
130
+ )
131
+ elif person["role"] == "wife" and topic in ["family plans", "weekend outings"]:
132
+ prompt += "Emma has been amazing at keeping family life as normal as possible despite my condition.\n"
133
+ elif person["role"] == "son" and topic == "football":
134
+ prompt += "I try to stay engaged with Billy's football enthusiasm even as my condition progresses.\n"
135
+
136
+ # Add the user's message if provided
137
+ if user_message:
138
+ prompt += f"\n{person['name']} just said to me: \"{user_message}\"\n"
139
+ else:
140
+ # Use a common phrase from the person if no message is provided
141
+ if person["common_phrases"]:
142
+ default_message = person["common_phrases"][0]
143
+ prompt += f"\n{person['name']} just said to me: \"{default_message}\"\n"
144
+
145
+ # Add the response prompt with specific guidance
146
+ prompt += f"""
147
+ I want to respond to {person['name']} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said. I'll use casual language with some humor since we're close friends.
148
+
149
+ My response to {person['name']}:"""
150
+
151
+ return prompt
152
+
153
+
154
+ class LLMInterface:
155
+ """Base interface for language model generation."""
156
+
157
+ def __init__(self, model_name, max_length=150, temperature=0.9):
158
+ """Initialize the LLM interface.
159
+
160
+ Args:
161
+ model_name: Name or path of the model
162
+ max_length: Maximum length of generated text
163
+ temperature: Controls randomness (higher = more random)
164
+ """
165
+ self.model_name = model_name
166
+ self.max_length = max_length
167
+ self.temperature = temperature
168
+
169
+ def generate(self, prompt, num_responses=3):
170
+ """Generate responses for the given prompt.
171
+
172
+ Args:
173
+ prompt: The prompt to generate responses for
174
+ num_responses: Number of responses to generate
175
+
176
+ Returns:
177
+ A list of generated responses
178
+ """
179
+ raise NotImplementedError("Subclasses must implement this method")
180
+
181
+ def cleanup_response(self, text):
182
+ """Clean up a generated response.
183
+
184
+ Args:
185
+ text: The raw generated text
186
+
187
+ Returns:
188
+ Cleaned up text
189
+ """
190
+ # Make sure it's a complete sentence or phrase
191
+ # If it ends abruptly, add an ellipsis
192
+ if text and not any(text.endswith(end) for end in [".", "!", "?", '..."']):
193
+ if text.endswith('"'):
194
+ text = text[:-1] + '..."'
195
+ else:
196
+ text += "..."
197
+
198
+ return text
199
+
200
+
201
+ class HuggingFaceInterface(LLMInterface):
202
+ """Interface for Hugging Face Transformers models."""
203
+
204
+ def __init__(self, model_name="distilgpt2", max_length=150, temperature=0.9):
205
+ """Initialize the Hugging Face interface."""
206
+ super().__init__(model_name, max_length, temperature)
207
+ try:
208
+ # Check if we're dealing with a gated model
209
+ is_gated_model = any(
210
+ name in model_name for name in ["gemma", "llama", "mistral"]
211
+ )
212
+
213
+ # Get token from environment
214
+ import os
215
+
216
+ token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
217
+ "HF_TOKEN"
218
+ )
219
+
220
+ if is_gated_model and token:
221
+ print(f"Using token for gated model: {model_name}")
222
+ from huggingface_hub import login
223
+
224
+ login(token=token, add_to_git_credential=False)
225
+
226
+ # Explicitly pass token to pipeline
227
+ from transformers import AutoTokenizer, AutoModelForCausalLM
228
+
229
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
230
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
231
+ self.pipeline = pipeline(
232
+ "text-generation", model=model, tokenizer=tokenizer
233
+ )
234
+ else:
235
+ self.pipeline = pipeline("text-generation", model=model_name)
236
+
237
+ print(f"Successfully loaded model: {model_name}")
238
+ except Exception as e:
239
+ print(f"Error loading model {model_name}: {e}")
240
+ if "gated" in str(e).lower() or "403" in str(e):
241
+ print(
242
+ "\nThis appears to be a gated model that requires authentication."
243
+ )
244
+ print("Please make sure you:")
245
+ print("1. Have accepted the model license on the Hugging Face Hub")
246
+ print(
247
+ "2. Have created a token with 'Access to public gated repositories' permission"
248
+ )
249
+ print(
250
+ "3. Have set the token as HUGGING_FACE_HUB_TOKEN environment variable"
251
+ )
252
+ print("\nAlternatively, try using the Ollama backend:")
253
+ print(
254
+ f"python demo.py --backend ollama --model gemma:7b-it [other args]"
255
+ )
256
+ raise
257
+
258
+ def generate(self, prompt, num_responses=3):
259
+ """Generate responses using the Hugging Face pipeline."""
260
+ # Calculate prompt length in tokens (approximate)
261
+ prompt_length = len(prompt.split())
262
+
263
+ # Generate the responses
264
+ responses = self.pipeline(
265
+ prompt,
266
+ max_length=prompt_length + self.max_length,
267
+ temperature=self.temperature,
268
+ do_sample=True,
269
+ num_return_sequences=num_responses,
270
+ top_p=0.92,
271
+ top_k=50,
272
+ truncation=True,
273
+ )
274
+
275
+ # Extract just the generated parts (not the prompt)
276
+ generated_texts = []
277
+ for resp in responses:
278
+ # Get the text after the prompt
279
+ generated = resp["generated_text"][len(prompt) :].strip()
280
+
281
+ # Clean up the response
282
+ generated = self.cleanup_response(generated)
283
+
284
+ # Add to our list if it's not empty
285
+ if generated:
286
+ generated_texts.append(generated)
287
+
288
+ return generated_texts
289
+
290
+
291
+ class OllamaInterface(LLMInterface):
292
+ """Interface for Ollama models."""
293
+
294
+ def __init__(self, model_name="gemma:7b", max_length=150, temperature=0.9):
295
+ """Initialize the Ollama interface."""
296
+ super().__init__(model_name, max_length, temperature)
297
+ # Check if Ollama is installed and the model is available
298
+ try:
299
+ import requests
300
+
301
+ response = requests.get("http://localhost:11434/api/tags")
302
+ if response.status_code == 200:
303
+ models = [model["name"] for model in response.json()["models"]]
304
+ if model_name not in models:
305
+ print(
306
+ f"Warning: Model {model_name} not found in Ollama. Available models: {', '.join(models)}"
307
+ )
308
+ print(f"You may need to run: ollama pull {model_name}")
309
+ print(f"Ollama is available and will use model: {model_name}")
310
+ except Exception as e:
311
+ print(f"Warning: Ollama may not be installed or running: {e}")
312
+ print("You can install Ollama from https://ollama.ai/")
313
+
314
+ def generate(self, prompt, num_responses=3):
315
+ """Generate responses using Ollama API."""
316
+ import requests
317
+
318
+ generated_texts = []
319
+ for _ in range(num_responses):
320
+ try:
321
+ response = requests.post(
322
+ "http://localhost:11434/api/generate",
323
+ json={
324
+ "model": self.model_name,
325
+ "prompt": prompt,
326
+ "temperature": self.temperature,
327
+ "max_tokens": self.max_length,
328
+ },
329
+ stream=False,
330
+ )
331
+
332
+ if response.status_code == 200:
333
+ # Extract the generated text
334
+ generated = response.json().get("response", "").strip()
335
+
336
+ # Clean up the response
337
+ generated = self.cleanup_response(generated)
338
+
339
+ # Add to our list if it's not empty
340
+ if generated:
341
+ generated_texts.append(generated)
342
+ else:
343
+ print(f"Error from Ollama API: {response.text}")
344
+ except Exception as e:
345
+ print(f"Error generating with Ollama: {e}")
346
+
347
+ return generated_texts
348
+
349
+
350
+ class LLMToolInterface(LLMInterface):
351
+ """Interface for Simon Willison's LLM tool."""
352
+
353
+ def __init__(
354
+ self, model_name="gemini-1.5-pro-latest", max_length=150, temperature=0.9
355
+ ):
356
+ """Initialize the LLM tool interface."""
357
+ super().__init__(model_name, max_length, temperature)
358
+ # Check if LLM tool is installed
359
+ try:
360
+ import subprocess
361
+
362
+ result = subprocess.run(["llm", "models"], capture_output=True, text=True)
363
+ if result.returncode == 0:
364
+ models = [
365
+ line.strip() for line in result.stdout.split("\n") if line.strip()
366
+ ]
367
+ print(f"LLM tool is available. Found {len(models)} models.")
368
+
369
+ # Check for specific model types
370
+ gemini_models = [
371
+ m for m in models if "gemini" in m.lower() or "gemma" in m.lower()
372
+ ]
373
+ if gemini_models:
374
+ print(f"Gemini models available: {', '.join(gemini_models[:3])}...")
375
+
376
+ # Check for Ollama models
377
+ ollama_models = [m for m in models if "ollama" in m.lower()]
378
+ if ollama_models:
379
+ print(f"Ollama models available: {', '.join(ollama_models[:3])}...")
380
+
381
+ # Check for MLX models
382
+ mlx_models = [m for m in models if "mlx" in m.lower()]
383
+ if mlx_models:
384
+ print(f"MLX models available: {', '.join(mlx_models[:3])}...")
385
+
386
+ # Check if the specified model is available
387
+ if not any(self.model_name in m for m in models):
388
+ print(
389
+ f"Warning: Model '{self.model_name}' not found in available models."
390
+ )
391
+ print("You may need to install the appropriate plugin:")
392
+ if (
393
+ "gemini" in self.model_name.lower()
394
+ or "gemma" in self.model_name.lower()
395
+ ):
396
+ print("llm install llm-gemini")
397
+ elif "mlx" in self.model_name.lower():
398
+ print("llm install llm-mlx")
399
+ elif "ollama" in self.model_name.lower():
400
+ print("llm install llm-ollama")
401
+ print("ollama pull " + self.model_name.replace("ollama/", ""))
402
+ else:
403
+ print("Warning: LLM tool may be installed but returned an error.")
404
+ except Exception as e:
405
+ print(f"Warning: Simon Willison's LLM tool may not be installed: {e}")
406
+ print("You can install it with: pip install llm")
407
+
408
+ def generate(self, prompt, num_responses=3):
409
+ """Generate responses using the LLM tool."""
410
+ import subprocess
411
+ import os
412
+
413
+ # Check for required environment variables
414
+ if "gemini" in self.model_name.lower() or "gemma" in self.model_name.lower():
415
+ if not os.environ.get("GEMINI_API_KEY"):
416
+ print("Warning: GEMINI_API_KEY environment variable not found.")
417
+ print("Gemini API may not work without it.")
418
+ elif "ollama" in self.model_name.lower():
419
+ # Check if Ollama is running
420
+ try:
421
+ import requests
422
+
423
+ response = requests.get("http://localhost:11434/api/tags", timeout=2)
424
+ if response.status_code != 200:
425
+ print("Warning: Ollama server doesn't seem to be running.")
426
+ print("Start Ollama with: ollama serve")
427
+ except Exception:
428
+ print("Warning: Ollama server doesn't seem to be running.")
429
+ print("Start Ollama with: ollama serve")
430
+
431
+ # Determine the appropriate parameter name for max tokens
432
+ if "gemini" in self.model_name.lower() or "gemma" in self.model_name.lower():
433
+ max_tokens_param = "max_output_tokens"
434
+ elif "ollama" in self.model_name.lower():
435
+ max_tokens_param = "num_predict"
436
+ else:
437
+ max_tokens_param = "max_tokens"
438
+
439
+ generated_texts = []
440
+ for _ in range(num_responses):
441
+ try:
442
+ # Call the LLM tool
443
+ result = subprocess.run(
444
+ [
445
+ "llm",
446
+ "-m",
447
+ self.model_name,
448
+ "-s",
449
+ f"temperature={self.temperature}",
450
+ "-s",
451
+ f"{max_tokens_param}={self.max_length}",
452
+ prompt,
453
+ ],
454
+ capture_output=True,
455
+ text=True,
456
+ )
457
+
458
+ if result.returncode == 0:
459
+ # Get the generated text
460
+ generated = result.stdout.strip()
461
+
462
+ # Clean up the response
463
+ generated = self.cleanup_response(generated)
464
+
465
+ # Add to our list if it's not empty
466
+ if generated:
467
+ generated_texts.append(generated)
468
+ else:
469
+ print(f"Error from LLM tool: {result.stderr}")
470
+ except Exception as e:
471
+ print(f"Error generating with LLM tool: {e}")
472
+
473
+ return generated_texts
474
+
475
+
476
+ class MLXInterface(LLMInterface):
477
+ """Interface for MLX-powered models on Mac."""
478
+
479
+ def __init__(
480
+ self, model_name="mlx-community/gemma-7b-it", max_length=150, temperature=0.9
481
+ ):
482
+ """Initialize the MLX interface."""
483
+ super().__init__(model_name, max_length, temperature)
484
+ # Check if MLX is installed
485
+ try:
486
+ import importlib.util
487
+
488
+ if importlib.util.find_spec("mlx") is not None:
489
+ print("MLX is available for optimized inference on Mac")
490
+ else:
491
+ print("Warning: MLX is not installed. Install with: pip install mlx")
492
+ except Exception as e:
493
+ print(f"Warning: Error checking for MLX: {e}")
494
+
495
+ def generate(self, prompt, num_responses=3):
496
+ """Generate responses using MLX."""
497
+ try:
498
+ # Dynamically import MLX to avoid errors on non-Mac platforms
499
+ import mlx.core as mx
500
+ from transformers import AutoTokenizer, AutoModelForCausalLM
501
+
502
+ # Load the model and tokenizer
503
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
504
+ model = AutoModelForCausalLM.from_pretrained(
505
+ self.model_name, trust_remote_code=True, mx_dtype=mx.float16
506
+ )
507
+
508
+ generated_texts = []
509
+ for _ in range(num_responses):
510
+ # Tokenize the prompt
511
+ inputs = tokenizer(prompt, return_tensors="np")
512
+
513
+ # Generate
514
+ outputs = model.generate(
515
+ inputs["input_ids"],
516
+ max_length=len(inputs["input_ids"][0]) + self.max_length,
517
+ temperature=self.temperature,
518
+ do_sample=True,
519
+ top_p=0.92,
520
+ top_k=50,
521
+ )
522
+
523
+ # Decode the generated tokens
524
+ generated = tokenizer.decode(
525
+ outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
526
+ )
527
+
528
+ # Clean up the response
529
+ generated = self.cleanup_response(generated)
530
+
531
+ # Add to our list if it's not empty
532
+ if generated:
533
+ generated_texts.append(generated)
534
+
535
+ return generated_texts
536
+ except Exception as e:
537
+ print(f"Error generating with MLX: {e}")
538
+ return []
539
+
540
+
541
+ def create_llm_interface(backend, model_name, max_length=150, temperature=0.9):
542
+ """Create an appropriate LLM interface based on the backend.
543
+
544
+ Args:
545
+ backend: The backend to use ('hf', 'llm')
546
+ model_name: The name of the model to use
547
+ max_length: Maximum length of generated text
548
+ temperature: Controls randomness (higher = more random)
549
+
550
+ Returns:
551
+ An LLM interface instance
552
+ """
553
+ if backend == "hf":
554
+ return HuggingFaceInterface(model_name, max_length, temperature)
555
+ elif backend == "llm":
556
+ return LLMToolInterface(model_name, max_length, temperature)
557
+ else:
558
+ raise ValueError(f"Unknown backend: {backend}")
559
+
560
+
561
+ def generate_response(
562
  prompt,
563
+ model_name="distilgpt2",
564
+ max_length=150,
565
+ temperature=0.9,
566
+ num_responses=3,
567
+ backend="hf",
568
+ ):
569
+ """Generate multiple responses using the specified model and backend.
570
+
571
+ Args:
572
+ prompt: The prompt to generate responses for
573
+ model_name: The name of the model to use
574
+ max_length: Maximum number of new tokens to generate
575
+ temperature: Controls randomness (higher = more random)
576
+ num_responses: Number of different responses to generate
577
+ backend: The backend to use ('hf', 'ollama', 'llm', 'mlx')
578
+
579
+ Returns:
580
+ A list of generated responses
581
+ """
582
+ # Create the appropriate interface
583
+ interface = create_llm_interface(backend, model_name, max_length, temperature)
584
+
585
+ # Generate responses
586
+ return interface.generate(prompt, num_responses)
587
+
588
+
589
+ def main():
590
+ # Set up argument parser
591
+ parser = argparse.ArgumentParser(
592
+ description="Generate AAC responses using social graph context"
593
+ )
594
+ parser.add_argument(
595
+ "--person", default="billy", help="Person ID from the social graph"
596
+ )
597
+ parser.add_argument("--topic", help="Topic of conversation")
598
+ parser.add_argument("--message", help="Message from the conversation partner")
599
+ parser.add_argument(
600
+ "--backend",
601
+ default="llm",
602
+ choices=["hf", "llm"],
603
+ help="Backend to use for generation (hf=HuggingFace, "
604
+ "llm=Simon Willison's LLM tool with support for Gemini/MLX/Ollama)",
605
+ )
606
+ parser.add_argument(
607
+ "--model",
608
+ default="gemini-1.5-pro-latest",
609
+ help="Model to use for generation. Recommended models by backend:\n"
610
+ "- hf: 'distilgpt2', 'gpt2-medium', 'google/gemma-2b-it'\n"
611
+ "- llm: 'gemini-1.5-pro-latest', 'gemma-3-27b-it' (requires llm-gemini plugin)\n"
612
+ " 'mlx-community/gemma-7b-it' (requires llm-mlx plugin)\n"
613
+ " 'ollama/gemma3:4b-it-qat', 'ollama/llama3:8b' (requires llm-ollama plugin)",
614
+ )
615
+ parser.add_argument(
616
+ "--num_responses", type=int, default=3, help="Number of responses to generate"
617
+ )
618
+ parser.add_argument(
619
+ "--max_length",
620
+ type=int,
621
+ default=150,
622
+ help="Maximum length of generated responses",
623
+ )
624
+ parser.add_argument(
625
+ "--temperature",
626
+ type=float,
627
+ default=0.9,
628
+ help="Temperature for generation (higher = more creative)",
629
+ )
630
+ args = parser.parse_args()
631
+
632
+ # Check for token if using HF backend with gated models
633
+ if args.backend == "hf" and any(
634
+ name in args.model for name in ["gemma", "llama", "mistral"]
635
+ ):
636
+ if not check_hf_token():
637
+ print("\nSuggestion: Try using the LLM tool with Gemini API instead:")
638
+ print(
639
+ f"python demo.py --backend llm --model gemini-1.5-pro-latest --person {args.person}"
640
+ + (f' --topic "{args.topic}"' if args.topic else "")
641
+ + (f' --message "{args.message}"' if args.message else "")
642
+ )
643
+ print("\nOr use a non-gated model:")
644
+ print(
645
+ f"python demo.py --backend hf --model gpt2-medium --person {args.person}"
646
+ + (f' --topic "{args.topic}"' if args.topic else "")
647
+ + (f' --message "{args.message}"' if args.message else "")
648
+ )
649
+ print("\nContinuing anyway, but expect authentication errors...\n")
650
+
651
+ # Load the social graph
652
+ social_graph = load_social_graph()
653
+
654
+ # Build the prompt
655
+ prompt = build_enhanced_prompt(social_graph, args.person, args.topic, args.message)
656
+
657
+ print("\n=== PROMPT ===")
658
+ print(prompt)
659
+ print(
660
+ f"\n=== GENERATING RESPONSE USING {args.backend.upper()} BACKEND WITH MODEL {args.model} ==="
661
+ )
662
+
663
+ # Generate the responses
664
+ try:
665
+ responses = generate_response(
666
+ prompt,
667
+ args.model,
668
+ max_length=args.max_length,
669
+ num_responses=args.num_responses,
670
+ temperature=args.temperature,
671
+ backend=args.backend,
672
+ )
673
+
674
+ print("\n=== RESPONSES ===")
675
+ for i, response in enumerate(responses, 1):
676
+ print(f"{i}. {response}")
677
+ print()
678
+ except Exception as e:
679
+ print(f"\nError generating responses: {e}")
680
+
681
+ if args.backend == "hf" and any(
682
+ name in args.model for name in ["gemma", "llama", "mistral"]
683
+ ):
684
+ print("\nThis appears to be an authentication issue with a gated model.")
685
+ print("Try using the LLM tool with Gemini API instead:")
686
+ print(
687
+ f"python demo.py --backend llm --model gemini-1.5-pro-latest --person {args.person}"
688
+ + (f' --topic "{args.topic}"' if args.topic else "")
689
+ + (f' --message "{args.message}"' if args.message else "")
690
+ )
691
+ # Ollama is now handled through the llm backend
692
+ elif args.backend == "llm":
693
+ if "gemini" in args.model.lower() or "gemma" in args.model.lower():
694
+ print(
695
+ "\nMake sure you have the GEMINI_API_KEY environment variable set:"
696
+ )
697
+ print("export GEMINI_API_KEY=your_api_key")
698
+ print("\nAnd make sure llm-gemini is installed:")
699
+ print("llm install llm-gemini")
700
+ elif "mlx" in args.model.lower():
701
+ print("\nMake sure llm-mlx is installed:")
702
+ print("llm install llm-mlx")
703
+ elif "ollama" in args.model.lower():
704
+ print("\nMake sure Ollama is installed and running:")
705
+ print("1. Install from https://ollama.ai/")
706
+ print("2. Start Ollama with: ollama serve")
707
+ print("3. Install the llm-ollama plugin: llm install llm-ollama")
708
+ print(
709
+ f"4. Pull the model: ollama pull {args.model.replace('ollama/', '')}"
710
+ )
711
+ else:
712
+ print("\nMake sure Simon Willison's LLM tool is installed:")
713
+ print("pip install llm")
714
+
715
+
716
+ # If running as a script
717
+ if __name__ == "__main__":
718
+ main()