Spaces:
Sleeping
Sleeping
import gradio as gr | |
from peft import PeftModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
import os.path | |
import pickle | |
import torch | |
from openai import OpenAI | |
base_model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
model_id = "HiGenius/Headline-Generation-Model" | |
hf_token = os.environ.get('HF_TOKEN') | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_model(): | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, use_auth_token=hf_token) | |
model = PeftModel.from_pretrained(base_model, model_id, use_auth_token=hf_token).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=hf_token) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side='left' | |
tokenizer.truncation_side="left" | |
return tokenizer, model | |
def summarize_content(content): | |
client = OpenAI(api_key=openai_api_key) | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": "Summarize the following article content concisely while preserving key information:"}, | |
{"role": "user", "content": content} | |
], | |
max_tokens=600, | |
temperature=0.3 | |
) | |
return response.choices[0].message.content | |
tokenizer, model = load_model() | |
guideline_path = "./guidelines.txt" | |
with open(guideline_path, 'r', encoding='utf-8') as f: | |
guidelines = f.read() | |
def process_prompt(tokenizer, content, video_summary = '', guidelines = None): | |
# Check token lengths | |
content_tokens = len(tokenizer.encode(content)) | |
total_tokens = content_tokens | |
if video_summary: | |
total_tokens += len(tokenizer.encode(video_summary)) | |
if content_tokens > 850 or total_tokens > 900: | |
content = summarize_content(content) | |
if guidelines: | |
system_prompt = "You are a helpful assistant that writes engaging headlines. To maximize engagement, you may follow these proven guidelines:\n" + guidelines | |
else: | |
system_prompt = "You are a helpful assistant that writes engaging headlines." | |
user_prompt = ( | |
f"Below is an article and its accompanying video summary:\n\n" | |
f"Article Content:\n{content}\n\n" | |
f"Video Summary:\n{'None' if video_summary == '' else video_summary}\n\n" | |
f"Write ONLY a single engaging headline that accurately reflects the article. Do not include any additional text, explanations, or options." | |
) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
return prompt | |
def generate_headlines(content, video_summary): | |
if not content.strip(): | |
return "Please enter valid article content." | |
if not video_summary.strip(): | |
video_summary = '' | |
prompt = process_prompt(tokenizer, content, video_summary, guidelines) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device) | |
headlines = [] | |
for i in range(5): | |
outputs = model.generate(**inputs, | |
max_new_tokens=60, | |
num_return_sequences=1, | |
do_sample=True, | |
temperature=0.7) | |
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
response = response.replace('"', '') | |
headlines.append(f"Headline {i+1}: {response}") | |
return "\n\n".join(headlines) | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_headlines, | |
inputs=[ | |
gr.Textbox(label="Article Content", placeholder="Type the main content of the article here..."), | |
gr.Textbox(label="Video Summary (Optional)", placeholder="Type the summary of the video related to the article...") | |
], | |
outputs=gr.Textbox(label="Generated Headlines"), | |
title="Article Headline Writer", | |
description="Write catchy headlines from content and video summary." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |