linp / main.py
ka1kuk's picture
Update main.py
fd60e55
raw
history blame
1.49 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import asyncio
from Linlada import Chatbot, ConversationStyle
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
async def generate(prompt):
bot = await Chatbot.create()
result = await bot.ask(prompt=prompt, conversation_style=ConversationStyle.precise)
return result
def dummy(images, **kwargs):
return images, False
async def generate_image(prompt):
model_id = "runwayml/stable-diffusion-v1-5"
pipe = await StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.safety_checker = dummy
image = await pipe(prompt).images[0]
return image
@app.get("/")
def read_root():
return "Hello, I'm Linlada"
@app.get("/test/{hello}")
def hi(hello: str):
return {"text": hello}
@app.get("/image/{image}")
def img(image: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(generate_image(image))
loop.close()
return result
@app.get('/linlada/{prompt}')
def generate_image_route(prompt: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(generate(prompt))
loop.close()
return result