bart-cls / main.py
astro21's picture
Update main.py
6c7afbb
raw
history blame contribute delete
799 Bytes
from fastapi import FastAPI
from gradio_client import Client
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Add CORS middleware to allow requests from any origin (for development)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Define a route for the prediction using FastAPI
@app.post("/predict")
async def predict(text: str):
# Replace this URL with the actual API endpoint URL
api_endpoint_url = "https://astro21-test-2.hf.space/--replicas/7592n/"
# Use the Gradio client to make the prediction request
client = Client(api_endpoint_url)
result = client.predict(
text,
api_name="/predict"
)
# Return the result as a response
return {"result": result}