Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from rag_system import build_rag_chain, ask_question | |
from vector_store import get_embeddings, load_vector_store | |
from llm_loader import load_llama_model | |
import uuid | |
import os | |
import shutil | |
from urllib.parse import urljoin, quote | |
from fastapi.responses import StreamingResponse | |
import json | |
import time | |
app = FastAPI() | |
# ์ ์ ํ์ผ ์๋น์ ์ํ ์ค์ | |
os.makedirs("static/documents", exist_ok=True) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# ์ ์ญ ๊ฐ์ฒด ์ค๋น | |
embeddings = get_embeddings(device="cpu") | |
vectorstore = load_vector_store(embeddings, load_path="vector_db") | |
llm = load_llama_model() | |
qa_chain = build_rag_chain(llm, vectorstore, language="ko", k=7) | |
# ์๋ฒ URL ์ค์ (์ค์ ํ๊ฒฝ์ ๋ง๊ฒ ์์ ํ์) | |
BASE_URL = "http://220.124.155.35:8500" | |
class Question(BaseModel): | |
question: str | |
def get_document_url(source_path): | |
if not source_path or source_path == 'N/A': | |
return None | |
filename = os.path.basename(source_path) | |
dataset_root = os.path.join(os.getcwd(), "dataset") | |
# dataset ์ ์ฒด ํ์ ํด๋์์ ํ์ผ๋ช ์ผ์นํ๋ ํ์ผ ์ฐพ๊ธฐ | |
found_path = None | |
for root, dirs, files in os.walk(dataset_root): | |
if filename in files: | |
found_path = os.path.join(root, filename) | |
break | |
if not found_path or not os.path.exists(found_path): | |
return None | |
static_path = f"static/documents/{filename}" | |
shutil.copy2(found_path, static_path) | |
encoded_filename = quote(filename) | |
return urljoin(BASE_URL, f"/static/documents/{encoded_filename}") | |
def create_download_link(url, filename): | |
return f'์ถ์ฒ: [{filename}]({url})' | |
def ask(question: Question): | |
result = ask_question(qa_chain, question.question) | |
# ์์ค ๋ฌธ์ ์ ๋ณด ์ฒ๋ฆฌ | |
sources = [] | |
for doc in result["source_documents"]: | |
source_path = doc.metadata.get('source', 'N/A') | |
document_url = get_document_url(source_path) if source_path != 'N/A' else None | |
source_info = { | |
"source": source_path, | |
"content": doc.page_content, | |
"page": doc.metadata.get('page', 'N/A'), | |
"document_url": document_url, | |
"filename": os.path.basename(source_path) if source_path != 'N/A' else None | |
} | |
sources.append(source_info) | |
return { | |
"answer": result['result'].split("A:")[-1].strip() if "A:" in result['result'] else result['result'].strip(), | |
"sources": sources | |
} | |
def list_models(): | |
return JSONResponse({ | |
"object": "list", | |
"data": [ | |
{ | |
"id": "rag", | |
"object": "model", | |
"owned_by": "local", | |
} | |
] | |
}) | |
async def openai_compatible_chat(request: Request): | |
payload = await request.json() | |
messages = payload.get("messages", []) | |
user_input = messages[-1]["content"] if messages else "" | |
stream = payload.get("stream", False) | |
result = ask_question(qa_chain, user_input) | |
answer = result['result'] | |
# ์์ค ๋ฌธ์ ์ ๋ณด ์ฒ๋ฆฌ | |
sources = [] | |
for doc in result["source_documents"]: | |
source_path = doc.metadata.get('source', 'N/A') | |
document_url = get_document_url(source_path) if source_path != 'N/A' else None | |
filename = os.path.basename(source_path) if source_path != 'N/A' else None | |
source_info = { | |
"source": source_path, | |
"content": doc.page_content, | |
"page": doc.metadata.get('page', 'N/A'), | |
"document_url": document_url, | |
"filename": filename | |
} | |
sources.append(source_info) | |
# ์์ค ์ ๋ณด๋ฅผ ํ ์ค์ฉ๋ง ์ถ๋ ฅ | |
sources_md = "\n์ฐธ๊ณ ๋ฌธ์:\n" | |
seen = set() | |
for source in sources: | |
key = (source['filename'], source['document_url']) | |
if source['document_url'] and source['filename'] and key not in seen: | |
sources_md += f"์ถ์ฒ: [{source['filename']}]({source['document_url']})\n" | |
seen.add(key) | |
final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip() | |
final_answer += sources_md | |
if not stream: | |
return JSONResponse({ | |
"id": f"chatcmpl-{uuid.uuid4()}", | |
"object": "chat.completion", | |
"choices": [{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": final_answer | |
}, | |
"finish_reason": "stop" | |
}], | |
"model": "rag", | |
}) | |
# ์คํธ๋ฆฌ๋ฐ ์๋ต์ ์ํ generator | |
def event_stream(): | |
# ๋ต๋ณ ๋ณธ๋ฌธ๋ง ๋จผ์ ์คํธ๋ฆฌ๋ฐ | |
answer_main = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip() | |
for char in answer_main: | |
chunk = { | |
"id": f"chatcmpl-{uuid.uuid4()}", | |
"object": "chat.completion.chunk", | |
"choices": [{ | |
"index": 0, | |
"delta": { | |
"content": char | |
}, | |
"finish_reason": None | |
}] | |
} | |
yield f"data: {json.dumps(chunk)}\n\n" | |
time.sleep(0.005) | |
# ์ฐธ๊ณ ๋ฌธ์(๋ค์ด๋ก๋ ๋งํฌ)๋ ๋ง์ง๋ง์ ํ ๋ฒ์ ๋ถ์ฌ์ ์ ์ก | |
sources_md = "\n์ฐธ๊ณ ๋ฌธ์:\n" | |
seen = set() | |
for source in sources: | |
key = (source['filename'], source['document_url']) | |
if source['document_url'] and source['filename'] and key not in seen: | |
sources_md += f"์ถ์ฒ: [{source['filename']}]({source['document_url']})\n" | |
seen.add(key) | |
if sources_md.strip() != "์ฐธ๊ณ ๋ฌธ์:": | |
chunk = { | |
"id": f"chatcmpl-{uuid.uuid4()}", | |
"object": "chat.completion.chunk", | |
"choices": [{ | |
"index": 0, | |
"delta": { | |
"content": sources_md | |
}, | |
"finish_reason": None | |
}] | |
} | |
yield f"data: {json.dumps(chunk)}\n\n" | |
done = { | |
"id": f"chatcmpl-{uuid.uuid4()}", | |
"object": "chat.completion.chunk", | |
"choices": [{ | |
"index": 0, | |
"delta": {}, | |
"finish_reason": "stop" | |
}] | |
} | |
yield f"data: {json.dumps(done)}\n\n" | |
return | |
return StreamingResponse(event_stream(), media_type="text/event-stream") | |