File size: 4,540 Bytes
e145e85
 
 
 
8593270
e145e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8593270
e145e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8593270
e145e85
 
 
 
 
8593270
e145e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Import required libraries
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
#import redis
from transformers import BartForSequenceClassification, BartTokenizer, AutoTokenizer, AutoConfig, pipeline
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv('../.env')

# Access environment variables
redis_port = os.getenv("REDIS_PORT")
fastapi_port = os.getenv("FASTAPI_PORT")


print("Redis port:", redis_port)
print("FastAPI port:", fastapi_port)


# Initialize FastAPI app and Redis client
app = FastAPI()
#redis_client = redis.Redis(host='redis', port=6379)

# Load BART model and tokenizer
#model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
#tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")

model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

def score_text_with_labels(model, text: list, labels: list, multi: bool=True):
    #candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
    results = [result['scores'] for result in model(text, labels, multi_label=multi)]
    #return dict(zip(labels, results['scores']))
    return results

def smooth_sequence(tweets_scores, window_size):
    # Calculate the sum of scores for both labels for each tweet
    tweet_sum_scores = [(sum(scores), index) for index, scores in enumerate(tweets_scores)]
    # Sort tweets based on their sum scores, then by their original index to stabilize
    sorted_tweets = sorted(tweet_sum_scores, key=lambda x: (x[0], x[1]))
    # Extract the original indices of tweets after sorting
    sorted_indices = [index for _, index in sorted_tweets]
    # Create a new sequence based on sorted indices
    smoothed_sequence = [tweets_scores[index] for index in sorted_indices]    
    return smoothed_sequence

def rerank_on_label(label: str):
    return 200


# Define Pydantic models
class Item(BaseModel):
    #id: str
    #title: str = None
    text: str
    #type: str
    #engagements: dict

class RerankedItems(BaseModel):
    ranked_ids: List[str]
    new_items: List[dict]

# Define a health check endpoint
@app.get("/")
async def health_check():
    return {"status": "ok"}

# Define FastAPI routes and logic
@app.post("/rerank/")
async def rerank_items(items: List[Item]) -> RerankedItems:
    reranked_ids = []

    # Process each item
    for item in items:
        # Classify the item using Hugging Face BART model
        labels = classify_item(item.text)

        # Save the item with labels in Redis
        #redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})

        # Add the item id to the reranked list
        reranked_ids.append(item.id)

    # Sort the items based on model confidence
    #reranked_ids.sort(key=lambda x: redis_client.zscore("classified_items", x), reverse=True)

    # Return the reranked items
    return {"ranked_ids": reranked_ids, "new_items": []}  # Ignore "new_items" for now

# Define an endpoint to classify items and save them in Redis
@app.post("/classify/")
async def classify_and_save(items: List[Item]) -> None:
    print("new 1")
    #labels = ["factful", "civic", "constructive", "politics", "health", "news"]
    #labels = ["factful", "politics"]
    labels = ["something else", "news feed, news articles, breaking news", "politics and polititians", "healthcare and health"]
    #labels = ["health", "politics", "news", "non-health non-politics non-news"]
    texts = [item.text for item in items]
    print(texts)
    
    labels = score_text_with_labels(model=model, text=texts, labels=labels, multi=True)
    print(labels)
    return labels
    #for item in items:
    #    print(item)
        # Classify the item using Hugging Face BART model
        #labels = classify_item(item.text)
        #return score_text_with_labels(model, item.text, labels)
        # Save the item with labels in Redis
        #redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
        #return labels
    #return None

# Function to classify item text using Hugging Face BART model
def classify_item(text: str) -> List[str]:
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    print(1)
    # Perform inference
    outputs = model(**inputs)
    print(2)
    # Get predicted label
    predicted_label = tokenizer.decode(outputs.logits.argmax())

    return [predicted_label]