summariser / app.py
kyserS09's picture
Update app.py
3d097bd verified
import gradio as gr
import os
import requests
import torch
from transformers import (
LEDTokenizer, LEDForConditionalGeneration,
BartTokenizer, BartForConditionalGeneration,
PegasusTokenizer, PegasusForConditionalGeneration,
AutoTokenizer, AutoModelForSeq2SeqLM
)
# OpenAI API Key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Ensure this is set in your environment variables
# List of models in priority order
MODELS = [
{
"name": "allenai/led-large-16384",
"tokenizer_class": LEDTokenizer,
"model_class": LEDForConditionalGeneration
},
{
"name": "facebook/bart-large-cnn",
"tokenizer_class": BartTokenizer,
"model_class": BartForConditionalGeneration
},
{
"name": "Falconsai/text_summarization",
"tokenizer_class": AutoTokenizer,
"model_class": AutoModelForSeq2SeqLM
},
{
"name": "google/pegasus-xsum",
"tokenizer_class": PegasusTokenizer,
"model_class": PegasusForConditionalGeneration
}
]
# Load models sequentially
loaded_models = []
for model_info in MODELS:
try:
tokenizer = model_info["tokenizer_class"].from_pretrained(model_info["name"])
model = model_info["model_class"].from_pretrained(model_info["name"])
loaded_models.append({"name": model_info["name"], "tokenizer": tokenizer, "model": model})
print(f"Loaded model: {model_info['name']}")
except Exception as e:
print(f"Failed to load {model_info['name']}: {e}")
def summarize_with_transformers(text):
"""
Try summarizing with locally loaded Transformer models in order of priority.
"""
for model_data in loaded_models:
try:
tokenizer = model_data["tokenizer"]
model = model_data["model"]
# Tokenize input with truncation
inputs = tokenizer([text], max_length=16384, return_tensors="pt", truncation=True)
# Generate summary
summary_ids = model.generate(
inputs["input_ids"],
num_beams=4,
max_length=512,
min_length=100,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary # Return the first successful response
except Exception as e:
print(f"Error using {model_data['name']}: {e}")
return None # Indicate failure
def summarize_with_chatgpt(text):
"""
Fallback to OpenAI ChatGPT API if all other models fail.
"""
if not OPENAI_API_KEY:
return "Error: No OpenAI API key provided."
headers = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": f"Summarize this article: {text}"}],
"max_tokens": 512
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
return f"Error: Failed to summarize with ChatGPT (status {response.status_code})"
def summarize_text(text):
"""
Main function to summarize text, trying Transformer models first, then ChatGPT if needed.
"""
summary = summarize_with_transformers(text)
if summary:
return summary # Return successful summary from a Transformer model
print("All Transformer models failed. Falling back to ChatGPT...")
return summarize_with_chatgpt(text) # Use ChatGPT as last resort
# Gradio Interface
iface = gr.Interface(
fn=summarize_text,
inputs="text",
outputs="text",
title="Multi-Model Summarizer with Fallback",
description="Tries multiple models for summarization, falling back to ChatGPT if needed."
)
if __name__ == "__main__":
iface.launch()