fixing demo
Browse files
demo.py
CHANGED
@@ -1,40 +1,718 @@
|
|
1 |
from transformers import pipeline
|
2 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
# Load model
|
5 |
|
6 |
-
#
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
person = kg["people"]["billy"] # Using Billy instead of Bob
|
15 |
-
context = person["context"]
|
16 |
|
17 |
-
# User input
|
18 |
-
query = "What should I say to Billy?"
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
|
23 |
-
Billy just asked me: "Dad, did you see the United match last night?"
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
prompt,
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import pipeline
|
2 |
import json
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import subprocess
|
7 |
+
import requests
|
8 |
+
from typing import List, Dict, Any, Optional, Union
|
9 |
+
import time
|
10 |
|
|
|
11 |
|
12 |
+
# Check for Hugging Face token
|
13 |
+
def check_hf_token():
|
14 |
+
"""Check if a Hugging Face token is properly set up."""
|
15 |
+
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
16 |
|
17 |
+
if not token:
|
18 |
+
print("\nWarning: No Hugging Face token found in environment variables.")
|
19 |
+
print(
|
20 |
+
"To use gated models like Gemma, you need to set up a token with the right permissions."
|
21 |
+
)
|
22 |
+
print("1. Create a token at https://huggingface.co/settings/tokens")
|
23 |
+
print("2. Make sure to enable 'Access to public gated repositories'")
|
24 |
+
print("3. Set it as an environment variable:")
|
25 |
+
print(" export HUGGING_FACE_HUB_TOKEN=your_token_here")
|
26 |
+
return False
|
27 |
|
28 |
+
return True
|
|
|
|
|
29 |
|
|
|
|
|
30 |
|
31 |
+
def load_social_graph(file_path="social_graph.json"):
|
32 |
+
"""Load the social graph from a JSON file."""
|
33 |
+
with open(file_path, "r") as f:
|
34 |
+
return json.load(f)
|
35 |
|
|
|
36 |
|
37 |
+
def get_person_info(social_graph, person_id):
|
38 |
+
"""Get information about a person from the social graph."""
|
39 |
+
if person_id in social_graph["people"]:
|
40 |
+
return social_graph["people"][person_id]
|
41 |
+
else:
|
42 |
+
available_people = ", ".join(social_graph["people"].keys())
|
43 |
+
raise ValueError(
|
44 |
+
f"Person '{person_id}' not found in social graph. Available people: {available_people}"
|
45 |
+
)
|
46 |
|
47 |
+
|
48 |
+
def build_enhanced_prompt(social_graph, person_id, topic=None, user_message=None):
|
49 |
+
"""Build an enhanced prompt using social graph information."""
|
50 |
+
# Get AAC user information
|
51 |
+
aac_user = social_graph["aac_user"]
|
52 |
+
|
53 |
+
# Get conversation partner information
|
54 |
+
person = get_person_info(social_graph, person_id)
|
55 |
+
|
56 |
+
# Start building the prompt with AAC user information
|
57 |
+
prompt = f"""I am {aac_user['name']}, a {aac_user['age']}-year-old with MND (Motor Neuron Disease) from {aac_user['location']}.
|
58 |
+
{aac_user['background']}
|
59 |
+
|
60 |
+
My communication needs: {aac_user['communication_needs']}
|
61 |
+
|
62 |
+
I am talking to {person['name']}, who is my {person['role']}.
|
63 |
+
About {person['name']}: {person['context']}
|
64 |
+
We typically talk about: {', '.join(person['topics'])}
|
65 |
+
We communicate {person['frequency']}.
|
66 |
+
"""
|
67 |
+
|
68 |
+
# Add places information if available
|
69 |
+
if "places" in social_graph:
|
70 |
+
relevant_places = social_graph["places"][
|
71 |
+
:3
|
72 |
+
] # Just use a few places for context
|
73 |
+
prompt += f"\nPlaces important to me: {', '.join(relevant_places)}\n"
|
74 |
+
|
75 |
+
# Add communication style based on relationship
|
76 |
+
if person["role"] in ["wife", "son", "daughter", "mother", "father"]:
|
77 |
+
prompt += "I communicate with my family in a warm, loving way, sometimes using inside jokes.\n"
|
78 |
+
elif person["role"] in ["doctor", "therapist", "nurse"]:
|
79 |
+
prompt += (
|
80 |
+
"I communicate with healthcare providers in a direct, informative way.\n"
|
81 |
+
)
|
82 |
+
elif person["role"] in ["best mate", "friend"]:
|
83 |
+
prompt += "I communicate with friends casually, often with humor and sometimes swearing.\n"
|
84 |
+
elif person["role"] in ["work colleague", "boss"]:
|
85 |
+
prompt += "I communicate with colleagues professionally but still friendly.\n"
|
86 |
+
|
87 |
+
# Add common utterances by category if available
|
88 |
+
if "common_utterances" in social_graph:
|
89 |
+
# Try to find relevant utterance category based on topic
|
90 |
+
utterance_category = None
|
91 |
+
if topic == "football" or topic == "sports":
|
92 |
+
utterance_category = "sports_talk"
|
93 |
+
elif topic == "programming" or topic == "tech news":
|
94 |
+
utterance_category = "tech_talk"
|
95 |
+
elif topic in ["family plans", "children's activities"]:
|
96 |
+
utterance_category = "family_talk"
|
97 |
+
|
98 |
+
# Add relevant utterances if category exists
|
99 |
+
if (
|
100 |
+
utterance_category
|
101 |
+
and utterance_category in social_graph["common_utterances"]
|
102 |
+
):
|
103 |
+
utterances = social_graph["common_utterances"][utterance_category][:2]
|
104 |
+
prompt += f"\nI might say things like: {' or '.join(utterances)}\n"
|
105 |
+
|
106 |
+
# Add topic information if provided
|
107 |
+
if topic and topic in person["topics"]:
|
108 |
+
prompt += f"\nWe are currently discussing {topic}.\n"
|
109 |
+
|
110 |
+
# Add specific context about this topic with this person
|
111 |
+
if topic == "football" and "Manchester United" in person["context"]:
|
112 |
+
prompt += (
|
113 |
+
"We both support Manchester United and often discuss recent matches.\n"
|
114 |
+
)
|
115 |
+
elif topic == "programming" and "software developer" in person["context"]:
|
116 |
+
prompt += (
|
117 |
+
"We both work in software development and share technical interests.\n"
|
118 |
+
)
|
119 |
+
elif topic == "family plans" and person["role"] in ["wife", "husband"]:
|
120 |
+
prompt += "We make family decisions together, considering my condition.\n"
|
121 |
+
elif topic == "old scout adventures" and person["role"] == "best mate":
|
122 |
+
prompt += "We often reminisce about our Scout camping trips in South East London.\n"
|
123 |
+
elif topic == "cycling" and "cycling" in person["context"]:
|
124 |
+
prompt += "I miss being able to cycle but enjoy talking about past cycling adventures.\n"
|
125 |
+
|
126 |
+
# Add shared experiences based on relationship and topic
|
127 |
+
if person["role"] == "best mate" and topic in ["football", "pub quizzes"]:
|
128 |
+
prompt += (
|
129 |
+
"We've watched many matches together and done countless pub quizzes.\n"
|
130 |
+
)
|
131 |
+
elif person["role"] == "wife" and topic in ["family plans", "weekend outings"]:
|
132 |
+
prompt += "Emma has been amazing at keeping family life as normal as possible despite my condition.\n"
|
133 |
+
elif person["role"] == "son" and topic == "football":
|
134 |
+
prompt += "I try to stay engaged with Billy's football enthusiasm even as my condition progresses.\n"
|
135 |
+
|
136 |
+
# Add the user's message if provided
|
137 |
+
if user_message:
|
138 |
+
prompt += f"\n{person['name']} just said to me: \"{user_message}\"\n"
|
139 |
+
else:
|
140 |
+
# Use a common phrase from the person if no message is provided
|
141 |
+
if person["common_phrases"]:
|
142 |
+
default_message = person["common_phrases"][0]
|
143 |
+
prompt += f"\n{person['name']} just said to me: \"{default_message}\"\n"
|
144 |
+
|
145 |
+
# Add the response prompt with specific guidance
|
146 |
+
prompt += f"""
|
147 |
+
I want to respond to {person['name']} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said. I'll use casual language with some humor since we're close friends.
|
148 |
+
|
149 |
+
My response to {person['name']}:"""
|
150 |
+
|
151 |
+
return prompt
|
152 |
+
|
153 |
+
|
154 |
+
class LLMInterface:
|
155 |
+
"""Base interface for language model generation."""
|
156 |
+
|
157 |
+
def __init__(self, model_name, max_length=150, temperature=0.9):
|
158 |
+
"""Initialize the LLM interface.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
model_name: Name or path of the model
|
162 |
+
max_length: Maximum length of generated text
|
163 |
+
temperature: Controls randomness (higher = more random)
|
164 |
+
"""
|
165 |
+
self.model_name = model_name
|
166 |
+
self.max_length = max_length
|
167 |
+
self.temperature = temperature
|
168 |
+
|
169 |
+
def generate(self, prompt, num_responses=3):
|
170 |
+
"""Generate responses for the given prompt.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
prompt: The prompt to generate responses for
|
174 |
+
num_responses: Number of responses to generate
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
A list of generated responses
|
178 |
+
"""
|
179 |
+
raise NotImplementedError("Subclasses must implement this method")
|
180 |
+
|
181 |
+
def cleanup_response(self, text):
|
182 |
+
"""Clean up a generated response.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
text: The raw generated text
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Cleaned up text
|
189 |
+
"""
|
190 |
+
# Make sure it's a complete sentence or phrase
|
191 |
+
# If it ends abruptly, add an ellipsis
|
192 |
+
if text and not any(text.endswith(end) for end in [".", "!", "?", '..."']):
|
193 |
+
if text.endswith('"'):
|
194 |
+
text = text[:-1] + '..."'
|
195 |
+
else:
|
196 |
+
text += "..."
|
197 |
+
|
198 |
+
return text
|
199 |
+
|
200 |
+
|
201 |
+
class HuggingFaceInterface(LLMInterface):
|
202 |
+
"""Interface for Hugging Face Transformers models."""
|
203 |
+
|
204 |
+
def __init__(self, model_name="distilgpt2", max_length=150, temperature=0.9):
|
205 |
+
"""Initialize the Hugging Face interface."""
|
206 |
+
super().__init__(model_name, max_length, temperature)
|
207 |
+
try:
|
208 |
+
# Check if we're dealing with a gated model
|
209 |
+
is_gated_model = any(
|
210 |
+
name in model_name for name in ["gemma", "llama", "mistral"]
|
211 |
+
)
|
212 |
+
|
213 |
+
# Get token from environment
|
214 |
+
import os
|
215 |
+
|
216 |
+
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
|
217 |
+
"HF_TOKEN"
|
218 |
+
)
|
219 |
+
|
220 |
+
if is_gated_model and token:
|
221 |
+
print(f"Using token for gated model: {model_name}")
|
222 |
+
from huggingface_hub import login
|
223 |
+
|
224 |
+
login(token=token, add_to_git_credential=False)
|
225 |
+
|
226 |
+
# Explicitly pass token to pipeline
|
227 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
228 |
+
|
229 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
|
230 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
|
231 |
+
self.pipeline = pipeline(
|
232 |
+
"text-generation", model=model, tokenizer=tokenizer
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
self.pipeline = pipeline("text-generation", model=model_name)
|
236 |
+
|
237 |
+
print(f"Successfully loaded model: {model_name}")
|
238 |
+
except Exception as e:
|
239 |
+
print(f"Error loading model {model_name}: {e}")
|
240 |
+
if "gated" in str(e).lower() or "403" in str(e):
|
241 |
+
print(
|
242 |
+
"\nThis appears to be a gated model that requires authentication."
|
243 |
+
)
|
244 |
+
print("Please make sure you:")
|
245 |
+
print("1. Have accepted the model license on the Hugging Face Hub")
|
246 |
+
print(
|
247 |
+
"2. Have created a token with 'Access to public gated repositories' permission"
|
248 |
+
)
|
249 |
+
print(
|
250 |
+
"3. Have set the token as HUGGING_FACE_HUB_TOKEN environment variable"
|
251 |
+
)
|
252 |
+
print("\nAlternatively, try using the Ollama backend:")
|
253 |
+
print(
|
254 |
+
f"python demo.py --backend ollama --model gemma:7b-it [other args]"
|
255 |
+
)
|
256 |
+
raise
|
257 |
+
|
258 |
+
def generate(self, prompt, num_responses=3):
|
259 |
+
"""Generate responses using the Hugging Face pipeline."""
|
260 |
+
# Calculate prompt length in tokens (approximate)
|
261 |
+
prompt_length = len(prompt.split())
|
262 |
+
|
263 |
+
# Generate the responses
|
264 |
+
responses = self.pipeline(
|
265 |
+
prompt,
|
266 |
+
max_length=prompt_length + self.max_length,
|
267 |
+
temperature=self.temperature,
|
268 |
+
do_sample=True,
|
269 |
+
num_return_sequences=num_responses,
|
270 |
+
top_p=0.92,
|
271 |
+
top_k=50,
|
272 |
+
truncation=True,
|
273 |
+
)
|
274 |
+
|
275 |
+
# Extract just the generated parts (not the prompt)
|
276 |
+
generated_texts = []
|
277 |
+
for resp in responses:
|
278 |
+
# Get the text after the prompt
|
279 |
+
generated = resp["generated_text"][len(prompt) :].strip()
|
280 |
+
|
281 |
+
# Clean up the response
|
282 |
+
generated = self.cleanup_response(generated)
|
283 |
+
|
284 |
+
# Add to our list if it's not empty
|
285 |
+
if generated:
|
286 |
+
generated_texts.append(generated)
|
287 |
+
|
288 |
+
return generated_texts
|
289 |
+
|
290 |
+
|
291 |
+
class OllamaInterface(LLMInterface):
|
292 |
+
"""Interface for Ollama models."""
|
293 |
+
|
294 |
+
def __init__(self, model_name="gemma:7b", max_length=150, temperature=0.9):
|
295 |
+
"""Initialize the Ollama interface."""
|
296 |
+
super().__init__(model_name, max_length, temperature)
|
297 |
+
# Check if Ollama is installed and the model is available
|
298 |
+
try:
|
299 |
+
import requests
|
300 |
+
|
301 |
+
response = requests.get("http://localhost:11434/api/tags")
|
302 |
+
if response.status_code == 200:
|
303 |
+
models = [model["name"] for model in response.json()["models"]]
|
304 |
+
if model_name not in models:
|
305 |
+
print(
|
306 |
+
f"Warning: Model {model_name} not found in Ollama. Available models: {', '.join(models)}"
|
307 |
+
)
|
308 |
+
print(f"You may need to run: ollama pull {model_name}")
|
309 |
+
print(f"Ollama is available and will use model: {model_name}")
|
310 |
+
except Exception as e:
|
311 |
+
print(f"Warning: Ollama may not be installed or running: {e}")
|
312 |
+
print("You can install Ollama from https://ollama.ai/")
|
313 |
+
|
314 |
+
def generate(self, prompt, num_responses=3):
|
315 |
+
"""Generate responses using Ollama API."""
|
316 |
+
import requests
|
317 |
+
|
318 |
+
generated_texts = []
|
319 |
+
for _ in range(num_responses):
|
320 |
+
try:
|
321 |
+
response = requests.post(
|
322 |
+
"http://localhost:11434/api/generate",
|
323 |
+
json={
|
324 |
+
"model": self.model_name,
|
325 |
+
"prompt": prompt,
|
326 |
+
"temperature": self.temperature,
|
327 |
+
"max_tokens": self.max_length,
|
328 |
+
},
|
329 |
+
stream=False,
|
330 |
+
)
|
331 |
+
|
332 |
+
if response.status_code == 200:
|
333 |
+
# Extract the generated text
|
334 |
+
generated = response.json().get("response", "").strip()
|
335 |
+
|
336 |
+
# Clean up the response
|
337 |
+
generated = self.cleanup_response(generated)
|
338 |
+
|
339 |
+
# Add to our list if it's not empty
|
340 |
+
if generated:
|
341 |
+
generated_texts.append(generated)
|
342 |
+
else:
|
343 |
+
print(f"Error from Ollama API: {response.text}")
|
344 |
+
except Exception as e:
|
345 |
+
print(f"Error generating with Ollama: {e}")
|
346 |
+
|
347 |
+
return generated_texts
|
348 |
+
|
349 |
+
|
350 |
+
class LLMToolInterface(LLMInterface):
|
351 |
+
"""Interface for Simon Willison's LLM tool."""
|
352 |
+
|
353 |
+
def __init__(
|
354 |
+
self, model_name="gemini-1.5-pro-latest", max_length=150, temperature=0.9
|
355 |
+
):
|
356 |
+
"""Initialize the LLM tool interface."""
|
357 |
+
super().__init__(model_name, max_length, temperature)
|
358 |
+
# Check if LLM tool is installed
|
359 |
+
try:
|
360 |
+
import subprocess
|
361 |
+
|
362 |
+
result = subprocess.run(["llm", "models"], capture_output=True, text=True)
|
363 |
+
if result.returncode == 0:
|
364 |
+
models = [
|
365 |
+
line.strip() for line in result.stdout.split("\n") if line.strip()
|
366 |
+
]
|
367 |
+
print(f"LLM tool is available. Found {len(models)} models.")
|
368 |
+
|
369 |
+
# Check for specific model types
|
370 |
+
gemini_models = [
|
371 |
+
m for m in models if "gemini" in m.lower() or "gemma" in m.lower()
|
372 |
+
]
|
373 |
+
if gemini_models:
|
374 |
+
print(f"Gemini models available: {', '.join(gemini_models[:3])}...")
|
375 |
+
|
376 |
+
# Check for Ollama models
|
377 |
+
ollama_models = [m for m in models if "ollama" in m.lower()]
|
378 |
+
if ollama_models:
|
379 |
+
print(f"Ollama models available: {', '.join(ollama_models[:3])}...")
|
380 |
+
|
381 |
+
# Check for MLX models
|
382 |
+
mlx_models = [m for m in models if "mlx" in m.lower()]
|
383 |
+
if mlx_models:
|
384 |
+
print(f"MLX models available: {', '.join(mlx_models[:3])}...")
|
385 |
+
|
386 |
+
# Check if the specified model is available
|
387 |
+
if not any(self.model_name in m for m in models):
|
388 |
+
print(
|
389 |
+
f"Warning: Model '{self.model_name}' not found in available models."
|
390 |
+
)
|
391 |
+
print("You may need to install the appropriate plugin:")
|
392 |
+
if (
|
393 |
+
"gemini" in self.model_name.lower()
|
394 |
+
or "gemma" in self.model_name.lower()
|
395 |
+
):
|
396 |
+
print("llm install llm-gemini")
|
397 |
+
elif "mlx" in self.model_name.lower():
|
398 |
+
print("llm install llm-mlx")
|
399 |
+
elif "ollama" in self.model_name.lower():
|
400 |
+
print("llm install llm-ollama")
|
401 |
+
print("ollama pull " + self.model_name.replace("ollama/", ""))
|
402 |
+
else:
|
403 |
+
print("Warning: LLM tool may be installed but returned an error.")
|
404 |
+
except Exception as e:
|
405 |
+
print(f"Warning: Simon Willison's LLM tool may not be installed: {e}")
|
406 |
+
print("You can install it with: pip install llm")
|
407 |
+
|
408 |
+
def generate(self, prompt, num_responses=3):
|
409 |
+
"""Generate responses using the LLM tool."""
|
410 |
+
import subprocess
|
411 |
+
import os
|
412 |
+
|
413 |
+
# Check for required environment variables
|
414 |
+
if "gemini" in self.model_name.lower() or "gemma" in self.model_name.lower():
|
415 |
+
if not os.environ.get("GEMINI_API_KEY"):
|
416 |
+
print("Warning: GEMINI_API_KEY environment variable not found.")
|
417 |
+
print("Gemini API may not work without it.")
|
418 |
+
elif "ollama" in self.model_name.lower():
|
419 |
+
# Check if Ollama is running
|
420 |
+
try:
|
421 |
+
import requests
|
422 |
+
|
423 |
+
response = requests.get("http://localhost:11434/api/tags", timeout=2)
|
424 |
+
if response.status_code != 200:
|
425 |
+
print("Warning: Ollama server doesn't seem to be running.")
|
426 |
+
print("Start Ollama with: ollama serve")
|
427 |
+
except Exception:
|
428 |
+
print("Warning: Ollama server doesn't seem to be running.")
|
429 |
+
print("Start Ollama with: ollama serve")
|
430 |
+
|
431 |
+
# Determine the appropriate parameter name for max tokens
|
432 |
+
if "gemini" in self.model_name.lower() or "gemma" in self.model_name.lower():
|
433 |
+
max_tokens_param = "max_output_tokens"
|
434 |
+
elif "ollama" in self.model_name.lower():
|
435 |
+
max_tokens_param = "num_predict"
|
436 |
+
else:
|
437 |
+
max_tokens_param = "max_tokens"
|
438 |
+
|
439 |
+
generated_texts = []
|
440 |
+
for _ in range(num_responses):
|
441 |
+
try:
|
442 |
+
# Call the LLM tool
|
443 |
+
result = subprocess.run(
|
444 |
+
[
|
445 |
+
"llm",
|
446 |
+
"-m",
|
447 |
+
self.model_name,
|
448 |
+
"-s",
|
449 |
+
f"temperature={self.temperature}",
|
450 |
+
"-s",
|
451 |
+
f"{max_tokens_param}={self.max_length}",
|
452 |
+
prompt,
|
453 |
+
],
|
454 |
+
capture_output=True,
|
455 |
+
text=True,
|
456 |
+
)
|
457 |
+
|
458 |
+
if result.returncode == 0:
|
459 |
+
# Get the generated text
|
460 |
+
generated = result.stdout.strip()
|
461 |
+
|
462 |
+
# Clean up the response
|
463 |
+
generated = self.cleanup_response(generated)
|
464 |
+
|
465 |
+
# Add to our list if it's not empty
|
466 |
+
if generated:
|
467 |
+
generated_texts.append(generated)
|
468 |
+
else:
|
469 |
+
print(f"Error from LLM tool: {result.stderr}")
|
470 |
+
except Exception as e:
|
471 |
+
print(f"Error generating with LLM tool: {e}")
|
472 |
+
|
473 |
+
return generated_texts
|
474 |
+
|
475 |
+
|
476 |
+
class MLXInterface(LLMInterface):
|
477 |
+
"""Interface for MLX-powered models on Mac."""
|
478 |
+
|
479 |
+
def __init__(
|
480 |
+
self, model_name="mlx-community/gemma-7b-it", max_length=150, temperature=0.9
|
481 |
+
):
|
482 |
+
"""Initialize the MLX interface."""
|
483 |
+
super().__init__(model_name, max_length, temperature)
|
484 |
+
# Check if MLX is installed
|
485 |
+
try:
|
486 |
+
import importlib.util
|
487 |
+
|
488 |
+
if importlib.util.find_spec("mlx") is not None:
|
489 |
+
print("MLX is available for optimized inference on Mac")
|
490 |
+
else:
|
491 |
+
print("Warning: MLX is not installed. Install with: pip install mlx")
|
492 |
+
except Exception as e:
|
493 |
+
print(f"Warning: Error checking for MLX: {e}")
|
494 |
+
|
495 |
+
def generate(self, prompt, num_responses=3):
|
496 |
+
"""Generate responses using MLX."""
|
497 |
+
try:
|
498 |
+
# Dynamically import MLX to avoid errors on non-Mac platforms
|
499 |
+
import mlx.core as mx
|
500 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
501 |
+
|
502 |
+
# Load the model and tokenizer
|
503 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
504 |
+
model = AutoModelForCausalLM.from_pretrained(
|
505 |
+
self.model_name, trust_remote_code=True, mx_dtype=mx.float16
|
506 |
+
)
|
507 |
+
|
508 |
+
generated_texts = []
|
509 |
+
for _ in range(num_responses):
|
510 |
+
# Tokenize the prompt
|
511 |
+
inputs = tokenizer(prompt, return_tensors="np")
|
512 |
+
|
513 |
+
# Generate
|
514 |
+
outputs = model.generate(
|
515 |
+
inputs["input_ids"],
|
516 |
+
max_length=len(inputs["input_ids"][0]) + self.max_length,
|
517 |
+
temperature=self.temperature,
|
518 |
+
do_sample=True,
|
519 |
+
top_p=0.92,
|
520 |
+
top_k=50,
|
521 |
+
)
|
522 |
+
|
523 |
+
# Decode the generated tokens
|
524 |
+
generated = tokenizer.decode(
|
525 |
+
outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
526 |
+
)
|
527 |
+
|
528 |
+
# Clean up the response
|
529 |
+
generated = self.cleanup_response(generated)
|
530 |
+
|
531 |
+
# Add to our list if it's not empty
|
532 |
+
if generated:
|
533 |
+
generated_texts.append(generated)
|
534 |
+
|
535 |
+
return generated_texts
|
536 |
+
except Exception as e:
|
537 |
+
print(f"Error generating with MLX: {e}")
|
538 |
+
return []
|
539 |
+
|
540 |
+
|
541 |
+
def create_llm_interface(backend, model_name, max_length=150, temperature=0.9):
|
542 |
+
"""Create an appropriate LLM interface based on the backend.
|
543 |
+
|
544 |
+
Args:
|
545 |
+
backend: The backend to use ('hf', 'llm')
|
546 |
+
model_name: The name of the model to use
|
547 |
+
max_length: Maximum length of generated text
|
548 |
+
temperature: Controls randomness (higher = more random)
|
549 |
+
|
550 |
+
Returns:
|
551 |
+
An LLM interface instance
|
552 |
+
"""
|
553 |
+
if backend == "hf":
|
554 |
+
return HuggingFaceInterface(model_name, max_length, temperature)
|
555 |
+
elif backend == "llm":
|
556 |
+
return LLMToolInterface(model_name, max_length, temperature)
|
557 |
+
else:
|
558 |
+
raise ValueError(f"Unknown backend: {backend}")
|
559 |
+
|
560 |
+
|
561 |
+
def generate_response(
|
562 |
prompt,
|
563 |
+
model_name="distilgpt2",
|
564 |
+
max_length=150,
|
565 |
+
temperature=0.9,
|
566 |
+
num_responses=3,
|
567 |
+
backend="hf",
|
568 |
+
):
|
569 |
+
"""Generate multiple responses using the specified model and backend.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
prompt: The prompt to generate responses for
|
573 |
+
model_name: The name of the model to use
|
574 |
+
max_length: Maximum number of new tokens to generate
|
575 |
+
temperature: Controls randomness (higher = more random)
|
576 |
+
num_responses: Number of different responses to generate
|
577 |
+
backend: The backend to use ('hf', 'ollama', 'llm', 'mlx')
|
578 |
+
|
579 |
+
Returns:
|
580 |
+
A list of generated responses
|
581 |
+
"""
|
582 |
+
# Create the appropriate interface
|
583 |
+
interface = create_llm_interface(backend, model_name, max_length, temperature)
|
584 |
+
|
585 |
+
# Generate responses
|
586 |
+
return interface.generate(prompt, num_responses)
|
587 |
+
|
588 |
+
|
589 |
+
def main():
|
590 |
+
# Set up argument parser
|
591 |
+
parser = argparse.ArgumentParser(
|
592 |
+
description="Generate AAC responses using social graph context"
|
593 |
+
)
|
594 |
+
parser.add_argument(
|
595 |
+
"--person", default="billy", help="Person ID from the social graph"
|
596 |
+
)
|
597 |
+
parser.add_argument("--topic", help="Topic of conversation")
|
598 |
+
parser.add_argument("--message", help="Message from the conversation partner")
|
599 |
+
parser.add_argument(
|
600 |
+
"--backend",
|
601 |
+
default="llm",
|
602 |
+
choices=["hf", "llm"],
|
603 |
+
help="Backend to use for generation (hf=HuggingFace, "
|
604 |
+
"llm=Simon Willison's LLM tool with support for Gemini/MLX/Ollama)",
|
605 |
+
)
|
606 |
+
parser.add_argument(
|
607 |
+
"--model",
|
608 |
+
default="gemini-1.5-pro-latest",
|
609 |
+
help="Model to use for generation. Recommended models by backend:\n"
|
610 |
+
"- hf: 'distilgpt2', 'gpt2-medium', 'google/gemma-2b-it'\n"
|
611 |
+
"- llm: 'gemini-1.5-pro-latest', 'gemma-3-27b-it' (requires llm-gemini plugin)\n"
|
612 |
+
" 'mlx-community/gemma-7b-it' (requires llm-mlx plugin)\n"
|
613 |
+
" 'ollama/gemma3:4b-it-qat', 'ollama/llama3:8b' (requires llm-ollama plugin)",
|
614 |
+
)
|
615 |
+
parser.add_argument(
|
616 |
+
"--num_responses", type=int, default=3, help="Number of responses to generate"
|
617 |
+
)
|
618 |
+
parser.add_argument(
|
619 |
+
"--max_length",
|
620 |
+
type=int,
|
621 |
+
default=150,
|
622 |
+
help="Maximum length of generated responses",
|
623 |
+
)
|
624 |
+
parser.add_argument(
|
625 |
+
"--temperature",
|
626 |
+
type=float,
|
627 |
+
default=0.9,
|
628 |
+
help="Temperature for generation (higher = more creative)",
|
629 |
+
)
|
630 |
+
args = parser.parse_args()
|
631 |
+
|
632 |
+
# Check for token if using HF backend with gated models
|
633 |
+
if args.backend == "hf" and any(
|
634 |
+
name in args.model for name in ["gemma", "llama", "mistral"]
|
635 |
+
):
|
636 |
+
if not check_hf_token():
|
637 |
+
print("\nSuggestion: Try using the LLM tool with Gemini API instead:")
|
638 |
+
print(
|
639 |
+
f"python demo.py --backend llm --model gemini-1.5-pro-latest --person {args.person}"
|
640 |
+
+ (f' --topic "{args.topic}"' if args.topic else "")
|
641 |
+
+ (f' --message "{args.message}"' if args.message else "")
|
642 |
+
)
|
643 |
+
print("\nOr use a non-gated model:")
|
644 |
+
print(
|
645 |
+
f"python demo.py --backend hf --model gpt2-medium --person {args.person}"
|
646 |
+
+ (f' --topic "{args.topic}"' if args.topic else "")
|
647 |
+
+ (f' --message "{args.message}"' if args.message else "")
|
648 |
+
)
|
649 |
+
print("\nContinuing anyway, but expect authentication errors...\n")
|
650 |
+
|
651 |
+
# Load the social graph
|
652 |
+
social_graph = load_social_graph()
|
653 |
+
|
654 |
+
# Build the prompt
|
655 |
+
prompt = build_enhanced_prompt(social_graph, args.person, args.topic, args.message)
|
656 |
+
|
657 |
+
print("\n=== PROMPT ===")
|
658 |
+
print(prompt)
|
659 |
+
print(
|
660 |
+
f"\n=== GENERATING RESPONSE USING {args.backend.upper()} BACKEND WITH MODEL {args.model} ==="
|
661 |
+
)
|
662 |
+
|
663 |
+
# Generate the responses
|
664 |
+
try:
|
665 |
+
responses = generate_response(
|
666 |
+
prompt,
|
667 |
+
args.model,
|
668 |
+
max_length=args.max_length,
|
669 |
+
num_responses=args.num_responses,
|
670 |
+
temperature=args.temperature,
|
671 |
+
backend=args.backend,
|
672 |
+
)
|
673 |
+
|
674 |
+
print("\n=== RESPONSES ===")
|
675 |
+
for i, response in enumerate(responses, 1):
|
676 |
+
print(f"{i}. {response}")
|
677 |
+
print()
|
678 |
+
except Exception as e:
|
679 |
+
print(f"\nError generating responses: {e}")
|
680 |
+
|
681 |
+
if args.backend == "hf" and any(
|
682 |
+
name in args.model for name in ["gemma", "llama", "mistral"]
|
683 |
+
):
|
684 |
+
print("\nThis appears to be an authentication issue with a gated model.")
|
685 |
+
print("Try using the LLM tool with Gemini API instead:")
|
686 |
+
print(
|
687 |
+
f"python demo.py --backend llm --model gemini-1.5-pro-latest --person {args.person}"
|
688 |
+
+ (f' --topic "{args.topic}"' if args.topic else "")
|
689 |
+
+ (f' --message "{args.message}"' if args.message else "")
|
690 |
+
)
|
691 |
+
# Ollama is now handled through the llm backend
|
692 |
+
elif args.backend == "llm":
|
693 |
+
if "gemini" in args.model.lower() or "gemma" in args.model.lower():
|
694 |
+
print(
|
695 |
+
"\nMake sure you have the GEMINI_API_KEY environment variable set:"
|
696 |
+
)
|
697 |
+
print("export GEMINI_API_KEY=your_api_key")
|
698 |
+
print("\nAnd make sure llm-gemini is installed:")
|
699 |
+
print("llm install llm-gemini")
|
700 |
+
elif "mlx" in args.model.lower():
|
701 |
+
print("\nMake sure llm-mlx is installed:")
|
702 |
+
print("llm install llm-mlx")
|
703 |
+
elif "ollama" in args.model.lower():
|
704 |
+
print("\nMake sure Ollama is installed and running:")
|
705 |
+
print("1. Install from https://ollama.ai/")
|
706 |
+
print("2. Start Ollama with: ollama serve")
|
707 |
+
print("3. Install the llm-ollama plugin: llm install llm-ollama")
|
708 |
+
print(
|
709 |
+
f"4. Pull the model: ollama pull {args.model.replace('ollama/', '')}"
|
710 |
+
)
|
711 |
+
else:
|
712 |
+
print("\nMake sure Simon Willison's LLM tool is installed:")
|
713 |
+
print("pip install llm")
|
714 |
+
|
715 |
+
|
716 |
+
# If running as a script
|
717 |
+
if __name__ == "__main__":
|
718 |
+
main()
|