alpeshsonar commited on
Commit
0dab623
·
verified ·
1 Parent(s): 9739a1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
@@ -5,17 +7,16 @@ from threading import Thread
5
  import uvicorn
6
 
7
  # Initialize FastAPI
8
-
9
  app = FastAPI()
10
 
11
  # Load the tokenizer and model
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model = model.to(device)
14
 
15
  # Define the function to generate text for Gradio
16
  def generate_text(input_text):
17
  input_text = "Extract lots from given text.\n" + input_text
18
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
19
  outputs = model.generate(inputs, max_new_tokens=1024)
20
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
  return result
@@ -30,10 +31,6 @@ iface = gr.Interface(
30
  )
31
 
32
  # Define a request body model for FastAPI
33
-
34
-
35
-
36
-
37
  class TextInput(BaseModel):
38
  input_text: str
39
 
@@ -41,12 +38,9 @@ class TextInput(BaseModel):
41
  @app.post("/generate")
42
  async def generate_text_api(input_data: TextInput):
43
  input_text = input_data.input_text
44
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
45
-
46
-
47
  outputs = model.generate(inputs, max_new_tokens=1024)
48
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
-
50
  return {"output": result}
51
 
52
  # Health check endpoint
 
1
+ import gradio as gr
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  import torch
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
 
7
  import uvicorn
8
 
9
  # Initialize FastAPI
 
10
  app = FastAPI()
11
 
12
  # Load the tokenizer and model
13
+ tokenizer = T5Tokenizer.from_pretrained("alpeshsonar/lot-t5-small-filter", legacy=False)
14
+ model = T5ForConditionalGeneration.from_pretrained("alpeshsonar/lot-t5-small-filter")
15
 
16
  # Define the function to generate text for Gradio
17
  def generate_text(input_text):
18
  input_text = "Extract lots from given text.\n" + input_text
19
+ inputs = tokenizer.encode(input_text, return_tensors="pt")
20
  outputs = model.generate(inputs, max_new_tokens=1024)
21
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
  return result
 
31
  )
32
 
33
  # Define a request body model for FastAPI
 
 
 
 
34
  class TextInput(BaseModel):
35
  input_text: str
36
 
 
38
  @app.post("/generate")
39
  async def generate_text_api(input_data: TextInput):
40
  input_text = input_data.input_text
41
+ inputs = tokenizer.encode(input_text, return_tensors="pt")
 
 
42
  outputs = model.generate(inputs, max_new_tokens=1024)
43
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
44
  return {"output": result}
45
 
46
  # Health check endpoint