Gopikanth123 commited on
Commit
4e5040c
·
verified ·
1 Parent(s): da79ef9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -13
main.py CHANGED
@@ -7,6 +7,10 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
  from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer, AutoModel
9
  from deep_translator import GoogleTranslator
 
 
 
 
10
 
11
 
12
  # Ensure HF_TOKEN is set
@@ -31,14 +35,14 @@ llm_client = InferenceClient(
31
  # generate_kwargs={"temperature": 0.1},
32
  # )
33
  # Configure Llama index settings with the new model
34
- Settings.llm = HuggingFaceInferenceAPI(
35
- model_name=repo_id,
36
- tokenizer_name=repo_id, # Use the same tokenizer as the model
37
- context_window=3000,
38
- token=HF_TOKEN,
39
- max_new_tokens=512,
40
- generate_kwargs={"temperature": 0.1},
41
- )
42
  # Settings.embed_model = HuggingFaceEmbedding(
43
  # model_name="BAAI/bge-small-en-v1.5"
44
  # )
@@ -46,17 +50,35 @@ Settings.llm = HuggingFaceInferenceAPI(
46
  # Settings.embed_model = HuggingFaceEmbedding(
47
  # model_name="xlm-roberta-base" # XLM-RoBERTa model for multilingual support
48
  # )
49
- Settings.embed_model = HuggingFaceEmbedding(
50
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
51
- )
52
 
53
  # # Configure tokenizer and model if required
54
  # tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
55
  # model = AutoModel.from_pretrained("xlm-roberta-base")
56
  # Configure tokenizer and model if required
57
  tokenizer = AutoTokenizer.from_pretrained(repo_id) # Use the tokenizer from the new model
58
- model = AutoModel.from_pretrained(repo_id) # Load the new model
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  PERSIST_DIR = "db"
61
  PDF_DIRECTORY = 'data'
62
 
 
7
  from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer, AutoModel
9
  from deep_translator import GoogleTranslator
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ import torch
12
+ from accelerate import infer_auto_device_map
13
+
14
 
15
 
16
  # Ensure HF_TOKEN is set
 
35
  # generate_kwargs={"temperature": 0.1},
36
  # )
37
  # Configure Llama index settings with the new model
38
+ # Settings.llm = HuggingFaceInferenceAPI(
39
+ # model_name=repo_id,
40
+ # tokenizer_name=repo_id, # Use the same tokenizer as the model
41
+ # context_window=3000,
42
+ # token=HF_TOKEN,
43
+ # max_new_tokens=512,
44
+ # generate_kwargs={"temperature": 0.1},
45
+ # )
46
  # Settings.embed_model = HuggingFaceEmbedding(
47
  # model_name="BAAI/bge-small-en-v1.5"
48
  # )
 
50
  # Settings.embed_model = HuggingFaceEmbedding(
51
  # model_name="xlm-roberta-base" # XLM-RoBERTa model for multilingual support
52
  # )
53
+ # Settings.embed_model = HuggingFaceEmbedding(
54
+ # model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
55
+ # )
56
 
57
  # # Configure tokenizer and model if required
58
  # tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
59
  # model = AutoModel.from_pretrained("xlm-roberta-base")
60
  # Configure tokenizer and model if required
61
  tokenizer = AutoTokenizer.from_pretrained(repo_id) # Use the tokenizer from the new model
62
+ # model = AutoModel.from_pretrained(repo_id) # Load the new model
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ repo_id,
65
+ load_in_4bit=True, # Load in 4-bit quantization
66
+ torch_dtype=torch.float16,
67
+ device_map="auto",
68
+ )
69
+ # Configure Llama index settings
70
+ Settings.llm = HuggingFaceInferenceAPI(
71
+ model_name=repo_id,
72
+ tokenizer_name=repo_id, # Use the same tokenizer as the model
73
+ context_window=2048, # Reduce context window to save memory
74
+ token=HF_TOKEN,
75
+ max_new_tokens=256, # Reduce max tokens to save memory
76
+ generate_kwargs={"temperature": 0.1},
77
+ )
78
+ # Use a smaller embedding model
79
+ Settings.embed_model = HuggingFaceEmbedding(
80
+ model_name="sentence-transformers/all-MiniLM-L6-v2" # Smaller and faster
81
+ )
82
  PERSIST_DIR = "db"
83
  PDF_DIRECTORY = 'data'
84