Spaces:
Sleeping
Sleeping
# Import required libraries | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List | |
#import redis | |
from transformers import BartForSequenceClassification, BartTokenizer, AutoTokenizer, AutoConfig, pipeline | |
from dotenv import load_dotenv | |
import os | |
# Load environment variables from .env file | |
load_dotenv('../.env') | |
# Access environment variables | |
redis_port = os.getenv("REDIS_PORT") | |
fastapi_port = os.getenv("FASTAPI_PORT") | |
print("Redis port:", redis_port) | |
print("FastAPI port:", fastapi_port) | |
# Initialize FastAPI app and Redis client | |
app = FastAPI() | |
#redis_client = redis.Redis(host='redis', port=6379) | |
# Load BART model and tokenizer | |
#model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli") | |
#tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli") | |
model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
def score_text_with_labels(model, text: list, labels: list, multi: bool=True): | |
#candidate_labels = ['travel', 'cooking', 'dancing', 'exploration'] | |
results = [result['scores'] for result in model(text, labels, multi_label=multi)] | |
#return dict(zip(labels, results['scores'])) | |
return results | |
def smooth_sequence(tweets_scores, window_size): | |
# Calculate the sum of scores for both labels for each tweet | |
tweet_sum_scores = [(sum(scores), index) for index, scores in enumerate(tweets_scores)] | |
# Sort tweets based on their sum scores, then by their original index to stabilize | |
sorted_tweets = sorted(tweet_sum_scores, key=lambda x: (x[0], x[1])) | |
# Extract the original indices of tweets after sorting | |
sorted_indices = [index for _, index in sorted_tweets] | |
# Create a new sequence based on sorted indices | |
smoothed_sequence = [tweets_scores[index] for index in sorted_indices] | |
return smoothed_sequence | |
def rerank_on_label(label: str): | |
return 200 | |
# Define Pydantic models | |
class Item(BaseModel): | |
#id: str | |
#title: str = None | |
text: str | |
#type: str | |
#engagements: dict | |
class RerankedItems(BaseModel): | |
ranked_ids: List[str] | |
new_items: List[dict] | |
# Define a health check endpoint | |
async def health_check(): | |
return {"status": "ok"} | |
# Define FastAPI routes and logic | |
async def rerank_items(items: List[Item]) -> RerankedItems: | |
reranked_ids = [] | |
# Process each item | |
for item in items: | |
# Classify the item using Hugging Face BART model | |
labels = classify_item(item.text) | |
# Save the item with labels in Redis | |
#redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)}) | |
# Add the item id to the reranked list | |
reranked_ids.append(item.id) | |
# Sort the items based on model confidence | |
#reranked_ids.sort(key=lambda x: redis_client.zscore("classified_items", x), reverse=True) | |
# Return the reranked items | |
return {"ranked_ids": reranked_ids, "new_items": []} # Ignore "new_items" for now | |
# Define an endpoint to classify items and save them in Redis | |
async def classify_and_save(items: List[Item]) -> None: | |
print("new 1") | |
#labels = ["factful", "civic", "constructive", "politics", "health", "news"] | |
#labels = ["factful", "politics"] | |
labels = ["something else", "news feed, news articles, breaking news", "politics and polititians", "healthcare and health"] | |
#labels = ["health", "politics", "news", "non-health non-politics non-news"] | |
texts = [item.text for item in items] | |
print(texts) | |
labels = score_text_with_labels(model=model, text=texts, labels=labels, multi=True) | |
print(labels) | |
return labels | |
#for item in items: | |
# print(item) | |
# Classify the item using Hugging Face BART model | |
#labels = classify_item(item.text) | |
#return score_text_with_labels(model, item.text, labels) | |
# Save the item with labels in Redis | |
#redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)}) | |
#return labels | |
#return None | |
# Function to classify item text using Hugging Face BART model | |
def classify_item(text: str) -> List[str]: | |
# Tokenize input text | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
print(1) | |
# Perform inference | |
outputs = model(**inputs) | |
print(2) | |
# Get predicted label | |
predicted_label = tokenizer.decode(outputs.logits.argmax()) | |
return [predicted_label] | |