Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
from typing import Optional, List, Tuple, Generator | |
import time | |
from functools import partial | |
import logging | |
import asyncio | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ChatInterface: | |
def __init__(self, text_model: str, image_model: str, hf_token: str): | |
"""Initialize the chat interface with specified models and token.""" | |
self.text_client = InferenceClient(text_model, token=hf_token) | |
self.image_client = InferenceClient(image_model, token=hf_token) | |
self.custom_responses = self._initialize_custom_responses() | |
self.system_prompt = self._initialize_system_prompt() | |
def _initialize_system_prompt() -> str: | |
"""Initialize the system prompt for the AI assistant.""" | |
return """# Xylaria AI Assistant (v1.3.0) | |
## Core Identity | |
- Name: Xylaria | |
- Version: 1.3.0 | |
- Base Model: Mistral-Nemo-Instruct | |
- Knowledge Cutoff: April 2024 | |
## Primary Directives | |
1. Provide accurate, well-researched information | |
2. Maintain ethical standards in all interactions | |
3. Adapt communication style to user needs | |
4. Acknowledge limitations and uncertainties | |
5. Prioritize user safety and wellbeing | |
## Technical Capabilities | |
- Programming & Software Development | |
- Mathematical Analysis & Computation | |
- Scientific Research & Explanation | |
- Data Analysis & Visualization | |
- Technical Writing & Documentation | |
- Problem-Solving & Debugging | |
- Educational Content Creation | |
## Communication Guidelines | |
- Use clear, precise language | |
- Adapt technical depth to user expertise | |
- Provide step-by-step explanations when needed | |
- Ask for clarification when necessary | |
- Maintain professional yet approachable tone | |
## Domain Expertise | |
1. Computer Science & Technology | |
- Multiple programming languages | |
- Software architecture & design | |
- Data structures & algorithms | |
- Best practices & patterns | |
2. Mathematics & Statistics | |
- Advanced mathematical concepts | |
- Statistical analysis | |
- Probability theory | |
- Data interpretation | |
3. Sciences | |
- Physics & Chemistry | |
- Biology & Life Sciences | |
- Environmental Science | |
- Engineering Principles | |
4. Humanities & Arts | |
- Technical Writing | |
- Documentation | |
- Creative Problem-Solving | |
- Research Methodology | |
## Response Framework | |
1. Analyze user query thoroughly | |
2. Consider context and background | |
3. Structure response logically | |
4. Provide examples when helpful | |
5. Verify accuracy of information | |
6. Include relevant caveats or limitations | |
## Ethical Guidelines | |
- Prioritize user safety | |
- Maintain data privacy | |
- Avoid harmful content | |
- Acknowledge uncertainties | |
- Provide balanced perspectives | |
- Respect intellectual property | |
## Limitations | |
- No real-time data access | |
- No persistent memory between sessions | |
- Cannot verify external sources | |
- No capability to execute code | |
- Limited to text and basic image generation | |
## Version-Specific Features | |
- Enhanced error handling | |
- Improved response consistency | |
- Better context awareness | |
- Advanced technical explanation capabilities | |
- Robust ethical framework""" | |
def _initialize_custom_responses() -> dict: | |
"""Initialize custom response patterns in a more maintainable way.""" | |
base_patterns = { | |
"name": ["xylaria"], | |
"developer": ["sk md saad amin"], | |
"strawberry_r": ["3"] | |
} | |
patterns = {} | |
name_variations = [ | |
"what is ur name", "what's ur name", "whats ur name", | |
"what is your name", "wat is ur name", "wut is ur name" | |
] | |
dev_variations = [ | |
"who is your developer", "who is ur developer", "who is ur dev", | |
"who's your developer", "who's ur dev" | |
] | |
strawberry_variations = [ | |
"how many 'r' is in strawberry", "how many r is in strawberry", | |
"how many r's are in strawberry" | |
] | |
for pattern in name_variations: | |
patterns[pattern] = "xylaria" | |
patterns[pattern.capitalize()] = "xylaria" | |
for pattern in dev_variations: | |
patterns[pattern] = "sk md saad amin" | |
patterns[pattern.capitalize()] = "sk md saad amin" | |
for pattern in strawberry_variations: | |
patterns[pattern] = "3" | |
patterns[pattern.capitalize()] = "3" | |
return patterns | |
async def _generate_text_response( | |
self, | |
messages: List[dict], | |
max_tokens: int, | |
temperature: float, | |
top_p: float | |
) -> Generator[str, None, None]: | |
"""Generate text response with retry logic.""" | |
try: | |
response = "" | |
async for message in self.text_client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
timeout=30 | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
except Exception as e: | |
logger.error(f"Error generating text response: {e}") | |
yield "I apologize, but I'm having trouble generating a response right now. Please try again in a moment." | |
async def _generate_image(self, prompt: str) -> Optional[bytes]: | |
"""Generate image with retry logic.""" | |
try: | |
return await self.image_client.text_to_image( | |
prompt, | |
parameters={ | |
"negative_prompt": "(worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", | |
"num_inference_steps": 30, | |
"guidance_scale": 7.5, | |
"sampling_steps": 15, | |
"upscaler": "4x-UltraSharp", | |
"denoising_strength": 0.5, | |
}, | |
timeout=60 | |
) | |
except Exception as e: | |
logger.error(f"Error generating image: {e}") | |
return None | |
def is_image_request(self, message: str) -> bool: | |
"""Detect if the message is requesting image generation.""" | |
image_triggers = { | |
"generate an image", "create an image", "draw", | |
"make a picture", "generate a picture", "create a picture", | |
"generate art", "create art", "make art", "visualize", | |
"show me" | |
} | |
return any(trigger in message.lower() for trigger in image_triggers) | |
async def respond( | |
self, | |
message: str, | |
history: List[Tuple[str, str]], | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
) -> Generator[str, None, None]: | |
"""Main response handler with improved error handling.""" | |
try: | |
# Check for custom responses first | |
message_lower = message.lower() | |
for pattern, response in self.custom_responses.items(): | |
if pattern in message_lower: | |
yield response | |
return | |
# Handle image generation requests | |
if self.is_image_request(message): | |
image = await self._generate_image(message) | |
if image: | |
yield f"Here's your generated image based on: {message}" | |
else: | |
yield "I apologize, but I couldn't generate the image. Please try again." | |
return | |
# Prepare conversation history with system prompt | |
messages = [{"role": "system", "content": self.system_prompt}] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
# Generate text response | |
async for response in self._generate_text_response( | |
messages, max_tokens, temperature, top_p | |
): | |
yield response | |
except Exception as e: | |
logger.error(f"Error in respond function: {e}") | |
yield "I encountered an error. Please try again or contact support if the issue persists." | |
def create_interface(hf_token: str): | |
"""Create and configure the Gradio interface.""" | |
chat = ChatInterface( | |
text_model="mistralai/Mistral-Nemo-Instruct-2407", | |
image_model="SG161222/RealVisXL_V3.0", | |
hf_token=hf_token | |
) | |
return gr.ChatInterface( | |
partial(chat.respond), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=1, | |
maximum=16343, | |
value=16343, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
css=""" | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); | |
body, .gradio-container { | |
font-family: 'Inter', sans-serif; | |
} | |
""" | |
) | |
if __name__ == "__main__": | |
# Get token from environment variable | |
hf_token = os.getenv("hf_token") | |
if not hf_token: | |
raise ValueError("Please set the hf_token environment variable") | |
# Create and launch the interface | |
demo = create_interface(hf_token) | |
demo.launch() |