from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from huggingface_hub import InferenceClient import os HF_TOKEN = os.getenv("HF_TOKEN") MODEL_ID = "google/gemma-2b-it" client = InferenceClient(token=HF_TOKEN) app = FastAPI() # Allow CORS for all origins (for development and Netlify) app.add_middleware( CORSMiddleware, allow_origins=["*"], # For production, specify frontend domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ChatRequest(BaseModel): message: str # @app.post("/chat") # async def chat(req: ChatRequest): # try: # messages = [{"role": "user", "content": req.message}] # response = client.chat_completion( # model=MODEL_ID, # messages=messages, # temperature=0.7, # ) # return {"response": response.choices[0].message.content} # except Exception as e: # return {"error": str(e)} async def chat_endpoint(data: ChatRequest): def stream(): for chunk in client.text_generation( model=MODEL_ID, prompt=data.message, stream=True, max_new_tokens=512, temperature=0.7, ): yield chunk return StreamingResponse(stream(), media_type="text/plain")