File size: 1,857 Bytes
69497aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import os
from huggingface_hub import login

# Authenticate with Hugging Face
def init_huggingface():
    try:
        # Get token from environment variable or use default
        hf_token = os.getenv('HUGGINGFACE_TOKEN')
        if hf_token:
            login(hf_token)
            print("Successfully logged in to Hugging Face")
        else:
            print("No Hugging Face token found, trying anonymous access")
    except Exception as e:
        print(f"Authentication error: {e}")

# Load model and tokenizer
def load_model():
    try:
        # Initialize Hugging Face authentication
        init_huggingface()
        
        print("Loading model...")
        # Try loading with auth token first
        model = GPT2LMHeadModel.from_pretrained(
            "aayushraina/gpt2shakespeare",
            local_files_only=False,
            trust_remote_code=True
        )
        print("Model loaded successfully!")
        
        print("Loading tokenizer...")
        # Use the base GPT-2 tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        print("Tokenizer loaded successfully!")
        
        model.eval()
        return model, tokenizer
    except Exception as e:
        print(f"Error loading model or tokenizer: {e}")
        try:
            # Fallback to base GPT-2 if custom model fails
            print("Attempting to load base GPT-2 model as fallback...")
            model = GPT2LMHeadModel.from_pretrained("gpt2")
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            print("Fallback successful - loaded base GPT-2")
            return model, tokenizer
        except Exception as e:
            print(f"Fallback failed: {e}")
            return None, None

# Rest of the code remains the same...