Spaces:
Running
Running
import streamlit as st | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import json | |
import os | |
import requests | |
import torch | |
# Streamlit app configuration | |
st.set_page_config(page_title="AI Chatbot", layout="centered") | |
# Fix and modify the model configuration dynamically | |
def fix_model_config(model_name): | |
config_url = f"https://huggingface.co/{model_name}/resolve/main/config.json" | |
fixed_config_path = "fixed_config.json" | |
# Download and modify config.json | |
if not os.path.exists(fixed_config_path): | |
response = requests.get(config_url) | |
response.raise_for_status() | |
config = response.json() | |
# Fix the `rope_scaling` field | |
if "rope_scaling" in config: | |
config["rope_scaling"] = { | |
"type": "linear", | |
"factor": config["rope_scaling"].get("factor", 1.0) | |
} | |
# Save the fixed config | |
with open(fixed_config_path, "w") as f: | |
json.dump(config, f) | |
return fixed_config_path | |
# Load the pipeline | |
def load_pipeline(): | |
model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2" | |
# Fix the model configuration | |
fixed_config_path = fix_model_config(model_name) | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
config=fixed_config_path, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Return the text generation pipeline | |
return pipeline("text-generation", model=model, tokenizer=tokenizer) | |
pipe = load_pipeline() | |
# Streamlit App UI | |
st.title("🤖 AI Chatbot") | |
st.markdown( | |
""" | |
Welcome to the **AI Chatbot** powered by Hugging Face's **Llama-3.1-8B-Lexi-Uncensored-V2** model. | |
Type your message below and interact with the AI! | |
""" | |
) | |
# User input area | |
user_input = st.text_area( | |
"Your Message", | |
placeholder="Type your message here...", | |
height=100 | |
) | |
# Button to generate response | |
if st.button("Generate Response"): | |
if user_input.strip(): | |
with st.spinner("Generating response..."): | |
try: | |
response = pipe(user_input, max_length=150, num_return_sequences=1) | |
st.text_area("Response", value=response[0]["generated_text"], height=200) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.warning("Please enter a message before clicking the button.") | |
# Footer | |
st.markdown("---") | |
st.markdown("Made with ❤️ using [Streamlit](https://streamlit.io) and [Hugging Face](https://huggingface.co).") |