socialboost / app /main.py
ezequiellopez
debugging setup
8593270
raw
history blame
4.54 kB
# Import required libraries
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
#import redis
from transformers import BartForSequenceClassification, BartTokenizer, AutoTokenizer, AutoConfig, pipeline
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv('../.env')
# Access environment variables
redis_port = os.getenv("REDIS_PORT")
fastapi_port = os.getenv("FASTAPI_PORT")
print("Redis port:", redis_port)
print("FastAPI port:", fastapi_port)
# Initialize FastAPI app and Redis client
app = FastAPI()
#redis_client = redis.Redis(host='redis', port=6379)
# Load BART model and tokenizer
#model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
#tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")
model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
def score_text_with_labels(model, text: list, labels: list, multi: bool=True):
#candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
results = [result['scores'] for result in model(text, labels, multi_label=multi)]
#return dict(zip(labels, results['scores']))
return results
def smooth_sequence(tweets_scores, window_size):
# Calculate the sum of scores for both labels for each tweet
tweet_sum_scores = [(sum(scores), index) for index, scores in enumerate(tweets_scores)]
# Sort tweets based on their sum scores, then by their original index to stabilize
sorted_tweets = sorted(tweet_sum_scores, key=lambda x: (x[0], x[1]))
# Extract the original indices of tweets after sorting
sorted_indices = [index for _, index in sorted_tweets]
# Create a new sequence based on sorted indices
smoothed_sequence = [tweets_scores[index] for index in sorted_indices]
return smoothed_sequence
def rerank_on_label(label: str):
return 200
# Define Pydantic models
class Item(BaseModel):
#id: str
#title: str = None
text: str
#type: str
#engagements: dict
class RerankedItems(BaseModel):
ranked_ids: List[str]
new_items: List[dict]
# Define a health check endpoint
@app.get("/")
async def health_check():
return {"status": "ok"}
# Define FastAPI routes and logic
@app.post("/rerank/")
async def rerank_items(items: List[Item]) -> RerankedItems:
reranked_ids = []
# Process each item
for item in items:
# Classify the item using Hugging Face BART model
labels = classify_item(item.text)
# Save the item with labels in Redis
#redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
# Add the item id to the reranked list
reranked_ids.append(item.id)
# Sort the items based on model confidence
#reranked_ids.sort(key=lambda x: redis_client.zscore("classified_items", x), reverse=True)
# Return the reranked items
return {"ranked_ids": reranked_ids, "new_items": []} # Ignore "new_items" for now
# Define an endpoint to classify items and save them in Redis
@app.post("/classify/")
async def classify_and_save(items: List[Item]) -> None:
print("new 1")
#labels = ["factful", "civic", "constructive", "politics", "health", "news"]
#labels = ["factful", "politics"]
labels = ["something else", "news feed, news articles, breaking news", "politics and polititians", "healthcare and health"]
#labels = ["health", "politics", "news", "non-health non-politics non-news"]
texts = [item.text for item in items]
print(texts)
labels = score_text_with_labels(model=model, text=texts, labels=labels, multi=True)
print(labels)
return labels
#for item in items:
# print(item)
# Classify the item using Hugging Face BART model
#labels = classify_item(item.text)
#return score_text_with_labels(model, item.text, labels)
# Save the item with labels in Redis
#redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
#return labels
#return None
# Function to classify item text using Hugging Face BART model
def classify_item(text: str) -> List[str]:
# Tokenize input text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
print(1)
# Perform inference
outputs = model(**inputs)
print(2)
# Get predicted label
predicted_label = tokenizer.decode(outputs.logits.argmax())
return [predicted_label]