You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1 line
4.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30699,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import numpy as np\n\nnp.random.seed(0) # 设置随机种子以获得可重复的结果\n\nX = np.random.randn(4, 8) # 假设我们的向量维度是8即从768变成8还是\"LLM with me\"的4个Token\nW = np.random.randn(8, 1) # 权重矩阵W形状为[8, 1]\nb = np.random.randn(1) # 偏置向量b形状为[1]\n# 线性变换Y = XW + b\n# 这里使用np.dot进行矩阵乘法然后加上偏置\nY = np.dot(X, W) + b\n# 输出结果Y形状为[4, 1] 为了得到形状[4,]的输出,我们可以将结果压缩到一维\nY = np.squeeze(Y)\n\nprint(\"Input X shape:\", X.shape)\nprint(\"Weight W shape:\", W.shape)\nprint(\"Bias b shape:\", b.shape)\nprint(\"Output Y shape:\", Y.shape)\nprint(\"Output Y:\", Y)","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2024-05-03T10:10:29.150978Z","iopub.execute_input":"2024-05-03T10:10:29.151830Z","iopub.status.idle":"2024-05-03T10:10:29.159883Z","shell.execute_reply.started":"2024-05-03T10:10:29.151794Z","shell.execute_reply":"2024-05-03T10:10:29.158888Z"},"trusted":true},"execution_count":15,"outputs":[{"name":"stdout","text":"Input X shape: (4, 8)\nWeight W shape: (8, 1)\nBias b shape: (1,)\nOutput Y shape: (4,)\nOutput Y: [-2.59709604 -0.78316274 -4.6765379 3.25016417]\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n# 假设的词汇表和词嵌入\nvocab = {'LLM': 0, 'with': 1, 'me': 2, '<PAD>': 3} # 一个简化的词汇表\nvocab_size = len(vocab) # 词汇表大小\nembedding_dim = 768 # 嵌入维度与GPT-2的小型版本相同\n\ntext = \"LLM with me\"\ninput_ids = torch.tensor([[vocab[word] for word in text.split()]], dtype=torch.long)\n\n# 模拟Transformer流程\nembedding_layer = nn.Embedding(vocab_size, embedding_dim)\nembedded = embedding_layer(input_ids)\ntransformer_output = torch.rand(embedded.size()) # 假设的Transformer输出\n# 创建一个线性层将Transformer输出映射到词汇表空间\nlinear_layer = nn.Linear(embedding_dim, vocab_size)\nvocab_space_scores = linear_layer(transformer_output)\n# 输出概率分布\nprobabilities = F.softmax(vocab_space_scores, dim=-1)\nprint(probabilities)","metadata":{"execution":{"iopub.status.busy":"2024-05-03T09:49:56.395154Z","iopub.execute_input":"2024-05-03T09:49:56.395831Z","iopub.status.idle":"2024-05-03T09:49:56.407684Z","shell.execute_reply.started":"2024-05-03T09:49:56.395797Z","shell.execute_reply":"2024-05-03T09:49:56.406645Z"},"trusted":true},"execution_count":7,"outputs":[{"name":"stdout","text":"tensor([[[0.2306, 0.2478, 0.2688, 0.2528],\n [0.1928, 0.3077, 0.2768, 0.2227],\n [0.2562, 0.2568, 0.2837, 0.2033]]], grad_fn=<SoftmaxBackward0>)\n","output_type":"stream"}]},{"cell_type":"code","source":"vocab_space_scores","metadata":{"execution":{"iopub.status.busy":"2024-05-03T09:50:05.253992Z","iopub.execute_input":"2024-05-03T09:50:05.254936Z","iopub.status.idle":"2024-05-03T09:50:05.261793Z","shell.execute_reply.started":"2024-05-03T09:50:05.254901Z","shell.execute_reply":"2024-05-03T09:50:05.260661Z"},"trusted":true},"execution_count":8,"outputs":[{"execution_count":8,"output_type":"execute_result","data":{"text/plain":"tensor([[[-0.1818, -0.1099, -0.0287, -0.0901],\n [-0.6043, -0.1369, -0.2430, -0.4603],\n [-0.1816, -0.1795, -0.0799, -0.4130]]], grad_fn=<ViewBackward0>)"},"metadata":{}}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}