Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
d849643
1
Parent(s):
a0c28a9
add trending models and datasets fetching endpoints with summaries
Browse files
main.py
CHANGED
@@ -14,12 +14,15 @@ from huggingface_hub import HfApi
|
|
14 |
from transformers import AutoTokenizer
|
15 |
import torch
|
16 |
import dateutil.parser
|
|
|
|
|
17 |
|
18 |
# Configuration constants
|
19 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
20 |
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
|
21 |
BATCH_SIZE = 2000
|
22 |
CACHE_TTL = "60"
|
|
|
23 |
|
24 |
if torch.cuda.is_available():
|
25 |
DEVICE = "cuda"
|
@@ -463,6 +466,156 @@ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
|
|
463 |
return query_results
|
464 |
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
if __name__ == "__main__":
|
467 |
import uvicorn
|
468 |
|
|
|
14 |
from transformers import AutoTokenizer
|
15 |
import torch
|
16 |
import dateutil.parser
|
17 |
+
import httpx
|
18 |
+
from datetime import datetime
|
19 |
|
20 |
# Configuration constants
|
21 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
22 |
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
|
23 |
BATCH_SIZE = 2000
|
24 |
CACHE_TTL = "60"
|
25 |
+
TRENDING_CACHE_TTL = "900" # 15 minutes cache for trending data
|
26 |
|
27 |
if torch.cuda.is_available():
|
28 |
DEVICE = "cuda"
|
|
|
466 |
return query_results
|
467 |
|
468 |
|
469 |
+
async def fetch_trending_models():
|
470 |
+
"""Fetch trending models from HuggingFace API"""
|
471 |
+
async with httpx.AsyncClient() as client:
|
472 |
+
response = await client.get("https://huggingface.co/api/models")
|
473 |
+
response.raise_for_status()
|
474 |
+
return response.json()
|
475 |
+
|
476 |
+
|
477 |
+
@cache(ttl=TRENDING_CACHE_TTL)
|
478 |
+
async def get_trending_models_with_summaries(
|
479 |
+
limit: int = 10,
|
480 |
+
min_likes: int = 0,
|
481 |
+
min_downloads: int = 0,
|
482 |
+
) -> List[ModelQueryResult]:
|
483 |
+
"""Fetch trending models and combine with summaries from database"""
|
484 |
+
try:
|
485 |
+
# Fetch trending models
|
486 |
+
trending_models = await fetch_trending_models()
|
487 |
+
|
488 |
+
# Filter by minimum likes/downloads
|
489 |
+
trending_models = [
|
490 |
+
model
|
491 |
+
for model in trending_models
|
492 |
+
if model.get("likes", 0) >= min_likes
|
493 |
+
and model.get("downloads", 0) >= min_downloads
|
494 |
+
]
|
495 |
+
|
496 |
+
# Sort by trending score and limit
|
497 |
+
trending_models = sorted(
|
498 |
+
trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True
|
499 |
+
)[:limit]
|
500 |
+
|
501 |
+
# Get model IDs
|
502 |
+
model_ids = [model["modelId"] for model in trending_models]
|
503 |
+
|
504 |
+
# Fetch summaries from ChromaDB
|
505 |
+
collection = client.get_collection("model_cards")
|
506 |
+
summaries = collection.get(ids=model_ids, include=["documents"])
|
507 |
+
|
508 |
+
# Create mapping of model_id to summary
|
509 |
+
id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
|
510 |
+
|
511 |
+
# Combine data
|
512 |
+
results = []
|
513 |
+
for model in trending_models:
|
514 |
+
if model["modelId"] in id_to_summary:
|
515 |
+
result = ModelQueryResult(
|
516 |
+
model_id=model["modelId"],
|
517 |
+
similarity=1.0, # Not applicable for trending
|
518 |
+
summary=id_to_summary[model["modelId"]],
|
519 |
+
likes=model.get("likes", 0),
|
520 |
+
downloads=model.get("downloads", 0),
|
521 |
+
)
|
522 |
+
results.append(result)
|
523 |
+
|
524 |
+
return results
|
525 |
+
|
526 |
+
except Exception as e:
|
527 |
+
logger.error(f"Error fetching trending models: {str(e)}")
|
528 |
+
raise HTTPException(status_code=500, detail="Failed to fetch trending models")
|
529 |
+
|
530 |
+
|
531 |
+
@app.get("/trending/models", response_model=ModelQueryResponse)
|
532 |
+
async def get_trending_models(
|
533 |
+
limit: int = Query(default=10, ge=1, le=100),
|
534 |
+
min_likes: int = Query(default=0, ge=0),
|
535 |
+
min_downloads: int = Query(default=0, ge=0),
|
536 |
+
):
|
537 |
+
"""Get trending models with their summaries"""
|
538 |
+
results = await get_trending_models_with_summaries(
|
539 |
+
limit=limit, min_likes=min_likes, min_downloads=min_downloads
|
540 |
+
)
|
541 |
+
return ModelQueryResponse(results=results)
|
542 |
+
|
543 |
+
|
544 |
+
async def fetch_trending_datasets():
|
545 |
+
"""Fetch trending datasets from HuggingFace API"""
|
546 |
+
async with httpx.AsyncClient() as client:
|
547 |
+
response = await client.get("https://huggingface.co/api/datasets")
|
548 |
+
response.raise_for_status()
|
549 |
+
return response.json()
|
550 |
+
|
551 |
+
|
552 |
+
@cache(ttl=TRENDING_CACHE_TTL)
|
553 |
+
async def get_trending_datasets_with_summaries(
|
554 |
+
limit: int = 10,
|
555 |
+
min_likes: int = 0,
|
556 |
+
min_downloads: int = 0,
|
557 |
+
) -> List[QueryResult]:
|
558 |
+
"""Fetch trending datasets and combine with summaries from database"""
|
559 |
+
try:
|
560 |
+
# Fetch trending datasets
|
561 |
+
trending_datasets = await fetch_trending_datasets()
|
562 |
+
|
563 |
+
# Filter by minimum likes/downloads
|
564 |
+
trending_datasets = [
|
565 |
+
dataset
|
566 |
+
for dataset in trending_datasets
|
567 |
+
if dataset.get("likes", 0) >= min_likes
|
568 |
+
and dataset.get("downloads", 0) >= min_downloads
|
569 |
+
]
|
570 |
+
|
571 |
+
# Sort by trending score and limit
|
572 |
+
trending_datasets = sorted(
|
573 |
+
trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True
|
574 |
+
)[:limit]
|
575 |
+
|
576 |
+
# Get dataset IDs
|
577 |
+
dataset_ids = [dataset["id"] for dataset in trending_datasets]
|
578 |
+
|
579 |
+
# Fetch summaries from ChromaDB
|
580 |
+
collection = client.get_collection("dataset_cards")
|
581 |
+
summaries = collection.get(ids=dataset_ids, include=["documents"])
|
582 |
+
|
583 |
+
# Create mapping of dataset_id to summary
|
584 |
+
id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
|
585 |
+
|
586 |
+
# Combine data
|
587 |
+
results = []
|
588 |
+
for dataset in trending_datasets:
|
589 |
+
if dataset["id"] in id_to_summary:
|
590 |
+
result = QueryResult(
|
591 |
+
dataset_id=dataset["id"],
|
592 |
+
similarity=1.0, # Not applicable for trending
|
593 |
+
summary=id_to_summary[dataset["id"]],
|
594 |
+
likes=dataset.get("likes", 0),
|
595 |
+
downloads=dataset.get("downloads", 0),
|
596 |
+
)
|
597 |
+
results.append(result)
|
598 |
+
|
599 |
+
return results
|
600 |
+
|
601 |
+
except Exception as e:
|
602 |
+
logger.error(f"Error fetching trending datasets: {str(e)}")
|
603 |
+
raise HTTPException(status_code=500, detail="Failed to fetch trending datasets")
|
604 |
+
|
605 |
+
|
606 |
+
@app.get("/trending/datasets", response_model=QueryResponse)
|
607 |
+
async def get_trending_datasets(
|
608 |
+
limit: int = Query(default=10, ge=1, le=100),
|
609 |
+
min_likes: int = Query(default=0, ge=0),
|
610 |
+
min_downloads: int = Query(default=0, ge=0),
|
611 |
+
):
|
612 |
+
"""Get trending datasets with their summaries"""
|
613 |
+
results = await get_trending_datasets_with_summaries(
|
614 |
+
limit=limit, min_likes=min_likes, min_downloads=min_downloads
|
615 |
+
)
|
616 |
+
return QueryResponse(results=results)
|
617 |
+
|
618 |
+
|
619 |
if __name__ == "__main__":
|
620 |
import uvicorn
|
621 |
|