|
import os |
|
import torch |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
from starlette.middleware.cors import CORSMiddleware |
|
|
|
|
|
app = FastAPI(title="Apollo AI Backend - Qwen2-0.5B", version="3.1.0-FIXED") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
API_KEY = os.getenv("API_KEY", "aigenapikey1234567890") |
|
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct" |
|
ADAPTER_PATH = "adapter" |
|
|
|
|
|
print("🔧 Loading tokenizer for Qwen2-0.5B...") |
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print("🧠 Loading Qwen2-0.5B base model...") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
BASE_MODEL, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float32, |
|
device_map="cpu" |
|
) |
|
|
|
print("🔗 Applying LoRA adapter to Qwen2-0.5B...") |
|
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) |
|
model.eval() |
|
|
|
print("✅ Qwen2-0.5B model ready!") |
|
|
|
def create_conversation_prompt(messages: list, is_force_mode: bool) -> str: |
|
""" |
|
Create a conversation prompt with STRONG mode enforcement |
|
""" |
|
if is_force_mode: |
|
system_prompt = """FORCE MODE - DIRECT ANSWERS ONLY: |
|
You MUST give direct, complete, factual answers. Do NOT ask questions. Provide exact solutions, working code, and clear explanations. |
|
|
|
EXAMPLE FORCE RESPONSE: |
|
Q: What does len() do in Python? |
|
A: len() returns the number of items in an object. Examples: |
|
- len([1,2,3]) returns 3 |
|
- len("hello") returns 5 |
|
- len({1,2,3}) returns 3 |
|
|
|
Always be direct and informative. Never ask "What do you think?" or similar questions.""" |
|
else: |
|
system_prompt = """MENTOR MODE - GUIDED LEARNING ONLY: |
|
You are a programming teacher. You MUST guide students to discover answers themselves. NEVER give direct answers or complete solutions. ALWAYS respond with guiding questions and hints. |
|
|
|
EXAMPLE MENTOR RESPONSE: |
|
Q: What does len() do in Python? |
|
A: Great question! What do you think might happen if you run len([1,2,3]) in Python? Can you guess what number it would return? Try it and see! What pattern do you notice? |
|
|
|
Always ask questions to guide learning. Never give direct answers.""" |
|
|
|
|
|
conversation = f"System: {system_prompt}\n\n" |
|
|
|
|
|
recent_messages = messages[-6:] if len(messages) > 6 else messages |
|
|
|
for msg in recent_messages: |
|
role = msg.get("role", "") |
|
content = msg.get("content", "") |
|
if role == "user": |
|
conversation += f"Student: {content}\n" |
|
elif role == "assistant": |
|
conversation += f"Assistant: {content}\n" |
|
|
|
conversation += "Assistant:" |
|
return conversation |
|
|
|
def validate_response_mode(response: str, is_force_mode: bool) -> str: |
|
""" |
|
CRITICAL: Enforce mode compliance in responses |
|
""" |
|
response = response.strip() |
|
|
|
if is_force_mode: |
|
|
|
has_questioning = any(phrase in response.lower() for phrase in [ |
|
"what do you think", "can you tell me", "what would happen", |
|
"try it", "guess", "what pattern", "how do you", "what's your" |
|
]) |
|
|
|
if has_questioning or response.count("?") > 1: |
|
|
|
print("🔧 Converting to direct answer for force mode") |
|
direct_parts = [] |
|
for sentence in response.split("."): |
|
if "?" not in sentence and len(sentence.strip()) > 10: |
|
direct_parts.append(sentence.strip()) |
|
|
|
if direct_parts: |
|
return ". ".join(direct_parts[:2]) + "." |
|
else: |
|
return "Here's the direct answer: " + response.split("?")[0].strip() + "." |
|
|
|
else: |
|
|
|
has_questions = "?" in response |
|
has_guidance = any(phrase in response.lower() for phrase in [ |
|
"what do you think", "can you", "try", "what would", "how do you", "what pattern" |
|
]) |
|
|
|
if not has_questions and not has_guidance: |
|
|
|
print("🔧 Adding guiding questions for mentor mode") |
|
return f"Interesting! {response} What do you think about this? Can you tell me what part makes most sense to you?" |
|
|
|
return response |
|
|
|
def generate_response(messages: list, is_force_mode: bool = False, max_tokens: int = 200, temperature: float = 0.7) -> str: |
|
""" |
|
Generate response using the AI model with STRONG mode enforcement |
|
""" |
|
try: |
|
|
|
prompt = create_conversation_prompt(messages, is_force_mode) |
|
|
|
print(f"🎯 Generating {'FORCE' if is_force_mode else 'MENTOR'} response with FIXED logic") |
|
print(f"🔍 DEBUG: force_mode = {is_force_mode}") |
|
print(f"📝 System prompt preview: {prompt.split('Student:')[0][:150]}...") |
|
|
|
|
|
if is_force_mode: |
|
|
|
generation_temp = 0.2 |
|
generation_tokens = min(max_tokens, 250) |
|
else: |
|
|
|
generation_temp = 0.4 |
|
generation_tokens = min(max_tokens, 200) |
|
|
|
print(f"🎛️ Using temperature: {generation_temp}, max_tokens: {generation_tokens}") |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=generation_tokens, |
|
temperature=generation_temp, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
top_p=0.9, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = full_response[len(prompt):].strip() |
|
|
|
|
|
response = response.replace("Student:", "").replace("Assistant:", "").replace("System:", "").strip() |
|
|
|
|
|
if "\n" in response: |
|
response = response.split("\n")[0].strip() |
|
|
|
print(f"✅ Raw generated response: {response[:100]}...") |
|
|
|
|
|
validated_response = validate_response_mode(response, is_force_mode) |
|
|
|
print(f"✅ Final validated response length: {len(validated_response)}") |
|
print(f"📝 Mode compliance: {'FORCE' if is_force_mode else 'MENTOR'}") |
|
|
|
if not validated_response or len(validated_response) < 10: |
|
|
|
if is_force_mode: |
|
return "len() returns the number of items in a sequence. For example: len([1,2,3]) returns 3, len('hello') returns 5." |
|
else: |
|
return "What do you think len() might do? Try running len([1,2,3]) and see what happens! What number do you get?" |
|
|
|
return validated_response |
|
|
|
except Exception as e: |
|
print(f"❌ Generation error: {e}") |
|
|
|
if is_force_mode: |
|
return "I need you to provide a more specific question so I can give you the exact answer you need." |
|
else: |
|
return "That's an interesting question! What do you think might be the answer? Can you break it down step by step?" |
|
|
|
|
|
@app.get("/") |
|
def root(): |
|
return { |
|
"message": "🤖 Apollo AI Backend v3.1-FIXED - Qwen2-0.5B", |
|
"model": "Qwen/Qwen2-0.5B-Instruct with LoRA", |
|
"status": "ready", |
|
"modes": { |
|
"mentor": "Guides learning with questions - FIXED", |
|
"force": "Provides direct answers - FIXED" |
|
}, |
|
"fixes": "Strong mode enforcement, response validation" |
|
} |
|
|
|
@app.get("/health") |
|
def health(): |
|
return { |
|
"status": "healthy", |
|
"model_loaded": True, |
|
"model_size": "0.5B", |
|
"version": "3.1-FIXED" |
|
} |
|
|
|
@app.post("/v1/chat/completions") |
|
async def chat_completions(request: Request): |
|
|
|
auth_header = request.headers.get("Authorization", "") |
|
if not auth_header.startswith("Bearer "): |
|
return JSONResponse( |
|
status_code=401, |
|
content={"error": "Missing or invalid Authorization header"} |
|
) |
|
|
|
token = auth_header.replace("Bearer ", "").strip() |
|
if token != API_KEY: |
|
return JSONResponse( |
|
status_code=401, |
|
content={"error": "Invalid API key"} |
|
) |
|
|
|
|
|
try: |
|
body = await request.json() |
|
messages = body.get("messages", []) |
|
max_tokens = min(body.get("max_tokens", 200), 400) |
|
temperature = max(0.1, min(body.get("temperature", 0.7), 1.0)) |
|
|
|
|
|
is_force_mode = body.get("force_mode", False) |
|
|
|
print(f"🚨 RECEIVED REQUEST - force_mode from body: {is_force_mode}") |
|
print(f"🚨 Type of force_mode: {type(is_force_mode)}") |
|
|
|
if not messages or not isinstance(messages, list): |
|
raise ValueError("Messages field is required and must be a list") |
|
|
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=400, |
|
content={"error": f"Invalid request body: {str(e)}"} |
|
) |
|
|
|
|
|
for i, msg in enumerate(messages): |
|
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: |
|
return JSONResponse( |
|
status_code=400, |
|
content={"error": f"Invalid message format at index {i}"} |
|
) |
|
|
|
try: |
|
print(f"📥 Processing request in {'FORCE' if is_force_mode else 'MENTOR'} mode - FIXED") |
|
print(f"📊 Total messages: {len(messages)}") |
|
print(f"🎯 CRITICAL - Mode flag received: {is_force_mode}") |
|
|
|
|
|
response_content = generate_response( |
|
messages=messages, |
|
is_force_mode=is_force_mode, |
|
max_tokens=max_tokens, |
|
temperature=temperature |
|
) |
|
|
|
print(f"✅ Generated response in {'FORCE' if is_force_mode else 'MENTOR'} mode") |
|
print(f"📝 Response preview: {response_content[:100]}...") |
|
|
|
return { |
|
"id": f"chatcmpl-apollo-{hash(str(messages)) % 10000}", |
|
"object": "chat.completion", |
|
"created": int(torch.tensor(0).item()), |
|
"model": f"qwen2-0.5b-{'force' if is_force_mode else 'mentor'}-fixed", |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": response_content |
|
}, |
|
"finish_reason": "stop" |
|
} |
|
], |
|
"usage": { |
|
"prompt_tokens": len(str(messages)), |
|
"completion_tokens": len(response_content), |
|
"total_tokens": len(str(messages)) + len(response_content) |
|
}, |
|
"apollo_mode": "force" if is_force_mode else "mentor", |
|
"mode_validation": "FIXED - Strong enforcement", |
|
"model_optimizations": "qwen2_0.5B_fixed" |
|
} |
|
|
|
except Exception as e: |
|
print(f"❌ Chat completion error: {e}") |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": f"Internal server error: {str(e)}"} |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print("🚀 Starting Apollo AI Backend v3.1-FIXED - Strong Mode Enforcement...") |
|
print("🧠 Model: Qwen/Qwen2-0.5B-Instruct (500M parameters)") |
|
print("🎯 Mentor Mode: FIXED - Always asks guiding questions") |
|
print("⚡ Force Mode: FIXED - Always gives direct answers") |
|
print("🔧 New: Response validation and mode enforcement") |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |