Allahbux commited on
Commit
928429b
·
verified ·
1 Parent(s): ac00ffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -1,21 +1,56 @@
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
3
 
4
  # Streamlit app configuration
5
  st.set_page_config(page_title="AI Chatbot", layout="centered")
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Load the model pipeline
8
  @st.cache_resource
9
  def load_pipeline():
10
  model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
11
 
 
 
 
12
  # Load tokenizer and model
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_name,
16
- device_map="auto", # Use GPU if available
17
- rope_scaling=None # Avoid issues with rope_scaling
18
  )
 
19
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
20
 
21
  pipe = load_pipeline()
 
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
+ import json
4
+ import requests
5
+ import os
6
 
7
  # Streamlit app configuration
8
  st.set_page_config(page_title="AI Chatbot", layout="centered")
9
 
10
+ # Fix the model's configuration before loading
11
+ def fix_model_config(model_name):
12
+ # Download the configuration file from the model repository
13
+ config_url = f"https://huggingface.co/{model_name}/resolve/main/config.json"
14
+ config_path = "config.json"
15
+
16
+ if not os.path.exists(config_path):
17
+ response = requests.get(config_url)
18
+ response.raise_for_status() # Raise an error if the request fails
19
+ with open(config_path, "w") as f:
20
+ f.write(response.text)
21
+
22
+ # Load the configuration and modify rope_scaling if necessary
23
+ with open(config_path, "r") as f:
24
+ config = json.load(f)
25
+
26
+ if "rope_scaling" in config:
27
+ config["rope_scaling"] = {
28
+ "type": "linear", # Replace the problematic rope_scaling type
29
+ "factor": config["rope_scaling"].get("factor", 1.0)
30
+ }
31
+
32
+ # Save the modified configuration
33
+ with open(config_path, "w") as f:
34
+ json.dump(config, f)
35
+
36
+ return config_path
37
+
38
  # Load the model pipeline
39
  @st.cache_resource
40
  def load_pipeline():
41
  model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
42
 
43
+ # Fix the model configuration
44
+ fixed_config_path = fix_model_config(model_name)
45
+
46
  # Load tokenizer and model
47
  tokenizer = AutoTokenizer.from_pretrained(model_name)
48
  model = AutoModelForCausalLM.from_pretrained(
49
  model_name,
50
+ config=fixed_config_path,
51
+ device_map="auto" # Use GPU if available
52
  )
53
+
54
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
55
 
56
  pipe = load_pipeline()