Spaces:
Running
Running
File size: 3,411 Bytes
0537a74 11511b6 61ea1a8 0537a74 61ea1a8 0537a74 552281c 0537a74 11511b6 506b0cf 11511b6 13255f6 61ea1a8 13255f6 61ea1a8 9edc07a 61ea1a8 0537a74 017ccfb 0537a74 13255f6 1907b08 0537a74 87317e6 61ea1a8 0537a74 11511b6 61ea1a8 506b0cf 61ea1a8 87317e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import re
from groq import Groq
# Initialize FastAPI app
app = FastAPI()
# Serve static files for assets
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize Hugging Face Inference Client
#client = InferenceClient()
client = Groq()
# Pydantic model for API input
class InfographicRequest(BaseModel):
description: str
# Load prompt template from environment variable
SYSTEM_INSTRUCT = os.getenv("SYSTEM_INSTRUCTOR")
PROMPT_TEMPLATE = os.getenv("PROMPT_TEMPLATE")
async def extract_code_blocks(markdown_text):
"""
Extracts code blocks from the given Markdown text.
Args:
markdown_text (str): The Markdown content as a string.
Returns:
list: A list of code blocks extracted from the Markdown.
"""
# Regex to match code blocks (fenced with triple backticks)
code_block_pattern = re.compile(r'```.*?\n(.*?)```', re.DOTALL)
# Find all code blocks
code_blocks = code_block_pattern.findall(markdown_text)
return code_blocks
async def generate_infographic(prompt):
generated_completion = client.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[
{"role": "system", "content": SYSTEM_INSTRUCT},
{"role": "user", "content": prompt}
],
temperature=0.5,
max_tokens=5000,
top_p=1,
stream=False,
stop=None
)
generated_text = generated_completion.choices[0].message.content
print(generated_text)
code_blocks= await extract_code_blocks(generated_text)
return code_blocks
# Route to serve the HTML template
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
return HTMLResponse(open("static/infographic_gen.html").read())
# Route to handle infographic generation
@app.post("/generate")
async def generate_infographic(request: InfographicRequest):
description = request.description
prompt = PROMPT_TEMPLATE.format(description=description)
code_blocks= await generate_infographic(prompt)
if code_blocks:
return JSONResponse(content={"html": code_blocks[0]})
else:
return JSONResponse(content={"error": "No generation"},status_code=500)
# try:
# messages = [{"role": "user", "content": prompt}]
# stream = client.chat.completions.create(
# model="Qwen/Qwen2.5-Coder-32B-Instruct",
# messages=messages,
# temperature=0.4,
# max_tokens=6000,
# top_p=0.7,
# stream=True,
# )
# generated_text = ""
# for chunk in stream:
# generated_text += chunk.choices[0].delta.content
# print(generated_text)
#code_blocks= await extract_code_blocks(generated_text)
# code_blocks= await generate_infographic(description)
# if code_blocks:
# return JSONResponse(content={"html": code_blocks[0]})
# else:
# return JSONResponse(content={"error": "No generation"},status_code=500)
# except Exception as e:
# return JSONResponse(content={"error": str(e)}, status_code=500)
|