Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import requests | |
import torch | |
from transformers import ( | |
LEDTokenizer, LEDForConditionalGeneration, | |
BartTokenizer, BartForConditionalGeneration, | |
PegasusTokenizer, PegasusForConditionalGeneration, | |
AutoTokenizer, AutoModelForSeq2SeqLM | |
) | |
# OpenAI API Key | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Ensure this is set in your environment variables | |
# List of models in priority order | |
MODELS = [ | |
{ | |
"name": "allenai/led-large-16384", | |
"tokenizer_class": LEDTokenizer, | |
"model_class": LEDForConditionalGeneration | |
}, | |
{ | |
"name": "facebook/bart-large-cnn", | |
"tokenizer_class": BartTokenizer, | |
"model_class": BartForConditionalGeneration | |
}, | |
{ | |
"name": "Falconsai/text_summarization", | |
"tokenizer_class": AutoTokenizer, | |
"model_class": AutoModelForSeq2SeqLM | |
}, | |
{ | |
"name": "google/pegasus-xsum", | |
"tokenizer_class": PegasusTokenizer, | |
"model_class": PegasusForConditionalGeneration | |
} | |
] | |
# Load models sequentially | |
loaded_models = [] | |
for model_info in MODELS: | |
try: | |
tokenizer = model_info["tokenizer_class"].from_pretrained(model_info["name"]) | |
model = model_info["model_class"].from_pretrained(model_info["name"]) | |
loaded_models.append({"name": model_info["name"], "tokenizer": tokenizer, "model": model}) | |
print(f"Loaded model: {model_info['name']}") | |
except Exception as e: | |
print(f"Failed to load {model_info['name']}: {e}") | |
def summarize_with_transformers(text): | |
""" | |
Try summarizing with locally loaded Transformer models in order of priority. | |
""" | |
for model_data in loaded_models: | |
try: | |
tokenizer = model_data["tokenizer"] | |
model = model_data["model"] | |
# Tokenize input with truncation | |
inputs = tokenizer([text], max_length=16384, return_tensors="pt", truncation=True) | |
# Generate summary | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
num_beams=4, | |
max_length=512, | |
min_length=100, | |
early_stopping=True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary # Return the first successful response | |
except Exception as e: | |
print(f"Error using {model_data['name']}: {e}") | |
return None # Indicate failure | |
def summarize_with_chatgpt(text): | |
""" | |
Fallback to OpenAI ChatGPT API if all other models fail. | |
""" | |
if not OPENAI_API_KEY: | |
return "Error: No OpenAI API key provided." | |
headers = { | |
"Authorization": f"Bearer {OPENAI_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"model": "gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": f"Summarize this article: {text}"}], | |
"max_tokens": 512 | |
} | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
if response.status_code == 200: | |
return response.json()["choices"][0]["message"]["content"] | |
else: | |
return f"Error: Failed to summarize with ChatGPT (status {response.status_code})" | |
def summarize_text(text): | |
""" | |
Main function to summarize text, trying Transformer models first, then ChatGPT if needed. | |
""" | |
summary = summarize_with_transformers(text) | |
if summary: | |
return summary # Return successful summary from a Transformer model | |
print("All Transformer models failed. Falling back to ChatGPT...") | |
return summarize_with_chatgpt(text) # Use ChatGPT as last resort | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=summarize_text, | |
inputs="text", | |
outputs="text", | |
title="Multi-Model Summarizer with Fallback", | |
description="Tries multiple models for summarization, falling back to ChatGPT if needed." | |
) | |
if __name__ == "__main__": | |
iface.launch() |