# Import required libraries from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List #import redis from transformers import BartForSequenceClassification, BartTokenizer, AutoTokenizer, AutoConfig, pipeline from dotenv import load_dotenv import os # Load environment variables from .env file load_dotenv('../.env') # Access environment variables redis_port = os.getenv("REDIS_PORT") fastapi_port = os.getenv("FASTAPI_PORT") print("Redis port:", redis_port) print("FastAPI port:", fastapi_port) # Initialize FastAPI app and Redis client app = FastAPI() #redis_client = redis.Redis(host='redis', port=6379) # Load BART model and tokenizer #model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli") #tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli") model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") def score_text_with_labels(model, text: list, labels: list, multi: bool=True): #candidate_labels = ['travel', 'cooking', 'dancing', 'exploration'] results = [result['scores'] for result in model(text, labels, multi_label=multi)] #return dict(zip(labels, results['scores'])) return results def smooth_sequence(tweets_scores, window_size): # Calculate the sum of scores for both labels for each tweet tweet_sum_scores = [(sum(scores), index) for index, scores in enumerate(tweets_scores)] # Sort tweets based on their sum scores, then by their original index to stabilize sorted_tweets = sorted(tweet_sum_scores, key=lambda x: (x[0], x[1])) # Extract the original indices of tweets after sorting sorted_indices = [index for _, index in sorted_tweets] # Create a new sequence based on sorted indices smoothed_sequence = [tweets_scores[index] for index in sorted_indices] return smoothed_sequence def rerank_on_label(label: str): return 200 # Define Pydantic models class Item(BaseModel): #id: str #title: str = None text: str #type: str #engagements: dict class RerankedItems(BaseModel): ranked_ids: List[str] new_items: List[dict] # Define a health check endpoint @app.get("/") async def health_check(): return {"status": "ok"} # Define FastAPI routes and logic @app.post("/rerank/") async def rerank_items(items: List[Item]) -> RerankedItems: reranked_ids = [] # Process each item for item in items: # Classify the item using Hugging Face BART model labels = classify_item(item.text) # Save the item with labels in Redis #redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)}) # Add the item id to the reranked list reranked_ids.append(item.id) # Sort the items based on model confidence #reranked_ids.sort(key=lambda x: redis_client.zscore("classified_items", x), reverse=True) # Return the reranked items return {"ranked_ids": reranked_ids, "new_items": []} # Ignore "new_items" for now # Define an endpoint to classify items and save them in Redis @app.post("/classify/") async def classify_and_save(items: List[Item]) -> None: print("new 1") #labels = ["factful", "civic", "constructive", "politics", "health", "news"] #labels = ["factful", "politics"] labels = ["something else", "news feed, news articles, breaking news", "politics and polititians", "healthcare and health"] #labels = ["health", "politics", "news", "non-health non-politics non-news"] texts = [item.text for item in items] print(texts) labels = score_text_with_labels(model=model, text=texts, labels=labels, multi=True) print(labels) return labels #for item in items: # print(item) # Classify the item using Hugging Face BART model #labels = classify_item(item.text) #return score_text_with_labels(model, item.text, labels) # Save the item with labels in Redis #redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)}) #return labels #return None # Function to classify item text using Hugging Face BART model def classify_item(text: str) -> List[str]: # Tokenize input text inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) print(1) # Perform inference outputs = model(**inputs) print(2) # Get predicted label predicted_label = tokenizer.decode(outputs.logits.argmax()) return [predicted_label]