File size: 4,250 Bytes
671d64d
e17a026
c04572a
 
 
 
 
4e926ed
c04572a
e17a026
c04572a
4e926ed
e17a026
4e926ed
e17a026
c04572a
 
671d64d
c04572a
e17a026
506a8d3
e17a026
c04572a
 
4e926ed
c04572a
 
 
4e926ed
 
 
 
 
 
 
 
 
 
 
 
 
c04572a
 
3094d88
 
 
c04572a
 
4e926ed
 
 
 
 
 
 
 
 
c04572a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671d64d
 
c04572a
671d64d
 
 
 
 
 
 
 
96b13ff
 
671d64d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import os.path 
import pickle
import torch
from openai import OpenAI

base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
model_id = "HiGenius/Headline-Generation-Model"

hf_token = os.environ.get('HF_TOKEN')
openai_api_key = os.environ.get('OPENAI_API_KEY')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@st.cache_resource
def load_model():
    base_model = AutoModelForCausalLM.from_pretrained(base_model_id, use_auth_token=hf_token)
    model = PeftModel.from_pretrained(base_model, model_id, use_auth_token=hf_token).to(device)
    tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=hf_token)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side='left'
    tokenizer.truncation_side="left"

    return tokenizer, model

def summarize_content(content):
    client = OpenAI(api_key=openai_api_key)
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "Summarize the following article content concisely while preserving key information:"},
            {"role": "user", "content": content}
        ],
        max_tokens=600,
        temperature=0.3
    )
    return response.choices[0].message.content

tokenizer, model = load_model()

guideline_path = "./guidelines.txt"
with open(guideline_path, 'r', encoding='utf-8') as f:
    guidelines = f.read()

def process_prompt(tokenizer, content, video_summary = '', guidelines = None):
    # Check token lengths
    content_tokens = len(tokenizer.encode(content))
    total_tokens = content_tokens
    if video_summary:
        total_tokens += len(tokenizer.encode(video_summary))

    if content_tokens > 850 or total_tokens > 900:
        content = summarize_content(content)
    
    if guidelines:
        system_prompt = "You are a helpful assistant that writes engaging headlines. To maximize engagement, you may follow these proven guidelines:\n" + guidelines
    else:
        system_prompt = "You are a helpful assistant that writes engaging headlines."

    user_prompt = (
        f"Below is an article and its accompanying video summary:\n\n"
        f"Article Content:\n{content}\n\n"
        f"Video Summary:\n{'None' if video_summary == '' else video_summary}\n\n"
        f"Write ONLY a single engaging headline that accurately reflects the article. Do not include any additional text, explanations, or options."
    )
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt


st.title("Article Headline Writer")
st.write("Write a catchy headline from content and video summary.")

# Inputs for content and video summary
content = st.text_area("Enter the article content:", placeholder="Type the main content of the article here...")
video_summary = st.text_area("Enter the summary of the article's accompanying video (optional):", placeholder="Type the summary of the video related to the article...")

if st.button("Generate Headline"):
    if content.strip():
        if not video_summary.strip():
            video_summary = ''
        # prompt = process_prompt(tokenizer, content, video_summary, guidelines)
        prompt = process_prompt(tokenizer, content, video_summary)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)
        
        st.write("### Generated 5 Potential Headlines:")
        for i in range(5):
            st.write(f"### Headline {i+1}")
            outputs = model.generate(**inputs,
                                   max_new_tokens=60,
                                   num_return_sequences=1, 
                                   do_sample=True,
                                   temperature=0.7)
            response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            response = response.replace('"', '')
            st.write(f"{response}")
    else:
        st.write("Please enter a valid prompt.")