victor-johnson commited on
Commit
11f5b1b
Β·
verified Β·
1 Parent(s): 253aa25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -7,27 +7,27 @@ from transformers import (
7
  )
8
  from langchain_huggingface import HuggingFacePipeline
9
  from langchain_core.prompts import PromptTemplate
10
- from langchain.chains import LLMChain
11
 
12
  # β€” Model setup β€”
13
  MODEL_ID = "bigcode/starcoder2-3b"
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
15
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
16
 
17
- # β€” Pipeline setup (pass generation parameters directly) β€”
18
  pipe = pipeline(
19
  "text-generation",
20
  model=model,
21
  tokenizer=tokenizer,
22
- device_map="auto",
23
  max_new_tokens=64,
24
- temperature=0.2,
25
- top_p=0.95,
26
  do_sample=False,
27
  )
28
  llm = HuggingFacePipeline(pipeline=pipe)
29
 
30
- # β€” Prompt & chain β€”
31
  prompt = PromptTemplate(
32
  input_variables=["description"],
33
  template=(
@@ -36,7 +36,7 @@ prompt = PromptTemplate(
36
  "Emmet:"
37
  ),
38
  )
39
- chain = LLMChain(llm=llm, prompt=prompt)
40
 
41
  # β€” FastAPI app β€”
42
  app = FastAPI()
@@ -47,8 +47,12 @@ class Req(BaseModel):
47
  class Res(BaseModel):
48
  emmet: str
49
 
 
 
 
 
50
  @app.post("/generate-emmet", response_model=Res)
51
  async def generate_emmet(req: Req):
52
- raw = chain.invoke(req.description) # use .invoke() instead of deprecated .run()
53
  emmet = raw.strip().splitlines()[0]
54
  return {"emmet": emmet}
 
7
  )
8
  from langchain_huggingface import HuggingFacePipeline
9
  from langchain_core.prompts import PromptTemplate
10
+ from langchain_core.runnables import RunnableSequence
11
 
12
  # β€” Model setup β€”
13
  MODEL_ID = "bigcode/starcoder2-3b"
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
+ # Explicitly set pad_token_id to eos_token_id
16
+ tokenizer.pad_token_id = tokenizer.eos_token_id
17
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
18
 
19
+ # β€” Pipeline setup (remove unused parameters, set device explicitly) β€”
20
  pipe = pipeline(
21
  "text-generation",
22
  model=model,
23
  tokenizer=tokenizer,
24
+ device=-1, # Explicitly use CPU; change to 0 or "cuda" if GPU available
25
  max_new_tokens=64,
 
 
26
  do_sample=False,
27
  )
28
  llm = HuggingFacePipeline(pipeline=pipe)
29
 
30
+ # β€” Prompt & chain (using RunnableSequence) β€”
31
  prompt = PromptTemplate(
32
  input_variables=["description"],
33
  template=(
 
36
  "Emmet:"
37
  ),
38
  )
39
+ chain = RunnableSequence(prompt | llm)
40
 
41
  # β€” FastAPI app β€”
42
  app = FastAPI()
 
47
  class Res(BaseModel):
48
  emmet: str
49
 
50
+ @app.get("/")
51
+ async def root():
52
+ return {"message": "Welcome to the Emmet Generator API. Use POST /generate-emmet."}
53
+
54
  @app.post("/generate-emmet", response_model=Res)
55
  async def generate_emmet(req: Req):
56
+ raw = chain.invoke(req.description)
57
  emmet = raw.strip().splitlines()[0]
58
  return {"emmet": emmet}