|
import os |
|
import requests |
|
import json |
|
import logging |
|
import time |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
API_KEYS = { |
|
"HUGGINGFACE": 'HF_TOKEN', |
|
"GROQ": 'GROQ_API_KEY', |
|
"OPENROUTER": 'OPENROUTER_API_KEY', |
|
"TOGETHERAI": 'TOGETHERAI_API_KEY', |
|
"COHERE": 'COHERE_API_KEY', |
|
"XAI": 'XAI_API_KEY', |
|
"OPENAI": 'OPENAI_API_KEY', |
|
"GOOGLE": 'GOOGLE_API_KEY', |
|
} |
|
|
|
API_URLS = { |
|
"HUGGINGFACE": 'https://api-inference.huggingface.co/models/', |
|
"GROQ": 'https://api.groq.com/openai/v1/chat/completions', |
|
"OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', |
|
"TOGETHERAI": 'https://api.together.ai/v1/chat/completions', |
|
"COHERE": 'https://api.cohere.ai/v1/chat', |
|
"XAI": 'https://api.x.ai/v1/chat/completions', |
|
"OPENAI": 'https://api.openai.com/v1/chat/completions', |
|
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', |
|
} |
|
|
|
try: |
|
with open("models.json", "r") as f: |
|
MODELS_BY_PROVIDER = json.load(f) |
|
logger.info("models.json loaded successfully.") |
|
except (FileNotFoundError, json.JSONDecodeError) as e: |
|
logger.error(f"Error loading models.json: {e}. Using hardcoded fallback models.") |
|
MODELS_BY_PROVIDER = { |
|
"groq": { |
|
"default": "llama3-8b-8192", |
|
"models": { |
|
"Llama 3 8B (Groq)": "llama3-8b-8192", |
|
"Llama 3 70B (Groq)": "llama3-70b-8192", |
|
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768", |
|
"Gemma 7B (Groq)": "gemma-7b-it", |
|
} |
|
}, |
|
"openrouter": { |
|
"default": "nousresearch/llama-3-8b-instruct", |
|
"models": { |
|
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct", |
|
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free", |
|
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free", |
|
} |
|
}, |
|
"google": { |
|
"default": "gemini-1.5-flash-latest", |
|
"models": { |
|
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest", |
|
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest", |
|
} |
|
}, |
|
"openai": { |
|
"default": "gpt-3.5-turbo", |
|
"models": { |
|
"GPT-4o mini (OpenAI)": "gpt-4o-mini", |
|
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo", |
|
} |
|
}, |
|
} |
|
|
|
|
|
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str: |
|
if ui_api_key_override: |
|
logger.debug(f"Using UI API key override for {provider}") |
|
return ui_api_key_override.strip() |
|
|
|
env_var_name = API_KEYS.get(provider.upper()) |
|
if env_var_name: |
|
env_key = os.getenv(env_var_name) |
|
if env_key: |
|
logger.debug(f"Using env var {env_var_name} for {provider}") |
|
return env_key.strip() |
|
|
|
if provider.lower() == 'huggingface': |
|
hf_token = os.getenv("HF_TOKEN") |
|
if hf_token: |
|
logger.debug(f"Using HF_TOKEN env var for {provider}") |
|
return hf_token.strip() |
|
|
|
logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.") |
|
return None |
|
|
|
def get_available_providers() -> list[str]: |
|
return sorted(list(MODELS_BY_PROVIDER.keys())) |
|
|
|
def get_models_for_provider(provider: str) -> list[str]: |
|
return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys())) |
|
|
|
def get_default_model_for_provider(provider: str) -> str | None: |
|
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}) |
|
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default") |
|
if default_model_id: |
|
for display_name, model_id in models_dict.items(): |
|
if model_id == default_model_id: |
|
return display_name |
|
if models_dict: |
|
return sorted(list(models_dict.keys()))[0] |
|
return None |
|
|
|
def get_model_id_from_display_name(provider: str, display_name: str) -> str | None: |
|
models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}) |
|
return models.get(display_name) |
|
|
|
def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter: |
|
provider_lower = provider.lower() |
|
api_key = _get_api_key(provider_lower, api_key_override) |
|
|
|
base_url = API_URLS.get(provider.upper()) |
|
model_id = get_model_id_from_display_name(provider_lower, model_display_name) |
|
|
|
if not api_key: |
|
env_var_name = API_KEYS.get(provider.upper(), 'N/A') |
|
yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'." |
|
return |
|
if not base_url: |
|
yield f"Error: Unknown provider '{provider}' or missing API URL configuration." |
|
return |
|
if not model_id: |
|
yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model." |
|
return |
|
|
|
headers = {} |
|
payload = {} |
|
request_url = base_url |
|
timeout_seconds = 180 |
|
|
|
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...") |
|
|
|
try: |
|
if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]: |
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
|
payload = { |
|
"model": model_id, |
|
"messages": messages, |
|
"stream": True, |
|
"temperature": 0.7, |
|
"max_tokens": 4096 |
|
} |
|
if provider_lower == "openrouter": |
|
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/huggingface/ai-space-commander" |
|
headers["X-Title"] = "Hugging Face Space Commander" |
|
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds) |
|
response.raise_for_status() |
|
|
|
byte_buffer = b"" |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if response.status_code != 200: |
|
error_body = response.text |
|
logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}") |
|
yield f"API HTTP Error ({response.status_code}) during stream: {error_body}" |
|
return |
|
|
|
byte_buffer += chunk |
|
while b'\n' in byte_buffer: |
|
line, byte_buffer = byte_buffer.split(b'\n', 1) |
|
decoded_line = line.decode('utf-8', errors='ignore') |
|
if decoded_line.startswith('data: '): |
|
data = decoded_line[6:] |
|
if data == '[DONE]': |
|
byte_buffer = b'' |
|
break |
|
try: |
|
event_data = json.loads(data) |
|
if event_data.get("choices") and len(event_data["choices"]) > 0: |
|
delta = event_data["choices"][0].get("delta") |
|
if delta and delta.get("content"): |
|
yield delta["content"] |
|
except json.JSONDecodeError: |
|
logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}") |
|
except Exception as e: |
|
logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}") |
|
if byte_buffer: |
|
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip() |
|
if remaining_line.startswith('data: '): |
|
data = remaining_line[6:] |
|
if data != '[DONE]': |
|
try: |
|
event_data = json.loads(data) |
|
if event_data.get("choices") and len(event_data["choices"]) > 0: |
|
delta = event_data["choices"][0].get("delta") |
|
if delta and delta.get("content"): |
|
yield delta["content"] |
|
except json.JSONDecodeError: |
|
logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}") |
|
except Exception as e: |
|
logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}") |
|
|
|
|
|
elif provider_lower == "google": |
|
system_instruction = None |
|
filtered_messages = [] |
|
for msg in messages: |
|
if msg["role"] == "system": |
|
if system_instruction: system_instruction += "\n" + msg["content"] |
|
else: system_instruction = msg["content"] |
|
else: |
|
role = "model" if msg["role"] == "assistant" else msg["role"] |
|
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]}) |
|
|
|
for i in range(1, len(filtered_messages)): |
|
if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]: |
|
yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format." |
|
return |
|
|
|
payload = { |
|
"contents": filtered_messages, |
|
"safetySettings": [ |
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, |
|
], |
|
"generationConfig": { |
|
"temperature": 0.7, |
|
"maxOutputTokens": 4096 |
|
} |
|
} |
|
if system_instruction: |
|
payload["system_instruction"] = {"parts": [{"text": system_instruction}]} |
|
|
|
request_url = f"{base_url}{model_id}:streamGenerateContent" |
|
request_url = f"{request_url}?key={api_key}" |
|
headers = {"Content-Type": "application/json"} |
|
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds) |
|
response.raise_for_status() |
|
|
|
byte_buffer = b"" |
|
json_decoder = json.JSONDecoder() |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if response.status_code != 200: |
|
error_body = response.text |
|
logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}") |
|
yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}" |
|
return |
|
|
|
byte_buffer += chunk |
|
decoded_buffer = byte_buffer.decode('utf-8', errors='ignore') |
|
buffer_index = 0 |
|
while buffer_index < len(decoded_buffer): |
|
try: |
|
obj, idx = json_decoder.raw_decode(decoded_buffer[buffer_index:].lstrip()) |
|
buffer_index += len(decoded_buffer[buffer_index:].lstrip()[:idx]) |
|
if obj.get("candidates") and len(obj["candidates"]) > 0: |
|
candidate = obj["candidates"][0] |
|
if candidate.get("content") and candidate["content"].get("parts"): |
|
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"]) |
|
if full_text_chunk: |
|
yield full_text_chunk |
|
if obj.get("error"): |
|
error_details = obj["error"].get("message", str(obj["error"])) |
|
logger.error(f"Google API returned error in stream data: {error_details}") |
|
yield f"API Error (Google): {error_details}" |
|
return |
|
except json.JSONDecodeError: |
|
break |
|
except Exception as e: |
|
logger.error(f"Error processing Google stream data object: {e}, Object: {obj}") |
|
|
|
if byte_buffer: |
|
logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}") |
|
|
|
|
|
elif provider_lower == "cohere": |
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
|
request_url = f"{base_url}" |
|
|
|
chat_history_for_cohere = [] |
|
system_prompt_for_cohere = None |
|
current_message_for_cohere = "" |
|
|
|
temp_history = [] |
|
for msg in messages: |
|
if msg["role"] == "system": |
|
if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"] |
|
else: system_prompt_for_cohere = msg["content"] |
|
elif msg["role"] == "user" or msg["role"] == "assistant": |
|
temp_history.append(msg) |
|
|
|
if not temp_history: |
|
yield "Error: No user message found for Cohere API call." |
|
return |
|
if temp_history[-1]["role"] != "user": |
|
yield "Error: Last message must be from user for Cohere API call." |
|
return |
|
|
|
current_message_for_cohere = temp_history[-1]["content"] |
|
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]] |
|
|
|
payload = { |
|
"model": model_id, |
|
"message": current_message_for_cohere, |
|
"stream": True, |
|
"temperature": 0.7, |
|
"max_tokens": 4096 |
|
} |
|
if chat_history_for_cohere: |
|
payload["chat_history"] = chat_history_for_cohere |
|
if system_prompt_for_cohere: |
|
payload["preamble"] = system_prompt_for_cohere |
|
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds) |
|
response.raise_for_status() |
|
|
|
byte_buffer = b"" |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if response.status_code != 200: |
|
error_body = response.text |
|
logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}") |
|
yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}" |
|
return |
|
|
|
byte_buffer += chunk |
|
while b'\n\n' in byte_buffer: |
|
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1) |
|
lines = event_chunk.strip().split(b'\n') |
|
event_type = None |
|
event_data = None |
|
|
|
for l in lines: |
|
if l.strip() == b"": continue |
|
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore') |
|
elif l.startswith(b"data: "): |
|
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore')) |
|
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}") |
|
else: |
|
logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}") |
|
|
|
if event_type == "text-generation" and event_data and "text" in event_data: |
|
yield event_data["text"] |
|
elif event_type == "stream-end": |
|
logger.debug("Cohere stream-end event received.") |
|
byte_buffer = b'' |
|
break |
|
elif event_type == "error": |
|
error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error" |
|
logger.error(f"Cohere stream error event: {error_msg}") |
|
yield f"API Error (Cohere stream): {error_msg}" |
|
return |
|
|
|
if byte_buffer: |
|
logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}") |
|
|
|
elif provider_lower == "huggingface": |
|
yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there." |
|
return |
|
|
|
else: |
|
yield f"Error: Unsupported provider '{provider}' for streaming chat." |
|
return |
|
|
|
except requests.exceptions.HTTPError as e: |
|
status_code = e.response.status_code if e.response is not None else 'N/A' |
|
error_text = e.response.text if e.response is not None else 'No response text' |
|
logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}") |
|
yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}" |
|
except requests.exceptions.Timeout: |
|
logger.error(f"Request Timeout after {timeout_seconds} seconds for {provider}/{model_id}.") |
|
yield f"API Request Timeout: The request took too long to complete ({timeout_seconds} seconds)." |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Request error during streaming for {provider}/{model_id}: {e}") |
|
yield f"API Request Error: Could not connect or receive response from {provider} ({e})" |
|
except Exception as e: |
|
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:") |
|
yield f"An unexpected error occurred during streaming: {e}" |