barathm2001 commited on
Commit
292c995
·
verified ·
1 Parent(s): fbd8767

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +16 -19
  2. requirements.txt +5 -3
app.py CHANGED
@@ -1,8 +1,9 @@
1
- import os
2
  import logging
3
  from fastapi import FastAPI, HTTPException
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
  from peft import PeftModel, PeftConfig
 
 
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO)
@@ -15,39 +16,35 @@ app = FastAPI()
15
  model = None
16
  tokenizer = None
17
  pipe = None
 
18
 
19
  @app.on_event("startup")
20
  async def load_model():
21
- global model, tokenizer, pipe
22
 
23
  try:
24
- # Get Hugging Face token from environment variable
25
- hf_token = os.environ.get("HUGGINGFACE_TOKEN")
26
-
27
  logger.info("Loading PEFT configuration...")
28
  config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
29
 
30
  logger.info("Loading base model...")
31
- base_model = AutoModelForCausalLM.from_pretrained(
32
- "mistralai/Mistral-7B-Instruct-v0.3",
33
- token=hf_token if hf_token else None,
34
- use_auth_token=True if not hf_token else None
35
- )
36
 
37
  logger.info("Loading PEFT model...")
38
  model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
39
 
40
  logger.info("Loading tokenizer...")
41
- tokenizer = AutoTokenizer.from_pretrained(
42
- "mistralai/Mistral-7B-Instruct-v0.3",
43
- token=hf_token if hf_token else None,
44
- use_auth_token=True if not hf_token else None
45
- )
46
 
47
  logger.info("Creating pipeline...")
48
  pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
49
 
50
  logger.info("Model, tokenizer, and pipeline loaded successfully.")
 
 
 
51
  except Exception as e:
52
  logger.error(f"Error loading model or creating pipeline: {e}")
53
  raise
@@ -58,12 +55,12 @@ def home():
58
 
59
  @app.get("/generate")
60
  async def generate(text: str):
61
- if not pipe:
62
  raise HTTPException(status_code=503, detail="Model not loaded")
63
 
64
  try:
65
- output = pipe(text, max_length=100, num_return_sequences=1)
66
- return {"output": output[0]['generated_text']}
67
  except Exception as e:
68
  logger.error(f"Error during text generation: {e}")
69
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
 
 
1
  import logging
2
  from fastapi import FastAPI, HTTPException
3
+ from transformers import AutoModelForCausalLM, pipeline
4
  from peft import PeftModel, PeftConfig
5
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
6
+ from mistral_common.client import MistralChain
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
 
16
  model = None
17
  tokenizer = None
18
  pipe = None
19
+ mistral_chain = None
20
 
21
  @app.on_event("startup")
22
  async def load_model():
23
+ global model, tokenizer, pipe, mistral_chain
24
 
25
  try:
 
 
 
26
  logger.info("Loading PEFT configuration...")
27
  config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
28
 
29
  logger.info("Loading base model...")
30
+ base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
 
 
 
 
31
 
32
  logger.info("Loading PEFT model...")
33
  model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
34
 
35
  logger.info("Loading tokenizer...")
36
+ tokenizer = MistralTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
37
+
38
+ logger.info("Creating MistralChain...")
39
+ mistral_chain = MistralChain(model, tokenizer)
 
40
 
41
  logger.info("Creating pipeline...")
42
  pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
43
 
44
  logger.info("Model, tokenizer, and pipeline loaded successfully.")
45
+ except ImportError as e:
46
+ logger.error(f"Error importing required modules. Please check your installation: {e}")
47
+ raise
48
  except Exception as e:
49
  logger.error(f"Error loading model or creating pipeline: {e}")
50
  raise
 
55
 
56
  @app.get("/generate")
57
  async def generate(text: str):
58
+ if not mistral_chain:
59
  raise HTTPException(status_code=503, detail="Model not loaded")
60
 
61
  try:
62
+ output = mistral_chain.generate(text, max_tokens=100)
63
+ return {"output": output}
64
  except Exception as e:
65
  logger.error(f"Error during text generation: {e}")
66
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
requirements.txt CHANGED
@@ -1,9 +1,11 @@
1
  fastapi==0.103.0
 
2
  uvicorn[standard]==0.17.*
3
  torch>=1.13.0
4
- transformers>=4.34.0,<4.35.0
5
  numpy<2
6
- peft>=0.6.0,<0.7.0
7
  accelerate>=0.24.1,<0.25.0
8
  huggingface_hub>=0.16.4,<0.18.0
9
- tokenizers>=0.14.0,<0.15.0
 
 
1
  fastapi==0.103.0
2
+ requests==2.27.*
3
  uvicorn[standard]==0.17.*
4
  torch>=1.13.0
5
+ transformers>=4.36.0,<5.0.0
6
  numpy<2
7
+ peft>=0.8.0
8
  accelerate>=0.24.1,<0.25.0
9
  huggingface_hub>=0.16.4,<0.18.0
10
+ tokenizers>=0.14.0,<0.15.0
11
+ git+https://github.com/mistralai/mistral-common.git@main