Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
3408aae
1
Parent(s):
5d6ca81
rename
Browse files
main.py
CHANGED
|
@@ -9,9 +9,13 @@ from httpx import AsyncClient
|
|
| 9 |
from huggingface_hub import DatasetCard
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from starlette.responses import RedirectResponse
|
| 12 |
-
from starlette.status import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
from
|
| 15 |
|
| 16 |
# Set up logging
|
| 17 |
logging.basicConfig(
|
|
@@ -97,6 +101,14 @@ class DatasetCardNotFoundError(HTTPException):
|
|
| 97 |
)
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
@app.get("/similar", response_model=QueryResponse)
|
| 101 |
@cache(ttl="1h")
|
| 102 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
@@ -115,7 +127,9 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 115 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
| 116 |
logger.info(f"Dataset {dataset_id} added to collection")
|
| 117 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
raise
|
| 120 |
except Exception as e:
|
| 121 |
logger.error(
|
|
@@ -157,6 +171,44 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 157 |
) from e
|
| 158 |
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
if __name__ == "__main__":
|
| 161 |
import uvicorn
|
| 162 |
|
|
|
|
| 9 |
from huggingface_hub import DatasetCard
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from starlette.responses import RedirectResponse
|
| 12 |
+
from starlette.status import (
|
| 13 |
+
HTTP_404_NOT_FOUND,
|
| 14 |
+
HTTP_500_INTERNAL_SERVER_ERROR,
|
| 15 |
+
HTTP_403_FORBIDDEN,
|
| 16 |
+
)
|
| 17 |
|
| 18 |
+
from load_card_data import get_embedding_function, get_save_path, refresh_data
|
| 19 |
|
| 20 |
# Set up logging
|
| 21 |
logging.basicConfig(
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
|
| 104 |
+
class DatasetNotForAllAudiencesError(HTTPException):
|
| 105 |
+
def __init__(self, dataset_id: str):
|
| 106 |
+
super().__init__(
|
| 107 |
+
status_code=HTTP_403_FORBIDDEN,
|
| 108 |
+
detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
@app.get("/similar", response_model=QueryResponse)
|
| 113 |
@cache(ttl="1h")
|
| 114 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
|
| 127 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
| 128 |
logger.info(f"Dataset {dataset_id} added to collection")
|
| 129 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 130 |
+
if result.get("not-for-all-audiences"):
|
| 131 |
+
raise DatasetNotForAllAudiencesError(dataset_id)
|
| 132 |
+
except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError):
|
| 133 |
raise
|
| 134 |
except Exception as e:
|
| 135 |
logger.error(
|
|
|
|
| 171 |
) from e
|
| 172 |
|
| 173 |
|
| 174 |
+
@app.post("/similar_by_text", response_model=QueryResponse)
|
| 175 |
+
@cache(ttl="1h")
|
| 176 |
+
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
|
| 177 |
+
try:
|
| 178 |
+
logger.info(f"Querying datasets by text: {query}")
|
| 179 |
+
collection = client.get_collection(
|
| 180 |
+
name="dataset_cards", embedding_function=get_embedding_function()
|
| 181 |
+
)
|
| 182 |
+
print(query)
|
| 183 |
+
query_result = collection.query(
|
| 184 |
+
query_texts=query, n_results=n, include=["distances"]
|
| 185 |
+
)
|
| 186 |
+
print(query_result)
|
| 187 |
+
|
| 188 |
+
if not query_result["ids"]:
|
| 189 |
+
logger.info(f"No similar datasets found for query: {query}")
|
| 190 |
+
raise HTTPException(
|
| 191 |
+
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Prepare the response
|
| 195 |
+
results = [
|
| 196 |
+
QueryResult(dataset_id=str(id), similarity=float(1 - distance))
|
| 197 |
+
for id, distance in zip(
|
| 198 |
+
query_result["ids"][0], query_result["distances"][0]
|
| 199 |
+
)
|
| 200 |
+
]
|
| 201 |
+
logger.info(f"Found {len(results)} similar datasets for query: {query}")
|
| 202 |
+
return QueryResponse(results=results)
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error querying datasets by text {query}: {str(e)}")
|
| 206 |
+
raise HTTPException(
|
| 207 |
+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
| 208 |
+
detail="An unexpected error occurred.",
|
| 209 |
+
) from e
|
| 210 |
+
|
| 211 |
+
|
| 212 |
if __name__ == "__main__":
|
| 213 |
import uvicorn
|
| 214 |
|