@@ -43,16 +43,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/education/DATA/projects/ai/toolformer/env/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from typing import Optional, List\n",
@@ -63,23 +54,10 @@
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from torchtyping import TensorType\n",
"from einops import rearrange\n",
"\n",
"from toolformer.api import BaseAPI\n",
"from toolformer.utils import extract_api_request_content"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# detect\n",
"# wait for end of token\n",
"# extract\n",
"# excute\n",
"# add the result to the input\n",
"# continue"
"from toolformer.utils import extract_api_content, extract_api_name"
]
},
{
@@ -113,20 +91,37 @@
" self.api_start_token_id = tokenizer(f' {start_character}', return_tensors=\"pt\")[\"input_ids\"][0]\n",
" self.api_end_token_id = tokenizer(end_character, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" self.api_output_token_id = tokenizer(f'{output_character}', return_tensors=\"pt\")[\"input_ids\"][0]\n",
" \n",
"\n",
" self.eos_token_ids = tokenizer(\n",
" [\".\", \".\\n\\n\"],\n",
" return_tensors=\"pt\"\n",
" )[\"input_ids\"].squeeze()\n",
"\n",
" # TODO: support batch\n",
" self.api_request_content: torch.Tensor = torch.tensor([])\n",
" \n",
" def _sampling(self, probs: TensorType[\"batch_size\", \"seq_len\"]) -> TensorType[\"batch_size\", \"seq_len\"]:\n",
" return torch.argmax(probs, dim=-1)\n",
" \n",
" def execute_api(self, text_ids: TensorType[\"seq_len\"]) -> TensorType[\"seq_len\"]:\n",
" def execute_api(self, text_ids: TensorType[\"seq_len\"]) -> Optional[ TensorType[\"seq_len\"] ]:\n",
" \"\"\"Execute an API call.\"\"\"\n",
" # content_ids = extract_api_request_content(text_ids, self.apis)\n",
" pass\n",
" text = self.tokenizer.decode(text_ids, skip_special_tokens=True)\n",
" api_name = extract_api_name(text, is_end_token=False)\n",
"\n",
" if api_name is not None:\n",
" # find does apis contains the api_name\n",
" for api in self.apis:\n",
" if api.name == api_name:\n",
" api_content = extract_api_content(text, api_name=api_name)\n",
" api_output = api(api_content)\n",
" return self.tokenizer(api_output, return_tensors=\"pt\")[\"input_ids\"][0]\n",
" return None\n",
" \n",
" def add_idx_to_api_request_content(self, idx: TensorType[1]):\n",
" self.api_request_content = torch.cat([self.api_request_content, idx.unsqueeze(0)], dim=0)\n",
" self.api_request_content = torch.cat([\n",
" self.api_request_content,\n",
" rearrange(idx, '... -> 1 ...')\n",
" ], dim=-1).long()\n",
" \n",
" def forward(\n",
" self,\n",
@@ -136,7 +131,6 @@
" **kwargs\n",
" ) -> TensorType[\"batch_size\", \"seq_len\"]:\n",
" # check padding to the left\n",
" \n",
" generated_ids = input_ids\n",
" \n",
" for _ in range(max_new_tokens):\n",
@@ -148,15 +142,23 @@
" \n",
" logits = output_ids.logits[:, -1, :]\n",
" probs = F.softmax(logits, dim=-1)\n",
" _, top_k_idx = torch.topk(probs, k=5, dim=-1)\n",
" # TODO: k should be a config\n",
" _, top_k_idx = torch.topk(probs, k=1, dim=-1)\n",
" \n",
" if self.is_calling_api is True:\n",
" if self.api_end_token_id in top_k_idx:\n",
" # if the api end token is in the top_k_idx, then we will execute the api\n",
" # and then add api_end_token_id to the generated_ids\n",
" self.add_idx_to_api_request_content(self.api_end_token_id)\n",
" api_output_ids = self.execute_api(self.api_request_content)\n",
" pred_ids = torch.tensor([self.api_end_token_id, api_output_ids])\n",
" # TODO: add support batch\n",
" api_output_ids = self.execute_api(self.api_request_content[0])\n",
" if api_output_ids is not None:\n",
" pred_ids = torch.cat([\n",
" self.api_output_token_id,\n",
" api_output_ids,\n",
" self.api_end_token_id\n",
" ], dim=-1).long()\n",
" else:\n",
" pred_ids = self.api_end_token_id\n",
" self.is_calling_api = False\n",
" else:\n",
" pred_ids = self._sampling(probs)\n",
@@ -170,8 +172,19 @@
" else:\n",
" pred_ids = self._sampling(probs)\n",
" \n",
" generated_ids = torch.cat([generated_ids, pred_ids.unsqueeze(dim=1)], dim=1)\n",
" attention_mask = torch.cat([attention_mask, torch.ones_like(pred_ids).unsqueeze(dim=1)], dim=1)\n",
" generated_ids = torch.cat([\n",
" generated_ids,\n",
" rearrange(pred_ids, '... -> 1 ...')\n",
" ], dim=1)\n",
" \n",
" attention_mask = torch.cat([\n",
" attention_mask,\n",
" rearrange(torch.ones_like(pred_ids), '... -> 1 ...')\n",
" ], dim=1)\n",
" \n",
" # ignore the case that pred_ids contains api_output\n",
" if len(pred_ids) == 1 and pred_ids in self.eos_token_ids:\n",
" break\n",
" \n",
" return generated_ids"
]